pyrollmatch

pyrollmatch — Rolling entry matching and weighting for staggered adoption studies.

A Python package for causal inference with staggered treatment adoption, built on polars and numpy for scalable matching/weighting on large panel datasets (100K+ units).

Methods

  • Matching: Nearest-neighbor matching on propensity scores or pairwise distances (Mahalanobis, Euclidean, robust Mahalanobis). Supports calipers, replacement modes, and matching order (following MatchIt conventions).
  • Entropy balancing: Direct covariate balance via convex optimization (Hainmueller 2012). Each entry cohort weights independently (stacked design).
  • Custom: User-defined per-period weighting functions.

Quick Start

>>> from pyrollmatch import rollmatch
>>>
>>> # Propensity score matching
>>> result = rollmatch(
...     data, treat="treat", tm="time", entry="entry_time", id="unit_id",
...     covariates=["x1", "x2", "x3"],
...     ps_caliper=0.2, num_matches=3,
... )
>>>
>>> # Mahalanobis distance matching
>>> result = rollmatch(
...     data, ..., model_type="mahalanobis",
... )
>>>
>>> # Entropy balancing
>>> result = rollmatch(
...     data, ..., method="ebal", moment=1,
... )
>>>
>>> result.balance         # covariate balance (SMD table)
>>> result.weights         # unit-level weights
>>> result.matched_data    # match pairs (matching only)

References

  • Witman et al. (2018). "Comparison Group Selection in the Presence of Rolling Entry." Health Services Research, 54(1), 262-270.
  • Hainmueller, J. (2012). "Entropy Balancing for Causal Effects." Political Analysis, 20(1), 25-46.
  • Imai, King, Stuart (2011). MatchIt: Nonparametric Preprocessing for Parametric Causal Inference.
  • Rosenbaum, P. (2010). Design of Observational Studies, ch. 8.
 1"""
 2pyrollmatch — Rolling entry matching and weighting for staggered adoption studies.
 3
 4A Python package for causal inference with staggered treatment adoption,
 5built on polars and numpy for scalable matching/weighting on large panel
 6datasets (100K+ units).
 7
 8Methods
 9-------
10- **Matching**: Nearest-neighbor matching on propensity scores or pairwise
11  distances (Mahalanobis, Euclidean, robust Mahalanobis). Supports calipers,
12  replacement modes, and matching order (following MatchIt conventions).
13- **Entropy balancing**: Direct covariate balance via convex optimization
14  (Hainmueller 2012). Each entry cohort weights independently (stacked design).
15- **Custom**: User-defined per-period weighting functions.
16
17Quick Start
18-----------
19>>> from pyrollmatch import rollmatch
20>>>
21>>> # Propensity score matching
22>>> result = rollmatch(
23...     data, treat="treat", tm="time", entry="entry_time", id="unit_id",
24...     covariates=["x1", "x2", "x3"],
25...     ps_caliper=0.2, num_matches=3,
26... )
27>>>
28>>> # Mahalanobis distance matching
29>>> result = rollmatch(
30...     data, ..., model_type="mahalanobis",
31... )
32>>>
33>>> # Entropy balancing
34>>> result = rollmatch(
35...     data, ..., method="ebal", moment=1,
36... )
37>>>
38>>> result.balance         # covariate balance (SMD table)
39>>> result.weights         # unit-level weights
40>>> result.matched_data    # match pairs (matching only)
41
42References
43----------
44- Witman et al. (2018). "Comparison Group Selection in the Presence of
45  Rolling Entry." Health Services Research, 54(1), 262-270.
46- Hainmueller, J. (2012). "Entropy Balancing for Causal Effects."
47  Political Analysis, 20(1), 25-46.
48- Imai, King, Stuart (2011). MatchIt: Nonparametric Preprocessing for
49  Parametric Causal Inference.
50- Rosenbaum, P. (2010). Design of Observational Studies, ch. 8.
51"""
52
53from .core import rollmatch, RollmatchResult
54from .reduce import reduce_data
55from .score import score_data, ScoredResult, SUPPORTED_MODELS, DISTANCE_MODELS, PROPENSITY_MODELS
56from .match import DistanceSpec
57from .weight import entropy_balance
58from .balance import (
59    compute_balance, compute_balance_weighted,
60    balance_by_period, balance_by_period_weighted,
61    smd_table,
62)
63from .diagnostics import (
64    balance_test, equivalence_test,
65    balance_test_weighted, equivalence_test_weighted,
66)
67
68__version__ = "0.1.0"
69__all__ = [
70    # Core
71    "rollmatch",
72    "RollmatchResult",
73    # Pipeline stages
74    "reduce_data",
75    "score_data",
76    "ScoredResult",
77    # Distance
78    "DistanceSpec",
79    # Constants
80    "SUPPORTED_MODELS",
81    "DISTANCE_MODELS",
82    "PROPENSITY_MODELS",
83    # Weighting
84    "entropy_balance",
85    # Balance diagnostics
86    "compute_balance",
87    "compute_balance_weighted",
88    "balance_by_period",
89    "balance_by_period_weighted",
90    "smd_table",
91    # Statistical tests
92    "balance_test",
93    "equivalence_test",
94    "balance_test_weighted",
95    "equivalence_test_weighted",
96]
def rollmatch( data: polars.dataframe.frame.DataFrame, treat: str, tm: str, entry: str, id: str, covariates: list[str], lookback: int = 1, method: Union[str, Callable] = 'matching', verbose: bool = True, **method_kwargs) -> RollmatchResult | None:
427def rollmatch(
428    data: pl.DataFrame,
429    treat: str,
430    tm: str,
431    entry: str,
432    id: str,
433    covariates: list[str],
434    lookback: int = 1,
435    method: str | Callable = "matching",
436    verbose: bool = True,
437    **method_kwargs,
438) -> RollmatchResult | None:
439    """Run the rolling entry matching/weighting pipeline.
440
441    This is the main entry point for pyrollmatch. It orchestrates data
442    reduction, scoring, matching (or weighting), and balance computation
443    in a single call.
444
445    Parameters
446    ----------
447    data : pl.DataFrame
448        Panel data with unit × time observations. Must contain columns
449        for treatment status, time period, entry period, unit ID, and
450        covariates.
451    treat : str
452        Column name for binary treatment indicator (1=treated, 0=control).
453    tm : str
454        Column name for time period (integer).
455    entry : str
456        Column name for entry period. Treated units: the period when
457        treatment begins. Controls: any value > max(tm) or null.
458    id : str
459        Column name for unit identifier.
460    covariates : list[str]
461        Covariate column names used for scoring and balance diagnostics.
462    lookback : int, default 1
463        Number of periods before entry to use as baseline. Must be >= 1.
464    method : str or callable, default ``"matching"``
465        Weighting method:
466
467        ``"matching"``
468            Nearest-neighbor matching on propensity scores or pairwise
469            distances. Returns match pairs in ``matched_data``.
470
471            **Matching kwargs:**
472
473            - ``ps_caliper`` (float, default 0): PS caliper multiplier.
474              0 = no caliper.
475            - ``ps_caliper_std`` (str, default ``"average"``): How to
476              compute pooled SD for PS caliper. ``"average"``,
477              ``"weighted"``, or ``"none"``.
478            - ``num_matches`` (int, default 1): Controls per treated.
479            - ``replacement`` (str, default ``"cross_cohort"``):
480              ``"unrestricted"``, ``"cross_cohort"``, or ``"global_no"``.
481            - ``model_type`` (str, default ``"logistic"``): Scoring model.
482              See :data:`~pyrollmatch.score.SUPPORTED_MODELS`.
483            - ``block_size`` (int, default 2000): Block size for memory.
484            - ``mahvars`` (list[str] or None): Covariates for Mahalanobis
485              distance matching with PS caliper (MatchIt pattern).
486            - ``m_order`` (str or None): Matching order: ``"largest"``,
487              ``"smallest"``, ``"random"``, ``"data"``, or ``None`` (auto).
488            - ``caliper`` (dict or None): Per-variable calipers,
489              e.g. ``{"x1": 0.5, "x2": 0.3}``.
490            - ``std_caliper`` (bool, default True): Whether per-variable
491              caliper widths are in SD units.
492
493        ``"ebal"``
494            Entropy balancing (Hainmueller 2012). Returns per-cohort
495            weights in ``weighted_data``.
496
497            **Ebal kwargs:**
498
499            - ``moment`` (int, default 1): Moment constraint (1=mean,
500              2=mean+variance, 3=mean+variance+skewness).
501            - ``max_weight`` (float or None): Maximum weight cap.
502
503        callable
504            User-defined function with signature
505            ``fn(treated_data, control_data, covariates, id, **kwargs)``
506            returning ``pl.DataFrame`` with columns ``[id, weight]``.
507
508    verbose : bool, default True
509        Print progress and balance summary.
510    **method_kwargs
511        Method-specific keyword arguments (see ``method`` above).
512
513    Returns
514    -------
515    RollmatchResult or None
516        Contains ``matched_data``, ``balance``, ``weights``, and summary
517        statistics. Returns ``None`` if no matches/weights could be
518        computed (e.g. empty data, convergence failure).
519
520    Examples
521    --------
522    Propensity score matching (default):
523
524    >>> result = rollmatch(
525    ...     data, treat="treat", tm="time", entry="entry_time",
526    ...     id="unit_id", covariates=["x1", "x2", "x3"],
527    ...     ps_caliper=0.2, num_matches=3,
528    ... )
529    >>> result.matched_data    # match pairs
530    >>> result.balance         # SMD table
531
532    Mahalanobis distance matching:
533
534    >>> result = rollmatch(
535    ...     data, treat="treat", tm="time", entry="entry_time",
536    ...     id="unit_id", covariates=["x1", "x2", "x3"],
537    ...     model_type="mahalanobis",
538    ... )
539
540    Mahalanobis matching with PS caliper (MatchIt ``mahvars`` pattern):
541
542    >>> result = rollmatch(
543    ...     data, treat="treat", tm="time", entry="entry_time",
544    ...     id="unit_id", covariates=["x1", "x2", "x3"],
545    ...     ps_caliper=0.25, mahvars=["x1", "x2"],
546    ... )
547
548    Entropy balancing:
549
550    >>> result = rollmatch(
551    ...     data, treat="treat", tm="time", entry="entry_time",
552    ...     id="unit_id", covariates=["x1", "x2", "x3"],
553    ...     method="ebal", moment=1,
554    ... )
555    >>> result.weights         # unit weights
556    >>> result.weighted_data   # per-cohort weights
557    """
558    if callable(method) and not isinstance(method, str):
559        return _run_callable(
560            method, data, treat, tm, entry, id, covariates,
561            lookback, verbose, **method_kwargs,
562        )
563    elif method == "matching":
564        return _run_matching(
565            data, treat, tm, entry, id, covariates,
566            lookback, verbose, **method_kwargs,
567        )
568    elif method == "ebal":
569        return _run_ebal(
570            data, treat, tm, entry, id, covariates,
571            lookback, verbose, **method_kwargs,
572        )
573    else:
574        raise ValueError(
575            f"method must be 'matching', 'ebal', or a callable, got {method!r}"
576        )

