Metadata-Version: 2.4
Name: inter-gnn
Version: 0.1.0
Summary: Interpretable GNN-Based Framework for Drug Discovery and Candidate Screening
Author: Harshal Loya, Jash Chauhan, Het Gala
License: MIT
Project-URL: Homepage, https://github.com/inter-gnn/inter-gnn
Project-URL: Documentation, https://inter-gnn.readthedocs.io
Project-URL: Repository, https://github.com/inter-gnn/inter-gnn
Keywords: graph-neural-networks,drug-discovery,explainable-ai,molecular-property-prediction,interpretability,activity-cliffs,concept-whitening
Classifier: Development Status :: 3 - Alpha
Classifier: Intended Audience :: Science/Research
Classifier: License :: OSI Approved :: MIT License
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.9
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Classifier: Topic :: Scientific/Engineering :: Bio-Informatics
Classifier: Topic :: Scientific/Engineering :: Chemistry
Requires-Python: >=3.9
Description-Content-Type: text/markdown
Requires-Dist: torch>=2.0.0
Requires-Dist: torch-geometric>=2.4.0
Requires-Dist: rdkit>=2023.3.1
Requires-Dist: numpy>=1.24.0
Requires-Dist: scipy>=1.10.0
Requires-Dist: pandas>=2.0.0
Requires-Dist: scikit-learn>=1.2.0
Requires-Dist: matplotlib>=3.7.0
Requires-Dist: pyyaml>=6.0
Requires-Dist: tqdm>=4.65.0
Provides-Extra: viz
Requires-Dist: plotly>=5.14.0; extra == "viz"
Requires-Dist: py3Dmol>=2.0.0; extra == "viz"
Requires-Dist: seaborn>=0.12.0; extra == "viz"
Requires-Dist: ipywidgets>=8.0.0; extra == "viz"
Provides-Extra: dev
Requires-Dist: pytest>=7.3.0; extra == "dev"
Requires-Dist: pytest-cov>=4.1.0; extra == "dev"
Requires-Dist: black>=23.0.0; extra == "dev"
Requires-Dist: ruff>=0.1.0; extra == "dev"
Requires-Dist: mypy>=1.4.0; extra == "dev"

# InterGNN — Interpretable GNN-Based Framework for Drug Discovery

