#! /usr/bin/env python
# -*- coding: utf-8 -*-
'''
SIGNET: Estimation and community detection of a signed network.
Copyright(c)  Ryuichiro Nakato <rnakato@iqb.u-tokyo.ac.jp>
All rights reserved.
'''

"""
signet CLI (TSV / MAT) + (LouvainSigned / Leiden modularity / Leiden CPM)

Examples
--------
# TSV (positive + negative), Leiden modularity (alpha), full outputs
signet --pos data/hsc_CDI_score_data_thre10.txt --neg data/hsc_EEI_score_data_thre5.txt \
  --thre-pos 10 --thre-neg 5 --method leiden-mod-alpha --alpha 0.6 --resolution 1.0 --seed 12345 \
  --out-prefix results/hsc_alpha06

# Outputs generated:
#   results/hsc_alpha06_partition.tsv
#   results/hsc_alpha06_summary.txt
#   results/hsc_alpha06_communities.gmt
#   results/hsc_alpha06_community_sizes.pdf
#   results/hsc_alpha06_inter_community.pdf
#   results/hsc_alpha06_module_0.pdf  ...

# Disable plots
signet --pos ... --neg ... --method leiden-mod-alpha --out-prefix results/run1 --no-plot

# Disable GMT
signet --pos ... --neg ... --method leiden-mod-alpha --out-prefix results/run1 --no-gmt

# SNAP signed TSV (Epinions), Leiden modularity (alpha)
signet --format snap --snap data/soc-sign-epinions.txt.gz --method leiden-mod-alpha \
  --alpha 0.6 --resolution 1.0 --seed 12345 --out-prefix results/epinions_alpha06

# SNAP signed TSV, single signed CPM
signet --format snap --snap data/soc-sign-epinions.txt.gz --method leiden-cpm-single \
  --lambda-neg 1.0 --gamma 0.5 --min-degree 2 --degree-mode union

# MAT (MATLAB format), Leiden modularity (alpha)
signet --format mat --mat data/epinions_data.mat --method leiden-mod-alpha --alpha 0.6 --resolution 1.0

# MAT (MATLAB format), single signed CPM
signet --format mat --mat data/epinions_data.mat --method leiden-cpm-single --lambda-neg 1.0 --gamma 0.5
"""

import argparse
import sys
from collections import Counter
from pathlib import Path

import numpy as np
import igraph as ig

try:
    from importlib.metadata import version as pkg_version
    __version__ = pkg_version("signet")
except Exception:
    __version__ = "dev"

import signet.network_module as nr
import signet.LeidenSigned as les
from signet.LouvainSigned import LouvainSigned


# =============================================================================
# IO helpers
# =============================================================================

