Metadata-Version: 2.4
Name: uni-transform
Version: 0.2.1
Summary: Unified 3D transform utilities supporting both NumPy and PyTorch backends
Author: Junhao Tu
License-Expression: MIT
Project-URL: Homepage, https://github.com/junhaotu/uni-transform
Project-URL: Repository, https://github.com/junhaotu/uni-transform
Project-URL: Documentation, https://github.com/junhaotu/uni-transform#readme
Project-URL: Issues, https://github.com/junhaotu/uni-transform/issues
Keywords: robotics,transforms,rotation,quaternion,euler,pytorch,numpy,SE3,SO3,rigid-body
Classifier: Development Status :: 4 - Beta
Classifier: Intended Audience :: Developers
Classifier: Intended Audience :: Science/Research
Classifier: Operating System :: OS Independent
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.8
Classifier: Programming Language :: Python :: 3.9
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Topic :: Scientific/Engineering
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Classifier: Topic :: Scientific/Engineering :: Mathematics
Classifier: Typing :: Typed
Requires-Python: >=3.8
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: numpy>=1.20.0
Requires-Dist: pytest>=8.3.5
Requires-Dist: rerun-sdk>=0.22.1
Requires-Dist: scipy>=1.7.0
Requires-Dist: torch>=1.9.0
Provides-Extra: dev
Requires-Dist: pytest>=7.0.0; extra == "dev"
Requires-Dist: pytest-cov>=4.0.0; extra == "dev"
Requires-Dist: black>=23.0.0; extra == "dev"
Requires-Dist: isort>=5.12.0; extra == "dev"
Requires-Dist: mypy>=1.0.0; extra == "dev"
Requires-Dist: ruff>=0.1.0; extra == "dev"
Requires-Dist: pre-commit>=3.5.0; extra == "dev"
Provides-Extra: docs
Requires-Dist: sphinx>=6.0.0; extra == "docs"
Requires-Dist: sphinx-rtd-theme>=1.3.0; extra == "docs"
Requires-Dist: myst-parser>=2.0.0; extra == "docs"
Dynamic: license-file

# uni-transform

A Python library for 3D rigid body transformations with **NumPy** and **PyTorch** backends.

