#!/usr/bin/env python

import multiprocessing
from functools import partial
import logging
import os, sys, glob, datetime, time, gzip
import argparse
import collections
from math import log
import numpy as np
import requests
from tqdm import tqdm
sys.path.append(os.path.join(os.path.split(sys.argv[0])[0], '../'))
from scTE.miniglbase import genelist, glload, location

chr_list = [str(k) for k in list(range(1,50))] + ['X','Y','M']

def read_opts(parser):
    args = parser.parse_args()

    if args.mode not in ['inclusive', 'exclusive', 'nointron'] :
        logging.error("Counting mode %s not supported\n" % (args.mode))
        parser.print_help()
        sys.exit(1)

    if args.genome not in ['mm10','hg38','panTro6','macFas5','dm6','danRer11','xenTro9','other'] :
        logging.error("Counting genome %s not supported\n" % (args.genome))
        parser.print_help()
        sys.exit(1)

    if args.genome == 'other' and (not args.tefile or not args.genefile):
        logging.error("When using -g other, both -te and -gene arguments are required\n")
        parser.print_help()
        sys.exit(1)

    args.info = logging.info
    return args

def cleanexon(exons):
    """
    Merge overlapping/adjacent exon regions without expanding to base positions.

    Args:
        exons: Dictionary mapping gene names to lists of [chr, start, end] exon coordinates

    Returns:
        List of dicts with merged exon locations and annotations
    """
    merged_exons = []

    for gene_name in sorted(exons):
        # Sort exons by start position
        sorted_exons = sorted(exons[gene_name], key=lambda x: x[1])
        if not sorted_exons:
            continue

        # Merge overlapping/adjacent exons using interval merge
        current_chr, current_start, current_end = sorted_exons[0]
        for exon in sorted_exons[1:]:
            chr, start, end = exon
            if start <= current_end + 1:  # Overlap or adjacent
                current_end = max(current_end, end)
            else:
                merged_exons.append({
                    'loc': location(chr=current_chr, left=current_start, right=current_end),
                    'annot': gene_name
                })
                current_chr, current_start, current_end = chr, start, end

        # Save the last merged interval
        merged_exons.append({
            'loc': location(chr=current_chr, left=current_start, right=current_end),
            'annot': gene_name
        })

    return merged_exons

def readGtf(filename):
    """
    Read and parse a GTF file, extracting exon and UTR annotations.
    
    Args:
        filename: Path to GTF file (can be gzipped)
        
    Returns:
        tuple: (raw_annotations, clean_annotations) where each is a dict mapping 
        gene names to lists of [chromosome, start, end] coordinates
    """
    raw = {}  # Store all exon/UTR annotations
    clean = {} # Store only protein-coding and lincRNA annotations
    
    # Open file (handle both gzipped and plain text)
    with gzip.open(filename, 'rb') if '.gz' in filename else open(filename, 'r') as f:
        for line in f:
            # Handle gzipped files
            if '.gz' in filename:
                line = line.decode('ascii')
                
            # Skip comments
            if line.startswith('#'):
                continue
                
            fields = line.strip().split('\t')
            
            # Only process exon and UTR features
            if fields[2] not in ['exon', 'UTR']:
                continue
                
            # Get chromosome, adding 'chr' prefix if needed
            chrom = fields[0]
            if 'chr' not in chrom:
                chrom = 'chr' + chrom
                
            # Skip chromosomes not in our list
            # TODO: improve usage of global var `chr_list`
            if chrom.replace('chr','') not in chr_list:
                continue
                
            # Get coordinates
            start = int(fields[3])
            end = int(fields[4])
            
            # Extract gene name from attributes
            if 'gene_name' not in fields[8]:
                continue
            gene_name = fields[8].split('gene_name "')[1].split('";')[0]
            
            # Store in raw annotations
            if gene_name not in raw:
                raw[gene_name] = []
            raw[gene_name].append([chrom, start, end])
            
            # Store protein-coding and lincRNA in clean annotations
            if 'protein_coding' in line or 'lincRNA' in line:
                if gene_name not in clean:
                    clean[gene_name] = []
                clean[gene_name].append([chrom, start, end])
    
    return raw, clean