Run the rolling entry matching/weighting pipeline.

This is the main entry point for pyrollmatch. It orchestrates data reduction, scoring, matching (or weighting), and balance computation in a single call.

Parameters

data : pl.DataFrame Panel data with unit × time observations. Must contain columns for treatment status, time period, entry period, unit ID, and covariates. treat : str Column name for binary treatment indicator (1=treated, 0=control). tm : str Column name for time period (integer). entry : str Column name for entry period. Treated units: the period when treatment begins. Controls: any value > max(tm) or null. id : str Column name for unit identifier. covariates : list[str] Covariate column names used for scoring and balance diagnostics. lookback : int, default 1 Number of periods before entry to use as baseline. Must be >= 1. method : str or callable, default "matching" Weighting method:

``"matching"``
    Nearest-neighbor matching on propensity scores or pairwise
    distances. Returns match pairs in ``matched_data``.

    **Matching kwargs:**

    - ``ps_caliper`` (float, default 0): PS caliper multiplier.
      0 = no caliper.
    - ``ps_caliper_std`` (str, default ``"average"``): How to
      compute pooled SD for PS caliper. ``"average"``,
      ``"weighted"``, or ``"none"``.
    - ``num_matches`` (int, default 1): Controls per treated.
    - ``replacement`` (str, default ``"cross_cohort"``):
      ``"unrestricted"``, ``"cross_cohort"``, or ``"global_no"``.
    - ``model_type`` (str, default ``"logistic"``): Scoring model.
      See `~pyrollmatch.score.SUPPORTED_MODELS`.
    - ``block_size`` (int, default 2000): Block size for memory.
    - ``mahvars`` (list[str] or None): Covariates for Mahalanobis
      distance matching with PS caliper (MatchIt pattern).
    - ``m_order`` (str or None): Matching order: ``"largest"``,
      ``"smallest"``, ``"random"``, ``"data"``, or ``None`` (auto).
    - ``caliper`` (dict or None): Per-variable calipers,
      e.g. ``{"x1": 0.5, "x2": 0.3}``.
    - ``std_caliper`` (bool, default True): Whether per-variable
      caliper widths are in SD units.

``"ebal"``
    Entropy balancing (Hainmueller 2012). Returns per-cohort
    weights in ``weighted_data``.

    **Ebal kwargs:**

    - ``moment`` (int, default 1): Moment constraint (1=mean,
      2=mean+variance, 3=mean+variance+skewness).
    - ``max_weight`` (float or None): Maximum weight cap.

callable
    User-defined function with signature
    ``fn(treated_data, control_data, covariates, id, **kwargs)``
    returning ``pl.DataFrame`` with columns ``[id, weight]``.

verbose : bool, default True Print progress and balance summary. **method_kwargs Method-specific keyword arguments (see method above).

Returns

RollmatchResult or None Contains matched_data, balance, weights, and summary statistics. Returns None if no matches/weights could be computed (e.g. empty data, convergence failure).

Examples

Propensity score matching (default):

>>> result = rollmatch(
...     data, treat="treat", tm="time", entry="entry_time",
...     id="unit_id", covariates=["x1", "x2", "x3"],
...     ps_caliper=0.2, num_matches=3,
... )
>>> result.matched_data    # match pairs
>>> result.balance         # SMD table

Mahalanobis distance matching:

>>> result = rollmatch(
...     data, treat="treat", tm="time", entry="entry_time",
...     id="unit_id", covariates=["x1", "x2", "x3"],
...     model_type="mahalanobis",
... )

Mahalanobis matching with PS caliper (MatchIt mahvars pattern):

>>> result = rollmatch(
...     data, treat="treat", tm="time", entry="entry_time",
...     id="unit_id", covariates=["x1", "x2", "x3"],
...     ps_caliper=0.25, mahvars=["x1", "x2"],
... )

Entropy balancing:

>>> result = rollmatch(
...     data, treat="treat", tm="time", entry="entry_time",
...     id="unit_id", covariates=["x1", "x2", "x3"],
...     method="ebal", moment=1,
... )
>>> result.weights         # unit weights
>>> result.weighted_data   # per-cohort weights
@dataclass
class RollmatchResult:
34@dataclass
35class RollmatchResult:
36    """Result from :func:`rollmatch`.
37
38    Attributes
39    ----------
40    matched_data : pl.DataFrame or None
41        Match pairs with columns ``[tm, treat_id, control_id, difference]``.
42        ``None`` for weighting-only methods (ebal, custom).
43    balance : pl.DataFrame
44        Covariate balance table with pre-match and post-match SMDs.
45    n_treated_total : int
46        Total treated units in the reduced sample.
47    n_treated_matched : int
48        Treated units that were successfully matched or weighted.
49    n_controls_matched : int
50        Unique control units used.
51    ps_caliper : float or None
52        Propensity score caliper multiplier used. ``None`` for non-matching
53        methods.
54    weights : pl.DataFrame
55        Unit-level weights ``[id, weight]``. For matching, derived from
56        pairs. For ebal, collapsed across cohorts.
57    weighted_data : pl.DataFrame or None
58        Per-cohort weights ``[tm, id, weight]`` (ebal/custom only).
59    method : str
60        Method used: ``"matching"``, ``"ebal"``, or ``"custom"``.
61    """
62    matched_data: pl.DataFrame | None
63    balance: pl.DataFrame
64    n_treated_total: int
65    n_treated_matched: int
66    n_controls_matched: int
67    ps_caliper: float | None
68    weights: pl.DataFrame
69    weighted_data: pl.DataFrame | None = None
70    method: str = "matching"

Result from rollmatch().

Attributes

matched_data : pl.DataFrame or None Match pairs with columns [tm, treat_id, control_id, difference]. None for weighting-only methods (ebal, custom). balance : pl.DataFrame Covariate balance table with pre-match and post-match SMDs. n_treated_total : int Total treated units in the reduced sample. n_treated_matched : int Treated units that were successfully matched or weighted. n_controls_matched : int Unique control units used. ps_caliper : float or None Propensity score caliper multiplier used. None for non-matching methods. weights : pl.DataFrame Unit-level weights [id, weight]. For matching, derived from pairs. For ebal, collapsed across cohorts. weighted_data : pl.DataFrame or None Per-cohort weights [tm, id, weight] (ebal/custom only). method : str Method used: "matching", "ebal", or "custom".

RollmatchResult( matched_data: polars.dataframe.frame.DataFrame | None, balance: polars.dataframe.frame.DataFrame, n_treated_total: int, n_treated_matched: int, n_controls_matched: int, ps_caliper: float | None, weights: polars.dataframe.frame.DataFrame, weighted_data: polars.dataframe.frame.DataFrame | None = None, method: str = 'matching')
matched_data: polars.dataframe.frame.DataFrame | None
balance: polars.dataframe.frame.DataFrame
n_treated_total: int
n_treated_matched: int
n_controls_matched: int
ps_caliper: float | None
weights: polars.dataframe.frame.DataFrame
weighted_data: polars.dataframe.frame.DataFrame | None = None
method: str = 'matching'
def reduce_data( data: polars.dataframe.frame.DataFrame, treat: str, tm: str, entry: str, id: str, lookback: int = 1) -> polars.dataframe.frame.DataFrame:
13def reduce_data(
14    data: pl.DataFrame,
15    treat: str,
16    tm: str,
17    entry: str,
18    id: str,
19    lookback: int = 1,
20) -> pl.DataFrame:
21    """Construct quasi-panel for rolling entry matching.
22
23    Parameters
24    ----------
25    data : pl.DataFrame
26        Panel data with treat, time, entry, id columns and covariates.
27    treat : str
28        Column name for binary treatment indicator (1=treated, 0=control).
29    tm : str
30        Column name for time period.
31    entry : str
32        Column name for entry period. For treated units, this is the time
33        period when treatment begins (integer). For control units, use either:
34        - ``None``/``null`` (recommended), or
35        - Any integer larger than max(tm) (e.g., 99, 999)
36        Control units' entry values are never used by the algorithm.
37    id : str
38        Column name for unit identifier.
39    lookback : int
40        Number of periods to look back from entry for baseline covariates.
41
42    Returns
43    -------
44    pl.DataFrame
45        Reduced dataset: treated at baseline + controls at all baseline periods.
46    """
47    if lookback < 1:
48        raise ValueError(f"lookback must be >= 1, got {lookback}")
49
50    for col in [treat, tm, entry, id]:
51        if col not in data.columns:
52            raise ValueError(f"Column '{col}' not found in data")
53
54    # Treatment set: treated units at their baseline period (entry - lookback)
55    # Controls may have null/None entry — that's fine, they're filtered by treat==0
56    treat_set = data.filter(
57        (pl.col(treat) == 1)
58        & pl.col(entry).is_not_null()
59        & (pl.col(tm) == pl.col(entry) - lookback)
60    )
61
62    # Get unique baseline time periods from treated set
63    baseline_periods = treat_set[tm].unique().to_list()
64
65    # Control set: controls at those same time periods
66    control_set = data.filter(
67        (pl.col(treat) == 0) & (pl.col(tm).is_in(baseline_periods))
68    )
69
70    # Combine
71    reduced = pl.concat([treat_set, control_set])
72
73    return reduced

Construct quasi-panel for rolling entry matching.

Parameters

data : pl.DataFrame Panel data with treat, time, entry, id columns and covariates. treat : str Column name for binary treatment indicator (1=treated, 0=control). tm : str Column name for time period. entry : str Column name for entry period. For treated units, this is the time period when treatment begins (integer). For control units, use either: - None/null (recommended), or - Any integer larger than max(tm) (e.g., 99, 999) Control units' entry values are never used by the algorithm. id : str Column name for unit identifier. lookback : int Number of periods to look back from entry for baseline covariates.

Returns

pl.DataFrame Reduced dataset: treated at baseline + controls at all baseline periods.

