#!python

"""
This tool will convert one or more SAM/BAM file/s into a pair of stranded
BigWig files containing the 5' counts. The input files can be either stored
locally or remote URLs. When multiple files are provided, the reads are
concatenated. Only fully mapped reads are kept.
"""

import gzip
import numpy
import argparse
import collections

import pysam
import pyBigWig
import pyfaidx

from tqdm import tqdm
from joblib import Parallel
from joblib import delayed

### Load the user-specified arguments

# Base arguments

parser = argparse.ArgumentParser(
    prog='bam2bw',
    description='This tool will convert BAM files to bigwig files without an intermediate.')

parser.add_argument('filename', nargs='+',
	help="""The SAM/BAM or tsv/tsv.gz file to be processed.""")         
parser.add_argument('-s', '--sizes', required=True, 
	help="""A chrom_sizes or FASTA file.""")

# Data properties

parser.add_argument('-u', '--unstranded', action='store_true',
    help="Have only one, unstranded, output.")
parser.add_argument('-f', '--fragments', action='store_true', default=False,
	help="The data is fragments and so both ends should be recorded.")

parser.add_argument('-ps', '--pos_shift', default=0, type=int,
	help="""A shift to apply to positive strand reads.""")
parser.add_argument('-ns', '--neg_shift', default=0, type=int,
	help="""A shift to apply to negative strand reads.""") 

parser.add_argument('-sf', '--scale_factor', default=1, type=float,
	help="""A scaling factor to multiply each position by.""")
parser.add_argument('-r', '--read_depth', default=False, action='store_true',
	help="""Whether to divide through by total (pre-scaled) read depth.""")

# Misc arguments

parser.add_argument('-p', '--parallel', default=1, type=int,
	help="The number of jobs to use, max of one per input file.")
parser.add_argument('-n', '--name', required=True)
parser.add_argument('-z', '--zooms', default=0, type=int,
    help="""The number of zooms to store in the bigwig.""")
parser.add_argument('-v', '--verbose', action='store_true')
args = parser.parse_args()


###

# First check that the files are the right format

acceptable_formats = '.bam', '.sam', '.tsv', '.tsv.gz', '.bed', '.bed.gz'

for filename in args.filename:
	if not filename.endswith(acceptable_formats):
		raise ValueError("Filenames must end in one of {}.".format(', '.join(acceptable_formats)))

# Here, we are determining the chromosomes and their sizes. This is necessary
# for creating the bigWig header(s) and for figuring out which reads to filter
# out (those that do not map to the provided chromosomes).
#
# Because we allow you to provide either a two-column chrom_sizes file or a
# FASTA file, we need code to handle the two situations. We use pyfaidx to
# quickly process the FASTA file so that we do not have to scan through the
# entire thing just to get the sizes.

chrom_sizes = []

# If provided a FASTA file, read the lengths of the sequences
if args.sizes.endswith(".fa"):
	fa = pyfaidx.Fasta(args.sizes)
	for chrom, seq in fa.items():
		chrom_sizes.append((chrom, len(seq)))

# If provided a chrom_sizes file, just use the provided lengths
else:
	with open(args.sizes, "r") as size_file:
		for line in size_file:
			chrom, size = line.strip("\r\n").split()
			chrom_sizes.append((chrom, int(size)))

###

# This is the main loop that goes through the reads and records them in
# one or two dictionaries (depending on if the data is stranded). The
# processing of BAM/SAM files relies on pysam, whereas tsv/tsv.gz files use
# basic file iteration. The processing of reads in both cases is largely the
# same except that, usually, the -f flag will be passed in for .tsv/.tsv.gz
# files because those come from fragments from ATAC-seq-like experiments,
# whereas BAM/SAM files are usually just reads.

# First, a function that reads a single file and returns one or two
# dictionaries of reads. This can be parallelized across files. 