def load_input_graphs(args):
    """
    Return (G_pos, G_neg, meta) where:
      - G_pos, G_neg are either igraph.Graph or networkx.Graph depending on method,
      - meta includes gene_id/name availability.
    """
    meta = {}

    if args.format == "tsv":
        pos_path = Path(args.pos)
        neg_path = Path(args.neg)
        if not pos_path.exists():
            raise FileNotFoundError(f"Positive TSV not found: {pos_path}")
        if not neg_path.exists():
            raise FileNotFoundError(f"Negative TSV not found: {neg_path}")

        if args.method.startswith("leiden"):
            G_pos = nr.load_graph_from_TSV_igraph(str(pos_path), threshold=float(args.thre_pos))
            G_neg = nr.load_graph_from_TSV_igraph(str(neg_path), threshold=float(args.thre_neg))
            meta["has_gene_id"] = "gene_id" in G_pos.vs.attributes()
            meta["has_name"] = "name" in G_pos.vs.attributes()
            meta["source"] = "tsv"
            meta["pos_path"] = str(pos_path)
            meta["neg_path"] = str(neg_path)
            return G_pos, G_neg, meta

        if args.method == "louvain":
            if LouvainSigned is None:
                raise ImportError("LouvainSigned is not importable.")
            G_pos = nr.load_graph_from_TSV(str(pos_path), threshold=float(args.thre_pos))
            G_neg = nr.load_graph_from_TSV(str(neg_path), threshold=float(args.thre_neg))
            meta["source"] = "tsv"
            meta["pos_path"] = str(pos_path)
            meta["neg_path"] = str(neg_path)
            return G_pos, G_neg, meta

        raise ValueError(f"Unsupported method for TSV: {args.method}")

    if args.format == "snap":
        snap_path = Path(args.snap)
        if not snap_path.exists():
            raise FileNotFoundError(f"SNAP signed TSV not found: {snap_path}")

        # Leiden 系は igraph 版を使う
        if args.method.startswith("leiden"):
            G_pos, G_neg = nr.load_snap_signed_tsv_igraph(
                str(snap_path),
                undirected=not args.directed,
                conflict=args.conflict,
                add_weights=not args.no_weights,
                keep_self_loops=args.keep_self_loops,
                min_degree=args.min_degree,
                degree_mode=args.degree_mode,
            )
            meta["has_gene_id"] = "gene_id" in G_pos.vs.attributes()
            meta["has_name"] = "name" in G_pos.vs.attributes()
            meta["source"] = "snap"
            meta["snap_path"] = str(snap_path)
            return G_pos, G_neg, meta

        # LouvainSigned は networkx 版
        if args.method == "louvain":
            if LouvainSigned is None:
                raise ImportError("LouvainSigned is not importable.")
            G_pos, G_neg = nr.load_snap_signed_tsv_nx(
                str(snap_path),
                undirected=not args.directed,
                conflict=args.conflict,
                add_weights=not args.no_weights,
                keep_self_loops=args.keep_self_loops,
                min_degree=args.min_degree,
                degree_mode=args.degree_mode,
            )
            meta["source"] = "snap"
            meta["snap_path"] = str(snap_path)
            return G_pos, G_neg, meta

        raise ValueError(f"Unsupported method for SNAP TSV: {args.method}")

    if args.format == "mat":
        from scipy.io import loadmat
        import scipy.sparse as sp

        mat_path = Path(args.mat)
        if not mat_path.exists():
            raise FileNotFoundError(f"MAT not found: {mat_path}")

        data = loadmat(str(mat_path))
        if args.pos_key not in data or args.neg_key not in data:
            raise KeyError(f"Keys not found in MAT: need '{args.pos_key}' and '{args.neg_key}'")

        pos_matrix = sp.csr_matrix(data[args.pos_key])
        neg_matrix = sp.csr_matrix(data[args.neg_key])
        pos_matrix.setdiag(0); pos_matrix.eliminate_zeros()
        neg_matrix.setdiag(0); neg_matrix.eliminate_zeros()

        G_pos = ig.Graph.Weighted_Adjacency(pos_matrix, mode="undirected", attr="weight", loops="ignore")
        G_neg = ig.Graph.Weighted_Adjacency(neg_matrix, mode="undirected", attr="weight", loops="ignore")

        n = G_pos.vcount()
        if G_neg.vcount() != n:
            raise ValueError("pos/neg matrices must have same shape.")

        G_pos.vs["gene_id"] = [str(i) for i in range(n)]
        G_pos.vs["name"] = [str(i) for i in range(n)]
        G_neg.vs["gene_id"] = [str(i) for i in range(n)]
        G_neg.vs["name"] = [str(i) for i in range(n)]

        meta["source"] = "mat"
        meta["mat_path"] = str(mat_path)
        meta["pos_key"] = args.pos_key
        meta["neg_key"] = args.neg_key
        meta["has_gene_id"] = True
        meta["has_name"] = True
        return G_pos, G_neg, meta

    raise ValueError(f"Unknown format type: {args.format}")


# =============================================================================
# Partition extraction helpers
# =============================================================================

def extract_gene_ids_and_membership(part, G_pos, is_louvain=False):
    """
    Extract (gene_ids, membership, gene_names) from a partition result.

    gene_names is a dict: gene_id -> display name.
    """
    if is_louvain:
        gene_ids = list(part.keys())
        membership = [part[g] for g in gene_ids]
        gene_names = {}
        return gene_ids, membership, gene_names

    membership = list(part.membership)

    if hasattr(part, "graph") and "gene_id" in part.graph.vs.attributes():
        gene_ids = list(part.graph.vs["gene_id"])
        if "name" in part.graph.vs.attributes():
            gene_names = dict(zip(gene_ids, part.graph.vs["name"]))
        else:
            gene_names = {}
    elif hasattr(G_pos, "vs") and "gene_id" in G_pos.vs.attributes():
        gene_ids = list(G_pos.vs["gene_id"])
        if "name" in G_pos.vs.attributes():
            gene_names = dict(zip(gene_ids, G_pos.vs["name"]))
        else:
            gene_names = {}
    else:
        gene_ids = [str(i) for i in range(len(membership))]
        gene_names = {}

    return gene_ids, membership, gene_names


# =============================================================================
# Output writers
# =============================================================================

def write_partition_tsv(out_path, gene_ids, membership, gene_names=None):
    """Write partition TSV: gene_id, gene_name (if available), community."""
    out_path = Path(out_path)
    has_names = gene_names and any(
        gene_names.get(gid, gid) != gid for gid in gene_ids
    )
    with out_path.open("w", encoding="utf-8") as f:
        if has_names:
            f.write("gene_id\tgene_name\tcommunity\n")
            for gid, c in zip(gene_ids, membership):
                name = gene_names.get(gid, gid)
                f.write(f"{gid}\t{name}\t{c}\n")
        else:
            f.write("gene_id\tcommunity\n")
            for gid, c in zip(gene_ids, membership):
                f.write(f"{gid}\t{c}\n")