def score_data( reduced_data: polars.dataframe.frame.DataFrame, covariates: list[str], treat: str, model_type: str = 'logistic', match_on: str = 'logit', max_iter: int = 1000) -> ScoredResult:
252def score_data(
253    reduced_data: pl.DataFrame,
254    covariates: list[str],
255    treat: str,
256    model_type: str = "logistic",
257    match_on: str = "logit",
258    max_iter: int = 1000,
259) -> ScoredResult:
260    """Fit a propensity/distance model and return scored data with metadata.
261
262    This is the scoring step of the matching pipeline. For propensity score
263    models, a classifier is fitted and scores are added to the data. For
264    distance-based models, covariance / scaling metadata is computed and
265    returned in the :class:`ScoredResult` for the matching engine to use.
266
267    Parameters
268    ----------
269    reduced_data : pl.DataFrame
270        Output from :func:`reduce_data`.
271    covariates : list[str]
272        Column names of matching covariates.
273    treat : str
274        Column name for binary treatment indicator (1=treated, 0=control).
275    model_type : str, default ``"logistic"``
276        Scoring model. One of :data:`SUPPORTED_MODELS`:
277
278        **Propensity score models** (produce scalar scores per unit):
279        ``"logistic"``, ``"probit"``, ``"gbm"``, ``"rf"``,
280        ``"lasso"``, ``"ridge"``, ``"elasticnet"``.
281
282        **Distance-based models** (pairwise distances computed by matcher):
283        ``"mahalanobis"``, ``"scaled_euclidean"``,
284        ``"robust_mahalanobis"``, ``"euclidean"``.
285    match_on : str, default ``"logit"``
286        Score transformation for propensity models:
287        ``"logit"`` for log-odds (recommended), ``"pscore"`` for raw
288        probability. Ignored for distance-based models.
289    max_iter : int, default 1000
290        Maximum optimizer iterations (propensity models only).
291
292    Returns
293    -------
294    ScoredResult
295        Contains ``.data`` (DataFrame with ``"score"`` column),
296        ``.model`` (fitted classifier or None), and distance metadata.
297
298    Raises
299    ------
300    ValueError
301        If ``model_type`` is not in :data:`SUPPORTED_MODELS`, covariates
302        are missing, or data contains NaN values.
303
304    Examples
305    --------
306    >>> result = score_data(reduced, ["x1", "x2"], "treat")
307    >>> result.data["score"]        # propensity scores
308    >>> result.model                # fitted LogisticRegression
309
310    >>> result = score_data(reduced, ["x1", "x2"], "treat",
311    ...                     model_type="mahalanobis")
312    >>> result.cov_inv              # inverse covariance matrix
313    >>> result.distance_metric      # "mahalanobis"
314    """
315    if model_type not in SUPPORTED_MODELS:
316        raise ValueError(
317            f"model_type must be one of {SUPPORTED_MODELS}, got '{model_type}'"
318        )
319
320    valid_match = ("logit", "pscore")
321    if match_on not in valid_match and model_type not in DISTANCE_MODELS:
322        raise ValueError(f"match_on must be one of {valid_match}, got '{match_on}'")
323
324    for col in covariates:
325        if col not in reduced_data.columns:
326            raise ValueError(f"Covariate '{col}' not found in data")
327
328    # Extract numpy arrays
329    X = reduced_data.select(covariates).to_numpy().astype(np.float64)
330    y = reduced_data[treat].to_numpy().astype(np.int32)
331
332    # Check for NaN
333    nan_mask = np.isnan(X).any(axis=1)
334    if nan_mask.any():
335        raise ValueError(
336            f"{nan_mask.sum()} rows have NaN in covariates. "
337            "Remove NaN rows before scoring."
338        )
339
340    # Initialize metadata fields
341    cov_inv = None
342    distance_metric = None
343    distance_transform = None
344    ranked_covariates = None
345
346    if model_type in DISTANCE_MODELS:
347        scores = np.zeros(len(X))  # placeholder — matcher uses covariates
348        model = None
349
350        if model_type == "mahalanobis":
351            cov = _pooled_within_group_cov(X, y)
352            cov += np.eye(cov.shape[0]) * 1e-6
353            cov_inv = np.linalg.inv(cov)
354            distance_metric = "mahalanobis"
355
356        elif model_type == "scaled_euclidean":
357            sd = _pooled_within_group_sd(X, y)
358            sd[sd < 1e-10] = 1.0
359            distance_transform = np.diag(1.0 / sd)
360            distance_metric = "scaled_euclidean"
361
362        elif model_type == "robust_mahalanobis":
363            # MatchIt: rank full sample, cov(ranks), scale by sd(1:n)
364            n = X.shape[0]
365            X_ranked = np.column_stack([
366                rankdata(X[:, j], method="average") for j in range(X.shape[1])
367            ])
368            var_r = np.cov(X_ranked, rowvar=False, ddof=1)
369            var_r = np.atleast_2d(var_r)
370            sd_1_to_n = np.std(np.arange(1, n + 1), ddof=1)
371            multiplier = sd_1_to_n / np.sqrt(np.diag(var_r)).clip(1e-10)
372            var_r = var_r * np.outer(multiplier, multiplier)
373            var_r += np.eye(var_r.shape[0]) * 1e-6
374            cov_inv = np.linalg.inv(var_r)
375            distance_metric = "robust_mahalanobis"
376
377        elif model_type == "euclidean":
378            distance_metric = "euclidean"
379
380    else:
381        # Propensity score model
382        model = _build_model(model_type, max_iter)
383        model.fit(X, y)
384        scores = _predict_scores(model, X, model_type, match_on)
385
386    result_df = reduced_data.with_columns(pl.Series("score", scores))
387
388    # For robust_mahalanobis: add full-dataset ranked columns
389    if model_type == "robust_mahalanobis":
390        ranked_covariates = [f"_ranked_{c}" for c in covariates]
391        result_df = result_df.with_columns([
392            pl.Series(f"_ranked_{col}", X_ranked[:, j])
393            for j, col in enumerate(covariates)
394        ])
395
396    return ScoredResult(
397        data=result_df, model=model, covariates=covariates,
398        model_type=model_type, match_on=match_on,
399        cov_inv=cov_inv,
400        distance_metric=distance_metric,
401        distance_transform=distance_transform,
402        ranked_covariates=ranked_covariates,
403    )

Fit a propensity/distance model and return scored data with metadata.

This is the scoring step of the matching pipeline. For propensity score models, a classifier is fitted and scores are added to the data. For distance-based models, covariance / scaling metadata is computed and returned in the ScoredResult for the matching engine to use.

Parameters

reduced_data : pl.DataFrame Output from reduce_data(). covariates : list[str] Column names of matching covariates. treat : str Column name for binary treatment indicator (1=treated, 0=control). model_type : str, default "logistic" Scoring model. One of SUPPORTED_MODELS:

**Propensity score models** (produce scalar scores per unit):
``"logistic"``, ``"probit"``, ``"gbm"``, ``"rf"``,
``"lasso"``, ``"ridge"``, ``"elasticnet"``.

**Distance-based models** (pairwise distances computed by matcher):
``"mahalanobis"``, ``"scaled_euclidean"``,
``"robust_mahalanobis"``, ``"euclidean"``.

match_on : str, default "logit" Score transformation for propensity models: "logit" for log-odds (recommended), "pscore" for raw probability. Ignored for distance-based models. max_iter : int, default 1000 Maximum optimizer iterations (propensity models only).

Returns

ScoredResult Contains .data (DataFrame with "score" column), .model (fitted classifier or None), and distance metadata.

Raises

ValueError If model_type is not in SUPPORTED_MODELS, covariates are missing, or data contains NaN values.

Examples

>>> result = score_data(reduced, ["x1", "x2"], "treat")
>>> result.data["score"]        # propensity scores
>>> result.model                # fitted LogisticRegression
>>> result = score_data(reduced, ["x1", "x2"], "treat",
...                     model_type="mahalanobis")
>>> result.cov_inv              # inverse covariance matrix
>>> result.distance_metric      # "mahalanobis"
@dataclass
class ScoredResult:
 74@dataclass
 75class ScoredResult:
 76    """Container for scored data and distance/model metadata.
 77
 78    Always returned by :func:`score_data`. Provides a uniform interface
 79    regardless of whether the model is propensity-based or distance-based.
 80
 81    Attributes
 82    ----------
 83    data : pl.DataFrame
 84        Input data with an added ``"score"`` column. For propensity models,
 85        this contains real propensity scores. For distance models, it
 86        contains placeholder zeros (the matcher uses covariates directly).
 87    model : Any
 88        Fitted sklearn classifier for propensity models. ``None`` for
 89        distance-based models.
 90    covariates : list[str]
 91        Covariate column names used for scoring.
 92    model_type : str
 93        The model type used (e.g. ``"logistic"``, ``"mahalanobis"``).
 94    match_on : str
 95        Score transformation applied: ``"logit"`` or ``"pscore"``.
 96        Only meaningful for propensity models.
 97    cov_inv : np.ndarray or None
 98        Inverse covariance matrix. Set for ``"mahalanobis"`` and
 99        ``"robust_mahalanobis"``; ``None`` otherwise.
100    distance_metric : str or None
101        Distance metric identifier passed to the matcher:
102        ``"mahalanobis"``, ``"robust_mahalanobis"``, ``"scaled_euclidean"``,
103        ``"euclidean"``, or ``None`` for propensity score matching.
104    distance_transform : np.ndarray or None
105        Diagonal scaling matrix for ``"scaled_euclidean"``.
106        ``None`` for all other model types.
107    ranked_covariates : list[str] or None
108        Column names of pre-ranked covariates added to ``data`` for
109        ``"robust_mahalanobis"`` (e.g. ``["_ranked_x1", "_ranked_x2"]``).
110        The matcher uses these instead of raw covariates so that
111        full-dataset ranks are consistent with ``cov_inv``.
112    """
113    data: pl.DataFrame
114    model: Any
115    covariates: list[str]
116    model_type: str
117    match_on: str
118    cov_inv: np.ndarray | None = None
119    distance_metric: str | None = None
120    distance_transform: np.ndarray | None = None
121    ranked_covariates: list[str] | None = None

Container for scored data and distance/model metadata.

Always returned by score_data(). Provides a uniform interface regardless of whether the model is propensity-based or distance-based.

Attributes

data : pl.DataFrame Input data with an added "score" column. For propensity models, this contains real propensity scores. For distance models, it contains placeholder zeros (the matcher uses covariates directly). model : Any Fitted sklearn classifier for propensity models. None for distance-based models. covariates : list[str] Covariate column names used for scoring. model_type : str The model type used (e.g. "logistic", "mahalanobis"). match_on : str Score transformation applied: "logit" or "pscore". Only meaningful for propensity models. cov_inv : np.ndarray or None Inverse covariance matrix. Set for "mahalanobis" and "robust_mahalanobis"; None otherwise. distance_metric : str or None Distance metric identifier passed to the matcher: "mahalanobis", "robust_mahalanobis", "scaled_euclidean", "euclidean", or None for propensity score matching. distance_transform : np.ndarray or None Diagonal scaling matrix for "scaled_euclidean". None for all other model types. ranked_covariates : list[str] or None Column names of pre-ranked covariates added to data for "robust_mahalanobis" (e.g. ["_ranked_x1", "_ranked_x2"]). The matcher uses these instead of raw covariates so that full-dataset ranks are consistent with cov_inv.

