Metadata-Version: 2.4
Name: ml_atlas_sdk
Version: 0.2.0
Summary: Atlas artifact bundle SDK — pushes a PyTorch model + validation data to MLflow in the format atlas (TRT/Triton optimizer) consumes.
Project-URL: Homepage, https://github.com/Meesho/BharatMLStack
Project-URL: Bug Tracker, https://github.com/Meesho/BharatMLStack/issues
Project-URL: Source, https://github.com/Meesho/BharatMLStack/tree/main/py-sdk/ml_atlas_sdk
Author-email: BharatML Stack Team <bharatml@meesho.com>
License: BharatMLStack Business Source License 1.1
Classifier: License :: Other/Proprietary License
Classifier: Operating System :: OS Independent
Classifier: Programming Language :: Python :: 3
Requires-Python: >=3.9
Requires-Dist: mlflow==2.9
Requires-Dist: numpy==1.24
Requires-Dist: onnx==1.14
Requires-Dist: onnxruntime==1.16
Requires-Dist: pandas==2.0
Requires-Dist: pydantic==2.0
Requires-Dist: pyyaml==6.0
Requires-Dist: torch==2.0
Provides-Extra: architecture
Requires-Dist: anthropic==0.39; extra == 'architecture'
Provides-Extra: dev
Requires-Dist: black; extra == 'dev'
Requires-Dist: flake8; extra == 'dev'
Requires-Dist: isort; extra == 'dev'
Requires-Dist: mypy; extra == 'dev'
Requires-Dist: pytest==6.0; extra == 'dev'
Provides-Extra: gcs
Requires-Dist: google-cloud-storage==2.10; extra == 'gcs'
Description-Content-Type: text/markdown

# ml_atlas_sdk

Pushes a PyTorch model + validation data to MLflow in the format atlas (the
TensorRT/Triton optimizer) consumes.

## Why

Atlas needs 9 specific input artifacts on GCS in a strict format. Without this
SDK, data scientists hand-write ONNX exports, val_inputs.pt, metadata.json,
etc. — and silently get them wrong (FP16/GPU-generated reference, bad
constant-folding, missing dynamic axes). Failures only surface inside atlas
hours later.

`ml_atlas_sdk` consolidates all 9 artifacts behind one `push_to_mlflow(bundle, run)`
call, validates the model+data locally before upload, and runs an
ORT round-trip self-check that catches export bugs before they ship.

## Install

```bash
pip install ml_atlas_sdk
# optional extras
pip install ml_atlas_sdk[architecture]   # LLM-driven architecture.md
pip install ml_atlas_sdk[gcs]            # for atlas reverse-side resolver
```

> **LLM architecture autogen (`architecture_doc="auto"`, the default)**
> requires both the `[architecture]` extra and `ANTHROPIC_API_KEY` in the
> environment when `push_to_mlflow` runs. If either is missing, the SDK
> silently falls back to a structural+ONNX-summary markdown (PyTorch `repr` +
> op histogram) — it never raises.

## Minimal usage

```python
import mlflow
import torch
from ml_atlas_sdk import (
    AtlasArtifactBundle,
    InputBinding,
    InputSchema,
    ModelMetadata,
    push_to_mlflow,
)

bundle = AtlasArtifactBundle(
    model_name="my_model",
    model=my_model,                                  # nn.Module
    metadata=ModelMetadata(
        model_name="my_model",
        input_names=["input__0", "input__1"],
        output_names=["output__0"],
        input_shapes={"input__0": [-1, 4, 8], "input__1": [-1, 16]},
        output_shape={"output__0": [-1, 3]},
        dynamic_axes={                               # paste from your torch.onnx.export call
            "input__0": {0: "batch_size"},
            "input__1": {0: "batch_size"},
            "output__0": {0: "batch_size"},
        },
        onnx_opset=17,
        torch_version=torch.__version__,
        architecture="MyModel",
        task_type="multitask_ranking",
    ),
    input_schema=InputSchema(inputs=[                # MANDATORY — predator-shaped feature bindings
        InputBinding(name="input__0", features=["ONLINE_FEATURE|user:age_bucket"]),
        InputBinding(name="input__1", features=["OFFLINE_FEATURE|user_segments"]),
    ]),
    val_inputs_df=val_in_df,                         # one column per ONNX input
    export_batch_size=64,                            # rows to slice for the ONNX trace sample
)

mlflow.set_tracking_uri("...")
mlflow.set_experiment("ranking_models")
with mlflow.start_run() as run:
    run_id = push_to_mlflow(bundle, run)
```

