Source code for vision3d.transforms._copy_paste_3d

"""3D copy-paste data augmentation with lazy object database."""

import math
from collections import defaultdict, deque
from dataclasses import dataclass, field
from typing import Any, override

import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image, ImageDraw
from torch import Tensor
from torch.utils._pytree import tree_flatten, tree_unflatten
from torchvision.tv_tensors import TVTensor

from vision3d.ops import (
    box3d_corners,
    box3d_overlap,
    points_in_boxes_3d,
    points_in_boxes_3d_indices,
    project_to_image,
)
from vision3d.ops._points_in_boxes_3d import _extract_box_params
from vision3d.tensors import (
    BoundingBox3DFormat,
    BoundingBoxes3D,
    CameraExtrinsics,
    CameraImages,
    CameraIntrinsics,
    PointCloud3D,
)
from vision3d.transforms._transform import Transform


@dataclass
class CameraCrop:
    """Image crop and convex-hull mask for one camera view of an object.

    Attributes:
        crop: Cropped image region ``[C, crop_h, crop_w]``.
        mask: Boolean hull mask ``[crop_h, crop_w]``.
        bbox: Bounding box in image coords ``(x_min, y_min, x_max, y_max)``.
    """

    crop: Tensor
    mask: Tensor
    bbox: tuple[int, int, int, int]


@dataclass
class ObjectEntry:
    """A single object extracted from a scene.

    Attributes:
        points: Points in scene frame ``[M, 3+C]``, or ``None`` for
            camera-only entries.
        box: Full box tensor ``[K]`` in its original format.
        label: Integer class label.
        camera_crops: Per-camera crops, or None when no camera data is
            available.  ``camera_crops[i]`` is None if the object is not
            visible in camera ``i``.
    """

    points: Tensor | None
    box: Tensor
    label: int
    camera_crops: list[CameraCrop | None] | None = field(default=None, repr=False)


def _convex_hull_2d(
    points: list[tuple[float, float]],
) -> list[tuple[float, float]]:
    """Compute 2-D convex hull (Andrew's monotone chain).

    Pure-Python implementation for small point sets (≤ 8 points).

    Args:
        points: List of ``(x, y)`` pairs.

    Returns:
        Hull vertices in counter-clockwise order.
    """
    pts = sorted(points)

    # Lower hull
    lower: list[tuple[float, float]] = []
    for p in pts:
        while len(lower) >= 2:
            o, a = lower[-2], lower[-1]
            if (a[0] - o[0]) * (p[1] - o[1]) - (a[1] - o[1]) * (p[0] - o[0]) > 0:
                break
            lower.pop()
        lower.append(p)

    # Upper hull
    upper: list[tuple[float, float]] = []
    for p in reversed(pts):
        while len(upper) >= 2:
            o, a = upper[-2], upper[-1]
            if (a[0] - o[0]) * (p[1] - o[1]) - (a[1] - o[1]) * (p[0] - o[0]) > 0:
                break
            upper.pop()
        upper.append(p)

    return lower[:-1] + upper[:-1]


def _fill_convex_polygon(
    vertices: list[tuple[float, float]],
    height: int,
    width: int,
    device: torch.device,
) -> Tensor:
    """Rasterise a convex polygon into a boolean mask using Pillow.

    Args:
        vertices: CCW-ordered hull vertices ``(x, y)`` in crop-local coords.
        height: Mask height (pixels).
        width: Mask width (pixels).
        device: Device for the output mask.

    Returns:
        Boolean mask ``[height, width]``.
    """
    img = Image.new("L", (width, height), 0)
    ImageDraw.Draw(img).polygon(vertices, fill=1)
    mask_np = np.frombuffer(img.tobytes(), dtype=np.uint8).reshape(height, width)
    return torch.from_numpy(mask_np.copy()).bool().to(device)


_HullMaskResult = tuple[Tensor, tuple[int, int, int, int], float]


