#!/usr/bin/env python3
# Timestamp: "2025-10-01 18:53:25 (ywatanabe)"
# File: scitex_stats/tests/categorical/_test_chi2.py
# ----------------------------------------
r"""
Chi-square test of independence for categorical data.
Tests association between two categorical variables in a contingency table.
"""
from __future__ import annotations
import os
from typing import Literal, Optional, Union
import matplotlib.axes
import numpy as np
import pandas as pd
import matplotlib.pyplot as _mpl_plt # noqa: E402
from scipy import stats
from scitex_stats._logging import getLogger
from scitex_stats._utils._formatters import p2stars
from scitex_stats._utils._normalizers import convert_results, force_dataframe
__FILE__ = __file__
__DIR__ = os.path.dirname(__FILE__)
logger = getLogger(__name__)
def cramers_v(chi2: float, n: int, r: int, c: int) -> float:
"""
Compute Cramér's V effect size for chi-square test.
Parameters
----------
chi2 : float
Chi-square statistic
n : int
Total sample size
r : int
Number of rows
c : int
Number of columns
Returns
-------
v : float
Cramér's V (0 to 1)
Notes
-----
Formula: V = sqrt(χ² / (n × (min(r,c) - 1))) # noqa: D301
Interpretation (Cramér, 1946):
For df* = min(r-1, c-1):
- df*=1: small=0.10, medium=0.30, large=0.50
- df*=2: small=0.07, medium=0.21, large=0.35
- df*=3: small=0.06, medium=0.17, large=0.29
"""
if n == 0:
return 0.0
min_dim = min(r, c)
if min_dim <= 1:
return 0.0
v = np.sqrt(chi2 / (n * (min_dim - 1)))
return float(v)
def interpret_cramers_v(v: float, df_star: int) -> str:
"""
Interpret Cramér's V effect size.
Parameters
----------
v : float
Cramér's V
df_star : int
min(rows-1, cols-1)
Returns
-------
interpretation : str
'negligible', 'small', 'medium', or 'large'
"""
if df_star == 1:
thresholds = (0.10, 0.30, 0.50)
elif df_star == 2:
thresholds = (0.07, 0.21, 0.35)
else: # df_star >= 3
thresholds = (0.06, 0.17, 0.29)
if v < thresholds[0]:
return "negligible"
elif v < thresholds[1]:
return "small"
elif v < thresholds[2]:
return "medium"
else:
return "large"
[docs]
def test_chi2( # noqa: C901
observed: Union[np.ndarray, pd.DataFrame],
var_row: Optional[str] = None,
var_col: Optional[str] = None,
alpha: float = 0.05,
correction: bool = True,
plot: bool = False,
ax: Optional[matplotlib.axes.Axes] = None,
return_as: Literal["dict", "dataframe"] = "dict",
decimals: int = 3,
verbose: bool = False,
) -> Union[dict, pd.DataFrame]:
"""
Chi-square test of independence for contingency tables.
Tests whether two categorical variables are independent.
Parameters
----------
observed : array-like or DataFrame
Observed frequencies as contingency table (rows × columns)
If DataFrame, row/column names used as variable names
var_row : str, optional
Name of row variable (default: 'row_variable')
var_col : str, optional
Name of column variable (default: 'col_variable')
alpha : float, default 0.05
Significance level
correction : bool, default True
Apply Yates' continuity correction for 2×2 tables
plot : bool, default False
If True, create mosaic plot visualization
ax : matplotlib.axes.Axes, optional
Axes to plot on. If provided, plot is set to True
return_as : {'dict', 'dataframe'}, default 'dict'
Return format
decimals : int, default 3
Number of decimal places for rounding
verbose : bool, default False
If True, print test results to logger
Returns
-------
result : dict or DataFrame
Test results with:
- test_method: Name of test
- statistic: Chi-square statistic (χ²)
- pvalue: p-value
- df: Degrees of freedom
- alpha: Significance level
- significant: Whether result is significant
- stars: Significance stars
- effect_size: Cramér's V
- effect_size_metric: "Cramér's V"
- effect_size_interpretation: Interpretation
- n: Total sample size
- expected_min: Minimum expected frequency
- var_row: Row variable name
- var_col: Column variable name
Notes
-----
Chi-square test of independence tests:
H₀: Two categorical variables are independent
H₁: Two categorical variables are associated
Test statistic:
χ² = Σ[(O - E)² / E]
where O = observed frequencies, E = expected frequencies
Assumptions:
1. Independence of observations
2. Expected frequencies ≥ 5 in at least 80% of cells
3. No expected frequencies < 1
For 2×2 tables with small expected frequencies, use Fisher's exact test instead.
Cramér's V measures strength of association (0 to 1):
- 0 = no association
- 1 = perfect association
References
----------
Cramér, H. (1946). Mathematical Methods of Statistics. Princeton University Press.
Examples
--------
>>> import numpy as np
>>> from scitex_stats.tests.categorical import test_chi2
# Example 1: 2×2 contingency table (treatment × outcome)
>>> observed = np.array([[30, 10], [20, 40]])
>>> result = test_chi2(observed, var_row='Treatment', var_col='Outcome', plot=True)
>>> print(result)
# Example 2: Using DataFrame
>>> import pandas as pd
>>> df = pd.DataFrame([[12, 8, 5], [15, 20, 10]],
... index=['Group A', 'Group B'],
... columns=['Low', 'Med', 'High'])
>>> result = test_chi2(df, plot=True)
# Example 3: Test gender × preference association
>>> observed = np.array([
... [20, 30, 15], # Male: product A, B, C
... [25, 20, 40] # Female: product A, B, C
... ])
>>> result = test_chi2(observed, var_row='Gender', var_col='Product', plot=True)
>>> print(f"χ² = {result['statistic']:.2f}, p = {result['pvalue']:.4f}")
>>> print(f"Cramér's V = {result['effect_size']:.3f} ({result['effect_size_interpretation']})")
# Example 4: Small expected frequencies warning
>>> observed = np.array([[2, 8], [3, 7]]) # Small counts
>>> result = test_chi2(observed)
# Example 5: Export to various formats
>>> result = test_chi2(observed, return_as='dataframe')
>>> convert_results(result, return_as='latex', path='chi2_test.tex')
"""
# Convert to numpy array
if isinstance(observed, pd.DataFrame):
if var_row is None:
var_row = observed.index.name or "row_variable"
if var_col is None:
var_col = observed.columns.name or "col_variable"
observed = observed.values
else:
observed = np.asarray(observed)
if var_row is None:
var_row = "row_variable"
if var_col is None:
var_col = "col_variable"
# Check dimensions
if observed.ndim != 2:
raise ValueError(f"Contingency table must be 2D (got {observed.ndim}D)")
rows, cols = observed.shape
if rows < 2 or cols < 2:
raise ValueError(f"Need at least 2×2 table (got {rows}×{cols})")
# Total sample size
n = int(np.sum(observed))
if n == 0:
raise ValueError("Contingency table is empty (sum = 0)")
# Perform chi-square test
# For 2×2 tables, apply Yates' correction if requested
if rows == 2 and cols == 2 and correction:
chi2_result = stats.chi2_contingency(observed, correction=True)
else:
chi2_result = stats.chi2_contingency(observed, correction=False)
chi2_stat, pvalue, dof, expected = chi2_result
chi2_stat = float(chi2_stat)
pvalue = float(pvalue)
dof = int(dof)
# Compute Cramér's V effect size
v = cramers_v(chi2_stat, n, rows, cols)
df_star = min(rows - 1, cols - 1)
interpretation = interpret_cramers_v(v, df_star)
# Check assumptions
expected_min = float(np.min(expected))
expected_lt5 = np.sum(expected < 5)
expected_lt1 = np.sum(expected < 1)
warnings = []
assumptions_met = True
if expected_lt1 > 0:
warnings.append(f"{expected_lt1} cells have expected frequency < 1")
assumptions_met = False
if expected_lt5 > 0.2 * expected.size:
pct = 100 * expected_lt5 / expected.size
warnings.append(
f"{expected_lt5}/{expected.size} cells ({pct:.1f}%) have expected frequency < 5"
)
assumptions_met = False
if not assumptions_met:
if rows == 2 and cols == 2:
warnings.append(
"Consider using Fisher's exact test for 2×2 table with small counts"
)
# Check significance
significant = pvalue < alpha
stars = p2stars(pvalue)
# Build result
result = {
"test_method": "Chi-square test of independence",
"statistic": round(chi2_stat, decimals),
"stat_symbol": "χ²",
"pvalue": round(pvalue, decimals),
"df": dof,
"alpha": alpha,
"significant": significant,
"stars": stars,
"effect_size": round(v, decimals),
"effect_size_metric": "Cramér's V",
"effect_size_interpretation": interpretation,
"n": n,
"n_rows": rows,
"n_cols": cols,
"expected_min": round(expected_min, decimals),
"assumptions_met": assumptions_met,
"var_row": var_row,
"var_col": var_col,
"H0": f"{var_row} and {var_col} are independent",
}
if warnings:
result["warnings"] = "; ".join(warnings)
# Log results if verbose
if verbose:
logger.info(
f"Chi-square: χ² = {chi2_stat:.3f}, df = {dof}, p = {pvalue:.4f} {p2stars(pvalue)}"
)
logger.info(f"Cramér's V = {v:.3f} ({interpretation})")
if warnings:
logger.warning(f"⚠ {'; '.join(warnings)}")
# Auto-enable plotting if ax is provided
if ax is not None:
plot = True
# Generate plot if requested
if plot:
if ax is None:
# For chi2, we need 3 panels, so create a figure with subplots
fig, axes = _mpl_plt.subplots(1, 3, figsize=(15, 5))
_plot_chi2_full(
observed,
expected,
chi2_stat,
pvalue,
v,
var_row,
var_col,
axes,
)
else:
# If single ax provided, create simplified single-panel plot
_plot_chi2_simple(
observed, expected, chi2_stat, pvalue, v, var_row, var_col, ax
)
# Convert to requested format
if return_as == "dataframe":
result = force_dataframe(result)
elif return_as not in ["dict", "dataframe"]:
return convert_results(result, return_as=return_as)
return result
def _plot_chi2_full(observed, expected, chi2_stat, pvalue, v, var_row, var_col, axes):
"""Create 3-panel visualization for chi-square test."""
rows, cols = observed.shape
# Panel 1: Observed frequencies heatmap
ax = axes[0]
im1 = ax.imshow(observed, cmap="Blues", aspect="auto")
ax.set_title("Observed")
ax.set_xlabel(var_col)
ax.set_ylabel(var_row)
ax.set_xticks(range(cols))
ax.set_yticks(range(rows))
ax.set_xticklabels([f"C{i + 1}" for i in range(cols)])
ax.set_yticklabels([f"R{i + 1}" for i in range(rows)])
# Add values
for i in range(rows):
for j in range(cols):
ax.text(j, i, f"{observed[i, j]:.0f}", ha="center", va="center")
_mpl_plt.colorbar(im1, ax=ax)
# Panel 2: Expected frequencies heatmap
ax = axes[1]
im2 = ax.imshow(expected, cmap="Oranges", aspect="auto")
ax.set_title("Expected")
ax.set_xlabel(var_col)
ax.set_ylabel(var_row)
ax.set_xticks(range(cols))
ax.set_yticks(range(rows))
ax.set_xticklabels([f"C{i + 1}" for i in range(cols)])
ax.set_yticklabels([f"R{i + 1}" for i in range(rows)])
# Add values
for i in range(rows):
for j in range(cols):
ax.text(j, i, f"{expected[i, j]:.1f}", ha="center", va="center")
_mpl_plt.colorbar(im2, ax=ax)
# Panel 3: Residuals (standardized)
ax = axes[2]
residuals = (observed - expected) / np.sqrt(expected)
vmax = max(abs(residuals.min()), abs(residuals.max()))
im3 = ax.imshow(residuals, cmap="RdBu_r", aspect="auto", vmin=-vmax, vmax=vmax)
ax.set_title("Chi-Square Test")
ax.set_xlabel(var_col)
ax.set_ylabel(var_row)
ax.set_xticks(range(cols))
ax.set_yticks(range(rows))
ax.set_xticklabels([f"C{i + 1}" for i in range(cols)])
ax.set_yticklabels([f"R{i + 1}" for i in range(rows)])
# Add values
for i in range(rows):
for j in range(cols):
color = "white" if abs(residuals[i, j]) > vmax / 2 else "black"
ax.text(
j,
i,
f"{residuals[i, j]:.2f}",
ha="center",
va="center",
color=color,
)
_mpl_plt.colorbar(im3, ax=ax)
# Add stats text box
stars_text = p2stars(pvalue).replace("ns", "$n$s")
text_str = (
f"$\\chi^2$ = {chi2_stat:.3f}\n$p$ = {pvalue:.4f} {stars_text}\n$V$ = {v:.3f}"
)
ax.text(
0.02,
0.98,
text_str,
transform=ax.transAxes,
verticalalignment="top",
color="black",
fontsize=6,
)
def _plot_chi2_simple(observed, expected, chi2_stat, pvalue, v, var_row, var_col, ax):
"""Create simplified single-panel residuals plot on given axes."""
rows, cols = observed.shape
# Show standardized residuals
residuals = (observed - expected) / np.sqrt(expected)
vmax = max(abs(residuals.min()), abs(residuals.max()))
im = ax.imshow(residuals, cmap="RdBu_r", aspect="auto", vmin=-vmax, vmax=vmax)
ax.set_title("Chi-Square Test")
ax.set_xlabel(var_col)
ax.set_ylabel(var_row)
ax.set_xticks(range(cols))
ax.set_yticks(range(rows))
ax.set_xticklabels([f"C{i + 1}" for i in range(cols)])
ax.set_yticklabels([f"R{i + 1}" for i in range(rows)])
# Add values
for i in range(rows):
for j in range(cols):
color = "white" if abs(residuals[i, j]) > vmax / 2 else "black"
ax.text(
j,
i,
f"{residuals[i, j]:.2f}",
ha="center",
va="center",
color=color,
)
_mpl_plt.colorbar(im, ax=ax)
# Add stats text box
stars_text = p2stars(pvalue).replace("ns", "$n$s")
text_str = (
f"$\\chi^2$ = {chi2_stat:.3f}\n$p$ = {pvalue:.4f} {stars_text}\n$V$ = {v:.3f}"
)
ax.text(
0.02,
0.98,
text_str,
transform=ax.transAxes,
verticalalignment="top",
color="black",
fontsize=6,
)
# Demo: python -m scitex_stats.tests.categorical._demo_chi2
# EOF