ScoredResult( data: polars.dataframe.frame.DataFrame, model: Any, covariates: list[str], model_type: str, match_on: str, cov_inv: numpy.ndarray | None = None, distance_metric: str | None = None, distance_transform: numpy.ndarray | None = None, ranked_covariates: list[str] | None = None)
data: polars.dataframe.frame.DataFrame
model: Any
covariates: list[str]
model_type: str
match_on: str
cov_inv: numpy.ndarray | None = None
distance_metric: str | None = None
distance_transform: numpy.ndarray | None = None
ranked_covariates: list[str] | None = None
@dataclass
class DistanceSpec:
52@dataclass
53class DistanceSpec:
54    """Bundle of distance-related parameters for the matching engine.
55
56    Encapsulates all metadata needed to compute pairwise distances
57    between treated and control units. For propensity score matching,
58    all fields are ``None``/default and the matcher uses scalar score
59    differences.
60
61    Attributes
62    ----------
63    metric : str or None
64        Distance metric: ``"mahalanobis"``, ``"robust_mahalanobis"``,
65        ``"scaled_euclidean"``, ``"euclidean"``, or ``None`` for
66        propensity score matching.
67    covariates : list[str] or None
68        Covariate column names to extract from the scored DataFrame
69        for pairwise distance computation.
70    cov_inv : np.ndarray or None
71        Inverse covariance matrix (for Mahalanobis variants).
72    transform : np.ndarray or None
73        Diagonal scaling matrix (for scaled Euclidean).
74    is_mahvars : bool
75        If True, propensity scores are real (mahvars pattern) and the
76        PS caliper should be applied alongside distance matching.
77    """
78    metric: str | None = None
79    covariates: list[str] | None = None
80    cov_inv: np.ndarray | None = None
81    transform: np.ndarray | None = None
82    is_mahvars: bool = False
83
84    @property
85    def use_pairwise(self) -> bool:
86        """Whether to compute pairwise covariate distances."""
87        return self.metric is not None and self.covariates is not None

Bundle of distance-related parameters for the matching engine.

Encapsulates all metadata needed to compute pairwise distances between treated and control units. For propensity score matching, all fields are None/default and the matcher uses scalar score differences.

Attributes

metric : str or None Distance metric: "mahalanobis", "robust_mahalanobis", "scaled_euclidean", "euclidean", or None for propensity score matching. covariates : list[str] or None Covariate column names to extract from the scored DataFrame for pairwise distance computation. cov_inv : np.ndarray or None Inverse covariance matrix (for Mahalanobis variants). transform : np.ndarray or None Diagonal scaling matrix (for scaled Euclidean). is_mahvars : bool If True, propensity scores are real (mahvars pattern) and the PS caliper should be applied alongside distance matching.

DistanceSpec( metric: str | None = None, covariates: list[str] | None = None, cov_inv: numpy.ndarray | None = None, transform: numpy.ndarray | None = None, is_mahvars: bool = False)
metric: str | None = None
covariates: list[str] | None = None
cov_inv: numpy.ndarray | None = None
transform: numpy.ndarray | None = None
is_mahvars: bool = False
use_pairwise: bool
84    @property
85    def use_pairwise(self) -> bool:
86        """Whether to compute pairwise covariate distances."""
87        return self.metric is not None and self.covariates is not None

Whether to compute pairwise covariate distances.

SUPPORTED_MODELS = ('logistic', 'probit', 'gbm', 'rf', 'lasso', 'ridge', 'elasticnet', 'mahalanobis', 'scaled_euclidean', 'robust_mahalanobis', 'euclidean')
DISTANCE_MODELS = ('mahalanobis', 'scaled_euclidean', 'robust_mahalanobis', 'euclidean')
PROPENSITY_MODELS = ('logistic', 'probit', 'gbm', 'rf', 'lasso', 'ridge', 'elasticnet')
def entropy_balance( treated_data: polars.dataframe.frame.DataFrame, control_data: polars.dataframe.frame.DataFrame, covariates: list[str], id: str, moment: int = 1, max_weight: float | None = None, tol: float = 1e-06, max_iter: int = 200) -> polars.dataframe.frame.DataFrame | None:
143def entropy_balance(
144    treated_data: pl.DataFrame,
145    control_data: pl.DataFrame,
146    covariates: list[str],
147    id: str,
148    moment: int = 1,
149    max_weight: float | None = None,
150    tol: float = 1e-6,
151    max_iter: int = 200,
152) -> pl.DataFrame | None:
153    """Compute entropy-balanced weights for control units.
154
155    Finds weights that minimize KL divergence from uniform while
156    satisfying exact covariate moment constraints (Hainmueller 2012).
157
158    Parameters
159    ----------
160    treated_data : pl.DataFrame
161        Treated units for this cohort.
162    control_data : pl.DataFrame
163        Control pool for this cohort.
164    covariates : list[str]
165        Covariate columns to balance.
166    id : str
167        Unit identifier column.
168    moment : int (1, 2, or 3)
169        Moments to balance: 1=means, 2=means+variances, 3=+skewness.
170    max_weight : float or None
171        Optional upper bound on any single weight (after normalization).
172    tol : float
173        Convergence tolerance for the optimizer.
174    max_iter : int
175        Maximum iterations.
176
177    Returns
178    -------
179    pl.DataFrame with columns [id, weight] for treated (weight=1) and
180    controls (entropy-balanced weights), or None if optimization fails.
181    """
182    X_t = treated_data.select(covariates).to_numpy().astype(np.float64)
183    X_c = control_data.select(covariates).to_numpy().astype(np.float64)
184
185    n_t = X_t.shape[0]
186    n_c = X_c.shape[0]
187
188    if n_t == 0 or n_c == 0:
189        return None
190
191    # Build constraint matrix and target
192    C = _build_constraint_matrix(X_c, moment)
193    target = _build_target(X_t, moment)
194
195    # Standardize for numerical stability
196    col_mean = C.mean(axis=0)
197    col_std = C.std(axis=0)
198    col_std[col_std < 1e-10] = 1.0  # avoid division by zero for constant cols
199    # Don't standardize the intercept column
200    col_mean[0] = 0.0
201    col_std[0] = 1.0
202
203    C_std = (C - col_mean) / col_std
204    target_std = (target - col_mean) / col_std
205
206    # Base weights (uniform)
207    base_q = np.full(n_c, 1.0 / n_c)
208
209    # Solve the dual
210    lam0 = np.zeros(C_std.shape[1])
211    result = minimize(
212        _dual_objective,
213        lam0,
214        args=(C_std, target_std, base_q),
215        method="L-BFGS-B",
216        jac=True,  # _dual_objective returns (obj, grad)
217        options={"maxiter": max_iter, "ftol": tol, "gtol": tol},
218    )
219
220    if not result.success:
221        warnings.warn(
222            f"Entropy balancing did not converge: {result.message}. "
223            "This may indicate insufficient overlap between treated and "
224            "control covariate distributions.",
225            stacklevel=2,
226        )
227        return None
228
229    # Recover primal weights
230    lin = C_std @ result.x
231    lin = np.clip(lin, -500, 500)
232    w = base_q * np.exp(lin)
233
234    # w should sum to ~1 (enforced by intercept constraint); rescale to n_treated
235    w = w * (n_t / w.sum())
236
237    # Effective sample size check
238    n_eff = w.sum() ** 2 / np.sum(w ** 2)
239    if n_eff / n_c < 0.1:
240        warnings.warn(
241            f"Effective sample size is very low: n_eff={n_eff:.1f} "
242            f"({100*n_eff/n_c:.1f}% of {n_c} controls). "
243            "Consider relaxing moment constraints or checking overlap.",
244            stacklevel=2,
245        )
246
247    # Optional weight capping (iterate to handle re-normalization)
248    if max_weight is not None:
249        for _ in range(20):  # converges in a few iterations
250            if np.max(w) <= max_weight:
251                break
252            w = np.minimum(w, max_weight)
253            w = w * (n_t / w.sum())
254
255    # Build output DataFrame
256    treat_ids = treated_data[id].to_numpy()
257    ctrl_ids = control_data[id].to_numpy()
258
259    treat_weights = pl.DataFrame({
260        id: treat_ids,
261        "weight": np.ones(n_t),
262    })
263    ctrl_weights = pl.DataFrame({
264        id: ctrl_ids,
265        "weight": w,
266    })
267
268    return pl.concat([treat_weights, ctrl_weights])

Compute entropy-balanced weights for control units.

Finds weights that minimize KL divergence from uniform while satisfying exact covariate moment constraints (Hainmueller 2012).

Parameters

treated_data : pl.DataFrame Treated units for this cohort. control_data : pl.DataFrame Control pool for this cohort. covariates : list[str] Covariate columns to balance. id : str Unit identifier column. moment : int (1, 2, or 3) Moments to balance: 1=means, 2=means+variances, 3=+skewness. max_weight : float or None Optional upper bound on any single weight (after normalization). tol : float Convergence tolerance for the optimizer. max_iter : int Maximum iterations.

Returns

pl.DataFrame with columns [id, weight] for treated (weight=1) and controls (entropy-balanced weights), or None if optimization fails.

def compute_balance( scored_data: polars.dataframe.frame.DataFrame, matches: polars.dataframe.frame.DataFrame, treat: str, id: str, tm: str, covariates: list[str]) -> polars.dataframe.frame.DataFrame:
 43def compute_balance(
 44    scored_data: pl.DataFrame,
 45    matches: pl.DataFrame,
 46    treat: str,
 47    id: str,
 48    tm: str,
 49    covariates: list[str],
 50) -> pl.DataFrame:
 51    """Compute covariate balance before and after matching.
 52
 53    Returns a table with means, SDs, and SMDs for each covariate,
 54    both in the full sample and the matched sample.
 55
 56    Parameters
 57    ----------
 58    scored_data : pl.DataFrame
 59        Reduced data with treatment indicator and covariates.
 60    matches : pl.DataFrame
 61        Match results with treat_id and control_id columns.
 62    treat : str
 63        Treatment indicator column.
 64    id : str
 65        Unit identifier column.
 66    tm : str
 67        Time period column.
 68    covariates : list[str]
 69        Covariate column names.
 70
 71    Returns
 72    -------
 73    pl.DataFrame with columns:
 74        covariate, full_mean_t, full_mean_c, full_sd_t, full_sd_c,
 75        full_smd, matched_mean_t, matched_mean_c, matched_sd_t,
 76        matched_sd_c, matched_smd
 77    """
 78    # Pre-compute matched data ONCE (not per covariate)
 79    treat_matches = matches.select(tm, "treat_id").unique().rename({"treat_id": id})
 80    control_matches = matches.select(tm, "control_id").unique().rename({"control_id": id})
 81    matched_ids_df = pl.concat([treat_matches, control_matches])
 82    matched_data = scored_data.join(matched_ids_df, on=[tm, id], how="semi")
 83
 84    # Pre-split by treatment group
 85    full_t = scored_data.filter(pl.col(treat) == 1)
 86    full_c = scored_data.filter(pl.col(treat) == 0)
 87    match_t = matched_data.filter(pl.col(treat) == 1)
 88    match_c = matched_data.filter(pl.col(treat) == 0)
 89
 90    rows = []
 91
 92    for cov in covariates:
 93        vals_t = full_t[cov].drop_nulls().to_numpy()
 94        vals_c = full_c[cov].drop_nulls().to_numpy()
 95
 96        full_mean_t = np.mean(vals_t) if len(vals_t) > 0 else np.nan
 97        full_mean_c = np.mean(vals_c) if len(vals_c) > 0 else np.nan
 98        full_sd_t = np.std(vals_t, ddof=1) if len(vals_t) > 1 else np.nan
 99        full_sd_c = np.std(vals_c, ddof=1) if len(vals_c) > 1 else np.nan