def _project_boxes_to_camera(
    boxes: Tensor,
    fmt: BoundingBox3DFormat,
    extrinsic: Tensor,
    intrinsic: Tensor,
) -> tuple[Tensor, Tensor]:
    """Project all box corners into a single camera at once.

    Args:
        boxes: ``[M, K]`` 3-D bounding boxes.
        fmt: Box format.
        extrinsic: ``[4, 4]`` lidar-to-camera.
        intrinsic: ``[3, 3]`` camera K.

    Returns:
        Tuple of ``(uv, depth)`` where ``uv`` is ``[M, 8, 2]`` pixel
        coordinates and ``depth`` is ``[M, 8]``.
    """
    corners = box3d_corners(boxes, fmt)  # [M, 8, 3]
    m = corners.shape[0]
    flat = corners.reshape(m * 8, 3)
    uv_flat, depth_flat = project_to_image(flat, extrinsic, intrinsic)
    return uv_flat.reshape(m, 8, 2), depth_flat.reshape(m, 8)


def _hull_mask_from_projected(
    uv: Tensor,
    depth: Tensor,
    img_h: int,
    img_w: int,
) -> _HullMaskResult | None:
    """Compute a convex-hull mask from pre-projected corners.

    Converts to plain Python immediately and does all geometry in floats
    to avoid torch dispatch overhead on tiny (8-element) tensors.

    Args:
        uv: ``[8, 2]`` pixel coordinates for one box.
        depth: ``[8]`` depth values for one box.
        img_h: Image height.
        img_w: Image width.

    Returns:
        ``(mask, bbox, depth)`` or ``None`` if the object is not visible.
    """
    uv_list: list[list[float]] = uv.tolist()
    depth_list: list[float] = depth.tolist()

    visible: list[tuple[float, float]] = []
    depth_sum = 0.0
    for i in range(len(depth_list)):
        if depth_list[i] > 0:
            visible.append((uv_list[i][0], uv_list[i][1]))
            depth_sum += depth_list[i]

    if len(visible) < 3:
        return None

    hull = _convex_hull_2d(visible)
    if len(hull) < 3:
        return None

    # Bounding rect of hull, clipped to image
    hull_xs = [p[0] for p in hull]
    hull_ys = [p[1] for p in hull]
    x_min = max(math.floor(min(hull_xs)), 0)
    y_min = max(math.floor(min(hull_ys)), 0)
    x_max = min(math.ceil(max(hull_xs)), img_w - 1)
    y_max = min(math.ceil(max(hull_ys)), img_h - 1)

    crop_h = y_max - y_min + 1
    crop_w = x_max - x_min + 1
    if crop_h < 1 or crop_w < 1:
        return None

    # Shift hull to crop-local coordinates and fill
    local_hull = [(x - x_min, y - y_min) for x, y in hull]
    mask = _fill_convex_polygon(local_hull, crop_h, crop_w, uv.device)

    mean_depth = depth_sum / len(visible)
    return mask, (x_min, y_min, x_max, y_max), mean_depth


def _batch_hull_masks(
    boxes: Tensor,
    fmt: BoundingBox3DFormat,
    extrinsic: Tensor,
    intrinsic: Tensor,
    img_h: int,
    img_w: int,
) -> list[_HullMaskResult | None]:
    """Compute hull masks for multiple boxes in one camera (batched projection).

    Args:
        boxes: ``[M, K]`` boxes.
        fmt: Box format.
        extrinsic: ``[4, 4]``.
        intrinsic: ``[3, 3]``.
        img_h: Image height.
        img_w: Image width.

    Returns:
        List of length ``M``, each element a ``(mask, bbox, depth)`` tuple
        or ``None``.
    """
    if boxes.shape[0] == 0:
        return []
    uv_all, depth_all = _project_boxes_to_camera(boxes, fmt, extrinsic, intrinsic)
    return [
        _hull_mask_from_projected(uv_all[i], depth_all[i], img_h, img_w)
        for i in range(boxes.shape[0])
    ]


