Metadata-Version: 2.4
Name: tgraphx
Version: 0.4.4
Summary: Tensor-aware graph neural networks preserving spatial node feature layouts
Author-email: Arash Sajjadi <arash.sajjadi@usask.ca>
Maintainer-email: Arash Sajjadi <arash.sajjadi@usask.ca>
License: MIT
Project-URL: Homepage, https://github.com/arashsajjadi/TGraphX
Project-URL: Repository, https://github.com/arashsajjadi/TGraphX
Project-URL: Documentation, https://github.com/arashsajjadi/TGraphX/tree/main/docs
Project-URL: Issues, https://github.com/arashsajjadi/TGraphX/issues
Project-URL: Changelog, https://github.com/arashsajjadi/TGraphX/blob/main/CHANGELOG.md
Project-URL: Paper, https://arxiv.org/abs/2504.03953
Keywords: graph neural networks,GNN,CNN,spatial,PyTorch,computer vision,message passing,tensor-aware,dashboard,graph learning
Classifier: Development Status :: 4 - Beta
Classifier: Intended Audience :: Science/Research
Classifier: Intended Audience :: Developers
Classifier: License :: OSI Approved :: MIT License
Classifier: Operating System :: OS Independent
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Programming Language :: Python :: 3.13
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Classifier: Topic :: Scientific/Engineering :: Image Recognition
Classifier: Topic :: Software Development :: Libraries :: Python Modules
Requires-Python: >=3.10
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: torch>=1.13
Requires-Dist: torchvision>=0.14
Requires-Dist: pyyaml>=5.4
Provides-Extra: dev
Requires-Dist: pytest>=7.0; extra == "dev"
Requires-Dist: matplotlib>=3.7; extra == "dev"
Requires-Dist: build>=0.10; extra == "dev"
Requires-Dist: twine>=4.0; extra == "dev"
Provides-Extra: monitoring
Requires-Dist: psutil>=5.9; extra == "monitoring"
Requires-Dist: pynvml>=11.0; extra == "monitoring"
Provides-Extra: tracking
Requires-Dist: tensorboard>=2.11; extra == "tracking"
Provides-Extra: mlflow
Requires-Dist: mlflow>=2.0; extra == "mlflow"
Provides-Extra: pyg
Requires-Dist: torch-geometric>=2.0; extra == "pyg"
Provides-Extra: ogb
Requires-Dist: ogb>=1.3.6; extra == "ogb"
Provides-Extra: pillow
Requires-Dist: Pillow>=9.0; extra == "pillow"
Dynamic: license-file

<p align="center">
  <img src="https://raw.githubusercontent.com/arashsajjadi/TGraphX/main/TGRAPHX.png" alt="TGraphX logo" width="220">
</p>

# TGraphX

