pyfector.fect

Public entry point for pyfector.

This module exposes fect(), the single function most users call, and the FectResult dataclass that packages its output.

fect() orchestrates the full pipeline:

  1. pyfector.panel.prepare_panel() converts the input DataFrame to dense (T, N) matrices with treatment timing and unit classification.
  2. pyfector.panel.initial_fit() computes the starting fit Y0 and initial covariate coefficients on the control subsample.
  3. If the user passes a range for r (IFE) or leaves lam=None (MC), pyfector.cv chooses the hyperparameter by cross-validation.
  4. pyfector.estimators iterates the EM loop for the chosen method (fe, ife, mc, or cfe).
  5. Treatment effects are computed from eff = Y - Y_ct; the overall ATT and dynamic ATT by relative event time are derived in _compute_effects() via numpy.bincount(). Raw missing outcomes remain excluded from effect averages; model imputations are counterfactual predictions, not substitutes for unobserved treated outcomes.
  6. When se=True the requested inference routine (bootstrap or jackknife) from pyfector.inference is called with a closure that re-runs the EM loop on resampled units.

Example

::

import pyfector

result = pyfector.fect(
    data=df,
    Y="outcome", D="treat",
    index=("unit", "year"),
    X=["gdp", "pop"],
    method="ife",
    r=(0, 5),            # CV over 0..5 factors
    se=True,
    nboots=200,
    device="cpu",         # or "gpu"
    n_jobs=4,
    seed=42,
)
result.summary()
result.plot(kind="gap")
  1"""
  2Public entry point for pyfector.
  3
  4This module exposes :func:`fect`, the single function most users call,
  5and the :class:`FectResult` dataclass that packages its output.
  6
  7:func:`fect` orchestrates the full pipeline:
  8
  91. :func:`pyfector.panel.prepare_panel` converts the input DataFrame to
 10   dense ``(T, N)`` matrices with treatment timing and unit
 11   classification.
 122. :func:`pyfector.panel.initial_fit` computes the starting fit ``Y0``
 13   and initial covariate coefficients on the control subsample.
 143. If the user passes a range for ``r`` (IFE) or leaves ``lam=None``
 15   (MC), :mod:`pyfector.cv` chooses the hyperparameter by
 16   cross-validation.
 174. :mod:`pyfector.estimators` iterates the EM loop for the chosen
 18   method (``fe``, ``ife``, ``mc``, or ``cfe``).
 195. Treatment effects are computed from ``eff = Y - Y_ct``; the overall
 20   ATT and dynamic ATT by relative event time are derived in
 21   :func:`_compute_effects` via :func:`numpy.bincount`.  Raw missing
 22   outcomes remain excluded from effect averages; model imputations are
 23   counterfactual predictions, not substitutes for unobserved treated
 24   outcomes.
 256. When ``se=True`` the requested inference routine
 26   (``bootstrap`` or ``jackknife``) from :mod:`pyfector.inference` is
 27   called with a closure that re-runs the EM loop on resampled units.
 28
 29Example
 30-------
 31::
 32
 33    import pyfector
 34
 35    result = pyfector.fect(
 36        data=df,
 37        Y="outcome", D="treat",
 38        index=("unit", "year"),
 39        X=["gdp", "pop"],
 40        method="ife",
 41        r=(0, 5),            # CV over 0..5 factors
 42        se=True,
 43        nboots=200,
 44        device="cpu",         # or "gpu"
 45        n_jobs=4,
 46        seed=42,
 47    )
 48    result.summary()
 49    result.plot(kind="gap")
 50"""
 51
 52from __future__ import annotations
 53
 54from dataclasses import dataclass, field
 55from typing import Literal
 56import os
 57
 58import numpy as np
 59
 60from .backend import set_device, get_backend, to_numpy, to_device, make_rng
 61from .panel import PanelData, prepare_panel, initial_fit
 62from .estimators import estimate_ife, estimate_mc, estimate_cfe, EstimationResult
 63from .cv import cv_ife, cv_mc, CVResult
 64from .inference import bootstrap, jackknife, InferenceResult
 65from .diagnostics import (
 66    Diagnostics,
 67    run_diagnostics as _run_diagnostics,
 68    validate_diagnostics_request,
 69)
 70
 71
 72@dataclass
 73class FectResult:
 74    """Container for all fect estimation results."""
 75    # Method info
 76    method: str
 77    r_cv: int | None = None
 78    lambda_cv: float | None = None
 79
 80    # Point estimates
 81    att_avg: float = 0.0
 82    att_avg_unit: float = 0.0
 83
 84    # Dynamic effects
 85    att_on: np.ndarray | None = None
 86    time_on: np.ndarray | None = None
 87    count_on: np.ndarray | None = None
 88
 89    # Exit effects (treatment reversal)
 90    att_off: np.ndarray | None = None
 91    time_off: np.ndarray | None = None
 92
 93    # Coefficients
 94    beta: np.ndarray | None = None
 95    covariate_names: list[str] = field(default_factory=list)
 96
 97    # Fixed effects
 98    mu: float = 0.0
 99    alpha: np.ndarray | None = None   # unit FE
