Source code for vision3d.transforms._transform
"""Base class for vision3d transforms."""
import enum
from collections.abc import Callable
from typing import Any, override
import torch
from torch import nn
from torch.utils._pytree import tree_flatten, tree_unflatten
from torchvision.tv_tensors import TVTensor
from .functional._registry import _get_kernel
def _needs_transform(inpt: Any) -> bool:
"""Only TVTensor subclasses are transformed. Plain tensors pass through.
Returns:
True if ``inpt`` is a TVTensor subclass.
"""
return isinstance(inpt, TVTensor)
[docs]
class Transform(nn.Module):
"""Base class for vision3d transforms.
Only :class:`~torchvision.tv_tensors.TVTensor` subclasses (e.g.
:class:`~vision3d.tensors.BoundingBoxes3D`,
:class:`~vision3d.tensors.PointCloud3D`) are transformed.
Plain tensors (labels, scores, etc.) pass through unchanged.
Subclasses should override :meth:`transform` and use ``_call_kernel``
to dispatch to the correct kernel for each input type.
"""
def __init__(self) -> None:
super().__init__()
[docs]
def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]:
"""Sample random parameters. Override for randomised transforms.
Returns:
Parameter dict passed to :meth:`transform`.
"""
return {}
def _call_kernel(
self, functional: Callable[..., Any], inpt: Any, *args: Any, **kwargs: Any
) -> Any:
kernel = _get_kernel(functional, type(inpt))
return kernel(inpt, *args, **kwargs)
[docs]
def transform(self, inpt: Any, params: dict[str, Any]) -> Any:
"""Apply the transform to a single input. Must be overridden."""
raise NotImplementedError
[docs]
@override
def forward(self, *inputs: Any) -> Any:
"""Apply the transform to one or more inputs (dicts, tuples, etc.).
Returns:
Transformed inputs in the same structure as the input.
"""
flat_inputs, spec = tree_flatten(inputs if len(inputs) > 1 else inputs[0])
needs = [_needs_transform(inpt) for inpt in flat_inputs]
params = self.make_params([inpt for inpt, nt in zip(flat_inputs, needs) if nt])
flat_outputs = [
self.transform(inpt, params) if nt else inpt
for inpt, nt in zip(flat_inputs, needs)
]
return tree_unflatten(flat_outputs, spec)
[docs]
@override
def extra_repr(self) -> str:
"""Auto-generate repr from public attributes.
Returns:
Comma-separated key=value string.
"""
extra = []
for name, value in self.__dict__.items():
if name.startswith("_") or name == "training":
continue
if not isinstance(value, (bool, int, float, str, tuple, list, enum.Enum)):
continue
extra.append(f"{name}={value}")
return ", ".join(extra)
[docs]
class RandomTransform(Transform):
"""Base class for transforms applied with probability ``p``."""
def __init__(self, p: float = 0.5) -> None:
if not (0.0 <= p <= 1.0):
msg = "`p` should be a float in [0.0, 1.0]."
raise ValueError(msg)
super().__init__()
self.p = p
[docs]
@override
def forward(self, *inputs: Any) -> Any:
"""Apply the transform with probability ``p``.
Returns:
Transformed inputs, or the original inputs if skipped.
"""
inputs = inputs if len(inputs) > 1 else inputs[0]
if torch.rand(1) >= self.p:
return inputs
flat_inputs, spec = tree_flatten(inputs)
needs = [_needs_transform(inpt) for inpt in flat_inputs]
params = self.make_params([inpt for inpt, nt in zip(flat_inputs, needs) if nt])
flat_outputs = [
self.transform(inpt, params) if nt else inpt
for inpt, nt in zip(flat_inputs, needs)
]
return tree_unflatten(flat_outputs, spec)