pyfector.fect
Public entry point for pyfector.
This module exposes fect(), the single function most users call,
and the FectResult dataclass that packages its output.
fect() orchestrates the full pipeline:
pyfector.panel.prepare_panel()converts the input DataFrame to dense(T, N)matrices with treatment timing and unit classification.pyfector.panel.initial_fit()computes the starting fitY0and initial covariate coefficients on the control subsample.- If the user passes a range for
r(IFE) or leaveslam=None(MC),pyfector.cvchooses the hyperparameter by cross-validation. pyfector.estimatorsiterates the EM loop for the chosen method (fe,ife,mc, orcfe).- Treatment effects are computed from
eff = Y - Y_ct; the overall ATT and dynamic ATT by relative event time are derived in_compute_effects()vianumpy.bincount(). Raw missing outcomes remain excluded from effect averages; model imputations are counterfactual predictions, not substitutes for unobserved treated outcomes. - When
se=Truethe requested inference routine (bootstraporjackknife) frompyfector.inferenceis called with a closure that re-runs the EM loop on resampled units.
Example
::
import pyfector
result = pyfector.fect(
data=df,
Y="outcome", D="treat",
index=("unit", "year"),
X=["gdp", "pop"],
method="ife",
r=(0, 5), # CV over 0..5 factors
se=True,
nboots=200,
device="cpu", # or "gpu"
n_jobs=4,
seed=42,
)
result.summary()
result.plot(kind="gap")
1""" 2Public entry point for pyfector. 3 4This module exposes :func:`fect`, the single function most users call, 5and the :class:`FectResult` dataclass that packages its output. 6 7:func:`fect` orchestrates the full pipeline: 8 91. :func:`pyfector.panel.prepare_panel` converts the input DataFrame to 10 dense ``(T, N)`` matrices with treatment timing and unit 11 classification. 122. :func:`pyfector.panel.initial_fit` computes the starting fit ``Y0`` 13 and initial covariate coefficients on the control subsample. 143. If the user passes a range for ``r`` (IFE) or leaves ``lam=None`` 15 (MC), :mod:`pyfector.cv` chooses the hyperparameter by 16 cross-validation. 174. :mod:`pyfector.estimators` iterates the EM loop for the chosen 18 method (``fe``, ``ife``, ``mc``, or ``cfe``). 195. Treatment effects are computed from ``eff = Y - Y_ct``; the overall 20 ATT and dynamic ATT by relative event time are derived in 21 :func:`_compute_effects` via :func:`numpy.bincount`. Raw missing 22 outcomes remain excluded from effect averages; model imputations are 23 counterfactual predictions, not substitutes for unobserved treated 24 outcomes. 256. When ``se=True`` the requested inference routine 26 (``bootstrap`` or ``jackknife``) from :mod:`pyfector.inference` is 27 called with a closure that re-runs the EM loop on resampled units. 28 29Example 30------- 31:: 32 33 import pyfector 34 35 result = pyfector.fect( 36 data=df, 37 Y="outcome", D="treat", 38 index=("unit", "year"), 39 X=["gdp", "pop"], 40 method="ife", 41 r=(0, 5), # CV over 0..5 factors 42 se=True, 43 nboots=200, 44 device="cpu", # or "gpu" 45 n_jobs=4, 46 seed=42, 47 ) 48 result.summary() 49 result.plot(kind="gap") 50""" 51 52from __future__ import annotations 53 54from dataclasses import dataclass, field 55from typing import Literal 56import os 57 58import numpy as np 59 60from .backend import set_device, get_backend, to_numpy, to_device, make_rng 61from .panel import PanelData, prepare_panel, initial_fit 62from .estimators import estimate_ife, estimate_mc, estimate_cfe, EstimationResult 63from .cv import cv_ife, cv_mc, CVResult 64from .inference import bootstrap, jackknife, InferenceResult 65 66 67@dataclass 68class FectResult: 69 """Container for all fect estimation results.""" 70 # Method info 71 method: str 72 r_cv: int | None = None 73 lambda_cv: float | None = None 74 75 # Point estimates 76 att_avg: float = 0.0 77 att_avg_unit: float = 0.0 78 79 # Dynamic effects 80 att_on: np.ndarray | None = None 81 time_on: np.ndarray | None = None 82 count_on: np.ndarray | None = None 83 84 # Exit effects (treatment reversal) 85 att_off: np.ndarray | None = None 86 time_off: np.ndarray | None = None 87 88 # Coefficients 89 beta: np.ndarray | None = None 90 covariate_names: list[str] = field(default_factory=list) 91 92 # Fixed effects 93 mu: float = 0.0 94 alpha: np.ndarray | None = None # unit FE 95 xi: np.ndarray | None = None # time FE 96 factors: np.ndarray | None = None 97 loadings: np.ndarray | None = None 98 99 # Counterfactual and effects matrices 100 Y_ct: np.ndarray | None = None # T×N counterfactual 101 eff: np.ndarray | None = None # T×N treatment effects 102 residuals: np.ndarray | None = None 103 104 # Model fit 105 sigma2: float = 0.0 106 IC: float = 0.0 107 PC: float = 0.0 108 rmse: float = 0.0 109 niter: int = 0 110 converged: bool = False 111 112 # Inference 113 inference: InferenceResult | None = None 114 115 # CV 116 cv_result: CVResult | None = None 117 118 # Panel metadata 119 panel: PanelData | None = None 120 121 # Reproducibility 122 seed: int | None = None 123 124 def summary(self) -> str: 125 """Print summary table of results.""" 126 lines = [] 127 lines.append(f"pyfector estimation results") 128 lines.append(f"{'='*60}") 129 lines.append(f"Method: {self.method}") 130 if self.r_cv is not None: 131 lines.append(f"Number of factors (CV): {self.r_cv}") 132 if self.lambda_cv is not None: 133 lines.append(f"Lambda (CV): {self.lambda_cv:.6f}") 134 lines.append(f"Converged: {self.converged} (iter={self.niter})") 135 lines.append(f"Sigma^2: {self.sigma2:.6f}") 136 lines.append(f"") 137 lines.append(f"ATT (average): {self.att_avg:.6f}") 138 if self.inference is not None: 139 inf = self.inference 140 lines.append(f" SE: {inf.att_avg_se:.6f}") 141 lines.append(f" CI: [{inf.att_avg_ci[0]:.6f}, {inf.att_avg_ci[1]:.6f}]") 142 lines.append(f" p-val: {inf.att_avg_pval:.4f}") 143 144 if self.beta is not None and len(self.beta) > 0: 145 lines.append(f"") 146 lines.append(f"Coefficients:") 147 for i, name in enumerate(self.covariate_names): 148 lines.append(f" {name}: {self.beta[i]:.6f}") 149 150 if self.att_on is not None and self.time_on is not None: 151 lines.append(f"") 152 lines.append(f"Dynamic effects (ATT by relative time):") 153 lines.append(f" {'Time':>6s} {'ATT':>10s} {'Count':>6s}", ) 154 for i, t in enumerate(self.time_on): 155 count = self.count_on[i] if self.count_on is not None else "" 156 att = self.att_on[i] 157 if self.inference is not None: 158 se = self.inference.att_on_se[i] 159 lines.append(f" {t:>6.0f} {att:>10.4f} ({se:.4f}) {count}") 160 else: 161 lines.append(f" {t:>6.0f} {att:>10.4f} {count}") 162 163 lines.append(f"{'='*60}") 164 if self.panel is not None: 165 lines.append(f"N={self.panel.N}, T={self.panel.T}") 166 if self.seed is not None: 167 lines.append(f"Seed: {self.seed}") 168 return "\n".join(lines) 169 170 def __repr__(self): 171 return self.summary() 172 173 def plot(self, kind="gap", **kwargs): 174 """Plot results. Shortcut for ``pyfector.plot(self, kind, ...)``.""" 175 from .plotting import plot as _plot 176 return _plot(self, kind=kind, **kwargs) 177 178 def diagnose(self, **kwargs): 179 """Run diagnostic tests. Shortcut for ``pyfector.run_diagnostics(self, ...)``.""" 180 from .diagnostics import run_diagnostics 181 return run_diagnostics(self, **kwargs) 182 183 184def fect( 185 data, 186 Y: str, 187 D: str, 188 index: tuple[str, str], 189 X: list[str] | None = None, 190 W: str | None = None, 191 group: str | None = None, 192 method: Literal["fe", "ife", "mc", "cfe", "both"] = "ife", 193 force: Literal["none", "unit", "time", "two-way"] = "two-way", 194 r: int | tuple[int, int] = 0, 195 lam: float | None = None, 196 nlambda: int = 10, 197 CV: bool = True, 198 k: int = 10, 199 cv_prop: float = 0.1, 200 cv_nobs: int = 3, 201 cv_treat: bool = True, 202 cv_donut: int = 0, 203 criterion: str = "mspe", 204 cv_rule: Literal["min", "onepct"] = "min", 205 se: bool = False, 206 vartype: Literal["bootstrap", "jackknife"] = "bootstrap", 207 nboots: int = 200, 208 alpha: float = 0.05, 209 tol: float = 1e-7, 210 max_iter: int = 5000, 211 min_T0: int = 1, 212 min_T0_strict: bool = False, 213 max_missing: float = 1.0, 214 normalize: bool = False, 215 # CFE-specific 216 Z: list[str] | None = None, 217 Q: list[str] | None = None, 218 # Performance 219 device: Literal["cpu", "gpu"] = "cpu", 220 n_jobs: int | None = -1, 221 seed: int | None = None, 222) -> FectResult: 223 """Estimate counterfactual treatment effects for panel data. 224 225 This is the main Python entry point for the counterfactual estimator 226 workflow. Where the paper and the historical R package differ, 227 pyfector defaults to the paper's statistical definition and exposes 228 R-package-style behavior through explicit options. 229 230 Missing outcome policy 231 ---------------------- 232 pyfector distinguishes raw missing outcomes from counterfactual 233 missingness caused by treatment. Observed untreated cells 234 (``D == 0`` and non-missing ``Y``) fit the response surface. Observed 235 treated cells (``D == 1`` and non-missing ``Y``) contribute to ATT as 236 ``Y - Y_ct``. If a treated outcome is missing in the input data, the 237 model can still produce a counterfactual ``Y_ct`` for that cell, but 238 the cell is not counted in ``att_avg`` or ``att_on`` because the 239 treated potential outcome was not observed. 240 241 By default, ``min_T0`` is enforced only for treated and reversal 242 units. Sparse controls are retained if they have at least one 243 observed outcome, because they may still inform the low-rank response 244 surface. Set ``min_T0_strict=True`` to require controls to satisfy 245 ``min_T0`` too, matching the more conservative R fect sparse-panel 246 behavior. 247 248 Parameters 249 ---------- 250 data : polars.DataFrame, pandas.DataFrame 251 Long-format panel data. 252 Y, D : str 253 Column names for outcome and binary treatment indicator. 254 index : (str, str) 255 Column names for (unit_id, time_period). 256 X : list of str, optional 257 Time-varying covariates. 258 W : str, optional 259 Observation weight column. 260 group : str, optional 261 Reserved for grouped estimation. Currently raises 262 ``NotImplementedError`` when supplied. 263 method : {"fe", "ife", "mc", "cfe", "both"} 264 Estimation method. 265 force : {"none", "unit", "time", "two-way"} 266 Fixed effects specification. 267 r : int or (int, int) 268 Number of factors. If tuple, CV selects from range. 269 lam : float, optional 270 Nuclear norm penalty for MC. If None with CV=True, auto-selected. 271 nlambda : int 272 Number of automatically generated lambda candidates for MC CV. 273 CV : bool 274 If True, cross-validate over ``r`` for IFE when ``r`` is a tuple, 275 or over ``lam`` for MC when ``lam`` is None. 276 k : int 277 Number of CV folds. 278 cv_prop : float 279 Fraction of eligible observed control cells masked per CV fold. 280 cv_nobs : int 281 Number of consecutive within-unit observations to mask as a block. 282 cv_treat : bool 283 If True, restrict CV masks to pre-treatment cells of ever-treated 284 units. If False, use all observed control cells. 285 cv_donut : int 286 Exclude this many periods around treatment onset from CV evaluation. 287 criterion : {"mspe", "gmspe", "mad"} 288 Cross-validation loss. 289 cv_rule : {"min", "onepct"} 290 CV selection rule. ``"min"`` chooses the strict minimum-score 291 candidate and is the paper-faithful default. ``"onepct"`` chooses 292 the simplest candidate within 1% of the best score (lower ``r`` for 293 IFE, higher ``lam`` for MC). 294 se : bool 295 Compute standard errors via bootstrap/jackknife. 296 vartype : {"bootstrap", "jackknife"} 297 Inference method when ``se=True``. 298 nboots : int 299 Number of bootstrap replications. Ignored for jackknife. 300 alpha : float 301 Significance level for confidence intervals and tests. 302 tol : float 303 EM convergence tolerance for final point estimation. 304 max_iter : int 305 Maximum EM iterations. 306 min_T0 : int 307 Minimum untreated/pre-treatment observed periods. By default this is 308 enforced only for treated and treatment-reversal units. 309 min_T0_strict : bool 310 If True, enforce ``min_T0`` on all units, including controls. This 311 matches R fect's conservative handling of sparse control rows. 312 max_missing : float 313 Maximum missing-outcome fraction per unit, in ``[0, 1]``. Units with 314 no observed outcomes are always dropped, regardless of this threshold, 315 because they provide neither fitting information nor observed treated 316 effects. 317 normalize : bool 318 If True, estimate on an outcome standardized by its observed standard 319 deviation, then transform effects back to the original scale. 320 Z, Q : list of str, optional 321 Reserved CFE interaction arguments. Currently raise 322 ``NotImplementedError`` when supplied. 323 device : {"cpu", "gpu"} 324 Compute device. 325 n_jobs : int, optional 326 Parallel workers for CV and bootstrap. ``-1`` or ``None`` uses 327 all available CPUs. 328 seed : int, optional 329 Random seed for full reproducibility. 330 """ 331 # Set device 332 set_device(device) 333 xp = get_backend() 334 n_jobs = _resolve_n_jobs(n_jobs) 335 336 if group is not None: 337 raise NotImplementedError("The `group` argument is not implemented yet.") 338 if Z is not None or Q is not None: 339 raise NotImplementedError("The `Z` and `Q` CFE interaction arguments are not implemented yet.") 340 if criterion not in {"mspe", "gmspe", "mad"}: 341 raise ValueError("criterion must be 'mspe', 'gmspe', or 'mad'") 342 if cv_rule not in {"min", "onepct"}: 343 raise ValueError("cv_rule must be 'min' or 'onepct'") 344 if min_T0 < 0: 345 raise ValueError("min_T0 must be non-negative") 346 if not 0.0 <= max_missing <= 1.0: 347 raise ValueError("max_missing must be between 0 and 1") 348 349 # Map force string to int 350 force_map = {"none": 0, "unit": 1, "time": 2, "two-way": 3} 351 force_int = force_map[force] 352 353 # Prepare panel data 354 panel = prepare_panel( 355 data, Y=Y, D=D, index=index, X=X, W=W, 356 group=group, min_T0=min_T0, min_T0_strict=min_T0_strict, 357 max_missing=max_missing, 358 ) 359 360 # Move to device 361 Y_mat = to_device(panel.Y) 362 D_mat = to_device(panel.D) 363 I_mat = to_device(panel.I) 364 II_mat = to_device(panel.II) 365 X_mat = to_device(panel.X) if panel.X is not None else None 366 W_mat = to_device(panel.W) if panel.W is not None else None 367 368 # Normalize 369 norm_factor = 1.0 370 if normalize: 371 sd_y = float(xp.std(Y_mat[I_mat > 0])) 372 if sd_y > 0: 373 Y_mat = Y_mat / sd_y 374 norm_factor = sd_y 375 376 # Initial fit 377 Y0, beta0 = initial_fit(Y_mat, X_mat, II_mat, force_int) 378 379 # Determine r and lambda 380 r_cv = None 381 lambda_cv = None 382 cv_result = None 383 384 if method == "ife": 385 if isinstance(r, tuple) and CV: 386 cv_result = cv_ife( 387 Y_mat, Y0, X_mat, I_mat, II_mat, D_mat, W_mat, beta0, 388 force=force_int, r_range=r, k=k, cv_prop=cv_prop, 389 cv_nobs=cv_nobs, cv_treat=cv_treat, cv_donut=cv_donut, 390 criterion=criterion, cv_rule=cv_rule, 391 tol=tol, max_iter=max_iter, 392 n_jobs=n_jobs, seed=seed, 393 ) 394 r_cv = cv_result.best_r 395 else: 396 r_cv = r if isinstance(r, int) else r[0] 397 398 elif method == "mc": 399 if lam is None and CV: 400 cv_result = cv_mc( 401 Y_mat, Y0, X_mat, I_mat, II_mat, D_mat, W_mat, beta0, 402 force=force_int, nlambda=nlambda, k=k, cv_prop=cv_prop, 403 cv_nobs=cv_nobs, cv_treat=cv_treat, cv_donut=cv_donut, 404 criterion=criterion, cv_rule=cv_rule, 405 tol=tol, max_iter=max_iter, 406 n_jobs=n_jobs, seed=seed, 407 ) 408 lambda_cv = cv_result.best_lambda 409 else: 410 lambda_cv = lam if lam is not None else 0.0 411 412 elif method == "fe": 413 r_cv = 0 414 415 elif method == "cfe": 416 r_cv = r if isinstance(r, int) else r[0] 417 418 # Point estimation 419 if method in ("fe", "ife"): 420 est = estimate_ife( 421 Y_mat, Y0, X_mat, II_mat, W_mat, beta0, 422 r=r_cv, force=force_int, tol=tol, max_iter=max_iter, 423 ) 424 elif method == "mc": 425 est = estimate_mc( 426 Y_mat, Y0, X_mat, II_mat, W_mat, beta0, 427 lam=lambda_cv, force=force_int, tol=tol, max_iter=max_iter, 428 ) 429 elif method == "cfe": 430 est = estimate_cfe( 431 Y_mat, Y0, X_mat, II_mat, W_mat, beta0, 432 r=r_cv, force=force_int, tol=tol, max_iter=max_iter, 433 ) 434 elif method == "both": 435 # Run both IFE and MC, return IFE results with MC comparison 436 if isinstance(r, tuple) and CV: 437 cv_result = cv_ife( 438 Y_mat, Y0, X_mat, I_mat, II_mat, D_mat, W_mat, beta0, 439 force=force_int, r_range=r, k=k, cv_prop=cv_prop, 440 cv_nobs=cv_nobs, cv_treat=cv_treat, cv_donut=cv_donut, 441 criterion=criterion, cv_rule=cv_rule, 442 tol=tol, max_iter=max_iter, 443 n_jobs=n_jobs, seed=seed, 444 ) 445 r_cv = cv_result.best_r 446 else: 447 r_cv = r if isinstance(r, int) else r[0] 448 est = estimate_ife( 449 Y_mat, Y0, X_mat, II_mat, W_mat, beta0, 450 r=r_cv, force=force_int, tol=tol, max_iter=max_iter, 451 ) 452 else: 453 raise ValueError(f"Unknown method: {method}") 454 455 # Compute effects 456 eff = Y_mat - est.fit 457 Y_ct = est.fit 458 459 # Denormalize 460 if normalize and norm_factor != 1.0: 461 eff = eff * norm_factor 462 Y_ct = Y_ct * norm_factor 463 Y_mat = Y_mat * norm_factor 464 if est.beta is not None: 465 est = est._replace(beta=est.beta * norm_factor) 466 467 # ATT computation 468 T_on = to_device(panel.T_on) 469 att_avg, att_on, time_on, count_on, att_avg_unit = _compute_effects( 470 to_numpy(eff), to_numpy(D_mat), to_numpy(panel.T_on), to_numpy(I_mat), 471 ) 472 473 # Build result 474 result = FectResult( 475 method=method, 476 r_cv=r_cv, 477 lambda_cv=lambda_cv, 478 att_avg=att_avg, 479 att_avg_unit=att_avg_unit, 480 att_on=att_on, 481 time_on=time_on, 482 count_on=count_on, 483 beta=to_numpy(est.beta) if est.beta is not None else None, 484 covariate_names=panel.covariate_names, 485 mu=est.mu, 486 alpha=to_numpy(est.alpha) if est.alpha is not None else None, 487 xi=to_numpy(est.xi) if est.xi is not None else None, 488 factors=to_numpy(est.factors) if est.factors is not None else None, 489 loadings=to_numpy(est.loadings) if est.loadings is not None else None, 490 Y_ct=to_numpy(Y_ct), 491 eff=to_numpy(eff), 492 residuals=to_numpy(est.residuals), 493 sigma2=est.sigma2, 494 IC=est.IC, 495 PC=est.PC, 496 niter=est.niter, 497 converged=est.converged, 498 cv_result=cv_result, 499 panel=panel, 500 seed=seed, 501 ) 502 503 # Inference 504 if se: 505 result.inference = _run_inference( 506 result, panel, Y_mat, X_mat, W_mat, beta0, Y0, 507 method=method, r_cv=r_cv, lambda_cv=lambda_cv, 508 force_int=force_int, tol=tol, max_iter=max_iter, 509 vartype=vartype, nboots=nboots, alpha=alpha, 510 n_jobs=n_jobs, seed=seed, normalize=normalize, 511 norm_factor=norm_factor, 512 ) 513 514 return result 515 516 517def _compute_effects(eff, D, T_on, I): 518 """Compute overall ATT, per-unit ATT, and dynamic ATT by event time. 519 520 ATT averages are defined only over observed outcome cells. A missing 521 raw treated outcome has no observed ``Y(1)``, so it is excluded even 522 though the estimator may have produced a counterfactual ``Y_ct`` for 523 that matrix position. 524 525 Dynamic ATT grouping is done with :func:`numpy.bincount` so the cost 526 is O(n_observed_ever_treated_cells) regardless of the number of 527 distinct relative-time values. Per-unit ATT uses a column-wise 528 masked mean via ``np.add.reduce`` rather than a Python loop. 529 """ 530 treated = (D > 0) & (I > 0) 531 n_treated = int(np.sum(treated)) 532 533 att_avg = float(np.sum(eff[treated]) / max(n_treated, 1)) 534 535 # Per-unit ATT (post-treatment only) — sum then divide by count, 536 # keeping only units that have at least one treated observation. 537 col_counts = treated.sum(axis=0) # (N,) 538 col_sums = (eff * treated).sum(axis=0) # (N,) 539 has_any = col_counts > 0 540 if np.any(has_any): 541 unit_atts = col_sums[has_any] / col_counts[has_any] 542 att_avg_unit = float(unit_atts.mean()) 543 else: 544 att_avg_unit = 0.0 545 546 # Dynamic ATT by relative event time. Include pre-treatment periods 547 # for ever-treated units (counterfactual gaps before onset). 548 ever_treated = np.any(D > 0, axis=0) # (N,) 549 all_periods = (I > 0) & ever_treated[np.newaxis, :] # (T, N) 550 551 T_on_flat = T_on[all_periods] 552 eff_flat = eff[all_periods] 553 valid = ~np.isnan(T_on_flat) 554 T_on_flat = T_on_flat[valid].astype(np.int64) 555 eff_flat = eff_flat[valid] 556 557 if T_on_flat.size == 0: 558 return att_avg, np.array([]), np.array([]), np.array([], dtype=np.int64), att_avg_unit 559 560 # Bincount over the shifted integer indices. 561 offset = int(T_on_flat.min()) 562 idx = T_on_flat - offset 563 minlength = int(T_on_flat.max() - offset + 1) 564 sums = np.bincount(idx, weights=eff_flat, minlength=minlength) 565 counts = np.bincount(idx, minlength=minlength) 566 keep = counts > 0 567 time_on = (offset + np.arange(minlength))[keep].astype(np.float64) 568 att_on = (sums[keep] / counts[keep]).astype(np.float64) 569 count_on = counts[keep].astype(np.int64) 570 571 return att_avg, att_on, time_on, count_on, att_avg_unit 572 573 574def _run_inference( 575 result, panel, Y_mat, X_mat, W_mat, beta0, Y0, 576 method, r_cv, lambda_cv, force_int, tol, max_iter, 577 vartype, nboots, alpha, n_jobs, seed, normalize, norm_factor, 578): 579 """Run bootstrap or jackknife inference. 580 581 Bootstrap replicates use a relaxed convergence tolerance 582 (``max(tol, 1e-3)``) because bootstrap SEs converge to 3–4 583 significant digits regardless of inner precision. The full-sample 584 fit is used as the warm-start initialiser for each replicate, 585 which cuts EM iterations by ~30-60 %. 586 """ 587 xp = get_backend() 588 589 # Relaxed tolerance: bootstrap SEs don't benefit from tight inner tol. 590 boot_tol = max(tol, 1e-3) 591 592 # Pre-move panel data to device once (avoids CPU→GPU copy per bootstrap rep) 593 II_dev = to_device(panel.II) 594 D_dev = to_device(panel.D) 595 I_dev = to_device(panel.I) 596 T_on_np = panel.T_on # keep on CPU for ATT computation 597 598 # Warm-start: use full-sample fitted values as initial imputation 599 # for each replicate instead of recomputing initial_fit from scratch. 600 Y0_full = to_device(result.Y_ct) 601 602 def _estimate_fn(unit_idx): 603 """Re-estimate on a subset of units.""" 604 Y_sub = Y_mat[:, unit_idx] 605 II_sub = II_dev[:, unit_idx] 606 D_sub = D_dev[:, unit_idx] 607 I_sub = I_dev[:, unit_idx] 608 X_sub = X_mat[:, unit_idx, :] if X_mat is not None else None 609 W_sub = W_mat[:, unit_idx] if W_mat is not None else None 610 T_on_sub = T_on_np[:, unit_idx] 611 beta0_sub = beta0 612 613 # Warm-start from full-sample fit (columns for resampled units). 614 # This is much closer to the replicate's solution than a cold 615 # initial_fit, cutting EM iterations significantly. 616 Y0_sub = Y0_full[:, unit_idx].copy() 617 618 if method in ("fe", "ife"): 619 est = estimate_ife( 620 Y_sub, Y0_sub, X_sub, II_sub, W_sub, beta0_sub, 621 r=r_cv, force=force_int, tol=boot_tol, max_iter=max_iter, 622 ) 623 elif method == "mc": 624 est = estimate_mc( 625 Y_sub, Y0_sub, X_sub, II_sub, W_sub, beta0_sub, 626 lam=lambda_cv, force=force_int, tol=boot_tol, max_iter=max_iter, 627 ) 628 else: 629 est = estimate_ife( 630 Y_sub, Y0_sub, X_sub, II_sub, W_sub, beta0_sub, 631 r=r_cv or 0, force=force_int, tol=boot_tol, max_iter=max_iter, 632 ) 633 634 eff = to_numpy(Y_sub - est.fit) 635 if normalize and norm_factor != 1.0: 636 eff = eff * norm_factor 637 638 return eff, to_numpy(D_sub), T_on_sub, to_numpy(I_sub) 639 640 if vartype == "bootstrap": 641 return bootstrap( 642 _estimate_fn, to_numpy(Y_mat), to_numpy(panel.D), 643 to_numpy(panel.I), panel.T_on, panel.unit_type, 644 nboots=nboots, alpha=alpha, n_jobs=n_jobs, seed=seed, 645 ) 646 else: 647 return jackknife( 648 _estimate_fn, to_numpy(Y_mat), to_numpy(panel.D), 649 to_numpy(panel.I), panel.T_on, panel.unit_type, 650 alpha=alpha, n_jobs=n_jobs, 651 ) 652 653 654def _resolve_n_jobs(n_jobs: int | None) -> int: 655 """Normalize public n_jobs values before passing them to joblib.""" 656 if n_jobs is None or n_jobs == -1: 657 return os.cpu_count() or 1 658 if n_jobs == 0 or n_jobs < -1: 659 raise ValueError("n_jobs must be a positive integer, -1, or None") 660 return int(n_jobs)
68@dataclass 69class FectResult: 70 """Container for all fect estimation results.""" 71 # Method info 72 method: str 73 r_cv: int | None = None 74 lambda_cv: float | None = None 75 76 # Point estimates 77 att_avg: float = 0.0 78 att_avg_unit: float = 0.0 79 80 # Dynamic effects 81 att_on: np.ndarray | None = None 82 time_on: np.ndarray | None = None 83 count_on: np.ndarray | None = None 84 85 # Exit effects (treatment reversal) 86 att_off: np.ndarray | None = None 87 time_off: np.ndarray | None = None 88 89 # Coefficients 90 beta: np.ndarray | None = None 91 covariate_names: list[str] = field(default_factory=list) 92 93 # Fixed effects 94 mu: float = 0.0 95 alpha: np.ndarray | None = None # unit FE 96 xi: np.ndarray | None = None # time FE 97 factors: np.ndarray | None = None 98 loadings: np.ndarray | None = None 99 100 # Counterfactual and effects matrices 101 Y_ct: np.ndarray | None = None # T×N counterfactual 102 eff: np.ndarray | None = None # T×N treatment effects 103 residuals: np.ndarray | None = None 104 105 # Model fit 106 sigma2: float = 0.0 107 IC: float = 0.0 108 PC: float = 0.0 109 rmse: float = 0.0 110 niter: int = 0 111 converged: bool = False 112 113 # Inference 114 inference: InferenceResult | None = None 115 116 # CV 117 cv_result: CVResult | None = None 118 119 # Panel metadata 120 panel: PanelData | None = None 121 122 # Reproducibility 123 seed: int | None = None 124 125 def summary(self) -> str: 126 """Print summary table of results.""" 127 lines = [] 128 lines.append(f"pyfector estimation results") 129 lines.append(f"{'='*60}") 130 lines.append(f"Method: {self.method}") 131 if self.r_cv is not None: 132 lines.append(f"Number of factors (CV): {self.r_cv}") 133 if self.lambda_cv is not None: 134 lines.append(f"Lambda (CV): {self.lambda_cv:.6f}") 135 lines.append(f"Converged: {self.converged} (iter={self.niter})") 136 lines.append(f"Sigma^2: {self.sigma2:.6f}") 137 lines.append(f"") 138 lines.append(f"ATT (average): {self.att_avg:.6f}") 139 if self.inference is not None: 140 inf = self.inference 141 lines.append(f" SE: {inf.att_avg_se:.6f}") 142 lines.append(f" CI: [{inf.att_avg_ci[0]:.6f}, {inf.att_avg_ci[1]:.6f}]") 143 lines.append(f" p-val: {inf.att_avg_pval:.4f}") 144 145 if self.beta is not None and len(self.beta) > 0: 146 lines.append(f"") 147 lines.append(f"Coefficients:") 148 for i, name in enumerate(self.covariate_names): 149 lines.append(f" {name}: {self.beta[i]:.6f}") 150 151 if self.att_on is not None and self.time_on is not None: 152 lines.append(f"") 153 lines.append(f"Dynamic effects (ATT by relative time):") 154 lines.append(f" {'Time':>6s} {'ATT':>10s} {'Count':>6s}", ) 155 for i, t in enumerate(self.time_on): 156 count = self.count_on[i] if self.count_on is not None else "" 157 att = self.att_on[i] 158 if self.inference is not None: 159 se = self.inference.att_on_se[i] 160 lines.append(f" {t:>6.0f} {att:>10.4f} ({se:.4f}) {count}") 161 else: 162 lines.append(f" {t:>6.0f} {att:>10.4f} {count}") 163 164 lines.append(f"{'='*60}") 165 if self.panel is not None: 166 lines.append(f"N={self.panel.N}, T={self.panel.T}") 167 if self.seed is not None: 168 lines.append(f"Seed: {self.seed}") 169 return "\n".join(lines) 170 171 def __repr__(self): 172 return self.summary() 173 174 def plot(self, kind="gap", **kwargs): 175 """Plot results. Shortcut for ``pyfector.plot(self, kind, ...)``.""" 176 from .plotting import plot as _plot 177 return _plot(self, kind=kind, **kwargs) 178 179 def diagnose(self, **kwargs): 180 """Run diagnostic tests. Shortcut for ``pyfector.run_diagnostics(self, ...)``.""" 181 from .diagnostics import run_diagnostics 182 return run_diagnostics(self, **kwargs)
Container for all fect estimation results.
125 def summary(self) -> str: 126 """Print summary table of results.""" 127 lines = [] 128 lines.append(f"pyfector estimation results") 129 lines.append(f"{'='*60}") 130 lines.append(f"Method: {self.method}") 131 if self.r_cv is not None: 132 lines.append(f"Number of factors (CV): {self.r_cv}") 133 if self.lambda_cv is not None: 134 lines.append(f"Lambda (CV): {self.lambda_cv:.6f}") 135 lines.append(f"Converged: {self.converged} (iter={self.niter})") 136 lines.append(f"Sigma^2: {self.sigma2:.6f}") 137 lines.append(f"") 138 lines.append(f"ATT (average): {self.att_avg:.6f}") 139 if self.inference is not None: 140 inf = self.inference 141 lines.append(f" SE: {inf.att_avg_se:.6f}") 142 lines.append(f" CI: [{inf.att_avg_ci[0]:.6f}, {inf.att_avg_ci[1]:.6f}]") 143 lines.append(f" p-val: {inf.att_avg_pval:.4f}") 144 145 if self.beta is not None and len(self.beta) > 0: 146 lines.append(f"") 147 lines.append(f"Coefficients:") 148 for i, name in enumerate(self.covariate_names): 149 lines.append(f" {name}: {self.beta[i]:.6f}") 150 151 if self.att_on is not None and self.time_on is not None: 152 lines.append(f"") 153 lines.append(f"Dynamic effects (ATT by relative time):") 154 lines.append(f" {'Time':>6s} {'ATT':>10s} {'Count':>6s}", ) 155 for i, t in enumerate(self.time_on): 156 count = self.count_on[i] if self.count_on is not None else "" 157 att = self.att_on[i] 158 if self.inference is not None: 159 se = self.inference.att_on_se[i] 160 lines.append(f" {t:>6.0f} {att:>10.4f} ({se:.4f}) {count}") 161 else: 162 lines.append(f" {t:>6.0f} {att:>10.4f} {count}") 163 164 lines.append(f"{'='*60}") 165 if self.panel is not None: 166 lines.append(f"N={self.panel.N}, T={self.panel.T}") 167 if self.seed is not None: 168 lines.append(f"Seed: {self.seed}") 169 return "\n".join(lines)
Print summary table of results.
174 def plot(self, kind="gap", **kwargs): 175 """Plot results. Shortcut for ``pyfector.plot(self, kind, ...)``.""" 176 from .plotting import plot as _plot 177 return _plot(self, kind=kind, **kwargs)
Plot results. Shortcut for pyfector.plot(self, kind, ...).
179 def diagnose(self, **kwargs): 180 """Run diagnostic tests. Shortcut for ``pyfector.run_diagnostics(self, ...)``.""" 181 from .diagnostics import run_diagnostics 182 return run_diagnostics(self, **kwargs)
Run diagnostic tests. Shortcut for pyfector.run_diagnostics(self, ...).
185def fect( 186 data, 187 Y: str, 188 D: str, 189 index: tuple[str, str], 190 X: list[str] | None = None, 191 W: str | None = None, 192 group: str | None = None, 193 method: Literal["fe", "ife", "mc", "cfe", "both"] = "ife", 194 force: Literal["none", "unit", "time", "two-way"] = "two-way", 195 r: int | tuple[int, int] = 0, 196 lam: float | None = None, 197 nlambda: int = 10, 198 CV: bool = True, 199 k: int = 10, 200 cv_prop: float = 0.1, 201 cv_nobs: int = 3, 202 cv_treat: bool = True, 203 cv_donut: int = 0, 204 criterion: str = "mspe", 205 cv_rule: Literal["min", "onepct"] = "min", 206 se: bool = False, 207 vartype: Literal["bootstrap", "jackknife"] = "bootstrap", 208 nboots: int = 200, 209 alpha: float = 0.05, 210 tol: float = 1e-7, 211 max_iter: int = 5000, 212 min_T0: int = 1, 213 min_T0_strict: bool = False, 214 max_missing: float = 1.0, 215 normalize: bool = False, 216 # CFE-specific 217 Z: list[str] | None = None, 218 Q: list[str] | None = None, 219 # Performance 220 device: Literal["cpu", "gpu"] = "cpu", 221 n_jobs: int | None = -1, 222 seed: int | None = None, 223) -> FectResult: 224 """Estimate counterfactual treatment effects for panel data. 225 226 This is the main Python entry point for the counterfactual estimator 227 workflow. Where the paper and the historical R package differ, 228 pyfector defaults to the paper's statistical definition and exposes 229 R-package-style behavior through explicit options. 230 231 Missing outcome policy 232 ---------------------- 233 pyfector distinguishes raw missing outcomes from counterfactual 234 missingness caused by treatment. Observed untreated cells 235 (``D == 0`` and non-missing ``Y``) fit the response surface. Observed 236 treated cells (``D == 1`` and non-missing ``Y``) contribute to ATT as 237 ``Y - Y_ct``. If a treated outcome is missing in the input data, the 238 model can still produce a counterfactual ``Y_ct`` for that cell, but 239 the cell is not counted in ``att_avg`` or ``att_on`` because the 240 treated potential outcome was not observed. 241 242 By default, ``min_T0`` is enforced only for treated and reversal 243 units. Sparse controls are retained if they have at least one 244 observed outcome, because they may still inform the low-rank response 245 surface. Set ``min_T0_strict=True`` to require controls to satisfy 246 ``min_T0`` too, matching the more conservative R fect sparse-panel 247 behavior. 248 249 Parameters 250 ---------- 251 data : polars.DataFrame, pandas.DataFrame 252 Long-format panel data. 253 Y, D : str 254 Column names for outcome and binary treatment indicator. 255 index : (str, str) 256 Column names for (unit_id, time_period). 257 X : list of str, optional 258 Time-varying covariates. 259 W : str, optional 260 Observation weight column. 261 group : str, optional 262 Reserved for grouped estimation. Currently raises 263 ``NotImplementedError`` when supplied. 264 method : {"fe", "ife", "mc", "cfe", "both"} 265 Estimation method. 266 force : {"none", "unit", "time", "two-way"} 267 Fixed effects specification. 268 r : int or (int, int) 269 Number of factors. If tuple, CV selects from range. 270 lam : float, optional 271 Nuclear norm penalty for MC. If None with CV=True, auto-selected. 272 nlambda : int 273 Number of automatically generated lambda candidates for MC CV. 274 CV : bool 275 If True, cross-validate over ``r`` for IFE when ``r`` is a tuple, 276 or over ``lam`` for MC when ``lam`` is None. 277 k : int 278 Number of CV folds. 279 cv_prop : float 280 Fraction of eligible observed control cells masked per CV fold. 281 cv_nobs : int 282 Number of consecutive within-unit observations to mask as a block. 283 cv_treat : bool 284 If True, restrict CV masks to pre-treatment cells of ever-treated 285 units. If False, use all observed control cells. 286 cv_donut : int 287 Exclude this many periods around treatment onset from CV evaluation. 288 criterion : {"mspe", "gmspe", "mad"} 289 Cross-validation loss. 290 cv_rule : {"min", "onepct"} 291 CV selection rule. ``"min"`` chooses the strict minimum-score 292 candidate and is the paper-faithful default. ``"onepct"`` chooses 293 the simplest candidate within 1% of the best score (lower ``r`` for 294 IFE, higher ``lam`` for MC). 295 se : bool 296 Compute standard errors via bootstrap/jackknife. 297 vartype : {"bootstrap", "jackknife"} 298 Inference method when ``se=True``. 299 nboots : int 300 Number of bootstrap replications. Ignored for jackknife. 301 alpha : float 302 Significance level for confidence intervals and tests. 303 tol : float 304 EM convergence tolerance for final point estimation. 305 max_iter : int 306 Maximum EM iterations. 307 min_T0 : int 308 Minimum untreated/pre-treatment observed periods. By default this is 309 enforced only for treated and treatment-reversal units. 310 min_T0_strict : bool 311 If True, enforce ``min_T0`` on all units, including controls. This 312 matches R fect's conservative handling of sparse control rows. 313 max_missing : float 314 Maximum missing-outcome fraction per unit, in ``[0, 1]``. Units with 315 no observed outcomes are always dropped, regardless of this threshold, 316 because they provide neither fitting information nor observed treated 317 effects. 318 normalize : bool 319 If True, estimate on an outcome standardized by its observed standard 320 deviation, then transform effects back to the original scale. 321 Z, Q : list of str, optional 322 Reserved CFE interaction arguments. Currently raise 323 ``NotImplementedError`` when supplied. 324 device : {"cpu", "gpu"} 325 Compute device. 326 n_jobs : int, optional 327 Parallel workers for CV and bootstrap. ``-1`` or ``None`` uses 328 all available CPUs. 329 seed : int, optional 330 Random seed for full reproducibility. 331 """ 332 # Set device 333 set_device(device) 334 xp = get_backend() 335 n_jobs = _resolve_n_jobs(n_jobs) 336 337 if group is not None: 338 raise NotImplementedError("The `group` argument is not implemented yet.") 339 if Z is not None or Q is not None: 340 raise NotImplementedError("The `Z` and `Q` CFE interaction arguments are not implemented yet.") 341 if criterion not in {"mspe", "gmspe", "mad"}: 342 raise ValueError("criterion must be 'mspe', 'gmspe', or 'mad'") 343 if cv_rule not in {"min", "onepct"}: 344 raise ValueError("cv_rule must be 'min' or 'onepct'") 345 if min_T0 < 0: 346 raise ValueError("min_T0 must be non-negative") 347 if not 0.0 <= max_missing <= 1.0: 348 raise ValueError("max_missing must be between 0 and 1") 349 350 # Map force string to int 351 force_map = {"none": 0, "unit": 1, "time": 2, "two-way": 3} 352 force_int = force_map[force] 353 354 # Prepare panel data 355 panel = prepare_panel( 356 data, Y=Y, D=D, index=index, X=X, W=W, 357 group=group, min_T0=min_T0, min_T0_strict=min_T0_strict, 358 max_missing=max_missing, 359 ) 360 361 # Move to device 362 Y_mat = to_device(panel.Y) 363 D_mat = to_device(panel.D) 364 I_mat = to_device(panel.I) 365 II_mat = to_device(panel.II) 366 X_mat = to_device(panel.X) if panel.X is not None else None 367 W_mat = to_device(panel.W) if panel.W is not None else None 368 369 # Normalize 370 norm_factor = 1.0 371 if normalize: 372 sd_y = float(xp.std(Y_mat[I_mat > 0])) 373 if sd_y > 0: 374 Y_mat = Y_mat / sd_y 375 norm_factor = sd_y 376 377 # Initial fit 378 Y0, beta0 = initial_fit(Y_mat, X_mat, II_mat, force_int) 379 380 # Determine r and lambda 381 r_cv = None 382 lambda_cv = None 383 cv_result = None 384 385 if method == "ife": 386 if isinstance(r, tuple) and CV: 387 cv_result = cv_ife( 388 Y_mat, Y0, X_mat, I_mat, II_mat, D_mat, W_mat, beta0, 389 force=force_int, r_range=r, k=k, cv_prop=cv_prop, 390 cv_nobs=cv_nobs, cv_treat=cv_treat, cv_donut=cv_donut, 391 criterion=criterion, cv_rule=cv_rule, 392 tol=tol, max_iter=max_iter, 393 n_jobs=n_jobs, seed=seed, 394 ) 395 r_cv = cv_result.best_r 396 else: 397 r_cv = r if isinstance(r, int) else r[0] 398 399 elif method == "mc": 400 if lam is None and CV: 401 cv_result = cv_mc( 402 Y_mat, Y0, X_mat, I_mat, II_mat, D_mat, W_mat, beta0, 403 force=force_int, nlambda=nlambda, k=k, cv_prop=cv_prop, 404 cv_nobs=cv_nobs, cv_treat=cv_treat, cv_donut=cv_donut, 405 criterion=criterion, cv_rule=cv_rule, 406 tol=tol, max_iter=max_iter, 407 n_jobs=n_jobs, seed=seed, 408 ) 409 lambda_cv = cv_result.best_lambda 410 else: 411 lambda_cv = lam if lam is not None else 0.0 412 413 elif method == "fe": 414 r_cv = 0 415 416 elif method == "cfe": 417 r_cv = r if isinstance(r, int) else r[0] 418 419 # Point estimation 420 if method in ("fe", "ife"): 421 est = estimate_ife( 422 Y_mat, Y0, X_mat, II_mat, W_mat, beta0, 423 r=r_cv, force=force_int, tol=tol, max_iter=max_iter, 424 ) 425 elif method == "mc": 426 est = estimate_mc( 427 Y_mat, Y0, X_mat, II_mat, W_mat, beta0, 428 lam=lambda_cv, force=force_int, tol=tol, max_iter=max_iter, 429 ) 430 elif method == "cfe": 431 est = estimate_cfe( 432 Y_mat, Y0, X_mat, II_mat, W_mat, beta0, 433 r=r_cv, force=force_int, tol=tol, max_iter=max_iter, 434 ) 435 elif method == "both": 436 # Run both IFE and MC, return IFE results with MC comparison 437 if isinstance(r, tuple) and CV: 438 cv_result = cv_ife( 439 Y_mat, Y0, X_mat, I_mat, II_mat, D_mat, W_mat, beta0, 440 force=force_int, r_range=r, k=k, cv_prop=cv_prop, 441 cv_nobs=cv_nobs, cv_treat=cv_treat, cv_donut=cv_donut, 442 criterion=criterion, cv_rule=cv_rule, 443 tol=tol, max_iter=max_iter, 444 n_jobs=n_jobs, seed=seed, 445 ) 446 r_cv = cv_result.best_r 447 else: 448 r_cv = r if isinstance(r, int) else r[0] 449 est = estimate_ife( 450 Y_mat, Y0, X_mat, II_mat, W_mat, beta0, 451 r=r_cv, force=force_int, tol=tol, max_iter=max_iter, 452 ) 453 else: 454 raise ValueError(f"Unknown method: {method}") 455 456 # Compute effects 457 eff = Y_mat - est.fit 458 Y_ct = est.fit 459 460 # Denormalize 461 if normalize and norm_factor != 1.0: 462 eff = eff * norm_factor 463 Y_ct = Y_ct * norm_factor 464 Y_mat = Y_mat * norm_factor 465 if est.beta is not None: 466 est = est._replace(beta=est.beta * norm_factor) 467 468 # ATT computation 469 T_on = to_device(panel.T_on) 470 att_avg, att_on, time_on, count_on, att_avg_unit = _compute_effects( 471 to_numpy(eff), to_numpy(D_mat), to_numpy(panel.T_on), to_numpy(I_mat), 472 ) 473 474 # Build result 475 result = FectResult( 476 method=method, 477 r_cv=r_cv, 478 lambda_cv=lambda_cv, 479 att_avg=att_avg, 480 att_avg_unit=att_avg_unit, 481 att_on=att_on, 482 time_on=time_on, 483 count_on=count_on, 484 beta=to_numpy(est.beta) if est.beta is not None else None, 485 covariate_names=panel.covariate_names, 486 mu=est.mu, 487 alpha=to_numpy(est.alpha) if est.alpha is not None else None, 488 xi=to_numpy(est.xi) if est.xi is not None else None, 489 factors=to_numpy(est.factors) if est.factors is not None else None, 490 loadings=to_numpy(est.loadings) if est.loadings is not None else None, 491 Y_ct=to_numpy(Y_ct), 492 eff=to_numpy(eff), 493 residuals=to_numpy(est.residuals), 494 sigma2=est.sigma2, 495 IC=est.IC, 496 PC=est.PC, 497 niter=est.niter, 498 converged=est.converged, 499 cv_result=cv_result, 500 panel=panel, 501 seed=seed, 502 ) 503 504 # Inference 505 if se: 506 result.inference = _run_inference( 507 result, panel, Y_mat, X_mat, W_mat, beta0, Y0, 508 method=method, r_cv=r_cv, lambda_cv=lambda_cv, 509 force_int=force_int, tol=tol, max_iter=max_iter, 510 vartype=vartype, nboots=nboots, alpha=alpha, 511 n_jobs=n_jobs, seed=seed, normalize=normalize, 512 norm_factor=norm_factor, 513 ) 514 515 return result
Estimate counterfactual treatment effects for panel data.
This is the main Python entry point for the counterfactual estimator workflow. Where the paper and the historical R package differ, pyfector defaults to the paper's statistical definition and exposes R-package-style behavior through explicit options.
Missing outcome policy
pyfector distinguishes raw missing outcomes from counterfactual
missingness caused by treatment. Observed untreated cells
(D == 0 and non-missing Y) fit the response surface. Observed
treated cells (D == 1 and non-missing Y) contribute to ATT as
Y - Y_ct. If a treated outcome is missing in the input data, the
model can still produce a counterfactual Y_ct for that cell, but
the cell is not counted in att_avg or att_on because the
treated potential outcome was not observed.
By default, min_T0 is enforced only for treated and reversal
units. Sparse controls are retained if they have at least one
observed outcome, because they may still inform the low-rank response
surface. Set min_T0_strict=True to require controls to satisfy
min_T0 too, matching the more conservative R fect sparse-panel
behavior.
Parameters
data : polars.DataFrame, pandas.DataFrame
Long-format panel data.
Y, D : str
Column names for outcome and binary treatment indicator.
index : (str, str)
Column names for (unit_id, time_period).
X : list of str, optional
Time-varying covariates.
W : str, optional
Observation weight column.
group : str, optional
Reserved for grouped estimation. Currently raises
NotImplementedError when supplied.
method : {"fe", "ife", "mc", "cfe", "both"}
Estimation method.
force : {"none", "unit", "time", "two-way"}
Fixed effects specification.
r : int or (int, int)
Number of factors. If tuple, CV selects from range.
lam : float, optional
Nuclear norm penalty for MC. If None with CV=True, auto-selected.
nlambda : int
Number of automatically generated lambda candidates for MC CV.
CV : bool
If True, cross-validate over r for IFE when r is a tuple,
or over lam for MC when lam is None.
k : int
Number of CV folds.
cv_prop : float
Fraction of eligible observed control cells masked per CV fold.
cv_nobs : int
Number of consecutive within-unit observations to mask as a block.
cv_treat : bool
If True, restrict CV masks to pre-treatment cells of ever-treated
units. If False, use all observed control cells.
cv_donut : int
Exclude this many periods around treatment onset from CV evaluation.
criterion : {"mspe", "gmspe", "mad"}
Cross-validation loss.
cv_rule : {"min", "onepct"}
CV selection rule. "min" chooses the strict minimum-score
candidate and is the paper-faithful default. "onepct" chooses
the simplest candidate within 1% of the best score (lower r for
IFE, higher lam for MC).
se : bool
Compute standard errors via bootstrap/jackknife.
vartype : {"bootstrap", "jackknife"}
Inference method when se=True.
nboots : int
Number of bootstrap replications. Ignored for jackknife.
alpha : float
Significance level for confidence intervals and tests.
tol : float
EM convergence tolerance for final point estimation.
max_iter : int
Maximum EM iterations.
min_T0 : int
Minimum untreated/pre-treatment observed periods. By default this is
enforced only for treated and treatment-reversal units.
min_T0_strict : bool
If True, enforce min_T0 on all units, including controls. This
matches R fect's conservative handling of sparse control rows.
max_missing : float
Maximum missing-outcome fraction per unit, in [0, 1]. Units with
no observed outcomes are always dropped, regardless of this threshold,
because they provide neither fitting information nor observed treated
effects.
normalize : bool
If True, estimate on an outcome standardized by its observed standard
deviation, then transform effects back to the original scale.
Z, Q : list of str, optional
Reserved CFE interaction arguments. Currently raise
NotImplementedError when supplied.
device : {"cpu", "gpu"}
Compute device.
n_jobs : int, optional
Parallel workers for CV and bootstrap. -1 or None uses
all available CPUs.
seed : int, optional
Random seed for full reproducibility.