100    xi: np.ndarray | None = None      # time FE
101    factors: np.ndarray | None = None
102    loadings: np.ndarray | None = None
103
104    # Counterfactual and effects matrices
105    Y_ct: np.ndarray | None = None    # T×N counterfactual
106    eff: np.ndarray | None = None     # T×N treatment effects
107    residuals: np.ndarray | None = None
108
109    # Model fit
110    sigma2: float = 0.0
111    sigma2_fect: float = 0.0   # additive-FE baseline residual variance
112    IC: float = 0.0
113    PC: float = 0.0
114    rmse: float = 0.0
115    niter: int = 0
116    converged: bool = False
117
118    # Inference
119    inference: InferenceResult | None = None
120
121    # Diagnostics (populated when fect(..., diagnostics="full" | list))
122    diagnostics: Diagnostics | None = None
123
124    # CV
125    cv_result: CVResult | None = None
126
127    # Panel metadata
128    panel: PanelData | None = None
129
130    # Reproducibility
131    seed: int | None = None
132
133    def summary(self) -> str:
134        """Print summary table of results."""
135        lines = []
136        lines.append(f"pyfector estimation results")
137        lines.append(f"{'='*60}")
138        lines.append(f"Method: {self.method}")
139        if self.r_cv is not None:
140            lines.append(f"Number of factors (CV): {self.r_cv}")
141        if self.lambda_cv is not None:
142            lines.append(f"Lambda (CV): {self.lambda_cv:.6f}")
143        lines.append(f"Converged: {self.converged} (iter={self.niter})")
144        lines.append(f"Sigma^2: {self.sigma2:.6f}")
145        lines.append(f"Sigma^2_fect (FE baseline): {self.sigma2_fect:.6f}")
146        lines.append(f"")
147        lines.append(f"ATT (average): {self.att_avg:.6f}")
148        if self.inference is not None:
149            inf = self.inference
150            lines.append(f"  SE:     {inf.att_avg_se:.6f}")
151            lines.append(f"  CI:     [{inf.att_avg_ci[0]:.6f}, {inf.att_avg_ci[1]:.6f}]")
152            lines.append(f"  p-val:  {inf.att_avg_pval:.4f}")
153
154        if self.beta is not None and len(self.beta) > 0:
155            lines.append(f"")
156            lines.append(f"Coefficients:")
157            for i, name in enumerate(self.covariate_names):
158                lines.append(f"  {name}: {self.beta[i]:.6f}")
159
160        if self.att_on is not None and self.time_on is not None:
161            lines.append(f"")
162            lines.append(f"Dynamic effects (ATT by relative time):")
163            lines.append(f"  {'Time':>6s}  {'ATT':>10s}  {'Count':>6s}", )
164            for i, t in enumerate(self.time_on):
165                count = self.count_on[i] if self.count_on is not None else ""
166                att = self.att_on[i]
167                if self.inference is not None:
168                    se = self.inference.att_on_se[i]
169                    lines.append(f"  {t:>6.0f}  {att:>10.4f}  ({se:.4f})  {count}")
170                else:
171                    lines.append(f"  {t:>6.0f}  {att:>10.4f}  {count}")
172
173        lines.append(f"{'='*60}")
174        if self.panel is not None:
175            lines.append(f"N={self.panel.N}, T={self.panel.T}")
176        if self.seed is not None:
177            lines.append(f"Seed: {self.seed}")
178        if self.diagnostics is not None:
179            lines.append("")
180            lines.append(self.diagnostics.summary())
181        return "\n".join(lines)
182
183    def __repr__(self):
184        return self.summary()
185
186    def plot(self, kind="gap", **kwargs):
187        """Plot results. Shortcut for ``pyfector.plot(self, kind, ...)``."""
188        from .plotting import plot as _plot
189        return _plot(self, kind=kind, **kwargs)
190
191    def diagnose(self, **kwargs):
192        """Run diagnostic tests. Shortcut for ``pyfector.run_diagnostics(self, ...)``."""
193        from .diagnostics import run_diagnostics
194        return run_diagnostics(self, **kwargs)
195
196
197def fect(
198    data,
199    Y: str,
200    D: str,
201    index: tuple[str, str],
202    X: list[str] | None = None,
203    W: str | None = None,
204    group: str | None = None,
205    method: Literal["fe", "ife", "mc", "cfe", "both"] = "ife",
206    force: Literal["none", "unit", "time", "two-way"] = "two-way",
207    r: int | tuple[int, int] = 0,
208    lam: float | None = None,
209    nlambda: int = 10,
210    lambda_candidates: list[float] | np.ndarray | None = None,
211    CV: bool = True,
212    k: int = 10,
213    cv_prop: float = 0.1,
214    cv_nobs: int = 3,
215    cv_treat: bool = True,
216    cv_donut: int = 0,
217    criterion: str = "mspe",
218    cv_rule: Literal["min", "onepct"] = "min",
219    se: bool = False,
220    vartype: Literal["bootstrap", "jackknife"] = "bootstrap",
221    nboots: int = 200,
222    alpha: float = 0.05,
223    tol: float = 1e-7,
224    max_iter: int = 5000,
225    min_T0: int = 1,
226    min_T0_strict: bool = False,
227    max_missing: float = 1.0,
228    normalize: bool = False,
229    # CFE-specific
230    Z: list[str] | None = None,
231    Q: list[str] | None = None,
232    # Performance
233    device: Literal["cpu", "gpu"] = "cpu",
234    n_jobs: int | None = -1,
235    seed: int | None = None,
236    # Diagnostics (optional; run at fit time and attached to result)
237    diagnostics: Literal["none", "full"] | list[str] = "none",
238    diagnostics_options: dict | None = None,
239) -> FectResult:
240    """Estimate counterfactual treatment effects for panel data.
241
242    This is the main Python entry point for the counterfactual estimator
243    workflow.  Where the paper and the historical R package differ,
244    pyfector defaults to the paper's statistical definition and exposes
245    R-package-style behavior through explicit options.
246
247    Missing outcome policy
248    ----------------------
249    pyfector distinguishes raw missing outcomes from counterfactual
250    missingness caused by treatment.  Observed untreated cells
251    (``D == 0`` and non-missing ``Y``) fit the response surface.  Observed
252    treated cells (``D == 1`` and non-missing ``Y``) contribute to ATT as
253    ``Y - Y_ct``.  If a treated outcome is missing in the input data, the
254    model can still produce a counterfactual ``Y_ct`` for that cell, but
255    the cell is not counted in ``att_avg`` or ``att_on`` because the
256    treated potential outcome was not observed.
257
258    By default, ``min_T0`` is enforced only for treated and reversal
259    units.  Sparse controls are retained if they have at least one
260    observed outcome, because they may still inform the low-rank response
261    surface.  Set ``min_T0_strict=True`` to require controls to satisfy
262    ``min_T0`` too, matching the more conservative R fect sparse-panel
263    behavior.
264
265    Parameters
266    ----------
267    data : polars.DataFrame, pandas.DataFrame
268        Long-format panel data.
269    Y, D : str
270        Column names for outcome and binary treatment indicator.
271    index : (str, str)
272        Column names for (unit_id, time_period).
273    X : list of str, optional
274        Time-varying covariates.
275    W : str, optional
276        Observation weight column.
277    group : str, optional
278        Reserved for grouped estimation. Currently raises
279        ``NotImplementedError`` when supplied.
280    method : {"fe", "ife", "mc", "cfe", "both"}
281        Estimation method.
282    force : {"none", "unit", "time", "two-way"}
283        Fixed effects specification.
284    r : int or (int, int)
285        Number of factors.  If tuple, CV selects from range.
286    lam : float, optional
287        Nuclear norm penalty for MC.  If None with CV=True, auto-selected.
288    nlambda : int
289        Number of automatically generated lambda candidates for MC CV.
290    lambda_candidates : array-like, optional
291        Explicit non-negative lambda candidates for MC CV. When supplied,
292        ``nlambda`` is ignored.
293    CV : bool
294        If True, cross-validate over ``r`` for IFE when ``r`` is a tuple,
295        or over ``lam`` for MC when ``lam`` is None.
296    k : int
297        Number of CV folds.
298    cv_prop : float
299        Fraction of eligible observed control cells masked per CV fold.
300    cv_nobs : int
301        Number of consecutive within-unit observations to mask as a block.
302    cv_treat : bool
303        If True, restrict CV masks to pre-treatment cells of ever-treated
304        units. If False, use all observed control cells.
305    cv_donut : int
306        Exclude this many periods around treatment onset from CV evaluation.
307    criterion : {"mspe", "gmspe", "mad"}
308        Cross-validation loss.
309    cv_rule : {"min", "onepct"}
310        CV selection rule. ``"min"`` chooses the strict minimum-score
311        candidate and is the paper-faithful default. ``"onepct"`` chooses
312        the simplest candidate within 1% of the best score (lower ``r`` for
313        IFE, higher ``lam`` for MC).
314    se : bool
315        Compute standard errors via bootstrap/jackknife.
316    vartype : {"bootstrap", "jackknife"}
317        Inference method when ``se=True``.
318    nboots : int
319        Number of bootstrap replications. Ignored for jackknife.
320    alpha : float
321        Significance level for confidence intervals and tests.
322    tol : float
323        EM convergence tolerance for final point estimation.
324    max_iter : int
325        Maximum EM iterations.
326    min_T0 : int
327        Minimum untreated/pre-treatment observed periods. By default this is
328        enforced only for treated and treatment-reversal units.
329    min_T0_strict : bool
330        If True, enforce ``min_T0`` on all units, including controls. This
331        matches R fect's conservative handling of sparse control rows.
332    max_missing : float
333        Maximum missing-outcome fraction per unit, in ``[0, 1]``. Units with
334        no observed outcomes are always dropped, regardless of this threshold,
335        because they provide neither fitting information nor observed treated
336        effects.
337    normalize : bool
338        If True, estimate on an outcome standardized by its observed standard
339        deviation, then transform effects back to the original scale.
340    Z, Q : list of str, optional
341        Reserved CFE interaction arguments. Currently raise
342        ``NotImplementedError`` when supplied.
343    device : {"cpu", "gpu"}
344        Compute device.
345    n_jobs : int, optional
346        Parallel workers for CV and bootstrap. ``-1`` or ``None`` uses
347        all available CPUs.
348    seed : int, optional
349        Random seed for full reproducibility.
350    """
351    # Set device
352    set_device(device)
353    xp = get_backend()
354    n_jobs = _resolve_n_jobs(n_jobs)
355    if device == "gpu":
356        n_jobs = 1
357
358    if group is not None:
359        raise NotImplementedError("The `group` argument is not implemented yet.")
360    if Z is not None or Q is not None:
361        raise NotImplementedError("The `Z` and `Q` CFE interaction arguments are not implemented yet.")
362    if criterion not in {"mspe", "gmspe", "mad"}:
363        raise ValueError("criterion must be 'mspe', 'gmspe', or 'mad'")
364    if cv_rule not in {"min", "onepct"}:
365        raise ValueError("cv_rule must be 'min' or 'onepct'")
366    if min_T0 < 0:
367        raise ValueError("min_T0 must be non-negative")
368    if not 0.0 <= max_missing <= 1.0:
369        raise ValueError("max_missing must be between 0 and 1")
370
371    # Validate diagnostics request before doing any expensive estimation
372    # so users don't wait for a 30-min MC fit only to find their config
373    # is wrong.
374    requested_diag = validate_diagnostics_request(
375        diagnostics, diagnostics_options, se,
376    )
377
378    # Map force string to int
379    force_map = {"none": 0, "unit": 1, "time": 2, "two-way": 3}
380    force_int = force_map[force]
381
382    # Prepare panel data
383    panel = prepare_panel(
384        data, Y=Y, D=D, index=index, X=X, W=W,
385        group=group, min_T0=min_T0, min_T0_strict=min_T0_strict,
386        max_missing=max_missing,
387    )
388
389    # Move to device
390    Y_mat = to_device(panel.Y)
391    D_mat = to_device(panel.D)
392    I_mat = to_device(panel.I)
393    II_mat = to_device(panel.II)
394    X_mat = to_device(panel.X) if panel.X is not None else None
395    W_mat = to_device(panel.W) if panel.W is not None else None
396
397    # Normalize
398    norm_factor = 1.0
399    if normalize:
400        sd_y = float(xp.std(Y_mat[I_mat > 0]))
401        if sd_y > 0:
402            Y_mat = Y_mat / sd_y
403            norm_factor = sd_y
404
405    # Initial fit
406    Y0, beta0 = initial_fit(Y_mat, X_mat, II_mat, force_int)
407
408    # Determine r and lambda
409    r_cv = None
410    lambda_cv = None
411    cv_result = None
412
413    if method == "ife":
414        if isinstance(r, tuple) and CV:
415            cv_result = cv_ife(
416                Y_mat, Y0, X_mat, I_mat, II_mat, D_mat, W_mat, beta0,
417                force=force_int, r_range=r, k=k, cv_prop=cv_prop,
418                cv_nobs=cv_nobs, cv_treat=cv_treat, cv_donut=cv_donut,
419                criterion=criterion, cv_rule=cv_rule,
420                tol=tol, max_iter=max_iter,
421                n_jobs=n_jobs, seed=seed,
422            )
423            r_cv = cv_result.best_r
424        else:
425            r_cv = r if isinstance(r, int) else r[0]
426
427    elif method == "mc":
428        if lam is None and CV:
429            cv_result = cv_mc(
430                Y_mat, Y0, X_mat, I_mat, II_mat, D_mat, W_mat, beta0,
431                force=force_int, lambda_candidates=lambda_candidates,
432                nlambda=nlambda, k=k, cv_prop=cv_prop,
433                cv_nobs=cv_nobs, cv_treat=cv_treat, cv_donut=cv_donut,
434                criterion=criterion, cv_rule=cv_rule,
435                tol=tol, max_iter=max_iter,
436                n_jobs=n_jobs, seed=seed,
437            )
438            lambda_cv = cv_result.best_lambda
439        else:
440            lambda_cv = lam if lam is not None else 0.0
441
442    elif method == "fe":
443        r_cv = 0
444
445    elif method == "cfe":
446        r_cv = r if isinstance(r, int) else r[0]
447
448    # Point estimation
449    if method in ("fe", "ife"):
450        est = estimate_ife(
451            Y_mat, Y0, X_mat, II_mat, W_mat, beta0,
452            r=r_cv, force=force_int, tol=tol, max_iter=max_iter,
453        )
454    elif method == "mc":
455        est = estimate_mc(
456            Y_mat, Y0, X_mat, II_mat, W_mat, beta0,
457            lam=lambda_cv, force=force_int, tol=tol, max_iter=max_iter,
458        )
459    elif method == "cfe":
460        est = estimate_cfe(
461            Y_mat, Y0, X_mat, II_mat, W_mat, beta0,
462            r=r_cv, force=force_int, tol=tol, max_iter=max_iter,
463        )
464    elif method == "both":
465        # Run both IFE and MC, return IFE results with MC comparison
466        if isinstance(r, tuple) and CV:
467            cv_result = cv_ife(
468                Y_mat, Y0, X_mat, I_mat, II_mat, D_mat, W_mat, beta0,
469                force=force_int, r_range=r, k=k, cv_prop=cv_prop,
470                cv_nobs=cv_nobs, cv_treat=cv_treat, cv_donut=cv_donut,
471                criterion=criterion, cv_rule=cv_rule,
472                tol=tol, max_iter=max_iter,
473                n_jobs=n_jobs, seed=seed,
474            )
475            r_cv = cv_result.best_r
476        else:
477            r_cv = r if isinstance(r, int) else r[0]
478        est = estimate_ife(
479            Y_mat, Y0, X_mat, II_mat, W_mat, beta0,
480            r=r_cv, force=force_int, tol=tol, max_iter=max_iter,
481        )
482    else:
483        raise ValueError(f"Unknown method: {method}")
484
485    # Compute effects
486    eff = Y_mat - est.fit
487    Y_ct = est.fit
488
489    # Additive-FE baseline residual variance (Liu et al. 2024 sigma2.fect).
490    # For method="fe" the main estimator IS the additive-FE pass, so reuse
491    # est.sigma2. For ife/mc/cfe/both, run an extra r=0 IFE pass on the
492    # same panel with the user's requested FE structure.
493    if method == "fe":
494        sigma2_fect_value = float(est.sigma2)
495    else:
496        est_fect = estimate_ife(
497            Y_mat, Y0, X_mat, II_mat, W_mat, beta0,
498            r=0, force=force_int, tol=tol, max_iter=max_iter,
499        )
500        sigma2_fect_value = float(est_fect.sigma2)
501
502    # Denormalize
503    if normalize and norm_factor != 1.0:
504        eff = eff * norm_factor
505        Y_ct = Y_ct * norm_factor
506        Y_mat = Y_mat * norm_factor
507        if est.beta is not None:
508            est = est._replace(beta=est.beta * norm_factor)
509        sigma2_fect_value *= norm_factor ** 2
510
511    # ATT computation
512    T_on = to_device(panel.T_on)
513    att_avg, att_on, time_on, count_on, att_avg_unit = _compute_effects(
514        to_numpy(eff), to_numpy(D_mat), to_numpy(panel.T_on), to_numpy(I_mat),
515    )
516
517    # Build result
518    result = FectResult(
519        method=method,
520        r_cv=r_cv,
521        lambda_cv=lambda_cv,
522        att_avg=att_avg,
523        att_avg_unit=att_avg_unit,
524        att_on=att_on,
525        time_on=time_on,
526        count_on=count_on,
527        beta=to_numpy(est.beta) if est.beta is not None else None,
528        covariate_names=panel.covariate_names,
529        mu=est.mu,
530        alpha=to_numpy(est.alpha) if est.alpha is not None else None,
531        xi=to_numpy(est.xi) if est.xi is not None else None,
532        factors=to_numpy(est.factors) if est.factors is not None else None,
533        loadings=to_numpy(est.loadings) if est.loadings is not None else None,
534        Y_ct=to_numpy(Y_ct),
535        eff=to_numpy(eff),
536        residuals=to_numpy(est.residuals),
537        sigma2=est.sigma2,
538        sigma2_fect=sigma2_fect_value,
539        IC=est.IC,
540        PC=est.PC,
541        niter=est.niter,
542        converged=est.converged,
543        cv_result=cv_result,
544        panel=panel,
545        seed=seed,
546    )
547
548    # Inference
549    if se:
550        result.inference = _run_inference(
551            result, panel, Y_mat, X_mat, W_mat, beta0, Y0,
552            method=method, r_cv=r_cv, lambda_cv=lambda_cv,
553            force_int=force_int, tol=tol, max_iter=max_iter,
554            vartype=vartype, nboots=nboots, alpha=alpha,
555            n_jobs=n_jobs, seed=seed, normalize=normalize,
556            norm_factor=norm_factor,
557        )
558
559    # Run requested diagnostics at fit time. requested_diag is None when
560    # diagnostics="none". Validation already enforced se=True and
561    # required-config presence.
562    if requested_diag is not None:
563        opts = dict(diagnostics_options or {})
564        if "loo" in requested_diag:
565            opts["loo"] = True
566        else:
567            opts.setdefault("loo", False)
568        result.diagnostics = _run_diagnostics(
569            result, _requested=requested_diag, **opts,
570        )
571
572    return result
573
574
575def _compute_effects(eff, D, T_on, I):
576    """Compute overall ATT, per-unit ATT, and dynamic ATT by event time.
577
578    ATT averages are defined only over observed outcome cells.  A missing
579    raw treated outcome has no observed ``Y(1)``, so it is excluded even
580    though the estimator may have produced a counterfactual ``Y_ct`` for
581    that matrix position.
582
583    Dynamic ATT grouping is done with :func:`numpy.bincount` so the cost
584    is O(n_observed_ever_treated_cells) regardless of the number of
585    distinct relative-time values.  Per-unit ATT uses a column-wise
586    masked mean via ``np.add.reduce`` rather than a Python loop.
587    """
588    treated = (D > 0) & (I > 0)
589    n_treated = int(np.sum(treated))
590
591    att_avg = float(np.sum(eff[treated]) / max(n_treated, 1))
592
593    # Per-unit ATT (post-treatment only) — sum then divide by count,
594    # keeping only units that have at least one treated observation.
595    col_counts = treated.sum(axis=0)                                 # (N,)
596    col_sums = (eff * treated).sum(axis=0)                           # (N,)
597    has_any = col_counts > 0
598    if np.any(has_any):
599        unit_atts = col_sums[has_any] / col_counts[has_any]
600        att_avg_unit = float(unit_atts.mean())
601    else:
602        att_avg_unit = 0.0
603
604    # Dynamic ATT by relative event time.  Include pre-treatment periods
605    # for ever-treated units (counterfactual gaps before onset).
606    ever_treated = np.any(D > 0, axis=0)                             # (N,)
607    all_periods = (I > 0) & ever_treated[np.newaxis, :]              # (T, N)
608
609    T_on_flat = T_on[all_periods]
610    eff_flat = eff[all_periods]
611    valid = ~np.isnan(T_on_flat)
612    T_on_flat = T_on_flat[valid].astype(np.int64)
613    eff_flat = eff_flat[valid]
614
615    if T_on_flat.size == 0:
616        return att_avg, np.array([]), np.array([]), np.array([], dtype=np.int64), att_avg_unit
617
618    # Bincount over the shifted integer indices.
619    offset = int(T_on_flat.min())
620    idx = T_on_flat - offset
621    minlength = int(T_on_flat.max() - offset + 1)
622    sums = np.bincount(idx, weights=eff_flat, minlength=minlength)
623    counts = np.bincount(idx, minlength=minlength)
624    keep = counts > 0
625    time_on = (offset + np.arange(minlength))[keep].astype(np.float64)
626    att_on = (sums[keep] / counts[keep]).astype(np.float64)
627    count_on = counts[keep].astype(np.int64)
628
629    return att_avg, att_on, time_on, count_on, att_avg_unit
630
631
632def _run_inference(
633    result, panel, Y_mat, X_mat, W_mat, beta0, Y0,
634    method, r_cv, lambda_cv, force_int, tol, max_iter,
635    vartype, nboots, alpha, n_jobs, seed, normalize, norm_factor,
636):
637    """Run bootstrap or jackknife inference.
638
639    Bootstrap replicates use a relaxed convergence tolerance
640    (``max(tol, 1e-3)``) because bootstrap SEs converge to 3–4
641    significant digits regardless of inner precision.  The full-sample
642    fit is used as the warm-start initialiser for each replicate,
643    which cuts EM iterations by ~30-60 %.
644    """
645    xp = get_backend()
646
647    # Relaxed tolerance: bootstrap SEs don't benefit from tight inner tol.
648    boot_tol = max(tol, 1e-3)
649
650    # Pre-move panel data to device once (avoids CPU→GPU copy per bootstrap rep)
651    II_dev = to_device(panel.II)
652    D_dev = to_device(panel.D)
653    I_dev = to_device(panel.I)
654    T_on_np = panel.T_on  # keep on CPU for ATT computation
655
656    # Warm-start: use full-sample fitted values as initial imputation
657    # for each replicate instead of recomputing initial_fit from scratch.
658    Y0_full = to_device(result.Y_ct)
659
660    def _estimate_fn(unit_idx):
661        """Re-estimate on a subset of units."""
662        Y_sub = Y_mat[:, unit_idx]
663        II_sub = II_dev[:, unit_idx]
664        D_sub = D_dev[:, unit_idx]
665        I_sub = I_dev[:, unit_idx]
666        X_sub = X_mat[:, unit_idx, :] if X_mat is not None else None
667        W_sub = W_mat[:, unit_idx] if W_mat is not None else None
668        T_on_sub = T_on_np[:, unit_idx]
669        beta0_sub = beta0
670
671        # Warm-start from full-sample fit (columns for resampled units).
672        # This is much closer to the replicate's solution than a cold
673        # initial_fit, cutting EM iterations significantly.
674        Y0_sub = Y0_full[:, unit_idx].copy()
675
676        if method in ("fe", "ife"):
677            est = estimate_ife(
678                Y_sub, Y0_sub, X_sub, II_sub, W_sub, beta0_sub,
679                r=r_cv, force=force_int, tol=boot_tol, max_iter=max_iter,
680            )
681        elif method == "mc":
682            est = estimate_mc(
683                Y_sub, Y0_sub, X_sub, II_sub, W_sub, beta0_sub,
684                lam=lambda_cv, force=force_int, tol=boot_tol, max_iter=max_iter,
685            )
686        else:
687            est = estimate_ife(
688                Y_sub, Y0_sub, X_sub, II_sub, W_sub, beta0_sub,
689                r=r_cv or 0, force=force_int, tol=boot_tol, max_iter=max_iter,
690            )
691
692        eff = to_numpy(Y_sub - est.fit)
693        if normalize and norm_factor != 1.0:
694            eff = eff * norm_factor
695
696        return eff, to_numpy(D_sub), T_on_sub, to_numpy(I_sub)
697
698    if vartype == "bootstrap":
699        return bootstrap(
700            _estimate_fn, to_numpy(Y_mat), to_numpy(panel.D),
701            to_numpy(panel.I), panel.T_on, panel.unit_type,
702            nboots=nboots, alpha=alpha, n_jobs=n_jobs, seed=seed,
703            point_estimate=(result.eff, panel.D, panel.T_on, panel.I),
704        )
705    else:
706        return jackknife(
707            _estimate_fn, to_numpy(Y_mat), to_numpy(panel.D),
708            to_numpy(panel.I), panel.T_on, panel.unit_type,
709            alpha=alpha, n_jobs=n_jobs,
710        )
711
712
713def _resolve_n_jobs(n_jobs: int | None) -> int:
714    """Normalize public n_jobs values before passing them to joblib."""
715    if n_jobs is None or n_jobs == -1:
716        return os.cpu_count() or 1
717    if n_jobs == 0 or n_jobs < -1:
718        raise ValueError("n_jobs must be a positive integer, -1, or None")
719    return int(n_jobs)
@dataclass
class FectResult:
 73@dataclass
 74class FectResult:
 75    """Container for all fect estimation results."""
 76    # Method info
 77    method: str
 78    r_cv: int | None = None
 79    lambda_cv: float | None = None
 80
 81    # Point estimates
 82    att_avg: float = 0.0
 83    att_avg_unit: float = 0.0
 84
 85    # Dynamic effects
 86    att_on: np.ndarray | None = None
 87    time_on: np.ndarray | None = None
 88    count_on: np.ndarray | None = None
 89
 90    # Exit effects (treatment reversal)
 91    att_off: np.ndarray | None = None
 92    time_off: np.ndarray | None = None
 93
 94    # Coefficients
 95    beta: np.ndarray | None = None
 96    covariate_names: list[str] = field(default_factory=list)
 97
 98    # Fixed effects
 99    mu: float = 0.0
