Source code for vision3d.tensors._wrap

from typing import Any

from torch import Tensor
from torchvision.tv_tensors import TVTensor

from ._bounding_boxes_3d import BoundingBoxes3D


[docs] def wrap( wrappee: Tensor, *, like: TVTensor, **kwargs: Any, ) -> TVTensor: """Convert a :class:`~torch.Tensor` into the same TVTensor subclass as ``like``. If ``like`` is a :class:`~vision3d.tensors.BoundingBoxes3D`, the ``format`` of ``like`` is assigned to ``wrappee`` unless overridden via ``kwargs``. Args: wrappee: The tensor to convert. like: The reference. ``wrappee`` will be converted into the same subclass as ``like``. kwargs: Can contain ``"format"`` if ``like`` is a :class:`~vision3d.tensors.BoundingBoxes3D`. Ignored otherwise. Returns: A TVTensor of the same subclass as ``like``. """ if isinstance(like, BoundingBoxes3D): return type(like)._wrap( wrappee, format=kwargs.get("format", like.format), ) else: return wrappee.as_subclass(type(like))