Source code for vision3d.metrics._mean_average_precision_3d

"""3D detection mean Average Precision metric."""

from collections.abc import Iterable
from dataclasses import dataclass, field
from enum import Enum
from typing import TYPE_CHECKING, NotRequired, TypedDict

import torch
from torch import Tensor

from vision3d.metrics._types import Prediction3D, Target3D
from vision3d.ops import box3d_iou
from vision3d.ops._points_in_boxes_3d import _extract_box_params

if TYPE_CHECKING:
    from vision3d.tensors import BoundingBox3DFormat


[docs] class APInterpolation(Enum): """AP interpolation mode.""" R40 = "r40" """40-point interpolation (modern KITTI default).""" R11 = "r11" """11-point interpolation (legacy KITTI, PASCAL VOC07).""" R101 = "r101" """101-point interpolation (COCO).""" ALL_POINTS = "all_points" """VOC07 area-under-curve at every recall change."""
_RangeBin = tuple[float, float] @dataclass(frozen=True) class _DetectionStatsKey: """Key into the per-bucket accumulator dict. Attributes: class_id: Integer class ID this bucket scores. iou_threshold: IoU threshold at which TP/FP were decided. range_bin: ``(low, high)`` distance bounds in meters, or ``None`` when range bucketing is disabled. """ class_id: int iou_threshold: float range_bin: _RangeBin | None @dataclass class _DetectionStats: """Per-bucket accumulator for AP computation. Attributes: scores: Per-frame prediction score chunks. is_tp: Per-frame true-positive flag chunks; ``is_tp[f][i]`` is ``True`` iff prediction ``i`` of frame ``f`` was matched to a ground-truth box at this bucket's IoU threshold. num_gt: Total ground-truth boxes seen for this bucket. """ scores: list[Tensor] = field(default_factory=list) is_tp: list[Tensor] = field(default_factory=list) num_gt: int = 0
[docs] class MeanAveragePrecision3DResult(TypedDict): """Structured result returned by :meth:`MeanAveragePrecision3D.compute`. Undefined slots (buckets with no ground-truth boxes accumulated) are reported as ``-1.0`` and callers can filter them with ``x >= 0``. Attributes: mAP: Overall mean AP, taken over every defined ``(class, iou, bin)`` bucket. mAP_per_class: AP per class, averaged over the other axes. AP_per_iou: AP per IoU threshold, averaged over the other axes. AP_per_class_per_iou: AP per ``(class, iou)`` pair, averaged over range bins (or a single value when range bucketing is disabled). AP_per_range: AP per range bin, averaged over the other axes. Only present when ``range_bins`` was set on the metric. AP_per_class_per_range: AP per ``(class, range_bin)`` pair, averaged over IoU thresholds. Only present when ``range_bins`` was set on the metric. """ mAP: float mAP_per_class: dict[int, float] AP_per_iou: dict[float, float] AP_per_class_per_iou: dict[tuple[int, float], float] AP_per_range: NotRequired[dict[_RangeBin, float]] AP_per_class_per_range: NotRequired[dict[tuple[int, _RangeBin], float]]
[docs] class MeanAveragePrecision3D: """3D detection mAP metric. Matching is greedy by descending score, one prediction to one ground truth, with precision/recall accumulated globally across frames (KITTI convention). Args: class_ids: Integer class IDs to score. Predictions and GTs with labels outside this set are ignored. iou_thresholds: IoU thresholds to report AP at. Default ``(0.5, 0.7)``. ap_interpolation: Interpolation mode. Default :attr:`APInterpolation.R40`. range_bins: Optional distance bins ``[low, high)`` in meters from the sensor origin. When set, AP is also broken down per bin; boxes are bucketed by their center's distance. """ def __init__( self, class_ids: list[int], iou_thresholds: tuple[float, ...] = (0.5, 0.7), ap_interpolation: APInterpolation = APInterpolation.R40, range_bins: tuple[_RangeBin, ...] | None = None, ) -> None: if not class_ids: msg = "class_ids must be non-empty" raise ValueError(msg) if not iou_thresholds: msg = "iou_thresholds must be non-empty" raise ValueError(msg) self.class_ids = list(class_ids) self.iou_thresholds = tuple(iou_thresholds) self.ap_interpolation = ap_interpolation self.range_bins = tuple(range_bins) if range_bins is not None else None self._state: dict[_DetectionStatsKey, _DetectionStats] = {}
[docs] def update( self, preds: list[Prediction3D], targets: list[Target3D], ) -> None: """Accumulate one or more frames of predictions vs ground truth. Args: preds: List of per-frame :class:`Prediction3D` dicts. targets: List of per-frame :class:`Target3D` dicts. Raises: ValueError: If ``preds`` and ``targets`` differ in length. """ if len(preds) != len(targets): msg = ( f"preds and targets must have the same length; " f"got {len(preds)} vs {len(targets)}" ) raise ValueError(msg) for pred, target in zip(preds, targets): self._update_frame(pred, target)
def _update_frame(self, pred: Prediction3D, target: Target3D) -> None: pred_boxes = pred["boxes"] pred_scores = pred["scores"] pred_labels = pred["labels"] gt_boxes = target["boxes"] gt_labels = target["labels"] fmt: BoundingBox3DFormat = gt_boxes.format pred_centers, _, _ = _extract_box_params(pred_boxes, fmt) gt_centers, _, _ = _extract_box_params(gt_boxes, fmt) pred_dist = pred_centers.norm(dim=-1) gt_dist = gt_centers.norm(dim=-1) for range_bin in self.range_bins or (None,): if range_bin is None: pred_in_bin = torch.ones_like(pred_dist, dtype=torch.bool) gt_in_bin = torch.ones_like(gt_dist, dtype=torch.bool) else: low, high = range_bin pred_in_bin = (pred_dist >= low) & (pred_dist < high) gt_in_bin = (gt_dist >= low) & (gt_dist < high) for cls in self.class_ids: p_mask = pred_in_bin & (pred_labels == cls) g_mask = gt_in_bin & (gt_labels == cls) cls_pred_boxes = pred_boxes[p_mask] cls_pred_scores = pred_scores[p_mask] cls_gt_boxes = gt_boxes[g_mask] n_pred = cls_pred_boxes.shape[0] n_gt = cls_gt_boxes.shape[0] if n_pred == 0 and n_gt == 0: continue if n_pred > 0 and n_gt > 0: iou = box3d_iou(cls_pred_boxes, cls_gt_boxes, fmt) else: iou = torch.zeros( n_pred, n_gt, dtype=torch.float32, device=pred_boxes.device ) for thresh in self.iou_thresholds: key = _DetectionStatsKey(cls, thresh, range_bin) state = self._state.setdefault(key, _DetectionStats()) state.num_gt += n_gt if n_pred == 0: continue is_tp = _greedy_match(cls_pred_scores, iou, thresh) state.scores.append(cls_pred_scores.detach()) state.is_tp.append(is_tp)
[docs] def compute(self) -> MeanAveragePrecision3DResult: """Compute the aggregated metric. Returns: Populated :class:`MeanAveragePrecision3DResult`. """ ap_by_key: dict[_DetectionStatsKey, float] = {} for key, state in self._state.items(): if state.num_gt == 0: continue scores = ( torch.cat(state.scores) if state.scores else torch.empty(0, dtype=torch.float32) ) is_tp = ( torch.cat(state.is_tp) if state.is_tp else torch.empty(0, dtype=torch.bool) ) ap_by_key[key] = _compute_ap( scores, is_tp, state.num_gt, self.ap_interpolation ) result: MeanAveragePrecision3DResult = { "mAP": _mean_defined(ap_by_key.values()), "mAP_per_class": {}, "AP_per_iou": {}, "AP_per_class_per_iou": {}, } per_class: dict[int, list[float]] = {c: [] for c in self.class_ids} for key, ap in ap_by_key.items(): per_class[key.class_id].append(ap) result["mAP_per_class"] = {c: _mean_defined(v) for c, v in per_class.items()} per_iou: dict[float, list[float]] = {t: [] for t in self.iou_thresholds} for key, ap in ap_by_key.items(): per_iou[key.iou_threshold].append(ap) result["AP_per_iou"] = {t: _mean_defined(v) for t, v in per_iou.items()} per_class_iou: dict[tuple[int, float], list[float]] = {} for key, ap in ap_by_key.items(): per_class_iou.setdefault((key.class_id, key.iou_threshold), []).append(ap) result["AP_per_class_per_iou"] = { k: _mean_defined(v) for k, v in per_class_iou.items() } if self.range_bins is not None: per_range: dict[_RangeBin, list[float]] = {b: [] for b in self.range_bins} for key, ap in ap_by_key.items(): if key.range_bin is not None: per_range[key.range_bin].append(ap) result["AP_per_range"] = {b: _mean_defined(v) for b, v in per_range.items()} per_class_range: dict[tuple[int, _RangeBin], list[float]] = {} for key, ap in ap_by_key.items(): if key.range_bin is not None: per_class_range.setdefault( (key.class_id, key.range_bin), [] ).append(ap) result["AP_per_class_per_range"] = { k: _mean_defined(v) for k, v in per_class_range.items() } return result
[docs] def reset(self) -> None: """Clear all accumulated state.""" self._state.clear()
def _greedy_match(scores: Tensor, iou: Tensor, threshold: float) -> Tensor: """Greedy one-to-one matching, preds ordered by descending score. Args: scores: ``[N_pred]`` prediction confidences. iou: ``[N_pred, N_gt]`` IoU matrix. threshold: IoU threshold below which no match is made. Returns: Boolean ``[N_pred]`` mask where ``True`` means the prediction was assigned a GT above the threshold. """ n_pred, n_gt = iou.shape is_tp = torch.zeros(n_pred, dtype=torch.bool, device=iou.device) if n_pred == 0 or n_gt == 0: return is_tp gt_matched = torch.zeros(n_gt, dtype=torch.bool, device=iou.device) order = scores.argsort(descending=True).tolist() neg_inf = torch.full((), -1.0, dtype=iou.dtype, device=iou.device) for i in order: row = torch.where(gt_matched, neg_inf, iou[i]) best_val, best_j = row.max(dim=0) if best_val.item() >= threshold: is_tp[i] = True gt_matched[best_j] = True return is_tp def _compute_ap( scores: Tensor, is_tp: Tensor, num_gt: int, interpolation: APInterpolation, ) -> float: """Compute AP from per-prediction (score, is_tp) tensors and num_gt. Returns: AP in ``[0, 1]``, or ``-1.0`` if ``num_gt == 0`` (undefined). Raises: ValueError: If ``interpolation`` is not a known :class:`APInterpolation`. """ if num_gt == 0: return -1.0 if scores.numel() == 0: return 0.0 scores = scores.to(torch.float32).cpu() is_tp_f = is_tp.to(torch.float32).cpu() order = scores.argsort(descending=True) tp_cum = is_tp_f[order].cumsum(dim=0) fp_cum = (1.0 - is_tp_f[order]).cumsum(dim=0) precisions = tp_cum / (tp_cum + fp_cum) recalls = tp_cum / num_gt # Right envelope: precisions[i] = max(precisions[i:]). precisions = precisions.flip(0).cummax(dim=0).values.flip(0) if interpolation == APInterpolation.R11: targets = torch.linspace( 0.0, 1.0, 11, dtype=torch.float32, device=recalls.device ) return _sample_ap(recalls, precisions, targets) if interpolation == APInterpolation.R40: targets = torch.arange(1, 41, dtype=torch.float32, device=recalls.device) / 40.0 return _sample_ap(recalls, precisions, targets) if interpolation == APInterpolation.R101: targets = torch.linspace( 0.0, 1.0, 101, dtype=torch.float32, device=recalls.device ) return _sample_ap(recalls, precisions, targets) if interpolation == APInterpolation.ALL_POINTS: return _all_points_ap(recalls, precisions) msg = f"unknown interpolation: {interpolation}" raise ValueError(msg) def _sample_ap(recalls: Tensor, precisions: Tensor, targets: Tensor) -> float: """Sample precision at each target recall level and average. Returns: Mean of sampled precisions at the target recall levels. """ idx = torch.searchsorted(recalls, targets) sampled = torch.zeros_like(targets) valid = idx < recalls.numel() sampled[valid] = precisions[idx[valid]] return float(sampled.mean().item()) def _all_points_ap(recalls: Tensor, precisions: Tensor) -> float: """VOC07 area-under-curve AP at every recall change. Returns: Area under the right-enveloped precision-recall curve. """ zero = torch.zeros(1, dtype=recalls.dtype, device=recalls.device) one = torch.ones(1, dtype=recalls.dtype, device=recalls.device) mrec = torch.cat([zero, recalls, one]) mpre = torch.cat([zero, precisions, zero]) mpre = mpre.flip(0).cummax(dim=0).values.flip(0) deltas = mrec[1:] - mrec[:-1] return float((deltas * mpre[1:]).sum().item()) def _mean_defined(values: Iterable[float]) -> float: """Mean over non-sentinel values, dropping ``-1`` entries. Returns: Mean of the valid entries, or ``-1.0`` if none are valid. """ vs = [v for v in values if v >= 0] if not vs: return -1.0 return sum(vs) / len(vs)