100    alpha: np.ndarray | None = None   # unit FE
101    xi: np.ndarray | None = None      # time FE
102    factors: np.ndarray | None = None
103    loadings: np.ndarray | None = None
104
105    # Counterfactual and effects matrices
106    Y_ct: np.ndarray | None = None    # T×N counterfactual
107    eff: np.ndarray | None = None     # T×N treatment effects
108    residuals: np.ndarray | None = None
109
110    # Model fit
111    sigma2: float = 0.0
112    sigma2_fect: float = 0.0   # additive-FE baseline residual variance
113    IC: float = 0.0
114    PC: float = 0.0
115    rmse: float = 0.0
116    niter: int = 0
117    converged: bool = False
118
119    # Inference
120    inference: InferenceResult | None = None
121
122    # Diagnostics (populated when fect(..., diagnostics="full" | list))
123    diagnostics: Diagnostics | None = None
124
125    # CV
126    cv_result: CVResult | None = None
127
128    # Panel metadata
129    panel: PanelData | None = None
130
131    # Reproducibility
132    seed: int | None = None
133
134    def summary(self) -> str:
135        """Print summary table of results."""
136        lines = []
137        lines.append(f"pyfector estimation results")
138        lines.append(f"{'='*60}")
139        lines.append(f"Method: {self.method}")
140        if self.r_cv is not None:
141            lines.append(f"Number of factors (CV): {self.r_cv}")
142        if self.lambda_cv is not None:
143            lines.append(f"Lambda (CV): {self.lambda_cv:.6f}")
144        lines.append(f"Converged: {self.converged} (iter={self.niter})")
145        lines.append(f"Sigma^2: {self.sigma2:.6f}")
146        lines.append(f"Sigma^2_fect (FE baseline): {self.sigma2_fect:.6f}")
147        lines.append(f"")
148        lines.append(f"ATT (average): {self.att_avg:.6f}")
149        if self.inference is not None:
150            inf = self.inference
151            lines.append(f"  SE:     {inf.att_avg_se:.6f}")
152            lines.append(f"  CI:     [{inf.att_avg_ci[0]:.6f}, {inf.att_avg_ci[1]:.6f}]")
153            lines.append(f"  p-val:  {inf.att_avg_pval:.4f}")
154
155        if self.beta is not None and len(self.beta) > 0:
156            lines.append(f"")
157            lines.append(f"Coefficients:")
158            for i, name in enumerate(self.covariate_names):
159                lines.append(f"  {name}: {self.beta[i]:.6f}")
160
161        if self.att_on is not None and self.time_on is not None:
162            lines.append(f"")
163            lines.append(f"Dynamic effects (ATT by relative time):")
164            lines.append(f"  {'Time':>6s}  {'ATT':>10s}  {'Count':>6s}", )
165            for i, t in enumerate(self.time_on):
166                count = self.count_on[i] if self.count_on is not None else ""
167                att = self.att_on[i]
168                if self.inference is not None:
169                    se = self.inference.att_on_se[i]
170                    lines.append(f"  {t:>6.0f}  {att:>10.4f}  ({se:.4f})  {count}")
171                else:
172                    lines.append(f"  {t:>6.0f}  {att:>10.4f}  {count}")
173
174        lines.append(f"{'='*60}")
175        if self.panel is not None:
176            lines.append(f"N={self.panel.N}, T={self.panel.T}")
177        if self.seed is not None:
178            lines.append(f"Seed: {self.seed}")
179        if self.diagnostics is not None:
180            lines.append("")
181            lines.append(self.diagnostics.summary())
182        return "\n".join(lines)
183
184    def __repr__(self):
185        return self.summary()
186
187    def plot(self, kind="gap", **kwargs):
188        """Plot results. Shortcut for ``pyfector.plot(self, kind, ...)``."""
189        from .plotting import plot as _plot
190        return _plot(self, kind=kind, **kwargs)
191
192    def diagnose(self, **kwargs):
193        """Run diagnostic tests. Shortcut for ``pyfector.run_diagnostics(self, ...)``."""
194        from .diagnostics import run_diagnostics
195        return run_diagnostics(self, **kwargs)