That's it. The SDK:
- forces `model.eval().cpu().float()` + `torch.no_grad()` for export
- generates canonical val_outputs by running the model on val_inputs
- exports ONNX with proper dynamic_axes
- runs ORT round-trip self-check (with 5-step escalation if it fails)
- uploads the required atlas artifacts under `atlas/` in the run (including
  `input_schema.json` — the predator-shaped feature bindings, structurally
  validated) plus any optional ones the bundle supplied (`traced.pt`,
  `calibration_inputs.pt`, `val_outputs_user.pt`, `drift_report.json`,
  `architecture.md`) + logs all params/tags atlas needs

See `PLAN.md` for the full design.

## Full usage

Every field on `AtlasArtifactBundle` (and the nested `ModelMetadata`) shown
with realistic values, including a non-batch dynamic axis (`seq_len`). Use
this when you want INT8 calibration, drift comparison against your own
reference outputs, or to override SDK defaults.

```python
from pathlib import Path

import mlflow
import pandas as pd
import torch
from ml_atlas_sdk import (
    AtlasArtifactBundle,
    InputBinding,
    InputSchema,
    ModelMetadata,
    push_to_mlflow,
)

# ── 1. Metadata: ONNX-side schema.
#       In input_shapes / output_shape: use -1 for every dynamic dim and
#       positive ints for static dims.
#       In dynamic_axes: paste your torch.onnx.export `dynamic_axes` dict
#       verbatim — same format, {io_name: {axis_idx: axis_name}}. Static
#       axes are NOT listed here; axis 0 (batch) MUST be listed for every
#       input AND every output.
metadata = ModelMetadata(
    model_name="my_model",
    input_names=["tokens", "dense"],
    output_names=["logits"],
    input_shapes={
        "tokens": [-1, -1],         # axes 0 and 1 both dynamic; names below
        "dense":  [-1, 16],         # axis 1 static (16)
    },
    output_shape={
        "logits": [-1, 3],          # axis 1 static (3)
    },
    dynamic_axes={
        "tokens": {0: "batch_size", 1: "seq_len"},
        "dense":  {0: "batch_size"},
        "logits": {0: "batch_size"},
    },
    onnx_opset=17,                  # must be >= 17 (LayerNorm/SDPA precision)
    torch_version=torch.__version__,
    architecture="MyModel",
    task_type="multitask_ranking",
)

# ── 2. Input schema: predator-shaped feature bindings (MANDATORY).
#       One binding per ONNX input; each feature is "TYPE|name" where
#       TYPE is one of ALLOWED_FEATURE_TYPES. Validation is structural —
#       the SDK does not reach out to OFS / offline APIs.
input_schema = InputSchema(inputs=[
    InputBinding(name="tokens", features=["ONLINE_FEATURE|user:tokens"]),
    InputBinding(name="dense",  features=["OFFLINE_FEATURE|user_dense"]),
])
# Equivalent accepted forms:
#   input_schema = {"inputs": [...]}                     # dict
#   input_schema = Path("input_schema.json")             # file path on disk
#   input_schema = "input_schema.json"                   # str path on disk

# ── 3. Validation data: one column per ONNX input, ndarray cells.
val_in_df: pd.DataFrame = ...        # cols: "tokens", "dense"

# Optional: your own reference outputs. SDK still generates canonical
# val_outputs internally; yours is used ONLY for the drift report
# (val_outputs_user.pt is uploaded for debugging, never as the atlas artifact).
val_out_df: pd.DataFrame = ...       # cols: "logits"

# Optional: calibration sample for INT8. >= 1000 rows recommended.
calib_in_df: pd.DataFrame = ...      # cols: "tokens", "dense"

bundle = AtlasArtifactBundle(
    model_name="my_model",
    model=my_model,                                  # nn.Module — SDK forces .eval().cpu().float()
    metadata=metadata,
    input_schema=input_schema,                       # mandatory; see step 2
    val_inputs_df=val_in_df,
    val_outputs_df=val_out_df,                       # optional; drives drift report
    export_batch_size=64,                            # rows to slice for the ONNX trace sample (default 64)
    calibration_inputs_df=calib_in_df,               # optional; needed only for INT8 path in atlas
    architecture_doc="auto",                         # "auto" (LLM-generated), a str/Path, or None to skip
    skip_sdk_validations=False,                      # see Bypass section below
    drift_tolerance_mse=1e-3,                        # raises drift error if user vs canonical mse exceeds this
    drift_tolerance_max_abs=5e-2,                    # raises drift error if max|user-canonical| exceeds this
    bundle_schema_version="1",
)

mlflow.set_tracking_uri("...")
mlflow.set_experiment("ranking_models")
with mlflow.start_run() as run:
    run_id = push_to_mlflow(bundle, run)
```

