# ==========================================================================================
# 12. EXPLAINABILITY
# ==========================================================================================

import numpy as np
from sklearn.ensemble import RandomForestClassifier

X = np.random.randn(100, 5)
y = np.random.randint(0, 2, size=100)
rf = RandomForestClassifier().fit(X, y)

# ------------------------------------------------------------------------------------------
# Feature importance ≠ causality
# ------------------------------------------------------------------------------------------
# In most models, importance = association with prediction. Does NOT imply causal effect.

# ------------------------------------------------------------------------------------------
# SHAP (SHapley Additive ExPlanations)
# ------------------------------------------------------------------------------------------
# Theory: Allocates model prediction difference fairly among all features, respecting feature interactions.
# Code:
# Requires: pip install shap
# Example for Tree-based:
# shap_explainer = shap.TreeExplainer(rf)
# shap_values = shap_explainer.shap_values(X)
# shap.summary_plot(shap_values, X)
# INTERPRET: Bar plots show global importance; waterfall plots show local explanations.

# ------------------------------------------------------------------------------------------
# LIME (Local Interpretable Model-agnostic Explanations)
# ------------------------------------------------------------------------------------------
# Theory: Locally approximates model prediction via simple surrogate (e.g., linear) model.
# Requires: pip install lime
# Code:
# lime_explainer = lime.lime_tabular.LimeTabularExplainer(X, mode='classification')
# exp = lime_explainer.explain_instance(X[0], rf.predict_proba)
# exp.show_in_notebook() # or exp.as_list()
# INTERPRET: Lists features with biggest impact on prediction for given instance.

# ------------------------------------------------------------------------------------------
# Global vs Local explanations
# ------------------------------------------------------------------------------------------
# Global: Overall, which features drive model? (feature_importances_, mean SHAP values)
# Local: Why *this* prediction? (instance-specific SHAP/LIME).

# ==========================================================================================
# 14. DATA VISUALIZATION AS STATISTICAL REASONING TOOL
# ==========================================================================================

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LinearRegression
from sklearn.metrics import roc_curve, precision_recall_curve, learning_curve, validation_curve
from statsmodels.graphics.gofplots import qqplot

# Each block: Python code + What statistical assumption it checks + exam conclusion + exam pitfall

# -----------------------------
# Univariate Visualizations
# -----------------------------

# Histogram: Detects shape, skewness, and modality. Examines normality.
plt.hist(np.random.randn(1000), bins=30)
plt.title('Histogram: Univariate Distribution')
plt.xlabel('Value')
plt.ylabel('Frequency')
plt.show()
# Checks: Normality, outliers, skew. Conclusion: Symmetry or lack thereof.
# Pitfall: Histogram bin size can hide/reveal features; never infer normality from perfect symmetry alone.

# KDE (Kernel Density Estimate): Smooth estimate of PDF
sns.kdeplot(np.random.randn(1000))
plt.title('KDE: Density Shape')
plt.xlabel('Value')
plt.show()
# Checks: Skewness, modality. Conclusion: Multimodal? Heavy tails?
# Pitfall: Bandwidth setting can over/under-smooth.

# Boxplot: Quartiles, median, outliers.
sns.boxplot(np.random.randn(1000))
plt.title('Boxplot: Outlier & Spread Detection')
plt.show()
# Checks: Outliers, symmetry, central tendency. Used especially for comparing groups.
# Pitfall: Insufficient for tails, not useful with few observations.

# Violin plot: Combines boxplot and KDE
sns.violinplot(np.random.randn(1000))
plt.title('Violin Plot')
plt.show()
# Checks: Distribution shape plus spread. Use for groupwise comparison.
# Pitfall: Overinterpreting minor bumps (sensitive to bandwidth).

# -----------------------------
# Bivariate Visualizations
# -----------------------------