Container for all fect estimation results.

FectResult( method: str, r_cv: int | None = None, lambda_cv: float | None = None, att_avg: float = 0.0, att_avg_unit: float = 0.0, att_on: numpy.ndarray | None = None, time_on: numpy.ndarray | None = None, count_on: numpy.ndarray | None = None, att_off: numpy.ndarray | None = None, time_off: numpy.ndarray | None = None, beta: numpy.ndarray | None = None, covariate_names: list[str] = <factory>, mu: float = 0.0, alpha: numpy.ndarray | None = None, xi: numpy.ndarray | None = None, factors: numpy.ndarray | None = None, loadings: numpy.ndarray | None = None, Y_ct: numpy.ndarray | None = None, eff: numpy.ndarray | None = None, residuals: numpy.ndarray | None = None, sigma2: float = 0.0, sigma2_fect: float = 0.0, IC: float = 0.0, PC: float = 0.0, rmse: float = 0.0, niter: int = 0, converged: bool = False, inference: pyfector.inference.InferenceResult | None = None, diagnostics: pyfector.Diagnostics | None = None, cv_result: pyfector.cv.CVResult | None = None, panel: pyfector.panel.PanelData | None = None, seed: int | None = None)
method: str
r_cv: int | None = None
lambda_cv: float | None = None
att_avg: float = 0.0
att_avg_unit: float = 0.0
att_on: numpy.ndarray | None = None
time_on: numpy.ndarray | None = None
count_on: numpy.ndarray | None = None
att_off: numpy.ndarray | None = None
time_off: numpy.ndarray | None = None
beta: numpy.ndarray | None = None
covariate_names: list[str]
mu: float = 0.0
alpha: numpy.ndarray | None = None
xi: numpy.ndarray | None = None
factors: numpy.ndarray | None = None
loadings: numpy.ndarray | None = None
Y_ct: numpy.ndarray | None = None
eff: numpy.ndarray | None = None
residuals: numpy.ndarray | None = None
sigma2: float = 0.0
sigma2_fect: float = 0.0
IC: float = 0.0
PC: float = 0.0
rmse: float = 0.0
niter: int = 0
converged: bool = False
inference: pyfector.inference.InferenceResult | None = None
diagnostics: pyfector.Diagnostics | None = None
cv_result: pyfector.cv.CVResult | None = None
panel: pyfector.panel.PanelData | None = None
seed: int | None = None
def summary(self) -> str:
134    def summary(self) -> str:
135        """Print summary table of results."""
136        lines = []
137        lines.append(f"pyfector estimation results")
138        lines.append(f"{'='*60}")
139        lines.append(f"Method: {self.method}")
140        if self.r_cv is not None:
141            lines.append(f"Number of factors (CV): {self.r_cv}")
142        if self.lambda_cv is not None:
143            lines.append(f"Lambda (CV): {self.lambda_cv:.6f}")
144        lines.append(f"Converged: {self.converged} (iter={self.niter})")
145        lines.append(f"Sigma^2: {self.sigma2:.6f}")
146        lines.append(f"Sigma^2_fect (FE baseline): {self.sigma2_fect:.6f}")
147        lines.append(f"")
148        lines.append(f"ATT (average): {self.att_avg:.6f}")
149        if self.inference is not None:
150            inf = self.inference
151            lines.append(f"  SE:     {inf.att_avg_se:.6f}")
152            lines.append(f"  CI:     [{inf.att_avg_ci[0]:.6f}, {inf.att_avg_ci[1]:.6f}]")
153            lines.append(f"  p-val:  {inf.att_avg_pval:.4f}")
154
155        if self.beta is not None and len(self.beta) > 0:
156            lines.append(f"")
157            lines.append(f"Coefficients:")
158            for i, name in enumerate(self.covariate_names):
159                lines.append(f"  {name}: {self.beta[i]:.6f}")
160
161        if self.att_on is not None and self.time_on is not None:
162            lines.append(f"")
163            lines.append(f"Dynamic effects (ATT by relative time):")
164            lines.append(f"  {'Time':>6s}  {'ATT':>10s}  {'Count':>6s}", )
165            for i, t in enumerate(self.time_on):
166                count = self.count_on[i] if self.count_on is not None else ""
167                att = self.att_on[i]
168                if self.inference is not None:
169                    se = self.inference.att_on_se[i]
170                    lines.append(f"  {t:>6.0f}  {att:>10.4f}  ({se:.4f})  {count}")
171                else:
172                    lines.append(f"  {t:>6.0f}  {att:>10.4f}  {count}")
173
174        lines.append(f"{'='*60}")
175        if self.panel is not None:
176            lines.append(f"N={self.panel.N}, T={self.panel.T}")
177        if self.seed is not None:
178            lines.append(f"Seed: {self.seed}")
179        if self.diagnostics is not None:
180            lines.append("")
181            lines.append(self.diagnostics.summary())
182        return "\n".join(lines)