Notes:
- `input_schema` is **mandatory**. It mirrors the predator-side feature
  schema (`{inputs: [{name, features: ["TYPE|name"]}]}`) and is written
  pass-through as `atlas/input_schema.json` in the MLflow run. Accepted
  forms on the bundle: `InputSchema` instance, a `dict` of the same
  shape, or a `Path`/`str` pointing to a JSON file on disk (auto-loaded).
  Valid feature-type prefixes are exported as `ALLOWED_FEATURE_TYPES`:
  `ONLINE_FEATURE`, `PARENT_ONLINE_FEATURE`, `OFFLINE_FEATURE`,
  `PARENT_OFFLINE_FEATURE`, `RTP_FEATURE`, `PARENT_RTP_FEATURE`,
  `DEFAULT_FEATURE`, `PARENT_DEFAULT_FEATURE`, `MODEL_FEATURE`,
  `CALIBRATION`, `PCTR_CALIBRATION`, `PCVR_CALIBRATION`. Validation is
  structural only — the SDK does not call OFS or offline APIs.
- `dynamic_axes` is the same dict you'd pass to `torch.onnx.export` — copy it
  in verbatim. Every input AND every output must have axis 0 listed (batch is
  always dynamic for atlas). Non-batch dynamic axes are named here too.
- Canonical value lists for non-batch dynamic axes (e.g. `seq_len = [64, 128, 256]`)
  AND the deployment-time batch list are supplied as part of the atlas
  trigger-run request, not on the bundle — the bundle only declares *which*
  axes are dynamic and what they're called.
- `export_batch_size` is **only** the SDK-side trace sample size — how many
  rows of `val_inputs_df` get sliced when running `torch.onnx.export`. It has
  nothing to do with deployment batch sizes. Default 64 works for most models;
  bump it if the trace needs more variance, drop it for tiny models / sparse
  data. If `val_inputs_df` has fewer rows, all rows are used.
- `val_outputs_df` is **optional**; the SDK always generates the canonical
  `val_outputs.pt` itself by running the FP32-CPU model under `torch.no_grad()`
  on `val_inputs_df`. If you supply yours, it's saved as `val_outputs_user.pt`
  for debugging and compared against the canonical (mse / max_abs) — that's
  the drift report.
- `calibration_inputs_df` is **optional**; omit it unless atlas will run the
  INT8 quantization path for this model.
- `architecture_doc="auto"` is the default. To get LLM-generated docs, install
  with `pip install ml_atlas_sdk[architecture]` and set `ANTHROPIC_API_KEY` in the
  environment. If the extra is missing, the API call fails, or the env var is
  unset, the SDK silently falls back to a structural+ONNX-summary markdown
  (PyTorch `repr` + op histogram). Pass a `str`/`Path` to supply your own
  markdown, or `None` to skip the file entirely.
- `skip_sdk_validations` accepts `True` (skip all), `False` (skip none), or a
  list/set of the check names documented in the **Validations** section below.

## Validations

Every check below runs by default during `push_to_mlflow`. Pass the check's
name in `skip_sdk_validations` to skip it. Atlas re-validates everything on its
side regardless — skipping here only changes whether you get fast local
feedback or slower atlas-side feedback after the upload.

