Source code for vision3d.ops._box3d_corners

"""Compute 3D bounding box corners."""

import torch
from torch import Tensor

from vision3d.tensors import BoundingBox3DFormat

from ._points_in_boxes_3d import _extract_box_params


[docs] def box3d_corners( boxes: Tensor, format: BoundingBox3DFormat, ) -> Tensor: r"""Compute the 8 world-space corners of 3D bounding boxes. Supports all rotation formats including full 9-DOF (yaw, pitch, roll). Corner ordering:: 4 -------- 5 z x |\ |\ | / | 7 -------| 6 | / | | | | |/ 0 |--------1 | +------ y \| \| 3 -------- 2 Bottom face (z-): {0, 1, 2, 3}. Top face (z+): {4, 5, 6, 7}. Args: boxes: 3D bounding boxes ``[N, K]``. format: Format of the bounding boxes. Returns: Corner coordinates ``[N, 8, 3]``. """ centers, half_dims, rot = _extract_box_params(boxes, format) # 8 sign combinations for (x, y, z) offsets # Order: (-x,-y,-z), (+x,-y,-z), (+x,+y,-z), (-x,+y,-z), # (-x,-y,+z), (+x,-y,+z), (+x,+y,+z), (-x,+y,+z) signs = torch.tensor( [ [-1, -1, -1], [+1, -1, -1], [+1, +1, -1], [-1, +1, -1], [-1, -1, +1], [+1, -1, +1], [+1, +1, +1], [-1, +1, +1], ], dtype=boxes.dtype, device=boxes.device, ) # [8, 3] # Local corners before rotation: [N, 8, 3] local = half_dims.unsqueeze(1) * signs.unsqueeze(0) # Rotate by full rotation matrix: [N, 3, 3] @ [N, 3, 8] -> [N, 3, 8] -> [N, 8, 3] rotated = torch.einsum("nij,nkj->nki", rot, local) return rotated + centers.unsqueeze(1)