def write_summary(out_path, args, G_pos, G_neg, gene_ids, membership):
    """Write summary statistics to a text file."""
    out_path = Path(out_path)
    counts = Counter(membership)
    sizes = sorted(counts.values(), reverse=True)
    n_singletons = sum(1 for s in sizes if s == 1)

    sizes_arr = np.array(sizes, dtype=float)
    if len(sizes_arr) > 1:
        probs = sizes_arr / sizes_arr.sum()
        entropy = -np.sum(probs * np.log(probs))
        norm_entropy = entropy / np.log(len(sizes_arr))
    else:
        norm_entropy = 0.0

    with out_path.open("w", encoding="utf-8") as f:
        f.write("=" * 60 + "\n")
        f.write("SIGNET — Signed Network Community Detection Summary\n")
        f.write("=" * 60 + "\n\n")

        f.write("[Parameters]\n")
        f.write(f"  method           : {args.method}\n")
        f.write(f"  format           : {args.format}\n")
        if args.format == "tsv":
            f.write(f"  pos              : {args.pos}\n")
            f.write(f"  neg              : {args.neg}\n")
            f.write(f"  thre_pos         : {args.thre_pos}\n")
            f.write(f"  thre_neg         : {args.thre_neg}\n")
        elif args.format == "snap":
            f.write(f"  snap             : {args.snap}\n")
            f.write(f"  directed         : {args.directed}\n")
            f.write(f"  conflict         : {args.conflict}\n")
            f.write(f"  no_weights       : {args.no_weights}\n")
            f.write(f"  keep_self_loops  : {args.keep_self_loops}\n")
            f.write(f"  min_degree       : {args.min_degree}\n")
            f.write(f"  degree_mode      : {args.degree_mode}\n")
        elif args.format == "mat":
            f.write(f"  mat              : {args.mat}\n")
            f.write(f"  pos_key          : {args.pos_key}\n")
            f.write(f"  neg_key          : {args.neg_key}\n")

        if args.method in ("louvain", "leiden-mod-alpha"):
            f.write(f"  alpha            : {args.alpha}\n")
        if args.method == "leiden-mod-lambda":
            f.write(f"  lambda_neg       : {args.lambda_neg}\n")
        if args.method in ("louvain", "leiden-mod-alpha", "leiden-mod-lambda"):
            f.write(f"  resolution       : {args.resolution}\n")
        if args.method == "leiden-cpm-single":
            f.write(f"  gamma            : {args.gamma}\n")
            f.write(f"  lambda_neg       : {args.lambda_neg}\n")
            f.write(f"  neg_weight_mode  : {args.neg_weight_mode}\n")
        f.write(f"  seed             : {args.seed}\n\n")

        f.write("[Graph]\n")
        if hasattr(G_pos, "vcount"):
            f.write(f"  G_pos            : {G_pos.vcount()} vertices, {G_pos.ecount()} edges\n")
            f.write(f"  G_neg            : {G_neg.vcount()} vertices, {G_neg.ecount()} edges\n")
        else:
            f.write(f"  G_pos            : {G_pos.number_of_nodes()} nodes, {G_pos.number_of_edges()} edges\n")
            f.write(f"  G_neg            : {G_neg.number_of_nodes()} nodes, {G_neg.number_of_edges()} edges\n")
        f.write("\n")

        f.write("[Partition]\n")
        f.write(f"  total nodes      : {len(membership)}\n")
        f.write(f"  communities      : {len(counts)}\n")
        f.write(f"  largest          : {sizes[0] if sizes else 0}\n")
        f.write(f"  smallest         : {sizes[-1] if sizes else 0}\n")
        f.write(f"  mean size        : {np.mean(sizes):.1f}\n")
        f.write(f"  median size      : {np.median(sizes):.1f}\n")
        f.write(f"  singletons       : {n_singletons}\n")
        f.write(f"  norm. entropy    : {norm_entropy:.4f}\n\n")

        f.write("[Community sizes]\n")
        for comm_id in sorted(counts.keys()):
            f.write(f"  community {comm_id:>4d}   : {counts[comm_id]} nodes\n")


def write_gmt(out_path, gene_ids, membership, gene_names=None):
    """
    Write communities in GMT format.
    Each line: community_name <TAB> description <TAB> gene1 <TAB> gene2 ...

    If gene_names is available, gene names are used; otherwise gene_ids.
    """
    out_path = Path(out_path)
    comm_to_genes = {}
    for gid, comm in zip(gene_ids, membership):
        comm_to_genes.setdefault(comm, []).append(gid)

    with out_path.open("w", encoding="utf-8") as f:
        for comm_id in sorted(comm_to_genes.keys()):
            genes = comm_to_genes[comm_id]
            if gene_names:
                display = [gene_names.get(gid, gid) for gid in genes]
            else:
                display = genes
            name = f"community_{comm_id}"
            desc = f"n={len(genes)}"
            f.write(f"{name}\t{desc}\t" + "\t".join(str(g) for g in display) + "\n")


