"""3D oriented bounding box overlap using the Separating Axis Theorem."""
import torch
from torch import Tensor
from vision3d.tensors import BoundingBox3DFormat
from ._points_in_boxes_3d import _extract_box_params
[docs]
def box3d_overlap(
boxes1: Tensor,
boxes2: Tensor,
format: BoundingBox3DFormat,
) -> Tensor:
"""Check 3D overlap between two sets of oriented bounding boxes.
Uses the Separating Axis Theorem (SAT) with 15 potential separating
axes (3 face normals per box + 9 edge cross products).
Args:
boxes1: First set of boxes ``[N, K]``.
boxes2: Second set of boxes ``[M, K]``.
format: Format of both box sets.
Returns:
Boolean matrix ``[N, M]`` where True indicates overlap.
"""
centers1, half1, rot1 = _extract_box_params(boxes1, format)
centers2, half2, rot2 = _extract_box_params(boxes2, format)
# Pairwise center difference: [N, M, 3]
diff = centers2.unsqueeze(0) - centers1.unsqueeze(1)
# Precompute R1^T @ R2 and projections of diff onto each rotation frame.
# dot1[n, m, i] = diff[n,m] . rot1[n, :, i]
dot1 = torch.einsum("nmk,nik->nmi", diff, rot1) # [N, M, 3]
# dot2[n, m, j] = diff[n,m] . rot2[m, :, j]
dot2 = torch.einsum("nmk,mjk->nmj", diff, rot2) # [N, M, 3]
# c[n, m, i, j] = rot1[n, :, i] . rot2[m, :, j]
c = torch.einsum("nik,mjk->nmij", rot1, rot2) # [N, M, 3, 3]
abs_c = c.abs()
overlap = torch.ones(
centers1.shape[0], centers2.shape[0], dtype=torch.bool, device=boxes1.device
)
# Face normals of box1 (axes i=0,1,2)
for i in range(3):
d = dot1[:, :, i].abs()
r1 = half1[:, i].unsqueeze(1)
r2 = (abs_c[:, :, i, :] * half2.unsqueeze(0)).sum(dim=-1)
overlap &= d <= r1 + r2
# Face normals of box2 (axes j=0,1,2)
for j in range(3):
d = dot2[:, :, j].abs()
r1 = (abs_c[:, :, :, j] * half1.unsqueeze(1)).sum(dim=-1)
r2 = half2[:, j].unsqueeze(0)
overlap &= d <= r1 + r2
# Edge cross products: rot1[:,:,i] x rot2[:,:,j]
# For axis a_i x b_j, the projections simplify using the c matrix.
for i in range(3):
i1 = (i + 1) % 3
i2 = (i + 2) % 3
for j in range(3):
j1 = (j + 1) % 3
j2 = (j + 2) % 3
# d = |diff . (a_i x b_j)| = |dot1_i2 * c_i1j - dot1_i1 * c_i2j|
# (using triple product expansion)
d = (
dot1[:, :, i1] * c[:, :, i2, j] - dot1[:, :, i2] * c[:, :, i1, j]
).abs()
r1 = (
half1[:, i1].unsqueeze(1) * abs_c[:, :, i2, j]
+ half1[:, i2].unsqueeze(1) * abs_c[:, :, i1, j]
)
r2 = (
half2[:, j1].unsqueeze(0) * abs_c[:, :, i, j2]
+ half2[:, j2].unsqueeze(0) * abs_c[:, :, i, j1]
)
overlap &= d <= r1 + r2
return overlap