#!python
desc = '''

The scATAC-seq data comes as three files, P1, P2 and the barcode, and there is no UMI

You can just align P1 and P2 with your favourite aligner (we prefer STAR with these settings):

****
teopts=' --outFilterMultimapNmax 100 --winAnchorMultimapNmax 100 --outSAMmultNmax 1 --outSAMtype BAM SortedByCoordinate --twopassMode Basic --outWigType wiggle --outWigNorm RPM'
opts='--runRNGseed 42 --runThreadN 12 --readFilesCommand zcat '

genome_mm10='--genomeDir mm10_gencode_vM21_starsolo/SAindex'
genome_hg38='--genomeDir hg38_gencode_v30_starsolo/SAindex'

# p1 = read
# p2 = barcode and UMI
# Make sure you set the correct genome index;
STAR $opts $teopts $genome_hg38 --outFileNamePrefix ss.${out} --readFilesIn ${p1} ${p2}
****

This script will then reprocess the BAM file, and put the BARCODE into CR SAM tag and spoof a UMI

The UMI is generated by incrementing the sequence, so, each UMI is up to 4^14 (26 million).
I guess there remains a change of a clash, but it should be so rare
as to be basically impossible.

Keep in mind though that downstream UMI statistics are inaccurate

Require pysam

'''
import sys, os , time
import gzip
import argparse
import logging
try:
    import pysam
except ImportError:
    print('pack_scatacseq requires pysam')
    sys.exit(1)

sys.path.append(os.path.join(os.path.split(sys.argv[0])[0], '../'))
# from scTE.scatacseq import build_barcode_dict, parse_bam, load_expected_whitelist
from scTE.scatacseq import atacBam2bed,para_atacBam2bed
from scTE.base import *

# Command-line options;
def prepare_parser():
    exmp = 'scTEATAC -i input.bam -o out --genome mm10 -x mm10.te.idx'

    description = 'Package the BAM and BARCODE for the scATAC-seq data to make it suitable for scTE main pipeline'

    description = 'dummy'

    parser = argparse.ArgumentParser(prog='scTE_scatacseq', description=description, epilog=exmp)
    # Optional:
    optional = parser._action_groups.pop()
#     optional.add_argument('-e', '--expwhite', nargs=1, required=False, help='A txt file containing the expected whitelist of barcodes to correct the observed barcodes with')
    optional.add_argument('--ondisk', action='store_true', required=False, help='Do everything in memory (faster, but you will need a lot!, or do it on disk (slower, but no memory requirement')

    optional.add_argument('--min_counts', dest='countnumber',metavar='INT', type=int, default=1000,
                        help='Minimum number of counts required for a cell to pass filtering. Default: 2*min_genes')

    optional.add_argument('-CB', dest='CB', type=str, nargs='?', default='False', choices=['True','False'],
                        help='Set to false to ignore for cell barcodes, Default: False')

    optional.add_argument('-UMI', dest='UMI', type=str, nargs='?', default='False', choices=['True','False'],
                        help='Set to false to ignore for UMI. Default: False')

    optional.add_argument('--ignoreDuplicates', dest='noDup', type=str, nargs='?', default='True', choices=['True','False'],
                        help='If set, reads that have the same orientation and start position will be considered only once. If reads are paired, the mate’s position also has to coincide to ignore a read. Default: True')

    optional.add_argument('--keeptmp', dest='keeptmp', type=str, nargs='?', default='False', choices=['True','False'],
                        help='Keep the _scTEtmp file, which is useful for debugging. Default: False')

    optional.add_argument('-p','--thread', metavar='INT', dest='thread', type=int, default=1,
                        help='Number of threads to use, Default: 1')
    
    optional.add_argument('--hdf5', dest='hdf5', type=str, nargs='?', default='False', choices=['True','False'],
                        help='Save the output as .h5ad formatted file instead of csv file. Default: False')
                        
    required = parser.add_argument_group('required arguments')

    required.add_argument('-i','--input', dest='input', type=str, nargs='+', required=True,
                        help='Input file: BAM/SAM file')
                        
#     required.add_argument('-o', '--out', nargs=1, required=True, help='the output filename prefix')
    required.add_argument('-o','--out', dest='out', nargs='?', required=True, help='Output file prefix')

    required.add_argument('-x', dest='annoglb',nargs='+', required=True,
                    help='The filename of the indexed genome')

#     required.add_argument('-g','--genome', metavar='genome', dest='genome', type=str, nargs='?', default='mm10', choices=['hg38','mm10',], required=True,
#                         help='"hg38" for human, "mm10" for mouse')

    
#     required.add_argument('-f', '--infastq', nargs=1, required=True, help='THe FASTQ file containing the barcode read')
#     required.add_argument('-o', '--outbam', nargs=1, required=True, help='the BAM alignment file to save the result into')
#     required.add_argument('-w', '--obswhite', nargs=1, required=True, help='A txt file to save the observed barcode whitelist to')

    parser._action_groups.append(optional)

    logging.basicConfig(level=logging.DEBUG,
                    format='%(levelname)-8s: %(message)s',
                    datefmt='%m-%d %H:%M')

    parser.log = logging.getLogger('scTE_scatacseq')

    return parser