100        full_pooled = np.sqrt((full_sd_t**2 + full_sd_c**2) / 2) if not (np.isnan(full_sd_t) or np.isnan(full_sd_c)) else np.nan
101        full_smd = (full_mean_t - full_mean_c) / full_pooled if full_pooled and full_pooled > 0 else np.nan
102
103        mvals_t = match_t[cov].drop_nulls().to_numpy()
104        mvals_c = match_c[cov].drop_nulls().to_numpy()
105
106        m_mean_t = np.mean(mvals_t) if len(mvals_t) > 0 else np.nan
107        m_mean_c = np.mean(mvals_c) if len(mvals_c) > 0 else np.nan
108        m_sd_t = np.std(mvals_t, ddof=1) if len(mvals_t) > 1 else np.nan
109        m_sd_c = np.std(mvals_c, ddof=1) if len(mvals_c) > 1 else np.nan
110        m_pooled = np.sqrt((m_sd_t**2 + m_sd_c**2) / 2) if not (np.isnan(m_sd_t) or np.isnan(m_sd_c)) else np.nan
111        m_smd = (m_mean_t - m_mean_c) / m_pooled if m_pooled and m_pooled > 0 else np.nan
112
113        rows.append({
114            "covariate": cov,
115            "full_mean_t": round(full_mean_t, 4),
116            "full_mean_c": round(full_mean_c, 4),
117            "full_sd_t": round(full_sd_t, 4),
118            "full_sd_c": round(full_sd_c, 4),
119            "full_smd": round(full_smd, 4),
120            "matched_mean_t": round(m_mean_t, 4),
121            "matched_mean_c": round(m_mean_c, 4),
122            "matched_sd_t": round(m_sd_t, 4),
123            "matched_sd_c": round(m_sd_c, 4),
124            "matched_smd": round(m_smd, 4),
125        })
126
127    return pl.DataFrame(rows)

Compute covariate balance before and after matching.

Returns a table with means, SDs, and SMDs for each covariate, both in the full sample and the matched sample.

Parameters

scored_data : pl.DataFrame Reduced data with treatment indicator and covariates. matches : pl.DataFrame Match results with treat_id and control_id columns. treat : str Treatment indicator column. id : str Unit identifier column. tm : str Time period column. covariates : list[str] Covariate column names.

Returns

pl.DataFrame with columns: covariate, full_mean_t, full_mean_c, full_sd_t, full_sd_c, full_smd, matched_mean_t, matched_mean_c, matched_sd_t, matched_sd_c, matched_smd

def compute_balance_weighted( data: polars.dataframe.frame.DataFrame, weights: polars.dataframe.frame.DataFrame, treat: str, id: str, covariates: list[str]) -> polars.dataframe.frame.DataFrame:
273def compute_balance_weighted(
274    data: pl.DataFrame,
275    weights: pl.DataFrame,
276    treat: str,
277    id: str,
278    covariates: list[str],
279) -> pl.DataFrame:
280    """Compute covariate balance using weighted means and SDs.
281
282    Parameters
283    ----------
284    data : pl.DataFrame
285        Reduced data with treatment indicator and covariates.
286    weights : pl.DataFrame
287        Unit weights with columns [id, weight].
288    treat : str
289        Treatment indicator column.
290    id : str
291        Unit identifier column.
292    covariates : list[str]
293        Covariate column names.
294
295    Returns
296    -------
297    pl.DataFrame with same schema as compute_balance():
298        covariate, full_mean_t, full_mean_c, full_sd_t, full_sd_c,
299        full_smd, matched_mean_t, matched_mean_c, matched_sd_t,
300        matched_sd_c, matched_smd
301    """
302    # Full sample (unweighted)
303    full_t = data.filter(pl.col(treat) == 1)
304    full_c = data.filter(pl.col(treat) == 0)
305
306    # Weighted sample: join with weights
307    weighted = data.join(weights, on=id, how="inner")
308    w_t = weighted.filter(pl.col(treat) == 1)
309    w_c = weighted.filter(pl.col(treat) == 0)
310
311    rows = []
312    for cov in covariates:
313        # Full sample (unweighted)
314        vals_ft = full_t[cov].drop_nulls().to_numpy()
315        vals_fc = full_c[cov].drop_nulls().to_numpy()
316        full_mean_t = np.mean(vals_ft) if len(vals_ft) > 0 else np.nan
317        full_mean_c = np.mean(vals_fc) if len(vals_fc) > 0 else np.nan
318        full_sd_t = np.std(vals_ft, ddof=1) if len(vals_ft) > 1 else np.nan
319        full_sd_c = np.std(vals_fc, ddof=1) if len(vals_fc) > 1 else np.nan
320        full_pooled = np.sqrt((full_sd_t**2 + full_sd_c**2) / 2) if not (np.isnan(full_sd_t) or np.isnan(full_sd_c)) else np.nan
321        full_smd = (full_mean_t - full_mean_c) / full_pooled if full_pooled and full_pooled > 0 else np.nan
322
323        # Weighted sample
324        wvals_t = w_t[cov].drop_nulls().to_numpy()
325        wvals_c = w_c[cov].drop_nulls().to_numpy()
326        ww_t = w_t.filter(pl.col(cov).is_not_null())["weight"].to_numpy()
327        ww_c = w_c.filter(pl.col(cov).is_not_null())["weight"].to_numpy()
328
329        m_mean_t = _weighted_mean(wvals_t, ww_t) if len(wvals_t) > 0 else np.nan
330        m_mean_c = _weighted_mean(wvals_c, ww_c) if len(wvals_c) > 0 else np.nan
331        m_sd_t = _weighted_std(wvals_t, ww_t) if len(wvals_t) > 1 else np.nan
332        m_sd_c = _weighted_std(wvals_c, ww_c) if len(wvals_c) > 1 else np.nan
333        m_pooled = np.sqrt((m_sd_t**2 + m_sd_c**2) / 2) if not (np.isnan(m_sd_t) or np.isnan(m_sd_c)) else np.nan
334        m_smd = (m_mean_t - m_mean_c) / m_pooled if m_pooled and m_pooled > 0 else np.nan
335
336        rows.append({
337            "covariate": cov,
338            "full_mean_t": round(full_mean_t, 4),
339            "full_mean_c": round(full_mean_c, 4),
340            "full_sd_t": round(full_sd_t, 4),
341            "full_sd_c": round(full_sd_c, 4),
342            "full_smd": round(full_smd, 4),
343            "matched_mean_t": round(m_mean_t, 4),
344            "matched_mean_c": round(m_mean_c, 4),
345            "matched_sd_t": round(m_sd_t, 4),
346            "matched_sd_c": round(m_sd_c, 4),
347            "matched_smd": round(m_smd, 4),
348        })
349
350    return pl.DataFrame(rows)

Compute covariate balance using weighted means and SDs.

Parameters

data : pl.DataFrame Reduced data with treatment indicator and covariates. weights : pl.DataFrame Unit weights with columns [id, weight]. treat : str Treatment indicator column. id : str Unit identifier column. covariates : list[str] Covariate column names.

Returns

pl.DataFrame with same schema as compute_balance(): covariate, full_mean_t, full_mean_c, full_sd_t, full_sd_c, full_smd, matched_mean_t, matched_mean_c, matched_sd_t, matched_sd_c, matched_smd

def balance_by_period( scored_data: polars.dataframe.frame.DataFrame, matches: polars.dataframe.frame.DataFrame, treat: str, id: str, tm: str, covariates: list[str]) -> tuple[polars.dataframe.frame.DataFrame, polars.dataframe.frame.DataFrame]:
130def balance_by_period(
131    scored_data: pl.DataFrame,
132    matches: pl.DataFrame,
133    treat: str,
134    id: str,
135    tm: str,
136    covariates: list[str],
137) -> tuple[pl.DataFrame, pl.DataFrame]:
138    """Compute per-period covariate balance (SMD) for matched samples.
139
140    Pooled SMD can mask within-cohort imbalance through cancellation.
141    This function computes SMD separately for each entry period, then
142    aggregates across periods.
143
144    Parameters
145    ----------
146    scored_data : pl.DataFrame
147        Reduced data with treatment indicator and covariates.
148    matches : pl.DataFrame
149        Match results with treat_id, control_id, and time period columns.
150    treat : str
151        Treatment indicator column.
152    id : str
153        Unit identifier column.
154    tm : str
155        Time period column.
156    covariates : list[str]
157        Covariate column names.
158
159    Returns
160    -------
161    (aggregate, detail) : tuple[pl.DataFrame, pl.DataFrame]
162
163        **aggregate** — one row per covariate with columns:
164        ``covariate``, ``wtd_mean_smd``, ``median_abs_smd``,
165        ``max_abs_smd``, ``n_periods``.
166
167        **detail** — one row per (period, covariate) with columns:
168        ``period``, ``covariate``, ``n_treated``, ``n_controls``,
169        ``mean_treated``, ``mean_control``, ``smd``.
170    """
171    time_periods = matches[tm].unique().sort().to_list()
172
173    detail_rows = []
174
175    for t in time_periods:
176        period_matches = matches.filter(pl.col(tm) == t)
177
178        # Treated units in this period
179        t_ids = period_matches.select("treat_id").unique().rename({"treat_id": id})
180        t_data = scored_data.filter(
181            (pl.col(treat) == 1) & (pl.col(tm) == t)
182        ).join(t_ids, on=id, how="semi")
183
184        # Control units in this period
185        c_ids = period_matches.select("control_id").unique().rename({"control_id": id})
186        c_data = scored_data.filter(
187            (pl.col(treat) == 0) & (pl.col(tm) == t)
188        ).join(c_ids, on=id, how="semi")
189
190        n_treated = t_data.height
191        n_controls = c_data.height
192
193        if n_treated < 2 or n_controls < 2:
194            continue
195
196        for cov in covariates:
197            vals_t = t_data[cov].drop_nulls().to_numpy()
198            vals_c = c_data[cov].drop_nulls().to_numpy()
199
200            if len(vals_t) < 2 or len(vals_c) < 2:
201                detail_rows.append({
202                    "period": t, "covariate": cov,
203                    "n_treated": n_treated, "n_controls": n_controls,
204                    "mean_treated": np.nan, "mean_control": np.nan,
205                    "smd": np.nan,
206                })
207                continue
208
209            mean_t = np.mean(vals_t)
210            mean_c = np.mean(vals_c)
211            sd_t = np.std(vals_t, ddof=1)
212            sd_c = np.std(vals_c, ddof=1)
213            pooled = np.sqrt((sd_t**2 + sd_c**2) / 2)
214            smd = (mean_t - mean_c) / pooled if pooled > 0 else np.nan
215
216            detail_rows.append({
217                "period": t, "covariate": cov,
218                "n_treated": n_treated, "n_controls": n_controls,
219                "mean_treated": round(mean_t, 4),
220                "mean_control": round(mean_c, 4),
221                "smd": round(smd, 4),
222            })
223
224    detail = pl.DataFrame(detail_rows) if detail_rows else pl.DataFrame(
225        schema={"period": pl.Int64, "covariate": pl.Utf8,
226                "n_treated": pl.Int64, "n_controls": pl.Int64,
227                "mean_treated": pl.Float64, "mean_control": pl.Float64,
228                "smd": pl.Float64}
229    )
230
231    if detail.height == 0:
232        agg = pl.DataFrame(
233            schema={"covariate": pl.Utf8, "wtd_mean_smd": pl.Float64,
234                    "median_abs_smd": pl.Float64, "max_abs_smd": pl.Float64,
235                    "n_periods": pl.UInt32}
236        )
237        return agg, detail
238
239    # Aggregate: weighted mean (by n_treated), median |SMD|, max |SMD|
240    agg_rows = []
241    for cov in covariates:
242        cov_detail = detail.filter(
243            (pl.col("covariate") == cov) & pl.col("smd").is_not_nan()
244        )
245        if cov_detail.height == 0:
246            agg_rows.append({
247                "covariate": cov, "wtd_mean_smd": np.nan,
248                "median_abs_smd": np.nan, "max_abs_smd": np.nan,
249                "n_periods": 0,
250            })
251            continue
252
253        smds = cov_detail["smd"].to_numpy()
254        weights = cov_detail["n_treated"].to_numpy().astype(float)
255        total_w = weights.sum()
256
257        wtd_mean = float(np.average(smds, weights=weights)) if total_w > 0 else np.nan
258        median_abs = float(np.median(np.abs(smds)))
259        max_abs = float(np.max(np.abs(smds)))
260
261        agg_rows.append({
262            "covariate": cov,
263            "wtd_mean_smd": round(wtd_mean, 4),
264            "median_abs_smd": round(median_abs, 4),
265            "max_abs_smd": round(max_abs, 4),
266            "n_periods": cov_detail.height,
267        })
268
269    agg = pl.DataFrame(agg_rows)
270    return agg, detail

