Source code for vision3d.transforms._range_filter

"""Range-based filtering for points and boxes."""

from typing import Any, override

import torch
from torch import Tensor

from vision3d.ops._points_in_boxes_3d import _extract_box_params
from vision3d.tensors import BoundingBoxes3D, PointCloud3D

from ._transform import Transform


[docs] class RangeFilter3D(Transform): """Drop points and boxes outside an axis-aligned 3D region. Points are filtered by their xyz coordinates; boxes are filtered by their **center** (format-agnostic). Labels in ``targets`` are filtered in sync with boxes. **Must** be applied after spatial augmentations (rotate / scale / translate can push data out of the sensor range) and before the model sees the data. Args: point_cloud_range: Axis-aligned bounds ``(x_min, y_min, z_min, x_max, y_max, z_max)``. """ def __init__(self, point_cloud_range: tuple[float, ...]) -> None: super().__init__() if len(point_cloud_range) != 6: msg = "point_cloud_range must have 6 elements (x_min, y_min, z_min, x_max, y_max, z_max)" raise ValueError(msg) self.point_cloud_range = tuple(point_cloud_range)
[docs] @override def forward(self, *inputs: Any) -> Any: """Filter points and boxes outside the configured range. Accepts both a single sample dict and an ``(inputs, targets)`` pair. Returns: Filtered sample in the same structure as the input. """ if len(inputs) == 1: return self._filter_sample(inputs[0]) inputs_dict, targets_dict = inputs return self._filter_inputs(inputs_dict), self._filter_targets(targets_dict)
def _filter_inputs(self, inputs: dict[str, Any]) -> dict[str, Any]: out = dict(inputs) if "points" in out: out["points"] = self._filter_points(out["points"]) return out def _filter_targets(self, targets: dict[str, Any]) -> dict[str, Any]: out = dict(targets) self._apply_box_mask(out) return out def _filter_sample(self, sample: dict[str, Any]) -> dict[str, Any]: out = dict(sample) if "points" in out: out["points"] = self._filter_points(out["points"]) self._apply_box_mask(out) return out def _apply_box_mask(self, d: dict[str, Any]) -> None: """Filter boxes and labels in-place by center range.""" if "boxes" not in d: return boxes = d["boxes"] keep = self._box_keep_mask(boxes) d["boxes"] = BoundingBoxes3D( boxes.as_subclass(Tensor)[keep], format=boxes.format ) if "labels" in d: d["labels"] = d["labels"][keep] def _filter_points(self, points: PointCloud3D) -> PointCloud3D: pts = points.as_subclass(Tensor) min_bound = torch.tensor( self.point_cloud_range[:3], dtype=pts.dtype, device=pts.device ) max_bound = torch.tensor( self.point_cloud_range[3:], dtype=pts.dtype, device=pts.device ) keep = ((pts[:, :3] >= min_bound) & (pts[:, :3] < max_bound)).all(dim=-1) return PointCloud3D(pts[keep]) def _box_keep_mask(self, boxes: BoundingBoxes3D) -> Tensor: raw = boxes.as_subclass(Tensor) centers, _, _ = _extract_box_params(raw, boxes.format) min_bound = torch.tensor( self.point_cloud_range[:3], dtype=raw.dtype, device=raw.device ) max_bound = torch.tensor( self.point_cloud_range[3:], dtype=raw.dtype, device=raw.device ) return ((centers >= min_bound) & (centers < max_bound)).all(dim=-1)