Module statkit.decision
Evaluate models using decision curve analysis.
Expand source code
"""Evaluate models using decision curve analysis."""
from typing import Literal, Optional
from matplotlib import pyplot as plt # type: ignore
from numpy import array, divide, linspace, ones_like, r_, zeros_like
from numpy.typing import NDArray
from pandas import Series
from sklearn.metrics import confusion_matrix # type: ignore
from sklearn.utils import column_or_1d # type: ignore
def _binary_classification_thresholds(y_true, y_proba, thresholds):
"""Compute false and true positives for given probability thresholds."""
y_true = column_or_1d(y_true)
y_proba = column_or_1d(y_proba)
matrices = []
for t in thresholds:
if t < 1:
y_pred = (y_proba > t).astype(int)
else:
y_pred = (y_proba >= t).astype(int)
cm = confusion_matrix(y_true, y_pred)
matrices.append(cm)
matrices = array(matrices)
tns = matrices[:, 0, 0]
fps = matrices[:, 0, 1]
fns = matrices[:, 1, 0]
tps = matrices[:, 1, 1]
return tns, fps, fns, tps
def net_benefit(
y_true: Series | NDArray,
y_pred: Series | NDArray,
thresholds=100,
action: bool = True,
):
"""Net benefit of taking an action using a model's predictions.
Args:
y_true: Binary ground truth label (1: positive, 0: negative class).
y_pred: Probability of positive class label.
thresholds: When an array, evaluate net benefit at these coordinates (the
probability thresholds). When an int, the number of x coordinates.
action: When `True` (`False`), estimate net benefit of taking (not taking) an
action/intervention/treatment.
Returns:
thresholds: Probability threshold of prediction a positive class.
benefit: The net benefit corresponding to the thresholds.
References:
[1]: Vickers-Elkin. "Decision curve analysis: a novel method for evaluating
prediction models." Medical Decision Making 26.6 (2006): 565-574.
[2]: Rousson-Zumbrunn. "Decision curve analysis revisited: overall net
benefit, relationships to ROC curve analysis, and application to case-control
studies." BMC medical informatics and decision making 11.1 (2011): 1–9.
"""
if set(y_true.astype(int)) != {0, 1}:
raise ValueError(
"Decision curve analysis only supports binary classification (with labels 1 and 0)."
)
if isinstance(thresholds, int):
thresholds = linspace(0, 1, num=thresholds)
N = len(y_true)
tns, fps, fns, tps = _binary_classification_thresholds(y_true, y_pred, thresholds)
if action:
loss_over_profit = divide(
thresholds, 1 - thresholds, where=thresholds < 1, out=zeros_like(thresholds)
)
benefit = tps / N - fps / N * loss_over_profit
else:
# Invert 0<-->1 so that true positives are true negatives, and false positives
# are false negatives.
profit_over_loss = divide(
1 - thresholds,
thresholds,
where=thresholds != 0,
out=zeros_like(thresholds),
)
benefit = tns / N - fns / N * profit_over_loss
return thresholds, benefit
def net_benefit_oracle(y_true, action: bool = True) -> float:
"""Net benefit of omniscient strategy, i.e., a hypothetical perfect predictor."""
if action:
return y_true.mean()
return 1 - y_true.mean()
def net_benefit_action(y_true, thresholds, action: bool = True):
"""Net benefit of always doing an action/intervention/treatment.
Args:
action: When `False`, invert positive label in `y_true`.
"""
if action:
loss_over_profit = divide(
thresholds,
1 - thresholds,
where=thresholds < 1,
out=1e9 * ones_like(thresholds),
)
return y_true.mean() - (1 - y_true.mean()) * loss_over_profit
profit_over_loss = divide(
1 - thresholds,
thresholds,
where=thresholds != 0,
out=1e9 * ones_like(thresholds),
)
return 1 - y_true.mean() - y_true.mean() * profit_over_loss
def overall_net_benefit(y_true, y_pred, n_thresholds: int = 100):
"""Net benefit combining both taking and not-taking action."""
thresholds_action, benefit_action = net_benefit(
y_true, y_pred, n_thresholds, action=True
)
_, benefit_no_action = net_benefit(y_true, y_pred, thresholds_action, action=False)
return thresholds_action, benefit_action + benefit_no_action
class NetBenefitDisplay:
"""Net benefit decision curve analysis visualisation.
Args:
threshold_probability: Probability to dichotomise the predicted probability
of the model.
net_benefit: Net benefit of taking an action as a function of
`threshold_probability`.
oracle: The (constant) net benefit of a perfect predictor.
benefit_type: Literal["action", "noop", "overall"] = "action",
Example:
```python
from sklearn.datasets import make_blobs
from sklearn.linear_model import LogisticRegression
from statkit.decision import NetBenefitDisplay
centers = [[0, 0], [1, 1]]
X_train, y_train = make_blobs(
centers=centers, cluster_std=1, n_samples=20, random_state=5
)
X_test, y_test = make_blobs(
centers=centers, cluster_std=1, n_samples=20, random_state=1005
)
clf = LogisticRegression(random_state=5).fit(X_train, y_train)
y_pred_base = clf.predict_proba(X_test)[:, 1]
NetBenefitDisplay.from_predictions(y_test, y_pred_base, name='Logistic Regression')
```
"""
def __init__(
self,
threshold_probability,
net_benefit,
net_benefit_action=None,
net_benefit_noop=None,
benefit_type: Literal["action", "noop", "overall"] = "action",
oracle: Optional[float] = None,
estimator_name: Optional[str] = None,
):
self.threshold_probability = threshold_probability
self.net_benefit = net_benefit
self.net_benefit_action = net_benefit_action
self.net_benefit_noop = net_benefit_noop
self.benefit_type = benefit_type
self.estimator_name = estimator_name
self.oracle = oracle
def plot(self, show_references: bool = True, ax=None):
"""
Args:
show_references: Show oracle (requires prevalence) and no
action/treatment/intervention reference curves.
ax: Optional axes object to plot on. If `None`, a new figure and axes is
created.
"""
if ax is None:
_, ax = plt.subplots()
self.ax_ = ax
self.figure_ = ax.figure
ax.plot(
self.threshold_probability, self.net_benefit, "-", label=self.estimator_name
)
if show_references:
ax.plot(
self.threshold_probability,
self.net_benefit_action,
":",
label="Always act",
)
ax.plot(
self.threshold_probability,
self.net_benefit_noop,
"-.",
label="Never act",
)
if self.oracle is not None:
ax.plot([0, 1], [self.oracle, self.oracle], "--", label="Oracle")
ylabel = "Net benefit"
if self.benefit_type == "action":
ylabel = "Net benefit (action)"
if self.benefit_type == "noop":
ylabel = "Net benefit (no-action)"
if self.benefit_type == "overall":
ylabel = "Overall net benefit"
ax.set(xlabel="Threshold probability", ylabel=ylabel)
margin = 0.05
ax.set_ylim([-margin, 1 + margin])
ax.legend(loc="upper right", frameon=False)
return self
@classmethod
def from_predictions(
cls,
y_true,
y_pred,
benefit_type: Literal["action", "noop", "overall"] = "action",
thresholds: int = 100,
name: Optional[str] = None,
show_references: bool = True,
ax=None,
):
"""Make a net benefit plot from true and predicted labels.
Args:
y_true: Binary ground truth label (1: positive, 0: negative class).
y_pred: Predicted class labels.
benefit_type: Type of net benefit curve. `"action"`: net benefit of
treatment/intervention/action; `"noop`: net benefit of no
treatment/intervention/action; `"overall"`: overall net benefit (see
`overall_net_benefit`).
show_references: Show oracle (requires prevalence) and no
action/treatment/intervention reference curves.
thresholds: When an array, evaluate net benefit at these coordinates (the
probability thresholds). When an int, the number of x coordinates.
ax: Optional axes object to plot on. If `None`, a new figure and axes is
created.
"""
if benefit_type == "action":
oracle = net_benefit_oracle(y_true, action=True)
thresholds, benefit = net_benefit(y_true, y_pred, thresholds)
benefit_action = net_benefit_action(y_true, thresholds, action=True)
benefit_noop = zeros_like(benefit)
elif benefit_type == "noop":
oracle = net_benefit_oracle(y_true, action=False)
thresholds, benefit = net_benefit(y_true, y_pred, thresholds, action=False)
benefit_noop = net_benefit_action(y_true, thresholds, action=False)
benefit_action = zeros_like(benefit)
elif benefit_type == "overall":
oracle = 1.0
thresholds, benefit = overall_net_benefit(y_true, y_pred, thresholds)
benefit_action = net_benefit_action(y_true, thresholds, action=True)
benefit_noop = net_benefit_action(y_true, thresholds, action=False)
return cls(
thresholds,
benefit,
benefit_action,
benefit_noop,
benefit_type,
oracle,
estimator_name=name,
).plot(ax=ax, show_references=show_references)
Functions
def net_benefit(y_true: pandas.core.series.Series | numpy.ndarray[typing.Any, numpy.dtype[+_ScalarType_co]], y_pred: pandas.core.series.Series | numpy.ndarray[typing.Any, numpy.dtype[+_ScalarType_co]], thresholds=100, action: bool = True)
-
Net benefit of taking an action using a model's predictions.
Args
y_true
- Binary ground truth label (1: positive, 0: negative class).
y_pred
- Probability of positive class label.
thresholds
- When an array, evaluate net benefit at these coordinates (the probability thresholds). When an int, the number of x coordinates.
action
- When
True
(False
), estimate net benefit of taking (not taking) an action/intervention/treatment.
Returns
thresholds
- Probability threshold of prediction a positive class.
benefit
- The net benefit corresponding to the thresholds.
References
[1]: Vickers-Elkin. "Decision curve analysis: a novel method for evaluating prediction models." Medical Decision Making 26.6 (2006): 565-574.
[2]: Rousson-Zumbrunn. "Decision curve analysis revisited: overall net benefit, relationships to ROC curve analysis, and application to case-control studies." BMC medical informatics and decision making 11.1 (2011): 1–9.
Expand source code
def net_benefit( y_true: Series | NDArray, y_pred: Series | NDArray, thresholds=100, action: bool = True, ): """Net benefit of taking an action using a model's predictions. Args: y_true: Binary ground truth label (1: positive, 0: negative class). y_pred: Probability of positive class label. thresholds: When an array, evaluate net benefit at these coordinates (the probability thresholds). When an int, the number of x coordinates. action: When `True` (`False`), estimate net benefit of taking (not taking) an action/intervention/treatment. Returns: thresholds: Probability threshold of prediction a positive class. benefit: The net benefit corresponding to the thresholds. References: [1]: Vickers-Elkin. "Decision curve analysis: a novel method for evaluating prediction models." Medical Decision Making 26.6 (2006): 565-574. [2]: Rousson-Zumbrunn. "Decision curve analysis revisited: overall net benefit, relationships to ROC curve analysis, and application to case-control studies." BMC medical informatics and decision making 11.1 (2011): 1–9. """ if set(y_true.astype(int)) != {0, 1}: raise ValueError( "Decision curve analysis only supports binary classification (with labels 1 and 0)." ) if isinstance(thresholds, int): thresholds = linspace(0, 1, num=thresholds) N = len(y_true) tns, fps, fns, tps = _binary_classification_thresholds(y_true, y_pred, thresholds) if action: loss_over_profit = divide( thresholds, 1 - thresholds, where=thresholds < 1, out=zeros_like(thresholds) ) benefit = tps / N - fps / N * loss_over_profit else: # Invert 0<-->1 so that true positives are true negatives, and false positives # are false negatives. profit_over_loss = divide( 1 - thresholds, thresholds, where=thresholds != 0, out=zeros_like(thresholds), ) benefit = tns / N - fns / N * profit_over_loss return thresholds, benefit
def net_benefit_action(y_true, thresholds, action: bool = True)
-
Net benefit of always doing an action/intervention/treatment.
Args
action
- When
False
, invert positive label iny_true
.
Expand source code
def net_benefit_action(y_true, thresholds, action: bool = True): """Net benefit of always doing an action/intervention/treatment. Args: action: When `False`, invert positive label in `y_true`. """ if action: loss_over_profit = divide( thresholds, 1 - thresholds, where=thresholds < 1, out=1e9 * ones_like(thresholds), ) return y_true.mean() - (1 - y_true.mean()) * loss_over_profit profit_over_loss = divide( 1 - thresholds, thresholds, where=thresholds != 0, out=1e9 * ones_like(thresholds), ) return 1 - y_true.mean() - y_true.mean() * profit_over_loss
def net_benefit_oracle(y_true, action: bool = True) ‑> float
-
Net benefit of omniscient strategy, i.e., a hypothetical perfect predictor.
Expand source code
def net_benefit_oracle(y_true, action: bool = True) -> float: """Net benefit of omniscient strategy, i.e., a hypothetical perfect predictor.""" if action: return y_true.mean() return 1 - y_true.mean()
def overall_net_benefit(y_true, y_pred, n_thresholds: int = 100)
-
Net benefit combining both taking and not-taking action.
Expand source code
def overall_net_benefit(y_true, y_pred, n_thresholds: int = 100): """Net benefit combining both taking and not-taking action.""" thresholds_action, benefit_action = net_benefit( y_true, y_pred, n_thresholds, action=True ) _, benefit_no_action = net_benefit(y_true, y_pred, thresholds_action, action=False) return thresholds_action, benefit_action + benefit_no_action
Classes
class NetBenefitDisplay (threshold_probability, net_benefit, net_benefit_action=None, net_benefit_noop=None, benefit_type: Literal['action', 'noop', 'overall'] = 'action', oracle: Optional[float] = None, estimator_name: Optional[str] = None)
-
Net benefit decision curve analysis visualisation.
Args
threshold_probability
- Probability to dichotomise the predicted probability of the model.
net_benefit
- Net benefit of taking an action as a function of
threshold_probability
. oracle
- The (constant) net benefit of a perfect predictor.
benefit_type
- Literal["action", "noop", "overall"] = "action",
Example
from sklearn.datasets import make_blobs from sklearn.linear_model import LogisticRegression from statkit.decision import NetBenefitDisplay centers = [[0, 0], [1, 1]] X_train, y_train = make_blobs( centers=centers, cluster_std=1, n_samples=20, random_state=5 ) X_test, y_test = make_blobs( centers=centers, cluster_std=1, n_samples=20, random_state=1005 ) clf = LogisticRegression(random_state=5).fit(X_train, y_train) y_pred_base = clf.predict_proba(X_test)[:, 1] NetBenefitDisplay.from_predictions(y_test, y_pred_base, name='Logistic Regression')
Expand source code
class NetBenefitDisplay: """Net benefit decision curve analysis visualisation. Args: threshold_probability: Probability to dichotomise the predicted probability of the model. net_benefit: Net benefit of taking an action as a function of `threshold_probability`. oracle: The (constant) net benefit of a perfect predictor. benefit_type: Literal["action", "noop", "overall"] = "action", Example: ```python from sklearn.datasets import make_blobs from sklearn.linear_model import LogisticRegression from statkit.decision import NetBenefitDisplay centers = [[0, 0], [1, 1]] X_train, y_train = make_blobs( centers=centers, cluster_std=1, n_samples=20, random_state=5 ) X_test, y_test = make_blobs( centers=centers, cluster_std=1, n_samples=20, random_state=1005 ) clf = LogisticRegression(random_state=5).fit(X_train, y_train) y_pred_base = clf.predict_proba(X_test)[:, 1] NetBenefitDisplay.from_predictions(y_test, y_pred_base, name='Logistic Regression') ``` """ def __init__( self, threshold_probability, net_benefit, net_benefit_action=None, net_benefit_noop=None, benefit_type: Literal["action", "noop", "overall"] = "action", oracle: Optional[float] = None, estimator_name: Optional[str] = None, ): self.threshold_probability = threshold_probability self.net_benefit = net_benefit self.net_benefit_action = net_benefit_action self.net_benefit_noop = net_benefit_noop self.benefit_type = benefit_type self.estimator_name = estimator_name self.oracle = oracle def plot(self, show_references: bool = True, ax=None): """ Args: show_references: Show oracle (requires prevalence) and no action/treatment/intervention reference curves. ax: Optional axes object to plot on. If `None`, a new figure and axes is created. """ if ax is None: _, ax = plt.subplots() self.ax_ = ax self.figure_ = ax.figure ax.plot( self.threshold_probability, self.net_benefit, "-", label=self.estimator_name ) if show_references: ax.plot( self.threshold_probability, self.net_benefit_action, ":", label="Always act", ) ax.plot( self.threshold_probability, self.net_benefit_noop, "-.", label="Never act", ) if self.oracle is not None: ax.plot([0, 1], [self.oracle, self.oracle], "--", label="Oracle") ylabel = "Net benefit" if self.benefit_type == "action": ylabel = "Net benefit (action)" if self.benefit_type == "noop": ylabel = "Net benefit (no-action)" if self.benefit_type == "overall": ylabel = "Overall net benefit" ax.set(xlabel="Threshold probability", ylabel=ylabel) margin = 0.05 ax.set_ylim([-margin, 1 + margin]) ax.legend(loc="upper right", frameon=False) return self @classmethod def from_predictions( cls, y_true, y_pred, benefit_type: Literal["action", "noop", "overall"] = "action", thresholds: int = 100, name: Optional[str] = None, show_references: bool = True, ax=None, ): """Make a net benefit plot from true and predicted labels. Args: y_true: Binary ground truth label (1: positive, 0: negative class). y_pred: Predicted class labels. benefit_type: Type of net benefit curve. `"action"`: net benefit of treatment/intervention/action; `"noop`: net benefit of no treatment/intervention/action; `"overall"`: overall net benefit (see `overall_net_benefit`). show_references: Show oracle (requires prevalence) and no action/treatment/intervention reference curves. thresholds: When an array, evaluate net benefit at these coordinates (the probability thresholds). When an int, the number of x coordinates. ax: Optional axes object to plot on. If `None`, a new figure and axes is created. """ if benefit_type == "action": oracle = net_benefit_oracle(y_true, action=True) thresholds, benefit = net_benefit(y_true, y_pred, thresholds) benefit_action = net_benefit_action(y_true, thresholds, action=True) benefit_noop = zeros_like(benefit) elif benefit_type == "noop": oracle = net_benefit_oracle(y_true, action=False) thresholds, benefit = net_benefit(y_true, y_pred, thresholds, action=False) benefit_noop = net_benefit_action(y_true, thresholds, action=False) benefit_action = zeros_like(benefit) elif benefit_type == "overall": oracle = 1.0 thresholds, benefit = overall_net_benefit(y_true, y_pred, thresholds) benefit_action = net_benefit_action(y_true, thresholds, action=True) benefit_noop = net_benefit_action(y_true, thresholds, action=False) return cls( thresholds, benefit, benefit_action, benefit_noop, benefit_type, oracle, estimator_name=name, ).plot(ax=ax, show_references=show_references)
Static methods
def from_predictions(y_true, y_pred, benefit_type: Literal['action', 'noop', 'overall'] = 'action', thresholds: int = 100, name: Optional[str] = None, show_references: bool = True, ax=None)
-
Make a net benefit plot from true and predicted labels.
Args
y_true
- Binary ground truth label (1: positive, 0: negative class).
y_pred
- Predicted class labels.
benefit_type
- Type of net benefit curve.
"action"
: net benefit of treatment/intervention/action;"noop
: net benefit of no treatment/intervention/action;"overall"
: overall net benefit (seeoverall_net_benefit()
). show_references
- Show oracle (requires prevalence) and no action/treatment/intervention reference curves.
thresholds
- When an array, evaluate net benefit at these coordinates (the probability thresholds). When an int, the number of x coordinates.
ax
- Optional axes object to plot on. If
None
, a new figure and axes is created.
Expand source code
@classmethod def from_predictions( cls, y_true, y_pred, benefit_type: Literal["action", "noop", "overall"] = "action", thresholds: int = 100, name: Optional[str] = None, show_references: bool = True, ax=None, ): """Make a net benefit plot from true and predicted labels. Args: y_true: Binary ground truth label (1: positive, 0: negative class). y_pred: Predicted class labels. benefit_type: Type of net benefit curve. `"action"`: net benefit of treatment/intervention/action; `"noop`: net benefit of no treatment/intervention/action; `"overall"`: overall net benefit (see `overall_net_benefit`). show_references: Show oracle (requires prevalence) and no action/treatment/intervention reference curves. thresholds: When an array, evaluate net benefit at these coordinates (the probability thresholds). When an int, the number of x coordinates. ax: Optional axes object to plot on. If `None`, a new figure and axes is created. """ if benefit_type == "action": oracle = net_benefit_oracle(y_true, action=True) thresholds, benefit = net_benefit(y_true, y_pred, thresholds) benefit_action = net_benefit_action(y_true, thresholds, action=True) benefit_noop = zeros_like(benefit) elif benefit_type == "noop": oracle = net_benefit_oracle(y_true, action=False) thresholds, benefit = net_benefit(y_true, y_pred, thresholds, action=False) benefit_noop = net_benefit_action(y_true, thresholds, action=False) benefit_action = zeros_like(benefit) elif benefit_type == "overall": oracle = 1.0 thresholds, benefit = overall_net_benefit(y_true, y_pred, thresholds) benefit_action = net_benefit_action(y_true, thresholds, action=True) benefit_noop = net_benefit_action(y_true, thresholds, action=False) return cls( thresholds, benefit, benefit_action, benefit_noop, benefit_type, oracle, estimator_name=name, ).plot(ax=ax, show_references=show_references)
Methods
def plot(self, show_references: bool = True, ax=None)
-
Args
show_references
- Show oracle (requires prevalence) and no action/treatment/intervention reference curves.
ax
- Optional axes object to plot on. If
None
, a new figure and axes is created.
Expand source code
def plot(self, show_references: bool = True, ax=None): """ Args: show_references: Show oracle (requires prevalence) and no action/treatment/intervention reference curves. ax: Optional axes object to plot on. If `None`, a new figure and axes is created. """ if ax is None: _, ax = plt.subplots() self.ax_ = ax self.figure_ = ax.figure ax.plot( self.threshold_probability, self.net_benefit, "-", label=self.estimator_name ) if show_references: ax.plot( self.threshold_probability, self.net_benefit_action, ":", label="Always act", ) ax.plot( self.threshold_probability, self.net_benefit_noop, "-.", label="Never act", ) if self.oracle is not None: ax.plot([0, 1], [self.oracle, self.oracle], "--", label="Oracle") ylabel = "Net benefit" if self.benefit_type == "action": ylabel = "Net benefit (action)" if self.benefit_type == "noop": ylabel = "Net benefit (no-action)" if self.benefit_type == "overall": ylabel = "Overall net benefit" ax.set(xlabel="Threshold probability", ylabel=ylabel) margin = 0.05 ax.set_ylim([-margin, 1 + margin]) ax.legend(loc="upper right", frameon=False) return self