Compute per-period covariate balance (SMD) for matched samples.

Pooled SMD can mask within-cohort imbalance through cancellation. This function computes SMD separately for each entry period, then aggregates across periods.

Parameters

scored_data : pl.DataFrame Reduced data with treatment indicator and covariates. matches : pl.DataFrame Match results with treat_id, control_id, and time period columns. treat : str Treatment indicator column. id : str Unit identifier column. tm : str Time period column. covariates : list[str] Covariate column names.

Returns

(aggregate, detail) : tuple[pl.DataFrame, pl.DataFrame]

**aggregate** — one row per covariate with columns:
``covariate``, ``wtd_mean_smd``, ``median_abs_smd``,
``max_abs_smd``, ``n_periods``.

**detail** — one row per (period, covariate) with columns:
``period``, ``covariate``, ``n_treated``, ``n_controls``,
``mean_treated``, ``mean_control``, ``smd``.
def balance_by_period_weighted( data: polars.dataframe.frame.DataFrame, weights: polars.dataframe.frame.DataFrame, treat: str, id: str, tm: str, covariates: list[str]) -> tuple[polars.dataframe.frame.DataFrame, polars.dataframe.frame.DataFrame]:
353def balance_by_period_weighted(
354    data: pl.DataFrame,
355    weights: pl.DataFrame,
356    treat: str,
357    id: str,
358    tm: str,
359    covariates: list[str],
360) -> tuple[pl.DataFrame, pl.DataFrame]:
361    """Per-period covariate balance using weighted means/SDs.
362
363    Parameters
364    ----------
365    data : pl.DataFrame
366        Reduced data.
367    weights : pl.DataFrame
368        Unit weights. Either stacked [tm, id, weight] (per-cohort
369        weights) or collapsed [id, weight]. If stacked, per-cohort
370        weights are used for each period's balance computation.
371    treat, id, tm : str
372        Column names.
373    covariates : list[str]
374        Covariate column names.
375
376    Returns
377    -------
378    (aggregate, detail) with same schemas as balance_by_period().
379    """
380    # Detect stacked vs collapsed weights
381    stacked = tm in weights.columns
382    if stacked:
383        weighted = data.join(weights, on=[tm, id], how="inner")
384    else:
385        weighted = data.join(weights, on=id, how="inner")
386    time_periods = weighted[tm].unique().sort().to_list()
387
388    detail_rows = []
389    for t in time_periods:
390        period_data = weighted.filter(pl.col(tm) == t)
391        t_data = period_data.filter(pl.col(treat) == 1)
392        c_data = period_data.filter(pl.col(treat) == 0)
393
394        n_treated = t_data.height
395        n_controls = c_data.height
396        if n_treated < 2 or n_controls < 2:
397            continue
398
399        for cov in covariates:
400            vals_t = t_data[cov].drop_nulls().to_numpy()
401            vals_c = c_data[cov].drop_nulls().to_numpy()
402            ww_t = t_data.filter(pl.col(cov).is_not_null())["weight"].to_numpy()
403            ww_c = c_data.filter(pl.col(cov).is_not_null())["weight"].to_numpy()
404
405            if len(vals_t) < 2 or len(vals_c) < 2:
406                detail_rows.append({
407                    "period": t, "covariate": cov,
408                    "n_treated": n_treated, "n_controls": n_controls,
409                    "mean_treated": np.nan, "mean_control": np.nan,
410                    "smd": np.nan,
411                })
412                continue
413
414            mean_t = _weighted_mean(vals_t, ww_t)
415            mean_c = _weighted_mean(vals_c, ww_c)
416            sd_t = _weighted_std(vals_t, ww_t)
417            sd_c = _weighted_std(vals_c, ww_c)
418            pooled = np.sqrt((sd_t**2 + sd_c**2) / 2) if not (np.isnan(sd_t) or np.isnan(sd_c)) else np.nan
419            smd = (mean_t - mean_c) / pooled if pooled and pooled > 0 else np.nan
420
421            detail_rows.append({
422                "period": t, "covariate": cov,
423                "n_treated": n_treated, "n_controls": n_controls,
424                "mean_treated": round(mean_t, 4),
425                "mean_control": round(mean_c, 4),
426                "smd": round(smd, 4),
427            })
428
429    detail = pl.DataFrame(detail_rows) if detail_rows else pl.DataFrame(
430        schema={"period": pl.Int64, "covariate": pl.Utf8,
431                "n_treated": pl.Int64, "n_controls": pl.Int64,
432                "mean_treated": pl.Float64, "mean_control": pl.Float64,
433                "smd": pl.Float64}
434    )
435
436    if detail.height == 0:
437        agg = pl.DataFrame(
438            schema={"covariate": pl.Utf8, "wtd_mean_smd": pl.Float64,
439                    "median_abs_smd": pl.Float64, "max_abs_smd": pl.Float64,
440                    "n_periods": pl.UInt32}
441        )
442        return agg, detail
443
444    agg_rows = []
445    for cov in covariates:
446        cov_detail = detail.filter(
447            (pl.col("covariate") == cov) & pl.col("smd").is_not_nan()
448        )
449        if cov_detail.height == 0:
450            agg_rows.append({
451                "covariate": cov, "wtd_mean_smd": np.nan,
452                "median_abs_smd": np.nan, "max_abs_smd": np.nan,
453                "n_periods": 0,
454            })
455            continue
456
457        smds = cov_detail["smd"].to_numpy()
458        period_weights = cov_detail["n_treated"].to_numpy().astype(float)
459        total_w = period_weights.sum()
460
461        wtd_mean = float(np.average(smds, weights=period_weights)) if total_w > 0 else np.nan
462        median_abs = float(np.median(np.abs(smds)))
463        max_abs = float(np.max(np.abs(smds)))
464
465        agg_rows.append({
466            "covariate": cov,
467            "wtd_mean_smd": round(wtd_mean, 4),
468            "median_abs_smd": round(median_abs, 4),
469            "max_abs_smd": round(max_abs, 4),
470            "n_periods": cov_detail.height,
471        })
472
473    agg = pl.DataFrame(agg_rows)
474    return agg, detail

Per-period covariate balance using weighted means/SDs.

Parameters

data : pl.DataFrame Reduced data. weights : pl.DataFrame Unit weights. Either stacked [tm, id, weight] (per-cohort weights) or collapsed [id, weight]. If stacked, per-cohort weights are used for each period's balance computation. treat, id, tm : str Column names. covariates : list[str] Covariate column names.

Returns

(aggregate, detail) with same schemas as balance_by_period().

def smd_table( balance: polars.dataframe.frame.DataFrame, threshold: float = 0.1) -> None:
477def smd_table(balance: pl.DataFrame, threshold: float = 0.1) -> None:
478    """Print a formatted SMD table with pass/fail indicators.
479
480    Parameters
481    ----------
482    balance : pl.DataFrame
483        Output from compute_balance().
484    threshold : float
485        |SMD| threshold for pass/fail (default 0.1).
486    """
487    max_smd = balance["matched_smd"].abs().max()
488    all_pass = balance["matched_smd"].abs().max() < threshold
489
490    print(f"\n{'='*70}")
491    print(f"  Standardized Mean Differences (threshold: |SMD| < {threshold})")
492    print(f"  Max |SMD| = {max_smd:.4f}  {'✓ ALL PASS' if all_pass else '✗ SOME FAIL'}")
493    print(f"{'='*70}\n")
494    print(f"  {'Covariate':<30} {'Full SMD':>10} {'Matched SMD':>12} {'Pass':>6}")
495    print(f"  {'-'*30} {'-'*10} {'-'*12} {'-'*6}")
496
497    for row in balance.iter_rows(named=True):
498        smd = row["matched_smd"]
499        passed = abs(smd) < threshold if smd is not None else False
500        print(f"  {row['covariate']:<30} {row['full_smd']:>10.4f} {smd:>12.4f} {'✓' if passed else '✗':>6}")
501
502    print()

Print a formatted SMD table with pass/fail indicators.

Parameters

balance : pl.DataFrame Output from compute_balance(). threshold : float |SMD| threshold for pass/fail (default 0.1).

