import numpy as np
from caml.extensions.plots import plot_cates_true_vs_estimated
42)
np.random.seed(= np.random.normal(0, 1, 100)
true_cates = true_cates + np.random.normal(0, 0.5, 100)
estimated_cates
= plot_cates_true_vs_estimated(true_cates, estimated_cates)
fig fig
plot_cates_true_vs_estimated
extensions.plots.plot_cates_true_vs_estimated(true_cates, estimated_cates, *, figure_kwargs={}, scatter_kwargs={})
Plots a scatter plot of the estimated CATEs against the true CATEs.
Parameters
Name | Type | Description | Default |
---|---|---|---|
estimated_cates |
numpy.typing.ArrayLike | The estimated CATEs. | required |
figure_kwargs |
dict | Matplotlib figure arguments. | {} |
scatter_kwargs |
dict | Matplotlib line arguments. | {} |
Returns
Type | Description |
---|---|
matplotlib.figure.Figure | The line plot figure object. |