"""Plotting utilities to surface metric-to-metric relationships.
These plots make the sample-vs-alarm gap (Andrade 2024), threshold
sensitivity, cadence sensitivity, and IoC vs surrogate transparent.
All return (fig, ax); none save to disk — that's the caller's job.
"""
from __future__ import annotations
import numpy as np
[docs]
def sensitivity_vs_fp_per_hour(sweep_df, *, ax=None,
acceptable_fp_per_hour: float | None = 0.15,
wearable_fp_per_hour: float | None = 0.042):
"""Operating-curve plot from forecasting.sweep_thresholds output.
Plots sensitivity (y) vs FP/hr (x). Optional reference lines for
Mormann's 0.15/h and the wearable target 0.042/h.
"""
import matplotlib.pyplot as plt
if ax is None:
_, ax = plt.subplots(figsize=(5, 4))
df = sweep_df.sort_values("fp_per_hour")
ax.plot(df["fp_per_hour"], df["sensitivity"],
marker="o", linestyle="-", label="model")
if acceptable_fp_per_hour is not None:
ax.axvline(acceptable_fp_per_hour, color="grey", linestyle="--",
label=f"Mormann ({acceptable_fp_per_hour:g}/h)")
if wearable_fp_per_hour is not None:
ax.axvline(wearable_fp_per_hour, color="black", linestyle=":",
label=f"wearable ({wearable_fp_per_hour:g}/h)")
ax.set_xlabel("FP per hour (interictal)")
ax.set_ylabel("alarm-based sensitivity")
ax.set_xscale("symlog", linthresh=0.01)
ax.set_ylim(0, 1.02)
ax.legend(loc="lower right", fontsize=8)
return ax.figure, ax
[docs]
def sample_vs_alarm_scatter(per_patient_df, *, ax=None,
x_metric: str = "roc_auc",
y_metric: str = "sensitivity"):
"""Reproduce the Andrade 2024 finding: per-patient sample-based AUC
vs alarm-based sensitivity. Identity line shows the (false) hope of
direct correspondence.
"""
import matplotlib.pyplot as plt
if ax is None:
_, ax = plt.subplots(figsize=(5, 4.5))
ax.scatter(per_patient_df[x_metric], per_patient_df[y_metric],
s=40, alpha=0.85)
ax.plot([0, 1], [0, 1], color="grey", linestyle=":", label="identity")
ax.axhline(0.5, color="red", linestyle="--", alpha=0.5,
label="chance sensitivity")
ax.set_xlabel(f"sample-based {x_metric}")
ax.set_ylabel(f"alarm-based {y_metric}")
ax.set_xlim(0, 1.05)
ax.set_ylim(-0.05, 1.05)
ax.legend(loc="lower right", fontsize=8)
return ax.figure, ax
[docs]
def cadence_ablation(sweep_df, *, ax=None, x: str = "cadence_s",
y: str = "fp_per_hour", logx: bool = True):
"""How does FP/hr (or any metric) move as we change the cadence?
Input: forecasting.sweep_policies output sorted by cadence.
"""
import matplotlib.pyplot as plt
if ax is None:
_, ax = plt.subplots(figsize=(5, 4))
df = sweep_df.sort_values(x)
ax.plot(df[x], df[y], marker="s", linestyle="-")
if logx:
ax.set_xscale("log")
ax.set_xlabel(x)
ax.set_ylabel(y)
return ax.figure, ax
[docs]
def ioc_vs_surrogate(sweep_df, *, ax=None):
"""IoC (sensitivity − surrogate_sensitivity) across thresholds.
Useful to see the threshold range where the model truly beats chance.
"""
import matplotlib.pyplot as plt
if ax is None:
_, ax = plt.subplots(figsize=(5, 4))
df = sweep_df.sort_values("threshold")
ax.plot(df["threshold"], df["sensitivity"], label="model sens",
marker="o")
ax.plot(df["threshold"], df["surrogate_sensitivity"],
label="surrogate sens", linestyle="--", marker="x")
ax.fill_between(df["threshold"], df["surrogate_sensitivity"],
df["sensitivity"], where=df["sensitivity"]
> df["surrogate_sensitivity"], alpha=0.2,
label="IoC > 0")
ax.set_xlabel("alarm threshold")
ax.set_ylabel("sensitivity")
ax.legend(fontsize=8)
ax.set_ylim(0, 1.02)
return ax.figure, ax
[docs]
def reliability_diagram(cal_report, *, ax=None,
title: str = "Reliability diagram"):
"""Plot a reliability diagram from a CalibrationReport.
The dashed identity line represents perfect calibration. Bin counts
are shown via marker size.
"""
import matplotlib.pyplot as plt
if ax is None:
_, ax = plt.subplots(figsize=(5, 5))
cnts = cal_report.bin_counts
sizes = 30 + 200 * (cnts / max(1, cnts.max()))
ax.plot([0, 1], [0, 1], color="grey", linestyle="--", label="ideal")
ax.plot(cal_report.bin_centers, cal_report.bin_observed,
color="C0", linestyle="-", marker="o", markersize=0,
label="model")
ax.scatter(cal_report.bin_centers, cal_report.bin_observed,
s=sizes, color="C0", edgecolor="white", zorder=3)
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
ax.set_xlabel("predicted probability")
ax.set_ylabel("observed positive rate")
ax.set_title(title)
ece = cal_report.expected_calibration_error
ax.text(0.05, 0.95, f"ECE={ece:.3f}\nBrier={cal_report.brier:.3f}\n"
f"Rel={cal_report.reliability:.3f}, "
f"Res={cal_report.resolution:.3f}",
transform=ax.transAxes, va="top", fontsize=9,
bbox=dict(boxstyle="round", facecolor="white", alpha=0.85))
ax.legend(loc="lower right", fontsize=8)
return ax.figure, ax
[docs]
def metric_correlation_heatmap(per_patient_df, *, ax=None,
metrics=None, method: str = "spearman"):
"""Heatmap of metric-to-metric correlations across patients.
Surfaces redundancy ("this metric tells us nothing new") and the
sample-vs-alarm divergence axis.
"""
import matplotlib.pyplot as plt
if metrics is None:
candidates = ["roc_auc", "pr_auc", "balanced_accuracy", "mcc",
"sensitivity", "precision", "f1", "fp_per_hour",
"ioc", "time_in_warning_frac"]
metrics = [m for m in candidates if m in per_patient_df.columns]
sub = per_patient_df[metrics].select_dtypes(include="number")
corr = sub.corr(method=method)
if ax is None:
_, ax = plt.subplots(figsize=(6, 5))
im = ax.imshow(corr.values, vmin=-1, vmax=1, cmap="RdBu_r")
ax.set_xticks(range(len(metrics)))
ax.set_yticks(range(len(metrics)))
ax.set_xticklabels(metrics, rotation=45, ha="right", fontsize=8)
ax.set_yticklabels(metrics, fontsize=8)
ax.figure.colorbar(im, ax=ax, label=f"{method} ρ")
for i in range(len(metrics)):
for j in range(len(metrics)):
ax.text(j, i, f"{corr.values[i, j]:.2f}",
ha="center", va="center", fontsize=7,
color="white" if abs(corr.values[i, j]) > 0.5 else "black")
return ax.figure, ax