def main():
    assert sys.version_info >= (3, 6), 'Python >=3.6 is required'

    timestart=datetime.datetime.now()
    
#     args=read_opts(prepare_parser())
    parser = prepare_parser()
    args = parser.parse_args()
    info = logging.info

    logger = parser.log
    
    if args.CB == 'True': args.CB = True
    else: args.CB = False
    if args.hdf5 == 'True': args.hdf5 = True
    else: args.hdf5 = False
    if args.noDup == 'True': args.noDup = True
    else: args.noDup = False
    if args.UMI == 'True': args.UMI = True
    else: args.UMI = False
    
    args.genenumber = 0
    args.cellnumber = 1e4

    logger.info('Arguments:')
    logger.info('out: %s' % args.out)
    logger.info('index: %s \n' % args.annoglb[0])
    logger.info("Minimum number of counts required = %s"% args.countnumber)
    logger.info("Number of threads = %s " % args.thread)

    outname = args.out.split('/')[-1:][0]
    
    info("Loading the genome annotation index... %s"%(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")))
    allelement, chr_list, all_annot, glannot = Readanno(filename=outname, annoglb=args.annoglb[0])
    chr_list = [ k for k in chr_list if k not in ['chrM']]
    info("Finished loading the genome annotation index... %s \n"%(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")))

    info("Processing BAM/SAM files ...%s"%(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")))

    if len(args.input) == 1 and ',' in args.input[0]:
        args.input=args.input[0].split(',')
    
    if not os.path.exists('%s_scTEtmp/o1'%outname):
        os.system('mkdir -p %s_scTEtmp/o1'%outname)

    if len(args.input) > 1:
        info('Using para_atacBam2bed as more than 1 input BAM')
        pool=multiprocessing.Pool(processes=args.thread)
        partial_work = partial(para_atacBam2bed, CB=args.CB,out=outname, noDup=args.noDup)
        pool.map(partial_work, args.input)

        os.system('gunzip -c -f %s_scTEtmp/o0/*.bed.gz | gzip > %s_scTEtmp/o1/%s.bed.gz' % (outname,outname,outname))
    else:
        atacBam2bed(args.input[0], outname, CB=args.CB, UMI=args.UMI, noDup=args.noDup, num_threads=args.thread)
    info("Done BAM/SAM files processing ...%s \n"%(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")))
    
    info("Splitting ...%s"%(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")))
    if args.thread == 1: #Single thread path, mainly
        # This is useful for testing optimsations, as the multiprocessing path the profile
        # Just gets locked up in {method 'acquire' of '_thread.lock' objects}
        info('Executing single thread path')
        whitelist = splitAllChrs(chr_list, filename=outname, genenumber=args.genenumber, countnumber=args.countnumber, UMI=args.UMI)
    else:
        info('Executing multiple thread path with %s threads' % args.thread)
        pool=multiprocessing.Pool(processes=args.thread)
        partial_work = partial(splitChr, filename=outname, CB=args.CB, UMI=args.UMI)
        pool.map(partial_work, chr_list)
        whitelist = filterCRs(filename=outname, genenumber=args.genenumber, countnumber=args.countnumber)

    info("Finished processing sample files %s \n"%(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")))
    
    info("Fetching from the annotation index... %s"%(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")))
    if args.thread == 1: #Single thread path
        for chrom in chr_list:
            align(chr=chrom, filename=outname, all_annot=None, glannot=glannot, whitelist=whitelist) #, CB=args.CB

    else: # Multiprocessing path:
        pool = multiprocessing.Pool(processes=args.thread)
        partial_work = partial(align, filename=outname, all_annot=all_annot, glannot=None, whitelist=whitelist ) # send a copy of the index , CB=args.CB
        pool.map(partial_work, chr_list)

    if not os.path.exists('%s_scTEtmp/o4'%outname):
        os.system('mkdir -p %s_scTEtmp/o4'%outname)
    os.system('gunzip -c -f %s_scTEtmp/o3/%s.*.bed.gz | gzip > %s_scTEtmp/o4/%s.bed.gz' % (outname,outname,outname,outname))
    info("Done fetching... %s \n"%(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")))

    info("Calculating expression... %s"%(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")))
    len_res, genenumber, filename = Countexpression(filename=outname, allelement=allelement, genenumber=args.genenumber, cellnumber=args.cellnumber,hdf5=args.hdf5)
    info('Detect {0} cells expressed at least {1} genes, results output to {2}.csv'.format(len_res, genenumber, filename))
    info("Finished calculating expression %s"%(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")))

    if args.keeptmp == 'True':
        pass
    else:
        os.system('rm -rf %s_scTEtmp'%outname)

    timeend = datetime.datetime.now()
    info("Done with %s\n" % timediff(timestart,timeend))


    if args.ondisk: # Cleanup the DB
        os.remove(tmpfilename)

if __name__ == '__main__':
    try:
        main()
    except KeyboardInterrupt:
        sys.stderr.write("User interrupt\n")
        sys.exit(0)
