#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import time
from argparse import ArgumentParser

import rnalib

"""Script to call rnalib tools."""

description = f'''
  Rnalib {rnalib.__version__} Copyright (C) 2024.  All rights reserved.\n
  Distributed on an "AS IS" basis without warranties or conditions of any kind, either express or implied.
'''


def main():
    parser = ArgumentParser(description=description)
    sp = parser.add_subparsers(title='tools', description='', dest="mode")
    # ============================================================================
    p = sp.add_parser('create_testdata', help='Create test data directory')
    p.add_argument("-o", "--out_dir", type=str, required=True, help='Output directory.')
    p.add_argument("-a", "--all", action='store_true', required=False, help='if set, both sets of \
                resources are created (small + large)', default=False)
    # ============================================================================
    p = sp.add_parser('tag_tc', help='Annotate T-to-C reads',
                      epilog="Example: rnalib tag_tc -b <bam> -o ./test -c 'chr20' -q 20 -F 0 -Q 20")
    p.add_argument("-b", "--bam", type=str, required=True, dest="bam_file",
                   metavar="bam", help="BAM file")
    p.add_argument("-o", "--out_prefix", type=str, required=False, dest="out_prefix",
                   metavar="out_prefix",
                   help="output out_prefix (default is bam_filepath w/o extension)")
    p.add_argument("-w", "--write_density_histogram", action="store_true", required=False,
                   default=True,
                   dest="write_density_histogram", help="write density histogram")
    p.add_argument("-t", "--tags", type=str, required=False, dest="tags",
                   metavar="tags", help="BAM tags to be used for the output BAM file. "
                                        "Default is {'ntc': 'xc', 'ntt': 'xt', 'col': 'YC'}")
    p.add_argument("-c", "--included_chrom", type=str, required=False,
                   dest="included_chrom", metavar="included_chrom",
                   help="chromosome(s) to be included. Default: all")
    p.add_argument("-s", "--snp_vcf", type=str, required=False, dest="snp_vcf_file",
                   metavar="snp_vcf_file", help="VCF file containing variant calls. "
                                                "T/C and A/G SNPs will be masked")
    p.add_argument("-f", "--fractional_counts", action="store_true", required=False,
                   dest="fractional_counts", help="fractional counts")
    p.add_argument("-q", "--min_mapping_quality", type=int, required=False,
                   dest="min_mapping_quality", default=0,
                   metavar="min_mapping_quality", help="minimum mapping quality")
    p.add_argument("-F", "--flag_filter", type=int, required=False, dest="flag_filter",
                   default=rnalib.DEFAULT_FLAG_FILTER,
                   metavar="flag_filter", help="flag filter for filtering reads. Default is "
                                               "rnalib.DEFAULT_FLAG_FILTER")
    p.add_argument("-Q", "--min_base_quality", type=int, required=False,
                   default=10, dest="min_base_quality", metavar="min_base_quality",
                   help="minimum base quality")
    p.add_argument("-r", "--reverse_strand", action="store_true", required=False,
                   default=False,
                   dest="reverse_strand", help="reverse read strand (default is False)")

    # ============================================================================
    p = sp.add_parser('filter_tc', help='Filter T-to-C reads',
                      epilog="Example: rnalib filter_tc -b <bam> -o test_tc2.bam -m 2 (write only reads with at least 2 T/C conversions)")
    p.add_argument("-b", "--bam", type=str, required=True, dest="bam_file",
                   metavar="bam", help="BAM file")
    p.add_argument("-o", "--outfile", type=str, required=False, dest="outfile",
                   metavar="outfile", help="output file (default is {bam_file_name}_tc-only.bam)")
    p.add_argument("-m", "--min_tc", type=int, required=False, dest="min_tc",
                   metavar="min_tc", help="minimum number of T/C conversions (default is 1)")
    p.add_argument("--tags", type=str, required=False, dest="tags", metavar="tags",
                   help="BAM tags to be used for the output BAM file. Default is {'ntc': 'xc', 'ntt': 'xt', 'col': 'YC'}")
    # ============================================================================
    p = sp.add_parser('prune_tags', help='Remove TAGs from BAM file')
    p.add_argument("-b", "--bam_file", type=str, required=True, help='Input BAM file.')
    p.add_argument("-o", "--out_file", type=str, required=True, help='Output BAM file.')
    p.add_argument("-t", "--kept_tags", type=str, required=False, help='Comma-separated list of tags that are kept (default: ("NH", "HI", "AS"))', default=False)
    # ============================================================================
    p = sp.add_parser('build_amplicon_resources', help='Build amplicon resource files')
    p.add_argument("-n", "--transcriptome_name", required=False,
                   default='transcriptome', dest="transcriptome_name",
                   help="Name of the transcriptome (used for naming files)")
    p.add_argument("-b", "--bed_file", required=True, dest="bed_file",
                   help="BED file with transcript coordinates")
    p.add_argument("-f", "--fasta_file", required=True, dest="fasta_file",
                   help="FASTA file (reference genome)")
    p.add_argument("-o", "--out_dir", required=True, dest="out_dir",
                   help="Output directory")
    p.add_argument("-p", "--padding", required=False, default=100,
                   dest="padding", help="Pad transcripts by this amount (default is 100)")
    p.add_argument("-e", "--amp_extension", required=False, default=100,
                   dest="amp_extension",
                   help="Extend transcripts by this amount (default is 100)")
    # ============================================================================
    p = sp.add_parser('calculate_mm_profile', help='Calculate a mismatch profile from a BAM file')
    p.add_argument("-b", "--bam_file", type=str, required=True, help='Input BAM file.')
    p.add_argument("-g", "--gff_file", type=str, required=True, help='Annotation GFF file.')
    p.add_argument("-f", "--fasta_file", type=str, required=True, help='Reference sequence FASTA file.')
    p.add_argument("-a", "--annotation_flavour", type=str, required=False, default="gencode",
                   help='Annotation flavour (default is gencode).')
    p.add_argument("-mc", "--min_cov", type=int, required=False, default=10, help='Minimum coverage (default is 10).')
    p.add_argument("-mf", "--max_mm_frac", type=float, required=False, default=0.1, help='Maximum mismatch fraction (default is 0.1).')
    p.add_argument("-ms", "--max_sample", type=int, required=False, default=1e6, help='Maximum number of samples ('
                                                                                      'default is 1e6).')
    p.add_argument("-s", "--strand_specific", action="store_true", required=False, default=False, help='Strand specific (default is False).')
    p.add_argument("-o", "--out_file", type=str, required=False, help='Output TSV file.')
    # ============================================================================
    args = parser.parse_args()
    start_time = time.time()
    if args.mode == "tag_tc":
        from rnalib.tools import tag_tc
        tag_tc(args.bam_file,
               out_prefix=args.out_prefix,
               write_density_histogram=args.write_density_histogram,
               tags=args.tags,
               included_chrom=args.included_chrom,
               snp_vcf_file=args.snp_vcf_file,
               fractional_counts=args.fractional_counts,
               min_mapping_quality=args.min_mapping_quality,
               flag_filter=args.flag_filter,
               min_base_quality=args.min_base_quality,
               reverse_strand=args.reverse_strand)
    elif args.mode == "filter_tc":
        from rnalib.tools import filter_tc
        filter_tc(bam_file=args.bam_file, out_file=args.outfile,
                  min_tc=args.min_tc, tags=args.tags)
    elif args.mode == "build_amplicon_resources":
        from rnalib.tools import build_amplicon_resources
        build_amplicon_resources(
            transcriptome_name=args.transcriptome_name,
            bed_file=args.bed_file, fasta_file=args.fasta_file,
            out_dir=args.out_dir, padding=args.padding, amp_extension=args.amp_extension)
    elif args.mode == "create_testdata":
        from rnalib.testdata import create_testdata, test_resources, large_test_resources
        create_testdata(args.out_dir, test_resources, show_data_dir=True)
        if args.all:
            create_testdata(args.out_dir, large_test_resources, show_data_dir=True)
        print("Done. Add the following monkeypatch if you want to use get_resource(<name>) in your code:")
        print(f"import rnalib; rnalib.__RNALIB_TESTDATA__ = '{args.out_dir}'")
    elif args.mode == "prune_tags":
        from rnalib.tools import prune_tags
        prune_tags(args.bam_file, args.out_file, args.kept_tags.split(',') if args.kept_tags else ("NH", "HI", "AS"))
    elif args.mode == "calculate_mm_profile":
        from rnalib.tools import calculate_mm_profile
        calculate_mm_profile(bam_file=args.bam_file,
                             gff_file=args.gff_file,
                             fasta_file=args.fasta_file,
                             annotation_flavour=args.annotation_flavour,
                             min_cov=args.min_cov,
                             max_mm_frac=args.max_mm_frac,
                             max_sample=args.max_sample,
                             strand_specific=args.strand_specific,
                             out_file=args.out_file)
    else:
        raise parser.parse_args(["-h"])
    print(f"All done in {time.time() - start_time} sec.")


if __name__ == '__main__':
    main()
