#!/usr/bin/env python3
"""Expanded sanity/stress checks for CRCAAgent (v1.2.0).

- Offline path exercises graph ops, batch predict, caching, fitting, uncertainty,
  sensitivity, Shapley, VAR/Granger, multi-layer, alternate realities, Bellman,
  async wrappers, optimization, Bayesian inference, root cause analysis, and vector autoregression.
- Optional LLM path runs only if OPENAI_API_KEY is set; still lightweight.
- Exit code: 0 on success; non-zero on failure.
"""
from importlib.util import spec_from_file_location, module_from_spec
from pathlib import Path
import os
import sys
import math
import asyncio
import warnings

# Suppress litellm cleanup RuntimeWarning noise for this test script
warnings.filterwarnings(
    "ignore",
    message="coroutine 'close_litellm_async_clients' was never awaited",
    category=RuntimeWarning,
)


def load_crca_agent():
    repo_root = Path(__file__).resolve().parents[1]
    module_path = repo_root / "CR-CA" / "CRCA.py"
    spec = spec_from_file_location("crca_module", str(module_path))
    mod = module_from_spec(spec)
    spec.loader.exec_module(mod)  # type: ignore
    return mod.CRCAAgent


def mask_key(k: str) -> str:
    if not k:
        return "<missing>"
    if len(k) <= 8:
        return k[:2] + "*" * (len(k) - 2)
    return k[:4] + "*" * (len(k) - 8) + k[-4:]


