from html import escape
import matplotlib.pyplot as plt
import numpy as np
from IPython.display import HTML, display
import crabbymetrics as cm
def html_table(headers, rows):
parts = [
"<table>",
"<thead>",
"<tr>",
*[f"<th>{escape(str(header))}</th>" for header in headers],
"</tr>",
"</thead>",
"<tbody>",
]
for row in rows:
parts.append("<tr>")
for cell in row:
parts.append(f"<td>{escape(str(cell))}</td>")
parts.append("</tr>")
parts.extend(["</tbody>", "</table>"])
return "".join(parts)
def expit(x):
return 1.0 / (1.0 + np.exp(-x))
def kang_schafer_dgp(n, rng):
z = rng.normal(size=(n, 4))
z1, z2, z3, z4 = z.T
propensity = expit(-z1 + 0.5 * z2 - 0.25 * z3 - 0.1 * z4)
d = rng.binomial(1, propensity)
y0 = 210.0 + 27.4 * z1 + 13.7 * (z2 + z3 + z4) + rng.normal(size=n)
y = y0
x = np.column_stack(
[
np.exp(z1 / 2.0),
z2 / (1.0 + np.exp(z1)) + 10.0,
(z1 * z3 / 25.0 + 0.6) ** 3,
(z2 + z4 + 20.0) ** 2,
]
)
return y, d, x, z
def hainmueller_dgp(
n,
rng,
overlap_design=1,
pscore_design=1,
outcome_design=1,
):
mean = np.zeros(3)
cov = np.array([[2.0, 1.0, -1.0], [1.0, 1.0, -0.5], [-1.0, -0.5, 1.0]])
x1, x2, x3 = rng.multivariate_normal(mean=mean, cov=cov, size=n).T
x4 = rng.uniform(-3.0, 3.0, size=n)
x5 = rng.chisquare(df=1.0, size=n)
x6 = rng.binomial(1, 0.5, size=n)
x = np.column_stack([x1, x2, x3, x4, x5, x6])
if overlap_design == 1:
epsilon = rng.normal(0.0, np.sqrt(30.0), size=n)
elif overlap_design == 2:
epsilon = rng.normal(0.0, 10.0, size=n)
elif overlap_design == 3:
epsilon = rng.chisquare(df=5.0, size=n)
epsilon = (epsilon - 5.0) / np.sqrt(10.0) * np.sqrt(67.6) + 0.5
else:
raise ValueError("unknown overlap_design")
if pscore_design == 1:
base_term = x1 + 2.0 * x2 - 2.0 * x3 - x4 - 0.5 * x5 + x6
elif pscore_design == 2:
base_term = x1 + x1**2 - x4 * x6
elif pscore_design == 3:
base_term = 2.0 * np.cos(x1) + np.sin(np.pi * x2)
else:
raise ValueError("unknown pscore_design")
d = (base_term + epsilon > 0.0).astype(int)
eta = rng.normal(0.0, 1.0, size=n)
if outcome_design == 1:
y = x1 + x2 + x3 - x4 + x5 + x6 + eta
elif outcome_design == 2:
y = x1 + x2 + 0.2 * x3 * x4 - np.sqrt(x5) + eta
elif outcome_design == 3:
y = 2.0 * np.cos(x1) + np.sin(np.pi * x2) + (x1 + x2 + x5) ** 2 + eta
else:
raise ValueError("unknown outcome_design")
return y, d, x
def fit_att_balancing(y, d, x, objective):
treated = d == 1
control = ~treated
model = cm.BalancingWeights(
objective=objective,
solver="auto",
autoscale=True,
max_iterations=300,
tolerance=1e-8,
)
model.fit(x[control], x[treated])
summary = model.summary()
weights = np.asarray(summary["weights"])
att_hat = y[treated].mean() - np.dot(weights, y[control])
return att_hat, summary
def standardized_mean_difference(x_treated, x_control, weights=None):
treated_mean = x_treated.mean(axis=0)
control_mean = x_control.mean(axis=0) if weights is None else np.average(
x_control, axis=0, weights=weights
)
treated_var = x_treated.var(axis=0)
control_var = x_control.var(axis=0) if weights is None else np.average(
(x_control - control_mean) ** 2, axis=0, weights=weights
)
pooled = np.sqrt(0.5 * (treated_var + control_var))
pooled = np.where(pooled > 1e-12, pooled, 1.0)
return (treated_mean - control_mean) / pooled
def evaluate_single_dataset():
rng = np.random.default_rng(123)
y, d, x, z = kang_schafer_dgp(2000, rng)
treated = d == 1
control = ~treated
naive_att = y[treated].mean() - y[control].mean()
quad_att, quad_summary = fit_att_balancing(y, d, x, "quadratic")
ent_att, ent_summary = fit_att_balancing(y, d, x, "entropy")
oracle_att, oracle_summary = fit_att_balancing(y, d, z, "entropy")
smd_before = standardized_mean_difference(x[treated], x[control])
smd_quad = standardized_mean_difference(
x[treated], x[control], weights=np.asarray(quad_summary["weights"])
)
smd_ent = standardized_mean_difference(
x[treated], x[control], weights=np.asarray(ent_summary["weights"])
)
rows = [
["Naive difference", f"{naive_att: .3f}", "--", "--", "--"],
[
"Quadratic balancing on observed X",
f"{quad_att: .3f}",
quad_summary["success"],
f"{quad_summary['effective_sample_size']: .1f}",
f"{quad_summary['max_abs_diff']: .2e}",
],
[
"Entropy balancing on observed X",
f"{ent_att: .3f}",
ent_summary["success"],
f"{ent_summary['effective_sample_size']: .1f}",
f"{ent_summary['max_abs_diff']: .2e}",
],
[
"Entropy balancing on latent Z (oracle)",
f"{oracle_att: .3f}",
oracle_summary["success"],
f"{oracle_summary['effective_sample_size']: .1f}",
f"{oracle_summary['max_abs_diff']: .2e}",
],
]
display(HTML(html_table(["Estimator", "ATT estimate", "Success", "ESS", "Max balance error"], rows)))
labels = [f"x{j + 1}" for j in range(x.shape[1])]
fig, ax = plt.subplots(figsize=(8, 4))
xpos = np.arange(len(labels))
width = 0.25
ax.bar(xpos - width, np.abs(smd_before), width=width, label="Unweighted")
ax.bar(xpos, np.abs(smd_quad), width=width, label="Quadratic")
ax.bar(xpos + width, np.abs(smd_ent), width=width, label="Entropy")
ax.axhline(0.1, color="black", linestyle="--", linewidth=1.0)
ax.set_xticks(xpos)
ax.set_xticklabels(labels)
ax.set_ylabel("Absolute standardized mean difference")
ax.set_title("Kang-Schafer: single-dataset balance on observed transformed covariates")
ax.legend()
fig.tight_layout()
return {
"naive": naive_att,
"quadratic": quad_att,
"entropy": ent_att,
"oracle": oracle_att,
}