from pathlib import Path
import numpy as np
import pandas as pd
import crabbymetrics as cm
np.set_printoptions(precision=4, suppress=True)
def repo_root():
for candidate in [Path.cwd().resolve(), *Path.cwd().resolve().parents]:
if (candidate / "ding_w_source").exists():
return candidate
raise FileNotFoundError("could not locate ding_w_source from the current working directory")
def expit(x):
return 1.0 / (1.0 + np.exp(-x))
def linear_basis(x):
return x
def rich_basis(x):
x1 = x[:, 0]
x2 = x[:, 1]
return np.column_stack([x1, x2, x1**2, x2**2, x1 * x2, np.sin(x1)])
def fit_binary_logit(x, d):
model = cm.Logit(alpha=1.0, max_iterations=300)
model.fit(x, d.astype(np.int32))
summary = model.summary()
return expit(summary["intercept"] + x @ np.asarray(summary["coef"]))
def fit_outcome_regression(x, y, d):
treated = d == 1
control = d == 0
mod1 = cm.OLS()
mod1.fit(x[treated], y[treated])
mod0 = cm.OLS()
mod0.fit(x[control], y[control])
return mod1.predict(x), mod0.predict(x)
def manual_estimators(y, d, x_outcome, x_pscore):
pi_hat = np.clip(fit_binary_logit(x_pscore, d), 0.02, 0.98)
mu1_hat, mu0_hat = fit_outcome_regression(x_outcome, y, d)
reg = float(np.mean(mu1_hat - mu0_hat))
ht_ipw = float(np.mean(d * y / pi_hat - (1 - d) * y / (1 - pi_hat)))
hajek_ipw = float(
np.mean(d * y / pi_hat) / np.mean(d / pi_hat)
- np.mean((1 - d) * y / (1 - pi_hat)) / np.mean((1 - d) / (1 - pi_hat))
)
aipw = float(
np.mean(
mu1_hat
- mu0_hat
+ d * (y - mu1_hat) / pi_hat
- (1 - d) * (y - mu0_hat) / (1 - pi_hat)
)
)
return reg, ht_ipw, hajek_ipw, aipw
def simulate_one(seed, outcome_correct, pscore_correct):
rng = np.random.default_rng(seed)
n = 700
x = rng.normal(size=(n, 2))
x1 = x[:, 0]
x2 = x[:, 1]
pi_true = expit(-0.2 + 0.8 * x1 - 0.7 * x2 + 0.6 * x1 * x2 - 0.4 * x1**2)
d = rng.binomial(1, pi_true, size=n).astype(float)
y0 = 1.0 + 0.7 * x1 - 0.6 * x2 + 0.8 * x1 * x2 + 0.5 * x1**2 + rng.normal(scale=0.7, size=n)
tau = 1.0
y = y0 + tau * d
x_outcome = rich_basis(x) if outcome_correct else linear_basis(x)
x_pscore = rich_basis(x) if pscore_correct else linear_basis(x)
penalty_grid = np.logspace(-4, 2, 20)
reg, ht_ipw, hajek_ipw, aipw = manual_estimators(y, d, x_outcome, x_pscore)
native = cm.AIPW(penalty=penalty_grid, cv=2, n_folds=2, seed=seed)
native.fit(y, d, x_pscore)
native_ate = float(native.summary()["ate"])
return np.array([tau, reg, ht_ipw, hajek_ipw, aipw, native_ate])