def download_with_retry(url, max_retries=2):
    """
    Download a file from a URL with retry logic and integrity verification.

    Supports HTTP(S) via requests and FTP via ftplib.
    Checks existing file size and re-downloads on mismatch.

    Args:
        url: The URL to download from (http://, https://, or ftp://)
        max_retries: Maximum number of retry attempts

    Returns:
        The downloaded filename
    """
    import ftplib

    filename = url.split('/')[-1]

    # Determine expected file size
    expected_size = 0
    if url.startswith('ftp://'):
        # Use ftplib to get file size
        try:
            parts = url.replace('ftp://', '').split('/')
            host = parts[0]
            path = '/' + '/'.join(parts[1:])
            ftp = ftplib.FTP(host, timeout=10)
            ftp.login()
            expected_size = ftp.size(path)
            ftp.quit()
        except Exception:
            expected_size = 0
    else:
        # HTTP(S): HEAD request for content-length
        try:
            head_resp = requests.head(url, allow_redirects=True, timeout=10)
            head_resp.raise_for_status()
            expected_size = int(head_resp.headers.get('content-length', 0))
        except Exception:
            expected_size = 0

    # Verify existing file
    if os.path.exists(filename):
        actual_size = os.path.getsize(filename)
        if expected_size > 0 and actual_size == expected_size:
            logging.warning("%s already exists (%s bytes), skip download" % (filename, f"{actual_size:,}"))
            return filename
        elif expected_size > 0:
            logging.warning("%s size mismatch (expected %s, got %s), re-downloading" % (filename, f"{expected_size:,}", f"{actual_size:,}"))
        else:
            logging.warning("%s already exists, skip download (size not verified)" % filename)
            return filename

    for attempt in range(max_retries + 1):
        try:
            if url.startswith('ftp://'):
                # FTP download with progress
                parts = url.replace('ftp://', '').split('/')
                host = parts[0]
                path = '/' + '/'.join(parts[1:])
                ftp = ftplib.FTP(host, timeout=30)
                ftp.login()

                total_size = expected_size or 0
                with open(filename, 'wb') as f, tqdm(
                    desc=filename,
                    total=total_size,
                    unit='B',
                    unit_scale=True,
                    unit_divisor=1024,
                ) as progress_bar:
                    def callback(data):
                        f.write(data)
                        progress_bar.update(len(data))
                    ftp.retrbinary('RETR ' + path, callback)
                ftp.quit()
            else:
                # HTTP(S) download with progress
                response = requests.get(url, stream=True, timeout=30)
                response.raise_for_status()

                total_size = int(response.headers.get('content-length', 0))
                with open(filename, 'wb') as f, tqdm(
                    desc=filename,
                    total=total_size,
                    unit='B',
                    unit_scale=True,
                    unit_divisor=1024,
                ) as progress_bar:
                    for data in response.iter_content(chunk_size=1048576):
                        size = f.write(data)
                        progress_bar.update(size)
            return filename
        except Exception as e:
            if attempt == max_retries:
                raise
            logging.warning("Download failed, retry %s/%s: %s" % (attempt+1, max_retries, str(e)))