def balance_test( scored_data: polars.dataframe.frame.DataFrame, matches: polars.dataframe.frame.DataFrame, treat: str, id: str, tm: str, covariates: list[str], threshold: float = 0.1) -> polars.dataframe.frame.DataFrame:
 17def balance_test(
 18    scored_data: pl.DataFrame,
 19    matches: pl.DataFrame,
 20    treat: str,
 21    id: str,
 22    tm: str,
 23    covariates: list[str],
 24    threshold: float = 0.1,
 25) -> pl.DataFrame:
 26    """Run comprehensive balance diagnostics on matched sample.
 27
 28    For each covariate, computes:
 29    - Standardized mean difference (SMD)
 30    - Two-sample t-test (H0: means are equal)
 31    - Variance ratio (treat/control)
 32    - Kolmogorov-Smirnov test (H0: distributions are equal)
 33
 34    Parameters
 35    ----------
 36    scored_data : pl.DataFrame
 37        Reduced data with treatment indicator and covariates.
 38    matches : pl.DataFrame
 39        Match results with treat_id, control_id, tm columns.
 40    treat : str
 41        Treatment indicator column.
 42    id : str
 43        Unit identifier column.
 44    tm : str
 45        Time period column.
 46    covariates : list[str]
 47        Covariate column names.
 48    threshold : float
 49        SMD threshold for pass/fail (default 0.1).
 50
 51    Returns
 52    -------
 53    pl.DataFrame with diagnostics per covariate.
 54    """
 55    # Get matched units
 56    treat_matches = matches.select(tm, "treat_id").unique().rename({"treat_id": id})
 57    control_matches = matches.select(tm, "control_id").unique().rename({"control_id": id})
 58    matched_ids = pl.concat([treat_matches, control_matches])
 59    matched_data = scored_data.join(matched_ids, on=[tm, id], how="semi")
 60
 61    rows = []
 62    for cov in covariates:
 63        vals_t = matched_data.filter(pl.col(treat) == 1)[cov].drop_nulls().to_numpy().astype(float)
 64        vals_c = matched_data.filter(pl.col(treat) == 0)[cov].drop_nulls().to_numpy().astype(float)
 65
 66        if len(vals_t) < 2 or len(vals_c) < 2:
 67            continue
 68
 69        # SMD
 70        sd_t, sd_c = np.std(vals_t, ddof=1), np.std(vals_c, ddof=1)
 71        pooled_sd = np.sqrt((sd_t**2 + sd_c**2) / 2)
 72        smd = (np.mean(vals_t) - np.mean(vals_c)) / pooled_sd if pooled_sd > 0 else np.nan
 73
 74        # Two-sample t-test (Welch's)
 75        t_stat, t_pvalue = stats.ttest_ind(vals_t, vals_c, equal_var=False)
 76
 77        # Variance ratio
 78        var_ratio = np.var(vals_t, ddof=1) / np.var(vals_c, ddof=1) if np.var(vals_c, ddof=1) > 0 else np.nan
 79
 80        # KS test
 81        ks_stat, ks_pvalue = stats.ks_2samp(vals_t, vals_c)
 82
 83        rows.append({
 84            "covariate": cov,
 85            "mean_treated": round(np.mean(vals_t), 4),
 86            "mean_control": round(np.mean(vals_c), 4),
 87            "smd": round(smd, 4),
 88            "smd_pass": bool(abs(smd) < threshold),
 89            "t_stat": round(t_stat, 4),
 90            "t_pvalue": round(t_pvalue, 4),
 91            "var_ratio": round(var_ratio, 4),
 92            "var_ratio_pass": bool(0.5 < var_ratio < 2.0) if not np.isnan(var_ratio) else False,
 93            "ks_stat": round(ks_stat, 4),
 94            "ks_pvalue": round(ks_pvalue, 4),
 95        })
 96
 97    result = pl.DataFrame(rows)
 98
 99    # Print summary
100    n_pass_smd = result.filter(pl.col("smd_pass")).height
101    n_pass_var = result.filter(pl.col("var_ratio_pass")).height
102    n_total = result.height
103
104    print(f"\n{'='*70}")
105    print(f"  Post-Matching Balance Diagnostics")
106    print(f"{'='*70}")
107    print(f"  SMD < {threshold}: {n_pass_smd}/{n_total} pass")
108    print(f"  Variance ratio in (0.5, 2.0): {n_pass_var}/{n_total} pass")
109    print(f"{'='*70}\n")
110
111    print(f"  {'Covariate':<25} {'SMD':>8} {'t-test p':>10} {'VR':>8} {'KS p':>8}")
112    print(f"  {'-'*25} {'-'*8} {'-'*10} {'-'*8} {'-'*8}")
113    for row in result.iter_rows(named=True):
114        smd_flag = "✓" if row["smd_pass"] else "✗"
115        vr_flag = "✓" if row["var_ratio_pass"] else "✗"
116        print(f"  {row['covariate']:<25} {row['smd']:>7.4f}{smd_flag} {row['t_pvalue']:>10.4f} {row['var_ratio']:>7.3f}{vr_flag} {row['ks_pvalue']:>8.4f}")
117
118    return result

Run comprehensive balance diagnostics on matched sample.

For each covariate, computes:

  • Standardized mean difference (SMD)
  • Two-sample t-test (H0: means are equal)
  • Variance ratio (treat/control)
  • Kolmogorov-Smirnov test (H0: distributions are equal)

Parameters

scored_data : pl.DataFrame Reduced data with treatment indicator and covariates. matches : pl.DataFrame Match results with treat_id, control_id, tm columns. treat : str Treatment indicator column. id : str Unit identifier column. tm : str Time period column. covariates : list[str] Covariate column names. threshold : float SMD threshold for pass/fail (default 0.1).

Returns

pl.DataFrame with diagnostics per covariate.

def equivalence_test( scored_data: polars.dataframe.frame.DataFrame, matches: polars.dataframe.frame.DataFrame, treat: str, id: str, tm: str, covariates: list[str], multiplier: float = 0.36) -> polars.dataframe.frame.DataFrame:
121def equivalence_test(
122    scored_data: pl.DataFrame,
123    matches: pl.DataFrame,
124    treat: str,
125    id: str,
126    tm: str,
127    covariates: list[str],
128    multiplier: float = 0.36,
129) -> pl.DataFrame:
130    """TOST equivalence test for covariate balance.
131
132    Tests H0: |SMD| >= delta (non-equivalence).
133    Rejection = GOOD (positive evidence of negligible difference).
134    Uses Hartman & Hidalgo (2018) approach: delta = multiplier × pooled_SD.
135
136    Parameters
137    ----------
138    scored_data : pl.DataFrame
139        Reduced data.
140    matches : pl.DataFrame
141        Match results.
142    treat, id, tm : str
143        Column names.
144    covariates : list[str]
145        Covariate names.
146    multiplier : float
147        Equivalence bound as fraction of pooled SD (default 0.36).
148
149    Returns
150    -------
151    pl.DataFrame with TOST results per covariate.
152    """
153    treat_matches = matches.select(tm, "treat_id").unique().rename({"treat_id": id})
154    control_matches = matches.select(tm, "control_id").unique().rename({"control_id": id})
155    matched_ids = pl.concat([treat_matches, control_matches])
156    matched_data = scored_data.join(matched_ids, on=[tm, id], how="semi")
157
158    rows = []
159    for cov in covariates:
160        vals_t = matched_data.filter(pl.col(treat) == 1)[cov].drop_nulls().to_numpy().astype(float)
161        vals_c = matched_data.filter(pl.col(treat) == 0)[cov].drop_nulls().to_numpy().astype(float)
162
163        if len(vals_t) < 2 or len(vals_c) < 2:
164            continue
165
166        m, n = len(vals_t), len(vals_c)
167        diff = np.mean(vals_t) - np.mean(vals_c)
168        var_t = np.var(vals_t, ddof=1)
169        var_c = np.var(vals_c, ddof=1)
170
171        # Pooled SD: weighted formula matching Hartman & Hidalgo (2018)
172        # equivtest R package: sqrt(((m-1)*var_x + (n-1)*var_y) / (m+n-2))
173        pooled_sd = np.sqrt(((m - 1) * var_t + (n - 1) * var_c) / (m + n - 2))
174        delta = multiplier * pooled_sd
175
176        # Two one-sided t-tests following equivtest::tost()
177        # Uses Welch's t-test (unequal variances)
178        se = np.sqrt(var_t / m + var_c / n)
179        df_welch = se**4 / ((var_t/m)**2/(m-1) + (var_c/n)**2/(n-1)) if se > 0 else 1
180
181        # Upper test: H0: diff >= delta, alt: diff < delta
182        t_upper = (diff - delta) / se if se > 0 else np.inf
183        p_upper = stats.t.cdf(t_upper, df=df_welch)
184
185        # Lower test: H0: diff <= -delta, alt: diff > -delta
186        t_lower = (diff + delta) / se if se > 0 else -np.inf
187        p_lower = 1 - stats.t.cdf(t_lower, df=df_welch)
188
189        tost_p = max(p_upper, p_lower)
190
191        rows.append({
192            "covariate": cov,
193            "diff": round(diff, 6),
194            "se": round(se, 6),
195            "delta": round(delta, 4),
196            "tost_p_upper": round(p_upper, 4),
197            "tost_p_lower": round(p_lower, 4),
198            "tost_p": round(tost_p, 4),
199            "equivalent": bool(tost_p < 0.05),
200        })
201
202    result = pl.DataFrame(rows)
203
204    n_equiv = result.filter(pl.col("equivalent")).height
205    print(f"\n  TOST Equivalence Test (bound = {multiplier}σ)")
206    print(f"  Equivalent: {n_equiv}/{result.height} covariates (p < 0.05 = GOOD)")
207    for row in result.iter_rows(named=True):
208        flag = "✓ EQUIV" if row["equivalent"] else "  not equiv"
209        print(f"    {row['covariate']:<25} p={row['tost_p']:.4f} {flag}")
210
211    return result

TOST equivalence test for covariate balance.

Tests H0: |SMD| >= delta (non-equivalence). Rejection = GOOD (positive evidence of negligible difference). Uses Hartman & Hidalgo (2018) approach: delta = multiplier × pooled_SD.

Parameters

scored_data : pl.DataFrame Reduced data. matches : pl.DataFrame Match results. treat, id, tm : str Column names. covariates : list[str] Covariate names. multiplier : float Equivalence bound as fraction of pooled SD (default 0.36).

Returns

pl.DataFrame with TOST results per covariate.