# =============================================================================
# Plot writers
# =============================================================================

def plot_community_sizes(out_path, membership):
    """Plot community size distribution: rank-size bar + histogram."""
    import matplotlib
    matplotlib.use("Agg")
    import matplotlib.pyplot as plt

    counts = Counter(membership)
    sizes = sorted(counts.values(), reverse=True)

    fig, axes = plt.subplots(1, 2, figsize=(12, 5))

    # Rank-size plot
    ax = axes[0]
    ranks = np.arange(1, len(sizes) + 1)
    ax.bar(ranks, sizes, color="steelblue", edgecolor="none")
    ax.set_xlabel("Community rank")
    ax.set_ylabel("Size (# nodes)")
    ax.set_title("Community rank–size")
    if max(sizes) / max(min(sizes), 1) > 50:
        ax.set_yscale("log")

    # Histogram
    ax = axes[1]
    max_size = max(sizes)
    if max_size > 100:
        bins = np.logspace(0, np.log10(max_size + 1), 30)
        ax.set_xscale("log")
    else:
        bins = range(1, max_size + 2)
    ax.hist(sizes, bins=bins, color="steelblue", edgecolor="white")
    ax.set_xlabel("Community size")
    ax.set_ylabel("Frequency")
    ax.set_title("Community size distribution")

    fig.tight_layout()
    fig.savefig(str(out_path), dpi=150, bbox_inches="tight")
    plt.close(fig)


def _build_inter_community_matrix(G, membership_dict, community_ids):
    """Build community x community edge weight sum matrix."""
    n_comm = len(community_ids)
    cid_to_idx = {cid: i for i, cid in enumerate(community_ids)}
    mat = np.zeros((n_comm, n_comm), dtype=float)

    if nr.is_ig(G):
        for edge in G.es:
            u, v = edge.source, edge.target
            w = edge["weight"] if "weight" in G.es.attributes() else 1.0
            cu, cv = membership_dict.get(u), membership_dict.get(v)
            if cu is None or cv is None:
                continue
            i, j = cid_to_idx.get(cu), cid_to_idx.get(cv)
            if i is None or j is None:
                continue
            mat[i, j] += w
            if i != j:
                mat[j, i] += w
    elif nr.is_nx(G):
        for u, v, data in G.edges(data=True):
            w = data.get("weight", 1.0)
            cu, cv = membership_dict.get(u), membership_dict.get(v)
            if cu is None or cv is None:
                continue
            i, j = cid_to_idx.get(cu), cid_to_idx.get(cv)
            if i is None or j is None:
                continue
            mat[i, j] += w
            if i != j:
                mat[j, i] += w

    return mat


def _membership_dict_for_graph(G, gene_ids, membership):
    """Build node_id -> community_id dict matching the graph's node space."""
    gid_to_comm = dict(zip(gene_ids, membership))

    if nr.is_ig(G):
        result = {}
        if "gene_id" in G.vs.attributes():
            for v in G.vs:
                gid = v["gene_id"]
                if gid in gid_to_comm:
                    result[v.index] = gid_to_comm[gid]
        else:
            for idx, comm in enumerate(membership):
                if idx < G.vcount():
                    result[idx] = comm
        return result

    if nr.is_nx(G):
        return {node: gid_to_comm[node] for node in G.nodes() if node in gid_to_comm}

    return gid_to_comm