Print summary table of results.

def plot(self, kind='gap', **kwargs):
187    def plot(self, kind="gap", **kwargs):
188        """Plot results. Shortcut for ``pyfector.plot(self, kind, ...)``."""
189        from .plotting import plot as _plot
190        return _plot(self, kind=kind, **kwargs)

Plot results. Shortcut for pyfector.plot(self, kind, ...).

def diagnose(self, **kwargs):
192    def diagnose(self, **kwargs):
193        """Run diagnostic tests. Shortcut for ``pyfector.run_diagnostics(self, ...)``."""
194        from .diagnostics import run_diagnostics
195        return run_diagnostics(self, **kwargs)

Run diagnostic tests. Shortcut for pyfector.run_diagnostics(self, ...).

def fect( data, Y: str, D: str, index: tuple[str, str], X: list[str] | None = None, W: str | None = None, group: str | None = None, method: Literal['fe', 'ife', 'mc', 'cfe', 'both'] = 'ife', force: Literal['none', 'unit', 'time', 'two-way'] = 'two-way', r: int | tuple[int, int] = 0, lam: float | None = None, nlambda: int = 10, lambda_candidates: list[float] | numpy.ndarray | None = None, CV: bool = True, k: int = 10, cv_prop: float = 0.1, cv_nobs: int = 3, cv_treat: bool = True, cv_donut: int = 0, criterion: str = 'mspe', cv_rule: Literal['min', 'onepct'] = 'min', se: bool = False, vartype: Literal['bootstrap', 'jackknife'] = 'bootstrap', nboots: int = 200, alpha: float = 0.05, tol: float = 1e-07, max_iter: int = 5000, min_T0: int = 1, min_T0_strict: bool = False, max_missing: float = 1.0, normalize: bool = False, Z: list[str] | None = None, Q: list[str] | None = None, device: Literal['cpu', 'gpu'] = 'cpu', n_jobs: int | None = -1, seed: int | None = None, diagnostics: Union[Literal['none', 'full'], list[str]] = 'none', diagnostics_options: dict | None = None) -> FectResult:
198def fect(
199    data,
200    Y: str,
201    D: str,
202    index: tuple[str, str],
203    X: list[str] | None = None,
204    W: str | None = None,
205    group: str | None = None,
206    method: Literal["fe", "ife", "mc", "cfe", "both"] = "ife",
207    force: Literal["none", "unit", "time", "two-way"] = "two-way",
208    r: int | tuple[int, int] = 0,
209    lam: float | None = None,
210    nlambda: int = 10,
211    lambda_candidates: list[float] | np.ndarray | None = None,
212    CV: bool = True,
213    k: int = 10,
214    cv_prop: float = 0.1,
215    cv_nobs: int = 3,
216    cv_treat: bool = True,
217    cv_donut: int = 0,
218    criterion: str = "mspe",
219    cv_rule: Literal["min", "onepct"] = "min",
220    se: bool = False,
221    vartype: Literal["bootstrap", "jackknife"] = "bootstrap",
222    nboots: int = 200,
223    alpha: float = 0.05,
224    tol: float = 1e-7,
225    max_iter: int = 5000,
226    min_T0: int = 1,
227    min_T0_strict: bool = False,
228    max_missing: float = 1.0,
229    normalize: bool = False,
230    # CFE-specific
231    Z: list[str] | None = None,
232    Q: list[str] | None = None,
233    # Performance
234    device: Literal["cpu", "gpu"] = "cpu",
235    n_jobs: int | None = -1,
236    seed: int | None = None,
237    # Diagnostics (optional; run at fit time and attached to result)
238    diagnostics: Literal["none", "full"] | list[str] = "none",
239    diagnostics_options: dict | None = None,
240) -> FectResult:
241    """Estimate counterfactual treatment effects for panel data.
242
243    This is the main Python entry point for the counterfactual estimator
244    workflow.  Where the paper and the historical R package differ,
245    pyfector defaults to the paper's statistical definition and exposes
246    R-package-style behavior through explicit options.
247
248    Missing outcome policy
249    ----------------------
250    pyfector distinguishes raw missing outcomes from counterfactual
251    missingness caused by treatment.  Observed untreated cells
252    (``D == 0`` and non-missing ``Y``) fit the response surface.  Observed
253    treated cells (``D == 1`` and non-missing ``Y``) contribute to ATT as
254    ``Y - Y_ct``.  If a treated outcome is missing in the input data, the
255    model can still produce a counterfactual ``Y_ct`` for that cell, but
256    the cell is not counted in ``att_avg`` or ``att_on`` because the
257    treated potential outcome was not observed.
258
259    By default, ``min_T0`` is enforced only for treated and reversal
260    units.  Sparse controls are retained if they have at least one
261    observed outcome, because they may still inform the low-rank response
262    surface.  Set ``min_T0_strict=True`` to require controls to satisfy
263    ``min_T0`` too, matching the more conservative R fect sparse-panel
264    behavior.
265
266    Parameters
267    ----------
268    data : polars.DataFrame, pandas.DataFrame
269        Long-format panel data.
270    Y, D : str
271        Column names for outcome and binary treatment indicator.
272    index : (str, str)
273        Column names for (unit_id, time_period).
274    X : list of str, optional
275        Time-varying covariates.
276    W : str, optional
277        Observation weight column.
278    group : str, optional
279        Reserved for grouped estimation. Currently raises
280        ``NotImplementedError`` when supplied.
281    method : {"fe", "ife", "mc", "cfe", "both"}
282        Estimation method.
283    force : {"none", "unit", "time", "two-way"}
284        Fixed effects specification.
285    r : int or (int, int)
286        Number of factors.  If tuple, CV selects from range.
287    lam : float, optional
288        Nuclear norm penalty for MC.  If None with CV=True, auto-selected.
289    nlambda : int
290        Number of automatically generated lambda candidates for MC CV.
291    lambda_candidates : array-like, optional
292        Explicit non-negative lambda candidates for MC CV. When supplied,
293        ``nlambda`` is ignored.
294    CV : bool
295        If True, cross-validate over ``r`` for IFE when ``r`` is a tuple,
296        or over ``lam`` for MC when ``lam`` is None.
297    k : int
298        Number of CV folds.
299    cv_prop : float
300        Fraction of eligible observed control cells masked per CV fold.
301    cv_nobs : int
302        Number of consecutive within-unit observations to mask as a block.
303    cv_treat : bool
304        If True, restrict CV masks to pre-treatment cells of ever-treated
305        units. If False, use all observed control cells.
306    cv_donut : int
307        Exclude this many periods around treatment onset from CV evaluation.
308    criterion : {"mspe", "gmspe", "mad"}
309        Cross-validation loss.
310    cv_rule : {"min", "onepct"}
311        CV selection rule. ``"min"`` chooses the strict minimum-score
312        candidate and is the paper-faithful default. ``"onepct"`` chooses
313        the simplest candidate within 1% of the best score (lower ``r`` for
314        IFE, higher ``lam`` for MC).
315    se : bool
316        Compute standard errors via bootstrap/jackknife.
317    vartype : {"bootstrap", "jackknife"}
318        Inference method when ``se=True``.
319    nboots : int
320        Number of bootstrap replications. Ignored for jackknife.
321    alpha : float
322        Significance level for confidence intervals and tests.
323    tol : float
324        EM convergence tolerance for final point estimation.
325    max_iter : int
326        Maximum EM iterations.
327    min_T0 : int
328        Minimum untreated/pre-treatment observed periods. By default this is
329        enforced only for treated and treatment-reversal units.
330    min_T0_strict : bool
331        If True, enforce ``min_T0`` on all units, including controls. This
332        matches R fect's conservative handling of sparse control rows.
333    max_missing : float
334        Maximum missing-outcome fraction per unit, in ``[0, 1]``. Units with
335        no observed outcomes are always dropped, regardless of this threshold,
336        because they provide neither fitting information nor observed treated
337        effects.
338    normalize : bool
339        If True, estimate on an outcome standardized by its observed standard
340        deviation, then transform effects back to the original scale.
341    Z, Q : list of str, optional
342        Reserved CFE interaction arguments. Currently raise
343        ``NotImplementedError`` when supplied.
344    device : {"cpu", "gpu"}
345        Compute device.
346    n_jobs : int, optional
347        Parallel workers for CV and bootstrap. ``-1`` or ``None`` uses
348        all available CPUs.
349    seed : int, optional
350        Random seed for full reproducibility.
351    """
352    # Set device
353    set_device(device)
354    xp = get_backend()
355    n_jobs = _resolve_n_jobs(n_jobs)
356    if device == "gpu":
357        n_jobs = 1
358
359    if group is not None:
360        raise NotImplementedError("The `group` argument is not implemented yet.")
361    if Z is not None or Q is not None:
362        raise NotImplementedError("The `Z` and `Q` CFE interaction arguments are not implemented yet.")
363    if criterion not in {"mspe", "gmspe", "mad"}:
364        raise ValueError("criterion must be 'mspe', 'gmspe', or 'mad'")
365    if cv_rule not in {"min", "onepct"}:
366        raise ValueError("cv_rule must be 'min' or 'onepct'")
367    if min_T0 < 0:
368        raise ValueError("min_T0 must be non-negative")
369    if not 0.0 <= max_missing <= 1.0:
370        raise ValueError("max_missing must be between 0 and 1")
371
372    # Validate diagnostics request before doing any expensive estimation
373    # so users don't wait for a 30-min MC fit only to find their config
374    # is wrong.
375    requested_diag = validate_diagnostics_request(
376        diagnostics, diagnostics_options, se,
377    )
378
379    # Map force string to int
380    force_map = {"none": 0, "unit": 1, "time": 2, "two-way": 3}
381    force_int = force_map[force]
382
383    # Prepare panel data
384    panel = prepare_panel(
385        data, Y=Y, D=D, index=index, X=X, W=W,
386        group=group, min_T0=min_T0, min_T0_strict=min_T0_strict,
387        max_missing=max_missing,
388    )
389
390    # Move to device
391    Y_mat = to_device(panel.Y)
392    D_mat = to_device(panel.D)
393    I_mat = to_device(panel.I)
394    II_mat = to_device(panel.II)
395    X_mat = to_device(panel.X) if panel.X is not None else None
396    W_mat = to_device(panel.W) if panel.W is not None else None
397
398    # Normalize
399    norm_factor = 1.0
400    if normalize:
401        sd_y = float(xp.std(Y_mat[I_mat > 0]))
402        if sd_y > 0:
403            Y_mat = Y_mat / sd_y
404            norm_factor = sd_y
405
406    # Initial fit
407    Y0, beta0 = initial_fit(Y_mat, X_mat, II_mat, force_int)
408
409    # Determine r and lambda
410    r_cv = None
411    lambda_cv = None
412    cv_result = None
413
414    if method == "ife":
415        if isinstance(r, tuple) and CV:
416            cv_result = cv_ife(
417                Y_mat, Y0, X_mat, I_mat, II_mat, D_mat, W_mat, beta0,
418                force=force_int, r_range=r, k=k, cv_prop=cv_prop,
419                cv_nobs=cv_nobs, cv_treat=cv_treat, cv_donut=cv_donut,
420                criterion=criterion, cv_rule=cv_rule,
421                tol=tol, max_iter=max_iter,
422                n_jobs=n_jobs, seed=seed,
423            )
424            r_cv = cv_result.best_r
425        else:
426            r_cv = r if isinstance(r, int) else r[0]
427
428    elif method == "mc":
429        if lam is None and CV:
430            cv_result = cv_mc(
431                Y_mat, Y0, X_mat, I_mat, II_mat, D_mat, W_mat, beta0,
432                force=force_int, lambda_candidates=lambda_candidates,
433                nlambda=nlambda, k=k, cv_prop=cv_prop,
434                cv_nobs=cv_nobs, cv_treat=cv_treat, cv_donut=cv_donut,
435                criterion=criterion, cv_rule=cv_rule,
436                tol=tol, max_iter=max_iter,
437                n_jobs=n_jobs, seed=seed,
438            )
439            lambda_cv = cv_result.best_lambda
440        else:
441            lambda_cv = lam if lam is not None else 0.0
442
443    elif method == "fe":
444        r_cv = 0
445
446    elif method == "cfe":
447        r_cv = r if isinstance(r, int) else r[0]
448
449    # Point estimation
450    if method in ("fe", "ife"):
451        est = estimate_ife(
452            Y_mat, Y0, X_mat, II_mat, W_mat, beta0,
453            r=r_cv, force=force_int, tol=tol, max_iter=max_iter,
454        )
455    elif method == "mc":
456        est = estimate_mc(
457            Y_mat, Y0, X_mat, II_mat, W_mat, beta0,
458            lam=lambda_cv, force=force_int, tol=tol, max_iter=max_iter,
459        )
460    elif method == "cfe":
461        est = estimate_cfe(
462            Y_mat, Y0, X_mat, II_mat, W_mat, beta0,
463            r=r_cv, force=force_int, tol=tol, max_iter=max_iter,
464        )
465    elif method == "both":
466        # Run both IFE and MC, return IFE results with MC comparison
467        if isinstance(r, tuple) and CV:
468            cv_result = cv_ife(
469                Y_mat, Y0, X_mat, I_mat, II_mat, D_mat, W_mat, beta0,
470                force=force_int, r_range=r, k=k, cv_prop=cv_prop,
471                cv_nobs=cv_nobs, cv_treat=cv_treat, cv_donut=cv_donut,
472                criterion=criterion, cv_rule=cv_rule,
473                tol=tol, max_iter=max_iter,
474                n_jobs=n_jobs, seed=seed,
475            )
476            r_cv = cv_result.best_r
477        else:
478            r_cv = r if isinstance(r, int) else r[0]
479        est = estimate_ife(
480            Y_mat, Y0, X_mat, II_mat, W_mat, beta0,
481            r=r_cv, force=force_int, tol=tol, max_iter=max_iter,
482        )
483    else:
484        raise ValueError(f"Unknown method: {method}")
485
486    # Compute effects
487    eff = Y_mat - est.fit
488    Y_ct = est.fit
489
490    # Additive-FE baseline residual variance (Liu et al. 2024 sigma2.fect).
491    # For method="fe" the main estimator IS the additive-FE pass, so reuse
492    # est.sigma2. For ife/mc/cfe/both, run an extra r=0 IFE pass on the
493    # same panel with the user's requested FE structure.
494    if method == "fe":
495        sigma2_fect_value = float(est.sigma2)
496    else:
497        est_fect = estimate_ife(
498            Y_mat, Y0, X_mat, II_mat, W_mat, beta0,
499            r=0, force=force_int, tol=tol, max_iter=max_iter,
500        )
501        sigma2_fect_value = float(est_fect.sigma2)
502
503    # Denormalize
504    if normalize and norm_factor != 1.0:
505        eff = eff * norm_factor
506        Y_ct = Y_ct * norm_factor
507        Y_mat = Y_mat * norm_factor
508        if est.beta is not None:
509            est = est._replace(beta=est.beta * norm_factor)
510        sigma2_fect_value *= norm_factor ** 2
511
512    # ATT computation
513    T_on = to_device(panel.T_on)
514    att_avg, att_on, time_on, count_on, att_avg_unit = _compute_effects(
515        to_numpy(eff), to_numpy(D_mat), to_numpy(panel.T_on), to_numpy(I_mat),
516    )
517
518    # Build result
519    result = FectResult(
520        method=method,
521        r_cv=r_cv,
522        lambda_cv=lambda_cv,
523        att_avg=att_avg,
524        att_avg_unit=att_avg_unit,
525        att_on=att_on,
526        time_on=time_on,
527        count_on=count_on,
528        beta=to_numpy(est.beta) if est.beta is not None else None,
529        covariate_names=panel.covariate_names,
530        mu=est.mu,
531        alpha=to_numpy(est.alpha) if est.alpha is not None else None,
532        xi=to_numpy(est.xi) if est.xi is not None else None,
533        factors=to_numpy(est.factors) if est.factors is not None else None,
534        loadings=to_numpy(est.loadings) if est.loadings is not None else None,
535        Y_ct=to_numpy(Y_ct),
536        eff=to_numpy(eff),
537        residuals=to_numpy(est.residuals),
538        sigma2=est.sigma2,
539        sigma2_fect=sigma2_fect_value,
540        IC=est.IC,
541        PC=est.PC,
542        niter=est.niter,
543        converged=est.converged,
544        cv_result=cv_result,
545        panel=panel,
546        seed=seed,
547    )
548
549    # Inference
550    if se:
551        result.inference = _run_inference(
552            result, panel, Y_mat, X_mat, W_mat, beta0, Y0,
553            method=method, r_cv=r_cv, lambda_cv=lambda_cv,
554            force_int=force_int, tol=tol, max_iter=max_iter,
555            vartype=vartype, nboots=nboots, alpha=alpha,
556            n_jobs=n_jobs, seed=seed, normalize=normalize,
557            norm_factor=norm_factor,
558        )
559
560    # Run requested diagnostics at fit time. requested_diag is None when
561    # diagnostics="none". Validation already enforced se=True and
562    # required-config presence.
563    if requested_diag is not None:
564        opts = dict(diagnostics_options or {})
565        if "loo" in requested_diag:
566            opts["loo"] = True
567        else:
568            opts.setdefault("loo", False)
569        result.diagnostics = _run_diagnostics(
570            result, _requested=requested_diag, **opts,
571        )
572
573    return result