def main():
    CRCAAgent = load_crca_agent()
    failures = []
    rng_seed = 42
    try:
        import pandas as pd  # type: ignore
        PANDAS_AVAILABLE = True
    except Exception:
        PANDAS_AVAILABLE = False
    try:
        import numpy as np  # type: ignore
        NUMPY_AVAILABLE = True
    except Exception:
        NUMPY_AVAILABLE = False
        np = None

    # Test 1: add edge and read strength
    try:
        a = CRCAAgent(variables=["A", "B"], seed=rng_seed)
        a.add_causal_relationship("A", "B", strength=0.72, confidence=0.9)
        s = a._edge_strength("A", "B")
        assert abs(s - 0.72) < 1e-6, f"edge strength {s} != 0.72"
        print("PASS: add_causal_relationship -> _edge_strength")
    except Exception as e:
        failures.append(("edge_strength", str(e)))

    # Test 2: topological sort and is_dag
    try:
        a2 = CRCAAgent(variables=["n1", "n2", "n3"], seed=rng_seed)
        a2.add_causal_relationship("n1", "n2", strength=1.0)
        a2.add_causal_relationship("n2", "n3", strength=1.0)
        order = a2._topological_sort()
        assert order.index("n1") < order.index("n2") < order.index("n3")
        assert a2.is_dag() is True
        print("PASS: _topological_sort and is_dag")
    except Exception as e:
        failures.append(("topo_is_dag", str(e)))

    # Test 3: rustworkx duplicate-edge protection and metadata update
    try:
        a3 = CRCAAgent(variables=["X", "Y"], seed=rng_seed)
        # initial edges count
        try:
            before_edges = a3._graph.num_edges()
        except Exception:
            before_edges = None
        a3.add_causal_relationship("X", "Y", strength=0.5)
        # add again with different strength to trigger update path
        a3.add_causal_relationship("X", "Y", strength=0.77)
        s_updated = a3._edge_strength("X", "Y")
        assert abs(s_updated - 0.77) < 1e-6, f"edge strength not updated: {s_updated}"
        # ensure rustworkx didn't add duplicate edge (num_edges unchanged or increased by at most 1 if nodes added)
        try:
            after_edges = a3._graph.num_edges()
            if before_edges is not None:
                assert after_edges - before_edges <= 1
        except Exception:
            # If rustworkx didn't support num_edges, ignore this check
            pass
        print("PASS: rustworkx duplicate-edge protection and metadata update")
    except Exception as e:
        failures.append(("rustworkx_edge_update", str(e)))

    # Test 4: prediction cache behavior (cached wrapper)
    try:
        a4 = CRCAAgent(variables=["Z"], seed=rng_seed, enable_batch_predict=True, max_batch_size=8)
        a4.set_standardization_stats("Z", mean=0.0, std=1.0)
        a4._prediction_cache_max = 5
        a4.enable_cache(True)
        a4.clear_cache()
        factual = {"Z": 0.0}
        # call cached predictor twice
        r1 = a4._predict_outcomes_cached(factual, {})
        r2 = a4._predict_outcomes_cached(factual, {})
        # check cache populated
        cache_len = len(getattr(a4, "_prediction_cache", {}))
        assert cache_len >= 1, "prediction cache not populated"
        assert r1 == r2, "cached results differ between calls"
        # batch predict
        batch = a4._predict_outcomes_batch([{"Z": 1.0}, {"Z": -1.0}], interventions={"Z": 0.5})
        assert len(batch) == 2 and all("Z" in x for x in batch), "batch predict failed"
        print("PASS: _predict_outcomes_cached populates cache and returns consistent results")
    except Exception as e:
        failures.append(("predict_cache", str(e)))

    # Test 5: basic AAP smoke (deterministic)
    try:
        a5 = CRCAAgent(variables=["price", "volume", "momentum", "trading_volume", "market_sentiment"], seed=rng_seed)
        a5.set_standardization_stats("price", mean=1.0, std=1.0)
        a5.set_standardization_stats("volume", mean=100000.0, std=20000.0)
        a5.set_standardization_stats("momentum", mean=0.0, std=0.05)
        a5.set_standardization_stats("trading_volume", mean=100000.0, std=20000.0)
        a5.add_causal_relationship("volume", "price", strength=0.3)
        a5.add_causal_relationship("momentum", "price", strength=0.4)
        factual = {"price": 1.0, "volume": 100000.0, "momentum": 0.02, "trading_volume": 120000.0}
        res = a5.aap(factual, {"momentum": 0.05})
        assert "price" in res and isinstance(res["price"], (int, float))
        print("PASS: AAP smoke (price present and numeric)")
    except Exception as e:
        failures.append(("aap_smoke", str(e)))

    # Test 6: data fitting, uncertainty, optimization, async, and advanced analyses (requires pandas/numpy/scipy)
    if PANDAS_AVAILABLE and NUMPY_AVAILABLE:
        import pandas as pd  # type: ignore
        import numpy as np  # type: ignore
        try:
            a6 = CRCAAgent(
                variables=["price", "volume", "momentum"],
                enable_batch_predict=True,
                max_batch_size=16,
                bootstrap_workers=2,
                use_async=True,
                seed=rng_seed,
            )
            a6.add_causal_relationship("volume", "price", strength=0.0)
            a6.add_causal_relationship("momentum", "price", strength=0.0)
            # synthetic data
            n = 60
            ts = pd.DataFrame({
                "price": 100 + np.linspace(0, 5, n) + np.random.normal(0, 1, n),
                "volume": 1e5 + np.random.normal(0, 2e4, n),
                "momentum": np.random.normal(0, 0.05, n),
            })
            a6.fit_from_dataframe(ts, variables=["price", "volume", "momentum"], window=30, decay_alpha=0.9)
            s_vol = a6._edge_strength("volume", "price")
            assert not math.isnan(s_vol), "fit_from_dataframe produced NaN"
            print("PASS: fit_from_dataframe")

            unc = a6.quantify_uncertainty(ts, variables=["price", "volume", "momentum"], windows=10, alpha=0.9)
            assert "edge_cis" in unc
            print("PASS: quantify_uncertainty")

            # Async wrappers (v1.2.0+)
            loop = asyncio.new_event_loop()
            asyncio.set_event_loop(loop)
            try:
                unc_async = loop.run_until_complete(
                    a6.quantify_uncertainty_async(ts, variables=["price", "volume", "momentum"], windows=6, alpha=0.8)
                )
                assert "edge_cis" in unc_async
                gr_async = loop.run_until_complete(
                    a6.granger_causality_test_async(ts, var1="volume", var2="price", max_lag=2)
                )
                assert "granger_causes" in gr_async or "error" in gr_async
                var_async = loop.run_until_complete(
                    a6.vector_autoregression_estimation_async(
                        ts, variables=["price", "volume", "momentum"], max_lag=2
                    )
                )
                assert "coefficient_matrices" in var_async
                run_async_result = loop.run_until_complete(
                    a6.run_async(
                        initial_state={"price": 100.0, "volume": 100000.0, "momentum": 0.01},
                        max_steps=1
                    )
                )
                assert "evolved_state" in run_async_result
                print("PASS: async wrappers (v1.2.0: quantify_uncertainty_async, granger_causality_test_async, vector_autoregression_estimation_async, run_async)")
            finally:
                try:
                    if not loop.is_closed():
                        loop.run_until_complete(loop.shutdown_asyncgens())
                    loop.close()
                except Exception:
                    pass

            try:
                opt = a6.gradient_based_intervention_optimization(
                    initial_state={"price": 101.0, "volume": 120000.0, "momentum": 0.01},
                    target="price",
                    intervention_vars=["volume", "momentum"],
                    constraints={"volume": (80000.0, 150000.0), "momentum": (-0.2, 0.2)},
                )
                assert "optimal_intervention" in opt
                print("PASS: gradient_based_intervention_optimization")
            except Exception as e:
                failures.append(("gradient_opt", str(e)))

            try:
                gr = a6.granger_causality_test(ts, var1="volume", var2="price", max_lag=2)
                assert "granger_causes" in gr
                print("PASS: granger_causality_test")
            except Exception as e:
                failures.append(("granger", str(e)))

            try:
                info = a6.compute_information_theoretic_measures(ts, variables=["price", "volume", "momentum"])
                assert "entropies" in info
                print("PASS: compute_information_theoretic_measures")
            except Exception as e:
                failures.append(("info_theory", str(e)))

            try:
                sens = a6.sensitivity_analysis(
                    intervention={"volume": 120000.0, "momentum": 0.02},
                    target="price",
                    perturbation_size=0.01,
                )
                assert sens.get("most_influential_variable") is not None
                print("PASS: sensitivity_analysis")
            except Exception as e:
                failures.append(("sensitivity", str(e)))

            try:
                shap = a6.shapley_value_attribution(
                    baseline_state={"price": 100.0, "volume": 100000.0, "momentum": 0.0},
                    target_state={"price": 102.0, "volume": 130000.0, "momentum": 0.03},
                    target="price",
                    samples=20,
                )
                assert "shapley_values" in shap
                print("PASS: shapley_value_attribution")
            except Exception as e:
                failures.append(("shapley", str(e)))

            try:
                multi = a6.multi_layer_whatif_analysis(
                    scenarios=[{"volume": 110000.0}, {"momentum": 0.05}],
                    depth=2,
                )
                assert "multi_layer_analysis" in multi
                print("PASS: multi_layer_whatif_analysis")
            except Exception as e:
                failures.append(("multi_layer", str(e)))

            try:
                alt = a6.explore_alternate_realities(
                    factual_state={"price": 101.0, "volume": 120000.0, "momentum": 0.02},
                    target_outcome="price",
                    target_value=None,
                    max_realities=10,
                    max_interventions=2,
                )
                assert "best_reality" in alt
                print("PASS: explore_alternate_realities")
            except Exception as e:
                failures.append(("alternate_realities", str(e)))

            try:
                bell = a6.bellman_optimal_intervention(
                    initial_state={"price": 100.0, "volume": 100000.0, "momentum": 0.01},
                    target="price",
                    intervention_vars=["volume", "momentum"],
                    horizon=3,
                    discount=0.9,
                )
                assert "optimal_sequence" in bell
                print("PASS: bellman_optimal_intervention")
            except Exception as e:
                failures.append(("bellman", str(e)))

            try:
                var_result = a6.vector_autoregression_estimation(
                    ts, variables=["price", "volume", "momentum"], max_lag=2
                )
                assert "coefficient_matrices" in var_result
                print("PASS: vector_autoregression_estimation")
            except Exception as e:
                failures.append(("var_estimation", str(e)))

            try:
                bayesian_result = a6.bayesian_edge_inference(
                    ts, parent="volume", child="price", prior_mu=0.0, prior_sigma=1.0
                )
                assert "posterior_mean" in bayesian_result
                print("PASS: bayesian_edge_inference")
            except Exception as e:
                failures.append(("bayesian", str(e)))

            try:
                rca_result = a6.deep_root_cause_analysis(
                    problem_variable="price",
                    max_depth=10,
                    min_path_strength=0.01
                )
                assert "all_root_causes" in rca_result
                print("PASS: deep_root_cause_analysis")
            except Exception as e:
                failures.append(("root_cause", str(e)))

        except Exception as e:
            failures.append(("pandas_block", str(e)))
    else:
        print("SKIP: pandas/numpy not available; skipped data-fitting/uncertainty/optimization tests")

    # Check OPENAI_API_KEY presence (do not call network)
    key = os.environ.get("OPENAI_API_KEY", "")
    if key:
        print("OPENAI_API_KEY found:", mask_key(key))
    else:
        print("OPENAI_API_KEY not set (ok — LLM network tests skipped)")

    # Optional LLM integration smoke test (runs only if OPENAI_API_KEY is present)
    if key:
        try:
            # Initialize agent with a GROQ model for LLM smoke test
            agent_llm = CRCAAgent(
                variables=["price", "demand"],
                model_name="gpt-4o-mini",
                max_loops=1,
                enable_batch_predict=True,
                seed=rng_seed,
            )
            # Short LLM task to exercise the LLM integration path
            task = "Provide a one-sentence causal analysis: describe why increasing price might change demand."
            # Run the LLM causal analysis workflow
            res = agent_llm._run_llm_causal_analysis(task)
            ca = res.get("causal_analysis", "")
            # Basic checks and sample output (be permissive about returned shape)
            ok = False
            if isinstance(ca, str) and ca.strip():
                ok = True
            elif isinstance(res.get("analysis_steps"), list) and len(res.get("analysis_steps", [])) > 0:
                ok = True

            if ok:
                print("PASS: LLM integration smoke (causal_analysis present)")
                if isinstance(ca, str) and ca.strip():
                    sample = " ".join(ca.strip().splitlines())[:240]
                else:
                    sample = str(res)[:240]
                print("LLM sample:", sample)
            else:
                failures.append(("llm_smoke", f"unexpected LLM result shape: {list(res.keys())}"))

            # LLM-aware run with structured initial state (small)
            cf_res = agent_llm.run(
                task={"price": 10.0, "demand": 5.0},
                initial_state={"price": 10.0, "demand": 5.0},
                target_variables=["price", "demand"],
                max_steps=1,
            )
            if not isinstance(cf_res, dict) or "evolved_state" not in cf_res:
                failures.append(("llm_run_structured", "run() did not return expected shape"))
            else:
                print("PASS: LLM run() structured state path")
        except Exception as e:
            failures.append(("llm_smoke", str(e)))

        # Best-effort cleanup for litellm async clients to avoid warnings
        try:
            from litellm.llms.custom_httpx.async_client_cleanup import close_litellm_async_clients  # type: ignore
            cleanup_loop = asyncio.new_event_loop()
            asyncio.set_event_loop(cleanup_loop)
            cleanup_loop.run_until_complete(close_litellm_async_clients())
            cleanup_loop.close()
        except Exception:
            pass
    # Report
    if failures:
        print("\nSANITY CHECK FAILED:")
        for name, msg in failures:
            print(f"- {name}: {msg}")
        print("\nFix the above issues before considering prod deployment.")
        sys.exit(2)
    else:
        print("\nAll sanity checks passed.")
        sys.exit(0)


if __name__ == "__main__":
    main() 


