pyrollmatch
pyrollmatch — Rolling entry matching and weighting for staggered adoption studies.
A Python package for causal inference with staggered treatment adoption, built on polars and numpy for scalable matching/weighting on large panel datasets (100K+ units).
Methods
- Matching: Nearest-neighbor matching on propensity scores or pairwise distances (Mahalanobis, Euclidean, robust Mahalanobis). Supports calipers, replacement modes, and matching order (following MatchIt conventions).
- Entropy balancing: Direct covariate balance via convex optimization (Hainmueller 2012). Each entry cohort weights independently (stacked design).
- Custom: User-defined per-period weighting functions.
Quick Start
>>> from pyrollmatch import rollmatch
>>>
>>> # Propensity score matching
>>> result = rollmatch(
... data, treat="treat", tm="time", entry="entry_time", id="unit_id",
... covariates=["x1", "x2", "x3"],
... ps_caliper=0.2, num_matches=3,
... )
>>>
>>> # Mahalanobis distance matching
>>> result = rollmatch(
... data, ..., model_type="mahalanobis",
... )
>>>
>>> # Entropy balancing
>>> result = rollmatch(
... data, ..., method="ebal", moment=1,
... )
>>>
>>> result.balance # covariate balance (SMD table)
>>> result.weights # unit-level weights
>>> result.matched_data # match pairs (matching only)
References
- Witman et al. (2018). "Comparison Group Selection in the Presence of Rolling Entry." Health Services Research, 54(1), 262-270.
- Hainmueller, J. (2012). "Entropy Balancing for Causal Effects." Political Analysis, 20(1), 25-46.
- Imai, King, Stuart (2011). MatchIt: Nonparametric Preprocessing for Parametric Causal Inference.
- Rosenbaum, P. (2010). Design of Observational Studies, ch. 8.
1""" 2pyrollmatch — Rolling entry matching and weighting for staggered adoption studies. 3 4A Python package for causal inference with staggered treatment adoption, 5built on polars and numpy for scalable matching/weighting on large panel 6datasets (100K+ units). 7 8Methods 9------- 10- **Matching**: Nearest-neighbor matching on propensity scores or pairwise 11 distances (Mahalanobis, Euclidean, robust Mahalanobis). Supports calipers, 12 replacement modes, and matching order (following MatchIt conventions). 13- **Entropy balancing**: Direct covariate balance via convex optimization 14 (Hainmueller 2012). Each entry cohort weights independently (stacked design). 15- **Custom**: User-defined per-period weighting functions. 16 17Quick Start 18----------- 19>>> from pyrollmatch import rollmatch 20>>> 21>>> # Propensity score matching 22>>> result = rollmatch( 23... data, treat="treat", tm="time", entry="entry_time", id="unit_id", 24... covariates=["x1", "x2", "x3"], 25... ps_caliper=0.2, num_matches=3, 26... ) 27>>> 28>>> # Mahalanobis distance matching 29>>> result = rollmatch( 30... data, ..., model_type="mahalanobis", 31... ) 32>>> 33>>> # Entropy balancing 34>>> result = rollmatch( 35... data, ..., method="ebal", moment=1, 36... ) 37>>> 38>>> result.balance # covariate balance (SMD table) 39>>> result.weights # unit-level weights 40>>> result.matched_data # match pairs (matching only) 41 42References 43---------- 44- Witman et al. (2018). "Comparison Group Selection in the Presence of 45 Rolling Entry." Health Services Research, 54(1), 262-270. 46- Hainmueller, J. (2012). "Entropy Balancing for Causal Effects." 47 Political Analysis, 20(1), 25-46. 48- Imai, King, Stuart (2011). MatchIt: Nonparametric Preprocessing for 49 Parametric Causal Inference. 50- Rosenbaum, P. (2010). Design of Observational Studies, ch. 8. 51""" 52 53from .core import rollmatch, RollmatchResult 54from .reduce import reduce_data 55from .score import score_data, ScoredResult, SUPPORTED_MODELS, DISTANCE_MODELS, PROPENSITY_MODELS 56from .match import DistanceSpec 57from .weight import entropy_balance 58from .balance import ( 59 compute_balance, compute_balance_weighted, 60 balance_by_period, balance_by_period_weighted, 61 smd_table, 62) 63from .diagnostics import ( 64 balance_test, equivalence_test, 65 balance_test_weighted, equivalence_test_weighted, 66) 67 68__version__ = "0.1.0" 69__all__ = [ 70 # Core 71 "rollmatch", 72 "RollmatchResult", 73 # Pipeline stages 74 "reduce_data", 75 "score_data", 76 "ScoredResult", 77 # Distance 78 "DistanceSpec", 79 # Constants 80 "SUPPORTED_MODELS", 81 "DISTANCE_MODELS", 82 "PROPENSITY_MODELS", 83 # Weighting 84 "entropy_balance", 85 # Balance diagnostics 86 "compute_balance", 87 "compute_balance_weighted", 88 "balance_by_period", 89 "balance_by_period_weighted", 90 "smd_table", 91 # Statistical tests 92 "balance_test", 93 "equivalence_test", 94 "balance_test_weighted", 95 "equivalence_test_weighted", 96]
427def rollmatch( 428 data: pl.DataFrame, 429 treat: str, 430 tm: str, 431 entry: str, 432 id: str, 433 covariates: list[str], 434 lookback: int = 1, 435 method: str | Callable = "matching", 436 verbose: bool = True, 437 **method_kwargs, 438) -> RollmatchResult | None: 439 """Run the rolling entry matching/weighting pipeline. 440 441 This is the main entry point for pyrollmatch. It orchestrates data 442 reduction, scoring, matching (or weighting), and balance computation 443 in a single call. 444 445 Parameters 446 ---------- 447 data : pl.DataFrame 448 Panel data with unit × time observations. Must contain columns 449 for treatment status, time period, entry period, unit ID, and 450 covariates. 451 treat : str 452 Column name for binary treatment indicator (1=treated, 0=control). 453 tm : str 454 Column name for time period (integer). 455 entry : str 456 Column name for entry period. Treated units: the period when 457 treatment begins. Controls: any value > max(tm) or null. 458 id : str 459 Column name for unit identifier. 460 covariates : list[str] 461 Covariate column names used for scoring and balance diagnostics. 462 lookback : int, default 1 463 Number of periods before entry to use as baseline. Must be >= 1. 464 method : str or callable, default ``"matching"`` 465 Weighting method: 466 467 ``"matching"`` 468 Nearest-neighbor matching on propensity scores or pairwise 469 distances. Returns match pairs in ``matched_data``. 470 471 **Matching kwargs:** 472 473 - ``ps_caliper`` (float, default 0): PS caliper multiplier. 474 0 = no caliper. 475 - ``ps_caliper_std`` (str, default ``"average"``): How to 476 compute pooled SD for PS caliper. ``"average"``, 477 ``"weighted"``, or ``"none"``. 478 - ``num_matches`` (int, default 1): Controls per treated. 479 - ``replacement`` (str, default ``"cross_cohort"``): 480 ``"unrestricted"``, ``"cross_cohort"``, or ``"global_no"``. 481 - ``model_type`` (str, default ``"logistic"``): Scoring model. 482 See :data:`~pyrollmatch.score.SUPPORTED_MODELS`. 483 - ``block_size`` (int, default 2000): Block size for memory. 484 - ``mahvars`` (list[str] or None): Covariates for Mahalanobis 485 distance matching with PS caliper (MatchIt pattern). 486 - ``m_order`` (str or None): Matching order: ``"largest"``, 487 ``"smallest"``, ``"random"``, ``"data"``, or ``None`` (auto). 488 - ``caliper`` (dict or None): Per-variable calipers, 489 e.g. ``{"x1": 0.5, "x2": 0.3}``. 490 - ``std_caliper`` (bool, default True): Whether per-variable 491 caliper widths are in SD units. 492 493 ``"ebal"`` 494 Entropy balancing (Hainmueller 2012). Returns per-cohort 495 weights in ``weighted_data``. 496 497 **Ebal kwargs:** 498 499 - ``moment`` (int, default 1): Moment constraint (1=mean, 500 2=mean+variance, 3=mean+variance+skewness). 501 - ``max_weight`` (float or None): Maximum weight cap. 502 503 callable 504 User-defined function with signature 505 ``fn(treated_data, control_data, covariates, id, **kwargs)`` 506 returning ``pl.DataFrame`` with columns ``[id, weight]``. 507 508 verbose : bool, default True 509 Print progress and balance summary. 510 **method_kwargs 511 Method-specific keyword arguments (see ``method`` above). 512 513 Returns 514 ------- 515 RollmatchResult or None 516 Contains ``matched_data``, ``balance``, ``weights``, and summary 517 statistics. Returns ``None`` if no matches/weights could be 518 computed (e.g. empty data, convergence failure). 519 520 Examples 521 -------- 522 Propensity score matching (default): 523 524 >>> result = rollmatch( 525 ... data, treat="treat", tm="time", entry="entry_time", 526 ... id="unit_id", covariates=["x1", "x2", "x3"], 527 ... ps_caliper=0.2, num_matches=3, 528 ... ) 529 >>> result.matched_data # match pairs 530 >>> result.balance # SMD table 531 532 Mahalanobis distance matching: 533 534 >>> result = rollmatch( 535 ... data, treat="treat", tm="time", entry="entry_time", 536 ... id="unit_id", covariates=["x1", "x2", "x3"], 537 ... model_type="mahalanobis", 538 ... ) 539 540 Mahalanobis matching with PS caliper (MatchIt ``mahvars`` pattern): 541 542 >>> result = rollmatch( 543 ... data, treat="treat", tm="time", entry="entry_time", 544 ... id="unit_id", covariates=["x1", "x2", "x3"], 545 ... ps_caliper=0.25, mahvars=["x1", "x2"], 546 ... ) 547 548 Entropy balancing: 549 550 >>> result = rollmatch( 551 ... data, treat="treat", tm="time", entry="entry_time", 552 ... id="unit_id", covariates=["x1", "x2", "x3"], 553 ... method="ebal", moment=1, 554 ... ) 555 >>> result.weights # unit weights 556 >>> result.weighted_data # per-cohort weights 557 """ 558 if callable(method) and not isinstance(method, str): 559 return _run_callable( 560 method, data, treat, tm, entry, id, covariates, 561 lookback, verbose, **method_kwargs, 562 ) 563 elif method == "matching": 564 return _run_matching( 565 data, treat, tm, entry, id, covariates, 566 lookback, verbose, **method_kwargs, 567 ) 568 elif method == "ebal": 569 return _run_ebal( 570 data, treat, tm, entry, id, covariates, 571 lookback, verbose, **method_kwargs, 572 ) 573 else: 574 raise ValueError( 575 f"method must be 'matching', 'ebal', or a callable, got {method!r}" 576 )
Run the rolling entry matching/weighting pipeline.
This is the main entry point for pyrollmatch. It orchestrates data reduction, scoring, matching (or weighting), and balance computation in a single call.
Parameters
data : pl.DataFrame
Panel data with unit × time observations. Must contain columns
for treatment status, time period, entry period, unit ID, and
covariates.
treat : str
Column name for binary treatment indicator (1=treated, 0=control).
tm : str
Column name for time period (integer).
entry : str
Column name for entry period. Treated units: the period when
treatment begins. Controls: any value > max(tm) or null.
id : str
Column name for unit identifier.
covariates : list[str]
Covariate column names used for scoring and balance diagnostics.
lookback : int, default 1
Number of periods before entry to use as baseline. Must be >= 1.
method : str or callable, default "matching"
Weighting method:
``"matching"``
Nearest-neighbor matching on propensity scores or pairwise
distances. Returns match pairs in ``matched_data``.
**Matching kwargs:**
- ``ps_caliper`` (float, default 0): PS caliper multiplier.
0 = no caliper.
- ``ps_caliper_std`` (str, default ``"average"``): How to
compute pooled SD for PS caliper. ``"average"``,
``"weighted"``, or ``"none"``.
- ``num_matches`` (int, default 1): Controls per treated.
- ``replacement`` (str, default ``"cross_cohort"``):
``"unrestricted"``, ``"cross_cohort"``, or ``"global_no"``.
- ``model_type`` (str, default ``"logistic"``): Scoring model.
See `~pyrollmatch.score.SUPPORTED_MODELS`.
- ``block_size`` (int, default 2000): Block size for memory.
- ``mahvars`` (list[str] or None): Covariates for Mahalanobis
distance matching with PS caliper (MatchIt pattern).
- ``m_order`` (str or None): Matching order: ``"largest"``,
``"smallest"``, ``"random"``, ``"data"``, or ``None`` (auto).
- ``caliper`` (dict or None): Per-variable calipers,
e.g. ``{"x1": 0.5, "x2": 0.3}``.
- ``std_caliper`` (bool, default True): Whether per-variable
caliper widths are in SD units.
``"ebal"``
Entropy balancing (Hainmueller 2012). Returns per-cohort
weights in ``weighted_data``.
**Ebal kwargs:**
- ``moment`` (int, default 1): Moment constraint (1=mean,
2=mean+variance, 3=mean+variance+skewness).
- ``max_weight`` (float or None): Maximum weight cap.
callable
User-defined function with signature
``fn(treated_data, control_data, covariates, id, **kwargs)``
returning ``pl.DataFrame`` with columns ``[id, weight]``.
verbose : bool, default True
Print progress and balance summary.
**method_kwargs
Method-specific keyword arguments (see method above).
Returns
RollmatchResult or None
Contains matched_data, balance, weights, and summary
statistics. Returns None if no matches/weights could be
computed (e.g. empty data, convergence failure).
Examples
Propensity score matching (default):
>>> result = rollmatch(
... data, treat="treat", tm="time", entry="entry_time",
... id="unit_id", covariates=["x1", "x2", "x3"],
... ps_caliper=0.2, num_matches=3,
... )
>>> result.matched_data # match pairs
>>> result.balance # SMD table
Mahalanobis distance matching:
>>> result = rollmatch(
... data, treat="treat", tm="time", entry="entry_time",
... id="unit_id", covariates=["x1", "x2", "x3"],
... model_type="mahalanobis",
... )
Mahalanobis matching with PS caliper (MatchIt mahvars pattern):
>>> result = rollmatch(
... data, treat="treat", tm="time", entry="entry_time",
... id="unit_id", covariates=["x1", "x2", "x3"],
... ps_caliper=0.25, mahvars=["x1", "x2"],
... )
Entropy balancing:
>>> result = rollmatch(
... data, treat="treat", tm="time", entry="entry_time",
... id="unit_id", covariates=["x1", "x2", "x3"],
... method="ebal", moment=1,
... )
>>> result.weights # unit weights
>>> result.weighted_data # per-cohort weights
34@dataclass 35class RollmatchResult: 36 """Result from :func:`rollmatch`. 37 38 Attributes 39 ---------- 40 matched_data : pl.DataFrame or None 41 Match pairs with columns ``[tm, treat_id, control_id, difference]``. 42 ``None`` for weighting-only methods (ebal, custom). 43 balance : pl.DataFrame 44 Covariate balance table with pre-match and post-match SMDs. 45 n_treated_total : int 46 Total treated units in the reduced sample. 47 n_treated_matched : int 48 Treated units that were successfully matched or weighted. 49 n_controls_matched : int 50 Unique control units used. 51 ps_caliper : float or None 52 Propensity score caliper multiplier used. ``None`` for non-matching 53 methods. 54 weights : pl.DataFrame 55 Unit-level weights ``[id, weight]``. For matching, derived from 56 pairs. For ebal, collapsed across cohorts. 57 weighted_data : pl.DataFrame or None 58 Per-cohort weights ``[tm, id, weight]`` (ebal/custom only). 59 method : str 60 Method used: ``"matching"``, ``"ebal"``, or ``"custom"``. 61 """ 62 matched_data: pl.DataFrame | None 63 balance: pl.DataFrame 64 n_treated_total: int 65 n_treated_matched: int 66 n_controls_matched: int 67 ps_caliper: float | None 68 weights: pl.DataFrame 69 weighted_data: pl.DataFrame | None = None 70 method: str = "matching"
Result from rollmatch().
Attributes
matched_data : pl.DataFrame or None
Match pairs with columns [tm, treat_id, control_id, difference].
None for weighting-only methods (ebal, custom).
balance : pl.DataFrame
Covariate balance table with pre-match and post-match SMDs.
n_treated_total : int
Total treated units in the reduced sample.
n_treated_matched : int
Treated units that were successfully matched or weighted.
n_controls_matched : int
Unique control units used.
ps_caliper : float or None
Propensity score caliper multiplier used. None for non-matching
methods.
weights : pl.DataFrame
Unit-level weights [id, weight]. For matching, derived from
pairs. For ebal, collapsed across cohorts.
weighted_data : pl.DataFrame or None
Per-cohort weights [tm, id, weight] (ebal/custom only).
method : str
Method used: "matching", "ebal", or "custom".
13def reduce_data( 14 data: pl.DataFrame, 15 treat: str, 16 tm: str, 17 entry: str, 18 id: str, 19 lookback: int = 1, 20) -> pl.DataFrame: 21 """Construct quasi-panel for rolling entry matching. 22 23 Parameters 24 ---------- 25 data : pl.DataFrame 26 Panel data with treat, time, entry, id columns and covariates. 27 treat : str 28 Column name for binary treatment indicator (1=treated, 0=control). 29 tm : str 30 Column name for time period. 31 entry : str 32 Column name for entry period. For treated units, this is the time 33 period when treatment begins (integer). For control units, use either: 34 - ``None``/``null`` (recommended), or 35 - Any integer larger than max(tm) (e.g., 99, 999) 36 Control units' entry values are never used by the algorithm. 37 id : str 38 Column name for unit identifier. 39 lookback : int 40 Number of periods to look back from entry for baseline covariates. 41 42 Returns 43 ------- 44 pl.DataFrame 45 Reduced dataset: treated at baseline + controls at all baseline periods. 46 """ 47 if lookback < 1: 48 raise ValueError(f"lookback must be >= 1, got {lookback}") 49 50 for col in [treat, tm, entry, id]: 51 if col not in data.columns: 52 raise ValueError(f"Column '{col}' not found in data") 53 54 # Treatment set: treated units at their baseline period (entry - lookback) 55 # Controls may have null/None entry — that's fine, they're filtered by treat==0 56 treat_set = data.filter( 57 (pl.col(treat) == 1) 58 & pl.col(entry).is_not_null() 59 & (pl.col(tm) == pl.col(entry) - lookback) 60 ) 61 62 # Get unique baseline time periods from treated set 63 baseline_periods = treat_set[tm].unique().to_list() 64 65 # Control set: controls at those same time periods 66 control_set = data.filter( 67 (pl.col(treat) == 0) & (pl.col(tm).is_in(baseline_periods)) 68 ) 69 70 # Combine 71 reduced = pl.concat([treat_set, control_set]) 72 73 return reduced
Construct quasi-panel for rolling entry matching.
Parameters
data : pl.DataFrame
Panel data with treat, time, entry, id columns and covariates.
treat : str
Column name for binary treatment indicator (1=treated, 0=control).
tm : str
Column name for time period.
entry : str
Column name for entry period. For treated units, this is the time
period when treatment begins (integer). For control units, use either:
- None/null (recommended), or
- Any integer larger than max(tm) (e.g., 99, 999)
Control units' entry values are never used by the algorithm.
id : str
Column name for unit identifier.
lookback : int
Number of periods to look back from entry for baseline covariates.
Returns
pl.DataFrame Reduced dataset: treated at baseline + controls at all baseline periods.
252def score_data( 253 reduced_data: pl.DataFrame, 254 covariates: list[str], 255 treat: str, 256 model_type: str = "logistic", 257 match_on: str = "logit", 258 max_iter: int = 1000, 259) -> ScoredResult: 260 """Fit a propensity/distance model and return scored data with metadata. 261 262 This is the scoring step of the matching pipeline. For propensity score 263 models, a classifier is fitted and scores are added to the data. For 264 distance-based models, covariance / scaling metadata is computed and 265 returned in the :class:`ScoredResult` for the matching engine to use. 266 267 Parameters 268 ---------- 269 reduced_data : pl.DataFrame 270 Output from :func:`reduce_data`. 271 covariates : list[str] 272 Column names of matching covariates. 273 treat : str 274 Column name for binary treatment indicator (1=treated, 0=control). 275 model_type : str, default ``"logistic"`` 276 Scoring model. One of :data:`SUPPORTED_MODELS`: 277 278 **Propensity score models** (produce scalar scores per unit): 279 ``"logistic"``, ``"probit"``, ``"gbm"``, ``"rf"``, 280 ``"lasso"``, ``"ridge"``, ``"elasticnet"``. 281 282 **Distance-based models** (pairwise distances computed by matcher): 283 ``"mahalanobis"``, ``"scaled_euclidean"``, 284 ``"robust_mahalanobis"``, ``"euclidean"``. 285 match_on : str, default ``"logit"`` 286 Score transformation for propensity models: 287 ``"logit"`` for log-odds (recommended), ``"pscore"`` for raw 288 probability. Ignored for distance-based models. 289 max_iter : int, default 1000 290 Maximum optimizer iterations (propensity models only). 291 292 Returns 293 ------- 294 ScoredResult 295 Contains ``.data`` (DataFrame with ``"score"`` column), 296 ``.model`` (fitted classifier or None), and distance metadata. 297 298 Raises 299 ------ 300 ValueError 301 If ``model_type`` is not in :data:`SUPPORTED_MODELS`, covariates 302 are missing, or data contains NaN values. 303 304 Examples 305 -------- 306 >>> result = score_data(reduced, ["x1", "x2"], "treat") 307 >>> result.data["score"] # propensity scores 308 >>> result.model # fitted LogisticRegression 309 310 >>> result = score_data(reduced, ["x1", "x2"], "treat", 311 ... model_type="mahalanobis") 312 >>> result.cov_inv # inverse covariance matrix 313 >>> result.distance_metric # "mahalanobis" 314 """ 315 if model_type not in SUPPORTED_MODELS: 316 raise ValueError( 317 f"model_type must be one of {SUPPORTED_MODELS}, got '{model_type}'" 318 ) 319 320 valid_match = ("logit", "pscore") 321 if match_on not in valid_match and model_type not in DISTANCE_MODELS: 322 raise ValueError(f"match_on must be one of {valid_match}, got '{match_on}'") 323 324 for col in covariates: 325 if col not in reduced_data.columns: 326 raise ValueError(f"Covariate '{col}' not found in data") 327 328 # Extract numpy arrays 329 X = reduced_data.select(covariates).to_numpy().astype(np.float64) 330 y = reduced_data[treat].to_numpy().astype(np.int32) 331 332 # Check for NaN 333 nan_mask = np.isnan(X).any(axis=1) 334 if nan_mask.any(): 335 raise ValueError( 336 f"{nan_mask.sum()} rows have NaN in covariates. " 337 "Remove NaN rows before scoring." 338 ) 339 340 # Initialize metadata fields 341 cov_inv = None 342 distance_metric = None 343 distance_transform = None 344 ranked_covariates = None 345 346 if model_type in DISTANCE_MODELS: 347 scores = np.zeros(len(X)) # placeholder — matcher uses covariates 348 model = None 349 350 if model_type == "mahalanobis": 351 cov = _pooled_within_group_cov(X, y) 352 cov += np.eye(cov.shape[0]) * 1e-6 353 cov_inv = np.linalg.inv(cov) 354 distance_metric = "mahalanobis" 355 356 elif model_type == "scaled_euclidean": 357 sd = _pooled_within_group_sd(X, y) 358 sd[sd < 1e-10] = 1.0 359 distance_transform = np.diag(1.0 / sd) 360 distance_metric = "scaled_euclidean" 361 362 elif model_type == "robust_mahalanobis": 363 # MatchIt: rank full sample, cov(ranks), scale by sd(1:n) 364 n = X.shape[0] 365 X_ranked = np.column_stack([ 366 rankdata(X[:, j], method="average") for j in range(X.shape[1]) 367 ]) 368 var_r = np.cov(X_ranked, rowvar=False, ddof=1) 369 var_r = np.atleast_2d(var_r) 370 sd_1_to_n = np.std(np.arange(1, n + 1), ddof=1) 371 multiplier = sd_1_to_n / np.sqrt(np.diag(var_r)).clip(1e-10) 372 var_r = var_r * np.outer(multiplier, multiplier) 373 var_r += np.eye(var_r.shape[0]) * 1e-6 374 cov_inv = np.linalg.inv(var_r) 375 distance_metric = "robust_mahalanobis" 376 377 elif model_type == "euclidean": 378 distance_metric = "euclidean" 379 380 else: 381 # Propensity score model 382 model = _build_model(model_type, max_iter) 383 model.fit(X, y) 384 scores = _predict_scores(model, X, model_type, match_on) 385 386 result_df = reduced_data.with_columns(pl.Series("score", scores)) 387 388 # For robust_mahalanobis: add full-dataset ranked columns 389 if model_type == "robust_mahalanobis": 390 ranked_covariates = [f"_ranked_{c}" for c in covariates] 391 result_df = result_df.with_columns([ 392 pl.Series(f"_ranked_{col}", X_ranked[:, j]) 393 for j, col in enumerate(covariates) 394 ]) 395 396 return ScoredResult( 397 data=result_df, model=model, covariates=covariates, 398 model_type=model_type, match_on=match_on, 399 cov_inv=cov_inv, 400 distance_metric=distance_metric, 401 distance_transform=distance_transform, 402 ranked_covariates=ranked_covariates, 403 )
Fit a propensity/distance model and return scored data with metadata.
This is the scoring step of the matching pipeline. For propensity score
models, a classifier is fitted and scores are added to the data. For
distance-based models, covariance / scaling metadata is computed and
returned in the ScoredResult for the matching engine to use.
Parameters
reduced_data : pl.DataFrame
Output from reduce_data().
covariates : list[str]
Column names of matching covariates.
treat : str
Column name for binary treatment indicator (1=treated, 0=control).
model_type : str, default "logistic"
Scoring model. One of SUPPORTED_MODELS:
**Propensity score models** (produce scalar scores per unit):
``"logistic"``, ``"probit"``, ``"gbm"``, ``"rf"``,
``"lasso"``, ``"ridge"``, ``"elasticnet"``.
**Distance-based models** (pairwise distances computed by matcher):
``"mahalanobis"``, ``"scaled_euclidean"``,
``"robust_mahalanobis"``, ``"euclidean"``.
match_on : str, default "logit"
Score transformation for propensity models:
"logit" for log-odds (recommended), "pscore" for raw
probability. Ignored for distance-based models.
max_iter : int, default 1000
Maximum optimizer iterations (propensity models only).
Returns
ScoredResult
Contains .data (DataFrame with "score" column),
.model (fitted classifier or None), and distance metadata.
Raises
ValueError
If model_type is not in SUPPORTED_MODELS, covariates
are missing, or data contains NaN values.
Examples
>>> result = score_data(reduced, ["x1", "x2"], "treat")
>>> result.data["score"] # propensity scores
>>> result.model # fitted LogisticRegression
>>> result = score_data(reduced, ["x1", "x2"], "treat",
... model_type="mahalanobis")
>>> result.cov_inv # inverse covariance matrix
>>> result.distance_metric # "mahalanobis"
74@dataclass 75class ScoredResult: 76 """Container for scored data and distance/model metadata. 77 78 Always returned by :func:`score_data`. Provides a uniform interface 79 regardless of whether the model is propensity-based or distance-based. 80 81 Attributes 82 ---------- 83 data : pl.DataFrame 84 Input data with an added ``"score"`` column. For propensity models, 85 this contains real propensity scores. For distance models, it 86 contains placeholder zeros (the matcher uses covariates directly). 87 model : Any 88 Fitted sklearn classifier for propensity models. ``None`` for 89 distance-based models. 90 covariates : list[str] 91 Covariate column names used for scoring. 92 model_type : str 93 The model type used (e.g. ``"logistic"``, ``"mahalanobis"``). 94 match_on : str 95 Score transformation applied: ``"logit"`` or ``"pscore"``. 96 Only meaningful for propensity models. 97 cov_inv : np.ndarray or None 98 Inverse covariance matrix. Set for ``"mahalanobis"`` and 99 ``"robust_mahalanobis"``; ``None`` otherwise. 100 distance_metric : str or None 101 Distance metric identifier passed to the matcher: 102 ``"mahalanobis"``, ``"robust_mahalanobis"``, ``"scaled_euclidean"``, 103 ``"euclidean"``, or ``None`` for propensity score matching. 104 distance_transform : np.ndarray or None 105 Diagonal scaling matrix for ``"scaled_euclidean"``. 106 ``None`` for all other model types. 107 ranked_covariates : list[str] or None 108 Column names of pre-ranked covariates added to ``data`` for 109 ``"robust_mahalanobis"`` (e.g. ``["_ranked_x1", "_ranked_x2"]``). 110 The matcher uses these instead of raw covariates so that 111 full-dataset ranks are consistent with ``cov_inv``. 112 """ 113 data: pl.DataFrame 114 model: Any 115 covariates: list[str] 116 model_type: str 117 match_on: str 118 cov_inv: np.ndarray | None = None 119 distance_metric: str | None = None 120 distance_transform: np.ndarray | None = None 121 ranked_covariates: list[str] | None = None
Container for scored data and distance/model metadata.
Always returned by score_data(). Provides a uniform interface
regardless of whether the model is propensity-based or distance-based.
Attributes
data : pl.DataFrame
Input data with an added "score" column. For propensity models,
this contains real propensity scores. For distance models, it
contains placeholder zeros (the matcher uses covariates directly).
model : Any
Fitted sklearn classifier for propensity models. None for
distance-based models.
covariates : list[str]
Covariate column names used for scoring.
model_type : str
The model type used (e.g. "logistic", "mahalanobis").
match_on : str
Score transformation applied: "logit" or "pscore".
Only meaningful for propensity models.
cov_inv : np.ndarray or None
Inverse covariance matrix. Set for "mahalanobis" and
"robust_mahalanobis"; None otherwise.
distance_metric : str or None
Distance metric identifier passed to the matcher:
"mahalanobis", "robust_mahalanobis", "scaled_euclidean",
"euclidean", or None for propensity score matching.
distance_transform : np.ndarray or None
Diagonal scaling matrix for "scaled_euclidean".
None for all other model types.
ranked_covariates : list[str] or None
Column names of pre-ranked covariates added to data for
"robust_mahalanobis" (e.g. ["_ranked_x1", "_ranked_x2"]).
The matcher uses these instead of raw covariates so that
full-dataset ranks are consistent with cov_inv.
52@dataclass 53class DistanceSpec: 54 """Bundle of distance-related parameters for the matching engine. 55 56 Encapsulates all metadata needed to compute pairwise distances 57 between treated and control units. For propensity score matching, 58 all fields are ``None``/default and the matcher uses scalar score 59 differences. 60 61 Attributes 62 ---------- 63 metric : str or None 64 Distance metric: ``"mahalanobis"``, ``"robust_mahalanobis"``, 65 ``"scaled_euclidean"``, ``"euclidean"``, or ``None`` for 66 propensity score matching. 67 covariates : list[str] or None 68 Covariate column names to extract from the scored DataFrame 69 for pairwise distance computation. 70 cov_inv : np.ndarray or None 71 Inverse covariance matrix (for Mahalanobis variants). 72 transform : np.ndarray or None 73 Diagonal scaling matrix (for scaled Euclidean). 74 is_mahvars : bool 75 If True, propensity scores are real (mahvars pattern) and the 76 PS caliper should be applied alongside distance matching. 77 """ 78 metric: str | None = None 79 covariates: list[str] | None = None 80 cov_inv: np.ndarray | None = None 81 transform: np.ndarray | None = None 82 is_mahvars: bool = False 83 84 @property 85 def use_pairwise(self) -> bool: 86 """Whether to compute pairwise covariate distances.""" 87 return self.metric is not None and self.covariates is not None
Bundle of distance-related parameters for the matching engine.
Encapsulates all metadata needed to compute pairwise distances
between treated and control units. For propensity score matching,
all fields are None/default and the matcher uses scalar score
differences.
Attributes
metric : str or None
Distance metric: "mahalanobis", "robust_mahalanobis",
"scaled_euclidean", "euclidean", or None for
propensity score matching.
covariates : list[str] or None
Covariate column names to extract from the scored DataFrame
for pairwise distance computation.
cov_inv : np.ndarray or None
Inverse covariance matrix (for Mahalanobis variants).
transform : np.ndarray or None
Diagonal scaling matrix (for scaled Euclidean).
is_mahvars : bool
If True, propensity scores are real (mahvars pattern) and the
PS caliper should be applied alongside distance matching.
143def entropy_balance( 144 treated_data: pl.DataFrame, 145 control_data: pl.DataFrame, 146 covariates: list[str], 147 id: str, 148 moment: int = 1, 149 max_weight: float | None = None, 150 tol: float = 1e-6, 151 max_iter: int = 200, 152) -> pl.DataFrame | None: 153 """Compute entropy-balanced weights for control units. 154 155 Finds weights that minimize KL divergence from uniform while 156 satisfying exact covariate moment constraints (Hainmueller 2012). 157 158 Parameters 159 ---------- 160 treated_data : pl.DataFrame 161 Treated units for this cohort. 162 control_data : pl.DataFrame 163 Control pool for this cohort. 164 covariates : list[str] 165 Covariate columns to balance. 166 id : str 167 Unit identifier column. 168 moment : int (1, 2, or 3) 169 Moments to balance: 1=means, 2=means+variances, 3=+skewness. 170 max_weight : float or None 171 Optional upper bound on any single weight (after normalization). 172 tol : float 173 Convergence tolerance for the optimizer. 174 max_iter : int 175 Maximum iterations. 176 177 Returns 178 ------- 179 pl.DataFrame with columns [id, weight] for treated (weight=1) and 180 controls (entropy-balanced weights), or None if optimization fails. 181 """ 182 X_t = treated_data.select(covariates).to_numpy().astype(np.float64) 183 X_c = control_data.select(covariates).to_numpy().astype(np.float64) 184 185 n_t = X_t.shape[0] 186 n_c = X_c.shape[0] 187 188 if n_t == 0 or n_c == 0: 189 return None 190 191 # Build constraint matrix and target 192 C = _build_constraint_matrix(X_c, moment) 193 target = _build_target(X_t, moment) 194 195 # Standardize for numerical stability 196 col_mean = C.mean(axis=0) 197 col_std = C.std(axis=0) 198 col_std[col_std < 1e-10] = 1.0 # avoid division by zero for constant cols 199 # Don't standardize the intercept column 200 col_mean[0] = 0.0 201 col_std[0] = 1.0 202 203 C_std = (C - col_mean) / col_std 204 target_std = (target - col_mean) / col_std 205 206 # Base weights (uniform) 207 base_q = np.full(n_c, 1.0 / n_c) 208 209 # Solve the dual 210 lam0 = np.zeros(C_std.shape[1]) 211 result = minimize( 212 _dual_objective, 213 lam0, 214 args=(C_std, target_std, base_q), 215 method="L-BFGS-B", 216 jac=True, # _dual_objective returns (obj, grad) 217 options={"maxiter": max_iter, "ftol": tol, "gtol": tol}, 218 ) 219 220 if not result.success: 221 warnings.warn( 222 f"Entropy balancing did not converge: {result.message}. " 223 "This may indicate insufficient overlap between treated and " 224 "control covariate distributions.", 225 stacklevel=2, 226 ) 227 return None 228 229 # Recover primal weights 230 lin = C_std @ result.x 231 lin = np.clip(lin, -500, 500) 232 w = base_q * np.exp(lin) 233 234 # w should sum to ~1 (enforced by intercept constraint); rescale to n_treated 235 w = w * (n_t / w.sum()) 236 237 # Effective sample size check 238 n_eff = w.sum() ** 2 / np.sum(w ** 2) 239 if n_eff / n_c < 0.1: 240 warnings.warn( 241 f"Effective sample size is very low: n_eff={n_eff:.1f} " 242 f"({100*n_eff/n_c:.1f}% of {n_c} controls). " 243 "Consider relaxing moment constraints or checking overlap.", 244 stacklevel=2, 245 ) 246 247 # Optional weight capping (iterate to handle re-normalization) 248 if max_weight is not None: 249 for _ in range(20): # converges in a few iterations 250 if np.max(w) <= max_weight: 251 break 252 w = np.minimum(w, max_weight) 253 w = w * (n_t / w.sum()) 254 255 # Build output DataFrame 256 treat_ids = treated_data[id].to_numpy() 257 ctrl_ids = control_data[id].to_numpy() 258 259 treat_weights = pl.DataFrame({ 260 id: treat_ids, 261 "weight": np.ones(n_t), 262 }) 263 ctrl_weights = pl.DataFrame({ 264 id: ctrl_ids, 265 "weight": w, 266 }) 267 268 return pl.concat([treat_weights, ctrl_weights])
Compute entropy-balanced weights for control units.
Finds weights that minimize KL divergence from uniform while satisfying exact covariate moment constraints (Hainmueller 2012).
Parameters
treated_data : pl.DataFrame Treated units for this cohort. control_data : pl.DataFrame Control pool for this cohort. covariates : list[str] Covariate columns to balance. id : str Unit identifier column. moment : int (1, 2, or 3) Moments to balance: 1=means, 2=means+variances, 3=+skewness. max_weight : float or None Optional upper bound on any single weight (after normalization). tol : float Convergence tolerance for the optimizer. max_iter : int Maximum iterations.
Returns
pl.DataFrame with columns [id, weight] for treated (weight=1) and controls (entropy-balanced weights), or None if optimization fails.
43def compute_balance( 44 scored_data: pl.DataFrame, 45 matches: pl.DataFrame, 46 treat: str, 47 id: str, 48 tm: str, 49 covariates: list[str], 50) -> pl.DataFrame: 51 """Compute covariate balance before and after matching. 52 53 Returns a table with means, SDs, and SMDs for each covariate, 54 both in the full sample and the matched sample. 55 56 Parameters 57 ---------- 58 scored_data : pl.DataFrame 59 Reduced data with treatment indicator and covariates. 60 matches : pl.DataFrame 61 Match results with treat_id and control_id columns. 62 treat : str 63 Treatment indicator column. 64 id : str 65 Unit identifier column. 66 tm : str 67 Time period column. 68 covariates : list[str] 69 Covariate column names. 70 71 Returns 72 ------- 73 pl.DataFrame with columns: 74 covariate, full_mean_t, full_mean_c, full_sd_t, full_sd_c, 75 full_smd, matched_mean_t, matched_mean_c, matched_sd_t, 76 matched_sd_c, matched_smd 77 """ 78 # Pre-compute matched data ONCE (not per covariate) 79 treat_matches = matches.select(tm, "treat_id").unique().rename({"treat_id": id}) 80 control_matches = matches.select(tm, "control_id").unique().rename({"control_id": id}) 81 matched_ids_df = pl.concat([treat_matches, control_matches]) 82 matched_data = scored_data.join(matched_ids_df, on=[tm, id], how="semi") 83 84 # Pre-split by treatment group 85 full_t = scored_data.filter(pl.col(treat) == 1) 86 full_c = scored_data.filter(pl.col(treat) == 0) 87 match_t = matched_data.filter(pl.col(treat) == 1) 88 match_c = matched_data.filter(pl.col(treat) == 0) 89 90 rows = [] 91 92 for cov in covariates: 93 vals_t = full_t[cov].drop_nulls().to_numpy() 94 vals_c = full_c[cov].drop_nulls().to_numpy() 95 96 full_mean_t = np.mean(vals_t) if len(vals_t) > 0 else np.nan 97 full_mean_c = np.mean(vals_c) if len(vals_c) > 0 else np.nan 98 full_sd_t = np.std(vals_t, ddof=1) if len(vals_t) > 1 else np.nan 99 full_sd_c = np.std(vals_c, ddof=1) if len(vals_c) > 1 else np.nan 100 full_pooled = np.sqrt((full_sd_t**2 + full_sd_c**2) / 2) if not (np.isnan(full_sd_t) or np.isnan(full_sd_c)) else np.nan 101 full_smd = (full_mean_t - full_mean_c) / full_pooled if full_pooled and full_pooled > 0 else np.nan 102 103 mvals_t = match_t[cov].drop_nulls().to_numpy() 104 mvals_c = match_c[cov].drop_nulls().to_numpy() 105 106 m_mean_t = np.mean(mvals_t) if len(mvals_t) > 0 else np.nan 107 m_mean_c = np.mean(mvals_c) if len(mvals_c) > 0 else np.nan 108 m_sd_t = np.std(mvals_t, ddof=1) if len(mvals_t) > 1 else np.nan 109 m_sd_c = np.std(mvals_c, ddof=1) if len(mvals_c) > 1 else np.nan 110 m_pooled = np.sqrt((m_sd_t**2 + m_sd_c**2) / 2) if not (np.isnan(m_sd_t) or np.isnan(m_sd_c)) else np.nan 111 m_smd = (m_mean_t - m_mean_c) / m_pooled if m_pooled and m_pooled > 0 else np.nan 112 113 rows.append({ 114 "covariate": cov, 115 "full_mean_t": round(full_mean_t, 4), 116 "full_mean_c": round(full_mean_c, 4), 117 "full_sd_t": round(full_sd_t, 4), 118 "full_sd_c": round(full_sd_c, 4), 119 "full_smd": round(full_smd, 4), 120 "matched_mean_t": round(m_mean_t, 4), 121 "matched_mean_c": round(m_mean_c, 4), 122 "matched_sd_t": round(m_sd_t, 4), 123 "matched_sd_c": round(m_sd_c, 4), 124 "matched_smd": round(m_smd, 4), 125 }) 126 127 return pl.DataFrame(rows)
Compute covariate balance before and after matching.
Returns a table with means, SDs, and SMDs for each covariate, both in the full sample and the matched sample.
Parameters
scored_data : pl.DataFrame Reduced data with treatment indicator and covariates. matches : pl.DataFrame Match results with treat_id and control_id columns. treat : str Treatment indicator column. id : str Unit identifier column. tm : str Time period column. covariates : list[str] Covariate column names.
Returns
pl.DataFrame with columns: covariate, full_mean_t, full_mean_c, full_sd_t, full_sd_c, full_smd, matched_mean_t, matched_mean_c, matched_sd_t, matched_sd_c, matched_smd
273def compute_balance_weighted( 274 data: pl.DataFrame, 275 weights: pl.DataFrame, 276 treat: str, 277 id: str, 278 covariates: list[str], 279) -> pl.DataFrame: 280 """Compute covariate balance using weighted means and SDs. 281 282 Parameters 283 ---------- 284 data : pl.DataFrame 285 Reduced data with treatment indicator and covariates. 286 weights : pl.DataFrame 287 Unit weights with columns [id, weight]. 288 treat : str 289 Treatment indicator column. 290 id : str 291 Unit identifier column. 292 covariates : list[str] 293 Covariate column names. 294 295 Returns 296 ------- 297 pl.DataFrame with same schema as compute_balance(): 298 covariate, full_mean_t, full_mean_c, full_sd_t, full_sd_c, 299 full_smd, matched_mean_t, matched_mean_c, matched_sd_t, 300 matched_sd_c, matched_smd 301 """ 302 # Full sample (unweighted) 303 full_t = data.filter(pl.col(treat) == 1) 304 full_c = data.filter(pl.col(treat) == 0) 305 306 # Weighted sample: join with weights 307 weighted = data.join(weights, on=id, how="inner") 308 w_t = weighted.filter(pl.col(treat) == 1) 309 w_c = weighted.filter(pl.col(treat) == 0) 310 311 rows = [] 312 for cov in covariates: 313 # Full sample (unweighted) 314 vals_ft = full_t[cov].drop_nulls().to_numpy() 315 vals_fc = full_c[cov].drop_nulls().to_numpy() 316 full_mean_t = np.mean(vals_ft) if len(vals_ft) > 0 else np.nan 317 full_mean_c = np.mean(vals_fc) if len(vals_fc) > 0 else np.nan 318 full_sd_t = np.std(vals_ft, ddof=1) if len(vals_ft) > 1 else np.nan 319 full_sd_c = np.std(vals_fc, ddof=1) if len(vals_fc) > 1 else np.nan 320 full_pooled = np.sqrt((full_sd_t**2 + full_sd_c**2) / 2) if not (np.isnan(full_sd_t) or np.isnan(full_sd_c)) else np.nan 321 full_smd = (full_mean_t - full_mean_c) / full_pooled if full_pooled and full_pooled > 0 else np.nan 322 323 # Weighted sample 324 wvals_t = w_t[cov].drop_nulls().to_numpy() 325 wvals_c = w_c[cov].drop_nulls().to_numpy() 326 ww_t = w_t.filter(pl.col(cov).is_not_null())["weight"].to_numpy() 327 ww_c = w_c.filter(pl.col(cov).is_not_null())["weight"].to_numpy() 328 329 m_mean_t = _weighted_mean(wvals_t, ww_t) if len(wvals_t) > 0 else np.nan 330 m_mean_c = _weighted_mean(wvals_c, ww_c) if len(wvals_c) > 0 else np.nan 331 m_sd_t = _weighted_std(wvals_t, ww_t) if len(wvals_t) > 1 else np.nan 332 m_sd_c = _weighted_std(wvals_c, ww_c) if len(wvals_c) > 1 else np.nan 333 m_pooled = np.sqrt((m_sd_t**2 + m_sd_c**2) / 2) if not (np.isnan(m_sd_t) or np.isnan(m_sd_c)) else np.nan 334 m_smd = (m_mean_t - m_mean_c) / m_pooled if m_pooled and m_pooled > 0 else np.nan 335 336 rows.append({ 337 "covariate": cov, 338 "full_mean_t": round(full_mean_t, 4), 339 "full_mean_c": round(full_mean_c, 4), 340 "full_sd_t": round(full_sd_t, 4), 341 "full_sd_c": round(full_sd_c, 4), 342 "full_smd": round(full_smd, 4), 343 "matched_mean_t": round(m_mean_t, 4), 344 "matched_mean_c": round(m_mean_c, 4), 345 "matched_sd_t": round(m_sd_t, 4), 346 "matched_sd_c": round(m_sd_c, 4), 347 "matched_smd": round(m_smd, 4), 348 }) 349 350 return pl.DataFrame(rows)
Compute covariate balance using weighted means and SDs.
Parameters
data : pl.DataFrame Reduced data with treatment indicator and covariates. weights : pl.DataFrame Unit weights with columns [id, weight]. treat : str Treatment indicator column. id : str Unit identifier column. covariates : list[str] Covariate column names.
Returns
pl.DataFrame with same schema as compute_balance(): covariate, full_mean_t, full_mean_c, full_sd_t, full_sd_c, full_smd, matched_mean_t, matched_mean_c, matched_sd_t, matched_sd_c, matched_smd
130def balance_by_period( 131 scored_data: pl.DataFrame, 132 matches: pl.DataFrame, 133 treat: str, 134 id: str, 135 tm: str, 136 covariates: list[str], 137) -> tuple[pl.DataFrame, pl.DataFrame]: 138 """Compute per-period covariate balance (SMD) for matched samples. 139 140 Pooled SMD can mask within-cohort imbalance through cancellation. 141 This function computes SMD separately for each entry period, then 142 aggregates across periods. 143 144 Parameters 145 ---------- 146 scored_data : pl.DataFrame 147 Reduced data with treatment indicator and covariates. 148 matches : pl.DataFrame 149 Match results with treat_id, control_id, and time period columns. 150 treat : str 151 Treatment indicator column. 152 id : str 153 Unit identifier column. 154 tm : str 155 Time period column. 156 covariates : list[str] 157 Covariate column names. 158 159 Returns 160 ------- 161 (aggregate, detail) : tuple[pl.DataFrame, pl.DataFrame] 162 163 **aggregate** — one row per covariate with columns: 164 ``covariate``, ``wtd_mean_smd``, ``median_abs_smd``, 165 ``max_abs_smd``, ``n_periods``. 166 167 **detail** — one row per (period, covariate) with columns: 168 ``period``, ``covariate``, ``n_treated``, ``n_controls``, 169 ``mean_treated``, ``mean_control``, ``smd``. 170 """ 171 time_periods = matches[tm].unique().sort().to_list() 172 173 detail_rows = [] 174 175 for t in time_periods: 176 period_matches = matches.filter(pl.col(tm) == t) 177 178 # Treated units in this period 179 t_ids = period_matches.select("treat_id").unique().rename({"treat_id": id}) 180 t_data = scored_data.filter( 181 (pl.col(treat) == 1) & (pl.col(tm) == t) 182 ).join(t_ids, on=id, how="semi") 183 184 # Control units in this period 185 c_ids = period_matches.select("control_id").unique().rename({"control_id": id}) 186 c_data = scored_data.filter( 187 (pl.col(treat) == 0) & (pl.col(tm) == t) 188 ).join(c_ids, on=id, how="semi") 189 190 n_treated = t_data.height 191 n_controls = c_data.height 192 193 if n_treated < 2 or n_controls < 2: 194 continue 195 196 for cov in covariates: 197 vals_t = t_data[cov].drop_nulls().to_numpy() 198 vals_c = c_data[cov].drop_nulls().to_numpy() 199 200 if len(vals_t) < 2 or len(vals_c) < 2: 201 detail_rows.append({ 202 "period": t, "covariate": cov, 203 "n_treated": n_treated, "n_controls": n_controls, 204 "mean_treated": np.nan, "mean_control": np.nan, 205 "smd": np.nan, 206 }) 207 continue 208 209 mean_t = np.mean(vals_t) 210 mean_c = np.mean(vals_c) 211 sd_t = np.std(vals_t, ddof=1) 212 sd_c = np.std(vals_c, ddof=1) 213 pooled = np.sqrt((sd_t**2 + sd_c**2) / 2) 214 smd = (mean_t - mean_c) / pooled if pooled > 0 else np.nan 215 216 detail_rows.append({ 217 "period": t, "covariate": cov, 218 "n_treated": n_treated, "n_controls": n_controls, 219 "mean_treated": round(mean_t, 4), 220 "mean_control": round(mean_c, 4), 221 "smd": round(smd, 4), 222 }) 223 224 detail = pl.DataFrame(detail_rows) if detail_rows else pl.DataFrame( 225 schema={"period": pl.Int64, "covariate": pl.Utf8, 226 "n_treated": pl.Int64, "n_controls": pl.Int64, 227 "mean_treated": pl.Float64, "mean_control": pl.Float64, 228 "smd": pl.Float64} 229 ) 230 231 if detail.height == 0: 232 agg = pl.DataFrame( 233 schema={"covariate": pl.Utf8, "wtd_mean_smd": pl.Float64, 234 "median_abs_smd": pl.Float64, "max_abs_smd": pl.Float64, 235 "n_periods": pl.UInt32} 236 ) 237 return agg, detail 238 239 # Aggregate: weighted mean (by n_treated), median |SMD|, max |SMD| 240 agg_rows = [] 241 for cov in covariates: 242 cov_detail = detail.filter( 243 (pl.col("covariate") == cov) & pl.col("smd").is_not_nan() 244 ) 245 if cov_detail.height == 0: 246 agg_rows.append({ 247 "covariate": cov, "wtd_mean_smd": np.nan, 248 "median_abs_smd": np.nan, "max_abs_smd": np.nan, 249 "n_periods": 0, 250 }) 251 continue 252 253 smds = cov_detail["smd"].to_numpy() 254 weights = cov_detail["n_treated"].to_numpy().astype(float) 255 total_w = weights.sum() 256 257 wtd_mean = float(np.average(smds, weights=weights)) if total_w > 0 else np.nan 258 median_abs = float(np.median(np.abs(smds))) 259 max_abs = float(np.max(np.abs(smds))) 260 261 agg_rows.append({ 262 "covariate": cov, 263 "wtd_mean_smd": round(wtd_mean, 4), 264 "median_abs_smd": round(median_abs, 4), 265 "max_abs_smd": round(max_abs, 4), 266 "n_periods": cov_detail.height, 267 }) 268 269 agg = pl.DataFrame(agg_rows) 270 return agg, detail
Compute per-period covariate balance (SMD) for matched samples.
Pooled SMD can mask within-cohort imbalance through cancellation. This function computes SMD separately for each entry period, then aggregates across periods.
Parameters
scored_data : pl.DataFrame Reduced data with treatment indicator and covariates. matches : pl.DataFrame Match results with treat_id, control_id, and time period columns. treat : str Treatment indicator column. id : str Unit identifier column. tm : str Time period column. covariates : list[str] Covariate column names.
Returns
(aggregate, detail) : tuple[pl.DataFrame, pl.DataFrame]
**aggregate** — one row per covariate with columns:
``covariate``, ``wtd_mean_smd``, ``median_abs_smd``,
``max_abs_smd``, ``n_periods``.
**detail** — one row per (period, covariate) with columns:
``period``, ``covariate``, ``n_treated``, ``n_controls``,
``mean_treated``, ``mean_control``, ``smd``.
353def balance_by_period_weighted( 354 data: pl.DataFrame, 355 weights: pl.DataFrame, 356 treat: str, 357 id: str, 358 tm: str, 359 covariates: list[str], 360) -> tuple[pl.DataFrame, pl.DataFrame]: 361 """Per-period covariate balance using weighted means/SDs. 362 363 Parameters 364 ---------- 365 data : pl.DataFrame 366 Reduced data. 367 weights : pl.DataFrame 368 Unit weights. Either stacked [tm, id, weight] (per-cohort 369 weights) or collapsed [id, weight]. If stacked, per-cohort 370 weights are used for each period's balance computation. 371 treat, id, tm : str 372 Column names. 373 covariates : list[str] 374 Covariate column names. 375 376 Returns 377 ------- 378 (aggregate, detail) with same schemas as balance_by_period(). 379 """ 380 # Detect stacked vs collapsed weights 381 stacked = tm in weights.columns 382 if stacked: 383 weighted = data.join(weights, on=[tm, id], how="inner") 384 else: 385 weighted = data.join(weights, on=id, how="inner") 386 time_periods = weighted[tm].unique().sort().to_list() 387 388 detail_rows = [] 389 for t in time_periods: 390 period_data = weighted.filter(pl.col(tm) == t) 391 t_data = period_data.filter(pl.col(treat) == 1) 392 c_data = period_data.filter(pl.col(treat) == 0) 393 394 n_treated = t_data.height 395 n_controls = c_data.height 396 if n_treated < 2 or n_controls < 2: 397 continue 398 399 for cov in covariates: 400 vals_t = t_data[cov].drop_nulls().to_numpy() 401 vals_c = c_data[cov].drop_nulls().to_numpy() 402 ww_t = t_data.filter(pl.col(cov).is_not_null())["weight"].to_numpy() 403 ww_c = c_data.filter(pl.col(cov).is_not_null())["weight"].to_numpy() 404 405 if len(vals_t) < 2 or len(vals_c) < 2: 406 detail_rows.append({ 407 "period": t, "covariate": cov, 408 "n_treated": n_treated, "n_controls": n_controls, 409 "mean_treated": np.nan, "mean_control": np.nan, 410 "smd": np.nan, 411 }) 412 continue 413 414 mean_t = _weighted_mean(vals_t, ww_t) 415 mean_c = _weighted_mean(vals_c, ww_c) 416 sd_t = _weighted_std(vals_t, ww_t) 417 sd_c = _weighted_std(vals_c, ww_c) 418 pooled = np.sqrt((sd_t**2 + sd_c**2) / 2) if not (np.isnan(sd_t) or np.isnan(sd_c)) else np.nan 419 smd = (mean_t - mean_c) / pooled if pooled and pooled > 0 else np.nan 420 421 detail_rows.append({ 422 "period": t, "covariate": cov, 423 "n_treated": n_treated, "n_controls": n_controls, 424 "mean_treated": round(mean_t, 4), 425 "mean_control": round(mean_c, 4), 426 "smd": round(smd, 4), 427 }) 428 429 detail = pl.DataFrame(detail_rows) if detail_rows else pl.DataFrame( 430 schema={"period": pl.Int64, "covariate": pl.Utf8, 431 "n_treated": pl.Int64, "n_controls": pl.Int64, 432 "mean_treated": pl.Float64, "mean_control": pl.Float64, 433 "smd": pl.Float64} 434 ) 435 436 if detail.height == 0: 437 agg = pl.DataFrame( 438 schema={"covariate": pl.Utf8, "wtd_mean_smd": pl.Float64, 439 "median_abs_smd": pl.Float64, "max_abs_smd": pl.Float64, 440 "n_periods": pl.UInt32} 441 ) 442 return agg, detail 443 444 agg_rows = [] 445 for cov in covariates: 446 cov_detail = detail.filter( 447 (pl.col("covariate") == cov) & pl.col("smd").is_not_nan() 448 ) 449 if cov_detail.height == 0: 450 agg_rows.append({ 451 "covariate": cov, "wtd_mean_smd": np.nan, 452 "median_abs_smd": np.nan, "max_abs_smd": np.nan, 453 "n_periods": 0, 454 }) 455 continue 456 457 smds = cov_detail["smd"].to_numpy() 458 period_weights = cov_detail["n_treated"].to_numpy().astype(float) 459 total_w = period_weights.sum() 460 461 wtd_mean = float(np.average(smds, weights=period_weights)) if total_w > 0 else np.nan 462 median_abs = float(np.median(np.abs(smds))) 463 max_abs = float(np.max(np.abs(smds))) 464 465 agg_rows.append({ 466 "covariate": cov, 467 "wtd_mean_smd": round(wtd_mean, 4), 468 "median_abs_smd": round(median_abs, 4), 469 "max_abs_smd": round(max_abs, 4), 470 "n_periods": cov_detail.height, 471 }) 472 473 agg = pl.DataFrame(agg_rows) 474 return agg, detail
Per-period covariate balance using weighted means/SDs.
Parameters
data : pl.DataFrame Reduced data. weights : pl.DataFrame Unit weights. Either stacked [tm, id, weight] (per-cohort weights) or collapsed [id, weight]. If stacked, per-cohort weights are used for each period's balance computation. treat, id, tm : str Column names. covariates : list[str] Covariate column names.
Returns
(aggregate, detail) with same schemas as balance_by_period().
477def smd_table(balance: pl.DataFrame, threshold: float = 0.1) -> None: 478 """Print a formatted SMD table with pass/fail indicators. 479 480 Parameters 481 ---------- 482 balance : pl.DataFrame 483 Output from compute_balance(). 484 threshold : float 485 |SMD| threshold for pass/fail (default 0.1). 486 """ 487 max_smd = balance["matched_smd"].abs().max() 488 all_pass = balance["matched_smd"].abs().max() < threshold 489 490 print(f"\n{'='*70}") 491 print(f" Standardized Mean Differences (threshold: |SMD| < {threshold})") 492 print(f" Max |SMD| = {max_smd:.4f} {'✓ ALL PASS' if all_pass else '✗ SOME FAIL'}") 493 print(f"{'='*70}\n") 494 print(f" {'Covariate':<30} {'Full SMD':>10} {'Matched SMD':>12} {'Pass':>6}") 495 print(f" {'-'*30} {'-'*10} {'-'*12} {'-'*6}") 496 497 for row in balance.iter_rows(named=True): 498 smd = row["matched_smd"] 499 passed = abs(smd) < threshold if smd is not None else False 500 print(f" {row['covariate']:<30} {row['full_smd']:>10.4f} {smd:>12.4f} {'✓' if passed else '✗':>6}") 501 502 print()
Print a formatted SMD table with pass/fail indicators.
Parameters
balance : pl.DataFrame Output from compute_balance(). threshold : float |SMD| threshold for pass/fail (default 0.1).
17def balance_test( 18 scored_data: pl.DataFrame, 19 matches: pl.DataFrame, 20 treat: str, 21 id: str, 22 tm: str, 23 covariates: list[str], 24 threshold: float = 0.1, 25) -> pl.DataFrame: 26 """Run comprehensive balance diagnostics on matched sample. 27 28 For each covariate, computes: 29 - Standardized mean difference (SMD) 30 - Two-sample t-test (H0: means are equal) 31 - Variance ratio (treat/control) 32 - Kolmogorov-Smirnov test (H0: distributions are equal) 33 34 Parameters 35 ---------- 36 scored_data : pl.DataFrame 37 Reduced data with treatment indicator and covariates. 38 matches : pl.DataFrame 39 Match results with treat_id, control_id, tm columns. 40 treat : str 41 Treatment indicator column. 42 id : str 43 Unit identifier column. 44 tm : str 45 Time period column. 46 covariates : list[str] 47 Covariate column names. 48 threshold : float 49 SMD threshold for pass/fail (default 0.1). 50 51 Returns 52 ------- 53 pl.DataFrame with diagnostics per covariate. 54 """ 55 # Get matched units 56 treat_matches = matches.select(tm, "treat_id").unique().rename({"treat_id": id}) 57 control_matches = matches.select(tm, "control_id").unique().rename({"control_id": id}) 58 matched_ids = pl.concat([treat_matches, control_matches]) 59 matched_data = scored_data.join(matched_ids, on=[tm, id], how="semi") 60 61 rows = [] 62 for cov in covariates: 63 vals_t = matched_data.filter(pl.col(treat) == 1)[cov].drop_nulls().to_numpy().astype(float) 64 vals_c = matched_data.filter(pl.col(treat) == 0)[cov].drop_nulls().to_numpy().astype(float) 65 66 if len(vals_t) < 2 or len(vals_c) < 2: 67 continue 68 69 # SMD 70 sd_t, sd_c = np.std(vals_t, ddof=1), np.std(vals_c, ddof=1) 71 pooled_sd = np.sqrt((sd_t**2 + sd_c**2) / 2) 72 smd = (np.mean(vals_t) - np.mean(vals_c)) / pooled_sd if pooled_sd > 0 else np.nan 73 74 # Two-sample t-test (Welch's) 75 t_stat, t_pvalue = stats.ttest_ind(vals_t, vals_c, equal_var=False) 76 77 # Variance ratio 78 var_ratio = np.var(vals_t, ddof=1) / np.var(vals_c, ddof=1) if np.var(vals_c, ddof=1) > 0 else np.nan 79 80 # KS test 81 ks_stat, ks_pvalue = stats.ks_2samp(vals_t, vals_c) 82 83 rows.append({ 84 "covariate": cov, 85 "mean_treated": round(np.mean(vals_t), 4), 86 "mean_control": round(np.mean(vals_c), 4), 87 "smd": round(smd, 4), 88 "smd_pass": bool(abs(smd) < threshold), 89 "t_stat": round(t_stat, 4), 90 "t_pvalue": round(t_pvalue, 4), 91 "var_ratio": round(var_ratio, 4), 92 "var_ratio_pass": bool(0.5 < var_ratio < 2.0) if not np.isnan(var_ratio) else False, 93 "ks_stat": round(ks_stat, 4), 94 "ks_pvalue": round(ks_pvalue, 4), 95 }) 96 97 result = pl.DataFrame(rows) 98 99 # Print summary 100 n_pass_smd = result.filter(pl.col("smd_pass")).height 101 n_pass_var = result.filter(pl.col("var_ratio_pass")).height 102 n_total = result.height 103 104 print(f"\n{'='*70}") 105 print(f" Post-Matching Balance Diagnostics") 106 print(f"{'='*70}") 107 print(f" SMD < {threshold}: {n_pass_smd}/{n_total} pass") 108 print(f" Variance ratio in (0.5, 2.0): {n_pass_var}/{n_total} pass") 109 print(f"{'='*70}\n") 110 111 print(f" {'Covariate':<25} {'SMD':>8} {'t-test p':>10} {'VR':>8} {'KS p':>8}") 112 print(f" {'-'*25} {'-'*8} {'-'*10} {'-'*8} {'-'*8}") 113 for row in result.iter_rows(named=True): 114 smd_flag = "✓" if row["smd_pass"] else "✗" 115 vr_flag = "✓" if row["var_ratio_pass"] else "✗" 116 print(f" {row['covariate']:<25} {row['smd']:>7.4f}{smd_flag} {row['t_pvalue']:>10.4f} {row['var_ratio']:>7.3f}{vr_flag} {row['ks_pvalue']:>8.4f}") 117 118 return result
Run comprehensive balance diagnostics on matched sample.
For each covariate, computes:
- Standardized mean difference (SMD)
- Two-sample t-test (H0: means are equal)
- Variance ratio (treat/control)
- Kolmogorov-Smirnov test (H0: distributions are equal)
Parameters
scored_data : pl.DataFrame Reduced data with treatment indicator and covariates. matches : pl.DataFrame Match results with treat_id, control_id, tm columns. treat : str Treatment indicator column. id : str Unit identifier column. tm : str Time period column. covariates : list[str] Covariate column names. threshold : float SMD threshold for pass/fail (default 0.1).
Returns
pl.DataFrame with diagnostics per covariate.
121def equivalence_test( 122 scored_data: pl.DataFrame, 123 matches: pl.DataFrame, 124 treat: str, 125 id: str, 126 tm: str, 127 covariates: list[str], 128 multiplier: float = 0.36, 129) -> pl.DataFrame: 130 """TOST equivalence test for covariate balance. 131 132 Tests H0: |SMD| >= delta (non-equivalence). 133 Rejection = GOOD (positive evidence of negligible difference). 134 Uses Hartman & Hidalgo (2018) approach: delta = multiplier × pooled_SD. 135 136 Parameters 137 ---------- 138 scored_data : pl.DataFrame 139 Reduced data. 140 matches : pl.DataFrame 141 Match results. 142 treat, id, tm : str 143 Column names. 144 covariates : list[str] 145 Covariate names. 146 multiplier : float 147 Equivalence bound as fraction of pooled SD (default 0.36). 148 149 Returns 150 ------- 151 pl.DataFrame with TOST results per covariate. 152 """ 153 treat_matches = matches.select(tm, "treat_id").unique().rename({"treat_id": id}) 154 control_matches = matches.select(tm, "control_id").unique().rename({"control_id": id}) 155 matched_ids = pl.concat([treat_matches, control_matches]) 156 matched_data = scored_data.join(matched_ids, on=[tm, id], how="semi") 157 158 rows = [] 159 for cov in covariates: 160 vals_t = matched_data.filter(pl.col(treat) == 1)[cov].drop_nulls().to_numpy().astype(float) 161 vals_c = matched_data.filter(pl.col(treat) == 0)[cov].drop_nulls().to_numpy().astype(float) 162 163 if len(vals_t) < 2 or len(vals_c) < 2: 164 continue 165 166 m, n = len(vals_t), len(vals_c) 167 diff = np.mean(vals_t) - np.mean(vals_c) 168 var_t = np.var(vals_t, ddof=1) 169 var_c = np.var(vals_c, ddof=1) 170 171 # Pooled SD: weighted formula matching Hartman & Hidalgo (2018) 172 # equivtest R package: sqrt(((m-1)*var_x + (n-1)*var_y) / (m+n-2)) 173 pooled_sd = np.sqrt(((m - 1) * var_t + (n - 1) * var_c) / (m + n - 2)) 174 delta = multiplier * pooled_sd 175 176 # Two one-sided t-tests following equivtest::tost() 177 # Uses Welch's t-test (unequal variances) 178 se = np.sqrt(var_t / m + var_c / n) 179 df_welch = se**4 / ((var_t/m)**2/(m-1) + (var_c/n)**2/(n-1)) if se > 0 else 1 180 181 # Upper test: H0: diff >= delta, alt: diff < delta 182 t_upper = (diff - delta) / se if se > 0 else np.inf 183 p_upper = stats.t.cdf(t_upper, df=df_welch) 184 185 # Lower test: H0: diff <= -delta, alt: diff > -delta 186 t_lower = (diff + delta) / se if se > 0 else -np.inf 187 p_lower = 1 - stats.t.cdf(t_lower, df=df_welch) 188 189 tost_p = max(p_upper, p_lower) 190 191 rows.append({ 192 "covariate": cov, 193 "diff": round(diff, 6), 194 "se": round(se, 6), 195 "delta": round(delta, 4), 196 "tost_p_upper": round(p_upper, 4), 197 "tost_p_lower": round(p_lower, 4), 198 "tost_p": round(tost_p, 4), 199 "equivalent": bool(tost_p < 0.05), 200 }) 201 202 result = pl.DataFrame(rows) 203 204 n_equiv = result.filter(pl.col("equivalent")).height 205 print(f"\n TOST Equivalence Test (bound = {multiplier}σ)") 206 print(f" Equivalent: {n_equiv}/{result.height} covariates (p < 0.05 = GOOD)") 207 for row in result.iter_rows(named=True): 208 flag = "✓ EQUIV" if row["equivalent"] else " not equiv" 209 print(f" {row['covariate']:<25} p={row['tost_p']:.4f} {flag}") 210 211 return result
TOST equivalence test for covariate balance.
Tests H0: |SMD| >= delta (non-equivalence). Rejection = GOOD (positive evidence of negligible difference). Uses Hartman & Hidalgo (2018) approach: delta = multiplier × pooled_SD.
Parameters
scored_data : pl.DataFrame Reduced data. matches : pl.DataFrame Match results. treat, id, tm : str Column names. covariates : list[str] Covariate names. multiplier : float Equivalence bound as fraction of pooled SD (default 0.36).
Returns
pl.DataFrame with TOST results per covariate.
225def balance_test_weighted( 226 data: pl.DataFrame, 227 weights: pl.DataFrame, 228 treat: str, 229 id: str, 230 covariates: list[str], 231 threshold: float = 0.1, 232) -> pl.DataFrame: 233 """Run balance diagnostics on weighted sample. 234 235 For each covariate, computes: 236 - Weighted standardized mean difference (SMD) 237 - Weighted two-sample Welch t-test 238 - Weighted variance ratio 239 - Effective sample sizes 240 241 Parameters 242 ---------- 243 data : pl.DataFrame 244 Reduced data with treatment indicator and covariates. 245 weights : pl.DataFrame 246 Unit weights with columns [id, weight]. 247 treat : str 248 Treatment indicator column. 249 id : str 250 Unit identifier column. 251 covariates : list[str] 252 Covariate column names. 253 threshold : float 254 SMD threshold for pass/fail (default 0.1). 255 256 Returns 257 ------- 258 pl.DataFrame with diagnostics per covariate. 259 """ 260 weighted = data.join(weights, on=id, how="inner") 261 w_t = weighted.filter(pl.col(treat) == 1) 262 w_c = weighted.filter(pl.col(treat) == 0) 263 264 rows = [] 265 for cov in covariates: 266 vals_t = w_t[cov].drop_nulls().to_numpy().astype(float) 267 vals_c = w_c[cov].drop_nulls().to_numpy().astype(float) 268 ww_t = w_t.filter(pl.col(cov).is_not_null())["weight"].to_numpy().astype(float) 269 ww_c = w_c.filter(pl.col(cov).is_not_null())["weight"].to_numpy().astype(float) 270 271 if len(vals_t) < 2 or len(vals_c) < 2: 272 continue 273 274 # Weighted SMD 275 sd_t = _weighted_std(vals_t, ww_t) 276 sd_c = _weighted_std(vals_c, ww_c) 277 pooled_sd = np.sqrt((sd_t**2 + sd_c**2) / 2) 278 mean_t = _weighted_mean(vals_t, ww_t) 279 mean_c = _weighted_mean(vals_c, ww_c) 280 smd = (mean_t - mean_c) / pooled_sd if pooled_sd > 0 else np.nan 281 282 # Effective sample sizes 283 n_eff_t = _effective_n(ww_t) 284 n_eff_c = _effective_n(ww_c) 285 286 # Weighted Welch t-test 287 var_t = sd_t ** 2 if not np.isnan(sd_t) else np.nan 288 var_c = sd_c ** 2 if not np.isnan(sd_c) else np.nan 289 se = np.sqrt(var_t / n_eff_t + var_c / n_eff_c) if (n_eff_t > 0 and n_eff_c > 0) else np.nan 290 291 if se and se > 0 and n_eff_t > 1 and n_eff_c > 1: 292 t_stat = (mean_t - mean_c) / se 293 # Satterthwaite degrees of freedom 294 df_welch = (var_t / n_eff_t + var_c / n_eff_c) ** 2 / ( 295 (var_t / n_eff_t) ** 2 / (n_eff_t - 1) + 296 (var_c / n_eff_c) ** 2 / (n_eff_c - 1) 297 ) 298 t_pvalue = 2 * (1 - stats.t.cdf(abs(t_stat), df=df_welch)) 299 else: 300 t_stat = np.nan 301 t_pvalue = np.nan 302 303 # Variance ratio 304 var_ratio = var_t / var_c if var_c and var_c > 0 else np.nan 305 306 rows.append({ 307 "covariate": cov, 308 "mean_treated": round(mean_t, 4), 309 "mean_control": round(mean_c, 4), 310 "smd": round(smd, 4), 311 "smd_pass": bool(abs(smd) < threshold) if not np.isnan(smd) else False, 312 "t_stat": round(t_stat, 4) if not np.isnan(t_stat) else np.nan, 313 "t_pvalue": round(t_pvalue, 4) if not np.isnan(t_pvalue) else np.nan, 314 "var_ratio": round(var_ratio, 4) if not np.isnan(var_ratio) else np.nan, 315 "var_ratio_pass": bool(0.5 < var_ratio < 2.0) if not np.isnan(var_ratio) else False, 316 "n_eff_treated": round(n_eff_t, 1), 317 "n_eff_control": round(n_eff_c, 1), 318 }) 319 320 result = pl.DataFrame(rows) 321 322 n_pass_smd = result.filter(pl.col("smd_pass")).height 323 n_pass_var = result.filter(pl.col("var_ratio_pass")).height 324 n_total = result.height 325 326 print(f"\n{'='*70}") 327 print(f" Weighted Balance Diagnostics") 328 print(f"{'='*70}") 329 print(f" SMD < {threshold}: {n_pass_smd}/{n_total} pass") 330 print(f" Variance ratio in (0.5, 2.0): {n_pass_var}/{n_total} pass") 331 print(f"{'='*70}\n") 332 333 print(f" {'Covariate':<25} {'SMD':>8} {'t-test p':>10} {'VR':>8} {'n_eff_c':>8}") 334 print(f" {'-'*25} {'-'*8} {'-'*10} {'-'*8} {'-'*8}") 335 for row in result.iter_rows(named=True): 336 smd_flag = "✓" if row["smd_pass"] else "✗" 337 vr_flag = "✓" if row["var_ratio_pass"] else "✗" 338 print(f" {row['covariate']:<25} {row['smd']:>7.4f}{smd_flag} " 339 f"{row['t_pvalue']:>10.4f} {row['var_ratio']:>7.3f}{vr_flag} " 340 f"{row['n_eff_control']:>8.1f}") 341 342 return result
Run balance diagnostics on weighted sample.
For each covariate, computes:
- Weighted standardized mean difference (SMD)
- Weighted two-sample Welch t-test
- Weighted variance ratio
- Effective sample sizes
Parameters
data : pl.DataFrame Reduced data with treatment indicator and covariates. weights : pl.DataFrame Unit weights with columns [id, weight]. treat : str Treatment indicator column. id : str Unit identifier column. covariates : list[str] Covariate column names. threshold : float SMD threshold for pass/fail (default 0.1).
Returns
pl.DataFrame with diagnostics per covariate.
345def equivalence_test_weighted( 346 data: pl.DataFrame, 347 weights: pl.DataFrame, 348 treat: str, 349 id: str, 350 covariates: list[str], 351 multiplier: float = 0.36, 352) -> pl.DataFrame: 353 """TOST equivalence test for weighted covariate balance. 354 355 Same as equivalence_test() but uses weighted statistics and 356 effective sample sizes. 357 358 Parameters 359 ---------- 360 data : pl.DataFrame 361 Reduced data. 362 weights : pl.DataFrame 363 Unit weights with columns [id, weight]. 364 treat, id : str 365 Column names. 366 covariates : list[str] 367 Covariate names. 368 multiplier : float 369 Equivalence bound as fraction of pooled SD (default 0.36). 370 371 Returns 372 ------- 373 pl.DataFrame with TOST results per covariate. 374 """ 375 weighted = data.join(weights, on=id, how="inner") 376 w_t = weighted.filter(pl.col(treat) == 1) 377 w_c = weighted.filter(pl.col(treat) == 0) 378 379 rows = [] 380 for cov in covariates: 381 vals_t = w_t[cov].drop_nulls().to_numpy().astype(float) 382 vals_c = w_c[cov].drop_nulls().to_numpy().astype(float) 383 ww_t = w_t.filter(pl.col(cov).is_not_null())["weight"].to_numpy().astype(float) 384 ww_c = w_c.filter(pl.col(cov).is_not_null())["weight"].to_numpy().astype(float) 385 386 if len(vals_t) < 2 or len(vals_c) < 2: 387 continue 388 389 mean_t = _weighted_mean(vals_t, ww_t) 390 mean_c = _weighted_mean(vals_c, ww_c) 391 diff = mean_t - mean_c 392 393 var_t = _weighted_std(vals_t, ww_t) ** 2 394 var_c = _weighted_std(vals_c, ww_c) ** 2 395 396 n_eff_t = _effective_n(ww_t) 397 n_eff_c = _effective_n(ww_c) 398 399 # Pooled SD (weighted, for delta calculation) 400 pooled_sd = np.sqrt( 401 ((n_eff_t - 1) * var_t + (n_eff_c - 1) * var_c) / 402 (n_eff_t + n_eff_c - 2) 403 ) if (n_eff_t + n_eff_c > 2) else np.nan 404 405 delta = multiplier * pooled_sd if not np.isnan(pooled_sd) else np.nan 406 407 # Weighted SE and Welch df 408 se = np.sqrt(var_t / n_eff_t + var_c / n_eff_c) if (n_eff_t > 0 and n_eff_c > 0) else np.nan 409 if se and se > 0 and n_eff_t > 1 and n_eff_c > 1: 410 df_welch = se ** 4 / ( 411 (var_t / n_eff_t) ** 2 / (n_eff_t - 1) + 412 (var_c / n_eff_c) ** 2 / (n_eff_c - 1) 413 ) 414 else: 415 df_welch = 1 416 417 # TOST 418 if se and se > 0 and not np.isnan(delta): 419 t_upper = (diff - delta) / se 420 p_upper = stats.t.cdf(t_upper, df=df_welch) 421 t_lower = (diff + delta) / se 422 p_lower = 1 - stats.t.cdf(t_lower, df=df_welch) 423 tost_p = max(p_upper, p_lower) 424 else: 425 p_upper = np.nan 426 p_lower = np.nan 427 tost_p = np.nan 428 429 rows.append({ 430 "covariate": cov, 431 "diff": round(diff, 6), 432 "se": round(se, 6) if not np.isnan(se) else np.nan, 433 "delta": round(delta, 4) if not np.isnan(delta) else np.nan, 434 "tost_p_upper": round(p_upper, 4) if not np.isnan(p_upper) else np.nan, 435 "tost_p_lower": round(p_lower, 4) if not np.isnan(p_lower) else np.nan, 436 "tost_p": round(tost_p, 4) if not np.isnan(tost_p) else np.nan, 437 "equivalent": bool(tost_p < 0.05) if not np.isnan(tost_p) else False, 438 }) 439 440 result = pl.DataFrame(rows) 441 442 n_equiv = result.filter(pl.col("equivalent")).height 443 print(f"\n Weighted TOST Equivalence Test (bound = {multiplier}σ)") 444 print(f" Equivalent: {n_equiv}/{result.height} covariates (p < 0.05 = GOOD)") 445 for row in result.iter_rows(named=True): 446 flag = "✓ EQUIV" if row["equivalent"] else " not equiv" 447 print(f" {row['covariate']:<25} p={row['tost_p']:.4f} {flag}") 448 449 return result
TOST equivalence test for weighted covariate balance.
Same as equivalence_test() but uses weighted statistics and effective sample sizes.
Parameters
data : pl.DataFrame Reduced data. weights : pl.DataFrame Unit weights with columns [id, weight]. treat, id : str Column names. covariates : list[str] Covariate names. multiplier : float Equivalence bound as fraction of pooled SD (default 0.36).
Returns
pl.DataFrame with TOST results per covariate.