def plot_inter_community_heatmap(out_path, G_pos, G_neg, gene_ids, membership, top_n=30):
    """Plot inter-community positive / negative / signed edge density heatmaps."""
    import matplotlib
    matplotlib.use("Agg")
    import matplotlib.pyplot as plt
    from matplotlib.colors import TwoSlopeNorm

    counts = Counter(membership)
    top_comms = [cid for cid, _ in counts.most_common(top_n)]
    top_comms_sorted = sorted(top_comms)

    if len(top_comms_sorted) < 2:
        return

    mem_pos = _membership_dict_for_graph(G_pos, gene_ids, membership)
    mem_neg = _membership_dict_for_graph(G_neg, gene_ids, membership)

    mat_pos = _build_inter_community_matrix(G_pos, mem_pos, top_comms_sorted)
    mat_neg = _build_inter_community_matrix(G_neg, mem_neg, top_comms_sorted)

    # Normalize by community size pairs
    comm_sizes = np.array([counts[c] for c in top_comms_sorted], dtype=float)
    size_outer = np.outer(comm_sizes, comm_sizes)
    np.fill_diagonal(size_outer, comm_sizes * (comm_sizes - 1) / 2)
    size_outer = np.maximum(size_outer, 1.0)

    density_pos = mat_pos / size_outer
    density_neg = mat_neg / size_outer
    density_signed = density_pos - density_neg

    labels = [str(c) for c in top_comms_sorted]

    fig, axes = plt.subplots(1, 3, figsize=(18, 5.5))

    ax = axes[0]
    im = ax.imshow(density_pos, cmap="Reds", aspect="equal")
    ax.set_xticks(range(len(labels))); ax.set_xticklabels(labels, rotation=90, fontsize=7)
    ax.set_yticks(range(len(labels))); ax.set_yticklabels(labels, fontsize=7)
    ax.set_title("Positive edge density")
    fig.colorbar(im, ax=ax, shrink=0.7)

    ax = axes[1]
    im = ax.imshow(density_neg, cmap="Blues", aspect="equal")
    ax.set_xticks(range(len(labels))); ax.set_xticklabels(labels, rotation=90, fontsize=7)
    ax.set_yticks(range(len(labels))); ax.set_yticklabels(labels, fontsize=7)
    ax.set_title("Negative edge density")
    fig.colorbar(im, ax=ax, shrink=0.7)

    ax = axes[2]
    vmax = max(abs(density_signed.min()), abs(density_signed.max()), 1e-9)
    norm = TwoSlopeNorm(vmin=-vmax, vcenter=0, vmax=vmax)
    im = ax.imshow(density_signed, cmap="RdBu_r", norm=norm, aspect="equal")
    ax.set_xticks(range(len(labels))); ax.set_xticklabels(labels, rotation=90, fontsize=7)
    ax.set_yticks(range(len(labels))); ax.set_yticklabels(labels, fontsize=7)
    ax.set_title("Signed density (pos − neg)")
    fig.colorbar(im, ax=ax, shrink=0.7)

    fig.suptitle(f"Inter-community edge density (top {len(top_comms_sorted)} communities)", fontsize=12)
    fig.tight_layout(rect=[0, 0, 1, 0.95])
    fig.savefig(str(out_path), dpi=150, bbox_inches="tight")
    plt.close(fig)


def plot_module_subnetwork(out_path, G_pos, G_neg, gene_ids, membership,
                           community_id, gene_names=None, max_nodes=200):
    """
    Draw a single module's signed subnetwork (positive=red, negative=blue).
    For large modules, only top-degree nodes are shown.
    """
    import matplotlib
    matplotlib.use("Agg")
    import matplotlib.pyplot as plt
    import networkx as nx
    from matplotlib.lines import Line2D

    gid_to_comm = dict(zip(gene_ids, membership))
    member_gids = [gid for gid, c in gid_to_comm.items() if c == community_id]
    if not member_gids:
        return

    gid_set = set(member_gids)

    def _extract_subgraph_nx(G, target_gids):
        """Extract subgraph as NetworkX graph for the given gene IDs."""
        sub = nx.Graph()
        if nr.is_ig(G):
            gid_key = "gene_id" if "gene_id" in G.vs.attributes() else None
            name_key = "name" if "name" in G.vs.attributes() else None
            idx_to_gid = {}
            for v in G.vs:
                gid = v[gid_key] if gid_key else str(v.index)
                if gid in target_gids:
                    idx_to_gid[v.index] = gid
                    label = v[name_key] if name_key else gid
                    sub.add_node(gid, name=label)
            for e in G.es:
                gu = idx_to_gid.get(e.source)
                gv = idx_to_gid.get(e.target)
                if gu is not None and gv is not None:
                    w = e["weight"] if "weight" in G.es.attributes() else 1.0
                    sub.add_edge(gu, gv, weight=w)
        elif nr.is_nx(G):
            for node, data in G.nodes(data=True):
                if node in target_gids:
                    sub.add_node(node, **data)
            for u, v, data in G.edges(data=True):
                if u in target_gids and v in target_gids:
                    sub.add_edge(u, v, **data)
        return sub

    sub_pos = _extract_subgraph_nx(G_pos, gid_set)
    sub_neg = _extract_subgraph_nx(G_neg, gid_set)

    # Trim large modules to top-degree nodes
    all_nodes = set(sub_pos.nodes()) | set(sub_neg.nodes())
    if len(all_nodes) > max_nodes:
        degrees = {n: sub_pos.degree(n, weight=None) for n in sub_pos.nodes()}
        top_nodes = set(sorted(degrees, key=degrees.get, reverse=True)[:max_nodes])
        sub_pos = sub_pos.subgraph(top_nodes).copy()
        sub_neg = sub_neg.subgraph(top_nodes & set(sub_neg.nodes())).copy()
        all_nodes = set(sub_pos.nodes()) | set(sub_neg.nodes())
        title_suffix = f" (top {max_nodes} by degree)"
    else:
        title_suffix = ""

    combined = nx.Graph()
    combined.add_nodes_from(sub_pos.nodes(data=True))
    combined.add_nodes_from(sub_neg.nodes(data=True))
    combined.add_edges_from(sub_pos.edges())
    combined.add_edges_from(sub_neg.edges())

    if combined.number_of_nodes() == 0:
        return

    fig, ax = plt.subplots(figsize=(10, 8))

    k_val = 2.0 / np.sqrt(max(combined.number_of_nodes(), 2))
    pos = nx.spring_layout(combined, seed=42, k=k_val, iterations=200)

    node_size = max(30, 300 - combined.number_of_nodes())

    if combined.number_of_nodes() > 50:
        deg_rank = sorted(combined.degree(), key=lambda x: x[1], reverse=True)
        label_nodes = {n for n, _ in deg_rank[:30]}
    else:
        label_nodes = set(combined.nodes())

    labels = {}
    for node in label_nodes:
        d = combined.nodes[node]
        labels[node] = d.get("name", str(node))
        if gene_names and node in gene_names:
            labels[node] = gene_names[node]

    nx.draw_networkx_nodes(combined, pos, node_color="lightblue",
                           node_size=node_size, alpha=0.8, ax=ax)
    nx.draw_networkx_edges(combined, pos, edgelist=list(sub_pos.edges()),
                           edge_color="red", width=1.0, alpha=0.3, ax=ax)
    nx.draw_networkx_edges(combined, pos, edgelist=list(sub_neg.edges()),
                           edge_color="blue", width=1.0, alpha=0.3, ax=ax)
    nx.draw_networkx_labels(combined, pos, labels=labels, font_size=7, ax=ax)

    legend_elements = [
        Line2D([0], [0], color="red", lw=2, label="Positive edge"),
        Line2D([0], [0], color="blue", lw=2, label="Negative edge"),
    ]
    ax.legend(handles=legend_elements, loc="upper left")
    ax.set_title(f"Community {community_id} ({len(member_gids)} nodes){title_suffix}")
    ax.axis("off")
    fig.tight_layout()
    fig.savefig(str(out_path), dpi=150, bbox_inches="tight")
    plt.close(fig)