# Scatter plot: Tests linearity, heteroscedasticity, grouping
x = np.random.normal(0,1,100)
y = 2*x + np.random.normal(0,1,100)
plt.scatter(x, y)
plt.title('Scatter Plot: Linearity Check')
plt.xlabel('x')
plt.ylabel('y')
plt.show()
# Checks: Linearity assumption needed for regression.
# Pitfall: Easy to misread cause-effect; overplotting obscures trend for large n.

# Pairplot: All pairwise scatterplots in a dataframe
df = pd.DataFrame(np.random.randn(100, 4), columns=['A','B','C','D'])
sns.pairplot(df)
plt.suptitle("Pairplot: Explore All Pairwise Dependencies")
plt.show()
# Checks: Linear/nonlinear relation, clusters, outliers.
# Pitfall: Only works for low-dimensional data (visual overload past ~6D).

# Hexbin: Heatmap for scatter (large n)
x = np.random.randn(10000)
y = x + np.random.randn(10000)
plt.hexbin(x, y, gridsize=50, cmap='Blues')
plt.xlabel('x')
plt.ylabel('y')
plt.title('Hexbin: Density Structure for Large n')
plt.colorbar()
plt.show()
# Checks: Bivariate structure, density shapes for lots of points.
# Pitfall: Binning can obscure tails or outliers.

# Correlation heatmap
corr = df.corr()
sns.heatmap(corr, annot=True, cmap='vlag')
plt.title("Correlation Heatmap")
plt.show()
# Checks: Highly correlated predictors (multicollinearity).
# Pitfall: Correlation ≠ causation; miss nonlinear relationships.

# -----------------------------
# Multivariate Visualizations
# -----------------------------

# PCA projection plots (2D)
scaler = StandardScaler().fit(df)
df_scaled = scaler.transform(df)
pca = PCA(n_components=2).fit(df_scaled)
proj = pca.transform(df_scaled)
plt.scatter(proj[:,0], proj[:,1])
plt.title('PCA Projection: Multivariate Structure')
plt.xlabel('PC1')
plt.ylabel('PC2')
plt.show()
# Checks: Data structure, clustering, or separation. Good for pre-clustering view.
# Pitfall: Only interprets linear structure; ignores nonlinear.

# Parallel coordinates (multivariate patterns)
from pandas.plotting import parallel_coordinates
df['cluster'] = np.random.randint(0, 3, size=100)
parallel_coordinates(df, 'cluster')
plt.title('Parallel Coordinates: Multivariate Trends')
plt.show()
# Checks: Groupwise differences, high-dim patterns.
# Pitfall: Overplotting for many dimensions, groups.

# Andrews curves: Multivariate normality, outlier structure
from pandas.plotting import andrews_curves
andrews_curves(df, 'cluster')
plt.title('Andrews Curves: Shape Structure in High-D')
plt.show()
# Checks: Separability, class overlap.
# Pitfall: Not interpretable for very high d.

# -----------------------------
# Model Diagnostics Visualizations
# -----------------------------
# Residuals vs Fitted: Heteroscedasticity and nonlinearity
X_diag = np.random.uniform(-3,3,100)
y_diag = 2*X_diag + np.random.normal(0,1,100)
model = LinearRegression().fit(X_diag.reshape(-1,1), y_diag)
resid = y_diag - model.predict(X_diag.reshape(-1,1))
plt.scatter(model.predict(X_diag.reshape(-1,1)), resid)
plt.axhline(0, color='red', ls='--')
plt.title('Residuals vs Fitted')
plt.xlabel('Fitted values')
plt.ylabel('Residuals')
plt.show()
# Checks: Nonlinear trends (should see cloud); cone shape = heteroscedasticity.
# Pitfall: Clusters or trends show model misspecification.

# Q-Q Plot: Normality of residuals
qqplot(resid, line='s')
plt.title('Q-Q Plot of Residuals')
plt.show()
# Checks: Normality (points should follow diagonal).
# Pitfall: Outliers and heavy tails easily missed in large n.

