Source code for vision3d.transforms._point_cloud
"""Point cloud transform classes."""
from typing import Any, override
import torch
from torch import Tensor
from ._transform import RandomTransform, Transform
from .functional._point_cloud import (
jitter_points,
sample_points,
shuffle_points,
)
[docs]
class PointShuffle(RandomTransform):
"""Randomly permute point order with probability ``p``.
Args:
p: Probability of applying. Default: ``0.5``.
"""
[docs]
@override
def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]:
"""Sample a random permutation.
Returns:
Dict with ``"perm"`` key.
"""
n = max(inpt.shape[0] for inpt in flat_inputs if isinstance(inpt, Tensor))
return {"perm": torch.randperm(n)}
[docs]
class PointSample(Transform):
"""Subsample (or oversample with replacement) to exactly ``n`` points.
If the point cloud has more than ``n`` points, a random subset is
selected. If fewer, points are sampled with replacement to reach
``n``.
Args:
n: Target number of points.
"""
def __init__(self, n: int) -> None:
super().__init__()
self.n = n
[docs]
@override
def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]:
"""Sample indices to reach exactly ``n`` points.
Returns:
Dict with ``"indices"`` key.
"""
num_points = max(
inpt.shape[0] for inpt in flat_inputs if isinstance(inpt, Tensor)
)
if num_points >= self.n:
indices = torch.randperm(num_points)[: self.n]
else:
indices = torch.randint(0, num_points, (self.n,))
return {"indices": indices}
[docs]
class PointJitter(RandomTransform):
"""Add Gaussian noise to point xyz coordinates with probability ``p``.
Args:
sigma: Standard deviation of the Gaussian noise. Default: ``0.01``.
p: Probability of applying. Default: ``0.5``.
"""
def __init__(self, sigma: float = 0.01, p: float = 0.5) -> None:
super().__init__(p=p)
self.sigma = sigma
[docs]
@override
def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]:
"""Sample Gaussian noise.
Returns:
Dict with ``"noise"`` key containing ``[N, 3]`` tensor.
"""
n = max(inpt.shape[0] for inpt in flat_inputs if isinstance(inpt, Tensor))
return {"noise": torch.randn(n, 3) * self.sigma}