# ==========================================================================================
# 3. STATISTICAL HYPOTHESIS TESTING
# ==========================================================================================

import numpy as np
from scipy import stats
from sklearn.utils import resample

# ------------------------------------------------------------------------------------------
# Z-test (Large n, known variance; rarely used in practice)
# ------------------------------------------------------------------------------------------
# Null: means are equal.
# When: Large n, known std, comparisons of means (use t-test otherwise).
# Not directly in scipy; use one-sample test as example:

def z_test(sample, mu0, sigma0):
    """Z-test for observed mean vs population mean with known variance"""
    n = len(sample)
    z = (np.mean(sample) - mu0) / (sigma0 / np.sqrt(n))
    p = 2 * (1 - stats.norm.cdf(abs(z)))
    # Interpret: Small p (<0.05): Reject mu0 as true mean.
    return z, p

# EXAM TRAP: Don't use Z-test with unknown variance or small samples.

# ------------------------------------------------------------------------------------------
# Student t-test (1-sample, independent, paired)
# ------------------------------------------------------------------------------------------
# 1-sample t-test: compare sample mean to population mean
sample = np.random.normal(0, 1, size=20)
tstat, pval = stats.ttest_1samp(sample, popmean=0)
# stats.ttest_1samp: Null = mean == 0
# Example: p < 0.05: sample mean differs from 0.

# Independent samples t-test (equal variances)
a = np.random.normal(0, 1, 30)
b = np.random.normal(0, 1, 30)
tstat, pval = stats.ttest_ind(a, b, equal_var=True)
# stats.ttest_ind: Null = means equal.
# equal_var=True: assumes equal variance.

# Paired t-test
before = np.random.normal(loc=1, scale=1, size=30)
after = before + np.random.normal(loc=0, scale=0.2, size=30)
tstat, pval = stats.ttest_rel(before, after)
# stats.ttest_rel: Null = means of paired samples equal (repeated measures).

# ------------------------------------------------------------------------------------------
# Welch t-test (unequal variances)
# ------------------------------------------------------------------------------------------
tstat, pval = stats.ttest_ind(a, b, equal_var=False)
# stats.ttest_ind with equal_var=False: Adjusts for unequal sample variances.
# Null = means equal.
# TYPICAL EXAM TRAP: Don't use Student t-test with unequal variances.

# ------------------------------------------------------------------------------------------
# Mann–Whitney U (Wilcoxon rank-sum)
# ------------------------------------------------------------------------------------------
stat, pval = stats.mannwhitneyu(a, b, alternative='two-sided')
# Nonparametric; compares medians of independent samples
# Null = distributions are equal.
# Use for ordinal or non-normal scale.
# TYPICAL EXAM TRAP: Requires samples have similar shape.

# ------------------------------------------------------------------------------------------
# Wilcoxon signed-rank
# ------------------------------------------------------------------------------------------
stat, pval = stats.wilcoxon(before, after)
# Nonparametric paired comparison.
# Null = distributions equal.
# Use for paired non-normal data.
# Not for independent samples!

# ------------------------------------------------------------------------------------------
# ANOVA (One-way)
# ------------------------------------------------------------------------------------------
group1 = np.random.normal(0, 1, 30)
group2 = np.random.normal(0.1, 1, 30)
group3 = np.random.normal(-0.1, 1, 30)
fstat, pval = stats.f_oneway(group1, group2, group3)
# stats.f_oneway: Null = all means equal.
# Use for k>2 independent groups.
# If p < 0.05: At least one mean is different.

# ------------------------------------------------------------------------------------------
# Kruskal–Wallis (Nonparametric ANOVA)
# ------------------------------------------------------------------------------------------
hstat, pval = stats.kruskal(group1, group2, group3)
# stats.kruskal: Null = distributions equal.
# Use for k>2 independent, non-normal groups.

# ------------------------------------------------------------------------------------------
# Chi-square test of independence
# ------------------------------------------------------------------------------------------
table = np.array([[10, 20], [20, 20]])
chi2_val, p, dof, expected = stats.chi2_contingency(table)
# stats.chi2_contingency: Null = variables are independent.
# Use on categorical data in contingency tables.

# ------------------------------------------------------------------------------------------
# Fisher exact test (2x2 tables, small counts)
# ------------------------------------------------------------------------------------------
oddsratio, p = stats.fisher_exact(table)
# stats.fisher_exact: Null = independence.
# Use for small N; precise calculation.
# Only for 2x2 tables.

# ------------------------------------------------------------------------------------------
# Kolmogorov–Smirnov two-sample
# ------------------------------------------------------------------------------------------
dstat, pval = stats.ks_2samp(a, b)
# stats.ks_2samp: Null = samples from same distribution.
# Nonparametric. Sensitive to shape/location.

# ------------------------------------------------------------------------------------------
# Permutation Tests
# ------------------------------------------------------------------------------------------
def permutation_t_test(a, b, n_permutations=10000):
    """Permutation test for difference of means between a, b."""
    observed = np.mean(a) - np.mean(b)
    combined = np.concatenate([a, b])
    count = 0
    for _ in range(n_permutations):
        np.random.shuffle(combined)
        new_a = combined[:len(a)]
        new_b = combined[len(a):]
        diff = np.mean(new_a) - np.mean(new_b)
        if abs(diff) >= abs(observed):
            count += 1
    pvalue = count / n_permutations
    # INTERPRET: If p < 0.05, difference unlikely under null.
    return pvalue
# Nonparametric; makes minimal assumptions.

# ------------------------------------------------------------------------------------------
# Bootstrap Confidence Intervals
# ------------------------------------------------------------------------------------------
def bootstrap_ci(data, n_bootstrap=10000, ci=0.95):
    """Bootstrap confidence interval for mean."""
    means = [np.mean(resample(data)) for _ in range(n_bootstrap)]
    lower = np.percentile(means, (1 - ci) / 2 * 100)
    upper = np.percentile(means, (1 + ci) / 2 * 100)
    # INTERPRET: With 95% confidence, population mean is in [lower, upper].
    return lower, upper
# TYPICAL EXAM TRAP: Bootstrap assumes IID data.