# =============================================================================
# Method runners
# =============================================================================

def run_louvain_signed(G_pos, G_neg, args):
    l = LouvainSigned(G_pos, G_neg)
    part = l.best_partition(alpha=float(args.alpha), resolution=float(args.resolution), seed=args.seed)
    return part


def run_leiden_modularity_alpha(G_pos, G_neg, args):
    part = les.find_partition_signed_modularity_alpha(
        G_pos, G_neg,
        resolution=float(args.resolution),
        alpha=float(args.alpha),
        seed=args.seed,
    )
    return part


def run_leiden_modularity_lambda(G_pos, G_neg, args):
    part = les.find_partition_signed_modularity_lambda(
        G_pos, G_neg,
        resolution=float(args.resolution),
        lambda_neg=float(args.lambda_neg),
        seed=args.seed,
    )
    return part


def run_leiden_cpm_single_signed(G_pos, G_neg, args):
    lambda_neg = float(args.lambda_neg)

    if args.format == "mat":
        import scipy.sparse as sp
        Apos = G_pos.get_adjacency_sparse(attribute="weight")
        Aneg = G_neg.get_adjacency_sparse(attribute="weight")
        if args.neg_weight_mode == "absolute":
            Asigned = Apos - lambda_neg * Aneg
        else:
            Asigned = Apos + lambda_neg * Aneg
        Asigned = Asigned.tocsr()
        Asigned.setdiag(0); Asigned.eliminate_zeros()
        G_signed = ig.Graph.Weighted_Adjacency(Asigned, mode="undirected", attr="weight", loops="ignore")
        G_signed.vs["gene_id"] = list(G_pos.vs["gene_id"])
        G_signed.vs["name"] = list(G_pos.vs["name"])
    else:
        edge_w = {}
        def add_edge(u, v, w):
            a, b = (u, v) if u < v else (v, u)
            edge_w[(a, b)] = edge_w.get((a, b), 0.0) + w

        for e in G_pos.es:
            u, v = e.tuple
            w = float(e["weight"]) if "weight" in G_pos.es.attributes() else 1.0
            add_edge(u, v, w)

        for e in G_neg.es:
            u, v = e.tuple
            wraw = float(e["weight"]) if "weight" in G_neg.es.attributes() else 1.0
            if args.neg_weight_mode == "absolute":
                w = -lambda_neg * abs(wraw)
            else:
                w = lambda_neg * wraw
            add_edge(u, v, w)

        n = G_pos.vcount()
        edges2, weights2 = [], []
        for (u, v), w in edge_w.items():
            if w != 0.0:
                edges2.append((u, v))
                weights2.append(w)

        G_signed = ig.Graph(n=n, edges=edges2, directed=False)
        G_signed.es["weight"] = weights2
        if "gene_id" in G_pos.vs.attributes():
            G_signed.vs["gene_id"] = list(G_pos.vs["gene_id"])
        if "name" in G_pos.vs.attributes():
            G_signed.vs["name"] = list(G_pos.vs["name"])

    part = les.find_partition_signed_CPM_single_graph(
        G_signed,
        gamma=float(args.gamma),
        seed=args.seed,
    )
    return part