# ROC Curve: Classification discrimination
from sklearn.datasets import make_classification
Xc, yc = make_classification(n_samples=100, n_features=4, random_state=42)
cl = LogisticRegression().fit(Xc, yc)
fpr, tpr, _ = roc_curve(yc, cl.predict_proba(Xc)[:,1])
plt.plot(fpr, tpr)
plt.plot([0,1],[0,1],'k--')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curve')
plt.show()
# Checks: Class separability, threshold-independent performance.
# Pitfall: Inflates performance in high class imbalance.

# Precision–Recall Curve: Classifier for imbalanced data
prec, rec, _ = precision_recall_curve(yc, cl.predict_proba(Xc)[:,1])
plt.plot(rec, prec)
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.title('Precision-Recall Curve')
plt.show()
# Checks: Useful in imbalanced settings.
# Pitfall: PR AUC can still be misleading; always check baseline.

# Learning curves: Overfitting/underfitting via train/validation score as n increases
train_sizes, train_scores, val_scores = learning_curve(cl, Xc, yc, cv=5)
plt.plot(train_sizes, np.mean(train_scores, axis=1), label='Train')
plt.plot(train_sizes, np.mean(val_scores, axis=1), label='Validation')
plt.legend()
plt.xlabel('Training Set Size')
plt.ylabel('Score')
plt.title('Learning Curve')
plt.show()
# Checks: Gap large = overfitting; both low = underfitting.
# Pitfall: Misreading noise due to small validation sizes.

# Validation curves: Model complexity diagnosis
param_range = np.arange(1, 10)
train_scores, val_scores = validation_curve(cl, Xc, yc, param_name="C", param_range=param_range, cv=3)
plt.plot(param_range, np.mean(train_scores, axis=1), label='Train')
plt.plot(param_range, np.mean(val_scores, axis=1), label='Validation')
plt.legend()
plt.xlabel('Hyperparameter (C)')
plt.ylabel('Score')
plt.title('Validation Curve')
plt.show()
# Checks: Identify overfitting/underfitting regimes.

# ==========================================================================================
# 34. CLASSIFICATION VISUAL DIAGNOSTICS
# ==========================================================================================

from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, RocCurveDisplay, PrecisionRecallDisplay, brier_score_loss
from sklearn.calibration import calibration_curve

# (A) CONFUSION MATRIX — Statistical evidence for performance
y_true = np.random.randint(0, 2, 100)
clf_viz = RandomForestClassifier(random_state=0).fit(X, y_true)
y_pred = clf_viz.predict(X)
cm_raw = confusion_matrix(y_true, y_pred)
cmd = ConfusionMatrixDisplay(cm_raw)
cmd.plot()
plt.title('Confusion Matrix (Raw)')
plt.show()

cm_norm = confusion_matrix(y_true, y_pred, normalize='true')
cmd_norm = ConfusionMatrixDisplay(cm_norm)
cmd_norm.plot()
plt.title('Confusion Matrix (Normalized)')
plt.show()

# Purpose: Quantifies core outcomes (TP/FP/FN/TN). Checks: error asymmetry, misclassification patterns.
# Accuracy = trace(cm) / sum(cm); Precision=TP/(TP+FP); Recall=TP/(TP+FN)
# Exam traps: High accuracy can mask poor minority class; imbalance distorts prevalence. Cost asymmetry (false negatives cost more?).
# Symmetry illusion: Matrix may look "good" even if only majority class predicted.

# (B) ROC CURVE & AUC — Statistical discriminatory evidence
y_prob_viz = clf_viz.predict_proba(X)[:,1]
fpr, tpr, _ = roc_curve(y_true, y_prob_viz)
RocCurveDisplay(fpr=fpr, tpr=tpr).plot()
plt.title('ROC Curve')
plt.show()
auc = roc_auc_score(y_true, y_prob_viz)

# Purpose: Evaluates ranking over thresholds. TPR vs. FPR shows tradeoff at all possible thresholds.
# Exam traps: ROC can look good even with poor minority detection ("class imbalance paradox"), area under curve does NOT reflect true positive rate at a particular cutoff.
# Comparing ROC curves requires DeLong/statistical test for significance.

