Metadata-Version: 2.4
Name: kagu
Version: 0.1.0
Summary: Bayesian graphical causal models
Requires-Python: >=3.11
Requires-Dist: arviz<1.0,>=0.17
Requires-Dist: matplotlib>=3.7
Requires-Dist: numpy>=1.26
Requires-Dist: pandas>=2.0
Requires-Dist: pymc>=5.0
Requires-Dist: tqdm>=4.67.3
Provides-Extra: dev
Requires-Dist: ipykernel; extra == 'dev'
Requires-Dist: jupyter; extra == 'dev'
Requires-Dist: pytest-cov; extra == 'dev'
Requires-Dist: pytest>=8.0; extra == 'dev'
Provides-Extra: docs
Requires-Dist: mkdocs-gen-files>=0.5; extra == 'docs'
Requires-Dist: mkdocs-jupyter>=0.24; extra == 'docs'
Requires-Dist: mkdocs-literate-nav>=0.6; extra == 'docs'
Requires-Dist: mkdocs-material>=9.0; extra == 'docs'
Requires-Dist: mkdocstrings[python]>=0.24; extra == 'docs'
Description-Content-Type: text/markdown

# Kagu

**Kagu** is a Python library for fitting Bayesian graphical causal models (GCMs). It treats causality as a first-class concept rather than an afterthought, and provides a unified framework for causal inference that replaces the traditional patchwork of context-specific methods.

Models are specified as directed acyclic graphs (DAGs), fitted with full Bayesian inference via PyMC, and causal effects are extracted by propagating interventions forward through the structural model.

---

## Installation

```bash
uv add kagu
```

---

## Example

```python
import pandas as pd
import kagu as kg

# --- 1. Fit ---
data = pd.read_csv("health_data.csv")

model = kg.Model(
    dag={
        "age":      [],
        "smoking":  ["age"],
        "exercise": ["age"],
        "fitness":  ["exercise", "age"],
        "health":   ["smoking", "fitness", "age"],
    }
)
model.fit(data)

# Each node's conditional P(X | parents) is fitted as a separate PyMC model,
# exploiting the Markov blanket factorisation for efficiency.

# --- 2. Causal effects ---

# Unit effect at the mean — equivalent to a regression coefficient.
# Returns a full posterior — default HDI is 90%.
effect = model.effects("smoking", "health")
print(effect.summary())
# ┌─────────────────────────────────────────────────────┐
# │ smoking → health                                    │
# │ From: 8.2   To: 9.2                                 │
# │ Mean effect: -0.31   HDI 90%: [-0.45, -0.16]        │
# └─────────────────────────────────────────────────────┘

# Effect of a 1-SD increase centred at the mean
effect = model.effects("smoking", "health", std_units=True)

# Specific contrast: non-smoker (0) vs heavy smoker (20 cigarettes/day)
effect = model.effects("smoking", "health", values=(0, 20))
print(effect.summary())

# Conditional effect: effect of smoking for older adults only
effect = model.effects("smoking", "health", values=(0, 20), conditions={"age": 65})
print(effect.summary())

# Custom HDI width
effect = model.effects("smoking", "health", hdi=0.95)

# --- 3. Effect plot ---

# Sweep over a range of treatment values and plot E[health | do(smoking=x)]
# with an HDI ribbon
effect = model.effects("smoking", "health", sweep=True)
effect.plot()  # matplotlib figure, shows posterior mean + HDI band

# Override sweep defaults
effect = model.effects("smoking", "health", sweep=True,
                        sweep_n=100, sweep_range=(0, 30))
effect.plot()

# --- 4. Summaries and diagnostics ---

# Coefficient table across all mechanisms
print(model.summary())
# ┌──────────┬────────────┬──────────┬────────────────────┐
# │ node     │ parameter  │ mean     │ HDI 90%            │
# ├──────────┼────────────┼──────────┼────────────────────┤
# │ smoking  │ alpha      │  5.10    │ [ 3.82,  6.41]     │
# │ smoking  │ beta_age   │  0.21    │ [ 0.14,  0.28]     │
# │ fitness  │ alpha      │ 12.30    │ [10.91, 13.72]     │
# │ ...      │ ...        │ ...      │ ...                │
# └──────────┴────────────┴──────────┴────────────────────┘

# R-hat and ESS for a specific node (or all nodes if omitted)
model.diagnostics("health")

# --- 5. Plots ---

model.plot_dag()                  # DAG visualisation via matplotlib
model.plot_posterior("health")    # posterior predictive check for one node

# --- 6. Save and load ---

model.save("health_model.pkl")
model = kg.Model.load("health_model.pkl")
```

---

## Core concepts

### Structural causal models

Kagu represents a causal system as a set of structural equations:

```
X_i = f_i(Pa(X_i), ε_i)
```

where `Pa(X_i)` are the parents of node `X_i` in the DAG and `ε_i` is independent noise. Each `f_i` is a *mechanism* — a parameterised probabilistic model fitted to the data.

The joint distribution factorises over nodes:

```
P(X_1, ..., X_n) = ∏ P(X_i | Pa(X_i))
```

This means each mechanism can be fitted independently, once per node, rather than as a single large joint model. This is both computationally efficient and conceptually clean.

### Causal effect estimation

Effects are estimated via the *do-operator*. To compute the effect of `do(X = x)` on outcome `Y`:

1. Fix the treatment node at `x` (severing its incoming edges).
2. Propagate forward through the DAG in topological order, drawing from each downstream mechanism's posterior.
3. Compare `E[Y | do(X = x)]` to `E[Y | do(X = x')]` for your chosen baseline `x'`.

Because the full structural model is fitted, do-calculus is exact — no backdoor criterion or propensity score adjustments are needed.

---

## Mechanisms

| Class | Model | Use case |
|---|---|---|
| `LinearMechanism` | `X_i ~ Normal(α + Σ βⱼ Paⱼ, σ)` | Continuous, unbounded nodes |

The mechanism framework is designed to be extensible — any node can use a different mechanism, and custom mechanisms can be added by implementing the `Mechanism` ABC.

---

## Roadmap

- [ ] Documentation site
- [ ] CI/CD
- [ ] **GLM mechanisms** — support for the most common data types out of the box:
  - `LogNormalMechanism` / `GammaMechanism` — positive continuous (income, reaction times, counts-as-continuous)
  - `PoissonMechanism` / `NegBinomialMechanism` — count data
  - `BernoulliMechanism` / `BetaMechanism` — binary outcomes and proportions
  - `OrderedMechanism` — ordinal data (Likert scales, ratings)
- [ ] Model fit diagnostics per node (posterior predictive checks, LOO)
- [ ] User-defined priors via mechanism configuration
- [ ] Model comparison per node (WAIC / LOO)