# =============================================================================
# CLI
# =============================================================================

def build_parser():
    p = argparse.ArgumentParser(
        prog="signet",
        description="SIGNET: Signed network community detection",
        formatter_class=argparse.RawDescriptionHelpFormatter,
    )

    # Input
    p.add_argument("--format", choices=["tsv", "snap", "mat"], default="tsv",
                   help="Input format: tsv, snap, mat (default: tsv)")

    # TSV args
    p.add_argument("--pos", type=str, help="Positive TSV file")
    p.add_argument("--neg", type=str, help="Negative TSV file")
    p.add_argument("--thre-pos", type=float, default=10.0, help="Threshold for positive TSV (default: 10)")
    p.add_argument("--thre-neg", type=float, default=5.0, help="Threshold for negative TSV (default: 5)")

    # SNAP args
    p.add_argument("--snap", type=str, help="SNAP signed TSV file (.txt or .txt.gz)")
    p.add_argument("--directed", action="store_true",
                    help="Treat SNAP edges as directed pairs (default: convert to undirected)")
    p.add_argument("--conflict", choices=["both", "sum", "prefer_pos", "prefer_neg"],
                    default="both",
                    help="How to resolve conflicting signs on the same undirected pair (default: both)")
    p.add_argument("--no-weights", action="store_true",
                    help="Ignore multiplicity/count weights when loading SNAP signed TSV")
    p.add_argument("--keep-self-loops", action="store_true",
                    help="Keep self-loops in SNAP signed TSV (default: drop)")
    p.add_argument("--min-degree", type=int, default=None,
                    help="Remove nodes with degree < this threshold after loading SNAP graph")
    p.add_argument("--degree-mode", choices=["union", "pos", "neg"], default="union",
                    help="Degree basis for SNAP node filtering (default: union)")


    # MAT args
    p.add_argument("--mat", type=str, help="MAT file path")
    p.add_argument("--pos-key", type=str, default="pos", help="MAT key for positive matrix (default: pos)")
    p.add_argument("--neg-key", type=str, default="neg", help="MAT key for negative matrix (default: neg)")

    # Method
    p.add_argument("--method",
                   choices=["louvain", "leiden-mod-alpha", "leiden-mod-lambda", "leiden-cpm-single"],
                   required=True, help="Community detection method")

    # Algorithm params
    p.add_argument("--seed", type=int, default=None, help="Random seed")
    p.add_argument("--resolution", type=float, default=1.0, help="Resolution parameter (default: 1.0)")
    p.add_argument("--alpha", type=float, default=0.5, help="Alpha in [0,1] (default: 0.5)")
    p.add_argument("--lambda-neg", type=float, default=0.0, help="lambda_neg >= 0 (default: 0.0)")
    p.add_argument("--gamma", type=float, default=0.5, help="CPM gamma (default: 0.5)")
    p.add_argument("--neg-weight-mode", choices=["absolute", "signed"], default="absolute",
                   help="Negative weight interpretation (default: absolute)")

    # Output control
    p.add_argument("--out-prefix", type=str, default=None,
                   help="Output file prefix (e.g., results/run1). "
                        "Generates: <prefix>_partition.tsv, <prefix>_summary.txt, "
                        "<prefix>_communities.gmt, <prefix>_community_sizes.pdf, "
                        "<prefix>_inter_community.pdf, <prefix>_module_<id>.pdf")
    p.add_argument("--no-plot", action="store_true",
                   help="Disable all plot outputs (PDF)")
    p.add_argument("--no-gmt", action="store_true",
                   help="Disable GMT output")
    p.add_argument("--top-modules", type=int, default=10,
                   help="Number of top modules to plot individually (default: 10)")
    p.add_argument("--max-nodes-per-module", type=int, default=200,
                   help="Max nodes to draw per module plot (default: 200)")
    p.add_argument("--quiet", action="store_true", help="Suppress progress messages")

    p.add_argument("-v", "--version", action="version", version=f"%(prog)s {__version__}")
    return p