def _build_gene_buckets(merged_clean, chr_set):
    """Build chromosome-keyed bucket index from merged clean exons.

    Returns: dict of chr -> {bucket_left: [[left, rite], ...]}
    """
    gene = {}
    for l in merged_clean:
        chr = l['loc'].loc['chr']
        if chr not in chr_set:
            continue
        left = l['loc']['left']
        rite = l['loc']['right']

        left_buck = ((left-1)//10000) * 10000
        right_buck = (rite//10000) * 10000
        buckets_reqd = range(left_buck, right_buck+10000, 10000)

        if chr not in gene:
            gene[chr] = {}

        if buckets_reqd:
            for buck in buckets_reqd:
                if buck not in gene[chr]:
                    gene[chr][buck] = []
                gene[chr][buck].append([left, rite])
    return gene


def _te_overlaps_gene(left, rite, chr_gene_buckets):
    """Check if a TE interval overlaps any gene exon in the bucket index."""
    left_buck = ((left-1)//10000) * 10000
    right_buck = (rite//10000) * 10000
    buckets_reqd = range(left_buck, right_buck+10000, 10000)

    for buck in buckets_reqd:
        if buck not in chr_gene_buckets:
            continue
        for k in chr_gene_buckets[buck]:
            if left < k[1] and rite > k[0]:
                return True
    return False


def genomeIndex(genome, mode, tefile, genefile, outname, geneurls, teurls):
    if not genefile:
        genefilename = download_with_retry(geneurls)
    else:
        genefilename = genefile

    gtf_data = readGtf(genefilename)

    raw = cleanexon(gtf_data[0])
    clean = cleanexon(gtf_data[1])

    # Keep a reference to the original raw GTF dict for nointron mode
    # (which needs unmerged per-gene exon boundaries)
    raw_gtf = gtf_data[0]
    # Free gtf_data tuple — we've extracted what we need
    del gtf_data

    # for custom chromosome
    active_chr_set = set(chr_list)
    if tefile:
        tefilename = tefile
        if '.gz' in tefilename:
            o = gzip.open(tefilename,'rb')
        else:
            o = open(tefilename,'r')
        for line in o:
            if '.gz' in tefilename:
                line = line.decode('ascii')
            chr = line.strip().split('\t')[0]
            if chr not in active_chr_set:
                active_chr_set.add(chr)
        o.close()
        # Re-open for processing below
    else:
        tefilename = download_with_retry(teurls)

    if mode == 'exclusive':
        _build_exclusive_index(raw, clean, tefile, tefilename, active_chr_set, genome, outname)
        del raw, clean

    elif mode == 'inclusive':
        _build_inclusive_index(raw, tefile, tefilename, active_chr_set, genome, outname)
        del raw, clean

    elif mode == 'nointron':
        _build_nointron_index(raw, raw_gtf, clean, tefile, tefilename, active_chr_set, genome, outname)
        del raw, raw_gtf, clean

    if not tefile:
        os.system('rm %s ' % tefilename)
    if not genefile:
        os.system('rm %s' % genefilename)


def _build_exclusive_index(raw, clean, tefile, tefilename, active_chr_set, genome, outname):
    """Build exclusive-mode index with chr-by-chr streaming for memory efficiency."""

    # Build genome-wide gene bucket index from merged clean exons
    logging.info("Building gene bucket index for overlap detection...")
    gene = _build_gene_buckets(clean, active_chr_set)
    # Free merged clean — no longer needed after bucket index built
    del clean

    # Start all_annot with all merged gene exon regions (raw)
    # These go into the linearData directly — no pickle round-trip
    all_items = raw  # take ownership — raw won't be used again
    # (raw is already a list of dicts, we'll extend it with TE items)
    del raw  # clear the name binding; all_items now owns the list

    # --- Stream rmsk chr-by-chr, filtering TEs against gene buckets ---
    logging.info("Processing TE annotations (chr-by-chr streaming)...")

    if '.gz' in tefilename:
        o = gzip.open(tefilename, 'rb')
    else:
        o = open(tefilename, 'r')

    if tefile:
        # Custom TE file: may not be chr-sorted, use legacy genome-wide approach
        _process_tes_legacy(all_items, o, tefilename, tefile, gene, active_chr_set)
    else:
        # Downloaded rmsk: chr-sorted, use streaming chr-by-chr approach
        _process_tes_streaming(all_items, o, tefilename, gene, active_chr_set)

    o.close()
    del gene  # free bucket index

    # Build genelist from accumulated items (no pickle copy)
    logging.info("Building final genelist index (compact mode)...")
    all_annot = genelist()
    all_annot.linearData = all_items  # direct assignment, no copy
    del all_items
    all_annot._optimiseData_compact()  # columnar + buckets only — skips qkeyfind/dataByChr

    if not outname:
        all_annot.save('%s.exclusive.idx' % genome)
        logging.info('Done the index building, results output to %s.exclusive.idx' % genome)
    else:
        all_annot.save('%s.exclusive.idx' % outname)
        logging.info('Done the index building, results output to %s.exclusive.idx' % outname)


def _process_tes_streaming(all_items, file_handle, tefilename, gene, active_chr_set):
    """Stream chr-sorted rmsk TEs, filtering against gene buckets chr-by-chr.

    Flushes each chromosome's non-overlapping TEs to all_items as chr
    boundaries are crossed, keeping per-chr memory low.
    """
    current_chr = None
    chr_noverlap = []

    valid_classes = {'DNA', 'LINE', 'LTR', 'SINE', 'Satellite', 'Retroposon'}

    for line in file_handle:
        if '.gz' in tefilename:
            line = line.decode('ascii')
        t = line.strip().split('\t')

        chr = t[5].replace('chr', '')
        left = int(t[6])
        rite = int(t[7])
        name = t[10]
        clas = t[11]

        # Class filter
        if clas not in valid_classes:
            continue

        # Chr filter
        if chr not in active_chr_set:
            continue

        # Chr boundary detection: flush previous chr
        if chr != current_chr:
            if chr_noverlap:
                all_items.extend(chr_noverlap)
                chr_noverlap.clear()
            current_chr = chr

        # No genes on this chr → all TEs pass
        if chr not in gene:
            chr_noverlap.append({
                'loc': location(chr=chr, left=left, right=rite),
                'annot': name
            })
            continue

        # Overlap check
        if not _te_overlaps_gene(left, rite, gene[chr]):
            chr_noverlap.append({
                'loc': location(chr=chr, left=left, right=rite),
                'annot': name
            })

    # Flush final chromosome
    if chr_noverlap:
        all_items.extend(chr_noverlap)


def _process_tes_legacy(all_items, file_handle, tefilename, tefile, gene, active_chr_set):
    """Legacy TE processing for custom (unsorted) TE files.

    Builds noverlap genome-wide, then extends all_items.
    """
    noverlap = []

    for line in file_handle:
        if '.gz' in tefilename:
            line = line.decode('ascii')
        t = line.strip().split('\t')

        if not tefile:
            chr = t[5].replace('chr', '')
            left = int(t[6])
            rite = int(t[7])
            name = t[10]
            clas = t[11]
            valid_classes = {'DNA', 'LINE', 'LTR', 'SINE', 'Satellite', 'Retroposon'}
            if clas not in valid_classes:
                continue
        else:
            chr = t[0].replace('chr', '')
            left = int(t[1])
            rite = int(t[2])
            name = t[3]

        if chr not in active_chr_set:
            continue
        if chr not in gene:
            noverlap.append({
                'loc': location(chr=chr, left=left, right=rite),
                'annot': name
            })
            continue

        if not _te_overlaps_gene(left, rite, gene[chr]):
            noverlap.append({
                'loc': location(chr=chr, left=left, right=rite),
                'annot': name
            })

    all_items.extend(noverlap)


def _build_inclusive_index(raw, tefile, tefilename, active_chr_set, genome, outname):
    """Build inclusive-mode index (no overlap filtering needed — much simpler)."""

    # Build genes genelist from raw merged exons
    genes = genelist()
    genes.load_list(raw, copy=False)

    if not tefile:
        teform = {'force_tsv': True, 'loc': 'location(chr=column[5], left=column[6], right=column[7])',
                  'annot': 10, 'clas': 11}
        if tefilename.endswith('.gz'):
            TEs = genelist(tefilename, format=teform, gzip=True)
        else:
            TEs = genelist(tefilename, format=teform)

        keep = []
        valid_classes = {'DNA', 'LINE', 'LTR', 'SINE', 'Satellite', 'Retroposon'}
        for id, item in enumerate(TEs):
            if item['clas'] not in valid_classes:
                continue
            if item['loc']['chr'] not in active_chr_set:
                continue
            tmp = item.copy()
            del tmp['clas']
            keep.append(tmp)
        gls = genelist()
        gls.load_list(keep, copy=False)
    else:
        TEs = genelist(tefilename, format={
            'force_tsv': True,
            'loc': 'location(chr=column[0], left=column[1], right=column[2])',
            'annot': 3
        })
        gls = TEs.deepcopy()

    # Merge: avoid + operator's deep copy by building directly
    all_items = list(genes.linearData)
    all_items.extend(gls.linearData)
    del genes, TEs, gls

    all_annot = genelist()
    all_annot.linearData = all_items  # direct assignment, no copy
    del all_items
    all_annot._optimiseData_compact()  # columnar + buckets only

    if not outname:
        all_annot.save('%s.inclusive.idx' % genome)
        logging.info('Done the index building, results output to %s.inclusive.idx' % genome)
    else:
        all_annot.save('%s.inclusive.idx' % outname)
        logging.info('Done the index building, results output to %s.inclusive.idx' % outname)


def _build_nointron_index(raw, raw_gtf, clean_merged, tefile, tefilename, active_chr_set, genome, outname):
    """Build nointron-mode index.

    Uses raw_gtf (unmerged gene_name -> [[chr,start,end],...] dict from GTF)
    to compute per-gene min/max boundaries for the nointron overlap mask.
    """
    # Build collapsed per-gene intervals (take min/max of ALL exons for each gene)
    clean_gene = {}
    for k in raw_gtf:
        if len(raw_gtf[k]) == 1:
            clean_gene[k] = [raw_gtf[k][0]]
        else:
            tmp = []
            for it in raw_gtf[k]:
                tmp += it
                chr = [item for item in tmp if 'chr' in str(item)][0]
                tmp = [int(item) for item in tmp if 'chr' not in str(item)]
            clean_gene[k] = [[chr, np.min(tmp), np.max(tmp)]]
    # raw_gtf no longer needed
    del raw_gtf

    clean = cleanexon(clean_gene)

    # Build gene bucket index
    gene = _build_gene_buckets(clean, active_chr_set)
    del clean

    # Start all_items with merged raw (gene) exon regions
    all_items = raw  # take ownership
    del raw, clean_merged  # clean up references

    # Process TEs
    if '.gz' in tefilename:
        o = gzip.open(tefilename, 'rb')
    else:
        o = open(tefilename, 'r')

    if tefile:
        _process_tes_legacy(all_items, o, tefilename, tefile, gene, active_chr_set)
    else:
        _process_tes_streaming(all_items, o, tefilename, gene, active_chr_set)

    o.close()
    del gene

    all_annot = genelist()
    all_annot.linearData = all_items  # direct assignment, no copy
    del all_items
    all_annot._optimiseData_compact()  # columnar + buckets only

    if not outname:
        all_annot.save('%s.nointron.idx' % genome)
        logging.info('Done the index building, results output to %s.nointron.idx' % genome)
    else:
        all_annot.save('%s.nointron.idx' % outname)
        logging.info('Done the index building, results output to %s.nointron.idx' % outname)


def prepare_parser():
    desc = "Build genome annotation index for scTE"
    exmp = "Example: scTE_build -te Data/TE.bed -gene Data/Gene.gtf"
    
    parser = argparse.ArgumentParser(prog='scTE_build', description=desc, epilog=exmp)
    
    # Required arguments group
    required = parser.add_argument_group('required arguments for custom genome')
    required.add_argument('-te', dest='tefile',
                        help='Six columns bed file for transposable elements annotation. Support .gz format.')
    required.add_argument('-gene', dest='genefile',
                        help='Gtf file for genes annotation. Support .gz format.')

    # Optional arguments
    parser.add_argument('-m', '--mode', dest='mode', 
                        type=str, 
                        default='exclusive',
                        choices=['inclusive', 'exclusive', 'nointron'],
                        help='How to count TEs expression: inclusive (include all reads that can map to TEs), '
                             'exclusive (exclude reads mapping to protein coding genes and lncRNAs exons), '
                             'or nointron (exclude reads mapping to genes exons and introns). '
                             'DEFAULT: exclusive')

    parser.add_argument('-o', '--out', dest='out',
                        help='Output file prefix. Default: the genome name')

    parser.add_argument('-g', '--genome', dest='genome',
                        type=str,
                        default='other',
                        choices=['other', 'mm10', 'hg38', 'panTro6', 'macFas5', 'dm6', 'danRer11', 'xenTro9'],
                        help='Genome preset to use. When set to "other", -te and -gene arguments are required. '
                             'Default: other')

    return parser

# Define genome configurations as a dictionary
GENOME_CONFIGS = {
    'mm10': {
        'gene_url': 'ftp://ftp.ebi.ac.uk/pub/databases/gencode/Gencode_mouse/release_M21/gencode.vM21.annotation.gtf.gz',
    },
    'hg38': {
        'gene_url': 'ftp://ftp.ebi.ac.uk/pub/databases/gencode/Gencode_human/release_30/gencode.v30.annotation.gtf.gz',
    },
    'panTro6': {
        'gene_url': 'http://ftp.ensembl.org/pub/release-103/gtf/pan_troglodytes/Pan_troglodytes.Pan_tro_3.0.103.gtf.gz',
    },
    'macFas5': {
        'gene_url': 'http://ftp.ensembl.org/pub/release-102/gtf/macaca_fascicularis/Macaca_fascicularis.Macaca_fascicularis_5.0.102.gtf.gz',
    },
    'dm6': {
        'gene_url': 'http://ftp.ensembl.org/pub/release-103/gtf/drosophila_melanogaster/Drosophila_melanogaster.BDGP6.32.103.gtf.gz',
    },
    'danRer11': {
        'gene_url': 'http://ftp.ensembl.org/pub/release-103/gtf/danio_rerio/Danio_rerio.GRCz11.103.gtf.gz',
    },
    'xenTro9': {
        'gene_url': 'http://ftp.ensembl.org/pub/release-103/gtf/xenopus_tropicalis/Xenopus_tropicalis.Xenopus_tropicalis_v9.1.103.gtf.gz',
    }
}

def main():
    assert sys.version_info >= (3, 6), 'Python >=3.6 is required'
    
    args=read_opts(prepare_parser())

    # Print a readable parameter summary
    logging.info("=" * 50)
    logging.info("scTE_build parameter summary:")
    logging.info("  Genome:            %s" % args.genome)
    logging.info("  Mode:              %s" % args.mode)
    logging.info("  Output prefix:     %s" % (args.out or '(default: genome name)'))
    logging.info("  TE annotation:     %s" % (args.tefile or '(download from UCSC)'))
    logging.info("  Gene annotation:   %s" % (args.genefile or '(download from GENCODE)'))
    logging.info("=" * 50)

    info = args.info

    info("Building the scTE genome annotation index... %s" % (datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")))

    if args.genome == 'other':
        genomeIndex(args.genome, args.mode, args.tefile, args.genefile, args.out, 'No path', 'No path')
    else:
        config = GENOME_CONFIGS[args.genome]
        te_url = f'http://hgdownload.soe.ucsc.edu/goldenPath/{args.genome}/database/rmsk.txt.gz'
        genomeIndex(args.genome, args.mode, args.tefile, args.genefile, args.out, 
                    config['gene_url'], te_url)

    info("Done genome annotation index building... %s"%(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")))

if __name__ == '__main__':
    try:
        main()
    except KeyboardInterrupt:
        sys.stderr.write("User interrupt !\n")
        sys.exit(0)
