Source code for vision3d.ops._nms_3d
"""3D non-maximum suppression."""
import torch
from torch import Tensor
from vision3d.tensors import BoundingBox3DFormat
from ._box3d_iou import box3d_iou
[docs]
@torch.no_grad()
def nms_3d(
boxes: Tensor,
scores: Tensor,
iou_threshold: float,
format: BoundingBox3DFormat,
) -> Tensor:
"""Greedy, class-agnostic non-maximum suppression on 3D bounding boxes.
Iteratively removes lower-scoring boxes whose IoU with a
higher-scoring box exceeds ``iou_threshold``.
Args:
boxes: ``[N, K]`` boxes to perform NMS on. ``K`` depends on
``format``.
scores: ``[N]`` prediction confidences.
iou_threshold: Discard any box whose IoU with a higher-scoring
kept box is strictly greater than this value.
format: Format of ``boxes``.
Returns:
``int64`` tensor of indices into ``boxes`` that survived, sorted
in decreasing order of score.
"""
n = boxes.shape[0]
if n == 0:
return torch.empty(0, dtype=torch.long, device=boxes.device)
order = scores.argsort(descending=True)
boxes_sorted = boxes[order]
iou = box3d_iou(boxes_sorted, boxes_sorted, format) # [N, N]
keep_mask = torch.ones(n, dtype=torch.bool, device=boxes.device)
for i in range(n):
if not keep_mask[i]:
continue
if i + 1 >= n:
break
keep_mask[i + 1 :] &= iou[i, i + 1 :] <= iou_threshold
return order[keep_mask]
[docs]
@torch.no_grad()
def batched_nms_3d(
boxes: Tensor,
scores: Tensor,
idxs: Tensor,
iou_threshold: float,
format: BoundingBox3DFormat,
) -> Tensor:
"""Class-aware 3D NMS: runs :func:`nms_3d` independently per class.
Args:
boxes: ``[N, K]`` boxes.
scores: ``[N]`` prediction confidences.
idxs: ``[N]`` integer class labels.
iou_threshold: See :func:`nms_3d`.
format: Format of ``boxes``.
Returns:
``int64`` tensor of indices into ``boxes`` that survived, sorted
in decreasing order of score.
"""
if boxes.numel() == 0:
return torch.empty(0, dtype=torch.long, device=boxes.device)
keep_mask = torch.zeros_like(scores, dtype=torch.bool)
for class_id in torch.unique(idxs):
in_class = torch.where(idxs == class_id)[0]
class_keep = nms_3d(boxes[in_class], scores[in_class], iou_threshold, format)
keep_mask[in_class[class_keep]] = True
keep_indices = torch.where(keep_mask)[0]
return keep_indices[scores[keep_indices].argsort(descending=True)]