[docs] class CopyPaste3D(Transform): """Batch-level 3D copy-paste data augmentation. Maintains a lazy object database that grows as batches pass through. For each sample, pastes additional objects from the database to reach a target count per class. Objects are pasted at their original scene position from the source frame. Operates on collated batches ``(tuple_of_inputs, tuple_of_targets)``, not individual samples. Each instance should be used with only one dataset to avoid cross-contamination. :class:`CopyPaste3D` **must** be the first transform in any pipeline, before any 3D spatial transform (:class:`RandomFlip3D`, :class:`RandomRotate3D`, :class:`RandomScale3D`, :class:`RandomTranslate3D`). Pasted objects are extracted and re-inserted in the source-frame geometry of the scene they came from. If a scene transform has already mutated the frame, the pasted objects will disagree with the rest of the scene and the resulting boxes/points will be inconsistent. Args: target_counts: Dict mapping integer class label to desired object count per sample. E.g. ``{0: 15, 1: 10}``. min_points: Minimum number of points an extracted object must have to be stored in the database. Default: ``5``. max_database_size: Maximum entries per class. None means unlimited. Default: ``None``. p: Probability of applying the augmentation. Default: ``1.0``. """ def __init__( self, target_counts: dict[int, int], min_points: int = 5, max_database_size: int | None = None, p: float = 1.0, ) -> None: super().__init__() if not (0.0 <= p <= 1.0): msg = "`p` should be a float in [0.0, 1.0]." raise ValueError(msg) if min_points < 1: msg = "`min_points` should be a positive integer." raise ValueError(msg) if max_database_size is not None and max_database_size < 1: msg = "`max_database_size` should be a positive integer or None." raise ValueError(msg) self.target_counts = target_counts self.min_points = min_points self.max_database_size = max_database_size self.p = p self._database: dict[int, deque[ObjectEntry]] = defaultdict( lambda: deque(maxlen=self.max_database_size) )
[docs] @override def forward(self, *inputs: Any) -> Any: """Apply copy-paste augmentation to a collated batch. Accepts any pytree structure containing :class:`~vision3d.tensors.PointCloud3D`, :class:`~vision3d.tensors.BoundingBoxes3D`, and optionally camera tensors and plain-tensor labels. Returns: The same pytree structure with modified leaves. """ flat_inputs, spec = tree_flatten(inputs if len(inputs) > 1 else inputs[0]) batch_inputs, batch_targets = self._extract_samples(flat_inputs) for inp, tgt in zip(batch_inputs, batch_targets): self._extract_objects(inp, tgt) if torch.rand(1).item() >= self.p: return tree_unflatten(flat_inputs, spec) output_inputs = [] output_targets = [] for inp, tgt in zip(batch_inputs, batch_targets): new_inp, new_tgt = self._paste_objects(inp, tgt) output_inputs.append(new_inp) output_targets.append(new_tgt) self._insert_outputs(flat_inputs, output_inputs, output_targets) return tree_unflatten(flat_inputs, spec)
def _extract_samples( self, flat_inputs: list[Any] ) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]: """Group flat pytree leaves into per-sample input and target dicts. Returns: ``(batch_inputs, batch_targets)``: Lists of per-sample dicts. Raises: TypeError: If required types are missing or counts don't match. """ points: list[PointCloud3D] = [] boxes: list[BoundingBoxes3D] = [] images: list[CameraImages] = [] extrinsics: list[CameraExtrinsics] = [] intrinsics: list[CameraIntrinsics] = [] labels: list[Tensor] = [] for obj in flat_inputs: if isinstance(obj, PointCloud3D): points.append(obj) elif isinstance(obj, BoundingBoxes3D): boxes.append(obj) elif isinstance(obj, CameraImages): images.append(obj) elif isinstance(obj, CameraExtrinsics): extrinsics.append(obj) elif isinstance(obj, CameraIntrinsics): intrinsics.append(obj) elif isinstance(obj, Tensor) and not isinstance(obj, TVTensor): labels.append(obj) n = len(boxes) has_points = len(points) > 0 has_cameras = len(images) > 0 has_labels = len(labels) > 0 mismatched: list[str] = [] if has_points and len(points) != n: mismatched.append(f"PointCloud3D ({len(points)})") if has_cameras and len(images) != n: mismatched.append(f"CameraImages ({len(images)})") if has_cameras and len(extrinsics) != n: mismatched.append(f"CameraExtrinsics ({len(extrinsics)})") if has_cameras and len(intrinsics) != n: mismatched.append(f"CameraIntrinsics ({len(intrinsics)})") if has_labels and len(labels) != n: mismatched.append(f"plain tensors ({len(labels)})") if mismatched: raise TypeError( f"{type(self).__name__}() requires equal sized lists of " f"inputs per sample. Got {n} BoundingBoxes3D but " f"{', '.join(mismatched)}." ) batch_inputs: list[dict[str, Any]] = [] batch_targets: list[dict[str, Any]] = [] for i in range(n): inp: dict[str, Any] = {} if has_points: inp["points"] = points[i] if has_cameras: inp["images"] = images[i] inp["extrinsics"] = extrinsics[i] inp["intrinsics"] = intrinsics[i] tgt: dict[str, Any] = {"boxes": boxes[i]} if has_labels: tgt["labels"] = labels[i] batch_inputs.append(inp) batch_targets.append(tgt) return batch_inputs, batch_targets def _insert_outputs( self, flat_inputs: list[Any], output_inputs: list[dict[str, Any]], output_targets: list[dict[str, Any]], ) -> None: """Replace modified leaves in *flat_inputs* in-place. Uses per-type counters to walk through the flat list and replace each leaf with the corresponding value from the output dicts. """ c_pts = 0 c_img = 0 c_box = 0 c_lbl = 0 for i, obj in enumerate(flat_inputs): if isinstance(obj, PointCloud3D): flat_inputs[i] = output_inputs[c_pts]["points"] c_pts += 1 elif isinstance(obj, CameraImages): new_img = output_inputs[c_img].get("images") if new_img is not None: flat_inputs[i] = new_img c_img += 1 elif isinstance(obj, BoundingBoxes3D): flat_inputs[i] = output_targets[c_box]["boxes"] c_box += 1 elif isinstance(obj, (CameraExtrinsics, CameraIntrinsics)): pass # never modified by copy-paste elif isinstance(obj, Tensor) and not isinstance(obj, TVTensor): new_lbl = output_targets[c_lbl].get("labels") if new_lbl is not None: flat_inputs[i] = new_lbl c_lbl += 1 def _has_camera_data(self, inputs: dict[str, Any]) -> bool: return "images" in inputs and "extrinsics" in inputs and "intrinsics" in inputs def _extract_objects(self, inputs: dict[str, Any], targets: dict[str, Any]) -> None: """Extract per-object point clouds and store in database.""" points = inputs.get("points") boxes = targets["boxes"] labels = targets.get("labels", torch.zeros(0, dtype=torch.long)) if boxes.shape[0] == 0 or labels.shape[0] == 0: return fmt = boxes.format # Find valid objects: With point clouds this means meeting min_points, # for camera-only inputs all labeled boxes are valid. valid: list[tuple[int, Tensor | None]] = [] if points is not None: indices = points_in_boxes_3d_indices(points, boxes, fmt) for j in range(boxes.shape[0]): if j >= labels.shape[0]: break mask = indices == j obj_points = points[mask] if obj_points.shape[0] >= self.min_points: valid.append((j, obj_points)) else: for j in range(min(boxes.shape[0], labels.shape[0])): valid.append((j, None)) # Batch camera crop extraction for all valid objects at once has_cameras = self._has_camera_data(inputs) camera_crops_map: dict[int, list[CameraCrop | None]] = {} if has_cameras and valid: camera_crops_map = self._extract_all_camera_crops( boxes, fmt, inputs, [j for j, _ in valid] ) for j, obj_points in valid: label = int(labels[j].item()) entry = ObjectEntry( points=obj_points.detach().cpu() if obj_points is not None else None, box=boxes[j].detach().cpu(), label=label, camera_crops=camera_crops_map.get(j), ) self._database[label].append(entry) def _extract_all_camera_crops( self, boxes: Tensor, fmt: BoundingBox3DFormat, inputs: dict[str, Any], valid_indices: list[int], ) -> dict[int, list[CameraCrop | None]]: """Extract image crops for multiple objects from all camera views. Uses batched projection per camera to avoid per-object overhead. Args: boxes: All boxes ``[M, K]``. fmt: Box format. inputs: Input dict with camera data. valid_indices: Indices into ``boxes`` for objects that passed the min_points filter. Returns: Dict mapping box index to per-camera crop list. """ images = inputs["images"] # [N, C, H, W] extrinsics = inputs["extrinsics"] # [N, 4, 4] intrinsics = inputs["intrinsics"] # [N, 3, 3] n_cams = images.shape[0] img_h, img_w = images.shape[2], images.shape[3] if not valid_indices: return {} valid_boxes = boxes[valid_indices] # [V, K] result: dict[int, list[CameraCrop | None]] = { j: [None] * n_cams for j in valid_indices } for cam_idx in range(n_cams): uv_all, depth_all = _project_boxes_to_camera( valid_boxes, fmt, extrinsics[cam_idx], intrinsics[cam_idx] ) for vi, j in enumerate(valid_indices): hull_result = _hull_mask_from_projected( uv_all[vi], depth_all[vi], img_h, img_w ) if hull_result is None: continue mask, (x_min, y_min, x_max, y_max), _depth = hull_result crop = images[cam_idx, :, y_min : y_max + 1, x_min : x_max + 1] result[j][cam_idx] = CameraCrop( crop=crop.detach().cpu(), mask=mask.detach().cpu(), bbox=(x_min, y_min, x_max, y_max), ) return result def _paste_objects( self, inputs: dict[str, Any], targets: dict[str, Any], ) -> tuple[dict[str, Any], dict[str, Any]]: """Paste objects from database into a single sample. Returns: Modified ``(inputs, targets)`` dicts. """ points = inputs.get("points") boxes = targets["boxes"] labels = targets.get("labels", torch.zeros(0, dtype=torch.long)) fmt = boxes.format # Count existing objects per label existing_counts: dict[int, int] = {} for lbl in labels.tolist(): existing_counts[lbl] = existing_counts.get(lbl, 0) + 1 pasted_entries: list[ObjectEntry] = [] pasted_boxes: list[Tensor] = [] pasted_points: list[Tensor] = [] pasted_labels: list[int] = [] all_boxes = boxes device = boxes.device for label_id, target_count in self.target_counts.items(): n_paste = max(0, target_count - existing_counts.get(label_id, 0)) db = self._database.get(label_id) if not db or n_paste == 0: continue perm = torch.randperm(len(db)).tolist() candidates = [db[i] for i in perm[:n_paste]] cand_boxes = torch.stack([c.box for c in candidates]).to(device) # Candidates vs existing scene boxes. if all_boxes.shape[0] > 0: safe = ~box3d_overlap(cand_boxes, all_boxes, fmt).any(dim=1) else: safe = torch.ones(cand_boxes.shape[0], dtype=torch.bool, device=device) # Candidates vs each other. cc = box3d_overlap(cand_boxes, cand_boxes, fmt) cc.fill_diagonal_(False) safe_cpu = safe.cpu() cc_cpu = cc.cpu() accepted_k: list[int] = [] for k in range(len(candidates)): if not safe_cpu[k].item(): continue if cc_cpu[k, accepted_k].any().item(): continue accepted_k.append(k) if not accepted_k: continue for k in accepted_k: entry = candidates[k] pasted_entries.append(entry) pasted_boxes.append(cand_boxes[k]) if entry.points is not None and points is not None: pasted_points.append(entry.points.to(points.device)) pasted_labels.append(entry.label) all_boxes = torch.cat([all_boxes, cand_boxes[accepted_k]]) if not pasted_boxes: return inputs, targets pasted_boxes_tensor = torch.stack(pasted_boxes) new_inputs: dict[str, Any] = {**inputs} # Point cloud update: remove scene points in pasted regions, add pasted points if points is not None: remove_mask = points_in_boxes_3d(points, pasted_boxes_tensor, fmt).any( dim=1 ) kept_points = points[~remove_mask] if pasted_points: new_points = torch.cat([kept_points, torch.cat(pasted_points)]) else: new_points = kept_points new_inputs["points"] = PointCloud3D(new_points) # Box and label update new_boxes = torch.cat([boxes, pasted_boxes_tensor]) new_labels = torch.cat( [ labels, torch.tensor(pasted_labels, dtype=labels.dtype, device=labels.device), ] ) new_targets: dict[str, Any] = { **targets, "boxes": BoundingBoxes3D(new_boxes, format=fmt), "labels": new_labels, } # Camera image paste if self._has_camera_data(inputs): new_images = self._paste_camera_images(inputs, boxes, fmt, pasted_entries) if new_images is not None: new_inputs["images"] = new_images return new_inputs, new_targets def _paste_camera_images( self, inputs: dict[str, Any], existing_boxes: Tensor, fmt: BoundingBox3DFormat, pasted_entries: list[ObjectEntry], ) -> CameraImages | None: """Paste object image crops into camera views with depth-aware occlusion. Returns: Updated ``CameraImages`` or ``None`` if nothing to paste. """ images = inputs["images"] # [N, C, H, W] — cloned per-camera on write extrinsics = inputs["extrinsics"] # [N, 4, 4] intrinsics = inputs["intrinsics"] # [N, 3, 3] n_cams = images.shape[0] img_h, img_w = images.shape[2], images.shape[3] # Pre-compute centers (format-aware) once outside the camera loop. # Database entries live on CPU so we must bring them onto the working device. pasted_box_stack = torch.stack([e.box for e in pasted_entries]).to( images.device ) p_centers, _, _ = _extract_box_params(pasted_box_stack, fmt) p_ones = torch.ones( p_centers.shape[0], 1, dtype=p_centers.dtype, device=p_centers.device ) pasted_centers_hom = torch.cat([p_centers, p_ones], dim=-1) # [P, 4] has_existing = existing_boxes.shape[0] > 0 e_centers_hom = torch.zeros( 0, 4, dtype=existing_boxes.dtype, device=existing_boxes.device ) if has_existing: e_centers, _, _ = _extract_box_params(existing_boxes, fmt) e_ones = torch.ones( e_centers.shape[0], 1, dtype=e_centers.dtype, device=e_centers.device ) e_centers_hom = torch.cat([e_centers, e_ones], dim=-1) any_pasted = False cloned = False for cam_idx in range(n_cams): ext = extrinsics[cam_idx] # [4, 4] K = intrinsics[cam_idx] # [3, 3] # Depths of pasted objects in this camera paste_depths = (ext @ pasted_centers_hom.T).T[:, 2] # [P] order = paste_depths.argsort(descending=True) # Batched hull masks for existing scene boxes (for occlusion) existing_masks: list[_HullMaskResult | None] = [] existing_depths: list[float] = [] if has_existing: existing_depths_t = (ext @ e_centers_hom.T).T[:, 2] existing_depths = existing_depths_t.tolist() existing_masks = _batch_hull_masks( existing_boxes, fmt, ext, K, img_h, img_w ) # Batched hull masks for pasted objects (target projection) pasted_masks = _batch_hull_masks( pasted_box_stack, fmt, ext, K, img_h, img_w ) for idx in order: idx_int = int(idx.item()) entry = pasted_entries[idx_int] if entry.camera_crops is None: continue if cam_idx >= len(entry.camera_crops): continue cam_crop = entry.camera_crops[cam_idx] if cam_crop is None: continue result = pasted_masks[idx_int] if result is None: continue target_mask, (tx_min, ty_min, tx_max, ty_max), _target_depth = result target_h = ty_max - ty_min + 1 target_w = tx_max - tx_min + 1 # Resize stored crop and mask to target size. # Stored on CPU so we must bring onto the working device before resize. src_crop = ( cam_crop.crop.unsqueeze(0).float().to(images.device) ) # [1, C, sh, sw] resized_crop = F.interpolate( src_crop, size=(target_h, target_w), mode="bilinear", align_corners=False, ).squeeze(0) # [C, th, tw] src_mask = ( cam_crop.mask.unsqueeze(0).unsqueeze(0).float().to(images.device) ) resized_mask = ( F.interpolate(src_mask, size=(target_h, target_w), mode="nearest") .squeeze(0) .squeeze(0) .bool() ) # Intersect target hull with resized source mask paste_mask = target_mask & resized_mask # Depth-aware occlusion: subtract closer existing boxes paste_depth = float(paste_depths[idx_int].item()) for e_idx in range(len(existing_masks)): if ( existing_depths[e_idx] <= 0 or existing_depths[e_idx] >= paste_depth ): continue e_result = existing_masks[e_idx] if e_result is None: continue _e_mask, (ex_min, ey_min, ex_max, ey_max), _e_depth = e_result # Compute overlap region ox_min = max(tx_min, ex_min) oy_min = max(ty_min, ey_min) ox_max = min(tx_max, ex_max) oy_max = min(ty_max, ey_max) if ox_min > ox_max or oy_min > oy_max: continue # Subtract overlapping part of existing mask overlap_h = oy_max - oy_min + 1 overlap_w = ox_max - ox_min + 1 p_oy = oy_min - ty_min p_ox = ox_min - tx_min e_oy = oy_min - ey_min e_ox = ox_min - ex_min paste_mask[ p_oy : p_oy + overlap_h, p_ox : p_ox + overlap_w ] &= ~_e_mask[e_oy : e_oy + overlap_h, e_ox : e_ox + overlap_w] if not paste_mask.any(): continue # Clone on first write to avoid mutating the input if not cloned: images = images.clone() cloned = True # Paste into image region = images[cam_idx, :, ty_min : ty_max + 1, tx_min : tx_max + 1] region[:, paste_mask] = resized_crop[:, paste_mask].to(region.dtype) any_pasted = True if not any_pasted: return None return CameraImages(images)