| Check                | What it does                                                                                                                                  | Skip when                                                                                       | If you skip                                                          |
|----------------------|-----------------------------------------------------------------------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------------|----------------------------------------------------------------------|
| `df_shape_dtype`     | Confirms every cell in `val_inputs_df` / `val_outputs_df` / `calibration_inputs_df` has shape `[d for d in metadata.input_shapes[name] if d != -1]` and the right dtype (mirrors atlas `validate_val_data`). | You've already verified DF shapes match metadata and are debugging a different failure.         | Wrong-shape cells slip through; atlas rejects the bundle after upload. |
| `nan_inf`            | Hard-fails on NaN/Inf in `val_inputs_df` and in the canonical `val_outputs` the SDK generates; warns (doesn't fail) on NaN/Inf in user-supplied `val_outputs_df`. | Input data is *known* to contain NaN/Inf sentinels that your model handles correctly.            | Bad-numerics inputs/outputs ship and TRT engine builds may NaN-out in prod. |
| `calib_sample_count` | Warns if `calibration_inputs_df` has < 1000 rows (recommended floor for INT8 quantizer stability).                                            | You've intentionally supplied a small calibration sample (e.g. for a tiny model or test path).  | INT8 quantization may pick poor scale/zero-point and drop accuracy. |
| `forward_scan`       | Greps `inspect.getsource(model.forward)` for tracing footguns: `.item()`, `.tolist()`, `int(tensor.x(...))`, `range(tensor.size(...))`, `if x.training`, `.cpu().numpy()`, `if tensor.any()/.all()`. Warns per match. | Forward method genuinely uses one of these (e.g. eval-only `.item()` outside the traced path).  | You lose the early heads-up about Python-side control flow that bakes constants into ONNX. |
| `roundtrip`          | Loads the freshly-exported ONNX in ORT (CPU), runs `val_inputs` through it, and compares to the canonical PyTorch outputs (`rtol=1e-2`, `atol=5e-2`). On failure, the 5-step escalation loop walks opset bumps + const-fold off + dynamo before giving up. | The roundtrip flakes on a known-tolerated op and you've already verified accuracy off-band.     | Export bugs (FP16 contamination, opset-decomposition drift, broken constant-folding) ship silently. |
| `sample_invariance`  | Exports the model twice on disjoint val_input slices, hashes both ONNX byte streams *after stripping initializer values*. Different hashes → tracing was sample-dependent; warns naming the divergent op type. | Known harmless per-sample variation (rare).                                                     | A Python-control-flow bug in `forward` produces sample-dependent ONNX with no warning. |
| `onnx_cross_check`   | Post-export: `onnx.checker.check_model` passes, ONNX input/output names equal `metadata.input_names`/`output_names`, every I/O has `dim_param` at axis 0 (dynamic batch), opset version is exactly `metadata.onnx_opset` (or higher if escalation bumped it). Mirrors atlas `validate_onnx`. | Investigating exporter internals where the cross-check is the symptom, not the cause.           | Schema mismatches between your declared metadata and the actual ONNX slip through. |
| `drift`              | Computes `mse`, `max_abs_err`, `mean_abs_err`, `max_pct_err`, `mean_pct_err`, and (for last-axis size ≥ 8) `cosine_sim` between user-supplied `val_outputs_df` and the SDK's canonical outputs. Fails if `mse > drift_tolerance_mse` or `max_abs > drift_tolerance_max_abs`. | Tolerances are intentionally tight for diagnostics and you accept the noise.                    | A discrepancy between the outputs you think your model produces and what it actually produces ships unnoticed. |

## Bypass for emergency pushes

```python
bundle = AtlasArtifactBundle(..., skip_sdk_validations=True)              # all
bundle = AtlasArtifactBundle(..., skip_sdk_validations={"roundtrip"})     # granular
```

Atlas re-validates on its side regardless; bypass only saves DS local
feedback time. See `PLAN.md` §Bypass validations.

## Layout

```
ml_atlas_sdk/
  models.py            AtlasArtifactBundle, ModelMetadata, InputSchema, InputBinding
  errors.py            AtlasBundleError
  materialize.py       orchestrator (14-step deterministic order)
  validators/
    df_check.py        DataFrame schema check (mirrors atlas validate_val_data)
    onnx_check.py      post-export ONNX cross-check (mirrors atlas validate_onnx)
    roundtrip.py       ORT round-trip vs canonical PyTorch outputs
  exporters/
    df_to_tensor.py    DataFrame -> dict[name, Tensor]
    onnx_export.py     export-env pin + variance-aware sample + dynamic_axes
    escalation.py      5-attempt round-trip escalation loop
    forward_scan.py    inspect.getsource footgun warnings
    sample_invariance.py  two-sample export-hash divergence probe
    drift.py           DS-supplied vs canonical metrics (mse, max_abs, ...)
  architecture/
    autogen.py         LLM-driven architecture.md
    inspector.py       PyTorch + ONNX graph signal extraction
    cache.py           ~/.cache/ml-atlas-sdk/architecture/
  upload/
    mlflow_upload.py   push_to_mlflow public entrypoint
tests/
  smoke_e2e.py         happy-path e2e against file:// MLflow tracker
  smoke_bypass_drift.py  bypass + drift report paths
  fixtures/tiny_model.py
```