![Python 3.9+](https://img.shields.io/badge/python-3.9%2B-blue)
![PyTorch](https://img.shields.io/badge/PyTorch-2.0%2B-red)
![License: MIT](https://img.shields.io/badge/License-MIT-green)

An interpretable Graph Neural Network framework combining state-of-the-art molecular property prediction with inherent and post-hoc explainability methods. Designed for drug discovery workflows requiring trust, transparency, and scientific insight.

---

## Architecture

```
SMILES → Standardize → Featurize → MolecularGNNEncoder ──┐
                                                          ├─ CrossAttention → TaskHead → Prediction
Protein → ProteinGraphBuilder → TargetGNNEncoder ─────────┘
                                    │
                    ┌───────────────┼───────────────┐
                    ▼               ▼               ▼
              PrototypeLayer   MotifHead    ConceptWhitening
              (case-based)    (substructure)  (axis-aligned)
```

### Key Features

| Feature | Description |
|---------|-------------|
| **Molecular Encoder** | GINEConv with edge-aware message passing and chirality features |
| **Target Encoder** | Multi-head GATConv for residue-level protein graphs |
| **Cross-Attention Fusion** | Atom-residue interaction for drug-target affinity |
| **PAGE Prototypes** | Case-based classification via learned prototypes |
| **MAGE Motifs** | Differentiable motif mask generation with Gumbel-sigmoid |
| **Concept Whitening** | ZCA whitening + axis-aligned concept interpretability |
| **CF-GNNExplainer** | Counterfactual minimal perturbation explanations |
| **T-GNNExplainer** | Sufficient subgraph identification |
| **CIDER Diagnostics** | Causal invariance testing across environments |

---

## Installation

```bash
# Clone the repository
git clone https://github.com/your-org/Inter_gnn.git
cd Inter_gnn

# Install with all dependencies
pip install -e ".[vis,dev]"
```

### Requirements

- Python ≥ 3.9
- PyTorch ≥ 2.0
- PyTorch Geometric ≥ 2.4
- RDKit ≥ 2023.03
- NumPy, SciPy, Pandas, scikit-learn, matplotlib

---

## Quick Start

### 1. Create a Configuration

```yaml
# config.yaml
data:
  dataset_name: tox21
  split_method: scaffold
  batch_size: 32
  detect_cliffs: true
  compute_concepts: true

model:
  hidden_dim: 256
  num_mol_layers: 4
  task_type: classification
  num_tasks: 12

interpretability:
  use_prototypes: true
  num_prototypes_per_class: 5
  use_motifs: true
  num_motifs: 8
  use_concept_whitening: true

training:
  pretrain_epochs: 50
  finetune_epochs: 100
  learning_rate: 0.001
```

### 2. Train

```bash
inter-gnn train --config config.yaml
```

### 3. Evaluate

```bash
inter-gnn evaluate --config config.yaml --checkpoint checkpoints/finetune_best.pt
```

### 4. Generate Explanations

```bash
inter-gnn explain --config config.yaml --checkpoint model.pt --smiles "CC(=O)Oc1ccccc1C(=O)O"
```

### 5. Dashboard

```bash
inter-gnn dashboard --config config.yaml --checkpoint model.pt --output report/
```

---

## Python API

```python
from inter_gnn.training.config import InterGNNConfig
from inter_gnn.training.trainer import InterGNNTrainer
from inter_gnn.data.datamodule import InterGNNDataModule

# Load config
config = InterGNNConfig.from_yaml("config.yaml")

# Build data
dm = InterGNNDataModule(config)
dm.prepare_data()
dm.setup()

# Train (two-phase: pretrain → finetune)
trainer = InterGNNTrainer(config)
history = trainer.fit(dm.train_dataloader(), dm.val_dataloader())

# Explain a molecule
from inter_gnn.data.featurize import smiles_to_graph
import torch

graph = smiles_to_graph("CC(=O)Oc1ccccc1C(=O)O")
batch = torch.zeros(graph.x.shape[0], dtype=torch.long)
output = trainer.model(graph.x, graph.edge_index, graph.edge_attr, batch)

importance = trainer.model.get_node_importance(
    graph.x, graph.edge_index, graph.edge_attr, batch
)
```

---

## Module Overview

```
inter_gnn/
├── data/                    # Data & Preprocessing
│   ├── standardize.py       #   Molecule standardization (tautomer, charge, stereo)
│   ├── featurize.py         #   SMILES → molecular graph (~78-dim atom, ~14-dim bond)
│   ├── protein.py           #   Protein sequence → k-NN / contact graph
│   ├── concepts.py          #   SMARTS concept library (~30 patterns)
│   ├── cliffs.py            #   Activity cliff detection
│   ├── splits.py            #   Scaffold, cold-target, temporal splits
│   ├── datasets.py          #   9 benchmark dataset loaders
│   └── datamodule.py        #   DataModule wrapper
├── models/                  # Core Model
│   ├── encoders.py          #   GINEConv (molecule) + GATConv (protein) encoders
│   ├── attention.py         #   Cross-attention fusion + bilinear alternative
│   ├── task_heads.py        #   Classification + regression heads
│   └── core_model.py        #   Unified InterGNN model
├── interpretability/        # Intrinsic Interpretability
│   ├── prototypes.py        #   PAGE-inspired prototype layer
│   ├── motifs.py            #   MAGE-inspired motif generator
│   ├── concept_whitening.py #   ZCA whitening + concept alignment
│   └── stability.py         #   Explanation stability regularizer
├── explainers/              # Post-hoc Explanations
│   ├── cf_explainer.py      #   CF-GNNExplainer (counterfactual)
│   ├── t_explainer.py       #   T-GNNExplainer (sufficient subgraph)
│   └── cider.py             #   CIDER causal invariance diagnostics
├── training/                # Training Pipeline
│   ├── losses.py            #   Combined multi-objective loss
│   ├── trainer.py           #   Two-phase trainer (pretrain + finetune)
│   ├── callbacks.py         #   EarlyStopping, checkpointing, monitoring
│   └── config.py            #   YAML config with dataclass hierarchy
├── evaluation/              # Evaluation Metrics
│   ├── predictive.py        #   ROC-AUC, PR-AUC, RMSE, CI, etc.
│   ├── faithfulness.py      #   Deletion/Insertion AUC, sufficiency/necessity
│   ├── stability_metrics.py #   Jaccard stability, cliff consistency
│   ├── chemical_validity.py #   Valence checks, SMARTS match rates
│   ├── causal.py            #   Invariance violation, environment alignment
│   └── statistical.py       #   Paired bootstrap, randomization tests
├── visualization/           # Visualization Tools
│   ├── molecule_viz.py      #   Atom/bond saliency rendering
│   ├── prototype_viz.py     #   Prototype gallery
│   ├── motif_viz.py         #   Motif activation heatmaps
│   ├── concept_viz.py       #   Concept activation bars
│   ├── counterfactual_viz.py#   Counterfactual edit display
│   └── dashboard.py         #   HTML batch-export dashboard
└── cli.py                   # Command-line interface
```

---

## Supported Datasets

| Dataset | Type | Tasks | Source |
|---------|------|-------|--------|
| MUTAG | Classification | 1 | TUDataset |
| Tox21 | Classification | 12 | MoleculeNet |
| ClinTox | Classification | 2 | MoleculeNet |
| QM9 | Regression | 19 | MoleculeNet |
| Davis | DTA Regression | 1 | TDC |
| KIBA | DTA Regression | 1 | TDC |
| BindingDB | DTA Regression | 1 | TDC |
| SIDER | Classification | 27 | MoleculeNet |
| SynLethDB | Classification | 1 | Custom |

---

## Two-Phase Training

1. **Pre-training** — Trains encoders + task head with prediction loss only
2. **Joint Fine-tuning** — Attaches interpretability modules, trains all losses:
   - `L_pred`: Task prediction (BCE/MSE)
   - `L_pull/push/div`: Prototype losses
   - `L_sparsity/conn`: Motif losses
   - `L_align/decorr`: Concept whitening losses
   - `L_stability`: Explanation stability

---

## Citation

```bibtex
@software{inter_gnn2025,
  title={InterGNN: Interpretable Graph Neural Network for Drug Discovery},
  year={2025},
}
```

## License

MIT License