[![Python 3.8+](https://img.shields.io/badge/python-3.8+-blue.svg)](https://www.python.org/downloads/)
[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)

**Key Features:**
- **Dual API** - `Transform` for poses (rotation + translation), `Rotation` for pure rotations
- **Batch Operations** - All functions support arbitrary batch dimensions `(..., N)`
- **PyTorch Gradients** - Fully differentiable for deep learning
- **Dual Backend** - Seamless NumPy ↔ PyTorch switching
- **Extra State** - Store gripper width, joint states alongside poses
- **Unit Support** - Explicit translation units (meters/millimeters)
- **Visualization** - 3D visualization with [Rerun](https://rerun.io/) (optional)

## Installation

```bash
uv add uni-transform
```

From source:

```bash
git clone https://github.com/junhaotu/uni-transform.git
cd uni-transform
uv pip install -e .
```

## Quick Start

### Transform Class (Rotation + Translation)

```python
import numpy as np
from uni_transform import Transform

# Create from euler angles [x, y, z, roll, pitch, yaw]
tf = Transform.from_rep(np.array([1.0, 2.0, 3.0, 0.1, 0.2, 0.3]), from_rep="euler")

# Convert representations
quat_rep = tf.to_rep("quat")      # [x, y, z, qx, qy, qz, qw]
matrix = tf.as_matrix()            # 4x4 homogeneous matrix

# Compose & transform
tf_composed = tf @ tf.inverse()
points = tf.transform_point(np.array([[1, 0, 0], [0, 1, 0]]))
```

### Rotation Class (Pure Rotation)

```python
from uni_transform import Rotation

# Create from various representations
rot = Rotation.from_euler(np.array([0.1, 0.2, 0.3]), seq="ZYX")
rot = Rotation.from_quat(np.array([0, 0, 0.707, 0.707]))
rot = Rotation.from_rotvec(np.array([0, 0, np.pi/2]))
rot = Rotation.from_rotation_6d(np.array([1, 0, 0, 0, 1, 0]))

# Or use generic from_rep
rot = Rotation.from_rep(np.array([0.1, 0.2, 0.3]), from_rep="euler", seq="ZYX")

# Convert to any representation
quat = rot.to_rep("quat")
euler = rot.as_euler(seq="ZYX", degrees=True)
matrix = rot.as_matrix()

# Compose rotations
combined = rot1 @ rot2

# Apply to vectors
rotated = rot.apply(np.array([1.0, 0.0, 0.0]))

# Interpolation
rot_mid = rot1.slerp(rot2, t=0.5)

# Relative rotation
rot_rel = rot2.relative_to(rot1)  # rot1 @ rot_rel = rot2
```

### With Extra State (Gripper Width, etc.)

```python
# Robot pose with gripper: [x, y, z, qx, qy, qz, qw, gripper_width]
pose = np.array([0.5, 0.2, 0.1, 0, 0, 0, 1, 0.04])
tf = Transform.from_rep(pose, from_rep="quat", extra_dims=1)

tf.translation  # [0.5, 0.2, 0.1]
tf.extra        # [0.04] - gripper width preserved

# Extra is included in conversions
euler_pose = tf.to_rep("euler")  # [x, y, z, r, p, y, gripper]

# Extra is interpolated too
tf_mid = interpolate_transform(tf_start, tf_end, t=0.5)
tf_mid.extra  # Linearly interpolated gripper width
```

### Translation Units

```python
# Explicit unit tracking prevents accidental mixing
tf_m = Transform.from_pos_quat([1.0, 2.0, 3.0], [0, 0, 0, 1], translation_unit="m")
tf_mm = tf_m.to_unit("mm")  # Translation becomes [1000, 2000, 3000]

# Unit mismatch raises error (catches bugs early)
tf_mm @ tf_m  # Raises UnitMismatchError

# Stack also requires matching units
Transform.stack([tf_m, tf_m.clone()])  # OK
Transform.stack([tf_m, tf_mm])         # Raises UnitMismatchError
```

### Rotation Conversions (Functional API)

```python
from uni_transform import quaternion_to_matrix, matrix_to_euler, convert_rotation

# Direct conversions
quat = np.array([0, 0, 0.707, 0.707])  # xyzw format
matrix = quaternion_to_matrix(quat)
euler = matrix_to_euler(matrix, seq="ZYX")

# Generic conversion
euler = convert_rotation(quat, from_rep="quat", to_rep="euler", seq="ZYX")

# Static method on classes (no instance needed)
euler = Rotation.convert(quat, from_rep="quat", to_rep="euler", seq="ZYX")
pose_quat = Transform.convert(pose_euler, from_rep="euler", to_rep="quat")
```

### PyTorch Gradients

```python
import torch
from uni_transform import Transform, Rotation, geodesic_distance

# Transform with gradients
pred = Transform.from_rep(torch.randn(100, 9), from_rep="rotation_6d", requires_grad=True)
target = Transform.from_rep(torch.randn(100, 9), from_rep="rotation_6d")

loss = geodesic_distance(pred.rotation, target.rotation).mean()
loss.backward()  # Fully differentiable

# Rotation with gradients
rot = Rotation.from_euler(torch.tensor([0.1, 0.2, 0.3], requires_grad=True))
rotated = rot.apply(torch.tensor([1.0, 0.0, 0.0]))
rotated.sum().backward()
```

### Batch Dimensions

```python
# All operations support arbitrary batch dimensions
batch_tf = Transform.from_rep(np.random.randn(10, 50, 7), from_rep="quat")  # (10, 50) batch
batch_tf.rotation.shape   # (10, 50, 3, 3)
batch_tf.translation.shape  # (10, 50, 3)

# Same for Rotation
batch_rot = Rotation.from_rep(np.random.randn(10, 50, 4), from_rep="quat")
batch_rot.batch_shape  # (10, 50)

# Compose batched transforms
result = batch_tf @ batch_tf.inverse()  # Broadcasting supported
```

### Interpolation

```python
from uni_transform import (
    interpolate, interpolate_sequence,
    interpolate_transform, interpolate_transform_sequence,
    compute_spline, quaternion_slerp
)

# Vector interpolation
pos = interpolate(start, end, t=0.5)  # Linear
pos = interpolate(start, end, t=0.5, method="minimum_jerk", duration=2.0)  # Smooth

# Multi-point vector interpolation
positions = interpolate_sequence(waypoints, times, query_times, method="cubic_spline")

# Transform interpolation (rotation + translation)
tf_mid = interpolate_transform(tf_start, tf_end, t=0.5)
tf_mid = interpolate_transform(tf_start, tf_end, t=0.5,
    rotation_method="nlerp",           # "slerp" or "nlerp"
    translation_method="minimum_jerk", # "linear" or "minimum_jerk"
    duration=2.0
)

# Multi-point transform interpolation
keyframes = Transform.stack([tf0, tf1, tf2, tf3])
times = np.array([0.0, 1.0, 2.0, 3.0])
query_times = np.array([0.5, 1.5, 2.5])
result = interpolate_transform_sequence(keyframes, times, query_times,
    rotation_method="squad",           # "slerp", "nlerp", or "squad" (smooth)
    translation_method="cubic_spline"  # "linear", "minimum_jerk", or "cubic_spline"
)

# Rotation interpolation (via class methods)
rot_mid = rot_start.slerp(rot_end, t=0.5)  # Spherical interpolation
rot_mid = rot_start.nlerp(rot_end, t=0.5)  # Faster, approximate

# Reusable spline (compute once, evaluate many times)
spline = compute_spline(waypoints, times)
positions = spline.evaluate(query_times)
velocities = spline.derivative(query_times, order=1)

# Low-level quaternion SLERP
q_mid = quaternion_slerp(q0, q1, t=0.5)
```

### Relative Transforms & Deltas

```python
# Express tf in reference frame
tf_in_ref = tf.relative_to(reference_tf)  # = reference_tf.inverse() @ tf

# Apply incremental delta
tf_new = tf.apply_delta(delta_tf, in_world_frame=True)   # = delta_tf @ tf
tf_new = tf.apply_delta(delta_tf, in_world_frame=False)  # = tf @ delta_tf

# Same for Rotation
rot_rel = rot.relative_to(ref_rot)
rot_new = rot.apply_delta(delta_rot, in_body_frame=False)  # World frame
rot_new = rot.apply_delta(delta_rot, in_body_frame=True)   # Body frame
```

### Visualization (Rerun)

Visualize transforms, trajectories, and point clouds with [Rerun](https://rerun.io/). Requires `pip install rerun-sdk`.

```python
from uni_transform import Transform
from uni_transform.visualization import RerunVisualizer, log_transform, log_trajectory

# High-level API
viz = RerunVisualizer("my_app", spawn=True)  # spawn=True for GUI
viz.show_transform("robot/base", base_tf)
viz.show_trajectory("robot/path", trajectory)
viz.show_points("scene/cloud", points)

```

## Rotation Representations

| Name | Shape | Description |
|------|-------|-------------|
| `matrix` | `(..., 3, 3)` | SO(3) rotation matrix |
| `quat` | `(..., 4)` | Quaternion (**xyzw** format) |
| `euler` | `(..., 3)` | Euler angles (default: ZYX) |
| `rotation_6d` | `(..., 6)` | 6D continuous rotation |
| `rot_vec` | `(..., 3)` | Rotation vector (axis × angle) |

> `...` = arbitrary batch dimensions, e.g. `(B, T, 3, 3)` for batched trajectories

## Transform Class API

### Factory Methods

```python
Transform.identity(backend="numpy", extra_dims=0, translation_unit="m")
Transform.from_matrix(matrix_4x4, translation_unit="m")
Transform.from_rep(data, from_rep="quat", extra_dims=0, translation_unit="m")
Transform.from_pos_quat(position, quaternion, extra=None, translation_unit="m")
Transform.stack(transforms, axis=0)  # Requires matching units
Transform.convert(data, from_rep, to_rep, ...)  # Static conversion
```

### Properties

```python
tf.rotation          # (..., 3, 3) rotation matrix
tf.translation       # (..., 3) translation vector
tf.extra             # (..., K) extra state or None
tf.translation_unit  # TranslationUnit.METER or MILLIMETER
tf.backend           # "numpy" or "torch"
tf.batch_shape       # tuple of batch dimensions
tf.extra_dims        # number of extra dimensions (0 if None)
```

### Methods

```python
tf.as_matrix()           # 4x4 homogeneous matrix
tf.to_rep("quat")        # Convert to [translation, rotation, extra]
tf.to_unit("mm")         # Convert translation unit
tf.inverse()             # Inverse transform
tf.transform_point(p)    # Apply to point(s)
tf.transform_vector(v)   # Apply rotation only
tf.apply_delta(delta)    # Compose with delta
tf.relative_to(ref)      # Express in reference frame
tf.clone()               # Deep copy
tf.to(device, dtype)     # Move to device (PyTorch)
tf.detach()              # Detach from graph (PyTorch)
```

## Rotation Class API

### Factory Methods

```python
Rotation.identity(backend="numpy")
Rotation.from_matrix(matrix_3x3)
Rotation.from_rep(data, from_rep="quat", seq="ZYX", degrees=False)
Rotation.from_quat(quaternion)           # xyzw format
Rotation.from_euler(euler, seq="ZYX")
Rotation.from_rotvec(rotation_vector)
Rotation.from_rotation_6d(rot_6d)
Rotation.stack(rotations, axis=0)
Rotation.convert(data, from_rep, to_rep, ...)  # Static conversion
```

### Properties

```python
rot.matrix       # (..., 3, 3) rotation matrix
rot.backend      # "numpy" or "torch"
rot.batch_shape  # tuple of batch dimensions
rot.is_batched   # bool
```

### Methods

```python
rot.to_rep("euler", seq="ZYX")   # Convert to representation
rot.as_matrix()                   # (..., 3, 3)
rot.as_quat()                     # (..., 4) xyzw
rot.as_euler(seq="ZYX")           # (..., 3)
rot.as_rotvec()                   # (..., 3)
rot.as_rotation_6d()              # (..., 6)
rot.inverse()                     # Inverse rotation
rot.apply(vectors)                # Rotate vectors
rot.apply_delta(delta)            # Compose
rot.relative_to(ref)              # Relative rotation
rot.slerp(other, t)               # Spherical interpolation
rot.nlerp(other, t)               # Normalized linear interpolation
rot.clone()                       # Deep copy
rot.to(device, dtype)             # Move to device (PyTorch)
rot.detach()                      # Detach from graph (PyTorch)
```

## Key Functions

```python
# Conversions
quaternion_to_matrix, matrix_to_quaternion
euler_to_matrix, matrix_to_euler, matrix_to_euler_differentiable
rotvec_to_matrix, matrix_to_rotvec
quaternion_to_rotvec, rotvec_to_quaternion
rotation_6d_to_matrix, matrix_to_rotation_6d
convert_rotation, rotation_to_matrix, matrix_to_rotation

# Quaternion ops
quaternion_multiply, quaternion_apply, quaternion_inverse, quaternion_conjugate

# Distance (for loss functions)
geodesic_distance, translation_distance, transform_distance

# SE(3) Lie group
se3_log, se3_exp

# Interpolation (unified API)
interpolate, interpolate_sequence                    # Vector/scalar
interpolate_rotation, interpolate_rotation_sequence # Rotation
interpolate_transform, interpolate_transform_sequence # Transform
compute_spline, SplineCoefficients                   # Reusable spline

# Interpolation (low-level)
quaternion_slerp, quaternion_nlerp, quaternion_squad
minimum_jerk_interpolate, minimum_jerk_velocity, minimum_jerk_acceleration
cubic_spline_coefficients, cubic_spline_interpolate, cubic_spline_derivative

# Utilities
orthogonalize_rotation, xyz_rotation_6d_to_matrix

# Visualization (requires rerun-sdk)
from uni_transform.visualization import (
    init, connect, save,                    # Setup
    log_transform, log_trajectory,          # Core logging
    log_trajectory_animated,                # Animation
    log_transform_manager,                  # TransformManager graph
    log_points, log_text, log_scalar,       # Primitives
    set_time,                               # Timeline control
    RerunVisualizer,                        # High-level API
)
```

## Conventions

- **Quaternion**: `xyzw` format (matches SciPy/ROS)
- **Euler default**: `ZYX` sequence (yaw-pitch-roll)
- **Composition**: `tf1 @ tf2` applies `tf2` first, then `tf1`
- **Translation unit**: Default is meters (`"m"`), also supports millimeters (`"mm"`)
- **Extra**: Preserved through `inverse()`, `clone()`, `to_unit()`; interpolated in `interpolate_transform()`

## Module Structure

```
uni_transform/
├── __init__.py              # Public API exports
├── _core.py                 # Types, constants, backend utilities
├── rotation_conversions.py  # All rotation conversion functions
├── rotation.py              # Rotation class
├── transform.py             # Transform class
├── interpolation.py         # Interpolation functions
├── se3.py                   # SE(3) Lie group operations
├── metrics.py               # Distance functions
└── visualization.py         # Rerun visualization (optional)
```

## License

MIT License