Estimate counterfactual treatment effects for panel data.

This is the main Python entry point for the counterfactual estimator workflow. Where the paper and the historical R package differ, pyfector defaults to the paper's statistical definition and exposes R-package-style behavior through explicit options.

Missing outcome policy

pyfector distinguishes raw missing outcomes from counterfactual missingness caused by treatment. Observed untreated cells (D == 0 and non-missing Y) fit the response surface. Observed treated cells (D == 1 and non-missing Y) contribute to ATT as Y - Y_ct. If a treated outcome is missing in the input data, the model can still produce a counterfactual Y_ct for that cell, but the cell is not counted in att_avg or att_on because the treated potential outcome was not observed.

By default, min_T0 is enforced only for treated and reversal units. Sparse controls are retained if they have at least one observed outcome, because they may still inform the low-rank response surface. Set min_T0_strict=True to require controls to satisfy min_T0 too, matching the more conservative R fect sparse-panel behavior.

Parameters

data : polars.DataFrame, pandas.DataFrame Long-format panel data. Y, D : str Column names for outcome and binary treatment indicator. index : (str, str) Column names for (unit_id, time_period). X : list of str, optional Time-varying covariates. W : str, optional Observation weight column. group : str, optional Reserved for grouped estimation. Currently raises NotImplementedError when supplied. method : {"fe", "ife", "mc", "cfe", "both"} Estimation method. force : {"none", "unit", "time", "two-way"} Fixed effects specification. r : int or (int, int) Number of factors. If tuple, CV selects from range. lam : float, optional Nuclear norm penalty for MC. If None with CV=True, auto-selected. nlambda : int Number of automatically generated lambda candidates for MC CV. lambda_candidates : array-like, optional Explicit non-negative lambda candidates for MC CV. When supplied, nlambda is ignored. CV : bool If True, cross-validate over r for IFE when r is a tuple, or over lam for MC when lam is None. k : int Number of CV folds. cv_prop : float Fraction of eligible observed control cells masked per CV fold. cv_nobs : int Number of consecutive within-unit observations to mask as a block. cv_treat : bool If True, restrict CV masks to pre-treatment cells of ever-treated units. If False, use all observed control cells. cv_donut : int Exclude this many periods around treatment onset from CV evaluation. criterion : {"mspe", "gmspe", "mad"} Cross-validation loss. cv_rule : {"min", "onepct"} CV selection rule. "min" chooses the strict minimum-score candidate and is the paper-faithful default. "onepct" chooses the simplest candidate within 1% of the best score (lower r for IFE, higher lam for MC). se : bool Compute standard errors via bootstrap/jackknife. vartype : {"bootstrap", "jackknife"} Inference method when se=True. nboots : int Number of bootstrap replications. Ignored for jackknife. alpha : float Significance level for confidence intervals and tests. tol : float EM convergence tolerance for final point estimation. max_iter : int Maximum EM iterations. min_T0 : int Minimum untreated/pre-treatment observed periods. By default this is enforced only for treated and treatment-reversal units. min_T0_strict : bool If True, enforce min_T0 on all units, including controls. This matches R fect's conservative handling of sparse control rows. max_missing : float Maximum missing-outcome fraction per unit, in [0, 1]. Units with no observed outcomes are always dropped, regardless of this threshold, because they provide neither fitting information nor observed treated effects. normalize : bool If True, estimate on an outcome standardized by its observed standard deviation, then transform effects back to the original scale. Z, Q : list of str, optional Reserved CFE interaction arguments. Currently raise NotImplementedError when supplied. device : {"cpu", "gpu"} Compute device. n_jobs : int, optional Parallel workers for CV and bootstrap. -1 or None uses all available CPUs. seed : int, optional Random seed for full reproducibility.