pyfector
pyfector — Fast counterfactual estimators for panel data in Python.
A high-performance Python reimplementation of the R fect package,
featuring randomized SVD, GPU acceleration (CuPy), parallel computing
(joblib), Polars data ingestion, and seeded reproducibility.
Usage::
import pyfector
result = pyfector.fect(
data=df,
Y="outcome", D="treat",
index=("unit", "year"),
method="ife",
r=(0, 5),
se=True,
seed=42,
)
result.summary()
result.plot()
1""" 2pyfector — Fast counterfactual estimators for panel data in Python. 3 4A high-performance Python reimplementation of the R ``fect`` package, 5featuring randomized SVD, GPU acceleration (CuPy), parallel computing 6(joblib), Polars data ingestion, and seeded reproducibility. 7 8Usage:: 9 10 import pyfector 11 12 result = pyfector.fect( 13 data=df, 14 Y="outcome", D="treat", 15 index=("unit", "year"), 16 method="ife", 17 r=(0, 5), 18 se=True, 19 seed=42, 20 ) 21 result.summary() 22 result.plot() 23 24""" 25 26__version__ = "0.2.0" 27 28from .fect import fect, FectResult 29from .backend import set_device, get_device 30from .diagnostics import ( 31 run_diagnostics, 32 Diagnostics, 33 DiagnosticResult, 34 TostResult, 35 PretrendFResult, 36 EquivFResult, 37 PlaceboResult, 38 CarryoverResult, 39 LooResult, 40) 41from .plotting import plot 42 43__all__ = [ 44 "fect", 45 "FectResult", 46 "set_device", 47 "get_device", 48 "run_diagnostics", 49 "Diagnostics", 50 "DiagnosticResult", 51 "TostResult", 52 "PretrendFResult", 53 "EquivFResult", 54 "PlaceboResult", 55 "CarryoverResult", 56 "LooResult", 57 "plot", 58]
198def fect( 199 data, 200 Y: str, 201 D: str, 202 index: tuple[str, str], 203 X: list[str] | None = None, 204 W: str | None = None, 205 group: str | None = None, 206 method: Literal["fe", "ife", "mc", "cfe", "both"] = "ife", 207 force: Literal["none", "unit", "time", "two-way"] = "two-way", 208 r: int | tuple[int, int] = 0, 209 lam: float | None = None, 210 nlambda: int = 10, 211 lambda_candidates: list[float] | np.ndarray | None = None, 212 CV: bool = True, 213 k: int = 10, 214 cv_prop: float = 0.1, 215 cv_nobs: int = 3, 216 cv_treat: bool = True, 217 cv_donut: int = 0, 218 criterion: str = "mspe", 219 cv_rule: Literal["min", "onepct"] = "min", 220 se: bool = False, 221 vartype: Literal["bootstrap", "jackknife"] = "bootstrap", 222 nboots: int = 200, 223 alpha: float = 0.05, 224 tol: float = 1e-7, 225 max_iter: int = 5000, 226 min_T0: int = 1, 227 min_T0_strict: bool = False, 228 max_missing: float = 1.0, 229 normalize: bool = False, 230 # CFE-specific 231 Z: list[str] | None = None, 232 Q: list[str] | None = None, 233 # Performance 234 device: Literal["cpu", "gpu"] = "cpu", 235 n_jobs: int | None = -1, 236 seed: int | None = None, 237 # Diagnostics (optional; run at fit time and attached to result) 238 diagnostics: Literal["none", "full"] | list[str] = "none", 239 diagnostics_options: dict | None = None, 240) -> FectResult: 241 """Estimate counterfactual treatment effects for panel data. 242 243 This is the main Python entry point for the counterfactual estimator 244 workflow. Where the paper and the historical R package differ, 245 pyfector defaults to the paper's statistical definition and exposes 246 R-package-style behavior through explicit options. 247 248 Missing outcome policy 249 ---------------------- 250 pyfector distinguishes raw missing outcomes from counterfactual 251 missingness caused by treatment. Observed untreated cells 252 (``D == 0`` and non-missing ``Y``) fit the response surface. Observed 253 treated cells (``D == 1`` and non-missing ``Y``) contribute to ATT as 254 ``Y - Y_ct``. If a treated outcome is missing in the input data, the 255 model can still produce a counterfactual ``Y_ct`` for that cell, but 256 the cell is not counted in ``att_avg`` or ``att_on`` because the 257 treated potential outcome was not observed. 258 259 By default, ``min_T0`` is enforced only for treated and reversal 260 units. Sparse controls are retained if they have at least one 261 observed outcome, because they may still inform the low-rank response 262 surface. Set ``min_T0_strict=True`` to require controls to satisfy 263 ``min_T0`` too, matching the more conservative R fect sparse-panel 264 behavior. 265 266 Parameters 267 ---------- 268 data : polars.DataFrame, pandas.DataFrame 269 Long-format panel data. 270 Y, D : str 271 Column names for outcome and binary treatment indicator. 272 index : (str, str) 273 Column names for (unit_id, time_period). 274 X : list of str, optional 275 Time-varying covariates. 276 W : str, optional 277 Observation weight column. 278 group : str, optional 279 Reserved for grouped estimation. Currently raises 280 ``NotImplementedError`` when supplied. 281 method : {"fe", "ife", "mc", "cfe", "both"} 282 Estimation method. 283 force : {"none", "unit", "time", "two-way"} 284 Fixed effects specification. 285 r : int or (int, int) 286 Number of factors. If tuple, CV selects from range. 287 lam : float, optional 288 Nuclear norm penalty for MC. If None with CV=True, auto-selected. 289 nlambda : int 290 Number of automatically generated lambda candidates for MC CV. 291 lambda_candidates : array-like, optional 292 Explicit non-negative lambda candidates for MC CV. When supplied, 293 ``nlambda`` is ignored. 294 CV : bool 295 If True, cross-validate over ``r`` for IFE when ``r`` is a tuple, 296 or over ``lam`` for MC when ``lam`` is None. 297 k : int 298 Number of CV folds. 299 cv_prop : float 300 Fraction of eligible observed control cells masked per CV fold. 301 cv_nobs : int 302 Number of consecutive within-unit observations to mask as a block. 303 cv_treat : bool 304 If True, restrict CV masks to pre-treatment cells of ever-treated 305 units. If False, use all observed control cells. 306 cv_donut : int 307 Exclude this many periods around treatment onset from CV evaluation. 308 criterion : {"mspe", "gmspe", "mad"} 309 Cross-validation loss. 310 cv_rule : {"min", "onepct"} 311 CV selection rule. ``"min"`` chooses the strict minimum-score 312 candidate and is the paper-faithful default. ``"onepct"`` chooses 313 the simplest candidate within 1% of the best score (lower ``r`` for 314 IFE, higher ``lam`` for MC). 315 se : bool 316 Compute standard errors via bootstrap/jackknife. 317 vartype : {"bootstrap", "jackknife"} 318 Inference method when ``se=True``. 319 nboots : int 320 Number of bootstrap replications. Ignored for jackknife. 321 alpha : float 322 Significance level for confidence intervals and tests. 323 tol : float 324 EM convergence tolerance for final point estimation. 325 max_iter : int 326 Maximum EM iterations. 327 min_T0 : int 328 Minimum untreated/pre-treatment observed periods. By default this is 329 enforced only for treated and treatment-reversal units. 330 min_T0_strict : bool 331 If True, enforce ``min_T0`` on all units, including controls. This 332 matches R fect's conservative handling of sparse control rows. 333 max_missing : float 334 Maximum missing-outcome fraction per unit, in ``[0, 1]``. Units with 335 no observed outcomes are always dropped, regardless of this threshold, 336 because they provide neither fitting information nor observed treated 337 effects. 338 normalize : bool 339 If True, estimate on an outcome standardized by its observed standard 340 deviation, then transform effects back to the original scale. 341 Z, Q : list of str, optional 342 Reserved CFE interaction arguments. Currently raise 343 ``NotImplementedError`` when supplied. 344 device : {"cpu", "gpu"} 345 Compute device. 346 n_jobs : int, optional 347 Parallel workers for CV and bootstrap. ``-1`` or ``None`` uses 348 all available CPUs. 349 seed : int, optional 350 Random seed for full reproducibility. 351 """ 352 # Set device 353 set_device(device) 354 xp = get_backend() 355 n_jobs = _resolve_n_jobs(n_jobs) 356 if device == "gpu": 357 n_jobs = 1 358 359 if group is not None: 360 raise NotImplementedError("The `group` argument is not implemented yet.") 361 if Z is not None or Q is not None: 362 raise NotImplementedError("The `Z` and `Q` CFE interaction arguments are not implemented yet.") 363 if criterion not in {"mspe", "gmspe", "mad"}: 364 raise ValueError("criterion must be 'mspe', 'gmspe', or 'mad'") 365 if cv_rule not in {"min", "onepct"}: 366 raise ValueError("cv_rule must be 'min' or 'onepct'") 367 if min_T0 < 0: 368 raise ValueError("min_T0 must be non-negative") 369 if not 0.0 <= max_missing <= 1.0: 370 raise ValueError("max_missing must be between 0 and 1") 371 372 # Validate diagnostics request before doing any expensive estimation 373 # so users don't wait for a 30-min MC fit only to find their config 374 # is wrong. 375 requested_diag = validate_diagnostics_request( 376 diagnostics, diagnostics_options, se, 377 ) 378 379 # Map force string to int 380 force_map = {"none": 0, "unit": 1, "time": 2, "two-way": 3} 381 force_int = force_map[force] 382 383 # Prepare panel data 384 panel = prepare_panel( 385 data, Y=Y, D=D, index=index, X=X, W=W, 386 group=group, min_T0=min_T0, min_T0_strict=min_T0_strict, 387 max_missing=max_missing, 388 ) 389 390 # Move to device 391 Y_mat = to_device(panel.Y) 392 D_mat = to_device(panel.D) 393 I_mat = to_device(panel.I) 394 II_mat = to_device(panel.II) 395 X_mat = to_device(panel.X) if panel.X is not None else None 396 W_mat = to_device(panel.W) if panel.W is not None else None 397 398 # Normalize 399 norm_factor = 1.0 400 if normalize: 401 sd_y = float(xp.std(Y_mat[I_mat > 0])) 402 if sd_y > 0: 403 Y_mat = Y_mat / sd_y 404 norm_factor = sd_y 405 406 # Initial fit 407 Y0, beta0 = initial_fit(Y_mat, X_mat, II_mat, force_int) 408 409 # Determine r and lambda 410 r_cv = None 411 lambda_cv = None 412 cv_result = None 413 414 if method == "ife": 415 if isinstance(r, tuple) and CV: 416 cv_result = cv_ife( 417 Y_mat, Y0, X_mat, I_mat, II_mat, D_mat, W_mat, beta0, 418 force=force_int, r_range=r, k=k, cv_prop=cv_prop, 419 cv_nobs=cv_nobs, cv_treat=cv_treat, cv_donut=cv_donut, 420 criterion=criterion, cv_rule=cv_rule, 421 tol=tol, max_iter=max_iter, 422 n_jobs=n_jobs, seed=seed, 423 ) 424 r_cv = cv_result.best_r 425 else: 426 r_cv = r if isinstance(r, int) else r[0] 427 428 elif method == "mc": 429 if lam is None and CV: 430 cv_result = cv_mc( 431 Y_mat, Y0, X_mat, I_mat, II_mat, D_mat, W_mat, beta0, 432 force=force_int, lambda_candidates=lambda_candidates, 433 nlambda=nlambda, k=k, cv_prop=cv_prop, 434 cv_nobs=cv_nobs, cv_treat=cv_treat, cv_donut=cv_donut, 435 criterion=criterion, cv_rule=cv_rule, 436 tol=tol, max_iter=max_iter, 437 n_jobs=n_jobs, seed=seed, 438 ) 439 lambda_cv = cv_result.best_lambda 440 else: 441 lambda_cv = lam if lam is not None else 0.0 442 443 elif method == "fe": 444 r_cv = 0 445 446 elif method == "cfe": 447 r_cv = r if isinstance(r, int) else r[0] 448 449 # Point estimation 450 if method in ("fe", "ife"): 451 est = estimate_ife( 452 Y_mat, Y0, X_mat, II_mat, W_mat, beta0, 453 r=r_cv, force=force_int, tol=tol, max_iter=max_iter, 454 ) 455 elif method == "mc": 456 est = estimate_mc( 457 Y_mat, Y0, X_mat, II_mat, W_mat, beta0, 458 lam=lambda_cv, force=force_int, tol=tol, max_iter=max_iter, 459 ) 460 elif method == "cfe": 461 est = estimate_cfe( 462 Y_mat, Y0, X_mat, II_mat, W_mat, beta0, 463 r=r_cv, force=force_int, tol=tol, max_iter=max_iter, 464 ) 465 elif method == "both": 466 # Run both IFE and MC, return IFE results with MC comparison 467 if isinstance(r, tuple) and CV: 468 cv_result = cv_ife( 469 Y_mat, Y0, X_mat, I_mat, II_mat, D_mat, W_mat, beta0, 470 force=force_int, r_range=r, k=k, cv_prop=cv_prop, 471 cv_nobs=cv_nobs, cv_treat=cv_treat, cv_donut=cv_donut, 472 criterion=criterion, cv_rule=cv_rule, 473 tol=tol, max_iter=max_iter, 474 n_jobs=n_jobs, seed=seed, 475 ) 476 r_cv = cv_result.best_r 477 else: 478 r_cv = r if isinstance(r, int) else r[0] 479 est = estimate_ife( 480 Y_mat, Y0, X_mat, II_mat, W_mat, beta0, 481 r=r_cv, force=force_int, tol=tol, max_iter=max_iter, 482 ) 483 else: 484 raise ValueError(f"Unknown method: {method}") 485 486 # Compute effects 487 eff = Y_mat - est.fit 488 Y_ct = est.fit 489 490 # Additive-FE baseline residual variance (Liu et al. 2024 sigma2.fect). 491 # For method="fe" the main estimator IS the additive-FE pass, so reuse 492 # est.sigma2. For ife/mc/cfe/both, run an extra r=0 IFE pass on the 493 # same panel with the user's requested FE structure. 494 if method == "fe": 495 sigma2_fect_value = float(est.sigma2) 496 else: 497 est_fect = estimate_ife( 498 Y_mat, Y0, X_mat, II_mat, W_mat, beta0, 499 r=0, force=force_int, tol=tol, max_iter=max_iter, 500 ) 501 sigma2_fect_value = float(est_fect.sigma2) 502 503 # Denormalize 504 if normalize and norm_factor != 1.0: 505 eff = eff * norm_factor 506 Y_ct = Y_ct * norm_factor 507 Y_mat = Y_mat * norm_factor 508 if est.beta is not None: 509 est = est._replace(beta=est.beta * norm_factor) 510 sigma2_fect_value *= norm_factor ** 2 511 512 # ATT computation 513 T_on = to_device(panel.T_on) 514 att_avg, att_on, time_on, count_on, att_avg_unit = _compute_effects( 515 to_numpy(eff), to_numpy(D_mat), to_numpy(panel.T_on), to_numpy(I_mat), 516 ) 517 518 # Build result 519 result = FectResult( 520 method=method, 521 r_cv=r_cv, 522 lambda_cv=lambda_cv, 523 att_avg=att_avg, 524 att_avg_unit=att_avg_unit, 525 att_on=att_on, 526 time_on=time_on, 527 count_on=count_on, 528 beta=to_numpy(est.beta) if est.beta is not None else None, 529 covariate_names=panel.covariate_names, 530 mu=est.mu, 531 alpha=to_numpy(est.alpha) if est.alpha is not None else None, 532 xi=to_numpy(est.xi) if est.xi is not None else None, 533 factors=to_numpy(est.factors) if est.factors is not None else None, 534 loadings=to_numpy(est.loadings) if est.loadings is not None else None, 535 Y_ct=to_numpy(Y_ct), 536 eff=to_numpy(eff), 537 residuals=to_numpy(est.residuals), 538 sigma2=est.sigma2, 539 sigma2_fect=sigma2_fect_value, 540 IC=est.IC, 541 PC=est.PC, 542 niter=est.niter, 543 converged=est.converged, 544 cv_result=cv_result, 545 panel=panel, 546 seed=seed, 547 ) 548 549 # Inference 550 if se: 551 result.inference = _run_inference( 552 result, panel, Y_mat, X_mat, W_mat, beta0, Y0, 553 method=method, r_cv=r_cv, lambda_cv=lambda_cv, 554 force_int=force_int, tol=tol, max_iter=max_iter, 555 vartype=vartype, nboots=nboots, alpha=alpha, 556 n_jobs=n_jobs, seed=seed, normalize=normalize, 557 norm_factor=norm_factor, 558 ) 559 560 # Run requested diagnostics at fit time. requested_diag is None when 561 # diagnostics="none". Validation already enforced se=True and 562 # required-config presence. 563 if requested_diag is not None: 564 opts = dict(diagnostics_options or {}) 565 if "loo" in requested_diag: 566 opts["loo"] = True 567 else: 568 opts.setdefault("loo", False) 569 result.diagnostics = _run_diagnostics( 570 result, _requested=requested_diag, **opts, 571 ) 572 573 return result
Estimate counterfactual treatment effects for panel data.
This is the main Python entry point for the counterfactual estimator workflow. Where the paper and the historical R package differ, pyfector defaults to the paper's statistical definition and exposes R-package-style behavior through explicit options.
Missing outcome policy
pyfector distinguishes raw missing outcomes from counterfactual
missingness caused by treatment. Observed untreated cells
(D == 0 and non-missing Y) fit the response surface. Observed
treated cells (D == 1 and non-missing Y) contribute to ATT as
Y - Y_ct. If a treated outcome is missing in the input data, the
model can still produce a counterfactual Y_ct for that cell, but
the cell is not counted in att_avg or att_on because the
treated potential outcome was not observed.
By default, min_T0 is enforced only for treated and reversal
units. Sparse controls are retained if they have at least one
observed outcome, because they may still inform the low-rank response
surface. Set min_T0_strict=True to require controls to satisfy
min_T0 too, matching the more conservative R fect sparse-panel
behavior.
Parameters
data : polars.DataFrame, pandas.DataFrame
Long-format panel data.
Y, D : str
Column names for outcome and binary treatment indicator.
index : (str, str)
Column names for (unit_id, time_period).
X : list of str, optional
Time-varying covariates.
W : str, optional
Observation weight column.
group : str, optional
Reserved for grouped estimation. Currently raises
NotImplementedError when supplied.
method : {"fe", "ife", "mc", "cfe", "both"}
Estimation method.
force : {"none", "unit", "time", "two-way"}
Fixed effects specification.
r : int or (int, int)
Number of factors. If tuple, CV selects from range.
lam : float, optional
Nuclear norm penalty for MC. If None with CV=True, auto-selected.
nlambda : int
Number of automatically generated lambda candidates for MC CV.
lambda_candidates : array-like, optional
Explicit non-negative lambda candidates for MC CV. When supplied,
nlambda is ignored.
CV : bool
If True, cross-validate over r for IFE when r is a tuple,
or over lam for MC when lam is None.
k : int
Number of CV folds.
cv_prop : float
Fraction of eligible observed control cells masked per CV fold.
cv_nobs : int
Number of consecutive within-unit observations to mask as a block.
cv_treat : bool
If True, restrict CV masks to pre-treatment cells of ever-treated
units. If False, use all observed control cells.
cv_donut : int
Exclude this many periods around treatment onset from CV evaluation.
criterion : {"mspe", "gmspe", "mad"}
Cross-validation loss.
cv_rule : {"min", "onepct"}
CV selection rule. "min" chooses the strict minimum-score
candidate and is the paper-faithful default. "onepct" chooses
the simplest candidate within 1% of the best score (lower r for
IFE, higher lam for MC).
se : bool
Compute standard errors via bootstrap/jackknife.
vartype : {"bootstrap", "jackknife"}
Inference method when se=True.
nboots : int
Number of bootstrap replications. Ignored for jackknife.
alpha : float
Significance level for confidence intervals and tests.
tol : float
EM convergence tolerance for final point estimation.
max_iter : int
Maximum EM iterations.
min_T0 : int
Minimum untreated/pre-treatment observed periods. By default this is
enforced only for treated and treatment-reversal units.
min_T0_strict : bool
If True, enforce min_T0 on all units, including controls. This
matches R fect's conservative handling of sparse control rows.
max_missing : float
Maximum missing-outcome fraction per unit, in [0, 1]. Units with
no observed outcomes are always dropped, regardless of this threshold,
because they provide neither fitting information nor observed treated
effects.
normalize : bool
If True, estimate on an outcome standardized by its observed standard
deviation, then transform effects back to the original scale.
Z, Q : list of str, optional
Reserved CFE interaction arguments. Currently raise
NotImplementedError when supplied.
device : {"cpu", "gpu"}
Compute device.
n_jobs : int, optional
Parallel workers for CV and bootstrap. -1 or None uses
all available CPUs.
seed : int, optional
Random seed for full reproducibility.
73@dataclass 74class FectResult: 75 """Container for all fect estimation results.""" 76 # Method info 77 method: str 78 r_cv: int | None = None 79 lambda_cv: float | None = None 80 81 # Point estimates 82 att_avg: float = 0.0 83 att_avg_unit: float = 0.0 84 85 # Dynamic effects 86 att_on: np.ndarray | None = None 87 time_on: np.ndarray | None = None 88 count_on: np.ndarray | None = None 89 90 # Exit effects (treatment reversal) 91 att_off: np.ndarray | None = None 92 time_off: np.ndarray | None = None 93 94 # Coefficients 95 beta: np.ndarray | None = None 96 covariate_names: list[str] = field(default_factory=list) 97 98 # Fixed effects 99 mu: float = 0.0 100 alpha: np.ndarray | None = None # unit FE 101 xi: np.ndarray | None = None # time FE 102 factors: np.ndarray | None = None 103 loadings: np.ndarray | None = None 104 105 # Counterfactual and effects matrices 106 Y_ct: np.ndarray | None = None # T×N counterfactual 107 eff: np.ndarray | None = None # T×N treatment effects 108 residuals: np.ndarray | None = None 109 110 # Model fit 111 sigma2: float = 0.0 112 sigma2_fect: float = 0.0 # additive-FE baseline residual variance 113 IC: float = 0.0 114 PC: float = 0.0 115 rmse: float = 0.0 116 niter: int = 0 117 converged: bool = False 118 119 # Inference 120 inference: InferenceResult | None = None 121 122 # Diagnostics (populated when fect(..., diagnostics="full" | list)) 123 diagnostics: Diagnostics | None = None 124 125 # CV 126 cv_result: CVResult | None = None 127 128 # Panel metadata 129 panel: PanelData | None = None 130 131 # Reproducibility 132 seed: int | None = None 133 134 def summary(self) -> str: 135 """Print summary table of results.""" 136 lines = [] 137 lines.append(f"pyfector estimation results") 138 lines.append(f"{'='*60}") 139 lines.append(f"Method: {self.method}") 140 if self.r_cv is not None: 141 lines.append(f"Number of factors (CV): {self.r_cv}") 142 if self.lambda_cv is not None: 143 lines.append(f"Lambda (CV): {self.lambda_cv:.6f}") 144 lines.append(f"Converged: {self.converged} (iter={self.niter})") 145 lines.append(f"Sigma^2: {self.sigma2:.6f}") 146 lines.append(f"Sigma^2_fect (FE baseline): {self.sigma2_fect:.6f}") 147 lines.append(f"") 148 lines.append(f"ATT (average): {self.att_avg:.6f}") 149 if self.inference is not None: 150 inf = self.inference 151 lines.append(f" SE: {inf.att_avg_se:.6f}") 152 lines.append(f" CI: [{inf.att_avg_ci[0]:.6f}, {inf.att_avg_ci[1]:.6f}]") 153 lines.append(f" p-val: {inf.att_avg_pval:.4f}") 154 155 if self.beta is not None and len(self.beta) > 0: 156 lines.append(f"") 157 lines.append(f"Coefficients:") 158 for i, name in enumerate(self.covariate_names): 159 lines.append(f" {name}: {self.beta[i]:.6f}") 160 161 if self.att_on is not None and self.time_on is not None: 162 lines.append(f"") 163 lines.append(f"Dynamic effects (ATT by relative time):") 164 lines.append(f" {'Time':>6s} {'ATT':>10s} {'Count':>6s}", ) 165 for i, t in enumerate(self.time_on): 166 count = self.count_on[i] if self.count_on is not None else "" 167 att = self.att_on[i] 168 if self.inference is not None: 169 se = self.inference.att_on_se[i] 170 lines.append(f" {t:>6.0f} {att:>10.4f} ({se:.4f}) {count}") 171 else: 172 lines.append(f" {t:>6.0f} {att:>10.4f} {count}") 173 174 lines.append(f"{'='*60}") 175 if self.panel is not None: 176 lines.append(f"N={self.panel.N}, T={self.panel.T}") 177 if self.seed is not None: 178 lines.append(f"Seed: {self.seed}") 179 if self.diagnostics is not None: 180 lines.append("") 181 lines.append(self.diagnostics.summary()) 182 return "\n".join(lines) 183 184 def __repr__(self): 185 return self.summary() 186 187 def plot(self, kind="gap", **kwargs): 188 """Plot results. Shortcut for ``pyfector.plot(self, kind, ...)``.""" 189 from .plotting import plot as _plot 190 return _plot(self, kind=kind, **kwargs) 191 192 def diagnose(self, **kwargs): 193 """Run diagnostic tests. Shortcut for ``pyfector.run_diagnostics(self, ...)``.""" 194 from .diagnostics import run_diagnostics 195 return run_diagnostics(self, **kwargs)
Container for all fect estimation results.
134 def summary(self) -> str: 135 """Print summary table of results.""" 136 lines = [] 137 lines.append(f"pyfector estimation results") 138 lines.append(f"{'='*60}") 139 lines.append(f"Method: {self.method}") 140 if self.r_cv is not None: 141 lines.append(f"Number of factors (CV): {self.r_cv}") 142 if self.lambda_cv is not None: 143 lines.append(f"Lambda (CV): {self.lambda_cv:.6f}") 144 lines.append(f"Converged: {self.converged} (iter={self.niter})") 145 lines.append(f"Sigma^2: {self.sigma2:.6f}") 146 lines.append(f"Sigma^2_fect (FE baseline): {self.sigma2_fect:.6f}") 147 lines.append(f"") 148 lines.append(f"ATT (average): {self.att_avg:.6f}") 149 if self.inference is not None: 150 inf = self.inference 151 lines.append(f" SE: {inf.att_avg_se:.6f}") 152 lines.append(f" CI: [{inf.att_avg_ci[0]:.6f}, {inf.att_avg_ci[1]:.6f}]") 153 lines.append(f" p-val: {inf.att_avg_pval:.4f}") 154 155 if self.beta is not None and len(self.beta) > 0: 156 lines.append(f"") 157 lines.append(f"Coefficients:") 158 for i, name in enumerate(self.covariate_names): 159 lines.append(f" {name}: {self.beta[i]:.6f}") 160 161 if self.att_on is not None and self.time_on is not None: 162 lines.append(f"") 163 lines.append(f"Dynamic effects (ATT by relative time):") 164 lines.append(f" {'Time':>6s} {'ATT':>10s} {'Count':>6s}", ) 165 for i, t in enumerate(self.time_on): 166 count = self.count_on[i] if self.count_on is not None else "" 167 att = self.att_on[i] 168 if self.inference is not None: 169 se = self.inference.att_on_se[i] 170 lines.append(f" {t:>6.0f} {att:>10.4f} ({se:.4f}) {count}") 171 else: 172 lines.append(f" {t:>6.0f} {att:>10.4f} {count}") 173 174 lines.append(f"{'='*60}") 175 if self.panel is not None: 176 lines.append(f"N={self.panel.N}, T={self.panel.T}") 177 if self.seed is not None: 178 lines.append(f"Seed: {self.seed}") 179 if self.diagnostics is not None: 180 lines.append("") 181 lines.append(self.diagnostics.summary()) 182 return "\n".join(lines)
Print summary table of results.
187 def plot(self, kind="gap", **kwargs): 188 """Plot results. Shortcut for ``pyfector.plot(self, kind, ...)``.""" 189 from .plotting import plot as _plot 190 return _plot(self, kind=kind, **kwargs)
Plot results. Shortcut for pyfector.plot(self, kind, ...).
192 def diagnose(self, **kwargs): 193 """Run diagnostic tests. Shortcut for ``pyfector.run_diagnostics(self, ...)``.""" 194 from .diagnostics import run_diagnostics 195 return run_diagnostics(self, **kwargs)
Run diagnostic tests. Shortcut for pyfector.run_diagnostics(self, ...).
35def set_device(device: Literal["cpu", "gpu"]) -> None: 36 """Set the compute device globally.""" 37 global _DEVICE 38 if device == "gpu" and not _check_cupy(): 39 raise ImportError( 40 "CuPy is required for GPU support. Install it with: " 41 "pip install cupy-cuda12x (adjust for your CUDA version)" 42 ) 43 _DEVICE = device
Set the compute device globally.
Return the current device.
337def run_diagnostics( 338 result, 339 f_threshold: float = 0.5, 340 tost_threshold: float = 0.36, 341 placebo_period: tuple[int, int] | None = None, 342 carryover_period: tuple[int, int] | None = None, 343 loo: bool = False, 344 alpha: float = 0.05, 345 *, 346 _requested: list[str] | None = None, 347) -> Diagnostics: 348 """Run diagnostic tests on a FectResult. 349 350 Parameters 351 ---------- 352 result : FectResult 353 Must have inference results (``se=True``). 354 f_threshold : float 355 Non-centrality parameter for equivalence F-test. 356 tost_threshold : float 357 Equivalence bound for TOST. The literal ``0.36`` (default or 358 explicit) triggers Liu et al. (2024)'s scale-aware bound, 359 ``0.36 * sqrt(result.sigma2_fect)``. Any other positive float 360 is taken as an absolute outcome-scale bound. ``None`` is 361 invalid. 362 placebo_period : (start, end), optional 363 Relative time window for placebo test. Restricted to 364 pre-treatment event times: ``(time_on >= start) & 365 (time_on <= end) & (time_on < 0)``. So ``(-3, 0)`` selects 366 rel_time ∈ {-3, -2, -1}. 367 carryover_period : (start, end), optional 368 Relative time window for carryover test. 369 loo : bool 370 If True, run leave-one-out post-period sensitivity. 371 alpha : float 372 Significance cutoff used for ``TostResult.all_pass``. 373 """ 374 from scipy import stats 375 376 diag = Diagnostics() 377 378 if result.inference is None: 379 return diag 380 381 inf = result.inference 382 time_on = result.time_on 383 att_on = result.att_on 384 385 if time_on is None or att_on is None: 386 return diag 387 388 pre_mask = time_on < 0 389 if not np.any(pre_mask): 390 return diag 391 392 pre_idx = np.where(pre_mask)[0] 393 k = len(pre_idx) 394 att_pre = att_on[pre_idx] 395 396 requested = set(_requested) if _requested is not None else _VALID_DIAG_NAMES.copy() 397 needs_tost_threshold = bool({"tost", "placebo"} & requested) 398 threshold_abs = float("nan") 399 threshold_source: str | None = None 400 sigma2_fect_used: float | None = None 401 if needs_tost_threshold: 402 threshold_abs, threshold_source, sigma2_fect_used = _resolve_tost_threshold( 403 result, tost_threshold 404 ) 405 406 # Pre-trend / equivalence F-tests share the bootstrap covariance pass. 407 if ( 408 ("pretrend_f" in requested or "equiv_f" in requested) 409 and inf.att_on_boot is not None 410 ): 411 boot_pre_all = inf.att_on_boot[pre_idx, :] 412 valid_boot = np.all(np.isfinite(boot_pre_all), axis=0) 413 boot_pre = boot_pre_all[:, valid_boot] 414 n_boot = boot_pre.shape[1] 415 if n_boot <= k: 416 boot_pre = None 417 418 else: 419 boot_pre = None 420 421 if boot_pre is not None: 422 S = np.cov(boot_pre) 423 if k == 1: 424 S = np.asarray(S).reshape(1, 1) 425 426 try: 427 cond = np.linalg.cond(S) 428 S_inv = np.linalg.pinv(S) if cond > 1e12 else np.linalg.inv(S) 429 F_raw = float(att_pre @ S_inv @ att_pre) 430 scale = (n_boot - k) / ((n_boot - 1) * k) 431 F_stat = F_raw * scale 432 433 if np.isfinite(F_stat) and F_stat >= 0: 434 if "pretrend_f" in requested: 435 diag.pretrend_f = PretrendFResult( 436 f_stat=F_stat, 437 p_value=float(1 - stats.f.cdf(F_stat, k, n_boot - k)), 438 df1=k, 439 df2=n_boot - k, 440 ) 441 if "equiv_f" in requested: 442 ncp = n_boot * f_threshold 443 diag.equiv_f = EquivFResult( 444 p_value=float(stats.ncf.cdf(F_stat, k, n_boot - k, ncp)), 445 f_threshold=f_threshold, 446 ) 447 except np.linalg.LinAlgError: 448 pass 449 450 # Per-period TOST. 451 if "tost" in requested and inf.att_on_se is not None: 452 se_pre = inf.att_on_se[pre_idx] 453 454 tost_pvals = np.full(k, np.nan) 455 for i in range(k): 456 if se_pre[i] > 0: 457 if inf.att_on_boot is not None: 458 n_boot_i = int(np.isfinite(inf.att_on_boot[pre_idx[i], :]).sum()) 459 else: 460 n_boot_i = 200 461 df = max(n_boot_i - 1, 1) 462 t_upper = (att_pre[i] - threshold_abs) / se_pre[i] 463 t_lower = (att_pre[i] + threshold_abs) / se_pre[i] 464 p_upper = float(stats.t.cdf(t_upper, df)) 465 p_lower = float(1 - stats.t.cdf(t_lower, df)) 466 tost_pvals[i] = max(p_upper, p_lower) 467 468 finite = tost_pvals[np.isfinite(tost_pvals)] 469 max_p = float(finite.max()) if finite.size else float("nan") 470 all_pass = bool(finite.size and (finite < alpha).all()) 471 472 diag.tost = TostResult( 473 pvals=tost_pvals, 474 periods=time_on[pre_idx].copy(), 475 threshold=threshold_abs, 476 threshold_source=threshold_source, 477 sigma2_fect=sigma2_fect_used, 478 max_pval=max_p, 479 all_pass=all_pass, 480 ) 481 482 # Placebo: bootstrap percentile p + TOST equivalence p. 483 if "placebo" in requested and placebo_period is not None: 484 p_start, p_end = placebo_period 485 placebo_mask = (time_on >= p_start) & (time_on <= p_end) & pre_mask 486 if np.any(placebo_mask): 487 p_idx = np.where(placebo_mask)[0] 488 estimate = float(np.mean(att_on[p_idx])) 489 p_value = float("nan") 490 equiv_p = float("nan") 491 se = float("nan") 492 if inf.att_on_boot is not None: 493 boot_window = inf.att_on_boot[p_idx, :] 494 valid_boot = np.all(np.isfinite(boot_window), axis=0) 495 boot_placebo = np.mean(boot_window[:, valid_boot], axis=0) 496 p_value = min( 497 2 * float(np.mean(boot_placebo >= 0)), 498 2 * float(np.mean(boot_placebo <= 0)), 499 1.0, 500 ) if boot_placebo.size else float("nan") 501 if boot_placebo.size > 1: 502 se = float(np.std(boot_placebo, ddof=1)) 503 if se > 0: 504 df_p = boot_placebo.size - 1 505 t_up = (estimate - threshold_abs) / se 506 t_lo = (estimate + threshold_abs) / se 507 equiv_p = max( 508 float(stats.t.cdf(t_up, df_p)), 509 float(1 - stats.t.cdf(t_lo, df_p)), 510 ) 511 diag.placebo = PlaceboResult( 512 estimate=estimate, 513 se=se, 514 p_value=p_value, 515 equiv_p_value=equiv_p, 516 period=(int(p_start), int(p_end)), 517 ) 518 519 # Carryover (estimate only; bootstrap p deferred). 520 if ( 521 "carryover" in requested 522 and carryover_period is not None 523 and hasattr(result, "att_off") 524 and result.att_off is not None 525 ): 526 c_start, c_end = carryover_period 527 time_off = getattr(result, "time_off", None) 528 if time_off is not None: 529 off_mask = (time_off >= c_start) & (time_off <= c_end) 530 if np.any(off_mask): 531 diag.carryover = CarryoverResult( 532 estimate=float(np.mean(result.att_off[off_mask])), 533 period=(int(c_start), int(c_end)), 534 ) 535 536 # Leave-one-out post-period. 537 if "loo" in requested and loo: 538 post_mask = time_on >= 0 539 post_idx = np.where(post_mask)[0] 540 if len(post_idx) > 1: 541 full_att = result.att_avg 542 loo_atts: list[float] = [] 543 loo_periods: list[float] = [] 544 for drop_i in post_idx: 545 remaining = np.delete(post_idx, np.where(post_idx == drop_i)) 546 if len(remaining) > 0: 547 loo_atts.append(float(np.mean(att_on[remaining]))) 548 loo_periods.append(float(time_on[drop_i])) 549 atts_arr = np.array(loo_atts) 550 diag.loo = LooResult( 551 atts=atts_arr, 552 periods=np.array(loo_periods), 553 max_change=float(np.max(np.abs(atts_arr - full_att))), 554 ) 555 556 diag.options = { 557 "tost_threshold": threshold_abs if needs_tost_threshold else None, 558 "tost_threshold_source": threshold_source, 559 "f_threshold": f_threshold, 560 "alpha": alpha, 561 "placebo_period": placebo_period, 562 "carryover_period": carryover_period, 563 "loo": loo, 564 "sigma2_fect": sigma2_fect_used if sigma2_fect_used is not None 565 else getattr(result, "sigma2_fect", None), 566 } 567 try: 568 from . import __version__ as _ver 569 diag.options["pyfector_version"] = _ver 570 except Exception: 571 pass 572 573 return diag
Run diagnostic tests on a FectResult.
Parameters
result : FectResult
Must have inference results (se=True).
f_threshold : float
Non-centrality parameter for equivalence F-test.
tost_threshold : float
Equivalence bound for TOST. The literal 0.36 (default or
explicit) triggers Liu et al. (2024)'s scale-aware bound,
0.36 * sqrt(result.sigma2_fect). Any other positive float
is taken as an absolute outcome-scale bound. None is
invalid.
placebo_period : (start, end), optional
Relative time window for placebo test. Restricted to
pre-treatment event times: (time_on >= start) &
(time_on <= end) & (time_on < 0). So (-3, 0) selects
rel_time ∈ {-3, -2, -1}.
carryover_period : (start, end), optional
Relative time window for carryover test.
loo : bool
If True, run leave-one-out post-period sensitivity.
alpha : float
Significance cutoff used for TostResult.all_pass.
155@dataclass 156class Diagnostics: 157 """Slim-safe registry of diagnostic test results. 158 159 ``tests`` is the future-proof extension point. Built-in tests expose 160 convenience properties for stable user code, but the container itself 161 does not need a new field every time pyfector adds a diagnostic. 162 """ 163 options: dict = field(default_factory=dict) 164 tests: dict[str, Any] = field(default_factory=dict) 165 166 @property 167 def available(self) -> tuple[str, ...]: 168 """Names of populated diagnostics, with built-ins first.""" 169 known = [name for name in _DIAGNOSTIC_SUMMARY_ORDER if name in self.tests] 170 extra = sorted(name for name in self.tests if name not in set(known)) 171 return tuple(known + extra) 172 173 def get(self, name: str, default: Any = None) -> Any: 174 """Return a diagnostic by name, or ``default`` if absent.""" 175 return self.tests.get(name, default) 176 177 def set_test(self, name: str, value: Any | None) -> None: 178 """Set or remove a diagnostic result by name.""" 179 if not name or not isinstance(name, str): 180 raise ValueError("diagnostic name must be a non-empty string") 181 if value is None: 182 self.tests.pop(name, None) 183 else: 184 self.tests[name] = value 185 186 def set(self, name: str, value: Any | None) -> None: 187 """Compatibility shortcut for :meth:`set_test`.""" 188 self.set_test(name, value) 189 190 def __contains__(self, name: str) -> bool: 191 return name in self.tests 192 193 def __getitem__(self, name: str) -> Any: 194 return self.tests[name] 195 196 def _typed(self, name: str, typ: type) -> Any | None: 197 value = self.tests.get(name) 198 return value if isinstance(value, typ) else None 199 200 @property 201 def tost(self) -> TostResult | None: 202 return self._typed("tost", TostResult) 203 204 @tost.setter 205 def tost(self, value: TostResult | None) -> None: 206 self.set_test("tost", value) 207 208 @property 209 def pretrend_f(self) -> PretrendFResult | None: 210 return self._typed("pretrend_f", PretrendFResult) 211 212 @pretrend_f.setter 213 def pretrend_f(self, value: PretrendFResult | None) -> None: 214 self.set_test("pretrend_f", value) 215 216 @property 217 def equiv_f(self) -> EquivFResult | None: 218 return self._typed("equiv_f", EquivFResult) 219 220 @equiv_f.setter 221 def equiv_f(self, value: EquivFResult | None) -> None: 222 self.set_test("equiv_f", value) 223 224 @property 225 def placebo(self) -> PlaceboResult | None: 226 return self._typed("placebo", PlaceboResult) 227 228 @placebo.setter 229 def placebo(self, value: PlaceboResult | None) -> None: 230 self.set_test("placebo", value) 231 232 @property 233 def carryover(self) -> CarryoverResult | None: 234 return self._typed("carryover", CarryoverResult) 235 236 @carryover.setter 237 def carryover(self, value: CarryoverResult | None) -> None: 238 self.set_test("carryover", value) 239 240 @property 241 def loo(self) -> LooResult | None: 242 return self._typed("loo", LooResult) 243 244 @loo.setter 245 def loo(self, value: LooResult | None) -> None: 246 self.set_test("loo", value) 247 248 def summary(self) -> str: 249 lines = ["Diagnostic Tests", "=" * 50] 250 for name in self.available: 251 sub = self.tests[name] 252 if sub is not None: 253 if hasattr(sub, "summary"): 254 lines.append(sub.summary()) 255 else: 256 lines.append(f"{name}: {sub!r}") 257 return "\n".join(lines)
Slim-safe registry of diagnostic test results.
tests is the future-proof extension point. Built-in tests expose
convenience properties for stable user code, but the container itself
does not need a new field every time pyfector adds a diagnostic.
166 @property 167 def available(self) -> tuple[str, ...]: 168 """Names of populated diagnostics, with built-ins first.""" 169 known = [name for name in _DIAGNOSTIC_SUMMARY_ORDER if name in self.tests] 170 extra = sorted(name for name in self.tests if name not in set(known)) 171 return tuple(known + extra)
Names of populated diagnostics, with built-ins first.
173 def get(self, name: str, default: Any = None) -> Any: 174 """Return a diagnostic by name, or ``default`` if absent.""" 175 return self.tests.get(name, default)
Return a diagnostic by name, or default if absent.
177 def set_test(self, name: str, value: Any | None) -> None: 178 """Set or remove a diagnostic result by name.""" 179 if not name or not isinstance(name, str): 180 raise ValueError("diagnostic name must be a non-empty string") 181 if value is None: 182 self.tests.pop(name, None) 183 else: 184 self.tests[name] = value
Set or remove a diagnostic result by name.
186 def set(self, name: str, value: Any | None) -> None: 187 """Compatibility shortcut for :meth:`set_test`.""" 188 self.set_test(name, value)
Compatibility shortcut for set_test().
51@dataclass(frozen=True) 52class TostResult: 53 """Per-period TOST equivalence test on pre-treatment coefficients.""" 54 pvals: np.ndarray 55 periods: np.ndarray 56 threshold: float 57 threshold_source: str 58 sigma2_fect: float | None 59 max_pval: float 60 all_pass: bool 61 62 def summary(self) -> str: 63 lines = [ 64 f"TOST per pre-period (threshold={self.threshold:.4f}, " 65 f"source={self.threshold_source}):" 66 ] 67 for i, t in enumerate(self.periods): 68 p = float(self.pvals[i]) 69 status = "PASS" if p < 0.05 else "fail" 70 lines.append(f" t={t:+.0f}: p={p:.4f} [{status}]") 71 lines.append( 72 f" max p={self.max_pval:.4f} all_pass={self.all_pass}" 73 ) 74 return "\n".join(lines)
Per-period TOST equivalence test on pre-treatment coefficients.
62 def summary(self) -> str: 63 lines = [ 64 f"TOST per pre-period (threshold={self.threshold:.4f}, " 65 f"source={self.threshold_source}):" 66 ] 67 for i, t in enumerate(self.periods): 68 p = float(self.pvals[i]) 69 status = "PASS" if p < 0.05 else "fail" 70 lines.append(f" t={t:+.0f}: p={p:.4f} [{status}]") 71 lines.append( 72 f" max p={self.max_pval:.4f} all_pass={self.all_pass}" 73 ) 74 return "\n".join(lines)
77@dataclass(frozen=True) 78class PretrendFResult: 79 """Joint F-test for all pre-treatment ATTs equal to zero.""" 80 f_stat: float 81 p_value: float 82 df1: int 83 df2: int 84 85 def summary(self) -> str: 86 return ( 87 f"Pre-trend F-test:\n" 88 f" F({self.df1},{self.df2}) = {self.f_stat:.4f}, " 89 f"p = {self.p_value:.4f}" 90 )
Joint F-test for all pre-treatment ATTs equal to zero.
93@dataclass(frozen=True) 94class EquivFResult: 95 """Equivalence F-test (non-central F) for pre-trend bound.""" 96 p_value: float 97 f_threshold: float 98 99 def summary(self) -> str: 100 return ( 101 f"Equivalence F-test:\n" 102 f" p = {self.p_value:.4f} (f_threshold={self.f_threshold})" 103 )
Equivalence F-test (non-central F) for pre-trend bound.
106@dataclass(frozen=True) 107class PlaceboResult: 108 """Placebo window mean ATT plus bootstrap percentile and TOST p-values.""" 109 estimate: float 110 se: float 111 p_value: float 112 equiv_p_value: float 113 period: tuple[int, int] 114 115 def summary(self) -> str: 116 return ( 117 f"Placebo test (window {self.period}):\n" 118 f" ATT={self.estimate:.4f} (SE={self.se:.4f}) " 119 f"p={self.p_value:.4f} equiv_p={self.equiv_p_value:.4f}" 120 )
Placebo window mean ATT plus bootstrap percentile and TOST p-values.
123@dataclass(frozen=True) 124class CarryoverResult: 125 """Carryover-window mean ATT (post-reversal). SE/p deferred.""" 126 estimate: float 127 period: tuple[int, int] 128 129 def summary(self) -> str: 130 return ( 131 f"Carryover test (window {self.period}):\n" 132 f" ATT={self.estimate:.4f}" 133 )
Carryover-window mean ATT (post-reversal). SE/p deferred.
136@dataclass(frozen=True) 137class LooResult: 138 """Leave-one-out post-period sensitivity.""" 139 atts: np.ndarray 140 periods: np.ndarray 141 max_change: float 142 143 def summary(self) -> str: 144 return ( 145 f"Leave-one-out:\n" 146 f" Max ATT change = {self.max_change:.6f}" 147 )
Leave-one-out post-period sensitivity.
20def plot( 21 result, 22 kind: Literal["gap", "status", "factors", "counterfactual", "equiv", "calendar"] = "gap", 23 units: list | None = None, 24 show_ci: bool = True, 25 title: str | None = None, 26 figsize: tuple[float, float] = (10, 6), 27 ax=None, 28 **kwargs, 29): 30 """Plot fect results. 31 32 Parameters 33 ---------- 34 result : FectResult 35 Output from ``pyfector.fect()``. 36 kind : str 37 Plot type. 38 units : list, optional 39 Unit IDs for counterfactual plot. 40 show_ci : bool 41 Show confidence intervals (requires ``se=True`` in estimation). 42 """ 43 import matplotlib.pyplot as plt 44 45 if ax is None: 46 fig, ax = plt.subplots(figsize=figsize) 47 else: 48 fig = ax.get_figure() 49 50 if kind == "gap": 51 _plot_gap(result, ax, show_ci, **kwargs) 52 elif kind == "status": 53 _plot_status(result, ax, **kwargs) 54 elif kind == "factors": 55 _plot_factors(result, ax, **kwargs) 56 elif kind == "counterfactual": 57 _plot_counterfactual(result, ax, units, **kwargs) 58 elif kind == "equiv": 59 _plot_equiv(result, ax, **kwargs) 60 elif kind == "calendar": 61 _plot_calendar(result, ax, show_ci, **kwargs) 62 else: 63 raise ValueError(f"Unknown plot kind: {kind}") 64 65 if title: 66 ax.set_title(title) 67 68 fig.tight_layout() 69 return fig
Plot fect results.
Parameters
result : FectResult
Output from pyfector.fect.
kind : str
Plot type.
units : list, optional
Unit IDs for counterfactual plot.
show_ci : bool
Show confidence intervals (requires se=True in estimation).