Metadata-Version: 2.4
Name: strands-vjepa
Version: 0.1.1
Summary: V-JEPA 2 world model & action-conditioned planner for Strands Agents
Author-email: Cagatay Cali <cagataycali@icloud.com>
License: Apache-2.0
Project-URL: Homepage, https://github.com/cagataycali/strands-vjepa
Project-URL: Repository, https://github.com/cagataycali/strands-vjepa
Project-URL: Issues, https://github.com/cagataycali/strands-vjepa/issues
Keywords: strands,agents,vjepa,vjepa2,world-model,physical-ai,robotics,planning,cem
Classifier: Development Status :: 3 - Alpha
Classifier: Intended Audience :: Developers
Classifier: License :: OSI Approved :: Apache Software License
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Programming Language :: Python :: 3.13
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Requires-Python: >=3.10
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: strands-agents
Requires-Dist: cc-vjepa2>=0.0.3
Requires-Dist: torch>=2.0.0
Requires-Dist: torchvision
Requires-Dist: timm>=0.9.0
Requires-Dist: einops>=0.7.0
Requires-Dist: numpy
Requires-Dist: pillow>=10.0.0
Requires-Dist: pyyaml>=6.0
Provides-Extra: video
Requires-Dist: decord>=0.6.0; platform_machine == "x86_64" and extra == "video"
Requires-Dist: decord2>=3.3.0; platform_machine == "aarch64" and extra == "video"
Requires-Dist: decord2>=3.3.0; (platform_system == "Darwin" and platform_machine == "arm64") and extra == "video"
Requires-Dist: av>=11.0.0; extra == "video"
Provides-Extra: hf
Requires-Dist: transformers>=4.40.0; extra == "hf"
Requires-Dist: huggingface_hub>=0.20.0; extra == "hf"
Provides-Extra: robots
Provides-Extra: dev
Requires-Dist: pytest>=7.0; extra == "dev"
Requires-Dist: pytest-asyncio; extra == "dev"
Requires-Dist: ruff; extra == "dev"
Provides-Extra: all
Requires-Dist: strands-agents-tools; extra == "all"
Requires-Dist: transformers>=4.40.0; extra == "all"
Requires-Dist: huggingface_hub>=0.20.0; extra == "all"
Requires-Dist: av>=11.0.0; extra == "all"
Requires-Dist: decord>=0.6.0; platform_machine == "x86_64" and extra == "all"
Requires-Dist: decord2>=3.3.0; platform_machine == "aarch64" and extra == "all"
Requires-Dist: decord2>=3.3.0; (platform_system == "Darwin" and platform_machine == "arm64") and extra == "all"
Requires-Dist: pytest>=7.0; extra == "all"
Requires-Dist: ruff; extra == "all"
Dynamic: license-file

<div align="center">
  <img src="docs/strands-vjepa-logo.svg" alt="strands-vjepa" width="140">

  # strands-vjepa

  **V-JEPA 2 as a world model, planner, and robot policy for [Strands Agents](https://github.com/strands-agents/strands-agents).**

  [![License](https://img.shields.io/badge/license-Apache%202.0-blue.svg)](LICENSE)
  [![Python 3.10+](https://img.shields.io/badge/python-3.10%2B-blue.svg)]()
</div>

---

## Install

```bash
pip install strands-vjepa
```

## What you get

| Layer | What it does |
|-------|-------------|
| `WorldModel` / `HFWorldModel` | Encode video → latents, predict next state |
| `CEMPlanner` | Cross-Entropy Method search over action space |
| `VJepaPolicy` | Drop-in `strands-robots` Policy (zero-shot action head) |
| 7 agent tools | `vjepa_encode`, `vjepa_plan`, `vjepa_rollout`, `vjepa_train`, `vjepa_pretrain`, `vjepa_eval`, `vjepa_inspect` |
| Training | AC fine-tuning + JEPA self-supervised pretraining |
| Evaluation | Frozen-backbone probes (SSv2, K400, IN1K) |

## Quick start

### Encode + plan (zero-shot, no training)

```python
from strands_vjepa import WorldModel, CEMPlanner
from strands_vjepa.planner import CEMConfig
import torch

wm = WorldModel.from_pretrained("facebook/vjepa2-ac-vitl", device="cuda")
obs = wm.encode(torch.rand(1, 3, 16, 256, 256))
goal = wm.encode(torch.rand(1, 3, 16, 256, 256))

planner = CEMPlanner(wm, CEMConfig(horizon=8, n_samples=256))
result = planner.plan(obs, goal)
print(result.actions.shape)  # [8, 7]
```

### HuggingFace pretrained (real weights)

```python
from strands_vjepa import HFWorldModel

wm = HFWorldModel.from_hf("facebook/vjepa2-vitl-fpc16-256-ssv2")
latents = wm.encode(video)          # [1, 2048, 1024]
logits = wm.classify(video)         # [1, 174] (SSv2 classes)
```

### As Strands Agent tools

```python
from strands import Agent
from strands_vjepa.tools import vjepa_encode, vjepa_plan, vjepa_rollout

agent = Agent(tools=[vjepa_encode, vjepa_plan, vjepa_rollout])
agent("Plan 8 actions from obs.jpg to goal.jpg for a 7-DoF arm")
```

### As a robot policy

```python
import strands_vjepa  # auto-registers with strands-robots
from strands_robots.policies import create_policy

policy = create_policy("vjepa", world_model=wm, action_dim=7, horizon=8)
actions = await policy.get_actions(obs_dict, "goal.jpg")
```

### Fine-tune on your data

```python
from strands_vjepa.training import ACFineTuner, ACTrainingConfig

config = ACTrainingConfig(data_path="/data/trajectories", epochs=100)
ACFineTuner(config).train()
```

## Architecture

```
strands_vjepa/
├── world_model.py      # Encoder + AC predictor
├── planner.py          # CEM planner
├── policy.py           # strands-robots Policy
├── tools/              # 7 @tool-decorated agent tools
├── training/           # AC fine-tune + JEPA pretrain + DROID dataset
├── eval/               # Frozen-backbone classification probes
└── vendored/           # V-JEPA 2 model code (MIT, from Meta)
```

## Upstream

Built on [V-JEPA 2](https://github.com/facebookresearch/vjepa2) (Meta AI, Apache 2.0).  
This package depends on [`cagataycali/vjepa2`](https://github.com/cagataycali/vjepa2) — a fork with 10 community patches (MPS, ST-A², bugfixes).

## License

Apache 2.0