[![Tests](https://github.com/arashsajjadi/TGraphX/actions/workflows/tests.yml/badge.svg)](https://github.com/arashsajjadi/TGraphX/actions/workflows/tests.yml)
[![PyPI version](https://img.shields.io/pypi/v/tgraphx.svg)](https://pypi.org/project/tgraphx/)
[![Python 3.10+](https://img.shields.io/badge/python-3.10%2B-blue.svg)](https://www.python.org/)
[![PyTorch 1.13+](https://img.shields.io/badge/pytorch-1.13%2B-orange.svg)](https://pytorch.org/)
[![License: MIT](https://img.shields.io/badge/License-MIT-green.svg)](LICENSE)

**Preprint:** [TGraphX: Tensor-Aware Graph Neural Network for Multi-Dimensional Feature Learning](https://arxiv.org/abs/2504.03953) · *Sajjadi & Eramian, arXiv 2025*

Developed by **Arash Sajjadi**, PhD Candidate in Computer Science, University of Saskatchewan. Academic supervision: **Mark Eramian**.

TGraphX is a PyTorch library for graph neural networks whose **node features are multi-dimensional tensors** — `[C, H, W]` image-patch feature maps, `[C, D, H, W]` volumetric maps, or plain vectors `[D]`. Convolutional message passing operates directly on spatial node features, so local structure is never destroyed by flattening.

---

## Install

```bash
pip install tgraphx
```

Optional extras (all lazy-imported, none required for the base package):

```bash
pip install "tgraphx[tracking]"     # TensorBoard
pip install "tgraphx[mlflow]"       # MLflow
pip install "tgraphx[monitoring]"   # psutil + pynvml (dashboard hardware panel)
pip install "tgraphx[pyg]"          # PyTorch Geometric dataset adapter
pip install "tgraphx[ogb]"          # OGB dataset adapter
pip install "tgraphx[pillow]"       # ImageFolder dataset
pip install "tgraphx[dev]"          # pytest + build + twine
```

DGL has platform-sensitive wheels and is not packaged as a TGraphX extra. Follow the [DGL install guide](https://www.dgl.ai/pages/start.html); the DGL adapter is available either way once DGL is on the import path.

For a custom PyTorch build (CPU-only or specific CUDA), install PyTorch first, then TGraphX:

```bash
pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu
pip install tgraphx
```

---

## 60-second quickstart

**Vector node features:**

```python
import torch
from tgraphx import Graph, LinearMessagePassing

x = torch.randn(8, 32)
edge_index = torch.stack([torch.arange(8), (torch.arange(8) + 1) % 8])
g = Graph(x, edge_index)

layer = LinearMessagePassing(in_shape=(32,), out_shape=(64,))
out = layer(g.node_features, g.edge_index)   # [8, 64]
out.sum().backward()
```

**Spatial `[C, H, W]` node features — preserved through message passing:**

```python
import torch
from tgraphx import Graph, ConvMessagePassing

N, C, H, W = 6, 16, 8, 8
g = Graph(torch.randn(N, C, H, W), torch.stack(
    [torch.arange(N), (torch.arange(N) + 1) % N]))

layer = ConvMessagePassing(in_shape=(C, H, W), out_shape=(32, H, W))
out = layer(g.node_features, g.edge_index)   # [6, 32, 8, 8]
out.sum().backward()
```

A Colab tutorial walks through every workflow:
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1agls1xtqE5WxbWthcG0HEa3Gbk3fvoCD?usp=sharing)

---

## Why tensor-aware GNNs

Standard GNN frameworks expect flat vector node features. Flattening a `[C, H, W]` feature map into a `[C·H·W]` vector discards the spatial structure that makes CNNs effective. TGraphX keeps each node's representation as a tensor and applies 1×1 convolutions during message passing, so every neighbourhood aggregation step acts like a miniature CNN across neighbouring feature maps.

```
M_{i→j} = φ( X_i, X_j, E_{i→j} )
A_j     = AGG_{i ∈ N(j)} M_{i→j}
X'_j    = ψ( X_j, A_j )
```

Spatial dimensions `H` and `W` are preserved through every learned transform. Aggregations are permutation-invariant; the layers are permutation-equivariant under node relabelling (verified by `tests/test_math_invariants_v030.py`).

---

## Stable core APIs

| Component | Class | Input node shape | Notes |
|-----------|-------|-----------------|-------|
| Graph data structure | `Graph`, `GraphBatch` | any | Validated; supports `.to(device)` |
| Tensor-aware GCN-style message passing | `ConvMessagePassing` | `[N, C, H, W]` | `aggr="sum" / "mean" / "max"`; chunked forward |
| Tensor-aware multi-head GAT | `TensorGATLayer` | `[N, C, H, W]` | Softmax over incoming edges per destination per head; two-pass log-sum-exp chunked forward |
| Tensor-aware GraphSAGE | `TensorGraphSAGELayer` | `[N, C, H, W]` | Mean / max aggregation; chunked forward |
| Tensor-aware GIN / GINEConv | `TensorGINLayer` | `[N, C, H, W]` | `(1+ε)·h_j + Σ h_i`; chunked forward |
| Vector message passing | `LinearMessagePassing` | `[N, D]` | Base layer with linear projections |
| Custom-layer base class | `TensorMessagePassingLayer` | any | Override `message` / `update` |
| Vector model-zoo | `GCNConv`, `GATv2Conv`, `APPNP` | `[N, D]` | New in v0.3.0; permutation-equivariance and tiny-overfit tested |
| Pooling helpers | `global_mean_pool`, `global_sum_pool`, `global_max_pool` | `[N, *]` | Per-graph reductions |
| CNN patch encoder | `CNNEncoder` | `[N, C_in, pH, pW]` | Outputs spatial feature maps |
| Graph classification | `GraphClassifier` | `[N, C, H, W]` | Mean / sum / max readout |
| Node classification | `NodeClassifier` | `[N, D]` | Vector features |

3-D volumetric node features `[N, C, D, H, W]` are supported by `ConvMessagePassing`, `TensorGATLayer`, `TensorGraphSAGELayer`, and `TensorGINLayer` (pass `spatial_rank=3`).

### Factories

```python
from tgraphx import make_layer, build_model

layer = make_layer("gat", in_shape=(8, 4, 4), out_shape=(16, 4, 4), heads=2, residual=True)
model = build_model(
    task="graph_classification", layer="conv",
    in_shape=(3, 4, 4), hidden_shape=(8, 4, 4),
    num_layers=2, num_classes=10, pooling="mean",
)
```

| Task | `[N,D]` | `[N,C,H,W]` | `[N,C,D,H,W]` |
|------|:-:|:-:|:-:|
| `node_classification` | yes | yes | yes |
| `node_regression` | yes | yes | yes |
| `graph_classification` | yes | yes | yes |
| `graph_regression` | yes | yes | yes |
| `edge_prediction` | yes | yes (spatial pool) | yes (spatial pool) |

---

## Datasets, transforms, metrics, benchmarks

TGraphX ships a unified dataset registry, native synthetic datasets, folder-backed datasets, and optional adapters for torchvision, PyG, DGL, and OGB. **TGraphX does not redistribute third-party datasets**; adapters delegate download and parsing to the upstream library, and any download requires the user to pass `download=True` explicitly.

```python
from tgraphx.datasets import list_datasets, dataset_info, get_dataset

print(list_datasets())                        # 34 registered datasets
ds = get_dataset("synthetic:patch_graph", num_graphs=32, seed=0)
g = ds[0]
```

Cache layout: `<root>` (or `$TGRAPHX_DATA`, or `~/.cache/tgraphx/datasets`). Inspect with `tgraphx.datasets.cache_summary()`; clear with `clear_cache(...)`.

| Layer | Highlights |
|-------|-----------|
| `tgraphx.datasets` | Native synthetic + folder + torchvision/PyG/DGL/OGB adapters; safe atomic downloads with SHA-256 checksums and path-traversal-blocked archive extraction |
| `tgraphx.transforms` | `Compose`, `AddSelfLoops`, `RemoveSelfLoops`, `ToUndirected`, `NormalizeFeatures`, `StandardizeFeatures`, `AddDegreeFeatures`, `RandomNodeSplit`, `RandomLinkSplit`, `AddDegreeEncoding`, `AddLaplacianEigenvectors`, `PatchifyImage`, `BuildGridGraph`, … |
| `tgraphx.metrics` | Pure-PyTorch `accuracy`, `top_k_accuracy`, `precision_recall_f1`, `classification_report`, `mae`, `mse`, `rmse`, `r2_score`, `hits_at_k`, `mean_reciprocal_rank`, `ndcg_at_k`, `roc_auc`, `average_precision` |
| `benchmarks/` | `benchmark_dataset_loading.py`, `benchmark_training_synthetic.py`, `benchmark_tensor_vs_flatten.py`, `benchmark_transforms.py`, `benchmark_metrics.py`, `benchmark_layers.py`, `benchmark_graph_builders.py`, `benchmark_sampling.py`. Every benchmark accepts `--small` and `--output JSON`. |

Detailed write-ups: [docs/datasets.md](docs/datasets.md), [docs/transforms.md](docs/transforms.md), [docs/metrics.md](docs/metrics.md), [docs/benchmarks.md](docs/benchmarks.md), [docs/dataset_license_policy.md](docs/dataset_license_policy.md).

Importing `tgraphx`, `tgraphx.datasets`, `tgraphx.transforms`, `tgraphx.metrics`, `tgraphx.experiments`, or `tgraphx.explain` does **not** import torch_geometric / dgl / ogb.

---

## Experiment manager

`tgraphx.experiments` (new in v0.3.0) is a lightweight, dashboard-compatible experiment manager. YAML / JSON configs are validated against an explicit schema (no `eval`, no `exec`). Every run writes its artefacts under `run_dir` only:

```
runs/<run_name>/<timestamp>/
├── run_metadata.json           # run name, status, device, seed, version
├── experiment_config.json      # exact config copy
├── experiment_summary.json     # epochs, best metric, final loss
├── metrics.csv                 # dashboard-compatible
└── checkpoints/{best,latest}.pt
```

```python
from tgraphx.experiments import Runner, load_config

cfg = load_config("examples/configs/synthetic_patch_graph.yaml")
runner = Runner(cfg)
history = runner.fit()
```

Or via the CLI (registered as console scripts):

```bash
tgraphx-train  examples/configs/synthetic_patch_graph.yaml
tgraphx-grid   examples/configs/grid_sweep.yaml
tgraphx-report runs/                  # → Markdown summary
```

Built-in callbacks: `EarlyStopping`, `ModelCheckpoint`, `CSVLoggerCallback`, `LearningRateLogger`. Multi-seed and grid sweeps via `GridRunner`. The dashboard reads the same files when you run `tgraphx-dashboard --logdir runs/<run_name>`.

See [docs/experiments.md](docs/experiments.md) for the full schema.

---

## Explainability

`tgraphx.explain` (new in v0.3.0) provides diagnostic explainability tools. They run on CPU by default, do not retain autograd graphs, and never make causal claims about a model's predictions.

```python
from tgraphx.explain import (
    node_feature_saliency, integrated_gradients,
    edge_perturbation_attribution, attention_to_edge_scores,
    patch_saliency_to_image_grid,
    export_explanation_metadata, export_edge_scores_csv,
)

sal      = node_feature_saliency(model, graph, target=label)
ig       = integrated_gradients(model, graph, target=label, steps=16)
edge_imp = edge_perturbation_attribution(model, graph, target=label, max_edges=64)

# For TensorGATLayer outputs.
out, attn = layer(x, edge_index, return_attention=True)
edge_scores = attention_to_edge_scores(attn, edge_index, head_reduce="mean")

# Patch-level heatmap from grid_shape metadata.
heatmap = patch_saliency_to_image_grid(sal, grid_shape=graph.metadata["grid_shape"])

# Export artefacts the dashboard can render (explicit paths only).
export_explanation_metadata("runs/demo/explanation_metadata.json",
                            method="saliency", target=int(label))
export_edge_scores_csv("runs/demo/explanation_edges.csv",
                       graph.edge_index, edge_scores, top_k=20)
```

See [docs/explainability.md](docs/explainability.md) for examples and limits.

---

## Optional integrations

| Adapter | Install | Notes |
|---------|---------|-------|
| Native synthetic / folder datasets | (none) | Deterministic, learnable, no network |
| Torchvision-backed datasets | already a TGraphX base dependency | `MNIST`, `CIFAR-10/100`, `SVHN`, `STL-10`, `FakeData`, … converted to patch graphs |
| PyTorch Geometric | `pip install "tgraphx[pyg]"` | `Planetoid`, `TUDataset`, generic adapter; data-format converters only |
| DGL | follow upstream install | citation graphs, generic adapter |
| OGB | `pip install "tgraphx[ogb]"` | node / link / graph property prediction wrappers + `OGBEvaluatorWrapper` |
| MLflow | `pip install "tgraphx[mlflow]"` | `MLflowLogger`, lazy import |
| TensorBoard | `pip install "tgraphx[tracking]"` | `TensorBoardLogger`, lazy import |
| Hardware monitoring | `pip install "tgraphx[monitoring]"` | `psutil` + `pynvml` for the dashboard's hardware panel |

TGraphX is interoperable with the PyG / DGL / OGB ecosystems through small data-format converters; it is not a drop-in replacement for either framework.

---

## Backend and platform support

| Platform | Forward | torch.compile | AMP | CI coverage |
|----------|---------|--------------|-----|-------------|
| Linux + CPU | yes | yes | bfloat16 (recommended) | Full Ubuntu CI on Python 3.10 / 3.11 / 3.12 |
| NVIDIA CUDA | yes | yes | float16 / bfloat16 | Local tests; no GPU runners in CI |
| Apple Silicon MPS | yes | partial | partial | macOS smoke CI (import + build) |
| Windows + CPU | yes | yes | yes | Windows smoke CI (Python 3.11) |
| macOS + CPU | yes | yes | yes | macOS smoke CI (Python 3.11) |
| Multi-GPU (DDP) | helpers | user-managed | user-managed | `tgraphx.distributed` rank-zero / barrier helpers; full DDP setup is the user's responsibility |

For the AMP recommendations, dtype-cast policy, and per-platform caveats, see [docs/performance.md](docs/performance.md). For the precise API stability classification, see [docs/api_stability.md](docs/api_stability.md).

---

## Dashboard, local-first, privacy

TGraphX ships a local-first training dashboard that reads run artefacts (`metrics.csv`, `run_metadata.json`, `experiment_config.json`, `experiment_summary.json`, `dataset_metadata.json`, `transform_metadata.json`, `metrics_summary.json`, `benchmark_results.json`, `explanation_metadata.json`, `explanation_edges.csv`, `explanation_patch_heatmap.json`, `hetero_graph_metadata.json`, `temporal_metadata.json`, `sampling_metadata.json`, `hardware_report.json`).

```bash
tgraphx-dashboard --logdir runs/demo
# → http://127.0.0.1:8765 (localhost only, no token needed)

tgraphx-dashboard --logdir runs/demo --host 0.0.0.0 --token MY_SECRET_TOKEN
tgraphx-dashboard --logdir runs/demo --export-html snapshot.html
```

Properties:

* Off by default. The dashboard is opt-in; nothing about importing TGraphX runs a server.
* Local-first. Reads files from your `--logdir`; never executes code, never loads checkpoints, never runs models.
* No telemetry, no analytics, no external CDN; the offline HTML export is fully self-contained.
* LAN access requires `--token`; localhost mode does not.
* Path-traversal protected: every file read is validated against the resolved logdir.
* Background threads are launched only when you call `launch_dashboard_background(...)` explicitly.

Helpers for writing metadata files the dashboard understands:

```python
from tgraphx import (
    write_run_metadata, write_dataset_metadata, write_transform_metadata,
    write_metrics_summary, write_benchmark_results,
    write_explanation_metadata, write_experiment_config,
    write_hardware_report, write_sampling_metadata,
    write_hetero_graph_metadata, write_temporal_metadata,
)
```

`write_graph_stats` from earlier releases is preserved.

See [docs/dashboard.md](docs/dashboard.md) for the full UI, security model, and offline export.

---

## Privacy summary

| Behaviour | Default |
|---|---|
| Telemetry / analytics | None — never |
| Remote calls at import | None |
| Dashboard | Off — launch explicitly |
| CSV / TensorBoard / MLflow logging | Off — create a logger explicitly |
| Hardware monitoring | Off — pass `include_hardware=True` to `env_report` |
| Checkpoints | Off — call `save_checkpoint` or attach `ModelCheckpoint` explicitly |
| Dataset downloads | Off — adapters require `download=True` |
| Background threads | None unless `launch_dashboard_background` is called |
| File writes | Only to paths the user provides |

No `~/.tgraphx` directory or user-level config is created.

---

## Examples and tutorials

```bash
python examples/run_all_fast_examples.py        # every CPU-safe demo
python examples/datasets_quickstart.py           # registry + synthetic datasets
python examples/transforms_metrics_demo.py       # Compose + classification report
python examples/synthetic_datasets_demo.py       # all 7 native synthetic datasets
python examples/sampling_demo_v028.py            # random walk + hetero + temporal sampling
python examples/training_with_dashboard.py      # fit() + CSVLogger → dashboard
python examples/graph_transformer_demo.py        # vector graph transformer
python examples/gat_chunking_demo.py             # GAT chunked forward parity
```

The full set of demos lives under `examples/` (see [examples/README.md](examples/README.md) when present).

### Validation scripts

A small set of stand-alone validation scripts proves that the surfaces
TGraphX advertises actually work end-to-end:

| Script | Purpose |
|--------|---------|
| `examples/device_validation.py` | Runs vector + spatial layer smokes on CPU / CUDA / MPS, optionally under autocast (`--amp`). Emits a JSON report. |
| `examples/dashboard_artifact_validation.py` | Writes every supported dashboard metadata file, exports an offline HTML snapshot, and checks for CDN / token / `eval` leaks. |
| `examples/experiment_end_to_end_validation.py` | Trains a tiny synthetic experiment, asserts dashboard files exist, resumes from a checkpoint. |
| `examples/explainability_end_to_end_validation.py` | Trains, runs saliency / integrated gradients / edge perturbation, exports dashboard-readable explanation artefacts. |
| `examples/public_datasets/*.py` | Manual, opt-in scripts that exercise the torchvision / PyG / DGL / OGB adapters against real upstream loaders. They require `--download`, cap dataset size, skip cleanly if the optional package is missing, and never run in CI. |

See [docs/public_dataset_validation.md](docs/public_dataset_validation.md)
and [docs/device_validation.md](docs/device_validation.md) for the
exact invocations and policy.

---

## Boundaries

TGraphX is intentionally focused.

* TGraphX is not a drop-in replacement for PyG or DGL; the optional adapters convert data only.
* TGraphX provides DDP-aware helpers and a single-process smoke example, not an automatic multi-GPU training framework.
* Per-pixel and per-voxel GAT attention scores are not shipped — naive `[E, K, H, W]` score tensors are memory-prohibitive. Per-channel attention is shipped as `attention_mode="channel"`.
* Recurrent temporal memory modules (TGN, TGAT) are not shipped; temporal workflows use a stateless snapshot-loop pattern.
* Synthetic datasets are sanity / tutorial datasets, not benchmarks; benchmark scripts are reproducibility tools, not real-world performance comparisons.
* `kNN`, `radius`, `IoU`, and fully-connected graph builders are mathematically `O(N²)` and warn on large `N`; chunked variants reduce peak memory.
* Universal arbitrary-rank node-feature support across every layer is a future direction; the supported layouts today are vector `[N, D]`, 2-D spatial `[N, C, H, W]`, and 3-D volumetric `[N, C, D, H, W]`.

Detailed limitations live in [docs/limitations.md](docs/limitations.md); the roadmap is in [docs/roadmap.md](docs/roadmap.md).

---

## Citation

```bibtex
@misc{sajjadi2025tgraphxtensorawaregraphneural,
    title={TGraphX: Tensor-Aware Graph Neural Network for Multi-Dimensional Feature Learning},
    author={Arash Sajjadi and Mark Eramian},
    year={2025},
    eprint={2504.03953},
    archivePrefix={arXiv},
    primaryClass={cs.CV},
    url={https://arxiv.org/abs/2504.03953},
}
```

---

## Authorship

**Software author and maintainer:** Arash Sajjadi, PhD Candidate in Computer Science, University of Saskatchewan ([arash.sajjadi@usask.ca](mailto:arash.sajjadi@usask.ca)).
**Academic supervision:** Mark Eramian, PhD supervisor / academic advisor, University of Saskatchewan.

---

## License

TGraphX is released under the [MIT License](LICENSE).
