Source code for vision3d.ops._points_in_boxes_3d

"""Point-in-box tests for 3D bounding boxes."""

import torch
from torch import Tensor

from vision3d.tensors import BoundingBox3DFormat


[docs] def points_in_boxes_3d( points: Tensor, boxes: Tensor, format: BoundingBox3DFormat, ) -> Tensor: """Compute a boolean mask indicating which points fall inside which boxes. Supports all rotation formats including full 9-DOF (yaw, pitch, roll). Args: points: Point cloud coordinates ``[N, 3+C]``. Only the first 3 columns (x, y, z) are used. boxes: 3D bounding boxes ``[M, K]`` where K depends on format. format: Format of the bounding boxes. Returns: Boolean tensor ``[N, M]`` where entry ``(i, j)`` is True if point ``i`` is inside box ``j``. """ centers, half_dims, rot = _extract_box_params(boxes, format) return _points_in_rotated_boxes(points[:, :3], centers, half_dims, rot)
[docs] def points_in_boxes_3d_indices( points: Tensor, boxes: Tensor, format: BoundingBox3DFormat, ) -> Tensor: """Return per-point box assignment. If a point is inside multiple boxes, the first (lowest index) box wins. Args: points: Point cloud coordinates ``[N, 3+C]``. boxes: 3D bounding boxes ``[M, K]``. format: Format of the bounding boxes. Returns: Integer tensor ``[N]`` with the index of the box each point belongs to, or ``-1`` if the point is not in any box. """ mask = points_in_boxes_3d(points, boxes, format) # [N, M] n = points.shape[0] if mask.shape[1] == 0: return torch.full((n,), -1, dtype=torch.long, device=points.device) # First True along dim=1; if none, returns M (out of bounds) first_box = mask.to(torch.uint8).argmax(dim=1) # Points not in any box: all False along dim=1 in_any = mask.any(dim=1) first_box[~in_any] = -1 return first_box
def _build_rotation_matrix( yaw: Tensor, pitch: Tensor | None = None, roll: Tensor | None = None, ) -> Tensor: """Build ``[M, 3, 3]`` rotation matrices from Tait-Bryan ZY'X'' angles. When pitch and roll are None, builds a yaw-only Rz rotation (avoids unnecessary trig for the common case). Args: yaw: Yaw angles ``[M]`` in radians. pitch: Pitch angles ``[M]`` in radians, or None. roll: Roll angles ``[M]`` in radians, or None. Returns: Rotation matrices ``[M, 3, 3]``. """ m = yaw.shape[0] cy = torch.cos(yaw) sy = torch.sin(yaw) if pitch is None or roll is None: # Yaw-only: Rz(yaw) rot = torch.zeros(m, 3, 3, dtype=yaw.dtype, device=yaw.device) rot[:, 0, 0] = cy rot[:, 0, 1] = -sy rot[:, 1, 0] = sy rot[:, 1, 1] = cy rot[:, 2, 2] = 1.0 return rot # Full Tait-Bryan ZY'X'': R = Rz(yaw) @ Ry(pitch) @ Rx(roll) cp = torch.cos(pitch) sp = torch.sin(pitch) cr = torch.cos(roll) sr = torch.sin(roll) rot = torch.empty(m, 3, 3, dtype=yaw.dtype, device=yaw.device) rot[:, 0, 0] = cy * cp rot[:, 0, 1] = cy * sp * sr - sy * cr rot[:, 0, 2] = cy * sp * cr + sy * sr rot[:, 1, 0] = sy * cp rot[:, 1, 1] = sy * sp * sr + cy * cr rot[:, 1, 2] = sy * sp * cr - cy * sr rot[:, 2, 0] = -sp rot[:, 2, 1] = cp * sr rot[:, 2, 2] = cp * cr return rot def _extract_box_params( boxes: Tensor, format: BoundingBox3DFormat ) -> tuple[Tensor, Tensor, Tensor]: """Extract centers, half-dimensions, and rotation matrix from boxes. Returns: ``(centers, half_dims, rot)`` where ``centers`` and ``half_dims`` are ``[M, 3]`` and ``rot`` is ``[M, 3, 3]``. Raises: ValueError: If ``format`` is not a supported format. """ if format is BoundingBox3DFormat.XYZXYZ: mins = boxes[:, :3] maxs = boxes[:, 3:6] centers = (mins + maxs) / 2 half_dims = (maxs - mins) / 2 yaw = torch.zeros(boxes.shape[0], dtype=boxes.dtype, device=boxes.device) rot = _build_rotation_matrix(yaw) elif format is BoundingBox3DFormat.XYZLWH: centers = boxes[:, :3] half_dims = boxes[:, 3:6] / 2 yaw = torch.zeros(boxes.shape[0], dtype=boxes.dtype, device=boxes.device) rot = _build_rotation_matrix(yaw) elif format is BoundingBox3DFormat.XYZLWHY: centers = boxes[:, :3] half_dims = boxes[:, 3:6] / 2 rot = _build_rotation_matrix(boxes[:, 6]) elif format is BoundingBox3DFormat.XYZLWHYPR: centers = boxes[:, :3] half_dims = boxes[:, 3:6] / 2 rot = _build_rotation_matrix(boxes[:, 6], boxes[:, 7], boxes[:, 8]) else: msg = f"Unsupported format: {format}" raise ValueError(msg) return centers, half_dims, rot def _points_in_rotated_boxes( xyz: Tensor, centers: Tensor, half_dims: Tensor, rot: Tensor ) -> Tensor: """Check if points are inside arbitrarily rotated boxes. Args: xyz: Point positions ``[N, 3]``. centers: Box centers ``[M, 3]``. half_dims: Box half-extents ``[M, 3]`` (half_l, half_w, half_h). rot: Rotation matrices ``[M, 3, 3]``. Returns: Boolean ``[N, M]``. """ # Relative positions: [N, 1, 3] - [1, M, 3] -> [N, M, 3] rel = xyz.unsqueeze(1) - centers.unsqueeze(0) # Rotate into box local frame by R^T (inverse rotation) # rel: [N, M, 3], rot^T: [M, 3, 3] -> local: [N, M, 3] # Einstein: local_j = rel_k * R_jk (R^T has j,k swapped) local = torch.einsum("nmk,mjk->nmj", rel, rot) return (local.abs() <= half_dims.unsqueeze(0)).all(dim=-1)