def validate_args(args):
    if args.format == "tsv":
        if not args.pos or not args.neg:
            raise ValueError("--pos and --neg are required when --format tsv")
    if args.format == "snap":
        if not args.snap:
            raise ValueError("--snap is required when --format snap")
    if args.format == "mat":
        if not args.mat:
            raise ValueError("--mat is required when --format mat")

    if args.method == "louvain":
        if args.format != "tsv":
            raise ValueError("louvain requires --format tsv")
        if LouvainSigned is None:
            raise ImportError("LouvainSigned is not available.")
        if not (0.0 <= args.alpha <= 1.0):
            raise ValueError("--alpha must be in [0,1]")

    elif args.method == "leiden-mod-alpha":
        if not (0.0 <= args.alpha <= 1.0):
            raise ValueError("--alpha must be in [0,1]")

    elif args.method == "leiden-mod-lambda":
        if args.lambda_neg < 0:
            raise ValueError("--lambda-neg must be >= 0")

    elif args.method == "leiden-cpm-single":
        if args.lambda_neg < 0:
            raise ValueError("--lambda-neg must be >= 0")


def _log(msg, quiet=False):
    if not quiet:
        print(f"[signet] {msg}")


def main(argv=None):
    parser = build_parser()
    args = parser.parse_args(argv)
    validate_args(args)

    # =====================================================================
    # Load graphs
    # =====================================================================
    G_pos, G_neg, meta = load_input_graphs(args)

    if not args.quiet:
        if hasattr(G_pos, "vcount"):
            print(f"G_pos: v={G_pos.vcount()} e={G_pos.ecount()}")
            print(f"G_neg: v={G_neg.vcount()} e={G_neg.ecount()}")
        else:
            print(f"G_pos: nodes={G_pos.number_of_nodes()} edges={G_pos.number_of_edges()}")
            print(f"G_neg: nodes={G_neg.number_of_nodes()} edges={G_neg.number_of_edges()}")
        print(f"method={args.method} format={args.format}")

    # =====================================================================
    # Run community detection
    # =====================================================================
    is_louvain = (args.method == "louvain")

    if args.method == "louvain":
        part = run_louvain_signed(G_pos, G_neg, args)
    elif args.method == "leiden-mod-alpha":
        part = run_leiden_modularity_alpha(G_pos, G_neg, args)
    elif args.method == "leiden-mod-lambda":
        part = run_leiden_modularity_lambda(G_pos, G_neg, args)
    elif args.method == "leiden-cpm-single":
        part = run_leiden_cpm_single_signed(G_pos, G_neg, args)
    else:
        raise ValueError(f"Unknown method: {args.method}")

    # =====================================================================
    # Extract results
    # =====================================================================
    gene_ids, membership, gene_names = extract_gene_ids_and_membership(
        part, G_pos, is_louvain=is_louvain,
    )

    # =====================================================================
    # Write outputs
    # =====================================================================
    if args.out_prefix is None:
        counts = Counter(membership)
        print(f"[RESULT] nodes={len(membership)} communities={len(counts)} "
              f"largest={max(counts.values())} smallest={min(counts.values())}")
        return 0

    prefix = Path(args.out_prefix)
    prefix.parent.mkdir(parents=True, exist_ok=True)

    # 1. Partition TSV
    path = Path(f"{prefix}_partition.tsv")
    write_partition_tsv(path, gene_ids, membership, gene_names)
    _log(f"wrote {path}", args.quiet)

    # 2. Summary text
    path = Path(f"{prefix}_summary.txt")
    write_summary(path, args, G_pos, G_neg, gene_ids, membership)
    _log(f"wrote {path}", args.quiet)

    # 3. GMT
    if not args.no_gmt:
        path = Path(f"{prefix}_communities.gmt")
        write_gmt(path, gene_ids, membership, gene_names)
        _log(f"wrote {path}", args.quiet)

    # 4. Plots
    if not args.no_plot:
        # 4a. Community size distribution
        path = Path(f"{prefix}_community_sizes.pdf")
        plot_community_sizes(path, membership)
        _log(f"wrote {path}", args.quiet)

        # 4b. Inter-community heatmap
        path = Path(f"{prefix}_inter_community.pdf")
        try:
            plot_inter_community_heatmap(
                path, G_pos, G_neg, gene_ids, membership,
                top_n=min(args.top_modules * 3, 30),
            )
            _log(f"wrote {path}", args.quiet)
        except Exception as e:
            _log(f"skipped inter-community heatmap: {e}", args.quiet)

        # 4c. Individual module subnetwork plots
        counts = Counter(membership)
        top_comms = [cid for cid, _ in counts.most_common(args.top_modules)]
        for community_id in top_comms:
            path = Path(f"{prefix}_module_{community_id}.pdf")
            try:
                plot_module_subnetwork(
                    path, G_pos, G_neg, gene_ids, membership,
                    community_id=community_id,
                    gene_names=gene_names,
                    max_nodes=args.max_nodes_per_module,
                )
                _log(f"wrote {path}", args.quiet)
            except Exception as e:
                _log(f"skipped module {community_id}: {e}", args.quiet)

    _log("done.", args.quiet)
    return 0


if __name__ == "__main__":
    raise SystemExit(main())