def balance_test_weighted( data: polars.dataframe.frame.DataFrame, weights: polars.dataframe.frame.DataFrame, treat: str, id: str, covariates: list[str], threshold: float = 0.1) -> polars.dataframe.frame.DataFrame:
225def balance_test_weighted(
226    data: pl.DataFrame,
227    weights: pl.DataFrame,
228    treat: str,
229    id: str,
230    covariates: list[str],
231    threshold: float = 0.1,
232) -> pl.DataFrame:
233    """Run balance diagnostics on weighted sample.
234
235    For each covariate, computes:
236    - Weighted standardized mean difference (SMD)
237    - Weighted two-sample Welch t-test
238    - Weighted variance ratio
239    - Effective sample sizes
240
241    Parameters
242    ----------
243    data : pl.DataFrame
244        Reduced data with treatment indicator and covariates.
245    weights : pl.DataFrame
246        Unit weights with columns [id, weight].
247    treat : str
248        Treatment indicator column.
249    id : str
250        Unit identifier column.
251    covariates : list[str]
252        Covariate column names.
253    threshold : float
254        SMD threshold for pass/fail (default 0.1).
255
256    Returns
257    -------
258    pl.DataFrame with diagnostics per covariate.
259    """
260    weighted = data.join(weights, on=id, how="inner")
261    w_t = weighted.filter(pl.col(treat) == 1)
262    w_c = weighted.filter(pl.col(treat) == 0)
263
264    rows = []
265    for cov in covariates:
266        vals_t = w_t[cov].drop_nulls().to_numpy().astype(float)
267        vals_c = w_c[cov].drop_nulls().to_numpy().astype(float)
268        ww_t = w_t.filter(pl.col(cov).is_not_null())["weight"].to_numpy().astype(float)
269        ww_c = w_c.filter(pl.col(cov).is_not_null())["weight"].to_numpy().astype(float)
270
271        if len(vals_t) < 2 or len(vals_c) < 2:
272            continue
273
274        # Weighted SMD
275        sd_t = _weighted_std(vals_t, ww_t)
276        sd_c = _weighted_std(vals_c, ww_c)
277        pooled_sd = np.sqrt((sd_t**2 + sd_c**2) / 2)
278        mean_t = _weighted_mean(vals_t, ww_t)
279        mean_c = _weighted_mean(vals_c, ww_c)
280        smd = (mean_t - mean_c) / pooled_sd if pooled_sd > 0 else np.nan
281
282        # Effective sample sizes
283        n_eff_t = _effective_n(ww_t)
284        n_eff_c = _effective_n(ww_c)
285
286        # Weighted Welch t-test
287        var_t = sd_t ** 2 if not np.isnan(sd_t) else np.nan
288        var_c = sd_c ** 2 if not np.isnan(sd_c) else np.nan
289        se = np.sqrt(var_t / n_eff_t + var_c / n_eff_c) if (n_eff_t > 0 and n_eff_c > 0) else np.nan
290
291        if se and se > 0 and n_eff_t > 1 and n_eff_c > 1:
292            t_stat = (mean_t - mean_c) / se
293            # Satterthwaite degrees of freedom
294            df_welch = (var_t / n_eff_t + var_c / n_eff_c) ** 2 / (
295                (var_t / n_eff_t) ** 2 / (n_eff_t - 1) +
296                (var_c / n_eff_c) ** 2 / (n_eff_c - 1)
297            )
298            t_pvalue = 2 * (1 - stats.t.cdf(abs(t_stat), df=df_welch))
299        else:
300            t_stat = np.nan
301            t_pvalue = np.nan
302
303        # Variance ratio
304        var_ratio = var_t / var_c if var_c and var_c > 0 else np.nan
305
306        rows.append({
307            "covariate": cov,
308            "mean_treated": round(mean_t, 4),
309            "mean_control": round(mean_c, 4),
310            "smd": round(smd, 4),
311            "smd_pass": bool(abs(smd) < threshold) if not np.isnan(smd) else False,
312            "t_stat": round(t_stat, 4) if not np.isnan(t_stat) else np.nan,
313            "t_pvalue": round(t_pvalue, 4) if not np.isnan(t_pvalue) else np.nan,
314            "var_ratio": round(var_ratio, 4) if not np.isnan(var_ratio) else np.nan,
315            "var_ratio_pass": bool(0.5 < var_ratio < 2.0) if not np.isnan(var_ratio) else False,
316            "n_eff_treated": round(n_eff_t, 1),
317            "n_eff_control": round(n_eff_c, 1),
318        })
319
320    result = pl.DataFrame(rows)
321
322    n_pass_smd = result.filter(pl.col("smd_pass")).height
323    n_pass_var = result.filter(pl.col("var_ratio_pass")).height
324    n_total = result.height
325
326    print(f"\n{'='*70}")
327    print(f"  Weighted Balance Diagnostics")
328    print(f"{'='*70}")
329    print(f"  SMD < {threshold}: {n_pass_smd}/{n_total} pass")
330    print(f"  Variance ratio in (0.5, 2.0): {n_pass_var}/{n_total} pass")
331    print(f"{'='*70}\n")
332
333    print(f"  {'Covariate':<25} {'SMD':>8} {'t-test p':>10} {'VR':>8} {'n_eff_c':>8}")
334    print(f"  {'-'*25} {'-'*8} {'-'*10} {'-'*8} {'-'*8}")
335    for row in result.iter_rows(named=True):
336        smd_flag = "✓" if row["smd_pass"] else "✗"
337        vr_flag = "✓" if row["var_ratio_pass"] else "✗"
338        print(f"  {row['covariate']:<25} {row['smd']:>7.4f}{smd_flag} "
339              f"{row['t_pvalue']:>10.4f} {row['var_ratio']:>7.3f}{vr_flag} "
340              f"{row['n_eff_control']:>8.1f}")
341
342    return result

Run balance diagnostics on weighted sample.

For each covariate, computes:

  • Weighted standardized mean difference (SMD)
  • Weighted two-sample Welch t-test
  • Weighted variance ratio
  • Effective sample sizes

Parameters

data : pl.DataFrame Reduced data with treatment indicator and covariates. weights : pl.DataFrame Unit weights with columns [id, weight]. treat : str Treatment indicator column. id : str Unit identifier column. covariates : list[str] Covariate column names. threshold : float SMD threshold for pass/fail (default 0.1).

Returns

pl.DataFrame with diagnostics per covariate.

def equivalence_test_weighted( data: polars.dataframe.frame.DataFrame, weights: polars.dataframe.frame.DataFrame, treat: str, id: str, covariates: list[str], multiplier: float = 0.36) -> polars.dataframe.frame.DataFrame:
345def equivalence_test_weighted(
346    data: pl.DataFrame,
347    weights: pl.DataFrame,
348    treat: str,
349    id: str,
350    covariates: list[str],
351    multiplier: float = 0.36,
352) -> pl.DataFrame:
353    """TOST equivalence test for weighted covariate balance.
354
355    Same as equivalence_test() but uses weighted statistics and
356    effective sample sizes.
357
358    Parameters
359    ----------
360    data : pl.DataFrame
361        Reduced data.
362    weights : pl.DataFrame
363        Unit weights with columns [id, weight].
364    treat, id : str
365        Column names.
366    covariates : list[str]
367        Covariate names.
368    multiplier : float
369        Equivalence bound as fraction of pooled SD (default 0.36).
370
371    Returns
372    -------
373    pl.DataFrame with TOST results per covariate.
374    """
375    weighted = data.join(weights, on=id, how="inner")
376    w_t = weighted.filter(pl.col(treat) == 1)
377    w_c = weighted.filter(pl.col(treat) == 0)
378
379    rows = []
380    for cov in covariates:
381        vals_t = w_t[cov].drop_nulls().to_numpy().astype(float)
382        vals_c = w_c[cov].drop_nulls().to_numpy().astype(float)
383        ww_t = w_t.filter(pl.col(cov).is_not_null())["weight"].to_numpy().astype(float)
384        ww_c = w_c.filter(pl.col(cov).is_not_null())["weight"].to_numpy().astype(float)
385
386        if len(vals_t) < 2 or len(vals_c) < 2:
387            continue
388
389        mean_t = _weighted_mean(vals_t, ww_t)
390        mean_c = _weighted_mean(vals_c, ww_c)
391        diff = mean_t - mean_c
392
393        var_t = _weighted_std(vals_t, ww_t) ** 2
394        var_c = _weighted_std(vals_c, ww_c) ** 2
395
396        n_eff_t = _effective_n(ww_t)
397        n_eff_c = _effective_n(ww_c)
398
399        # Pooled SD (weighted, for delta calculation)
400        pooled_sd = np.sqrt(
401            ((n_eff_t - 1) * var_t + (n_eff_c - 1) * var_c) /
402            (n_eff_t + n_eff_c - 2)
403        ) if (n_eff_t + n_eff_c > 2) else np.nan
404
405        delta = multiplier * pooled_sd if not np.isnan(pooled_sd) else np.nan
406
407        # Weighted SE and Welch df
408        se = np.sqrt(var_t / n_eff_t + var_c / n_eff_c) if (n_eff_t > 0 and n_eff_c > 0) else np.nan
409        if se and se > 0 and n_eff_t > 1 and n_eff_c > 1:
410            df_welch = se ** 4 / (
411                (var_t / n_eff_t) ** 2 / (n_eff_t - 1) +
412                (var_c / n_eff_c) ** 2 / (n_eff_c - 1)
413            )
414        else:
415            df_welch = 1
416
417        # TOST
418        if se and se > 0 and not np.isnan(delta):
419            t_upper = (diff - delta) / se
420            p_upper = stats.t.cdf(t_upper, df=df_welch)
421            t_lower = (diff + delta) / se
422            p_lower = 1 - stats.t.cdf(t_lower, df=df_welch)
423            tost_p = max(p_upper, p_lower)
424        else:
425            p_upper = np.nan
426            p_lower = np.nan
427            tost_p = np.nan
428
429        rows.append({
430            "covariate": cov,
431            "diff": round(diff, 6),
432            "se": round(se, 6) if not np.isnan(se) else np.nan,
433            "delta": round(delta, 4) if not np.isnan(delta) else np.nan,
434            "tost_p_upper": round(p_upper, 4) if not np.isnan(p_upper) else np.nan,
435            "tost_p_lower": round(p_lower, 4) if not np.isnan(p_lower) else np.nan,
436            "tost_p": round(tost_p, 4) if not np.isnan(tost_p) else np.nan,
437            "equivalent": bool(tost_p < 0.05) if not np.isnan(tost_p) else False,
438        })
439
440    result = pl.DataFrame(rows)
441
442    n_equiv = result.filter(pl.col("equivalent")).height
443    print(f"\n  Weighted TOST Equivalence Test (bound = {multiplier}σ)")
444    print(f"  Equivalent: {n_equiv}/{result.height} covariates (p < 0.05 = GOOD)")
445    for row in result.iter_rows(named=True):
446        flag = "✓ EQUIV" if row["equivalent"] else "  not equiv"
447        print(f"    {row['covariate']:<25} p={row['tost_p']:.4f} {flag}")
448
449    return result

TOST equivalence test for weighted covariate balance.

Same as equivalence_test() but uses weighted statistics and effective sample sizes.

Parameters

data : pl.DataFrame Reduced data. weights : pl.DataFrame Unit weights with columns [id, weight]. treat, id : str Column names. covariates : list[str] Covariate names. multiplier : float Equivalence bound as fraction of pooled SD (default 0.36).

Returns

pl.DataFrame with TOST results per covariate.