Source code for vision3d.transforms.functional._registry

"""Kernel dispatch for vision3d transforms.

Minimal reimplementation of torchvision's kernel registry, since the public
``register_kernel`` only allows registering kernels for torchvision's own
functionals.
"""

import functools
from collections.abc import Callable
from typing import Any

from torch import Tensor
from torchvision.tv_tensors import TVTensor

# {functional: {input_type: kernel}}
KERNEL_REGISTRY: dict[Callable[..., Any], dict[type, Callable[..., Any]]] = {}


[docs] def register_kernel( functional: Callable[..., Any], tv_tensor_cls: type[TVTensor], *, tv_tensor_wrapper: bool = True, ) -> Callable[[Callable[..., Any]], Callable[..., Any]]: """Register a kernel for a functional and TVTensor type. Args: functional: The functional to register a kernel for. tv_tensor_cls: The TVTensor subclass this kernel handles. tv_tensor_wrapper: If True (default), the kernel receives an unwrapped pure tensor and the output is automatically re-wrapped. If False, the kernel receives the full TVTensor and must handle wrap itself. Returns: Decorator that registers the kernel. """ registry = KERNEL_REGISTRY.setdefault(functional, {}) def decorator(kernel: Callable[..., Any]) -> Callable[..., Any]: if tv_tensor_cls in registry: msg = ( f"{functional.__name__} already has a kernel " f"registered for {tv_tensor_cls.__name__}." ) raise ValueError(msg) if tv_tensor_wrapper: @functools.wraps(kernel) def wrapper(inpt: TVTensor, *args: Any, **kwargs: Any) -> TVTensor: from vision3d.tensors import wrap output = kernel(inpt.as_subclass(Tensor), *args, **kwargs) return wrap(output, like=inpt) registry[tv_tensor_cls] = wrapper else: registry[tv_tensor_cls] = kernel return kernel return decorator
def _get_kernel(functional: Callable[..., Any], input_type: type) -> Callable[..., Any]: """Look up the registered kernel for a functional and input type. Falls back to passthrough for unregistered types (labels, etc.). Args: functional: The functional to look up. input_type: The type of the input. Returns: The kernel function, or a passthrough lambda. """ registry = KERNEL_REGISTRY.get(functional, {}) for cls in input_type.__mro__: if cls in registry: return registry[cls] if cls is TVTensor: break # Passthrough for plain tensors, labels, etc. return lambda inpt, *args, **kwargs: inpt