def extract_reads(args, chrom_sizes, idx):
	missing_chroms = set()
	pos_reads = {}
	neg_reads = pos_reads if args.unstranded else {}

	# Create dictionaries for each chrom, regardless
	for chrom, _ in chrom_sizes:
		pos_reads[chrom] = collections.defaultdict(int)
		neg_reads[chrom] = collections.defaultdict(int) # Redundant if unstranded

	# Extract the file being considered	
	filename = args.filename[idx]
	name = filename.split("/")[-1]
	if filename.endswith(".bam"):
		bam = pysam.AlignmentFile(filename, "rb")
		for read in tqdm(bam.fetch(until_eof=True), disable=not args.verbose, position=idx, desc=name):
			if read.is_unmapped:
				continue

			# Check whether the chrom is in the allowable chroms. Otherwise, discard.
			
			chrom = read.reference_name
			if chrom not in pos_reads:
				if chrom not in missing_chroms:
					missing_chroms.add(chrom)
					if args.verbose:
						print("{} encountered in input but not in FASTA/chrom sizes.".format(
							chrom))
					
				continue
			
			start = read.reference_start + args.pos_shift
			end = read.reference_end + args.neg_shift
			
			# Here, we need to deal with two related issues.
			#
			#    (1) Does the read map to the fwd or rev strand?
			#    (2) Are we mapping the start or the strand and the end (fragments)?
			#
			# Accordingly, we first check to see the strand the read is on and take the
			# start of the read (start for fwd, end-1 for bwd reads). Then, we need to
			# check whether we want both starts and ends and record both if so. This
			# strategy works even if the underlying data is not stranded because
			# pos_reads and neg_reads are the same dictionary in that case.
			
			if read.is_forward:
				pos_reads[chrom][start] += 1
				if args.fragments:
					pos_reads[chrom][end-1] += 1
				
			else:
				neg_reads[chrom][end-1] += 1
				if args.fragments:
					neg_reads[chrom][start] += 1
		
		bam.close()
	
	elif filename[-4:] in ('.tsv', '.bed') or filename[-7:] in ('.tsv.gz', '.bed.gz'):
		# Open the file using the correct opener -- the standard one if the file is
		# not compressed, otherwise the gzip opener if gzipped.
		
		if filename[-4:] in ('.tsv', '.bed'):
			f = open(filename, "r")
		elif filename[-7:] in ('.tsv.gz', '.bed.gz'):
			f = gzip.open(filename, "rt")
		
		# Here, we process the entries in a similar manner to using pysam except that
		# we assume the coordinates are all fwd strand. We do not explicitly assume
		# that we want both the start and the end of the entry, which is controlled
		# using the -f flag, but we do not handle strandedness here.

		for line in tqdm(f, disable=not args.verbose):
			# Check whether the chrom is in the allowable chroms. Otherwise, discard.
			
			chrom, start, end = line.split()[:3]
			if chrom not in pos_reads:
				if chrom not in missing_chroms:
					missing_chroms.add(chrom)
					if args.verbose:
						print("{} encountered in input but not in FASTA/chrom sizes.".format(
							chrom))
				
				continue
				
			start = int(float(start)) + args.pos_shift
			end = int(float(end)) + args.neg_shift
			
			pos_reads[chrom][start] += 1
			if args.fragments:
				pos_reads[chrom][end-1] += 1

	return pos_reads, neg_reads

f = delayed(extract_reads)
print(args.parallel)
reads = Parallel(n_jobs=args.parallel, backend='multiprocessing')(
	f(args, chrom_sizes, i) for i in range(len(args.filename))
)

### Collect the reads into a single object.

pos_reads, neg_reads = reads[0]
if len(reads) > 1:
	for pos_reads_, neg_reads_ in reads[1:]:
		for chrom, reads_ in pos_reads_.items():
			for idx, count in reads_.items():
				if idx in pos_reads[chrom]:
					pos_reads[chrom][idx] += count
				else:
					pos_reads[chrom][idx] = count

		if not args.unstranded:
			for chrom, reads_ in neg_reads_.items():
				for idx, count in reads_.items():
					if idx in neg_reads[chrom]:
						neg_reads[chrom][idx] += count
					else:
						neg_reads[chrom][idx] = count

###

# Now that we have our dictionary(ies) of reads, we need to create bigWig
# objects and store them. Because the entries need to be sorted along the
# length of each chromosome we have to convert the dictionaries to numpy arrays
# and then sort them, but this usually is not that big of a hassle. It is much
# faster to sort the arrays at the end like this that it is to try to keep
# everything in order as you see the reads.

# Here, we open the bigWigs that we will be saving data into. If the data is
# stranded, we are saving two bigWigs. If the data is not stranded, we are only
# saving one bigWig.

if args.unstranded:
    bw_pos = pyBigWig.open(args.name + ".bw", "w")
    bw_pos.addHeader(chrom_sizes, maxZooms=args.zooms)

else:
    bw_pos = pyBigWig.open(args.name + ".+.bw", "w")
    bw_neg = pyBigWig.open(args.name + ".-.bw", "w")

    bw_pos.addHeader(chrom_sizes, maxZooms=args.zooms)
    bw_neg.addHeader(chrom_sizes, maxZooms=args.zooms)


# We use pyBigWig for our tool to create bigWig files. We choose to save
# entries using only the coordinate and the value, so two numbers per non-zero
# position, rather than as spans of size 1 which would be three numbers per
# non-zero position. This reduces file size and also I/O time.

if args.read_depth:
	read_depth = sum([sum(pos_reads[chrom].values()) for chrom, _ in chrom_sizes])
	if not args.unstranded:
		read_depth += sum([sum(neg_reads[chrom].values()) for chrom, _ in chrom_sizes])


for chrom, _ in chrom_sizes:
	reads = pos_reads[chrom]
	if len(reads) > 0:
		pos_starts = numpy.array(list(reads.keys()), dtype='int64')
		pos_values = numpy.array(list(reads.values()), dtype='float64')
		pos_values *= args.scale_factor

		if args.read_depth:
			pos_values /= read_depth

		idxs = numpy.argsort(pos_starts)
		bw_pos.addEntries(chrom, pos_starts[idxs], values=pos_values[idxs], span=1)

	###
	
	reads = neg_reads[chrom]
	if len(reads) > 0 and not args.unstranded:
		neg_starts = numpy.array(list(reads.keys()), dtype='int64')
		neg_values = numpy.array(list(reads.values()), dtype='float64')
		neg_values *= args.scale_factor

		if args.read_depth:
			neg_values /= read_depth

		idxs = numpy.argsort(neg_starts)
		bw_neg.addEntries(chrom, neg_starts[idxs], values=neg_values[idxs], span=1)

bw_pos.close()
if not args.unstranded:
	bw_neg.close()