# (C) PRECISION–RECALL CURVE (PR-ROC)
prec_viz, rec_viz, _ = precision_recall_curve(y_true, y_prob_viz)
PrecisionRecallDisplay(precision=prec_viz, recall=rec_viz).plot()
plt.title('Precision–Recall Curve')
plt.show()
ap = average_precision_score(y_true, y_prob_viz)
# Purpose: Focuses on positive class ("rare class"). Preferred for rare events.
# Relation to prevalence: Baseline is random guess = positive frequency.
# Exam traps: Interpreting PR-AUC without knowing prevalence; PR-AUC and ROC-AUC measure conceptually different things.

# (D) CALIBRATION CURVES (Reliability Diagram)
prob_true, prob_pred = calibration_curve(y_true, y_prob_viz, n_bins=10)
plt.plot(prob_pred, prob_true, marker='o')
plt.plot([0,1], [0,1], 'k--', label='Perfectly Calibrated')
plt.title('Calibration (Reliability) Diagram')
plt.xlabel('Mean Predicted Probability')
plt.ylabel('Fraction of Positives')
plt.legend()
plt.show()
brier = brier_score_loss(y_true, y_prob_viz)
# Purpose: Checks if predicted probabilities reflect actual outcomes.
# Exam traps: Good ROC does NOT mean good calibration (model can rank well but be uncalibrated).
# Ranking vs calibration: Can have perfect AUC but very poor probability estimates.

# ==========================================================================================
# 42. CAUSAL & DEPENDENCE VISUALIZATION
# ==========================================================================================

# PDP/ICE
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.inspection import PartialDependenceDisplay
import numpy as np

Xlc_ex = np.random.randn(100, 4)
ylc_ex = np.random.randint(0, 2, 100)
clf_pdp = GradientBoostingClassifier().fit(Xlc_ex, ylc_ex)
PartialDependenceDisplay.from_estimator(clf_pdp, Xlc_ex, features=[0])
plt.suptitle('Partial Dependence Plot (PDP)')
plt.show()
# ICE
PartialDependenceDisplay.from_estimator(clf_pdp, Xlc_ex, features=[0], kind="individual")
plt.suptitle('ICE Plots')
plt.show()
# Exam: PDP = average effect (marginal); ICE = effect per instance (conditional); Correlated features can make PDP misleading (extrapolation outside true feature support).

# ==========================================================================================
# 43. UNCERTAINTY VISUALIZATION
# ==========================================================================================

# CI: Bootstrapping example
means = []
for _ in range(1000):
    means.append(np.mean(np.random.choice(Xlc_ex[:,0], replace=True, size=len(Xlc_ex))))
ci_lower = np.percentile(means, 2.5)
ci_upper = np.percentile(means, 97.5)
plt.hist(means, bins=40)
plt.axvline(ci_lower, color='red')
plt.axvline(ci_upper, color='red')
plt.title('Bootstrap Distribution & Confidence Interval')
plt.show()
# Purpose: Aleatoric = irreducible (noise); Epistemic = model-based (parameter) uncertainty.

# ==========================================================================================
# 44. FINAL VISUALIZATION EXAM TRAPS (MANDATORY)
# ==========================================================================================

# Over-plotting: Too much data per plot obscures pattern (e.g., scatter for n>10k); solution: density plots/hexbin.
# Cherry-picking axes: Projections may "invent" clusters, especially in high-d.
# Post-hoc visualization bias: Generating many plots and reporting only "interesting" ones inflates Type I error.
# Visual causality illusion: Apparent associations or trends do not imply causality; confounders, temporal dependence, and data-generating process must be considered.

# ===========================================================================
# (RESTATE)
# Visualization answers, "What statistical assumption does this check?"
# It is evidence supporting/undermining a modeling or inferential claim.
# ===========================================================================
