#!/usr/bin/env python

import os
import sys
import argparse
import gzip
import csv
import logging
import subprocess
import multiprocessing as mp
import matplotlib
import random
random.seed(1)
import itertools
import sqlite3
import warnings
from matplotlib.colors import NoNorm

# Force matplotlib to not use any Xwindows backend.
matplotlib.use('Agg')

# Illustrator compatibility
new_rc_params = {'text.usetex': False, "svg.fonttype": 'none'}
matplotlib.rcParams.update(new_rc_params)

import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import matplotlib.colors as mcolors
from matplotlib import gridspec
from matplotlib.patches import ConnectionPatch
import seaborn as sns

from sklearn import metrics
from sklearn.mixture import GaussianMixture
from sklearn.exceptions import ConvergenceWarning


import numpy as np
import pandas as pd
import pysam
import scipy.stats as ss

from pandas.api.types import CategoricalDtype

skbio_installed = True
conflicting_call_warning = False

try:
    import skbio.alignment as skalign
    import skbio.sequence as skseq

except ModuleNotFoundError:
    skbio_installed = False

if skbio_installed:
    warnings.filterwarnings('ignore', module='skbio')

from tqdm import tqdm
from re import finditer

from uuid import uuid4
from collections import defaultdict as dd
from collections import Counter
from collections import namedtuple
from itertools import product
from operator import itemgetter
from copy import deepcopy

ont_fast5_api_installed = True

try:
    from ont_fast5_api.fast5_interface import get_fast5_file

except ModuleNotFoundError:
    ont_fast5_api_installed = False

from bx.intervals.intersection import Intersecter, Interval

FORMAT = '%(asctime)s %(message)s'
logging.basicConfig(format=FORMAT)
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

from time import sleep

class Read:
    def __init__(self, read_name, cpg_loc, stat, methcall, modname, phase=None):
        self.read_name  = read_name
        self.llrs       = {}
        self.meth_calls = {}
        self.phase      = phase
        self.modname    = modname
        self.call_count = 0

        # used for locus per-read plot
        self.ypos   = None
        self.starts = []
        self.ends   = []

        self.add_mod(cpg_loc, stat, methcall, modname)

    def add_mod(self, cpg_loc, stat, methcall, modname, warn=True):
        assert methcall in (-1,0,1)

        global conflicting_call_warning

        if cpg_loc in self.llrs:
            if not conflicting_call_warning:
                if warn:
                    logger.warning('''warning: multiple modification calls for the same site + mod, read: %s  cpg_loc: %d  mod: %s
                            consider excluding non-primary alignments from your analysis, or if using a .bam with mod tags try --primary_only
                            this warning is only shown once per thread'''% (self.read_name, cpg_loc, modname))

            conflicting_call_warning = True

        self.llrs[cpg_loc]       = stat
        self.meth_calls[cpg_loc] = methcall
        assert self.modname == modname

        if methcall != 0:
            self.call_count += 1

    def overlap(self, other):
        return min(max(self.ends), max(other.ends)) - max(min(self.starts), min(other.starts)) > 0
    
    def mod_count(self):
        return len([c for c in self.meth_calls.values() if c == 1])

    def mod_count_region(self, start, end):
        calls = 0

        for loc, call in self.meth_calls.items():
            if loc < start or loc > end:
                continue
            if call == 1:
                calls += 1

        return calls

    def site_count(self):
        return len(self.meth_calls)

    def site_count_region(self, start, end):
        sites = 0
        for site in self.meth_calls:
            if site < start or site > end:
                continue
            sites += 1

        return sites

    def mod_frac(self):
        if self.site_count() == 0:
            return 0

        return self.mod_count()/self.site_count()
    
    def mod_frac_region(self, start, end):
        local_site_count = self.site_count_region(start, end)
        if local_site_count == 0:
            return 0
        
        return self.mod_count_region(start,end)/local_site_count


class Gene:
    def __init__(self, ensg, name, strand):
        self.ensg = ensg
        self.name = name
        self.strand = strand
        self.tx_start = None
        self.tx_end = None
        self.cds_start = None
        self.cds_end = None
        self.exons = []

    def add_exon(self, block):
        assert len(block) == 2

        if block[0] > block[1]:
            block[0], block[1] = block[1], block[0]

        self.exons.append(block)
        self.exons = sorted(self.exons, key=itemgetter(0))

    def add_tx(self, block):
        assert len(block) == 2
        assert block[0] < block[1]
        if self.tx_start is None or self.tx_start > block[0]:
            self.tx_start = block[0]

        if self.tx_end is None or self.tx_end < block[1]:
            self.tx_end = block[1]

    def add_cds(self, block):
        assert len(block) == 2
        if block[0] > block[1]:
            logger.warning('CDS block start > end in gene %s' % self.ensg)
            return None

        if self.cds_start is None or self.cds_start > block[0]:
            self.cds_start = block[0]

        if self.cds_end is None or self.cds_end < block[1]:
            self.cds_end = block[1]

    def has_tx(self):
        return None not in (self.tx_start, self.tx_end)

    def has_cds(self):
        return None not in (self.cds_start, self.cds_end)

    def merge_exons(self):
        new_exons = []
        if len(self.exons) == 0:
            return

        last_block = self.exons[0]

        for block in self.exons[1:]:
            if min(block[1], last_block[1]) - max(block[0], last_block[0]) > 0: # overlap
                last_block = [min(block[0], last_block[0]), max(block[1], last_block[1])]

            else:
                new_exons.append(last_block)
                last_block = block

        new_exons.append(last_block)

        self.exons = new_exons


RC = {'A':'T', 'T':'A', 'C':'G', 'G':'C', 'N':'N'}

class Motif:
    def __init__(self, motif):

        self.motif = motif

        self.left_bases  = ''
        self.right_bases = ''
        self.key_base    = ''
        self.full_seq    = ''

        assert len(motif.split('[')) == 2, 'bad motif syntax: %s' % motif
        assert len(motif.split(']')) == 2, 'bad motif syntax: %s' % motif

        self.left_base   = motif.split('[')[0]
        self.right_bases = motif.split(']')[1]
        self.key_base    = motif.split('[')[1].split(']')[0]
        self.full_seq    = self.left_bases + self.key_base + self.right_bases
        
        assert self.key_base in 'ATCG', 'bad motif syntax: %s' % motif


    def match_seq(self, seq):
        ''' returns 0-based coords of key base for matches in seq '''

        sites = dd(list) 

        seq = list(seq.upper())


        for i in range(len(seq)-len(self.full_seq)):
            nmer = ''.join(seq[i:i+len(self.full_seq)])
            if nmer == self.full_seq:
                sites[i + len(self.left_bases)].append(self.motif)

        return sites
        

class Mod:
    def __init__(self, read_id, read_pos, read_base, sites, mods):
        self.read_id     = read_id
        self.read_pos    = read_pos
        self.read_base   = read_base
        self.sites       = sites
        self.mods        = mods
        self.chrom       = None
        self.genome_pos  = None
        self.genome_base = None
        self.aln_strand  = None

    def __str__(self):
        return '\t'.join((self.read_id, self.chrom, str(self.genome_pos), self.genome_base, self.aln_strand, str(self.read_pos), self.read_base, ','.join(self.sites), '\t'.join(map(str, self.mods))))


def base_field(mod_metadata):
    long_names = mod_metadata['modified_base_long_names'].split()

    bases = []

    i = 0
    alpha_field_name = None

    if 'output_alphabet' in mod_metadata:
        alpha_field_name = 'output_alphabet'

    elif 'modified_base_alphabet' in mod_metadata:
        alpha_field_name = 'modified_base_alphabet'

    for base in list(mod_metadata[alpha_field_name]):
        if base not in 'ACTG':
            bases.append(long_names[i])
            i += 1
        else:
            bases.append(base)

    return bases


def pandas_df(out, mod_base, can_base):
    data = dd(dict)

    header = []

    for i, line in enumerate(out):
        if i == 0:
            header = line.strip().split()
            continue

        for j, val in enumerate(line.strip().split('\t')):
            data[i][header[j]] = val

    data = pd.DataFrame.from_dict(data).T
    data = plot_data = pd.DataFrame(data.to_dict())

    data['can_log_prob'] = np.log((pd.to_numeric(data[can_base])+1)/256)
    data['mod_log_prob'] = np.log((pd.to_numeric(data[mod_base])+1)/256)
    data['modstat'] = data['mod_log_prob'] - data['can_log_prob']

    return data


def remove_all_text(ax):
    ax.set_xlabel('')
    ax.set_ylabel('')
    ax.set_title('')
    
    ax.set_xticklabels([])
    ax.set_yticklabels([])

    ax.set_xticks([])
    ax.set_yticks([])

    legend = ax.get_legend()
    if legend:
        legend.remove()
    
    for text in ax.texts:
        text.remove()


def guppy_f5_fetch(fast5, bam_db, args):
    bam = pysam.AlignmentFile(args.bam)
    ref = pysam.Fastafile(args.ref)

    bases = []
    modbases = dd(list)

    m = Motif(args.motif)

    out = []

    with get_fast5_file(args.fast5 + '/' + fast5, mode="r") as f5:
        for read_id in f5.get_read_ids():
            read = f5.get_read(read_id)
            latest_basecall = read.get_latest_analysis('Basecall_1D')

            mod_base_table = read.get_analysis_dataset(latest_basecall, 'BaseCalled_template/ModBaseProbs')
            called_base_table = read.get_analysis_dataset(latest_basecall, 'BaseCalled_template/Fastq')

            mod_table_path = '{}/BaseCalled_template/ModBaseProbs'.format(latest_basecall)
            called_table_path = '{}/BaseCalled_template/Fastq'.format(latest_basecall)

            mod_metadata = read.get_analysis_attributes(mod_table_path)
            called_metadata = read.get_analysis_attributes(called_table_path)

            if None in (mod_metadata, called_metadata):
                logger.error('fast5 %s does not seem to contain basecalling information' % fast5)
                return None

            if len(bases) == 0:
                bases = base_field(mod_metadata)

                mod_names = [b for b in bases if b not in ('ATCG')]

                if args.modname not in mod_names:
                    logger.error('mod "%s" not found in metadata. Available mod names: %s' % (args.modname, ','.join(mod_names)))
                    return None

                header = "readname\tchrom\tpos\tgenome_base\tstrand\tread_pos\tread_base\tmotif\t" + '\t'.join(bases) + "\n"

            seq = called_base_table.split('\n')[1]

            assert len(seq) == len(mod_base_table)

            sites = m.match_seq(seq)

            for i, mods in enumerate(mod_base_table):
                if i in sites:
                    modbases[read_id].append(Mod(read_id, i, seq[i], sites[i], mods))

    logger.info('processed %d reads from %s' % (len(modbases), fast5))

    logger.info('parsing alignments from %s' % bam_db)


    out.append(header)

    conn = sqlite3.connect(bam_db)
    c = conn.cursor()

    for readname in modbases:
        for row in c.execute("SELECT sam FROM bam WHERE readname='%s'" % readname):
            read = pysam.AlignedSegment.fromstring(row[0], bam.header)

            pos_lookup = {}
            for ap in read.get_aligned_pairs():
                if None not in ap:
                    pos_lookup[ap[0]] = ap[1]

            for modbase in modbases[read.query_name]:
                modbase.aln_strand = '+'
                aligned_read_pos = modbase.read_pos

                if read.is_reverse:
                    modbase.aln_strand = '-'
                    aligned_read_pos = (len(read.seq) - modbase.read_pos)-1 # SAM format stores read on + strand

                if aligned_read_pos in pos_lookup:
                    modbase.chrom = read.reference_name
                    modbase.genome_pos = pos_lookup[aligned_read_pos]
                    modbase.genome_base = ref.fetch(modbase.chrom, modbase.genome_pos, modbase.genome_pos+1)

                    modbase.genome_base = modbase.genome_base.upper()

                    if args.include_unmatched:
                        out.append(str(modbase) + '\n')
                        
                    else:
                        if modbase.aln_strand == '+' and modbase.genome_base == modbase.read_base:
                            out.append(str(modbase) + '\n')

                        if modbase.aln_strand == '-' and RC[modbase.genome_base] == modbase.read_base:
                            out.append(str(modbase) + '\n')

    
    data = pandas_df(out, args.modname, m.key_base)
    out_fn = args.fast5 + '/' + '.'.join(fast5.split('.')[:-1]) + '.gupmod.tsv'
    logger.info('writing output from %s to %s' % (fast5, out_fn))
    data.to_csv(out_fn, sep='\t', index=False)

    return out_fn


def exclude_ambiguous_reads(fn, chrom, start, end, min_mapq=10, spanning_only=False):
    reads = []

    bam = pysam.AlignmentFile(fn)
    for read in bam.fetch(chrom, start, end):
        p = read.get_reference_positions()
        if p[0] < start or p[-1] > end:
            if read.mapq >= min_mapq:
                if spanning_only:
                    if read.reference_start < start and read.reference_end > end:
                        reads.append(read.query_name)
                else:
                    reads.append(read.query_name)

    return reads


def get_ambiguous_reads(fn, chrom, start, end, min_mapq=10, w=50):
    reads = []

    bam = pysam.AlignmentFile(fn)
    for read in bam.fetch(chrom, start, end):
        p = read.get_reference_positions()
        if read.mapq < min_mapq or (p[0] > start-w and p[-1] < end+w):
            reads.append(read.query_name)

    return reads


def get_reads(fn, chrom, start, end, min_mapq=10, retry=5, spanning_only=False, primary_only=False):
    reads = []

    retries = 0
    bam_open = False

    while not bam_open:
        try:
            bam = pysam.AlignmentFile(fn)
            bam_open = True

        except OSError as e:
            if retries < retry:
                logger.warning(f'retry open on {fn}, possibly due to network filesystem')
                sleep(1)
                retries += 1
                
            else:
                sys.exit(e)


    for read in bam.fetch(chrom, start, end):
        if primary_only:
            if read.is_secondary or read.is_supplementary:
                continue

        if read.mapq >= min_mapq:
            if spanning_only:
                if read.reference_start < start and read.reference_end > end:
                    reads.append(read.query_name)
            else:
                reads.append(read.query_name)

    return reads


def get_phased_reads(fn, chrom, start, end, min_mapq=10, retry=5, tag_untagged=False, ignore_tags=False, HP_only=False, primary_only=False):
    reads = {}

    retries = 0
    bam_open = False

    while not bam_open:
        try:
            bam = pysam.AlignmentFile(fn)
            bam_open = True

        except OSError as e:
            if retries < retry:
                logger.warning(f'retry open on {fn}, possibly due to network filesystem')
                sleep(1)
                retries += 1
                
            else:
                sys.exit(e)

    for read in bam.fetch(chrom, start, end):
        if primary_only:
            if read.is_secondary or read.is_supplementary:
                continue
            
        if read.mapq >= min_mapq:    
            phase = None

            if tag_untagged or ignore_tags:
                phase = 'unphased'

            HP = None
            PS = None

            if not ignore_tags:
                for tag in read.get_tags():
                    if tag[0] == 'HP':
                        HP = tag[1]
                    if tag[0] == 'PS':
                        PS = tag[1]

            if HP is not None:
                phase = str(HP)

            if PS is not None:
                phase = phase + ':' + str(PS)

                if HP_only:
                    phase = str(HP)

            reads[read.query_name] = phase

    return reads


def get_variant_reads(fn, chrom, start, end, variants, min_mapq=10, primary_only=False):
    reads = {}

    assert len(variants) == 1

    bam = pysam.AlignmentFile(fn)
    for read in bam.fetch(chrom, start, end):
        if read.seq is None:
            continue
        
        if primary_only:
            if read.is_secondary or read.is_supplementary:
                continue
            
        if read.mapq >= min_mapq:    
            allele = None

            alt_sites = check_variants(variants, read, allele='alt')
            if len(alt_sites) > 0:
                allele = 'alt'

            else:
                ref_sites = check_variants(variants, read, allele='ref')
                if len(ref_sites) > 0:
                    allele = 'ref'

            reads[read.query_name] = allele

    return reads


def check_variants(variants, read, allele=None):
    sites = []

    assert allele in ('ref', 'alt')
    
    for q, r in read.get_aligned_pairs():
        if q is not None and r in variants:
            vref, valt, vartype, varlen = variants[r]

            if vartype == 'SNV':
                if allele == 'alt':
                    if read.seq[q].upper() == valt.upper():
                        sites.append(r)

                if allele == 'ref':
                    if read.seq[q].upper() == vref.upper():
                        sites.append(r)
            
            elif vartype in ('INS', 'DEL'):
                sv_gt = has_sv(read, r, variants[r])

                if allele == sv_gt:
                    sites.append(r)

    return sites


def has_sv(read, bnd, variant, window=100):
    if read.seq is None:
        return None

    svtype = variant[2]
    svlen = abs(variant[3])

    assert svtype in ('DEL', 'INS')

    ap = {}

    ap = dict(read.get_aligned_pairs())

    best_rpos = None
    best_gpos = None

    for rpos, gpos in ap.items():
        if gpos is None:
            continue

        if best_gpos is None:
            best_gpos = gpos
            best_rpos = rpos

        else:
            if abs(bnd-gpos) < abs(bnd-best_gpos):
                best_gpos = gpos
                best_rpos = rpos

    if best_rpos < window:
        return None

    if len(read.seq)-best_rpos < window:
        return None

    bnd_right = 0
    bnd_left = 0
    prev_rp = 0

    for rp in range(best_rpos-1, best_rpos+window):
        if rp in ap:
            if ap[rp] is None:
                bnd_right += 1
            elif None not in (ap[rp], ap[prev_rp]) and abs(abs(ap[rp]-ap[prev_rp]) - svlen) < window:
                bnd_right += svlen
            if ap[rp] is not None:
                prev_rp = rp

    prev_rp = 0

    for rp in range(best_rpos-window, best_rpos+1):
        if rp in ap:
            if ap[rp] is None:
                bnd_left += 1
            elif None not in (ap[rp], ap[prev_rp]) and abs(abs(ap[rp]-ap[prev_rp]) - svlen) < window:
                bnd_left += svlen
            if ap[rp] is not None:
                prev_rp = rp

    if svtype == 'INS' and max(bnd_left, bnd_right)*1.25 > window:
        return 'alt'

    if svtype == 'DEL' and max(bnd_left, bnd_right)*1.25 > svlen:
        return 'alt'
    
    return 'ref'


def single_seq_fa(fn):
    with open(fn, 'r') as fa:
        seq   = ''
        for line in fa:
            if line.startswith('>'):
                assert seq == '', 'input fa must have only one entry'
            else:
                seq = seq + line.strip()

    return seq


def rc(dna):
    ''' reverse complement '''
    complements = str.maketrans('acgtrymkbdhvACGTRYMKBDHV', 'tgcayrkmvhdbTGCAYRKMVHDB')
    return dna.translate(complements)[::-1]


def iupac(motif):
    motif = motif.upper()

    iupac = {
        'A':['A'],
        'C':['C'],
        'G':['G'],
        'T':['T'],
        'U':['T'],
        'R':['A','G'],
        'Y':['C','T'],
        'S':['G','C'],
        'W':['A','T'],
        'K':['G','T'],
        'M':['A','C'],
        'B':['C','G','T'],
        'D':['A','G','T'],
        'H':['A','C','T'],
        'V':['A','C','G'],
        'N':['A','C','G','T']
    }

    motifs = []

    for ib in list(motif):
        if ib not in iupac:
            sys.exit('base %s not an IUPAC base, please modify --motif' % ib)

        if len(motifs) == 0:
            for bp in iupac[ib]:
                motifs.append(bp)
        else:
            next_motifs = []

            for bp in iupac[ib]:
                for m in motifs:
                    next_motifs.append(m + bp)

            motifs = next_motifs

    return motifs


def get_modnames(meth_db):
    if not os.path.exists(meth_db):
        sys.exit('methylartist database (%s) does not exist, check that full path is included if not in current working directory.' % meth_db)

    conn = sqlite3.connect(meth_db)
    c = conn.cursor()

    mod_names = []

    for row in c.execute("SELECT DISTINCT mod FROM modnames"):
        mod_names.append(row[0])

    return mod_names


def get_cutoffs(meth_db, mod):
    if not os.path.exists(meth_db):
        sys.exit('methylartist database (%s) does not exist, check that full path is included if not in current working directory.' % meth_db)

    conn = sqlite3.connect(meth_db)
    c = conn.cursor()

    for row in c.execute("SELECT upper,lower FROM cutoffs WHERE modname='%s'" % mod):
        return row


def densecall_filter(reads, max_density=.7):
    ''' reads is a dict of Read objects with modified base calls '''
    out = {}
    filtered_count = 0

    for name, read in reads.items():
        if read.mod_frac() <= max_density:
            out[name] = read
        else:
            filtered_count += 1
    
    logger.debug('filtered %d reads via --max_read_density %f' % (filtered_count, max_density))

    return out


def is_bam(fn):
    for f in fn.split(','):
        try:
            bam = pysam.AlignmentFile(f)
        except:
            return False

    return True


def get_bincover(bam_fn, chrom, start, end):
    start = int(start)
    end = int(end)
    bam = pysam.AlignmentFile(bam_fn)

    count = 0

    for rec in bam.fetch(chrom, start, end):
        if rec.mapping_quality < 10:
            continue

        if not rec.is_duplicate:
            count += 1

    return [start, count]


def bam_bincover(bam_fn, chrom, w_starts, w_ends, procs=1, log=False):
    assert len(w_starts) == len(w_ends)

    pool = mp.Pool(processes=int(procs))
    results = []

    for start, end in zip(w_starts, w_ends):
        res = pool.apply_async(get_bincover, [bam_fn, chrom, start, end])
        results.append(res)

    segs = []
    for res in results:
        segs.append(res.get())

    segs = sorted(segs, key=itemgetter(0))
    cover = np.asarray([s[1] for s in segs])

    if log:
        cover = np.log2(cover+1)

    return cover


def bam_pileupcover(bam_fn, chrom, w_starts, w_ends, procs=1, log=False):
    assert len(w_starts) == len(w_ends)

    segtree = Intersecter()
    for start, end in zip(w_starts, w_ends):
        segtree.add_interval(Interval(start-1, end+1))

    minpos = min(w_starts)
    maxpos = max(w_ends)

    region = '%s:%d-%d' % (chrom, minpos, maxpos)

    segs = []
    samtools_cmd = ['samtools', 'mpileup', '-aa', '-q10', '-r', region, bam_fn]
    
    FNULL = open(os.devnull, 'w')
    p = subprocess.Popen(samtools_cmd, stdout=subprocess.PIPE, stderr=FNULL)

    for pline in p.stdout:
        pline = pline.decode()
        
        chrom = None
        pos = None
        seq = None

        cols = pline.strip().split()

        if len(cols) < 6 or cols[-1] == '*':
            chrom, pos = cols[:2]
            seq = ''

        else:
            chrom, pos, r, dp, seq, qual = cols

        pos = int(pos)
        seq = seq.upper()

        if segtree.find(pos, pos):
            base_depth = len([b for b in seq if b in ('A','T','C','G')])
            segs.append([pos, base_depth])

    segs = sorted(segs, key=itemgetter(0))
    cover = np.asarray([s[1] for s in segs])

    if log:
        cover = np.log2(cover+1)

    return cover
    

def mods_bedmethyl(bed_fn, sample_len=100000):
    mods = []
    samples = 0
    with gzip.open(bed_fn) as bed:
        for line in bed:
            modcol = line.decode().split()[3]
            
            if len(modcol) != 1:
                logger.warning(f'found in single base mod column (4): {modcol}, assuming non-modkit therefore using mod "m" (5mC) for entire bed')
                return ['m']
            
            else:
                if modcol not in mods:
                    mods.append(modcol)

            samples += 1
            if samples > sample_len:
                break

    logger.info('found mods: %s' % ','.join(mods))

    return mods


def mods_methbam(bam_fn):
    bam = pysam.AlignmentFile(bam_fn)

    logger.info('fetching mod types from %s, if this takes awhile MM/ML tags may be missing.' % bam_fn)

    mm_warned = False

    for rec in bam.fetch():
        mm = None

        try:
            mm = str(rec.get_tag('MM')).rstrip(';')
        except KeyError:
            try:
                mm = str(rec.get_tag('Mm')).rstrip(';')
            except KeyError:
                if not mm_warned:
                    logger.debug('cannot find Mm tag in at least one read (example: %s), ensure this bam has MM and ML tags!' % rec.qname)
                    mm_warned = True
                    continue
        
        mods = []

        if mm is None:
            logger.debug('null Mm tag in at least one read (example: %s), ensure this bam has proper MM and ML tags!' % rec.qname)
            mm_warned = True
            continue

        mod_strings = mm.split(';')

        for mod_string in mod_strings:
            m = mod_string.split(',')
            mod_info = m[0]

            mod_strand = '+'
            if '-' in mod_info:
                mod_strand = '-'
            
            mod_base, mod_type = mod_info.split(mod_strand)
            mod_type = mod_type.rstrip('?.')

            mods.append(mod_type)

        mods = list(set(mods))

        logger.info('found mods: %s' % ','.join(mods))

        return mods # assumes all mods are represented for each read that has an Mm tag

    if not mods:
        logger.warning('no mods found in %s, missing MM/ML tags?' % bam_fn)
        logger.warning('If this is C/T substitution data please indicate .bam file with --ctbam')
        sys.exit(1)


def split_ml(mod_strings, ml):
    mls = []

    total_ms = sum([len(ms.split(',')[1:]) for ms in mod_strings])

    if total_ms == 0:
        return mls

    assert total_ms == len(ml), 'mod bam formatting error total_ms: %d, len(ml): %d' % (total_ms, len(ml))

    i = 0
    for mod_string in mod_strings:
        m = mod_string.split(',')[1:] # discard first item (desc of mod base)
        mls.append(ml[i:i+len(m)])
        i += len(m)

        assert len(m) == len(mls[-1]), 'mod bam formatting error'
    
    return mls


def sample_bam(bam_fn, motif, ref, n):
    ref = pysam.Fastafile(ref)

    n = int(n)
    c = 0

    sample = []

    logger.info('sampling %d mods from %s' % (n, bam_fn))

    for chrom in ref.references:
        for res in parse_methbam(bam_fn, [], chrom, 0, ref.get_reference_length(chrom), motifsize=len(motif), restrict_motif=motif, primary_only=True):
            sample.append(res[3])
            c += 1
            if c >= n:
                break
        else:
            continue

        break

    return np.asarray(sample)


def parse_methbam(bam_fn, reads, chrom, start, end, motifsize=2, meth_thresh=0.8, can_thresh=0.8, restrict_motif=None, restrict_ref=None, primary_only=False, retry=5):
    retries = 0
    bam_open = False

    while not bam_open:
        try:
            bam = pysam.AlignmentFile(bam_fn)
            bam_open = True

        except OSError as e:
            if retries < retry:
                logger.warning(f'retry open on {bam_fn}, possibly due to network filesystem')
                sleep(1)
                retries += 1
                
            else:
                sys.exit(e)

    if restrict_motif:
        if len(restrict_motif) != motifsize:
            motifsize = len(restrict_motif)

    if restrict_ref is not None:
        restrict_ref = pysam.Fastafile(restrict_ref)

    for rec in bam.fetch(chrom, start, end):
        if rec.is_unmapped:
            continue

        if rec.seq is None:
            continue

        if len(reads) == 0:
            continue

        if rec.qname not in reads:
            continue

        if primary_only:
            if rec.is_secondary or rec.is_supplementary:
                continue

        try:
            mods = rec.modified_bases
        except Exception:
            continue

        if not mods:
            continue

        ap = dict([(k, v) for (k, v) in rec.get_aligned_pairs() if k is not None])

        for (canonical_base, strand, mod_code), calls in mods.items():
            mod_type = str(mod_code)

            for query_pos, qual in calls:
                if qual == -1:
                    continue  # skip unknown implicit probabilities (?-tagged)

                genome_pos = ap.get(query_pos)
                if genome_pos is None:
                    continue

                if rec.is_reverse:
                    genome_pos -= (int(motifsize)-1)

                p_mod = qual/255
                p_can = 1-p_mod

                assert p_mod <= 1.0

                methstate = 0

                if p_mod > meth_thresh:
                    methstate = 1

                if p_can > can_thresh:
                    methstate = -1

                if None not in (restrict_motif, restrict_ref):
                    if genome_pos < 0:
                        continue

                    try:
                        ref_motif = restrict_ref.fetch(rec.reference_name, genome_pos, genome_pos+motifsize)
                        if motifsize == 1 and rec.is_reverse:
                            ref_motif = rc(ref_motif)
                    except ValueError:
                        logger.warning('warning, out of bounds motif at position %d in read: %s' % (genome_pos, rec.tostring()))
                        continue

                    if ref_motif.upper() != restrict_motif.upper():
                        continue

                yield (rec.qname, rec.reference_name, genome_pos, p_mod, methstate, mod_type)

def parse_ctbam(bam_fn, reads, chrom, start, end, motifsize=2, meth_thresh=0.8, can_thresh=0.8, restrict_motif=None, restrict_ref=None, primary_only=False, retry=5):
    retries = 0
    bam_open = False

    while not bam_open:
        try:
            bam = pysam.AlignmentFile(bam_fn)
            bam_open = True

        except OSError as e:
            if retries < retry:
                logger.warning(f'retry open on {bam_fn}, possibly due to network filesystem')
                sleep(1)
                retries += 1
                
            else:
                sys.exit(e)

    for read in bam.fetch(chrom, start, end):
        if read.is_unmapped:
            continue

        if read.seq is None:
            continue

        if len(reads) == 0:
            continue

        if read.qname not in reads:
            continue

        if primary_only:
            if read.is_secondary or read.is_supplementary:
                continue

        strand = '+'
        if read.is_reverse:
            strand = '-'

        if read.has_tag('MD'):
            for cpg in bs_parse_mm(read):
                pos, mod_prob = cpg

                methcall = 0

                if mod_prob > 0.8:
                    methcall = 1

                if mod_prob < 0.2:
                    methcall = -1 

                data = [read.reference_name, pos, strand, read.query_name, mod_prob, methcall, 'm']

                yield (read.qname, read.reference_name, pos, mod_prob, methcall, 'm')


def get_segmeth_calls(args, bam_fn, mod_names, meth_dbs, chrom, seg_start, seg_end, seg_name, seg_strand, phase, methbam, bedmethyl, ct_bams, meth_thresh, can_thresh, lowmethread_thresh, highmethread_thresh):

    lowmeth_count = {}
    highmeth_count = {}
    motif_sites = dd(dict)

    if bedmethyl:
        for modname in mod_names:
            lowmeth_count[modname] = 0
            highmeth_count[modname] = 0

        seg_result = {}
        out_means = {}
        dmr_metrics = {}
        
        bed = pysam.Tabixfile(bam_fn)

        if chrom in bed.contigs:
            for modname in mod_names:
                seg_meth_calls = dd(int)
                for rec in bed.fetch(chrom, seg_start, seg_end):
                    c = rec.strip().split('\t')
                    assert len(c) >= 11

                    modcol = c[3]
                    if modcol != modname:
                        continue

                    methpct = float(c[10])

                    motif_sites[modname][c[1]] = True

                    methfrac = methpct/100.0
                    numsites = int(c[9])

                    meth_count = int(numsites*methfrac)
                    unmeth_count = numsites-meth_count
                    assert unmeth_count >= 0

                    seg_meth_calls[1] += meth_count
                    seg_meth_calls[-1] += unmeth_count

                seg_result[modname] = seg_meth_calls
                out_means[modname] = 'NA'
                dmr_metrics[modname] = 'NA'

        return seg_result, (chrom, seg_start, seg_end, seg_name, seg_strand, 'NA', motif_sites, out_means, dmr_metrics, lowmeth_count, highmeth_count)

    c_lookup = {}

    if not methbam:
        for meth_db in meth_dbs:
            conn = sqlite3.connect(meth_db)
            c_lookup[meth_db] = conn.cursor()

    if 'spanning_only' not in vars(args):
        args.spanning_only=False

    reads = []
    if hasattr(args, 'excl_ambig') and args.excl_ambig:
        reads = exclude_ambiguous_reads(bam_fn, chrom, seg_start, seg_end, min_mapq=int(args.min_mapq), spanning_only=args.spanning_only)
    else:
        reads = get_reads(bam_fn, chrom, seg_start, seg_end, min_mapq=int(args.min_mapq), spanning_only=args.spanning_only, primary_only=args.primary_only)

    reads = list(set(reads))

    if phase:
        phased_reads_dict = get_phased_reads(bam_fn, chrom, seg_start, seg_end, tag_untagged=(phase=='unphased'), min_mapq=int(args.min_mapq), HP_only=True)
        reads = [r for r in reads if r in phased_reads_dict and phased_reads_dict[r] == phase]

    reads = set(reads)
    seg_reads = dd(dict)
    multi_warn = True

    if methbam:
        parse_func = parse_methbam

        if ct_bams is not None and bam_fn in ct_bams:
            parse_func = parse_ctbam
            multi_warn = False

        for row in parse_func(bam_fn, reads, chrom, seg_start, seg_end, meth_thresh=meth_thresh, can_thresh=can_thresh, restrict_ref=args.ref, restrict_motif=args.motif, primary_only=args.primary_only):
            index, cg_chrom, cg_start, stat, methstate, modname = row

            if chrom != cg_chrom:
                continue

            if cg_start < seg_start or cg_start > seg_end:
                continue

            motif_sites[modname][cg_start] = True

            cg_seg_start = (cg_start - seg_start) + 1

            if index not in seg_reads[modname]:
                seg_reads[modname][index] = Read(index, cg_seg_start, stat, methstate, modname)
            else:
                seg_reads[modname][index].add_mod(cg_seg_start, stat, methstate, modname, warn=multi_warn)

    else:
        for index in reads:
            for meth_db in meth_dbs:
                c = c_lookup[meth_db]
                for row in c.execute("SELECT chrom, pos, stat, methstate, modname FROM methdata WHERE readname='%s' ORDER BY pos" % index):

                    cg_chrom, cg_start, stat, methstate, modname = row

                    if chrom != cg_chrom:
                        continue

                    if cg_start < seg_start or cg_start > seg_end:
                        continue

                    motif_sites[modname][cg_start] = True

                    cg_seg_start = (cg_start - seg_start) + 1

                    if index not in seg_reads[modname]:
                        seg_reads[modname][index] = Read(index, cg_seg_start, stat, methstate, modname)
                    else:
                        seg_reads[modname][index].add_mod(cg_seg_start, stat, methstate, modname)

    if args.max_read_density is not None:
        for modname in seg_reads:
            seg_reads[modname] = densecall_filter(seg_reads[modname], max_density=float(args.max_read_density))

    seg_result = {}
    seen_reads = {}
    out_means = {}
    dmr_metrics = {}

    modfrac_read_dist = {}

    for modname in mod_names:
        seg_meth_calls = dd(int)
        lowmeth_count[modname] = 0
        highmeth_count[modname] = 0
        modfrac_read_dist[modname] = []

        for name, read in seg_reads[modname].items():
            if read.modname != modname:
                continue

            seen_reads[name] = 1

            for loc, call in read.meth_calls.items():
                seg_meth_calls[call] += 1

            modfrac_read = read.mod_frac_region(0, seg_end-seg_start)

            modfrac_read_dist[modname].append(modfrac_read)

            if lowmethread_thresh is not None:
                if modfrac_read <= float(lowmethread_thresh):
                    lowmeth_count[modname] += 1

            if highmethread_thresh is not None:
                if modfrac_read >= float(highmethread_thresh):
                    highmeth_count[modname] += 1

        seg_result[modname] = seg_meth_calls

        if args.predict_dmr:
            if phase:
                args.predict_dmr = not phase

        if args.predict_dmr:
            means, sigmetrics = mixmodel(modfrac_read_dist[modname])

            if sigmetrics is not None:
                if any(int(args.dmr_minreads) > rc for rc in sigmetrics['group_sizes']):
                    sigmetrics = None

            if sigmetrics is not None:
                if float(args.dmr_minratio) > sigmetrics['group_ratio']:
                    sigmetrics = None

            if sigmetrics is not None:
                if float(args.dmr_maxoverlap) < sigmetrics['overlap_fraction']:
                    sigmetrics = None

            if sigmetrics is not None:
                if float(args.dmr_mindiff) > sigmetrics['raw_difference']:
                    sigmetrics = None
            
            if sigmetrics is not None:
                if int(args.dmr_minmotifs) > len(motif_sites[modname]):
                    sigmetrics = None

            out_means[modname] = ','.join(map(lambda x: str(round(x, 2)), means))
    
            if sigmetrics is None:
                out_means[modname] = 'NA'
            else:
                dmr_metrics[modname] = sigmetrics
        else:
            out_means[modname] = 'NA'
            dmr_metrics[modname] = 'NA'

    read_count = len(seen_reads)

    return seg_result, (chrom, seg_start, seg_end, seg_name, seg_strand, read_count, motif_sites, out_means, dmr_metrics, lowmeth_count, highmeth_count)


def mixmodel(fracs, reg_covar=1e-4, max_iter=1000, max_means=2):
    if len(fracs) == 0:
        return (0.0, 0.0), None
    
    if len(fracs) == 1:
        return (fracs[0], fracs[0]), None
    
    fracs = np.asarray(fracs).reshape(-1,1)

    N = np.arange(1,max_means+1)
    models = [None for i in range(len(N))]

    with warnings.catch_warnings():
        warnings.filterwarnings("ignore", category=ConvergenceWarning)

        for i in range(len(N)):
            models[i] = GaussianMixture(N[i], reg_covar=reg_covar, max_iter=max_iter, n_init=max_means, random_state=1).fit(fracs)

    BIC = [m.bic(fracs) for m in models]

    best_model = models[np.argmin(BIC)]

    significance_metrics = assess_mean_separation(fracs.flatten(), best_model)

    means = list(best_model.means_.reshape(1,-1)[0])

    if len(means) == 1:
        means.append(means[0])

    return sorted(means), significance_metrics


def assess_mean_separation(data, gmm_model):
    if gmm_model.n_components != 2:
        return None
        
    means = gmm_model.means_.flatten()
    covariances = gmm_model.covariances_.flatten()
    weights = gmm_model.weights_
    
    order = np.argsort(means)
    mean1, mean2 = means[order]
    std1, std2 = np.sqrt(covariances[order])
    weight1, weight2 = weights[order]
    
    # raw difference
    raw_diff = abs(mean2 - mean1)
    
    # effect size
    pooled_std = np.sqrt(((weight1 * std1**2) + (weight2 * std2**2)) / (weight1 + weight2))
    cohens_d = raw_diff / pooled_std if pooled_std > 0 else float('inf')
    
    # separation relative to spread
    avg_std = (std1 + std2) / 2
    separation_ratio = raw_diff / avg_std if avg_std > 0 else float('inf')
    
    # overlap coefficient (approximate)
    overlap_region = max(0, min(mean1 + 2*std1, mean2 + 2*std2) - max(mean1 - 2*std1, mean2 - 2*std2))
    total_span = (mean1 + 2*std1) - (mean1 - 2*std1) + (mean2 + 2*std2) - (mean2 - 2*std2)
    overlap_fraction = overlap_region / total_span if total_span > 0 else 1.0
    
    # assign points to clusters
    cluster_assignments = gmm_model.predict(data.reshape(-1, 1))
    group1 = data[cluster_assignments == order[0]]
    group2 = data[cluster_assignments == order[1]]
    
    if len(group1) > 1 and len(group2) > 1:
        t_stat, p_value = ss.ttest_ind(group1, group2)
    else:
        t_stat, p_value = None, None
    
    return {
        'raw_difference': raw_diff,
        'cohens_d': cohens_d,
        'separation_ratio': separation_ratio,
        'overlap_fraction': overlap_fraction,
        't_statistic': t_stat,
        'raw_p_value': p_value,
        'group_sizes': [len(group1), len(group2)],
        'group_ratio': min(len(group1), len(group2))/(len(group1)+len(group2))
    }


def get_meth_locus(args, bam, meth_dbs, mod, meth_thresh=0.8, can_thresh=0.8, allele=None, variant=None, phase=None, methbam=False, HP_only=False, restrict_motif=None, restrict_ref=None, ct_bams=None):
    # used for locus plots

    if phase:
        logger.info('fetching reads from %s with mod %s on phase %s' % (bam, mod, phase))
    else:
        logger.info('fetching reads from %s with mod %s' % (bam, mod))

    assert ':' in args.interval
    assert '-' in args.interval

    chrom, pos = args.interval.split(':')
    elt_start, elt_end = pos.split('-')

    elt_start = int(elt_start.replace(',',''))
    elt_end = int(elt_end.replace(',',''))

    c_lookup = {}

    if not methbam:
        for meth_db in meth_dbs:
            if not os.path.exists(meth_db):
                sys.exit('methylartist database (%s) does not exist, check that full path is included if not in current working directory.' % meth_db)

            conn = sqlite3.connect(meth_db)
            c_lookup[meth_db] = conn.cursor()

    # get list of relevant reads (exludes reads not anchored outside interval)
    reads = []
    if hasattr(args, 'excl_ambig') and args.excl_ambig:
        reads = exclude_ambiguous_reads(bam, chrom, elt_start, elt_end, min_mapq=int(args.min_mapq))
    else:
        reads = get_reads(bam, chrom, elt_start, elt_end, min_mapq=int(args.min_mapq), primary_only=args.primary_only)
        
    reads = list(set(reads))

    if phase:
        phased_reads_dict = get_phased_reads(bam, chrom, elt_start, elt_end, tag_untagged=(phase=='unphased'), min_mapq=int(args.min_mapq), HP_only=HP_only)
        reads = [r for r in reads if r in phased_reads_dict and phased_reads_dict[r] == phase]

    if variant:
        var_reads_dict = get_variant_reads(bam, chrom, elt_start, elt_end, variant, min_mapq=int(args.min_mapq))
        reads = [r for r in reads if r in var_reads_dict and var_reads_dict[r] == allele]

    if hasattr(args, 'unambig_highlight') and hasattr(args, 'highlight'): 
        if args.unambig_highlight and args.highlight:
            h_coords = []
            for h in args.highlight.split(','):
                if ':' in h:
                    h = h.split(':')[-1]
                    
                h_coords += map(int, h.split('-'))

            h_coords.sort()

            h_start, h_end = h_coords[0], h_coords[-1]

            excl_reads = get_ambiguous_reads(bam, chrom, h_start, h_end, min_mapq=int(args.min_mapq))

            new_reads = []
            for read in reads:
                if read not in excl_reads:
                    new_reads.append(read)

            reads = set(new_reads)

    seg_reads = {}
    multi_warn = True

    if methbam:
        parse_func = parse_methbam

        if ct_bams is not None and bam in ct_bams:
            parse_func = parse_ctbam
            multi_warn = False

        for row in parse_func(bam, reads, chrom, elt_start, elt_end, motifsize=args.motifsize, meth_thresh=meth_thresh, can_thresh=can_thresh, restrict_motif=restrict_motif, restrict_ref=restrict_ref, primary_only=args.primary_only):
            index, cg_chrom, cg_start, stat, methstate, modname = row

            if methstate == 0:
                continue

            if modname != mod:
                continue

            if chrom != cg_chrom:
                continue

            if cg_start < elt_start or cg_start > elt_end:
                continue

            cg_seg_start = (cg_start - elt_start) + 1

            if index not in seg_reads:
                seg_reads[index] = Read(index, cg_seg_start, stat, methstate, modname)
            else:
                seg_reads[index].add_mod(cg_seg_start, stat, methstate, modname, warn=multi_warn)

    else:
        for index in reads:
            for meth_db in meth_dbs:
                c = c_lookup[meth_db]

                for row in c.execute("SELECT chrom, pos, stat, methstate, modname FROM methdata WHERE readname='%s' ORDER BY pos" % index):
                    cg_chrom, cg_start, stat, methstate, modname = row

                    if methstate == 0:
                        continue

                    if modname != mod:
                        continue

                    if chrom != cg_chrom:
                        continue

                    if cg_start < elt_start or cg_start > elt_end:
                        continue

                    cg_seg_start = (cg_start - elt_start) + 1

                    if index not in seg_reads:
                        seg_reads[index] = Read(index, cg_seg_start, stat, methstate, modname)
                    else:
                        seg_reads[index].add_mod(cg_seg_start, stat, methstate, modname)

    if args.max_read_density is not None:
        seg_reads = densecall_filter(seg_reads, max_density=float(args.max_read_density))

    return seg_reads


def get_meth_profile_composite(args, data, methbam, seg_chrom, seg_start, seg_end, seg_strand, use_mod, phase, meth_thresh, can_thresh):
    per_bam_results = {}

    for bam in data:
        logger.info('profiling %s: %s:%d-%d:%s:%s phase: %s' % (bam, seg_chrom, seg_start, seg_end, seg_strand, use_mod, str(phase)))

        te_ref_seq = single_seq_fa(args.teref)
        ref = pysam.Fastafile(args.ref)

        conn = None
        c = None

        if not methbam:
            conn = sqlite3.connect(data[bam])
            c = conn.cursor()

        reads = []
        if args.excl_ambig:
            reads = exclude_ambiguous_reads(bam, seg_chrom, seg_start, seg_end, min_mapq=int(args.min_mapq))
        else:
            reads = get_reads(bam, seg_chrom, seg_start, seg_end, min_mapq=int(args.min_mapq), primary_only=args.primary_only)

        if phase:
            phased_reads_dict = get_phased_reads(bam, seg_chrom, seg_start, seg_end, tag_untagged=(phase=='unphased'), min_mapq=int(args.min_mapq), HP_only=True)
            reads = [r for r in reads if r in phased_reads_dict and phased_reads_dict[r] == phase]

        reads = set(reads)

        seg_reads = {}

        if methbam:
            for row in parse_methbam(bam, reads, seg_chrom, seg_start, seg_end, motifsize=len(args.motif), meth_thresh=meth_thresh, can_thresh=can_thresh, restrict_motif=args.motif, restrict_ref=args.ref, primary_only=args.primary_only):
                index, cg_chrom, cg_start, stat, methstate, modname = row

                if seg_chrom != cg_chrom:
                    continue

                if cg_start < seg_start or cg_start > seg_end:
                    continue

                if modname != use_mod:
                    continue

                cg_seg_start = cg_start - seg_start

                if index not in seg_reads:
                    seg_reads[index] = Read(index, cg_seg_start, stat, methstate, modname)
                else:
                    seg_reads[index].add_mod(cg_seg_start, stat, methstate, modname)

        else:
            for index in reads:
                for row in c.execute("SELECT chrom, pos, stat, methstate, modname FROM methdata WHERE readname='%s' ORDER BY pos" % index):

                    cg_chrom, cg_start, stat, methstate, modname = row

                    if seg_chrom != cg_chrom:
                        continue

                    if cg_start < seg_start or cg_start > seg_end:
                        continue

                    if modname != use_mod:
                        continue

                    cg_seg_start = cg_start - seg_start

                    if index not in seg_reads:
                        seg_reads[index] = Read(index, cg_seg_start, stat, methstate, modname)
                    else:
                        seg_reads[index].add_mod(cg_seg_start, stat, methstate, modname)

        if args.max_read_density is not None:
            seg_reads = densecall_filter(seg_reads, max_density=float(args.max_read_density))

        meth_table = dd(dict)
        sample = '.'.join(bam.split('.')[:-1])
        call_count = 0

        for name, read in seg_reads.items():
            for loc in read.llrs.keys():
                uuid = str(uuid4())
                meth_table[uuid]['loc'] = loc
                meth_table[uuid]['modstat'] = read.llrs[loc]
                meth_table[uuid]['read'] = name
                meth_table[uuid]['sample'] = sample
                meth_table[uuid]['call'] = read.meth_calls[loc]

                if read.meth_calls[loc] in (1, -1):
                    call_count += 1

        if call_count < int(args.mincalls):
            logger.warning('too few calls on seg: %s:%d-%d (%d)' % (seg_chrom, seg_start, seg_end, call_count))
            per_bam_results[bam] = None
            continue

        meth_table = pd.DataFrame.from_dict(meth_table).T
        meth_table['loc'] = pd.to_numeric(meth_table['loc'])
        meth_table['modstat'] = pd.to_numeric(meth_table['modstat'])

        meth_table['orig_loc'] = meth_table['loc']
        meth_table['loc'] = ss.rankdata(meth_table['loc'], method='dense')

        coord_to_cpg = {}
        cpg_to_coord = {}
        for orig_loc, new_loc in zip(meth_table['orig_loc'], meth_table['loc']):
            coord_to_cpg[orig_loc] = new_loc
            cpg_to_coord[new_loc]  = orig_loc

        windowed_methfrac, meth_n = slide_window(meth_table, sample, width=int(args.slidingwindowsize), slide=int(args.slidingwindowstep))

        if len(windowed_methfrac) <= int(args.smoothwindowsize):
            logger.warning('too few sites after windowing: %s:%d-%d' % (seg_chrom, seg_start, seg_end))
            per_bam_results[bam] = None
            continue

        smoothed_methfrac = smooth(np.asarray(list(windowed_methfrac.values())), window_len=int(args.smoothwindowsize), window=args.smoothfunc)

        coord_meth_pos = []

        cpg_meth_pos = list(windowed_methfrac.keys())

        for cpg in cpg_meth_pos:
            if cpg in cpg_to_coord:
                if seg_strand == '+':
                    coord_meth_pos.append(cpg_to_coord[cpg])
                if seg_strand == '-':
                    coord_meth_pos.append((seg_end-seg_start)-cpg_to_coord[cpg]-2)

        # alignment to ref elt

        elt_seq = ref.fetch(seg_chrom, seg_start, seg_end)

        if seg_strand == '-':
            elt_seq = rc(elt_seq)

        te_ref_seq = te_ref_seq.upper()
        elt_seq = elt_seq.upper()

        s_ref = skseq.DNA(te_ref_seq)
        s_elt = skseq.DNA(elt_seq)

        aln_res = []

        try:
            aln_res = skalign.local_pairwise_align_ssw(s_ref, s_elt)
        except IndexError: # scikit-bio throws this if no bases align  >:|
            logger.warning('no align on seg: %s:%d-%d' % (seg_chrom, seg_start, seg_end))
            per_bam_results[bam] = None
            continue
        
        coord_ref, coord_elt = aln_res[2]
        
        len_ref = coord_ref[1] - coord_ref[0]
        len_elt = coord_elt[1] - coord_elt[0]

        if len_ref / len(te_ref_seq) < float(args.lenfrac):
            logger.warning('ref align too short on seg: %s:%d-%d (%f)' % (seg_chrom, seg_start, seg_end, len_ref / len(te_ref_seq)))
            per_bam_results[bam] = None
            continue

        if len_elt / len(elt_seq) < float(args.lenfrac):
            logger.warning('elt align too short on seg: %s:%d-%d (%f)' % (seg_chrom, seg_start, seg_end, len_elt / len(elt_seq)))
            per_bam_results[bam] = None
            continue

        tab_msa = aln_res[0]

        elt_to_ref_coords = {}

        pos_ref = coord_ref[0]
        pos_elt = coord_elt[0]

        for pos in tab_msa.iter_positions():
            pos = list(pos)
            b_ref = pos[0]
            b_elt = pos[1]

            if '-' not in pos:
                elt_to_ref_coords[pos_elt] = pos_ref
                pos_ref += 1
                pos_elt += 1

            if b_elt == '-':
                pos_ref += 1

            if b_ref == '-':
                elt_to_ref_coords[pos_elt] = 'na'
                pos_elt += 1

        revised_coord_meth_pos = []
        meth_profile = []

        for pos, meth in zip(coord_meth_pos, smoothed_methfrac):
            if pos not in elt_to_ref_coords:
                continue

            revised_pos = elt_to_ref_coords[pos]

            if revised_pos != 'na':
                revised_coord_meth_pos.append(revised_pos)
                meth_profile.append(meth)

        bam_noext = '.'.join(os.path.basename(bam).split('.')[:-1])
        elt_info = '_'.join( map(str, (seg_chrom, seg_start, seg_end, seg_strand)))
        if phase:
            bam_noext += '.phase' + phase
        per_bam_results[bam_noext] = (revised_coord_meth_pos, meth_profile, elt_info)

    return per_bam_results


def get_meth_calls_wg(args, bam_fn, meth_fn, chrom, seg_start, seg_end, phased, mod, motifsize, ct_bams, meth_thresh, can_thresh):
    methbam = False

    if args.methdb is None:
        methbam = True
    
    if not methbam and not os.path.exists(meth_fn):
        sys.exit('methylartist database (%s) does not exist, check that full path is included if not in current working directory.' % meth_fn)

    conn = None
    c = None

    if not methbam:
        conn = sqlite3.connect(meth_fn)
        c = conn.cursor()

    reads = get_phased_reads(bam_fn, chrom, seg_start, seg_end, tag_untagged=phased, min_mapq=int(args.min_mapq), HP_only=True)

    seg_reads = {}
    multi_warn = True

    if methbam:
        parse_func = parse_methbam

        if ct_bams is not None:
            parse_func = parse_ctbam
            multi_warn = False

        for row in parse_func(bam_fn, set(reads), chrom, seg_start, seg_end, meth_thresh=meth_thresh, can_thresh=can_thresh, restrict_ref=args.ref, restrict_motif=args.motif, motifsize=motifsize, primary_only=args.primary_only):
            index, cg_chrom, cg_start, stat, methstate, modname = row

            if mod is not None and mod != modname:
                continue

            if chrom != cg_chrom:
                continue

            if cg_start < seg_start or cg_start >= seg_end:
                continue

            if index not in reads:
                continue

            cg_seg_start = (cg_start - seg_start) + 1

            if index not in seg_reads:
                seg_reads[index] = Read(index, cg_seg_start, stat, methstate, modname, phase=reads[index])

            else:
                seg_reads[index].add_mod(cg_seg_start, stat, methstate, modname, warn=multi_warn)

    else:
        for index in reads:
            for row in c.execute("SELECT chrom, pos, stat, methstate, modname FROM methdata WHERE readname='%s' ORDER BY pos" % index):

                cg_chrom, cg_start, stat, methstate, modname = row

                if mod is not None and mod != modname:
                    continue

                if chrom != cg_chrom:
                    continue

                if cg_start < seg_start or cg_start > seg_end:
                    continue

                cg_seg_start = (cg_start - seg_start) + 1

                if index not in seg_reads:
                    seg_reads[index] = Read(index, cg_seg_start, stat, methstate, modname, phase=reads[index])
                else:
                    seg_reads[index].add_mod(cg_seg_start, stat, methstate, modname)

    if args.max_read_density is not None:
        seg_reads = densecall_filter(seg_reads, max_density=float(args.max_read_density))

    meth_data = {}
    meth_data[1] = dd(list)
    meth_data[2] = dd(list)

    meth_table = {}
    meth_table[1] = dd(dict)
    meth_table[2] = dd(dict)

    if phased:
        for name, read in seg_reads.items():
            if read.phase is None:
                continue

            if read.phase == 'unphased':
                continue

            if read.phase not in ['1','2']:
                continue

            for loc in read.llrs.keys():
                phase = int(read.phase)
                assert phase in (1,2)

                meth_data[phase][loc].append(read.meth_calls[loc])


        for phase in (1,2):
            for loc in meth_data[phase]:
                pos = loc+seg_start
                N = len([call for call in meth_data[phase][loc] if call != 0]) # call 0 == ambiguous
                X = len([call for call in meth_data[phase][loc] if call == 1]) # call 1 == methylated

                if N > 0:
                    meth_table[phase][pos]['chr'] = chrom
                    meth_table[phase][pos]['N'] = N 
                    meth_table[phase][pos]['X'] = X 

    else:
        # pass through as phase 1 if not analysing phases
        phase = 1

        for name, read in seg_reads.items():
            for loc in read.llrs.keys():
                meth_data[phase][loc].append(read.meth_calls[loc])

        for loc in meth_data[phase]:
            pos = loc+seg_start
            N = len([call for call in meth_data[phase][loc] if call != 0])
            X = len([call for call in meth_data[phase][loc] if call == 1])

            if N > 0:
                meth_table[phase][pos]['chr'] = chrom
                meth_table[phase][pos]['N'] = N 
                meth_table[phase][pos]['X'] = X 

    return [meth_table[1], meth_table[2]]


def slide_window(meth_table, sample, width=20, slide=2):
    sample_data = meth_table[meth_table['sample'] == sample]
    
    midpt_min = sample_data['loc'].min()
    midpt_max = sample_data['loc'].max()
    
    meth_data = sample_data[sample_data['call'] == 1]['loc'].values
    unmeth_data = sample_data[sample_data['call'] == -1]['loc'].values
    
    win_start = int(midpt_min - width/2)
    win_end = win_start + width
    
    meth_frac = {}
    meth_n = {}
    
    while int((win_start+win_end)/2) < midpt_max:
        win_start += slide
        win_end += slide
        
        meth_count = np.sum((meth_data > win_start) & (meth_data < win_end))
        unmeth_count = np.sum((unmeth_data > win_start) & (unmeth_data < win_end))
        
        midpt = int((win_start+win_end)/2)
        
        if meth_count + unmeth_count > 0:
            meth_frac[midpt] = meth_count/(meth_count+unmeth_count)
            meth_n[midpt] = meth_count+unmeth_count
    
    return meth_frac, meth_n


def slide_window_cover(cover_table, sample, width=20, slide=2):
    # used for locus plots, composite plots

    midpt_min = min(cover_table['loc'])
    midpt_max = max(cover_table['loc'])

    win_start = int(midpt_min - width/2)
    win_end = win_start + width

    cover_windows = {}

    while int((win_start+win_end)/2) < midpt_max:
        win_start += slide
        win_end += slide

        midpt = int((win_start+win_end)/2)

        c = cover_table.loc[(cover_table['sample'] == sample) & (cover_table['loc'] > win_start) & (cover_table['loc'] < win_end)]
        cover_windows[midpt] = np.mean(c['cover'])

    #print(cover_windows)
    return cover_windows


def smooth(x, window_len=8, window='hanning'):
    if x.size < window_len:
        logger.warning('cannot smooth segment: fewer data points than window_len')
        return x

    if window_len < 3:
        return x
    
    if window_len % 2 != 0:
        window_len += 1  # Make it even
    
    window_functions = {
        'flat': lambda size: np.ones(size, 'd'),
        'hanning': np.hanning,
        'hamming': np.hamming,
        'bartlett': np.bartlett,
        'blackman': np.blackman
    }
    
    # Use np.pad instead of np.r_ for better performance
    s = np.pad(x, (window_len-1, window_len-1), mode='reflect')
    
    w = window_functions.get(window, np.hanning)(window_len)
    w = w / w.sum()
    y = np.convolve(w, s, mode='valid')
    
    half_window = int(window_len/2)
    
    return y[half_window-1:-(half_window)]


def mask_methfrac(data, cutoff=20):
    # used for locus plots
    data = np.asarray(data)
    data = data > int(cutoff)

    segs = []

    in_seg = False
    seg_start = 0

    for i in range(len(data)):
        if data[i]:
            if in_seg:
                segs.append(list(range(seg_start, i)))

            in_seg = False

        else:
            if not in_seg:
                seg_start = i

            in_seg = True

    if in_seg:
        segs.append(list(range(seg_start, len(data))))

    return segs


def get_bed_annotations(fn, chrom, elt_start, elt_end):
    ann = []

    elt_start = int(elt_start)
    elt_end = int(elt_end)

    with open(fn) as bed:
        for bed_rec in bed:
            c = bed_rec.strip().split()
            bed_chrom, bed_start, bed_end = c[:3]

            if bed_chrom != chrom:
                continue

            bed_start = int(bed_start)
            bed_end = int(bed_end)

            if bed_end < elt_start:
                continue

            if bed_start > elt_end:
                continue

            if bed_start < elt_start:
                bed_start = elt_start

            if bed_end > elt_end:
                bed_end = elt_end

            bed_label = None
            bed_colour = None
            strand = None

            if len(c) > 3:
                bed_label = c[3]
            
            if len(c) > 4:
                if c[4] not in ('+', '-'):
                    sys.exit('format for --bed input must be: chrom start end label strand colour')
                
                strand = c[4]
            
            if len(c) > 5:
                bed_colour = c[5]

            BED = namedtuple("BED", "start end label strand colour")
            ann.append(BED(start=bed_start, end=bed_end, label=bed_label, strand=strand, colour=bed_colour))
    
    return ann


def find_motif_windows(region_seq, motifs, start, end, motifs_per_window):
    w_starts = [start]
    w_ends = []
    
    motif_len = len(motifs[0])  # Assuming all motifs have the same length
    
    motif_set = set(motifs)
    rc_motif_set = set(rc(m) for m in motifs)
    
    motif_count = 0
    seq_len = len(region_seq)
    
    for i in range(seq_len - motif_len + 1):
        site_fwd = region_seq[i:i+motif_len]
        
        if site_fwd in motif_set or site_fwd in rc_motif_set:
            motif_count += 1
            
            if motif_count == motifs_per_window:
                position = start + i + 1  # +1 because we've just processed this position
                w_starts.append(position)
                w_ends.append(position)
                motif_count = 0
    
    w_ends.append(end)
    
    return w_starts, w_ends


def build_genes(gtf, chrom, start, end, tx=False):
    # used for locus plots
    genes = {}

    for line in gtf.fetch(chrom, start, end):

        chrom, source, feature, start, end, score, strand, frame, attribs = line.split('\t')

        block = [int(start), int(end)]

        attribs = attribs.strip()

        attr_dict = {}

        for attrib in attribs.split(';'):
            if attrib:
                key, val = attrib.strip().split()[:2]
                key = key.strip()
                val = val.strip().strip('"')
                attr_dict[key] = val

        if 'gene_id' not in attr_dict:
            continue

        if 'gene_name' not in attr_dict:
            attr_dict['gene_name'] = attr_dict['gene_id']

        ensg = attr_dict['gene_id']
        name = attr_dict['gene_name']

        if tx:
            if 'transcript_id' not in attr_dict:
                continue

            if 'transcript_name' not in attr_dict:
                attr_dict['transcript_name'] = attr_dict['transcript_id']

            ensg = attr_dict['transcript_id']
            name = attr_dict['transcript_name']

        if ensg not in genes:
            genes[ensg] = Gene(ensg, name, strand)

        if feature == 'exon':
            genes[ensg].add_exon(block)

        if feature == 'CDS':
            genes[ensg].add_cds(block)

        if feature == 'transcript':
            genes[ensg].add_tx(block)

    return genes


def sample_db(db, mod, n=1000000):
    if not os.path.exists(db):
        sys.exit('methylartist database (%s) does not exist, check that full path is included if not in current working directory.' % db)

    conn = sqlite3.connect(db)
    c = conn.cursor()

    logger.info('sample %d values from %s where modname="%s"' % (n, db, mod))
    return np.asarray(c.execute('SELECT stat FROM methdata WHERE modname="%s" ORDER BY RANDOM() LIMIT %d' % (mod, n)).fetchall()).flatten()


## tools

def scoredist(args):
    sample = {}
    #model = {}

    meth_cutoffs = []
    unmeth_cutoffs = []

    avail_mods = []

    out_fn = None

    if args.db is None and args.bam is None:
        sys.exit('please specify either -d/--db or -b/--bam')

    if args.motif is not None:
        assert iupac(args.motif)

    if args.db:
        if args.bam is not None:
            sys.exit('please specify either -d/--db or -b/--bam but not both')

        for db in args.db.split(','):
            avail_mods += get_modnames(db)

        avail_mods = list(set(avail_mods))

        if args.mod is None or args.mod not in avail_mods:
            logger.warning('mod %s not found, available mods: %s' % (str(args.mod), ','.join(avail_mods)))
            sys.exit()

        for db in args.db.split(','):
            assert os.path.exists(db), '%s not found' % db
            upper, lower = get_cutoffs(db, args.mod)
            meth_cutoffs.append(upper)
            unmeth_cutoffs.append(lower)

            sample[os.path.basename(db)] = sample_db(db, args.mod, n=int(args.n))
        
        out_fn = '_'.join(map(os.path.basename, args.db.split(','))) + '.scoredist'

    elif args.bam:

        if None in (args.ref, args.motif):
            logger.warning('--ref and --motif are required when using --bams')
            sys.exit(1)

        for bam in args.bam.split(','):
            meth_cutoffs.append(0.8)
            unmeth_cutoffs.append(0.2)
            sample[os.path.basename(bam)] = sample_bam(bam, args.motif, args.ref, int(args.n))
        
        out_fn = '_'.join(map(os.path.basename, args.bam.split(','))) + '.scoredist'


    meth_cutoffs = list(set(meth_cutoffs))
    unmeth_cutoffs = list(set(unmeth_cutoffs))

    if len(meth_cutoffs) > 1:
        logger.warning('multipe upper (methylation) cutoffs:' + ','.join(map(str, meth_cutoffs)))

    if len(unmeth_cutoffs) > 1:
        logger.warning('multipe upper (unmethylation) cutoffs:' + ','.join(map(str, unmeth_cutoffs)))

    fig = plt.figure()
    ax = fig.add_subplot(111)
    dens = sns.kdeplot(data=sample, palette=args.palette, lw=float(args.lw))

    top = ax.get_ylim()[1]

    ax.vlines(meth_cutoffs, ymin=0, ymax=top, colors='black', linestyles='dashed')
    ax.vlines(unmeth_cutoffs, ymin=0, ymax=top, colors='black', linestyles='dashed')

    xmin, xmax = ax.get_xlim()

    if args.xmin:
        xmin = float(args.xmin)
    
    if args.xmax:
        xmax = float(args.xmax)

    ax.set_xlim(xmin, xmax)

    if args.svg:
        out_fn += '.svg'
    else:
        out_fn += '.png'

    if args.outfile is not None:
        out_fn = args.outfile

    logger.info('plot written to %s' % out_fn)
    plt.savefig(out_fn)


def adjustcutoffs(args):
    if not os.path.exists(args.db):
        sys.exit('methylartist database (%s) does not exist, check that full path is included if not in current working directory.' % args.db)

    avail_mods = get_modnames(args.db)

    if args.mod not in avail_mods:
        logger.warning('mod %s not found, available mods: %s' % (args.mod, ','.join(avail_mods)))
        sys.exit()

    conn = sqlite3.connect(args.db)
    c = conn.cursor()

    logger.info('%s: reset methylation states to 0 for mod %s' % (args.db, args.mod))
    c.execute("UPDATE methdata SET methstate=0 where modname='%s'" % args.mod)

    logger.info('%s: mark sites with stat > %f methylated (1) for mod %s' % (args.db, float(args.methylated), args.mod))
    c.execute('UPDATE methdata SET methstate=1 where stat > %f and modname="%s"' % (float(args.methylated), args.mod))

    logger.info('%s: mark sites with stat < %f unmethylated (-1) for mod %s' % (args.db, float(args.unmethylated), args.mod))
    c.execute('UPDATE methdata SET methstate=-1 where stat < %f and modname="%s"' % (float(args.unmethylated), args.mod))

    logger.info('%s: update cutoff table for mod %s' % (args.db, args.mod))
    c.execute('UPDATE cutoffs SET upper=%f WHERE modname="%s"' % (float(args.methylated), args.mod))
    c.execute('UPDATE cutoffs SET lower=%f WHERE modname="%s"' % (float(args.unmethylated), args.mod))

    conn.commit()


def bs_parse_mm(read):
    ap = read.get_aligned_pairs(with_seq=True)

    for i in range(len(ap)-2):
        bp_cur = ap[i]
        bp_nxt = ap[i+1]

        if None in bp_cur or None in bp_nxt:
            continue

        if bp_cur[2].upper() == 'C' and bp_nxt[2] == 'G':
            seq_pos = bp_cur[0]
            nxt_pos = bp_nxt[0]

            if nxt_pos-seq_pos == 1:
                ref_pos = bp_cur[1]
                seq_bp = read.query_sequence[seq_pos].upper()

                if seq_bp not in ('C', 'T'):
                    continue

                mod_prob = 10**(read.query_qualities[seq_pos]/-10.0)*0.5

                if seq_bp == 'C':
                    mod_prob = 1.0-mod_prob

                yield (ref_pos, mod_prob)


def bs_parse_reads(bam_fn, db_fn):
    bam = pysam.AlignmentFile(bam_fn)
    conn = sqlite3.connect(db_fn)

    for read in tqdm(bam.fetch(), total=bam.mapped):
        if read.is_duplicate:
            continue

        if read.query_qualities is None:
            continue

        strand = '+'
        if read.is_reverse:
            strand = '-'

        if read.has_tag('MD'):
            for cpg in bs_parse_mm(read):
                pos, mod_prob = cpg

                methcall = 0

                if mod_prob > 0.8:
                    methcall = 1

                if mod_prob < 0.2:
                    methcall = -1 

                data = [read.reference_name, pos, strand, read.query_name, mod_prob, methcall, 'm']

                ''' (chrom text, pos integer, strand text, readname text, stat real, methstate integer, modname text) '''

                conn.execute('INSERT INTO methdata VALUES (?,?,?,?,?,?,?)', data)

    conn.commit()
    conn.close()


def bs_create_db(args):
    db_fn = None

    if args.db:
        if args.db.endswith('.db'):
            db_fn = args.db
        else:
            db_fn = args.db + '.db'

    logger.info(f'database output name: {db_fn}')

    if os.path.exists(db_fn) and not args.append:
        sys.exit(f'database {db_fn} already exists')

    if args.append and not os.path.exists(db_fn):
        sys.exit(f'database {db_fn} does not exist and --append has been called')

    conn = sqlite3.connect(db_fn)

    c = conn.cursor()

    if not args.append:
        c.execute('''CREATE TABLE methdata (chrom text, pos integer, strand text, readname text, stat real, methstate integer, modname text)''')
        c.execute('''CREATE TABLE modnames (mod text)''')
        c.execute('''CREATE TABLE cutoffs (upper real, lower real, modname text)''')
        c.execute('''CREATE INDEX read_index ON methdata(readname)''')

        c.execute("INSERT INTO modnames VALUES ('m')")
        c.execute("INSERT INTO cutoffs VALUES ('%.4f', '%.4f', 'm')" % (0.2, 0.8))

    else:
        logger.info(f'appending records to {db_fn}')

    conn.commit()
    conn.close()

    return db_fn


def db_sub(args):
    db_fn = bs_create_db(args)
    bs_parse_reads(args.bam, db_fn)


def db_guppy(args):
    if not ont_fast5_api_installed:
        sys.exit('ont_fast5_api is not installed but is required for this function. Please install e.g. via "pip install ont-fast5-api"')

    if not args.force:
        sys.exit('This function (db-guppy) is depreciated. Please use use guppy to create a .bam file with modification tags and use that as input to methylartist')

    assert os.path.exists(args.fast5), 'path not found: %s' % args.fast5

    bam_db = args.samplename + '.bamcache.db'

    if os.path.exists(bam_db):
        logger.info('using existing bam cache db %s' % bam_db)

    else:
        logger.info('mod motif size (--motifsize) = %d (ensure this is correct for your data)' % int(args.motifsize))
        conn = sqlite3.connect(bam_db)
        c = conn.cursor()

        logger.info('caching %s to %s' % (args.bam, bam_db))
        c.execute('''CREATE TABLE bam (readname text, sam text)''')
        c.execute('''CREATE INDEX read_index ON bam(readname)''')

        for bam_fn in args.bam.split(','):
            bam = pysam.AlignmentFile(bam_fn)

            commit_interval = 100000

            for i, read in enumerate(bam.fetch(), 1):
                if read.is_secondary or read.is_supplementary or read.is_duplicate:
                    continue

                read.query_qualities = None

                read_data = (read.query_name, read.to_string())
                c.execute("INSERT INTO bam VALUES ('%s', '%s')" % read_data)

                if i % commit_interval == 0:
                    logger.info('commiting %d records to %s...' % (commit_interval, bam_db))
                    conn.commit()

            logger.info('commiting remaining records to %s...' % bam_db)

        conn.commit()
        conn.close()

    fast5s = []

    for fn in os.listdir(args.fast5):
        if fn.endswith('.fast5'):
            fast5s.append(fn)

    logger.info('found %d fast5 files in %s' % (len(fast5s), args.fast5))

    logger.info('fetching base modifications...')

    pool = mp.Pool(processes=int(args.procs))

    results = []

    for fast5 in fast5s:
        res = pool.apply_async(guppy_f5_fetch, [fast5, bam_db, args])
        results.append(res)


    outfiles = []

    for res in results:
        outfiles.append(res.get())

    db_fn = args.samplename + '.guppy.db'

    if os.path.exists(db_fn) and not args.append:
        sys.exit('database %s already exists' % db_fn)

    conn = sqlite3.connect(db_fn)
    c = conn.cursor()

    if not args.append:
        c.execute('''CREATE TABLE methdata (chrom text, pos integer, strand text, readname text, stat real, methstate integer, modname text)''')
        c.execute('''CREATE TABLE modnames (mod text)''')
        c.execute('''CREATE INDEX read_index ON methdata(readname)''')
        c.execute('''CREATE TABLE cutoffs (upper real, lower real, modname text)''')

    else:
        logger.info('appending records to %s' % db_fn)

    for fn in outfiles:
        if fn is None:
            continue

        with open(fn) as tsv:
            logger.info('loading %s into %s' % (fn, db_fn))
            csv_reader = csv.DictReader(tsv, delimiter='\t')

            for row in csv_reader:
                readname = row['readname']
                chrom    = row['chrom']
                pos      = int(row['pos'])
                strand   = row['strand']

                mod_prob = float(row['mod_log_prob'])
                can_prob = float(row['can_log_prob'])
                minprob  = float(args.minprob)
                mod_base = args.modname

                llr = float(row['mod_log_prob']) - float(row['can_log_prob'])

            # adjust position of - strand calls based on mod motif size (default = 2 as CG is probably the most frequent use case)
            if strand == '-':
                pos -= (int(args.motifsize)-1)

                methcall = 0

                if mod_prob >= minprob:
                    methcall = 1

                if can_prob >= minprob:
                    methcall = -1

                assert (mod_prob + can_prob) < 1.01 # allow for some rounding error

                ins_data = (chrom, pos, strand, readname, llr, methcall, mod_base)
                c.execute("INSERT INTO methdata VALUES ('%s', %d, '%s', '%s', %.2f, %d, '%s')" % ins_data)

            c.execute("INSERT INTO modnames VALUES ('%s')" % args.modname)

            if not args.append:
                c.execute("INSERT INTO cutoffs VALUES ('%.4f', '%.4f', '%s')" % (llr, -1*llr, args.modname))

            logger.info('commiting records from %s to %s' % (fn, db_fn))
            conn.commit()

    conn.close()        
    logger.info('finished.')


def db_megalodon(args):
    basename = '.'.join(args.methdata.split('.')[:-1])
    if basename == 'per_read_modified_base_calls' and not args.db:
        sys.exit('default megalodon filename (per_read_modified_base_calls) is not informative: please specify a database name with --db')    

    db_fn = basename + '.megalodon.db'

    if args.db:
        if args.db.endswith('.db'):
            db_fn = args.db
        else:
            db_fn = args.db + '.db'

    logger.info('database output name: ' + db_fn)

    if os.path.exists(db_fn) and not args.append:
        sys.exit('database %s already exists' % db_fn)

    if args.append and not os.path.exists(db_fn):
        sys.exit('database %s does not exist and --append has been called' % db_fn)

    logger.info('mod motif size (--motifsize) = %d (ensure this is correct for your data)' % int(args.motifsize))

    conn = sqlite3.connect(db_fn)

    c = conn.cursor()

    if not args.append:
        c.execute('''CREATE TABLE methdata (chrom text, pos integer, strand text, readname text, stat real, methstate integer, modname text)''')
        c.execute('''CREATE TABLE modnames (mod text)''')
        c.execute('''CREATE TABLE cutoffs (upper real, lower real, modname text)''')
        c.execute('''CREATE INDEX read_index ON methdata(readname)''')

    else:
        logger.info('appending records to %s' % db_fn)

    for tsv in args.methdata.split(','):
        methdata = None

        if tsv.endswith('.gz'):
            methdata = gzip.open(tsv, 'rt')
        else:
            methdata = open(tsv)

        logger.info('parsing ' + tsv)

        csv_reader = csv.DictReader(methdata, delimiter='\t')

        progress_interval = 1000000

        modnames = {}

        ins_data = []

        minprob = float(args.minprob)

        for i, row in enumerate(csv_reader):
            readname = row['read_id']
            chrom    = row['chrm']
            pos      = int(row['pos'])
            strand   = row['strand']
            mod_prob = np.exp(float(row['mod_log_prob']))
            can_prob = np.exp(float(row['can_log_prob']))
            
            mod_base = row['mod_base']

            modnames[mod_base] = True

            # adjust position of - strand calls based on mod motif size (default = 2 as CG is probably the most frequent use case)
            if strand == '-':
                pos -= (int(args.motifsize)-1)

            methcall = 0

            if mod_prob >= minprob:
                methcall = 1

            if can_prob >= minprob:
                methcall = -1

            ins_data.append((chrom, pos, strand, readname, mod_prob, methcall, mod_base))

            if i % progress_interval == 0:
                conn.executemany('INSERT INTO methdata VALUES (?,?,?,?,?,?,?)', ins_data)
                ins_data = []
                logger.info('processed %d records from %s' % (i, tsv))

        for mod in modnames:
            c.execute("INSERT INTO modnames VALUES ('%s')" % mod)
            c.execute("INSERT INTO cutoffs VALUES ('%.4f', '%.4f', '%s')" % (minprob, 1-minprob, mod))

        if len(ins_data) > 0:
            conn.executemany('INSERT INTO methdata VALUES (?,?,?,?,?,?,?)', ins_data)
            ins_data = []

        logger.info('commiting records from %s to %s' % (tsv, db_fn))
        conn.commit()

    conn.close()


def db_custom(args):
    basename = '.'.join(args.methdata.split('.')[:-1])
    if basename == 'per_read_modified_base_calls' and not args.db:
        sys.exit('default megalodon filename (per_read_modified_base_calls) is not informative: please specify a database name with --db')    

    db_fn = basename + '.custom.db'

    if args.db:
        if args.db.endswith('.db'):
            db_fn = args.db
        else:
            db_fn = args.db + '.db'

    logger.info('database output name: ' + db_fn)

    if os.path.exists(db_fn) and not args.append:
        sys.exit('database %s already exists' % db_fn)

    if args.append and not os.path.exists(db_fn):
        sys.exit('database %s does not exist and --append has been called' % db_fn)

    logger.info('mod motif size (--motifsize) = %d (ensure this is correct for your data)' % int(args.motifsize))

    conn = sqlite3.connect(db_fn)

    c = conn.cursor()

    if not args.append:
        c.execute('''CREATE TABLE methdata (chrom text, pos integer, strand text, readname text, stat real, methstate integer, modname text)''')
        c.execute('''CREATE TABLE modnames (mod text)''')
        c.execute('''CREATE TABLE cutoffs (upper real, lower real, modname text)''')
        c.execute('''CREATE INDEX read_index ON methdata(readname)''')

    else:
        logger.info('appending records to %s' % db_fn)

    for tsv in args.methdata.split(','):
        methdata = None

        if tsv.endswith('.gz'):
            methdata = gzip.open(tsv, 'rt')
        else:
            methdata = open(tsv)

        logger.info('parsing ' + tsv)

        progress_interval = 1000000

        modnames = {}

        ins_data = []

        c_readname = int(args.readname)
        c_chrom    = int(args.chrom)
        c_pos      = int(args.pos)
        c_strand   = int(args.strand)
        c_modprob  = int(args.modprob)

        c_canprob = None

        if args.canprob is not None:
            c_canprob  = int(args.canprob)

        mod_base = None

        if args.modbasecol is not None:
            c_modbasecol = int(args.modbasecol)

        minmodprob = float(args.minmodprob)
        mincanprob = float(args.minmodprob)

        if args.mincanprob is not None:
            mincanprob = float(args.mincanprob)


        for i, row in enumerate(methdata):
            if args.header and i == 0:
                continue

            cols = row.strip().split()

            if args.delimiter is not None:
                cols = row.strip().split(args.delimiter)

            if len(cols) < 5:
                sys.exit('table %s has < 5 columns at line %d' % (tsv, i))

            readname = cols[c_readname]
            chrom    = cols[c_chrom]
            pos      = int(cols[c_pos])
            strand   = cols[c_strand]
            mod_prob = float(cols[c_modprob])
            can_prob = 1.0-mod_prob

            if c_canprob is not None:
                can_prob = float(cols[c_canprob])

            if args.modbase is not None:
                mod_base = args.modbase

            if mod_base is None:
                sys.exit('must specify either --modbase or --modbasecol')

            modnames[mod_base] = True

            # adjust position of - strand calls based on mod motif size (default = 2 as CG is probably the most frequent use case)
            if strand == '-':
                pos -= (int(args.motifsize)-1)

            methcall = 0

            if mod_prob >= minmodprob:
                methcall = 1

            if can_prob >= mincanprob:
                methcall = -1

            ins_data.append((chrom, pos, strand, readname, mod_prob, methcall, mod_base))

            if i % progress_interval == 0:
                conn.executemany('INSERT INTO methdata VALUES (?,?,?,?,?,?,?)', ins_data)
                ins_data = []
                logger.info('processed %d records from %s' % (i, tsv))

        for mod in modnames:
            c.execute("INSERT INTO modnames VALUES ('%s')" % mod)

        c.execute("INSERT INTO cutoffs VALUES ('%.4f', '%.4f', '%s')" % (float(minmodprob), 1-float(mincanprob), mod))

        if len(ins_data) > 0:
            conn.executemany('INSERT INTO methdata VALUES (?,?,?,?,?,?,?)', ins_data)
            ins_data = []

        logger.info('commiting records from %s to %s' % (tsv, db_fn))
        conn.commit()

    conn.close()


def db_nanopolish(args):
    basename = '.'.join(args.methdata.split('.')[:-1])

    if basename.endswith('.tsv'):
        basename = '.'.join(basename.split('.')[:-1])

    db_fn = basename + '.nanopolish.db'

    if args.db:
        if args.db.endswith('.db'):
            db_fn = args.db
        else:
            db_fn = args.db + '.db'

    logger.info('database output name: ' + db_fn)

    if os.path.exists(db_fn) and not args.append:
        sys.exit('database %s already exists' % db_fn)

    if args.append and not os.path.exists(db_fn):
        sys.exit('database %s does not exist and --append has been called' % db_fn)

    conn = sqlite3.connect(db_fn)

    c = conn.cursor()

    if not args.append:
        c.execute('''CREATE TABLE methdata (chrom text, pos integer, strand text, readname text, stat real, methstate integer, modname text)''')
        c.execute('''CREATE TABLE modnames (mod text)''')
        c.execute('''CREATE INDEX read_index ON methdata(readname)''')
        c.execute('''CREATE TABLE cutoffs (upper real, lower real, modname text)''')

    else:
        logger.info('appending records to %s' % db_fn)

    for tsv in args.methdata.split(','):
        methdata = None

        if tsv.endswith('.gz'):
            methdata = gzip.open(tsv, 'rt')
        else:
            methdata = open(tsv)

        logger.info('parsing ' + tsv)

        csv_reader = csv.DictReader(methdata, delimiter='\t')

        progress_interval = 500000

        ins_data = []

        for i, row in enumerate(csv_reader, 1):
            try:
                r_start  = int(row['start'])
                llr      = float(row['log_lik_ratio'])
                seq      = row['sequence']
                mod_base = args.modname
            except:
                logger.warning('bad line %d: %s' % (i, str(row)))
                continue
                
            if args.scalegroup:
                llr = llr/float(row['num_motifs'])

            methcall = 0

            if llr > float(args.thresh):
                methcall = 1

            elif llr < float(args.thresh)*-1:
                methcall = -1

            # get per-CG position (nanopolish/calculate_methylation_frequency.py)
            if args.motif not in seq:
                sys.exit('motif %s not found in kmer %s, please check --motif setting' % (args.motif, seq))

            cg_pos = seq.find(args.motif)
            first_cg_pos = cg_pos

            while cg_pos != -1:
                cg_start = r_start + cg_pos - first_cg_pos
                cg_pos = seq.find(args.motif, cg_pos + 1)

                ins_data.append((row['chromosome'], cg_start, row['strand'], row['read_name'], llr, methcall, mod_base))

            if i % progress_interval == 0:
                conn.executemany('INSERT INTO methdata VALUES (?,?,?,?,?,?,?)', ins_data)
                ins_data = []
                logger.info('processed %d records from %s' % (i, tsv))

        if len(ins_data) > 0:
            conn.executemany('INSERT INTO methdata VALUES (?,?,?,?,?,?,?)', ins_data)
            ins_data = []

        c.execute("INSERT INTO modnames VALUES ('%s')" % args.modname)

        c.execute("INSERT INTO cutoffs VALUES ('%.4f', '%.4f', '%s')" % (float(args.thresh), float(args.thresh)*-1, args.modname))

        logger.info('commiting records from %s to %s' % (tsv, db_fn))

        conn.commit()

    conn.close()


def segmeth(args):
    '''
    segment methylation stats over genomic intervals
    '''

    stats = [
    'meth_calls',
    'unmeth_calls',
    'no_calls',
    'methfrac',
    'readcount',
    'motifcount',
    'sigmeans',
    'highmeth',
    'lowmeth'
    ]

    mod_names = []

    data = dd(list)

    if args.data is None and args.bams is None:
        sys.exit('please specify either -d/--data or -b/--bams')

    methbam = False

    if args.motif is not None:
        assert iupac(args.motif)

    if args.data is not None:
        if args.bams is not None:
            sys.exit('please specify either -d/--data or -b/--bams but not both (or use -b and specify --bedmethyl)')

        with open(args.data) as _:
            for line in _:
                bam, meth = line.strip().split()[:2]
                for m_db in meth.split(','):
                    data[bam].append(m_db)

                    for m in get_modnames(m_db):
                        mod_names.append(m)

    ct_bams = None

    if args.bams is not None:
        if (not args.bedmethyl) and None in (args.ref, args.motif):
            logger.warning('--ref and --motif are required when using --bams')
            sys.exit(1)

        if args.ref and args.bedmethyl:
            logger.warning('--ref not required with --bedmethyl and is ignored')
        
        if args.motif and args.bedmethyl:
            logger.warning('--motif not required with --bedmethyl and is ignored')

        methbam = True
        bams = []

        if args.bams.endswith('.bam') or (':' in args.bams and args.bams.split(':')[0].endswith('.bam')):
            bams = args.bams.split(',')

        elif is_bam(args.bams):
            bams = args.bams.split(',')

        elif args.bedmethyl:
            logger.info('input is expected to be bgzip/tabix bedMethyl')

            beds = args.bams.split(',')
            if beds[0].endswith('.gz'):
                bams = beds
            else:
                logger.info('-b input does not end in .gz, assuming this is a file with a list of bedMethyl .gz files...')
                with open(args.bams) as bed_list:
                    for line in bed_list:
                        if not line.strip().split()[0].endswith('.gz'):
                            sys.exit('cannot identify input type')
                        bams.append(line.strip().split()[0])

        else:
            logger.info('assuming %s contains a list of .bams' % args.bams)
            with open(args.bams) as bam_list:
                for line in bam_list:
                    if not is_bam(line.strip().split()[0]):
                        sys.exit('cannot identify input type')
                    bams.append(line.strip().split()[0])

        ct_bams = []

        if args.ctbam is not None:
            if args.bedmethyl:
                sys.exit('--ctbam not compatible with --bed')

            for ctbam_fn in args.ctbam.split(','):
                if ctbam_fn not in bams:
                    sys.exit(f'{ctbam_fn} passed to --ctbam has no corresponding .bam passed to -b/--bam')
                
                ct_bams.append(ctbam_fn)
                logger.info(f'Noted C/T substitution .bam: {ctbam_fn}')

        ct_bams = set(ct_bams)

        for bam in bams:
            if ':' in bam:
                bam, _ = bam.split(':')

            if not os.path.exists(bam):
                sys.exit(f'.bam file not found: {bam}')

            if not args.bedmethyl:
                with pysam.AlignmentFile(bam) as fh:
                    if not fh.check_index():
                        sys.exit('bam not indexed: %s' % bam)

            data[bam] = None

        if bam in ct_bams:
            logger.info(f'Assuming modtype m for C/T substitution bam {bam}')
            mod_names = ['m']
        else:
            if not args.bedmethyl:
                mod_names = mods_methbam(bam)
            else:
                mod_names = mods_bedmethyl(bam)

    mod_names = list(set(mod_names))

    if args.mods:
        user_mod_names = []
        for mod in args.mods.split(','):
            if mod not in mod_names:
                sys.exit(f'mod {mod} not found in data!')
            
            user_mod_names.append(mod)

        mod_names = list(set(user_mod_names))

    mod_names = sorted(mod_names) # ensure consistent order of appearance for mod names in table

    logger.info('using mods: %s' % ','.join(mod_names))

    base_names = ['.'.join(bam.split('.')[:-1]) for bam in data]

    if args.phased:
        phased_base_names = []
        for bn in base_names:
            phased_base_names.append(bn+'_ph1')
            phased_base_names.append(bn+'_ph2')
        base_names = phased_base_names

    data_basename = None

    meth_thresh = 0.8
    can_thresh = 0.8

    if args.meth_thresh:
        meth_thresh = float(args.meth_thresh)

    if args.can_thresh:
        can_thresh = float(args.can_thresh)

    if not args.bedmethyl:
        logger.info(f'methylated base threshold: {meth_thresh}')
        logger.info(f'canonical base threshold: {can_thresh}')

    if methbam:
        data_basename = '.'.join(os.path.basename(args.bams.split(',')[0]).split('.')[:-1])
        if len(args.bams.split(',')) > 1:
            data_basename += '.cohort'

    else:
        data_basename = '.'.join(os.path.basename(args.data).split('.')[:-1])    

    ivl_basename = '.'.join(os.path.basename(args.intervals).split('.')[:-1])
    outfn = '.'.join((ivl_basename, data_basename))

    if args.excl_ambig:
        outfn += '.excl_ambig'

    if args.spanning_only:
        outfn += '.spanning_only'
    
    if args.phased:
        outfn += '.phased'

    if args.bedmethyl:
        outfn += '.bedMethyl'
    else:
        outfn += f'.mt{args.meth_thresh}.ct{args.can_thresh}'

    outfn += '.segmeth.tsv'

    if args.outfile is not None:
        outfn = args.outfile

    interval_count = 0

    with open(args.intervals) as seg_ivls:
        for line in seg_ivls:
            interval_count += 1
    
    logger.info(f'{args.intervals} has {interval_count} intervals')

    out = open(outfn, 'w')

    logger.info('segmeth output filename: %s' % outfn)

    pool = mp.Pool(processes=int(args.procs))

    results = []

    phases = [None]

    if args.phased:
        if args.bedmethyl:
            sys.exit('--bedmethyl not compatible with --phased (generate phased bedMethyl output before running segmeth)')

        phases = ['1', '2']

    for bam_fn, meth_fn in data.items():
        for phase in phases:
            base_name = '.'.join(bam_fn.split('.')[:-1])

            if phase is not None:
                base_name += '_ph' + phase

            with open(args.intervals) as seg_ivls:
                logger.info(f'parse intervals from {bam_fn} phase {phase}')
                for line in tqdm(seg_ivls, total=interval_count):
                    c = line.strip().split()
                    chrom, seg_start, seg_end = c[:3]

                    seg_name = 'NA'
                    seg_strand = 'NA'

                    if len(c) > 3:
                        seg_name = c[3]

                    if len(c) > 4:
                        seg_strand = c[4]

                    seg_start = int(seg_start)
                    seg_end = int(seg_end)

                    res = pool.apply_async(get_segmeth_calls, [args, bam_fn, mod_names, meth_fn, chrom, seg_start, seg_end, seg_name, seg_strand, phase, methbam, args.bedmethyl, ct_bams, meth_thresh, can_thresh, float(args.lowmeth_thresh), float(args.highmeth_thresh)])

                    results.append((res, base_name))

    logger.info('building results table')

    meth_segs = dd(dict)

    for res, base_name in tqdm(results):
        meth_result, seg = res.get()

        if meth_result is None:
            continue

        seg_id = '%s:%d-%d' % seg[:3]

        seg_chrom, seg_start, seg_end, seg_name, seg_strand, read_count, motif_count, gmm_means, dmr_metrics, lowmeth_count, highmeth_count = seg

        meth_segs[seg_id]['seg_id']     = seg_id
        meth_segs[seg_id]['seg_chrom']  = seg_chrom
        meth_segs[seg_id]['seg_start']  = str(seg_start)
        meth_segs[seg_id]['seg_end']    = str(seg_end)
        meth_segs[seg_id]['seg_name']   = seg_name
        meth_segs[seg_id]['seg_strand'] = seg_strand

        if seg_name == 'NA':
            meth_segs[seg_id]['seg_name'] = 'NoName'

        if seg_strand == 'NA':
            meth_segs[seg_id]['seg_strand'] = '.'

        for modname, meth_data in meth_result.items():
            no_calls = 0
            meth_calls = 0
            unmeth_calls = 0

            if -1 in meth_data:
                unmeth_calls = meth_data[-1]

            if 0 in meth_data:
                no_calls = meth_data[0]

            if 1 in meth_data:
                meth_calls = meth_data[1]

            meth_segs[seg_id][os.path.basename(base_name) + '_' + modname + '_meth_calls'] = meth_calls
            meth_segs[seg_id][os.path.basename(base_name) + '_' + modname + '_unmeth_calls'] = unmeth_calls
            meth_segs[seg_id][os.path.basename(base_name) + '_' + modname + '_no_calls'] = no_calls

            if meth_calls+unmeth_calls == 0:
                 meth_segs[seg_id][os.path.basename(base_name) + '_' + modname + '_methfrac'] = 'NaN'
            else:
                meth_segs[seg_id][os.path.basename(base_name) + '_' + modname + '_methfrac'] = meth_calls/float(meth_calls+unmeth_calls)

            meth_segs[seg_id][os.path.basename(base_name) + '_' + modname + '_readcount'] = str(read_count)
            meth_segs[seg_id][os.path.basename(base_name) + '_' + modname + '_motifcount'] = str(len(motif_count[modname]))
            meth_segs[seg_id][os.path.basename(base_name) + '_' + modname + '_sigmeans'] = str(gmm_means[modname])
            meth_segs[seg_id][os.path.basename(base_name) + '_' + modname + '_lowmeth'] = str(lowmeth_count[modname])
            meth_segs[seg_id][os.path.basename(base_name) + '_' + modname + '_highmeth'] = str(highmeth_count[modname])

    header = ['seg_id', 'seg_chrom', 'seg_start', 'seg_end', 'seg_name', 'seg_strand']

    for bn in base_names:
        for m in mod_names:
            for s in stats:
                header.append('%s_%s_%s' % (os.path.basename(bn), m, s))

    out.write('\t'.join(header) + '\n')

    for mseg in meth_segs:
        output = [] 
        for h in header:
            if h not in meth_segs[mseg]:
                logger.warning('no calls for sample %s in segment %s, skipped.' % (h, mseg))
                break

            output.append(str(meth_segs[mseg][h]))
        out.write('\t'.join(output) + '\n')

    out.close()


def segplot(args):
    '''
    strip plots / violin plots of segmented methylation data over categories by sample
    '''

    if args.ridge:
        v = sns.__version__.split('.')
        if int(v[1]) < 11:
            sys.exit('--ridge requires seaborn 0.11.2 or later')
        elif int(v[1]) == 11 and int(v[2]) < 2:
            sys.exit('--ridge requires seaborn 0.11.2 or later')

    if args.usemeta and not args.group_by_annotation:
        if not args.ridge:
            logger.warning('must use -a/--group_by_annotation with --usemeta and violin/strip plots')
            sys.exit(1)

    data = pd.read_csv(args.segmeth, sep='\t', header=0, index_col=0)

    samples = []
    mods = []

    for c in data.columns:
        if c.endswith('_meth_calls'):
            sample_mod = c.replace('_meth_calls', '')
            sample = '_'.join(sample_mod.split('_')[:-1])
            mod = sample_mod.split('_')[-1]            
            samples.append(sample)
            if mod not in mods:
                mods.append(mod)

    logger.info('available samples: %s' % ','.join(samples))
    logger.info('available mods: %s' % ','.join(mods))

    if args.samples:
        user_samples = []
        for s in args.samples.split(','):
            if s not in samples:
                sys.exit('%s not in sample list!' % s)
            user_samples.append(s)

        samples = user_samples

    metadata = None
    usemeta = []
    meta_colours = {}
    
    palette = args.palette

    if args.metadata:
        metadata = pd.read_csv(args.metadata, sep='\t', header=0, index_col=0)

        if args.usemeta:
            for md in args.usemeta.split(','):
                md = md.strip()
                if md in metadata.columns:
                    usemeta.append(md)
                else:
                    logger.warning(f'column {md} not found in {args.metadata}')
                    sys.exit(1)
        else:
            logger.warning('--metadata without --usemeta does not affect plot')
        
        uniqmeta = []
        for md in usemeta:
            uniqmeta.append(list(set(metadata[md])))
        
        meta_categories = itertools.product(*uniqmeta)
        meta_categories = ['_'.join(mc) for mc in meta_categories]

        mcp = sns.color_palette(args.palette, n_colors=len(meta_categories))
        meta_colours = dict([(mc, mcp[i]) for i, mc in enumerate(meta_categories)])
        palette = {}

        for s in samples:
            md = []
            if s not in metadata.index:
                logger.warning(f'no metadata for sample: {s}')
                continue
            else:
                for col in usemeta:
                    md.append(metadata.loc[s][col])

            md = '_'.join(md)
            assert md in meta_colours, f'meta {md} not found'

            palette[s] = meta_colours[md]
            for m in mods:
                palette[s+'_'+m] = palette[s]

        metadata = metadata.sort_values(usemeta)
        samples = list(metadata.index)

    if args.mods:
        user_mods = []
        for m in args.mods.split(','):
            if m not in mods:
                sys.exit('%s not in mod list!' % m)
            user_mods.append(m)

        mods = user_mods

    samples_mods = []
    for s in samples:
        for m in mods:
            sample_mod = s + '_' + m
            if sample_mod not in samples_mods:
                samples_mods.append(sample_mod)

    logger.info('sample + mod permutations: %s' % ','.join(samples_mods))

    for s in samples:
        for m in mods:
            sm = s + '_' + m
            data[sm + '_methfrac'] = data[sm + '_meth_calls']/(data[sm + '_meth_calls']+data[sm + '_unmeth_calls'])

    useable = []

    for seg in data.index:
        use_seg = True

        for s in samples:
            for m in mods:
                sm = s + '_' + m

                if (data[sm + '_meth_calls'].loc[seg] + data[sm + '_unmeth_calls'].loc[seg]) < int(int(args.mincalls)):
                    use_seg = False
                    continue

                if data[sm+'_readcount'].loc[seg] < int(args.minreads):
                    use_seg = False
                    continue

        if use_seg:
            useable.append(seg)

    data = data.loc[useable]

    logger.info('useable sites: %d' % len(useable))

    if len(useable) < 1:
        sys.exit('exiting: no useable sites, possibly due to lack of coverage or lack of modification calls')

    plot_data = dd(dict)

    order = []

    for seg in data.index:
        for s in samples:
            for m in mods:
                sm = s + '_' + m
                uid = seg + ':' + sm
                plot_data[uid]['sample']  = sm
                plot_data[uid]['modbase'] = data[sm + '_methfrac'].loc[seg]
                plot_data[uid]['group']   = data['seg_name'].loc[seg]

                if args.metadata:
                    if s not in metadata.index:
                        logger.warning(f'no metadata for sample: {s}')
                    else:
                        for col in metadata:
                            plot_data[uid][col] = metadata.loc[s][col]

                if args.ridge:
                    plot_data[uid]['samplegroup'] = sm + ' ' + plot_data[uid]['group']

                if plot_data[uid]['group'] not in order:
                    order.append(plot_data[uid]['group'])

    plot_data = pd.DataFrame.from_dict(plot_data).T
    plot_data = pd.DataFrame(plot_data.to_dict())

    basename = '.'.join(args.segmeth.split('.')[:-1])

    basename += '.mc%d' % int(args.mincalls)
    basename += '.mr%d' % int(args.minreads)

    plot_data.to_csv(basename+'.segplot_data.csv')
    logger.info('plot data written to %s.segplot_data.csv' % basename)

    pt_sz = int(args.pointsize)

    if args.categories is not None:
        order = args.categories.split(',')
        plot_data = plot_data[plot_data['group'].isin(order)]

    if args.violin:
        if args.group_by_annotation:
            sns_plot = sns.violinplot(x='group', y='modbase', data=plot_data, hue='sample', dodge=True, order=order, hue_order=samples_mods, palette=palette)
        else:
            sns_plot = sns.violinplot(x='sample', y='modbase', data=plot_data, hue='group', dodge=True, hue_order=order, palette=palette)
            
        basename += '.violin'

        if args.usemeta:
            legend_elements = []
            for metacat, colour in meta_colours.items():
                catlabel = '+'.join(metacat.split('_'))
                legend_elements.append(matplotlib.patches.Patch(facecolor=colour, label=catlabel))

            meta_title = '+'.join(usemeta)
            sns_plot.legend(handles=legend_elements, title=meta_title)

    elif args.ridge:
        sns.set_theme(style="white", rc={"axes.facecolor": (0, 0, 0, 0)})

        sample_order = CategoricalDtype(samples_mods, ordered=True)
        plot_data['sample'] = plot_data['sample'].astype(sample_order)

        if args.categories:
            cat_order = CategoricalDtype(order, ordered=True)
            plot_data['group'] = plot_data['group'].astype(cat_order)

        plot_data = plot_data.sort_values(['sample','group'])

        if args.group_by_annotation:
            plot_data = plot_data.sort_values(['group','sample'])

        ridge_min = float(args.min)
        ridge_max = float(args.max)

        plot_data = plot_data.query('@ridge_min <= modbase <= @ridge_max')

        sns_plot = sns.FacetGrid(plot_data, row='samplegroup', hue='sample', aspect=15, height=.5, palette=palette)
        sns_plot.map(sns.kdeplot, 'modbase', bw_adjust=float(args.ridge_smoothing), clip_on=False, fill=True, alpha=float(args.ridge_alpha), linewidth=1.5)
        sns_plot.map(sns.kdeplot, 'modbase', clip_on=False, color='w', lw=2, bw_adjust=float(args.ridge_smoothing), alpha=float(args.ridge_alpha))
        sns_plot.refline(y=0, linewidth=2, linestyle="-", color=None, clip_on=False)
        sns_plot.figure.subplots_adjust(hspace=float(args.ridge_spacing))

        seen_samples = set()
        seen_groups = set()

        s_col = sns.color_palette(args.palette, n_colors=len(set(plot_data['sample'])))

        s = 0

        for i in range(len(set(plot_data['samplegroup']))):
            ax = sns_plot.axes[i,0]
            samplename, groupname = ax.title.get_text().split('=')[-1].strip().split()

            if args.group_by_annotation and groupname not in seen_groups:
                ax.text(0, .15, groupname, color='k', ha='left', va='center', transform=ax.transAxes)

            if args.group_by_annotation and samplename not in seen_samples:
                ax.text(0, -.19, samplename, size=10, color='k', ha='left', va='center', transform=ax.transAxes)

            if not args.group_by_annotation:
                ax.text(0, .15, samplename, weight='bold', color=s_col[s], ha='left', va='center', transform=ax.transAxes)
                s += 1
                if s == len(s_col):
                    s = 0

            if not args.group_by_annotation and groupname not in seen_groups:
                ax.text(0, -.19, groupname, size=10, color='k', ha='left', va='center', transform=ax.transAxes)


            seen_groups.add(groupname)
            seen_samples.add(samplename)

        sns_plot.set_titles("")
        sns_plot.set(yticks=[], ylabel="")
        sns_plot.despine(bottom=True, left=True)

        if args.usemeta:
            legend_elements = []
            for metacat, colour in meta_colours.items():
                catlabel = '+'.join(metacat.split('_'))
                legend_elements.append(matplotlib.patches.Patch(facecolor=colour, label=catlabel))

            meta_title = '+'.join(usemeta)
            ax = sns_plot.axes[0,0]
            ax.legend(handles=legend_elements, title=meta_title, facecolor='white', framealpha=0.5, loc='best')

        basename += '.ridge'

    else:
        if args.group_by_annotation:
            sns_plot = sns.stripplot(x='group', y='modbase', data=plot_data, hue='sample', dodge=True, jitter=True, size=pt_sz, order=order, hue_order=samples_mods, palette=palette)
        else:
            sns_plot = sns.stripplot(x='sample', y='modbase', data=plot_data, hue='group', dodge=True, jitter=True, size=pt_sz, hue_order=order, palette=palette)

    if args.group_by_annotation:
        basename += '.group_by_annotation'

    if not args.ridge:
        sns_plot.set_xlabel("")
        sns_plot.set_ylabel(args.ylabel)

        if args.tiltlabel:
            sns_plot.set_xticklabels(sns_plot.get_xticklabels(), rotation=45)

        if args.vertlabel:
            sns_plot.set_xticklabels(sns_plot.get_xticklabels(), rotation=90)

        sns_plot.set_ylim(float(args.min),float(args.max))

        if args.width is None:
            args.width = 1 + len(order) * len(samples_mods)
            logger.info('auto set --width: %d' % args.width)

    fig = sns_plot.figure

    if not args.ridge:
        fig.set_size_inches(float(args.width), float(args.height))

    outfn = basename

    if args.svg:
        outfn += '.segplot.svg'
    else:
        outfn += '.segplot.png'

    if args.outfile is not None:
        outfn = args.outfile
        if args.svg:
            if outfn.split('.')[-1] != 'svg':
                logger.warning('warning: %s does not have extension .svg, appending')
                outfn += '.svg'

    fig.savefig(outfn, bbox_inches='tight')
    logger.info('plot saved to %s' % outfn)


def locus(args):
    '''
    plot methylation profile of a region in CpG space
    '''
    # set up

    assert ':' in args.interval
    assert '-' in args.interval

    if args.panelratios:
        if args.plot_coverage or args.phasediff:
            if len(args.panelratios.split(',')) != 6:
                logger.warning('locus -p/--panelratios requires 6 terms e.g. --panelratios 1,5,1,3,3,3') 
                sys.exit(1)
        else:
            if len(args.panelratios.split(',')) != 5:
                logger.warning('locus -p/--panelratios requires 5 terms e.g. --panelratios 1,5,1,3,3') 
                sys.exit(1)

    if args.phasediff and not args.phased:
        logger.warning('--phasediff cannot be called without --phase')
        sys.exit(1)

    if args.motif is not None:
        assert iupac(args.motif)

    chrom, pos= args.interval.split(':')
    elt_start, elt_end = pos.split('-')

    elt_start = int(elt_start.replace(',',''))
    elt_end = int(elt_end.replace(',',''))
    
    meth_thresh = 0.8
    can_thresh = 0.8

    if args.meth_thresh:
        meth_thresh = float(args.meth_thresh)

    if args.can_thresh:
        can_thresh = float(args.can_thresh)

    logger.info(f'methylated base threshold: {meth_thresh}')
    logger.info(f'canonical base threshold: {can_thresh}')

    data = dd(list)
    user_colours = {}

    markers = ['.',',','o','v','^','<','>','8','s','p','P','*','h','H','X','D','d']

    if args.readmarker not in markers:
        sys.exit('%s is not a valid marker type, valid types: %s' % (args.readmarker, ','.join(markers)))

    if args.data is None and args.bams is None:
        sys.exit('please specify either -d/--data or -b/--bams')

    methbam = False

    variants = {} # pos --> [ref, alt]
    split_vid = None

    if args.splitvar:
        split_vid = args.splitvar
        if args.phased:
            logger.warning('calling both --phased and --splitvar is not currently supported')
            sys.exit(1)

    if args.variants is not None:
        var_tbx = pysam.Tabixfile(args.variants)

        with open(args.variants) as _:
            for rec in var_tbx.fetch(chrom, elt_start, elt_end):
                vchrom, vpos, vid, vref, valt, vqual, vfilt, vinfo = rec.strip().split()[:8]
                vpos = int(vpos)-1  # VCF is 1-based, BAM is 0-based

                if vfilt not in ('PASS', '.'):
                    continue

                if ',' in valt:
                    logger.warning(f'{args.variants}: multiple ALT alleles not supported')
                    continue

                vtype = None
                vlen = None

                if vref != valt and vref in ('A', 'T', 'C', 'G') and valt in ('A', 'T', 'G', 'C'):
                    vtype = 'SNV'
                    vlen = 1
                
                for info_field in vinfo.split(';'):
                    if '=' in info_field:
                        i = info_field.split('=')
                        if i[0] == 'SVTYPE':
                            vtype = i[1]
                        if i[0] == 'SVLEN':
                            vlen = int(i[1])

                if split_vid:
                    if split_vid == vid:
                        logger.info(f'found variant ID {split_vid} for sample split')
                    else:
                        continue
                
                if vtype in ('SNV', 'INS', 'DEL'):
                    variants[vpos] = [vref, valt, vtype, vlen]
        
        if split_vid and len(variants) == 0:
            logger.warning(f'ID {split_vid} not found in {args.variants}, ensure ID is present and variant is not filtered')
            sys.exit(1)

        logger.info(f'found {len(variants)} variants in {args.variants} for region {args.interval}')

        var_colours = {varpos:vcol for (varpos,vcol) in zip(variants, sns.color_palette(args.variantpalette, n_colors=len(variants)))}

    if args.data is not None:
        if args.bams is not None:
            sys.exit('please specify either -d/--data or -b/--bams but not both')

        with open(args.data) as _:
            for line in _:
                c = line.strip().split()
                if len(c) < 2:
                    logger.warning("required fields for -d/--data are: .bam file and methylation .db (generated by methylartist)")
                    sys.exit()

                bam, meth_db = c[:2]
                for m_db in meth_db.split(','):
                    data[bam].append(m_db)

                if len(c) == 3:
                    user_colours[bam] = c[2]

    if args.motif is not None:
        if len(args.motif) != int(args.motifsize):
            logger.warning('motif size (set with --motifsize) %d does not match length of --motif (%s), changed --motifsize' % (int(args.motifsize), args.motif))
            args.motifsize = len(args.motif)

    if args.bams is not None:
        logger.info('mod motif size (--motifsize) = %d (ensure this is correct for your data)' % int(args.motifsize))

        if None in (args.ref, args.motif):
            logger.warning('--ref and --motif are required when using --bams')
            sys.exit(1)

        methbam = True
        bams = []

        if args.bams.endswith('.bam') or (':' in args.bams and args.bams.split(':')[0].endswith('.bam')):
            bams = args.bams.split(',')

        elif is_bam(args.bams):
            bams = args.bams.split(',')

        else:
            logger.info('assuming %s contains a list of .bams' % args.bams)
            with open(args.bams) as bam_list:
                for line in bam_list:
                    c = line.strip().split()
                    if not c[0].endswith('.bam'):
                        if os.path.exists(c[0]+'.bam'):
                            logger.warning(f'{c[0]} doesnt end with .bam, appending a .bam because {c[0]}.bam exists')
                            c[0] += '.bam'

                    if len(c) == 1:
                        bams.append(c[0])
                    elif len(c) == 2:
                        bams.append(c[0])
                        user_colours[bams[-1]] = c[1]
                    else:
                        sys.exit(f'unparsable line in {args.bams}: {line.strip()}')

        for bam in bams:
            if ':' in bam:
                bam, ucol = bam.split(':')
                user_colours[bam] = ucol

            if not os.path.exists(bam):
                sys.exit(f'.bam file not found: {bam}')

            with pysam.AlignmentFile(bam) as fh:
                if not fh.check_index():
                    sys.exit('bam not indexed: %s' % bam)

            data[bam] = None

    colour_mapping = {}

    if args.colormap:
        mapped_user_colours = {}
        
        if args.colormap == 'auto':
            unique_annotations = list(set(user_colours.values()))
            map_colours = sns.color_palette(args.samplepalette, len(unique_annotations))
            colour_mapping = dict(zip(unique_annotations, map_colours))

            logger.info('automatic colour mapping:')
            for ann, col in colour_mapping.items():
                logger.info(f'\t{ann}: {mcolors.rgb2hex(col)}')
                colour_mapping[ann] = mcolors.rgb2hex(col)

        else:
            with open(args.colormap) as cm:
                for line in cm:
                    c = line.strip().split()
                    if len(c) != 2:
                        logger.error('malformed --colormap file (should be two columns)')
                        sys.exit(1)
                    colour_mapping[c[0]] = mcolors.rgb2hex(c[1])

            logger.info(f'loaded colour mapping from {args.colormap}:')
            for ann, col in colour_mapping.items():
                logger.info(f'\t{ann}: {col}')

        for sample, ann in user_colours.items():
            mapped_user_colours[sample] = colour_mapping[ann]
    
        user_colours = mapped_user_colours

    for c in user_colours.values():
        if not matplotlib.colors.is_color_like(c):
            logger.error(f'invalid colour {c} found in {args.bams}, use "--colormap auto" to map colours onto annotations if desired')
            sys.exit(1)

    ct_bams = []

    if args.ctbam is not None:
        for ctbam_fn in args.ctbam.split(','):
            if ctbam_fn not in data:
                sys.exit(f'{ctbam_fn} passed to --ctbam has no corresponding .bam passed to -b/--bam')
            
            ct_bams.append(ctbam_fn)
            logger.info(f'Noted C/T substitution .bam: {ctbam_fn}')

    ct_bams = set(ct_bams)

    # table for plotting

    meth_table = dd(dict)

    sample_order = []

    reads = {}
    orig_bam = {}

    use_mods = None
    if args.mods:
        use_mods = args.mods.split(',')

    phases = {}

    if args.phased:
        for bam in data:
            phased_reads = get_phased_reads(bam, chrom, elt_start, elt_end, tag_untagged=args.include_unphased, min_mapq=int(args.min_mapq), HP_only=args.ignore_ps, primary_only=args.primary_only)
            phases[bam] = list(set([p for p in phased_reads.values() if p]))
            logger.info('phases for bam %s: %s' % (bam, ','.join(phases[bam])))

    if args.splitvar:
        for bam in data:
            phased_reads = get_variant_reads(bam, chrom, elt_start, elt_end, variants, min_mapq=int(args.min_mapq), primary_only=args.primary_only)
            phases[bam] = list(set([p for p in phased_reads.values() if p]))
            logger.info('splitvar phases for bam %s: %s' % (bam, ','.join(phases[bam])))

    for bam, meth_dbs in data.items():
        if meth_dbs is None:
            if bam in ct_bams:
                logger.info(f'Assuming modtype m for C/T substitution bam {bam}')
                mods = ['m']
            else:
                mods = mods_methbam(bam)
                logger.info('found mods: %s in bam %s' % (','.join(mods), bam))

        else:
            for meth_db in meth_dbs:
                mods = sorted(get_modnames(meth_db))
                logger.info('found mods: %s in db %s' % (','.join(mods), meth_db))

        for mod in mods:
            if use_mods:
                if mod not in use_mods:
                    logger.info('skipping %s, not specified in -m/--mods %s' % (mod, args.mods))
                    continue
            else:
                use_mods = mods

            if args.phased:
                for phase in phases[bam]:
                    bamname = '.'.join(os.path.basename(bam).split('.')[:-1]) + '.' + phase + '.' + mod
                    orig_bam[bamname] = bam
                    reads[bamname] = get_meth_locus(args, bam, meth_dbs, mod, meth_thresh=meth_thresh, can_thresh=can_thresh, phase=phase, methbam=methbam, HP_only=args.ignore_ps, restrict_motif=args.motif, restrict_ref=args.ref, ct_bams=ct_bams)

                    for name, read in reads[bamname].items():
                        for loc in read.llrs.keys():
                            uuid = str(uuid4())
                            meth_table[uuid]['loc'] = loc
                            meth_table[uuid]['modstat'] = read.llrs[loc]
                            meth_table[uuid]['read'] = name
                            meth_table[uuid]['sample'] = bamname
                            meth_table[uuid]['call'] = read.meth_calls[loc]

                            if bamname not in sample_order:
                                sample_order.append(bamname)
            elif args.splitvar:
                for allele in ('ref', 'alt'):
                    bamname = '.'.join(os.path.basename(bam).split('.')[:-1]) + '.' + allele + '.' + mod
                    orig_bam[bamname] = bam
                    reads[bamname] = get_meth_locus(args, bam, meth_dbs, mod, meth_thresh=meth_thresh, can_thresh=can_thresh, allele=allele, variant=variants, methbam=methbam, HP_only=args.ignore_ps, restrict_motif=args.motif, restrict_ref=args.ref, ct_bams=ct_bams)

                    for name, read in reads[bamname].items():
                        for loc in read.llrs.keys():
                            uuid = str(uuid4())
                            meth_table[uuid]['loc'] = loc
                            meth_table[uuid]['modstat'] = read.llrs[loc]
                            meth_table[uuid]['read'] = name
                            meth_table[uuid]['sample'] = bamname
                            meth_table[uuid]['call'] = read.meth_calls[loc]

                            if bamname not in sample_order:
                                sample_order.append(bamname)

            else:
                bamname = '.'.join(os.path.basename(bam).split('.')[:-1]) + '.' + mod
                orig_bam[bamname] = bam
                reads[bamname] = get_meth_locus(args, bam, meth_dbs, mod, meth_thresh=meth_thresh, can_thresh=can_thresh, methbam=methbam, restrict_motif=args.motif, restrict_ref=args.ref, ct_bams=ct_bams)

                for name, read in reads[bamname].items():
                    for loc in read.llrs.keys():
                        uuid = str(uuid4())
                        meth_table[uuid]['loc'] = loc
                        meth_table[uuid]['modstat'] = read.llrs[loc]
                        meth_table[uuid]['read'] = name
                        meth_table[uuid]['sample'] = bamname
                        meth_table[uuid]['call'] = read.meth_calls[loc]

                        if bamname not in sample_order:
                            sample_order.append(bamname)

    meth_table = pd.DataFrame.from_dict(meth_table).T

    if 'loc' not in meth_table:
        sys.exit('%s: insufficient coverage for plot' % args.interval)

    meth_table['loc'] = pd.to_numeric(meth_table['loc'])
    meth_table['modstat'] = pd.to_numeric(meth_table['modstat'])

    meth_table['orig_loc'] = meth_table['loc']
    meth_table['loc'] = ss.rankdata(meth_table['loc'], method='dense')

    if len(meth_table['orig_loc']) == 0:
        sys.exit('%s: insufficient coverage for plot' % args.interval)        

    # optional mincall filter

    if int(args.mincalls) > 0:
        drop_ix = []

        logger.info('assessing coverage per --mincalls, per-read sites to consider: %d' % len(meth_table['loc']))
        tick = 5000
        for i, new_loc in enumerate(meth_table['loc'], 1):
            for sample in sample_order:
                call_count = len(meth_table.loc[(meth_table['sample'] == sample) & (meth_table['loc'] == new_loc)])

                if call_count < int(args.mincalls):
                    drop_ix += list(meth_table.loc[(meth_table['loc'] == new_loc)].index)
            
            if i % tick == 0:
                logger.info('%s %s: processed %d sites' % (args.data, args.interval, i))
        
        drop_ix = list(set(drop_ix))
        logger.info('dropping %d positions' % len(drop_ix))
        meth_table.drop(drop_ix, inplace=True)

        meth_table['loc'] = ss.rankdata(meth_table['loc'], method='dense')

        if len(meth_table['orig_loc']) == 0:
            sys.exit('%s: insufficient coverage for plot' % args.interval)

    # create mod space

    coord_to_cpg = {}
    for orig_loc, new_loc in zip(meth_table['orig_loc'], meth_table['loc']):
        coord_to_cpg[orig_loc] = new_loc

    # calibrate plotting parameters
    logger.info('region size: %d' % (elt_end-elt_start))

    modspace_n = len(list(set(coord_to_cpg.values())))
    logger.info('mod space positions: %d' % modspace_n)

    if args.modspace is not None:
        logger.info('user set --modspace %d, mod_n:ticks: %.3f' % (int(args.modspace), modspace_n/float(args.modspace)))
        args.modspace = int(args.modspace)
    else:
        if modspace_n <= 1200:
            args.modspace = 1
        else:
            args.modspace = round(modspace_n/1200)

        logger.info('auto set --modspace %d' % args.modspace)

    if args.smoothwindowsize is not None:
        logger.info('user set --smoothwindowsize %d, mod_n:smoothwindowsize: %.3f' % (int(args.smoothwindowsize), modspace_n/float(args.smoothwindowsize)))
        args.smoothwindowsize = int(args.smoothwindowsize)

        if args.smoothwindowsize % 2 != 0:
            args.smoothwindowsize += 1
            logger.info('-s/--smoothwindowsize must be an even integer, adjusted value: %d' % args.smoothwindowsize)

    else:
        args.smoothwindowsize = round(0.0167*modspace_n + 18)

        if args.smoothwindowsize % 2 != 0:
            args.smoothwindowsize += 1

        logger.info('auto set --smoothwindowsize %d' % args.smoothwindowsize)

    # data for coverage plot

    cover_bams = {}

    if args.plot_coverage:
        cov_bams = []

        if args.plot_coverage.endswith('.bam'):
            cov_bams = args.plot_coverage.split(',')

        if args.plot_coverage.endswith('.txt'):
            with open(args.plot_coverage) as _:
                for fn in _:
                    cov_bams.append(fn.strip())

        for bam_fn in cov_bams:
            logger.info('gather coverage from %s...' % bam_fn)

            meth_seg_starts = deepcopy(np.asarray(meth_table['orig_loc']))
            meth_seg_starts += elt_start

            cover = bam_pileupcover(bam_fn, chrom, meth_seg_starts, meth_seg_starts, procs=int(args.coverprocs), log=args.logcover)

            c_name = '.'.join(os.path.basename(bam_fn).split('.')[:-1])

            cover_table = dd(dict)
            sorted_locs = sorted(list(set(meth_table['loc'])))

            for c, loc in zip(cover, sorted_locs):
                uuid = str(uuid4())
                cover_table[uuid]['loc'] = loc
                cover_table[uuid]['sample'] = c_name
                cover_table[uuid]['cover'] = c
            
            cover_table = pd.DataFrame.from_dict(cover_table).T

            windowed_cover = slide_window_cover(cover_table, c_name, width=int(args.slidingwindowsize), slide=int(args.slidingwindowstep))
            cover_bams[c_name] = smooth(np.asarray(list(windowed_cover.values())), window_len=int(args.smoothwindowsize), window=args.smoothfunc)


    # highlights

    h_start = []
    h_end = []
    h_cpg_start = []
    h_cpg_end = []
    h_colors = []

    if args.highlight:
        h_colors = sns.color_palette(args.highlightpalette, n_colors=len(args.highlight.split(',')))

    if args.highlight:
        for h in args.highlight.split(','):
            if ':' in h:
                h = h.split(':')[-1]
                
            h_s, h_e = map(int, h.split('-'))
            
            if h_s > h_e:
                logger.warning(f'highlight {h} end < start, skipped')
                continue

            if h_e < elt_start or h_s > elt_end:
                logger.warning(f'highlight {h} out of plotting region, skipped')
                continue

            if h_s < elt_start:
                logger.warning(f'highlight position {h_s} outside region, adjusted to {elt_start}')
                h_s = elt_start

            if h_e > elt_end:
                logger.warning(f'highlight position {h_e} outside region, adjusted to {elt_end}')
                h_e = elt_end

            h_start.append(h_s - elt_start)
            h_end.append(h_e - elt_start)

            h_cpg_start.append(coord_to_cpg[min(meth_table['orig_loc'], key=lambda x:abs(x-h_start[-1]))])
            h_cpg_end.append(coord_to_cpg[min(meth_table['orig_loc'], key=lambda x:abs(x-h_end[-1]))])

    if args.highlight_bed:
        need_colour = 0

        with open(args.highlight_bed) as h_bed:
            for line in h_bed:
                c = line.strip().split()
                assert len(c) >= 3, 'malformed line in --highlight_bed: %s' % line.strip()

                if c[0] != chrom:
                    continue

                h_s, h_e = map(int, c[1:3])

                assert h_s < h_e

                if h_s < elt_start:
                    continue

                if h_e > elt_end:
                    continue

                colour = None

                if len(c) > 3:
                    colour = c[3]
                else:
                    need_colour += 1
                
                h_start.append(h_s - elt_start)
                h_end.append(h_e - elt_start)

                h_cpg_start.append(coord_to_cpg[min(meth_table['orig_loc'], key=lambda x:abs(x-h_start[-1]))])
                h_cpg_end.append(coord_to_cpg[min(meth_table['orig_loc'], key=lambda x:abs(x-h_end[-1]))])

                h_colors.append(colour)
            
            logger.info('found %d within-interval highlight regions in %s' % (len(h_start), args.highlight_bed))

        more_colours = sns.color_palette(args.highlightpalette, n_colors=need_colour)

        for i, c in enumerate(h_colors):
            if c is None:
                h_colors[i] = more_colours.pop()

    # mask

    readmask = []
    if args.readmask:
        for ivl in args.readmask.split(','):
            if ':' in ivl:
                ivl = ivl.split(':')[1]
            assert '-' in ivl, 'malformed --readmask interval(s): %s' % args.readmask

            readmask.append(list(map(int, ivl.split('-'))))

    # set up plot
    p_ratios=[1,5,1,3,3]
    if args.plot_coverage:
        p_ratios.append(3)

    if args.phased and args.phasediff:
        p_ratios.append(3)

    if args.skip_align_plot:
        p_ratios[1] = 0
    
    if not args.include_raw_plot:
        p_ratios[3] = 0

    fig = plt.figure()
    gs = gridspec.GridSpec(len(p_ratios),1,height_ratios=p_ratios, hspace=0)

    img_w = 16
    img_h = 8

    if args.width:
        img_w = float(args.width)

    if args.height:
        img_h = float(args.height)

    fig.set_size_inches(img_w, img_h)

    if args.panelratios:
        user_p_ratios = list(map(int, args.panelratios.split(',')))
        if len(user_p_ratios) != len(p_ratios):
            logger.warning(f'--panelratios must have {len(p_ratios)} values')
            sys.exit(1)

        p_ratios = user_p_ratios

        if not args.include_raw_plot:
            p_ratios[3] = 0

        gs = gridspec.GridSpec(len(p_ratios),1,height_ratios=p_ratios, hspace=0)

    sample_color = {}
    for i, sample in enumerate(sample_order):
        sample_color[sample] = sns.color_palette(args.samplepalette, n_colors=len(sample_order))[i]

    if args.color_by_phase:
        if args.phased:
            basenames = ['.'.join(sample.split('.')[:-2]) for sample in sample_color]

            base_colors = {}

            for i, basename in enumerate(basenames):
                base_colors[basename] = sns.color_palette(args.samplepalette, n_colors=len(basenames))[i]

            for sample in sample_color:
                basename = '.'.join(sample.split('.')[:-2])
                sample_color[sample] = base_colors[basename]
        
        else:
            logger.warning('--color_by_phase has no effect without --phased')


    for sample in sample_color:
        if orig_bam[sample] in user_colours:
            sample_color[sample] = user_colours[orig_bam[sample]]

    if args.color_by_hp:
        if args.phased:
            hp_color = {}
            for hp in ('1','2'):
                hp_color[hp] = sns.color_palette(args.samplepalette, n_colors=2)[int(hp)-1]
            
            for sample in sample_color:
                hp = sample.split('.')[-2]
                assert hp in ('1','2')
                sample_color[sample] = hp_color[hp]

        else:
            logger.warning('--color_by_hp has no effect without --phased')

    cover_color = {}

    if args.plot_coverage:
        for i, c_name in enumerate(cover_bams.keys()):
            cover_color[c_name] = sns.color_palette(args.coverpalette, n_colors=len(cover_bams))[i]

    # plot genes

    ax0 = plt.subplot(gs[0])

    ax0.spines['bottom'].set_visible(False)
    ax0.spines['left'].set_visible(False)
    ax0.spines['right'].set_visible(False)
    ax0.xaxis.set_ticks_position('top')

    gtf = None
    genes = [] 

    if args.gtf is not None:
        logger.info('building genes plot...')
        gtf = pysam.Tabixfile(args.gtf, index=args.csi)
        genes = build_genes(gtf, chrom, elt_start, elt_end, tx=args.show_transcripts)

    exon_patches = []
    tx_lines = []

    genes_of_interest = []

    if args.genes is not None:
        genes_of_interest = args.genes.strip().split(',')

    genemap = dd(Intersecter)

    if genes_of_interest:
        new_genes = {}
        for ensg in genes:
            if genes[ensg].name in genes_of_interest:
                new_genes[ensg] = genes[ensg]
        
        genes = new_genes

    gene_colours = sns.color_palette(args.genepalette, n_colors=len(genes))

    y = 1

    for i, ensg in enumerate(genes):
        if not genes[ensg].has_tx():
            logger.warning('no transcript for gene: %s' % genes[ensg].name)
            continue

        logger.info('gene in region: %s' % genes[ensg].name)

        y = 1

        while genemap[y].find(genes[ensg].tx_start-elt_start, genes[ensg].tx_end-elt_start):
            y += (float(args.exonheight)+0.7)

        tx_offset = float(args.exonheight)/2.0

        tx_lines.append(matplotlib.lines.Line2D([genes[ensg].tx_start-elt_start, genes[ensg].tx_end-elt_start], [tx_offset+y, tx_offset+y], color=gene_colours[i], zorder=2))

        genes[ensg].merge_exons()
        for exon_start, exon_end in genes[ensg].exons:
            exon_len = exon_end - exon_start
            exon_patches.append(matplotlib.patches.Rectangle([exon_start-elt_start, y], exon_len, float(args.exonheight), edgecolor=gene_colours[i], facecolor=gene_colours[i], zorder=3))

        if not args.hidegenelabels:
            label = genes[ensg].name

            if genes[ensg].strand in ('+', '-'):
                if genes[ensg].strand == '+':
                    label += '>>'
                
                else:
                    label = '<<' + label

            lg_x = max(genes[ensg].tx_start-elt_start, 0)

            nudge_up = 0.0
            if genes[ensg].tx_start-elt_start < 0:
                nudge_up = tx_offset+0.1

            gtxt = ax0.text(lg_x, y+nudge_up, label, zorder=4, fontsize='small')
            bb_w = gtxt.get_tightbbox(renderer=fig.canvas.get_renderer()).width
            fig_w = fig.get_size_inches()[0]*fig.dpi
            txt_w = bb_w/fig_w*(elt_end-elt_start)
            gtxt.set_x(lg_x-txt_w*1.42)

            genemap[y].add_interval(Interval(genes[ensg].tx_start-elt_start-(txt_w*1.5), genes[ensg].tx_end-elt_start+(txt_w*1.5)))

        else:
            genemap[y].add_interval(Interval(genes[ensg].tx_start-elt_start, genes[ensg].tx_end-elt_start))

    gene_height = max(4, len(genemap))

    if not args.hidegenelabels:
        gene_height = max(8, len(genemap))

    if args.bed:
        logger.info('loading annotations from %s' % args.bed)
        gene_height += 3

        y += 1

        bed_annotations = get_bed_annotations(args.bed, chrom, elt_start, elt_end)
        bed_colours = sns.color_palette(args.genepalette, n_colors=len(bed_annotations))

        for i, ann in enumerate(bed_annotations):
            bed_colour = bed_colours[i]

            if ann.colour is not None:
                bed_colour = ann.colour

            exon_patches.append(matplotlib.patches.Rectangle([ann.start-elt_start, y], (ann.end-ann.start), float(args.exonheight), edgecolor=bed_colour, facecolor=bed_colour, zorder=3))

            if ann.label is not None and not args.hidebedlabel:
                lg_x  = max(ann.start-elt_start, 0)
                label = ann.label

                if ann.strand is not None:
                    if ann.strand == '+':
                        label += '>>'
                    
                    else:
                        label = '<<' + label

                gtxt  = ax0.text(lg_x, y, label, bbox=dict(boxstyle="round,pad=0.3", fc="lavender", alpha=0.5, lw=0), zorder=4)

    ax0.set_ylim(0, gene_height)
    ax0.set_yticks([])

    highlight_patches = []

    if args.highlight_alpha:
        a = float(args.highlight_alpha)
        if args.highlight or args.highlight_bed:
            for h, h_s in enumerate(h_start):
                h_e = h_end[h]
                h_color = h_colors[h]
                ymin = min(ax0.get_ybound())
                yheight = max(ax0.get_ybound())-ymin
                highlight_patches.append(matplotlib.patches.Rectangle([h_s, ymin], h_e-h_s, yheight, edgecolor=None, facecolor=h_color, alpha=a, zorder=1))

    for h in highlight_patches:
        ax0.add_patch(h)

    for p in exon_patches:
        ax0.add_patch(p)

    for tx in tx_lines:
        ax0.add_line(tx)

    # per-read plot

    ax1 = plt.subplot(gs[1])
    ax1.set_xticks([])
    ax1.set_yticks([])
    ax1.set_frame_on(False)

    if not args.skip_align_plot:
        logger.info('building read alignment plot...')

        readstack = dd(list)
        varcache = {}
        pos_cache = {}
        masked_count = 0

        # Process all BAM files first to collect position and variant data
        for bamname in reads:
            fetch_reads_bam = None

            # in case the user has a dodgy nfs
            retries = 0
            bam_open = False

            while not bam_open:
                try:
                    fetch_reads_bam = pysam.AlignmentFile(orig_bam[bamname])
                    bam_open = True

                except OSError as e:
                    if retries < 5:
                        logger.warning(f'retry open on {orig_bam[bamname]}, possibly due to network filesystem')
                        sleep(1)
                        retries += 1
                        
                    else:
                        sys.exit(e)
            
            # Pre-filter reads that match our criteria
            filtered_reads = []
            for read in fetch_reads_bam.fetch(chrom, elt_start, elt_end):
                if read.mapq < int(args.min_mapq):
                    continue

                if not args.allreads and (read.is_supplementary or read.is_secondary or read.is_duplicate):
                    continue

                # Check if read is masked
                masked = False
                if len(readmask) > 0:
                    for mask_start, mask_end in readmask:
                        if read.reference_start >= mask_start and read.reference_end <= mask_end:
                            logger.debug('masked read: %s' % read.query_name)
                            masked_count += 1
                            masked = True
                            break
                
                if masked:
                    continue
                
                filtered_reads.append(read)
            
            # Process filtered reads in batch
            for read in filtered_reads:
                if read.query_name not in pos_cache:
                    pos_cache[read.query_name] = ([read.reference_start], [read.reference_end])
                else:
                    pos_cache[read.query_name][0].append(read.reference_start)
                    pos_cache[read.query_name][1].append(read.reference_end)

                # Process variants if needed
                if args.variants and read.query_name not in varcache:
                    varcache[read.query_name] = check_variants(variants, read, allele='alt')
            
            fetch_reads_bam.close()

        # Build readstack with position data
        max_y = 1
        for bamname in reads:
            for readname, read in reads[bamname].items():
                if readname not in pos_cache or read.call_count == 0:
                    continue

                read.ypos = max_y
                max_y += 1
                read.starts, read.ends = pos_cache[readname]
                readstack[bamname].append(read)

        if args.readmask:
            logger.info('masked %d reads due to --readmask' % masked_count)

        # Process each sample separately to maintain sample grouping
        pack_y = 1
        for bamname in readstack:
            # Group reads by y-position
            y_groups = dd(list)
            for read in readstack[bamname]:
                y_groups[read.ypos].append(read)
            
            # Sort reads by start position for better packing
            sorted_reads = []
            for y_pos in y_groups:
                sorted_reads.extend(sorted(y_groups[y_pos], key=lambda r: min(r.starts)))
            
            # Pack reads for this sample only
            packed_rows = []
            for read in sorted_reads:
                # Try to place read in an existing row
                placed = False
                for row_idx, row in enumerate(packed_rows):
                    can_place = True
                    for existing_read in row:
                        if read.overlap(existing_read):
                            can_place = False
                            break
                    
                    if can_place:
                        row.append(read)
                        read.ypos = pack_y + row_idx
                        placed = True
                        break
                
                # If can't place in existing rows, create a new row
                if not placed:
                    packed_rows.append([read])
                    read.ypos = pack_y + len(packed_rows) - 1
            
            # Update pack_y for the next sample
            if packed_rows:
                pack_y += len(packed_rows)
            else:
                pack_y += 1

        ax1.set_ylim(0, pack_y+1)

        bam_ylims = {}
        bam_xlims = {}

        # Batch plotting for better performance - process each sample separately
        for bamname in readstack:
            rm = args.readmarker
            ms = float(args.readmarkersize)
            lw = float(args.readlinewidth)
            la = float(args.readlinealpha)
            ma = float(args.markeralpha)
            uec = sample_color[bamname]
            
            if args.readopenmarkeredgecolor is not None:
                uec = args.readopenmarkeredgecolor
            
            # Collect all plotting data first for this sample
            unmeth_markers = []
            meth_markers = []
            read_lines = []
            variant_markers = []
            
            for read in readstack[bamname]:
                # Collect methylation markers
                for call_pos, call in read.meth_calls.items():
                    if call == -1:
                        unmeth_markers.append((call_pos, read.ypos))
                    elif call == 1:
                        meth_markers.append((call_pos, read.ypos))
                
                # Collect read lines
                for i in range(len(read.starts)):
                    readline_start = max(read.starts[i], elt_start) - elt_start
                    readline_end = min(read.ends[i], elt_end) - elt_start
                    read_lines.append(((readline_start, read.ypos), (readline_end, read.ypos)))
                
                # Collect variant markers
                if args.variants and read.read_name in varcache:
                    for alt_site in varcache[read.read_name]:
                        variant_markers.append((alt_site-elt_start, read.ypos, var_colours[alt_site]))
            
            # Plot in batches for this sample
            if unmeth_markers:
                x, y = zip(*unmeth_markers)
                ax1.scatter(x, y, marker=rm, edgecolor=uec, facecolor='white', s=ms**2, zorder=3, alpha=ma)
            
            if meth_markers:
                x, y = zip(*meth_markers)
                ax1.scatter(x, y, marker=rm, edgecolor='black', facecolor='black', s=ms**2, zorder=3, alpha=ma)
            
            # Plot read lines for this sample
            for (x1, y1), (x2, y2) in read_lines:
                ax1.add_line(matplotlib.lines.Line2D([x1, x2], [y1, y2], lw=lw, zorder=2, color=sample_color[bamname], alpha=la))
            
            # Plot variant markers for this sample
            for x, y, color in variant_markers:
                ax1.plot(x, y, marker='v', fillstyle='full', mec='black', mfc=color, markersize=int(args.variantsize), zorder=4, alpha=1)

            bam_ylims[bamname] = [min([r[0][1] for r in read_lines]), max([r[0][1] for r in read_lines])]
            bam_xlims[bamname] = [min([r[1][0] for r in read_lines]), max([r[1][0] for r in read_lines])]
        
        if args.samplebox:
            xlim_max = max([v[1] for v in bam_xlims.values()])

            sample_margin = abs(ax1.get_xbound()[1])-xlim_max
            sample_box_width = sample_margin*.2
            sample_box_pad = sample_margin*.1

            sample_box_x1 = xlim_max + sample_box_pad

            sample_box_patches = []

            for bamname, sample_box_y in bam_ylims.items():
                sample_box_patches.append(matplotlib.patches.Rectangle([sample_box_x1, sample_box_y[0]-0.5], sample_box_width, sample_box_y[1]-sample_box_y[0]+1, edgecolor='black', facecolor=sample_color[bamname], zorder=1))
                if not args.notext:
                    ax1.text(sample_box_x1+sample_box_width+sample_box_pad, sum(sample_box_y)/2, bamname, ha='left', va='center', fontsize='x-small')

            for sb in sample_box_patches:
                ax1.add_patch(sb)

        highlight_patches = []

        if args.highlight_alpha:
            a = float(args.highlight_alpha)
            if args.highlight or args.highlight_bed:
                for h, h_s in enumerate(h_start):
                    h_e = h_end[h]
                    h_color = h_colors[h]
                    ymin = min(ax1.get_ybound())
                    yheight = max(ax1.get_ybound())-ymin
                    highlight_patches.append(matplotlib.patches.Rectangle([h_s, ymin], h_e-h_s, yheight, edgecolor=None, facecolor=h_color, alpha=a, zorder=1))

        for h in highlight_patches:
            ax1.add_patch(h)

    # plot correspondence between genome space and cpg space

    logger.info('building mod-space plot...')
    ax2 = plt.subplot(gs[2])
    ax3 = ax2.twiny()

    ax2.set_ylim(0,10)
    ax2.set_yticklabels([])
    ax2.set_yticks([])

    x1 = []
    x2 = []

    step = int(args.modspace)

    orig_positions = list(set(meth_table['orig_loc']))

    for i, x in enumerate(orig_positions):
        if i in (0, len(orig_positions)-1):
            x2.append(x)
            x1.append(coord_to_cpg[x])

        elif i % step == 0:
            x2.append(x)
            x1.append(coord_to_cpg[x])

    ax2.vlines(x1, 0, 1, color='#777777', zorder=1)
    ax3.vlines(x2, 9, 10, color='#777777', zorder=1)

    if args.highlight or args.highlight_bed:
        for i in range(len(h_start)):
            orig_highlight_box = matplotlib.patches.Rectangle((h_start[i],9), h_end[i]-h_start[i], 1.0, lw=1, edgecolor='#777777', facecolor=h_colors[i], zorder=2)
            cpg_highlight_box = matplotlib.patches.Rectangle((h_cpg_start[i],0), h_cpg_end[i]-h_cpg_start[i], 1.0, lw=1, edgecolor='#777777', facecolor=h_colors[i], zorder=3)

            ax3.add_patch(orig_highlight_box)
            ax2.add_patch(cpg_highlight_box)

    for x1_x, x2_x in zip(x1, x2):
        link_end1 = (x1_x, 1)
        link_end2 = (x2_x, 9)

        l_col = '#777777'

        for i in range(len(h_start)):
            if x2_x >= h_start[i] and x2_x <= h_end[i]:
                l_col = h_colors[i]

        con = ConnectionPatch(xyA=link_end1, xyB=link_end2, coordsA="data", coordsB="data", axesA=ax2, axesB=ax3, color=l_col)
        ax3.add_artist(con)

    ax0.set_xlim(ax3.get_xlim()) # sync axes between orig coords and gtf plot
    ax1.set_xlim(ax3.get_xlim())
    ax2.set_xticks([])
    ax3.set_xticks([])

    n_ticks = int(args.nticks) + 1
    tick_interval = (elt_end-elt_start)/n_ticks
    tick_list = list(range(int(ax0.get_xlim()[0]), int(ax0.get_xlim()[1]), int(tick_interval)))

    revised_tick_list = []
    for t in tick_list:
        if t < 0:
            revised_tick_list.append(0)
        else:
            revised_tick_list.append(t)

    tick_list = sorted(list(set(revised_tick_list)))

    xt_labels = [str(int(t+elt_start)) for t in tick_list]
    xt_labels[0] = chrom

    ax0.set_xticks(tick_list)

    if n_ticks > 11:
        ax0.set_xticklabels(xt_labels, rotation=45)
    else:
        ax0.set_xticklabels(xt_labels)
    
    # llr plot

    ax4 = plt.subplot(gs[3])
    ax4.set_xticks([])
    ax4.yaxis.tick_right()

    if not args.include_raw_plot:
        ax4.set_yticks([])

    else:
        logger.info('building llr plot... (if this is slow and/or raw signal is not helpful, try --skip_raw_plot)')

        ax4.axhline(y=0, c='#bbbbbb', linestyle='--',lw=1)

        for mod in use_mods:
            upper = 1.0
            lower = 0.0

            if not methbam:
                upper, lower = get_cutoffs(list(data.values())[0][0], mod)

            ax4.axhline(y=upper, c='k', linestyle='--',lw=1)
            ax4.axhline(y=lower, c='k', linestyle='--',lw=1)

        legend='auto'
        if args.hidelegend:
            legend=False

        ax4 = sns.lineplot(x='loc', y='modstat', hue='sample', data=meth_table, palette=sample_color, zorder=2, legend=legend)
        
        if args.statname:
            ax4.set_ylabel(args.statname)

        ax4.set_xlim(ax2.get_xlim())

        highlight_patches = []

        if args.highlight_alpha:
            a = float(args.highlight_alpha)
            if args.highlight or args.highlight_bed:
                for h, h_s in enumerate(h_cpg_start):
                    h_e = h_cpg_end[h]
                    h_color = h_colors[h]
                    ymin = min(ax4.get_ybound())
                    yheight = max(ax4.get_ybound())-ymin
                    highlight_patches.append(matplotlib.patches.Rectangle([h_s, ymin], h_e-h_s, yheight, edgecolor=None, facecolor=h_color, alpha=a, zorder=1))

        for h in highlight_patches:
            ax4.add_patch(h)

    # meth frac plot

    logger.info('building meth frac plot...')
    ax5 = plt.subplot(gs[4])

    order_stack = 2

    smoothalpha = float(args.smoothalpha)
    slw = float(args.smoothlinewidth)

    if smoothalpha > 1.0 or smoothalpha < 0.0:
        logger.warning('--smoothalpha must be between 0 and 1, set to 1.0')
        smoothalpha = 1.0

    smoothed_values = dd(dict)

    sample_filtered_data = {}
    for sample in sample_order:
        sample_filtered_data[sample] = meth_table[meth_table['sample'] == sample]

    skipped_samples = []

    if args.shuffle:
        random.shuffle(sample_order)

    for sample in sample_order:
        sample_data = sample_filtered_data[sample]

        windowed_methfrac, meth_n = slide_window(sample_data, sample, width=int(args.slidingwindowsize), slide=int(args.slidingwindowstep))
        methfrac_values = np.asarray(list(windowed_methfrac.values()))

        smoothed_methfrac = smooth(methfrac_values, window_len=int(args.smoothwindowsize), window=args.smoothfunc)

        ix = np.asarray(list(windowed_methfrac.keys()))

        for i, smooth_val in enumerate(smoothed_methfrac):
            smoothed_values[sample][ix[i]] = smooth_val

        masked_segs = mask_methfrac(list(meth_n.values()), cutoff=args.maskcutoff)

        frac_masked = 0.0
        if len(meth_n) > 0:
            frac_masked = len(list(itertools.chain(*masked_segs))) / len(meth_n.values())

        logger.info('%s:%d-%d (%s), sample %s fraction masked: %.3f' % (chrom, elt_start, elt_end, ''.join(use_mods), sample, frac_masked))

        if frac_masked > float(args.maxmaskedfrac):
            logger.warning('%s:%d-%d (%s), skip sample %s due to --maxmaskedfrac %.3f' % (chrom, elt_start, elt_end, ''.join(use_mods), sample, float(args.maxmaskedfrac)))
            skipped_samples.append(sample)
            continue

        if frac_masked > 0.1 and args.bams is not None and args.ref is None:
            logger.warning('*** WARNING: specifying a reference genome (indexed via samtools faidx) with --restrict_ref is strongly recommended when using mod .bams ***')

        ax5.plot(list(windowed_methfrac.keys()), smoothed_methfrac, marker='', lw=slw, color=sample_color[sample], alpha=smoothalpha, label=sample)

        order_stack += 1

        if not args.nomask:
            for seg in masked_segs:
                if len(seg) > 2:
                    mf_seg = np.asarray(smoothed_methfrac)[seg]
                    pos_seg = np.asarray(list(windowed_methfrac.keys()))[seg]
                
                    ax5.plot(pos_seg, mf_seg, marker='', lw=slw, color='#ffffff', alpha=0.8, zorder=order_stack)

                    order_stack += 1

    smoothed_df = pd.DataFrame(smoothed_values)

    if not args.include_raw_plot:
        if not args.hidelegend:
            ax5.legend()
    else:
        ax4.set_zorder(order_stack)  # adjust z-level of legend after smoothed plot is finished

    if args.color_by_hp:
        handles, labels = ax4.get_legend_handles_labels()
        phases = list(set([label.split('.')[-2] for label in labels]))

        found_phases = {}

        new_labels = []
        new_handles = []

        if args.phase_labels:
            if ':' not in args.phase_labels:
                sys.exit('incorrect syntax: %s' % args.phase_labels)

            phase_labels = dict([pl.split(':') for pl in args.phase_labels.split(',')])

        for i, label in enumerate(labels):
            hp = label.split('.')[-2]
            if hp not in found_phases:
                phase_name = hp

                if args.phase_labels:
                    if hp in phase_labels:
                        phase_name = phase_labels[hp]

                new_labels.append(phase_name)
                new_handles.append(handles[i])
                
            found_phases[hp] = True

        if not args.hidelegend:
            if not args.include_raw_plot:
                ax5.legend(new_handles, new_labels)
            else:
                ax4.legend(new_handles, new_labels)

    if args.colormap:
        if not args.include_raw_plot:
            ax5.legend([mpatches.Patch(color=c) for c in colour_mapping.values()], colour_mapping.keys())
        else:
            ax4.legend([mpatches.Patch(color=c) for c in colour_mapping.values()], colour_mapping.keys())

    ax5.set_xlim(ax2.get_xlim())
    ax5.set_ylim((float(args.ymin),float(args.ymax)))

    highlight_patches = []

    if args.highlight_alpha:
        a = float(args.highlight_alpha)
        if args.highlight or args.highlight_bed:
            for h, h_s in enumerate(h_cpg_start):
                h_e = h_cpg_end[h]
                h_color = h_colors[h]
                ymin = min(ax5.get_ybound())
                yheight = max(ax5.get_ybound())-ymin
                highlight_patches.append(matplotlib.patches.Rectangle([h_s, ymin], h_e-h_s, yheight, edgecolor=None, facecolor=h_color, alpha=a, zorder=1))

    for h in highlight_patches:
        ax5.add_patch(h)

    if not args.include_raw_plot:
        ax5.set_zorder(1)

    # coverage plot (optional)

    if args.plot_coverage:
        logger.info('building coverage plot...')
        ax6 = plt.subplot(gs[5])

        for sample in cover_bams:
            offset = list(windowed_methfrac.keys())[0]
            matched_cover = [cover_bams[sample][i-offset] for i in windowed_methfrac.keys()]
            
            ax6.plot(list(windowed_methfrac.keys()), matched_cover, marker='', lw=slw, color=cover_color[sample], zorder=3, label=sample, alpha=smoothalpha)

        ax6.legend()
        ax6.set_xlim(ax2.get_xlim())

        highlight_patches = []

        if args.highlight_alpha:
            a = float(args.highlight_alpha)
            if args.highlight or args.highlight_bed:
                for h, h_s in enumerate(h_cpg_start):
                    h_e = h_cpg_end[h]
                    h_color = h_colors[h]
                    ymin = min(ax6.get_ybound())
                    yheight = max(ax6.get_ybound())-ymin
                    highlight_patches.append(matplotlib.patches.Rectangle([h_s, ymin], h_e-h_s, yheight, edgecolor=None, facecolor=h_color, alpha=a, zorder=1))

        for h in highlight_patches:
            ax6.add_patch(h)

    # phase difference plot (optional)
    if args.phased and args.phasediff:
        logger.info('phase difference plot...')
        ax7 = plt.subplot(gs[-1])

        smoothed_values = pd.DataFrame(smoothed_values)

        phase_lookup = dd(list)

        for sn in smoothed_values.columns:
            if sn in skipped_samples:
                continue

            sn_parts = sn.split('.')
            sn_ph = '.'.join(sn_parts[:-2])
            if ':' in sn_parts[-2]:
                sn_ph = sn_ph + '.' + sn_parts[-2].split(':')[1] + '.' + sn_parts[-1]
            phase_lookup[sn_ph].append(sn)

        samples = []

        for sample in phase_lookup:
            if len(phase_lookup[sample]) == 2:
                samples.append(sample)
                ph1 = smoothed_values[phase_lookup[sample][0]]
                ph2 = smoothed_values[phase_lookup[sample][1]]
                smoothed_values[sample] = (ph1-ph2).abs()

        pd_colors = {}

        for pd_sample, phased_samples in phase_lookup.items():
            pd_colors[pd_sample] = sample_color[phased_samples[0]]

        if args.shuffle:
            random.shuffle(samples)

        phase_diffs = smoothed_values[samples]

        legend='auto'
        if args.hidelegend:
            legend=False

        ax7 = sns.lineplot(data=phase_diffs, lw=slw, palette=pd_colors, dashes=False, zorder=2, legend=legend, alpha=smoothalpha)
        phasediff_df = pd.DataFrame(phase_diffs)

        ax7.set_xlim(ax2.get_xlim())
        ax7.set_ylim((float(args.ymin),float(args.ymax)))
        ax7.set_ylim()

        highlight_patches = []

        if args.highlight_alpha:
            a = float(args.highlight_alpha)
            if args.highlight or args.highlight_bed:
                for h, h_s in enumerate(h_cpg_start):
                    h_e = h_cpg_end[h]
                    h_color = h_colors[h]
                    ymin = min(ax7.get_ybound())
                    yheight = max(ax7.get_ybound())-ymin
                    highlight_patches.append(matplotlib.patches.Rectangle([h_s, ymin], h_e-h_s, yheight, edgecolor=None, facecolor=h_color, alpha=a, zorder=1))

        for h in highlight_patches:
            ax7.add_patch(h)


    fn_prefix = '.%s_%d_%d.%s' % (chrom, elt_start, elt_end, ''.join(use_mods))

    if methbam:
        fn_prefix = '.'.join(os.path.basename(args.bams.split(',')[0]).split('.')[:-1]) + fn_prefix
        if len(args.bams.split(',')) > 1:
            fn_prefix += '.cohort'
    else:
        fn_prefix = '.'.join(os.path.basename(args.data).split('.')[:-1]) + fn_prefix

    if args.genes is not None:
        fn_prefix = '_'.join(args.genes.split(',')) + '.' + fn_prefix

    if args.phased:
        if args.phasediff:
            fn_prefix += '.phasediff'
        else:
            fn_prefix += '.phased'

    if args.splitvar:
        fn_prefix += f'.split_{args.splitvar}'

    param_str = '.ms%d.smw%d' % (args.modspace, int(args.smoothwindowsize))

    fn_prefix += param_str

    if args.max_read_density:
        fn_prefix += '.mrd%.2f' % float(args.max_read_density)

    if int(args.mincalls) > 0:
        fn_prefix += '.mc%d' % int(args.mincalls)

    fn_prefix += f'.mt{args.meth_thresh}.ct{args.can_thresh}'

    if args.notext:
        for i, g in enumerate(gs):
            if i != 2:
                remove_all_text(plt.subplot(g))

        fn_prefix += '.notext'

    outfn = fn_prefix

    if args.svg:
        outfn += '.locus.meth.svg'
    else:
        outfn += '.locus.meth.png'

    if args.outfile is not None:
        outfn = args.outfile
        if args.svg:
            if outfn.split('.')[-1] != 'svg':
                logger.warning('warning: %s does not have extension .svg, appending')
                outfn += '.svg'

    if args.smoothed_csv:
        if args.smoothed_csv == 'auto':
            args.smoothed_csv = fn_prefix + '.locus.meth.csv'

        smoothed_df.to_csv(args.smoothed_csv)
        logger.info(f'smoothed values written to {args.smoothed_csv}')

        if args.phasediff:
            phasediff_csv = '.'.join(args.smoothed_csv.split('.')[:-1]) + '.phasediff.csv'
            phasediff_df.to_csv(phasediff_csv)
            logger.info(f'smoothed phasediff values written to {phasediff_csv}')

    fig.savefig(outfn, bbox_inches='tight')
    logger.info('plot saved to %s' % outfn)

    if args.highlight_subplot and args.highlight and colour_mapping:
        reverse_colour_mapping = {v:k for k,v in colour_mapping.items()}

        fn_prefix = '.%s_%d_%d.%s' % (chrom, elt_start, elt_end, ''.join(use_mods))

        if methbam:
            fn_prefix = '.'.join(os.path.basename(args.bams.split(',')[0]).split('.')[:-1]) + fn_prefix
        else:
            fn_prefix = '.'.join(os.path.basename(args.data).split('.')[:-1]) + fn_prefix


        for h, h_s in enumerate(h_cpg_start):
            h_e = h_cpg_end[h]
            
            if not args.phasediff:
                plt.clf()
                plt.rcdefaults()
                sns.reset_defaults()
                plt.close('all')

                h_subset = smoothed_df.query(f'index >= {h_s} and index <= {h_e}').dropna(axis=1, how='all')

                h_plotdata = dd(dict)

                for i, sample in enumerate(h_subset.columns):
                    col = mcolors.rgb2hex(sample_color[sample])

                    h_plotdata[i]['sample'] = sample
                    h_plotdata[i]['median'] = h_subset[sample].median()
                    h_plotdata[i]['group']  = reverse_colour_mapping[col]
                    h_plotdata[i]['colour'] = col

                h_plotdata = pd.DataFrame.from_dict(h_plotdata).T

            
                sns.boxplot(data=h_plotdata, x='group', hue='group', y='median',palette=colour_mapping, boxprops=dict(alpha=0.4), whiskerprops=dict(alpha=0.4), capprops=dict(alpha=0.4), fill=False, width=0.3)
                sns.swarmplot(data=h_plotdata, x='group', hue='group', y='median', palette=colour_mapping, legend=False)
                out_prefix = f'{fn_prefix}.highlight_subplot_{h}'

                if args.svg:
                    out_fn = out_prefix + '.svg'
                else:
                    out_fn = out_prefix + '.png'

                plt.title(f'{fn_prefix} highlight {h}')
                plt.xlabel('median methylation in highlight')
                plt.savefig(out_fn)
                h_plotdata.to_csv(out_prefix+'.csv')
                logger.info(f'saved highlight plot to {out_fn}')
            

            if args.phasediff:
                plt.clf()
                plt.rcdefaults()
                sns.reset_defaults()
                plt.close('all')

                h_subset_pd = phasediff_df.query(f'index >= {h_s} and index <= {h_e}').dropna(axis=1, how='all')
                

                h_plotdata = dd(dict)

                for i, sample in enumerate(h_subset_pd.columns):
                    col = mcolors.rgb2hex(pd_colors[sample])

                    h_plotdata[i]['sample'] = sample
                    h_plotdata[i]['median'] = h_subset_pd[sample].median()
                    h_plotdata[i]['group']  = reverse_colour_mapping[col]
                    h_plotdata[i]['colour'] = col

                h_plotdata = pd.DataFrame.from_dict(h_plotdata).T
                
                sns.boxplot(data=h_plotdata, x='group', hue='group', y='median', palette=colour_mapping, boxprops=dict(alpha=0.4), whiskerprops=dict(alpha=0.4), capprops=dict(alpha=0.4), fill=False, width=0.3)
                sns.swarmplot(data=h_plotdata, x='group', hue='group', y='median', palette=colour_mapping, legend=False)
                out_prefix = f'{fn_prefix}.highlight_subplot_{h}.phasediff'

                if args.svg:
                    out_fn = out_prefix+'.svg'
                else:
                    out_fn = out_prefix+'.png'

                plt.title(f'{fn_prefix} phasediff hl {h}')
                plt.xlabel('median phasediff in highlight')
                plt.savefig(out_fn)
                h_plotdata.to_csv(out_prefix+'.csv')
                logger.info(f'saved phasediff highlight plot to {out_fn}')


def region(args):
    '''
    plotting function for windowed view of larger regions
    '''

    assert ':' in args.interval
    assert '-' in args.interval

    args.predict_dmr = False

    chrom, pos = args.interval.split(':')
    start, end = pos.split('-')

    start = int(start.replace(',',''))
    end = int(end.replace(',',''))

    assert start < end

    if args.motif is not None:
        assert iupac(args.motif)

    if end-start < 500000:
        logger.warning('locus smaller than 0.5 Mbp, "methylartist locus" may yield better results')

    scale_width = 1.0

    if args.scale_fullwidth:
        scale_width = (end-start)/float(args.scale_fullwidth)
        logger.info('scale width to %.3f' % scale_width)

    ref = pysam.FastaFile(args.ref)
    motifs = iupac(args.motif)
    region_seq = ref.fetch(chrom, start, end).upper()
    
    motif_count = 0
    for motif in motifs:
        motif_count += region_seq.count(motif) + region_seq.count(rc(motif))

    if args.windows is None:
        args.windows = round(motif_count / 30)
        logger.info('set window count to %d' % args.windows)

    args.windows = int(args.windows)

    if args.windows < 500:
        args.windows = 500
        logger.info('resetting windows to a minimum of 500')

    motifs_per_window = int(motif_count / int(args.windows))

    if motifs_per_window == 0:
        motifs_per_window = 1

    logger.info('motif count: %d, per window: %d' % (motif_count, motifs_per_window))

    w_starts, w_ends = find_motif_windows(region_seq, motifs, start, end, motifs_per_window)
    assert len(w_starts) == len(w_ends)
    logger.info('using %d windows normalised for %s content' % (len(w_starts), args.motif))

    if args.smoothwindowsize is None:
        w = len(w_starts)
        args.smoothwindowsize = round(0.02*w + 4)
        if args.smoothwindowsize % 2 != 0:
            args.smoothwindowsize += 1
            
        logger.info('set --smoothwindowsize to %d' % args.smoothwindowsize)

    args.smoothwindowsize = int(args.smoothwindowsize)

    if args.modspace is None:
        args.modspace = round(len(w_starts)/300)
        if args.modspace == 0:
            args.modspace = 1

        logger.info('set --modspace to %d' % args.modspace)

    args.modspace = int(args.modspace)

    data = dd(list)
    mods = []
    user_colours = {}

    if args.data is None and args.bams is None:
        sys.exit('please specify either -d/--data or -b/--bams (or use -b and specify --bedmethyl)')

    meth_thresh = 0.8
    can_thresh = 0.8

    if args.meth_thresh:
        meth_thresh = float(args.meth_thresh)

    if args.can_thresh:
        can_thresh = float(args.can_thresh)

    if not args.bedmethyl:
        logger.info(f'methylated base threshold: {meth_thresh}')
        logger.info(f'canonical base threshold: {can_thresh}')

    methbam = False

    if args.data is not None:
        if args.bams is not None:
            sys.exit('please specify either -d/--data or -b/--bams but not both')

        with open(args.data) as _:
            for line in _:
                c = line.strip().split()
                if len(c) < 2:
                    logger.warning("required fields for -d/--data are: .bam file and methylation .db (generated by methylartist)")
                    sys.exit()

                bam, meth_db = c[:2]
                for m_db in meth_db.split(','):
                    data[bam].append(m_db)
                    mods += sorted(get_modnames(m_db))

                if len(c) == 3:
                    user_colours[bam] = c[2]

    ct_bams = None
    bams = None

    if args.bams is not None:
        if (not args.bedmethyl) and None in (args.ref, args.motif):
            logger.warning('--ref and --motif are required when using --bams')
            sys.exit(1)

        methbam = True
        bams = []

        if args.bams.endswith('.bam') or (':' in args.bams and args.bams.split(':')[0].endswith('.bam')):
            bams = args.bams.split(',')
        
        elif is_bam(args.bams):
            bams = args.bams.split(',')

        elif args.bedmethyl:
            logger.info('input is expected to be bgzip/tabix bedMethyl')

            beds = args.bams.split(',')
            if beds[0].endswith('.gz'):
                bams = beds
            else:
                logger.info('-b input does not end in .gz, assuming this is a file with a list of bedMethyl .gz files...')
                with open(args.bams) as bed_list: 
                    for line in bed_list:
                        c = line.strip().split()

                        if not c[0].endswith('.gz'):
                            sys.exit('cannot identify input type')

                        if len(c) == 1:
                            bams.append(c[0])
                        elif len(c) == 2:
                            bams.append(c[0])
                            user_colours[bams[-1]] = c[1]
                        else:
                            sys.exit(f'unparsable line in {args.bams}: {line.strip()}')


        else:
            logger.info('assuming %s contains a list of .bams' % args.bams)
            with open(args.bams) as bam_list:
                for line in bam_list:
                    c = line.strip().split()
                    if not c[0].endswith('.bam'):
                        if os.path.exists(c[0]+'.bam'):
                            logger.warning(f'{c[0]} doesnt end with .bam, appending a .bam because {c[0]}.bam exists')
                            c[0] += '.bam'

                    if len(c) == 1:
                        bams.append(c[0])
                    elif len(c) == 2:
                        bams.append(c[0])
                        user_colours[bams[-1]] = c[1]
                    else:
                        sys.exit(f'unparsable line in {args.bams}: {line.strip()}')

    colour_mapping = {}

    if args.colormap:
        mapped_user_colours = {}
        
        if args.colormap == 'auto':
            unique_annotations = list(set(user_colours.values()))
            map_colours = sns.color_palette(args.samplepalette, len(unique_annotations))
            colour_mapping = dict(zip(unique_annotations, map_colours))

            logger.info('automatic colour mapping:')
            for ann, col in colour_mapping.items():
                logger.info(f'\t{ann}: {mcolors.rgb2hex(col)}')

        else:
            with open(args.colormap) as cm:
                for line in cm:
                    c = line.strip().split()
                    if len(c) != 2:
                        logger.error('malformed --colormap file (should be two columns)')
                        sys.exit(1)
                    colour_mapping[c[0]] = c[1]

            logger.info(f'loaded colour mapping from {args.colormap}:')
            for ann, col in colour_mapping.items():
                logger.info(f'\t{ann}: {col}')

        for sample, ann in user_colours.items():
            mapped_user_colours[sample] = colour_mapping[ann]
    
        user_colours = mapped_user_colours

    for c in user_colours.values():
        if not matplotlib.colors.is_color_like(c):
            logger.error(f'invalid colour {c} found in {args.bams}, use --colormap to map colours onto annotations if desired')
            sys.exit(1)

    ct_bams = []

    if args.ctbam is not None:
        if args.bedmethyl:
            sys.exit('--ctbam not compatible with --bed')

        for ctbam_fn in args.ctbam.split(','):
            if ctbam_fn not in bams:
                sys.exit(f'{ctbam_fn} passed to --ctbam has no corresponding .bam passed to -b/--bam')
            
            ct_bams.append(ctbam_fn)
            logger.info(f'Noted C/T substitution .bam: {ctbam_fn}')

    ct_bams = set(ct_bams)

    if bams:
        for bam in bams:
            if ':' in bam:
                bam, ucol = bam.split(':')
                user_colours[bam] = ucol

            if not os.path.exists(bam):
                sys.exit(f'.bam file not found: {bam}')

            if not args.bedmethyl:
                with pysam.AlignmentFile(bam) as fh:
                    if not fh.check_index():
                        sys.exit('bam not indexed: %s' % bam)

            data[bam] = None

    if not mods:
        if bam in ct_bams:
            logger.info(f'Assuming modtype m for C/T substitution bam {bam}')
            mods = ['m']
        else:
            if not args.bedmethyl:
                mods = mods_methbam(bam)
            else:
                mods = mods_bedmethyl(bam)

    mods = list(set(mods))

    logger.info('found mods: %s' % ','.join(mods))

    if args.mods:
        for mod in args.mods.split(','):
            assert mod in mods, 'mod %s not found' % mod
        mods = args.mods.split(',')

    logger.info('using mods %s' % ','.join(mods))

    if end-start > 500000:
        logger.info('region size %s greater than 0.5 Mbp, setting --skip_align_plot True' % (end-start))
        args.skip_align_plot = True
    
    if args.force_align_plot == args.skip_align_plot == True:
        logger.info('--skip_align_plot overridden via --force_align_plot')
        args.skip_align_plot = False

    if args.bedmethyl:
        logger.info('setting --skip_align_plot due to --bedmethyl')
        args.skip_align_plot = True
        if args.force_align_plot:
            logger.warning('--force_align_plot overridden due to --bedmethyl')

    reads = {}
    orig_bam = {}

    phases = [None]
    if args.phased:
        if args.bedmethyl:
            sys.exit('--bedmethyl not compatible with --phased (generate phased bedMethyl output before running segmeth)')

        phases = ['1','2']

    for phase in phases:
        for bam, meth_dbs in data.items():
            for mod in mods:
                bamname = '.'.join(os.path.basename(bam).split('.')[:-1]) + '.' + mod

                if args.phased:
                    bamname += '.' + phase

                orig_bam[bamname] = bam

                if not args.skip_align_plot:
                    reads[bamname] = get_meth_locus(args, bam, meth_dbs, mod, meth_thresh=meth_thresh, can_thresh=can_thresh, phase=phase, methbam=methbam, restrict_motif=args.motif, restrict_ref=args.ref, ct_bams=ct_bams)

    pool = mp.Pool(processes=int(args.procs))

    meth_segs = dd(dict)
    sample_names = {}
    shallow_windows = dd(list)
    min_window_calls = int(args.min_window_calls)

    for phase in phases:
        results = []

        for seg_start, seg_end in zip(w_starts, w_ends):
            for bam_fn, meth_dbs in data.items():
                seg_strand = '.'
                seg_name = '.'.join(os.path.basename(bam_fn).split('.')[:-1])
                res = pool.apply_async(get_segmeth_calls, [args, bam_fn, mods, meth_dbs, chrom, seg_start, seg_end, seg_name, seg_strand, phase, methbam, args.bedmethyl, ct_bams, meth_thresh, can_thresh, None, None])
                results.append(res)

        logger.info('parsing segments (phase %s)...' % str(phase))

        for res in tqdm(results):
            meth_result, seg = res.get()
            if meth_result is None:
                continue

            seg_id = '%s:%d-%d' % seg[:3]

            seg_chrom, seg_start, seg_end, seg_name, seg_strand, motif_count, gmm_means, dmr_metrics, lowmeth_count, highmeth_count = map(str, seg[:-1])

            meth_segs[seg_id]['seg_id']     = seg_id
            meth_segs[seg_id]['seg_chrom']  = seg_chrom
            meth_segs[seg_id]['seg_start']  = seg_start
            meth_segs[seg_id]['seg_end']    = seg_end
            meth_segs[seg_id]['seg_name']   = seg_name
            meth_segs[seg_id]['seg_strand'] = seg_strand

            for modname, meth_data in meth_result.items():
                no_calls = 0
                meth_calls = 0
                unmeth_calls = 0

                if -1 in meth_data:
                    unmeth_calls = meth_data[-1]

                if 0 in meth_data:
                    no_calls = meth_data[0]

                if 1 in meth_data:
                    meth_calls = meth_data[1]

                if meth_calls + unmeth_calls < min_window_calls:
                    if seg_id not in shallow_windows[modname]:
                        shallow_windows[modname].append(seg_id) # TODO phase aware

                sample_name = seg_name + '.' + modname
                
                if args.phased:
                    sample_name += '.' + phase

                sample_names[sample_name] = True

                meth_segs[seg_id][sample_name + '_meth_calls'] = meth_calls
                meth_segs[seg_id][sample_name + '_unmeth_calls'] = unmeth_calls

                if meth_calls + unmeth_calls == 0:
                    meth_segs[seg_id][sample_name + '_frac'] = 0 # changed from NaN
                else:
                    meth_segs[seg_id][sample_name + '_frac'] = meth_calls/float(meth_calls+unmeth_calls)

    if args.eff:
        for sample_eff in args.eff.split(','):
            if ':' not in sample_eff:
                logger.warning(f'malformed --eff argument: {args.eff}')
                continue
            
            eff_samplename, eff_value = sample_eff.split(':')

            if eff_samplename in sample_names:
                for seg_id in meth_segs:
                    meth_segs[seg_id][eff_samplename + '_frac'] = meth_segs[seg_id][eff_samplename + '_frac']*float(eff_value)

                logger.info(f'adjusted sample {eff_samplename}, efficiency: {eff_value}')

            else:
                logger.warning(f'samplename from --eff not found: {eff_samplename}')

    for mod in mods: # TODO phase aware
        shallow_frac = len(shallow_windows[mod])/len(w_starts)*100.0
        if shallow_frac > 0.0:
            logger.info('%.2f percent of windows for mod %s had less than %d calls' % (shallow_frac, mod, min_window_calls))
            if shallow_frac > float(args.maxuncovered):
                sys.exit('greater than %.2f windows are uncovered, aborting.' % float(args.maxuncovered))

    deleted_segs = 0
    for mod in mods: # TODO phase aware
        for seg_id in shallow_windows[mod]:
            if seg_id in meth_segs:
                del meth_segs[seg_id]
                deleted_segs += 1

    if deleted_segs > 0:
        logger.info('removed %d segs with less than %d calls in at least one mod' % (deleted_segs, min_window_calls))

    meth_segs = pd.DataFrame.from_dict(meth_segs).T

    if 'seg_start' not in meth_segs:
        logger.warning('no methylation calls.')
        sys.exit()

    meth_segs['seg_start'] = pd.to_numeric(meth_segs['seg_start'])

    meth_segs.sort_values('seg_start', inplace=True)
    meth_segs['pos'] = np.arange(len(meth_segs.index))

    for sample in sample_names:
        meth_segs[sample] = smooth(np.asarray(meth_segs[sample + '_frac']), window_len=int(args.smoothwindowsize), window=args.smoothfunc)
        meth_segs[sample] = meth_segs[sample].rolling(window=10, min_periods=1, center=True).mean()

    coord_to_pos = {}
    for orig_loc, new_loc in zip(meth_segs['seg_start'], meth_segs['pos']):
        coord_to_pos[orig_loc] = new_loc

    # coverage bams (if present)

    cover_bams = {}
    if args.plot_coverage:
        cov_bams = []

        if args.plot_coverage.endswith('.bam'):
            cov_bams = args.plot_coverage.split(',')

        if args.plot_coverage.endswith('.txt'):
            with open(args.plot_coverage) as _:
                for fn in _:
                    cov_bams.append(fn.strip())

        for bam_fn in cov_bams:
            logger.info('gather coverage from %s...' % bam_fn)
            meth_seg_starts = [int(s.split(':')[1].split('-')[0]) for s in meth_segs.index]
            meth_seg_ends = [int(s.split(':')[1].split('-')[1]) for s in meth_segs.index]
            cover = bam_bincover(bam_fn, chrom, meth_seg_starts, meth_seg_ends, procs=int(args.procs), log=args.logcover)
            c_name = '.'.join(os.path.basename(bam_fn).split('.')[:-1])
            cover_bams[c_name] = smooth(np.asarray(cover), window_len=int(args.smoothwindowsize), window=args.smoothfunc)

    # highlights

    h_start = []
    h_end = []
    h_cpg_start = []
    h_cpg_end = []
    h_colors = []

    if args.highlight:
        h_colors = sns.color_palette(args.highlightpalette, n_colors=len(args.highlight.split(',')))

    if args.highlight:
        for h in args.highlight.split(','):
            if ':' in h:
                h = h.split(':')[-1]
                
            h_s, h_e = map(int, h.split('-'))
            h_start.append(h_s)
            h_end.append(h_e)

            h_cpg_start.append(coord_to_pos[min(meth_segs['seg_start'], key=lambda x:abs(x-h_start[-1]))])
            h_cpg_end.append(coord_to_pos[min(meth_segs['seg_start'], key=lambda x:abs(x-h_end[-1]))])

    if args.highlight_bed:
        need_colour = 0

        with open(args.highlight_bed) as h_bed:
            for line in h_bed:
                c = line.strip().split()
                assert len(c) >= 3, 'malformed line in --highlight_bed: %s' % line.strip()

                if c[0] != chrom:
                    continue

                h_s, h_e = map(int, c[1:3])

                if h_e < start or h_s > end:
                    continue

                colour = None

                if len(c) > 3:
                    colour = c[3]
                else:
                    need_colour += 1
                
                h_start.append(h_s)
                h_end.append(h_e)

                h_cpg_start.append(coord_to_pos[min(meth_segs['seg_start'], key=lambda x:abs(x-h_start[-1]))])
                h_cpg_end.append(coord_to_pos[min(meth_segs['seg_start'], key=lambda x:abs(x-h_end[-1]))])

                h_colors.append(colour)

        more_colours = sns.color_palette(args.highlightpalette, n_colors=need_colour)

        for i, c in enumerate(h_colors):
            if c is None:
                h_colors[i] = more_colours.pop()
        
        logger.info('found %d highlights for %s' % (len(h_start), chrom))

    # mask

    readmask = []
    if args.readmask:
        for ivl in args.readmask.split(','):
            if ':' in ivl:
                ivl = ivl.split(':')[1]
            assert '-' in ivl, 'malformed --readmask interval(s): %s' % args.readmask

            readmask.append(list(map(int, ivl.split('-'))))

    # set up plot

    fig = plt.figure()
    height_ratios = [1,5,1,3]

    if args.skip_align_plot:
        height_ratios = [1,0,1,4]

    if args.plot_coverage:
        height_ratios.append(3)

    gs = gridspec.GridSpec(len(height_ratios),1,height_ratios=height_ratios, hspace=0)

    img_w = 16
    img_h = 8

    if args.skip_align_plot:
        if '--height' not in sys.argv[0]:
            args.height = 4.5

    if args.width:
        img_w = float(args.width)

    if args.height:
        img_h = float(args.height)

    if args.scale_fullwidth:
        img_w = img_w*scale_width

    fig.set_size_inches(img_w, img_h)

    if args.panelratios:
        p_ratios = list(map(int, args.panelratios.split(',')))
        if args.plot_coverage and len(p_ratios) < 5:
            logger.warning('--panelratios must have 5 values if used with --plot_coverage')
            sys.exit(1)

        gs = gridspec.GridSpec(len(height_ratios),1,height_ratios=p_ratios, hspace=0)

    sample_order = list(sample_names.keys())

    logger.info('sample order: %s' % ','.join(sample_order))

    sample_color = {}
    for i, sample in enumerate(sample_order):
        sample_color[sample] = sns.color_palette(args.samplepalette, n_colors=len(sample_order))[i]

    for sample in sample_color:
        if orig_bam[sample] in user_colours:
            sample_color[sample] = user_colours[orig_bam[sample]]

    if args.color_by_hp:
        if args.phased:
            hp_color = {}
            for hp in ('1','2'):
                hp_color[hp] = sns.color_palette(args.samplepalette, n_colors=2)[int(hp)-1]
            
            for sample in sample_color:
                hp = sample.split('.')[-1]
                assert hp in ('1','2')
                sample_color[sample] = hp_color[hp]

        else:
            logger.warning('--color_by_hp has no effect without --phase')

    cover_color = {}

    if args.plot_coverage:
        for i, c_name in enumerate(cover_bams.keys()):
            cover_color[c_name] = sns.color_palette(args.coverpalette, n_colors=len(cover_bams))[i]

    # plot genes

    ax0 = plt.subplot(gs[0])

    ax0.spines['bottom'].set_visible(False)
    ax0.spines['left'].set_visible(False)
    ax0.spines['right'].set_visible(False)
    ax0.xaxis.set_ticks_position('top')

    gtf = None
    genes = [] 

    if args.gtf is not None:
        logger.info('building genes plot...')
        gtf = pysam.Tabixfile(args.gtf, index=args.csi)
        genes = build_genes(gtf, chrom, start, end, tx=args.show_transcripts)

    exon_patches = []
    tx_lines = []

    genes_of_interest = []
    gene_specific_colours = {}

    if args.genes is not None:
        genes_of_interest = args.genes.strip().split(',')

        if '.' in args.genes and os.path.exists(args.genes):
            logger.info('assuming --genes is a file, reading gene list from %s' % args.genes)

            genes_of_interest = []
            with open(args.genes) as genes_file:
                for line in genes_file:
                    c = line.strip().split()
                    genes_of_interest.append(c[0])
                    if len(c) > 1:
                        gene_specific_colours[c[0]] = c[1]

    genemap = dd(Intersecter)

    if genes_of_interest:
        new_genes = {}
        for ensg in genes:
            if genes[ensg].name in genes_of_interest:
                new_genes[ensg] = genes[ensg]
        
        genes = new_genes

    if args.gtf:
        logger.info('%d genes in region' % len(genes))

    gene_colours = sns.color_palette(args.genepalette, n_colors=len(genes))

    for i, ensg in enumerate(genes):
        if not genes[ensg].has_tx():
            continue

        y = 1

        while genemap[y].find(genes[ensg].tx_start, genes[ensg].tx_end):
            y += 1
            if args.gene_track_height:
                if y > int(args.gene_track_height):
                    y = int(args.gene_track_height)
                    break
        
        if genes[ensg].name in gene_specific_colours:
            gene_colours[i] = gene_specific_colours[genes[ensg].name]

        tx_lines.append(matplotlib.lines.Line2D([genes[ensg].tx_start, genes[ensg].tx_end], [0.4+y, 0.4+y], color=gene_colours[i], zorder=2))

        genes[ensg].merge_exons()
        for exon_start, exon_end in genes[ensg].exons:
            exon_len = exon_end - exon_start
            exon_patches.append(matplotlib.patches.Rectangle([exon_start, y], exon_len, float(args.exonheight), edgecolor=gene_colours[i], facecolor=gene_colours[i], zorder=3))

        genemap[y].add_interval(Interval(genes[ensg].tx_start, genes[ensg].tx_end))

        if args.labelgenes:
            lg_x  = max(genes[ensg].tx_start, start)
            gtxt  = ax0.text(lg_x, y+0.8, genes[ensg].name, bbox=dict(boxstyle="round,pad=0.3", fc="lavender", alpha=0.5, lw=0), zorder=4)
            bb_w  = gtxt.get_tightbbox(renderer=fig.canvas.get_renderer()).width
            fig_w = fig.get_size_inches()[0]*fig.dpi
            txt_w = bb_w/fig_w*(end-start)
            gtxt.set_x(lg_x-txt_w/2)

    gene_height = max(3, len(genemap))

    if args.labelgenes:
        gene_height = max(6, len(genemap))

    if args.bed:
        logger.info('loading annotations from %s' % args.bed)
        gene_height += 3

        y += 1

        bed_annotations = get_bed_annotations(args.bed, chrom, start, end)
        bed_colours = sns.color_palette(args.genepalette, n_colors=len(bed_annotations))

        for i, ann in enumerate(bed_annotations):
            if ann.end < start or ann.start > end:
                continue

            if ann.start < start:
                ann.start = start
            
            if ann.end > end:
                end = ann.end

            bed_colour = bed_colours[i]

            if ann.colour is not None:
                bed_colour = ann.colour

            exon_patches.append(matplotlib.patches.Rectangle([ann.start, y], (ann.end-ann.start), float(args.exonheight), edgecolor=bed_colour, facecolor=bed_colour, zorder=3))

            if ann.label is not None and not args.hidebedlabel:
                lg_x  = max(ann.start, 0)
                label = ann.label

                if ann.strand is not None:
                    if ann.strand == '+':
                        label += '>>'
                    
                    else:
                        label = '<<' + label

                gtxt  = ax0.text(lg_x, y, label, bbox=dict(boxstyle="round,pad=0.3", fc="lavender", alpha=0.5, lw=0), zorder=4)

    ax0.set_ylim(0, gene_height)
    ax0.set_yticks([])

    highlight_patches = []

    if args.highlight_alpha:
        a = float(args.highlight_alpha)
        if args.highlight or args.highlight_bed:
            for h, h_s in enumerate(h_start):
                h_e = h_end[h]
                h_color = h_colors[h]
                ymin = min(ax0.get_ybound())
                yheight = max(ax0.get_ybound())-ymin
                if args.highlight_centerline:
                    clw = float(args.highlight_centerline)
                    highlight_patches.append(matplotlib.patches.Rectangle([h_s, ymin], 1, yheight, lw=clw, edgecolor=h_color, facecolor=h_color, alpha=a, zorder=1))
                else:
                    highlight_patches.append(matplotlib.patches.Rectangle([h_s, ymin], h_e-h_s, yheight, edgecolor=None, facecolor=h_color, alpha=a, zorder=1))

    for h in highlight_patches:
        ax0.add_patch(h)
        
    for p in exon_patches:
        ax0.add_patch(p)

    for tx in tx_lines:
        ax0.add_line(tx)

    # per-read plot

    ax1 = plt.subplot(gs[1])
    ax1.set_xticks([])
    ax1.set_yticks([])

    if args.skip_align_plot:
        logger.info('skipped alignment plot due to --skip_align_plot')

    else:
        logger.info('building read alignment plot...')

        readstack = dd(list)

        max_y  = 1
        pack_y = 1

        masked_count = 0

        for bamname in reads:
            fetch_reads_bam = pysam.AlignmentFile(orig_bam[bamname])
            pos_cache = {}

            for read in fetch_reads_bam.fetch(chrom, start, end):
                if read.mapq < int(args.min_mapq):
                    continue

                if read.is_supplementary or read.is_secondary or read.is_duplicate:
                    if not args.allreads:
                        continue

                masked = False
                if len(readmask) > 0:
                    for mask_start, mask_end in readmask:
                        if read.reference_start >= mask_start and read.reference_end <= mask_end:
                            logger.debug('masked read: %s' % read.query_name)
                            masked_count += 1
                            masked = True
                
                if masked:
                    continue

                if read.query_name not in pos_cache:
                    pos_cache[read.query_name] = ([read.reference_start], [read.reference_end])
                else:
                    pos_cache[read.query_name][0].append(read.reference_start)
                    pos_cache[read.query_name][1].append(read.reference_end)

            for readname, read in reads[bamname].items():
                if readname not in pos_cache:
                    logger.debug('read %s not found in %s (skipped)' % (readname, bamname))
                    continue

                if read.call_count == 0:
                    continue

                read.ypos = max_y
                max_y += 1

                read.starts, read.ends = pos_cache[readname]
                readstack[bamname].append(read)

            fetch_reads_bam.close()

        if args.readmask:
            logger.info('masked %d reads due to --readmask' % masked_count)

        # read packing

        for bamname in readstack:
            reads = readstack[bamname]

            y = dd(list)

            for read in readstack[bamname]:
                y[read.ypos].append(read)

            for p in y:
                for q in y:
                    if p == q:
                        continue

                    for read_q in y[q]:
                            move = True

                            for read_p in y[p]:
                                if read_q.overlap(read_p):
                                    move = False

                            if move:
                                y[p].append(read_q)
                                y[q].remove(read_q)

            for p in y:
                if len(y[p]) > 0:
                    for read in sorted(y[p], key=lambda r: min(r.starts)):
                        read.ypos = pack_y
                    pack_y += 1

        ax1.set_ylim(0,pack_y+1)

        for bamname in readstack:
            for read in readstack[bamname]:
                for call_pos, call in read.meth_calls.items():
                    call_pos += start

                    if call == -1:
                        ax1.plot(call_pos, read.ypos, marker='o', fillstyle='full', mec=sample_color[bamname], mfc='white', markersize=2, zorder=3)

                    if call == 1:
                        ax1.plot(call_pos, read.ypos, marker='o', fillstyle='full', mec='black', mfc='black', markersize=2, zorder=3)

                for i in range(len(read.starts)):
                    readline_start = max(read.starts[i], start)
                    readline_end   = min(read.ends[i], end)

                    ax1.add_line(matplotlib.lines.Line2D([readline_start, readline_end], [read.ypos, read.ypos], zorder=2, color=sample_color[bamname], alpha=0.4))

        highlight_patches = []

        if args.highlight_alpha:
            a = float(args.highlight_alpha)
            if args.highlight or args.highlight_bed:
                for h, h_s in enumerate(h_start):
                    h_e = h_end[h]
                    h_color = h_colors[h]
                    ymin = min(ax1.get_ybound())
                    yheight = max(ax1.get_ybound())-ymin
                    if args.highlight_centerline:
                        clw = float(args.highlight_centerline)
                        highlight_patches.append(matplotlib.patches.Rectangle([h_s, ymin], 1, yheight, lw=clw, edgecolor=h_color, facecolor=h_color, alpha=a, zorder=1))
                        pass
                    else:
                        highlight_patches.append(matplotlib.patches.Rectangle([h_s, ymin], h_e-h_s, yheight, edgecolor=None, facecolor=h_color, alpha=a, zorder=1))

        for h in highlight_patches:
            ax1.add_patch(h)

    # plot correspondence between genome space and cpg space

    logger.info('building mod-space plot...')
    ax2 = plt.subplot(gs[2])
    ax3 = ax2.twiny()

    ax2.set_ylim(0,10)
    ax2.set_yticklabels([])

    x1 = []
    x2 = []

    step = int(args.modspace)

    for i, x in enumerate(meth_segs['seg_start']):
        if i in (0, len(meth_segs['seg_start'])-1):
            x2.append(x)
            x1.append(coord_to_pos[x])

        elif i % step == 0:
            x2.append(x)
            x1.append(coord_to_pos[x])

    
    ax2.vlines(x1, 0, 1, color='#777777', zorder=1)
    ax3.vlines(x2, 9, 10, color='#777777', zorder=1)

    if args.highlight or args.highlight_bed:
        for i in range(len(h_start)):
            orig_highlight_box = None
            cpg_highlight_box = None

            if args.highlight_centerline:
                clw = float(args.highlight_centerline)
                orig_highlight_box = matplotlib.patches.Rectangle((h_start[i],9), 1.0, 1.0, lw=clw, edgecolor=h_colors[i], facecolor=h_colors[i], zorder=2)
                cpg_highlight_box = matplotlib.patches.Rectangle((h_cpg_start[i],0), 1.0, 1.0, lw=clw, edgecolor=h_colors[i], facecolor=h_colors[i], zorder=3)
            else:
                orig_highlight_box = matplotlib.patches.Rectangle((h_start[i],9), h_end[i]-h_start[i], 1.0, lw=1, edgecolor='#777777', facecolor=h_colors[i], zorder=2)
                cpg_highlight_box = matplotlib.patches.Rectangle((h_cpg_start[i],0), h_cpg_end[i]-h_cpg_start[i], 1.0, lw=1, edgecolor='#777777', facecolor=h_colors[i], zorder=3)

            ax3.add_patch(orig_highlight_box)
            ax2.add_patch(cpg_highlight_box)

    for x1_x, x2_x in zip(x1, x2):
        link_end1 = (x1_x, 1)
        link_end2 = (x2_x, 9)

        l_col = '#777777'

        for i in range(len(h_start)):
            if x2_x >= h_start[i] and x2_x <= h_end[i]:
                l_col = h_colors[i]

        con = ConnectionPatch(xyA=link_end1, xyB=link_end2, coordsA="data", coordsB="data", axesA=ax2, axesB=ax3, color=l_col)
        ax3.add_artist(con)
    
    if args.highlight_centerline:
        clw = float(args.highlight_centerline)

        i = 0
        for x1_x, x2_x in zip(h_cpg_start, h_start):
            link_end1 = (x1_x, 1)
            link_end2 = (x2_x, 9)

            con = ConnectionPatch(xyA=link_end1, xyB=link_end2, coordsA="data", coordsB="data", axesA=ax2, axesB=ax3, color=h_colors[i], lw=clw)
            ax3.add_artist(con)
            i += 1

    ax0.set_xlim(ax3.get_xlim()) # sync axes between orig coords and gtf plot
    ax1.set_xlim(ax3.get_xlim())
    ax2.set_xticks([])
    ax3.set_xticks([])

    n_ticks = int(args.nticks) + 1
    tick_interval = (end-start)/n_ticks
    tick_list = list(range(int(ax0.get_xlim()[0]), int(ax0.get_xlim()[1]), int(tick_interval)))

    revised_tick_list = []
    for t in tick_list:
        if t < 0:
            revised_tick_list.append(0)
        else:
            revised_tick_list.append(t)

    tick_list = sorted(list(set(revised_tick_list)))

    xt_labels = [str(int(t)) for t in tick_list]
    xt_labels[0] = chrom

    ax0.set_xticks(tick_list)
    if n_ticks > 11:
        ax0.set_xticklabels(xt_labels, rotation=45)
    else:
        ax0.set_xticklabels(xt_labels)

    # meth frac plot

    logger.info('building meth frac plot...')
    ax4 = plt.subplot(gs[3])

    smoothalpha = float(args.smoothalpha)
    slw = float(args.smoothlinewidth)

    if smoothalpha > 1.0 or smoothalpha < 0.0:
        logger.warning('--smoothalpha must be between 0 and 1, set to 1.0')
        smoothalpha = 1.0

    if args.shuffle:
        random.shuffle(sample_order)

    for sample in sample_order:
        ax4.plot(meth_segs['pos'], meth_segs[sample], marker='', lw=slw, color=sample_color[sample], zorder=2, label=sample, alpha=smoothalpha)

    ax4.legend()
    ax4.set_xlim(ax2.get_xlim())
    ax4.set_ylim((float(args.ymin),float(args.ymax)))

    highlight_patches = []

    if args.highlight_alpha:
        a = float(args.highlight_alpha)
        if args.highlight or args.highlight_bed:
            for h, h_s in enumerate(h_cpg_start):
                h_e = h_cpg_end[h]
                h_color = h_colors[h]
                ymin = min(ax4.get_ybound())
                yheight = max(ax4.get_ybound())-ymin
                if args.highlight_centerline:
                    clw = float(args.highlight_centerline)
                    highlight_patches.append(matplotlib.patches.Rectangle([h_s, ymin], 1, yheight, lw=clw, edgecolor=h_color, facecolor=h_color, alpha=a, zorder=1))
                else:
                    highlight_patches.append(matplotlib.patches.Rectangle([h_s, ymin], h_e-h_s, yheight, edgecolor=None, facecolor=h_color, alpha=a, zorder=1))

    for h in highlight_patches:
        ax4.add_patch(h)

    if args.colormap:
        ax4.legend([mpatches.Patch(color=c) for c in colour_mapping.values()], colour_mapping.keys())

    # coverage plot (optional)

    if args.plot_coverage:
        logger.info('building coverage plot...')
        ax5 = plt.subplot(gs[4])

        for sample in cover_bams:
            ax5.plot(meth_segs['pos'], cover_bams[sample], marker='', lw=slw, color=cover_color[sample], zorder=3, label=sample, alpha=smoothalpha)

        ax5.legend()
        ax5.set_xlim(ax2.get_xlim())
        ax5.set_ylim((float(args.cover_ymin), max(ax5.get_ybound())))

        highlight_patches = []

        if args.highlight_alpha:
            a = float(args.highlight_alpha)
            if args.highlight or args.highlight_bed:
                for h, h_s in enumerate(h_cpg_start):
                    h_e = h_cpg_end[h]
                    h_color = h_colors[h]
                    ymin = min(ax5.get_ybound())
                    yheight = max(ax5.get_ybound())-ymin
                    if args.highlight_centerline:
                        clw = float(args.highlight_centerline)
                        highlight_patches.append(matplotlib.patches.Rectangle([h_s, ymin], 1, yheight, lw=clw, edgecolor=h_color, facecolor=h_color, alpha=a, zorder=1))
                    else:
                        highlight_patches.append(matplotlib.patches.Rectangle([h_s, ymin], h_e-h_s, yheight, edgecolor=None, facecolor=h_color, alpha=a, zorder=1))

        for h in highlight_patches:
            ax5.add_patch(h)
            
    # output

    fn_prefix = '.%s_%d_%d.%s' % (chrom, start, end, ''.join(mods))

    if methbam:
        fn_prefix = '.'.join(os.path.basename(args.bams.split(',')[0]).split('.')[:-1]) + fn_prefix
        if len(args.bams.split(',')) > 1:
            fn_prefix += '.cohort'
    else:
        fn_prefix = '.'.join(os.path.basename(args.data).split('.')[:-1]) + fn_prefix

    if args.genes is not None:
        fn_prefix = '_'.join(args.genes.split(',')) + '.' + fn_prefix

    fn_prefix += '.s%d.w%d.m%d' % (args.smoothwindowsize, args.windows, args.modspace)

    if args.phased:
        fn_prefix += '.phased'

    if args.eff:
        fn_prefix += '.eff'

    if args.scale_fullwidth:
        fn_prefix += '.scale_width.%.3f' % scale_width

    if args.bedmethyl:
        fn_prefix += '.bedMethyl'
    else:
        fn_prefix += f'.mt{args.meth_thresh}.ct{args.can_thresh}'

    if args.segment_csv:
        if args.segment_csv == 'auto':
            args.segment_csv = fn_prefix + '.segments.csv'

        meth_segs[sample_order].to_csv(args.segment_csv, index_label='seg')
        logger.info(f'wrote segment data to {args.segment_csv}')

    outfn = fn_prefix

    if args.svg:
        outfn += '.region.meth.svg'
    else:
        outfn += '.region.meth.png'

    if args.outfile is not None:
        outfn = args.outfile
        if args.svg:
            if outfn.split('.')[-1] != 'svg':
                logger.warning('warning: %s does not have extension .svg, appending')
                outfn += '.svg'

    fig.savefig(outfn, bbox_inches='tight')
    logger.info('plot saved to %s' % outfn)


def composite(args):
    '''
    plot composite methylation profiles relative to a consensus element
    '''

    if not skbio_installed:
        sys.exit('scikit-bio is not installed but is required for this function. Please install e.g. via "pip install scikit-bio" or "conda install -c https://conda.anaconda.org/biocore scikit-bio"')

    te_ref_seq = single_seq_fa(args.teref).upper()

    assert os.path.exists(args.ref + '.fai'), 'ref fasta must be indexed'

    mod_names = []

    data = dd(list)

    if args.data is None and args.bams is None:
        sys.exit('please specify either -d/--data or -b/--bams')

    if args.color_by_phase and not args.phased:
        sys.exit('must specify --phased to use --color_by_phase')

    if args.motif is not None:
        assert iupac(args.motif)

    methbam = False
    user_colours = {}

    if args.data is not None:
        if args.bams is not None:
            sys.exit('please specify either -d/--data or -b/--bams but not both')

        with open(args.data) as _:
            for line in _:
                c = line.strip().split()
                if len(c) < 2:
                    logger.warning("required fields for -d/--data are: .bam file and methylation .db (generated by methylartist)")
                    sys.exit()

                bam, meth = c[:2]
                if ',' in meth:
                    sys.exit('multiple .db files per bam not supported for composite')

                data[bam] = meth

                for m in get_modnames(meth):
                    mod_names.append(m)

                if len(c) == 3:
                    bam_noext = '.'.join(os.path.basename(bam).split('.')[:-1])
                    user_colours[bam_noext] = c[2]

    if args.bams is not None:
        methbam = True
        bams = []

        if args.bams.endswith('.bam') or (':' in args.bams and args.bams.split(':')[0].endswith('.bam')):
            bams = args.bams.split(',')

        elif is_bam(args.bams):
            bams = args.bams.split(',')

        else:
            logger.info('assuming %s contains a list of .bams' % args.bams)
            with open(args.bams) as bam_list:
                for line in bam_list:
                    bams.append(line.strip().split()[0])

        for bam in bams:
            if ':' in bam:
                bam, _ = bam.split(':')

            if not os.path.exists(bam):
                sys.exit(f'.bam file not found: {bam}')

            with pysam.AlignmentFile(bam) as fh:
                if not fh.check_index():
                    sys.exit('bam not indexed: %s' % bam)

            data[bam] = None

        mod_names = mods_methbam(bam)

    mod_names = list(set(mod_names))

    logger.info('found mod names: %s' % ','.join(mod_names))

    use_mod = None

    if len(mod_names) == 1:
        use_mod = mod_names[0]
        logger.info('using the one available mod: %s' % use_mod)

    if use_mod is None:
        if args.mod is None:
            sys.exit('please specify a modification via --mod\navailable mods are: %s' % ','.join(mod_names))
            
        use_mod = args.mod

        if use_mod not in mod_names:
            sys.exit('please specify a modification via --mod\navailable mods are: %s' % ','.join(mod_names))

    meth_thresh = 0.8
    can_thresh = 0.8

    if args.meth_thresh:
        meth_thresh = float(args.meth_thresh)

    if args.can_thresh:
        can_thresh = float(args.can_thresh)

    logger.info(f'methylated base threshold: {meth_thresh}')
    logger.info(f'canonical base threshold: {can_thresh}')

    data_basename = None

    if methbam:
        data_basename = '.'.join(os.path.basename(args.bams.split(',')[0]).split('.')[:-1])
        if len(args.bams.split(',')) > 1:
            data_basename += '.cohort'

    else:
        data_basename = '.'.join(os.path.basename(args.data).split('.')[:-1])    

    seg_basename = '.'.join(os.path.basename(args.segdata).split('.')[:-1])

    outfn = data_basename + '.' + seg_basename

    if args.meanplot_cutoff:
        outfn += '.meanplot_cutoff_%d' % int(args.meanplot_cutoff)

    if args.phased:
        outfn += '.phased'

    outfn += f'.mt{args.meth_thresh}.ct{args.can_thresh}'

    outfn += '.composite'

    table_fn = outfn + '.table.tsv'

    if args.svg:
        outfn += '.svg'
    else:
        outfn += '.png'

    pool = mp.Pool(processes=int(args.procs))

    results = []

    with open(args.segdata) as bed:
        for line in bed:
            c = line.strip().split()
            
            seg_chrom, seg_start, seg_end = line.strip().split()[:3]        

            str_col = 0

            if c[3] in ('-', '+'):
                str_col = 3
            
            elif len(c) > 4 and c[4] in ('-', '+'):
                str_col = 4
            
            else:
                sys.exit('strand (+/-) not found in cols 4 or 5 of %s' % args.segdata)
            
            seg_strand = c[str_col]

            seg_start  = int(seg_start)
            seg_end    = int(seg_end)

            if args.phased:
                for phase in ('1','2'):
                    res = pool.apply_async(get_meth_profile_composite, [args, data, methbam, seg_chrom, seg_start, seg_end, seg_strand, use_mod, phase, meth_thresh, can_thresh])
                    results.append(res)

            else:
                res = pool.apply_async(get_meth_profile_composite, [args, data, methbam, seg_chrom, seg_start, seg_end, seg_strand, use_mod, None, meth_thresh, can_thresh])
                results.append(res)

    # collect mod data
    out_res = dd(list) # cache for --outelts

    for res in results:
        per_bam_res = res.get()

        for bam in per_bam_res:
            if per_bam_res[bam] is None:
                continue

            coord_meth_pos, meth_profile, elt_info = per_bam_res[bam]

            if len(coord_meth_pos) == 0:
                continue

            out_res[bam].append((coord_meth_pos, meth_profile, elt_info))

    # set bounds
    mod_start = 0
    mod_end = len(te_ref_seq)

    if args.start:
        mod_start = int(args.start)

    if args.end:
        mod_end = int(args.end)

    assert mod_start < mod_end

    # set up plot

    sample_color = {}
    if args.color_by_phase:
        for i, phase in enumerate(('phase1', 'phase2')):
            sample_color[phase] = sns.color_palette(args.palette, n_colors=2)[i]

    else:
        for i, bam in enumerate(out_res):
            sample_color[bam] = sns.color_palette(args.palette, n_colors=len(out_res))[i]

            if bam in user_colours:
                sample_color[bam] = user_colours[bam]

    fig = plt.figure()
    gs = None

    gs = gridspec.GridSpec(3,1,height_ratios=[3,1,8])

    g = 0

    # mean

    ax0 = plt.subplot(gs[g])
    g += 1
    ax0.set_ylim((-0.05,1.05))
    ax0.set_xlim((mod_start, mod_end))
    ax0.set_xticks([])

    meanplot_table = dd(dict)

    for bam in out_res:

        meth_by_coord = dd(list)

        per_elt_calls = []

        for coord_meth_pos, meth_profile, elt_info in out_res[bam]:
            per_elt_calls.append(len(coord_meth_pos))
            for c, m in zip(coord_meth_pos, meth_profile):
                meth_by_coord[c].append((m, elt_info))

        median_call_count = int(np.median(per_elt_calls))
        cutoff = median_call_count
        
        if args.meanplot_cutoff:
            cutoff = int(args.meanplot_cutoff)

        logger.info('%s: median per element call count: %d' % (bam, median_call_count))
        logger.info('%s: per site call count cutoff: %d' % (bam, cutoff))

        for c in sorted(meth_by_coord.keys()):
            if len(meth_by_coord[c]) >= cutoff:
                for m, elt_info in meth_by_coord[c]:
                    u = str(uuid4())
                    meanplot_table[u]['chrom'], meanplot_table[u]['start'], meanplot_table[u]['end'], meanplot_table[u]['strand'] = elt_info.split('_')
                    if args.color_by_phase:
                        meanplot_table[u]['sample'] = bam.split('.')[-1]
                    else:    
                        meanplot_table[u]['sample'] = bam

                    meanplot_table[u]['coord'] = c
                    meanplot_table[u]['meth'] = m

    if len(meanplot_table) == 0:
        sys.exit('no successful alignments!')
    
    meanplot_table = pd.DataFrame.from_dict(meanplot_table).T
    meanplot_table['start'] = np.array(meanplot_table['start'], dtype=int)
    meanplot_table['end'] = np.array(meanplot_table['end'], dtype=int)
    meanplot_table['coord'] = np.array(meanplot_table['coord'], dtype=int)
    meanplot_table['meth'] = np.array(meanplot_table['meth'], dtype=float)

    if args.output_table:
        meanplot_table.to_csv(table_fn, sep='\t', quoting=False, index=False)
        logger.info('wrote per-site table to %s' % table_fn)


    ax0 = sns.lineplot(x='coord', y='meth', data=meanplot_table, ci='sd', lw=2, hue='sample', palette=sample_color)
    ax0.set_ylabel(args.meanplot_ylabel)

    # mod
    ax1 = plt.subplot(gs[g])
    g += 1
    ax1.set_xlim((mod_start, mod_end))

    box = matplotlib.patches.Rectangle([0, 0], mod_end-mod_start, 1.0, edgecolor='#555555', facecolor='#cfcfcf', zorder=1)
    ax1.add_patch(box)

    if args.blocks:
        with open(args.blocks) as blocks:
            for line in blocks:
                b_start, b_end, b_name, b_col = line.strip().split()
                b_start = int(b_start)
                b_end = int(b_end)

                box = matplotlib.patches.Rectangle([b_start, 0], b_end-b_start, 1.0, edgecolor='#555555', facecolor=b_col, zorder=2)
                ax1.add_patch(box)

    mod_locs = []

    motif = args.motif

    for i in range(len(te_ref_seq)-len(motif)):
        if i >= mod_start and i <= mod_end:
            if te_ref_seq[i:i+len(motif)] == motif:
                mod_locs.append(i)

    ax1.vlines(mod_locs, 0, 1, lw=1, colors=('#FF4500'), zorder=3, alpha=0.5)

    ax1.spines['bottom'].set_visible(False)
    ax1.spines['left'].set_visible(False)
    ax1.spines['right'].set_visible(False)
    ax1.set_xlim((mod_start, mod_end))
    ax1.xaxis.set_ticks_position('top')

    # wiggles
    ax2 = plt.subplot(gs[g])
    g += 1

    z_max = len(out_res)*int(args.maxelts)

    for bam in out_res:
        if len(out_res[bam]) < int(args.minelts):
            sys.exit('fewer than --minelts (%d) usable elements (%d), giving up.' % (int(args.minelts),len(out_res[bam])))

        sample_size = int(args.maxelts)
        if sample_size > len(out_res[bam]):
            sample_size = len(out_res[bam])

        logger.info('sample %s has %d useable elements, will sample %d' % (bam, len(out_res[bam]), sample_size))

        for coord_meth_pos, meth_profile, elt_info in random.sample(out_res[bam], sample_size):
            if args.color_by_phase:
                bam = bam.split('.')[-1]

            ax2.plot(coord_meth_pos, meth_profile, lw=float(args.linewidth), alpha=float(args.alpha), color=sample_color[bam], zorder=random.randint(0,z_max))

    ax2.set_ylim((float(args.ymin),float(args.ymax)))
    ax2.set_xlim((mod_start, mod_end))
    ax2.set_xlabel('position')

    fig.set_size_inches(16, 6)

    if args.outfile is not None:
        outfn = args.outfile

    plt.savefig(outfn, bbox_inches='tight')
    logger.info('plotted to %s' % outfn)


def wgmeth(args):
    '''
    generates whole genome output in DSS or bedmethyl format
    '''

    meth_table = [None, None]
    seg_meth_table_store = dd(list)

    pool = mp.Pool(processes=int(args.procs))
    bin_size = int(args.binsize)

    methbam = False

    if args.methdb is None:
        if None in (args.ref, args.motif):
            logger.warning('--ref and --motif are required when using mod .bams (no --methdb)')
            sys.exit(1)
        
        if args.fai is None:
            args.fai = args.ref + '.fai'
            if not os.path.exists(args.fai):
                logger.warning(f'--fai not specified, assumed {args.fai} but it cannot find it, use samtools faidx {args.ref} and re-run or specify --fai')
                sys.exit(1)
            
        methbam = True
    else:
        if args.fai is None:
            logger.warning('--fai is required for -d/--db input')
            sys.exit(1)

    motifsize = None

    if args.motif is not None:
        assert iupac(args.motif)
        motifsize = len(args.motif)
        logger.info('motif size %d (%s)' % (motifsize, args.motif))

    if methbam:
        if args.ctbam is not None:
            logger.info(f'Assuming modtype m for C/T substitution bam {args.bam}')
            mods = ['m']
        else:
            mods = sorted(mods_methbam(args.bam))
    else:
        mods = sorted(get_modnames(args.methdb))

    if methbam and len(mods) == 0:
        sys.exit('bam %s does not appear to contain MM/ML tags' % args.bam)

    if args.mod not in mods:
        if args.mod is None:
            logger.warning('must specify which to use with --mod, available mods: %s' % ','.join(mods))
        else:
            logger.warning('mod %s not in known mods for db: %s' % (args.mod, ','.join(mods)))
        sys.exit()

    if len(mods) > 1 and args.mod is None:
        logger.warning('more than one mod exists, need to pick one with -m/--mod: %s' % ','.join(mods))
        sys.exit()

    if args.dss:
        logger.info('output will be DSS format: https://www.bioconductor.org/packages/release/bioc/html/DSS.html')
    else:
        logger.info('output will be bedMethyl format: https://www.encodeproject.org/data-standards/wgbs/')

    meth_thresh = 0.8
    can_thresh = 0.8

    if args.meth_thresh:
        meth_thresh = float(args.meth_thresh)

    if args.can_thresh:
        can_thresh = float(args.can_thresh)

    logger.info(f'methylated base threshold: {meth_thresh}')
    logger.info(f'canonical base threshold: {can_thresh}')

    results = []

    with open(args.fai) as fai:
        for line in fai:
            chrom, chrlen = line.strip().split()[:2]
            chrlen = int(chrlen)

            if chrlen < int(args.minlen):
                continue

            logger.info(f'processing chromosome {chrom}, len = {chrlen} bp')

            if args.chrom:
                if args.chrom != chrom:
                    continue

            for seg_start in range(0, chrlen, bin_size):
                seg_end = seg_start + bin_size

                if seg_end > chrlen:
                    seg_end = chrlen

                seg_start = int(seg_start)
                seg_end = int(seg_end)

                res = pool.apply_async(get_meth_calls_wg, [args, args.bam, args.methdb, chrom, seg_start, seg_end, args.phased, args.mod, motifsize, args.ctbam, meth_thresh, can_thresh])

                results.append(res)


    for res in tqdm(results):
        seg_meth_table = res.get()

        if seg_meth_table is None:
            continue

        if args.phased:
            for phase in (0,1):
                seg_meth_table_store[phase].append(pd.DataFrame.from_dict(seg_meth_table[phase]).T)

        else:
            seg_meth_table_store[0].append(pd.DataFrame.from_dict(seg_meth_table[0]).T)

    if args.phased:
        for phase in (0,1):
            meth_table[phase] = pd.concat(seg_meth_table_store[phase])

    else:
        meth_table[0] = pd.concat(seg_meth_table_store[0])

    if args.phased:
        for phase in (0,1):
            if len(meth_table[phase]) == 0:
                sys.exit('no calls for phase %d: is this data phased?' % phase)
            meth_table[phase]['pos'] = meth_table[phase].index
            meth_table[phase] = meth_table[phase].sort_values(['chr', 'pos'])

            outfn = '.'.join(os.path.basename(args.bam).split('.')[:-1])

            if args.chrom:
                outfn += '.%s' % args.chrom

            outfn += f'.mt{args.meth_thresh}.ct{args.can_thresh}'

            if args.dss:
                outfn += '.%s.phase_%d.DSS.txt' % (str(args.mod), phase)
            else:
                outfn += '.%s.phase_%d.methyl.bed' % (str(args.mod), phase)

            if args.outfile is not None:
                outfn = args.outfile

            logger.info('writing %s' % outfn)

            if args.dss:
                meth_table[phase].to_csv(outfn, columns=['chr','pos','N','X'], index=False, sep='\t') # 0-based

            else:
                meth_table[phase]['score']  = meth_table[phase]['N']
                meth_table[phase]['score']  = meth_table[phase]['score'].where(meth_table[phase]['score'] <= 1000, 1000) # cap score at 1000
                meth_table[phase]['start']  = meth_table[phase]['pos'] - 1 # 1-based
                meth_table[phase]['end']    = meth_table[phase]['pos']
                meth_table[phase]['pct']    = meth_table[phase]['X']/meth_table[phase]['N']
                meth_table[phase]['name']   = args.mod

                meth_table[phase]['strand'] = '.'
                meth_table[phase]['colour'] = '255,0,0'

                meth_table[phase].to_csv(outfn, columns=['chr','start','end','name', 'score', 'strand', 'start', 'end', 'colour', 'N', 'pct'], header=False, index=False, sep='\t')


    else:
        if len(meth_table[0]) == 0:
            sys.exit('no calls, potential data formatting issue?')
        meth_table[0]['pos'] = meth_table[0].index
        meth_table[0] = meth_table[0].sort_values(['chr', 'pos'])

        outfn = '.'.join(os.path.basename(args.bam).split('.')[:-1])

        if args.chrom:
            outfn += '.%s' % args.chrom

        outfn += f'.mt{args.meth_thresh}.ct{args.can_thresh}'

        if args.dss:
            outfn += '.%s.DSS.txt' % str(args.mod)
        else:
            outfn += '.%s.methyl.bed' % str(args.mod)

        if args.outfile is not None:
            outfn = args.outfile

        logger.info('writing %s' % outfn)

        if args.dss:
            meth_table[0].to_csv(outfn, columns=['chr','pos','N','X'], index=False, sep='\t')

        else:
            meth_table[0]['score']  = meth_table[0]['N']
            meth_table[0]['score']  = meth_table[0]['score'].where(meth_table[0]['score'] <= 1000, 1000)
            meth_table[0]['start']  = meth_table[0]['pos']-1
            meth_table[0]['end']    = meth_table[0]['pos']
            meth_table[0]['pct']    = meth_table[0]['X']/meth_table[0]['N']*100
            meth_table[0]['name']   = args.mod
            meth_table[0]['strand'] = '.'
            meth_table[0]['colour'] = '255,0,0'

            meth_table[0].to_csv(outfn, columns=['chr','start','end','name', 'score', 'strand', 'start', 'end', 'colour', 'N', 'pct'], header=False, index=False, sep='\t')


def main():
    logger.info('starting methylartist with command: %s' % ' '.join(sys.argv))
    args = parse_args()
    args.func(args)


def parse_args():
    parser = argparse.ArgumentParser(description='methylartist: tools for exploring nanopore modified base data')
    subparsers = parser.add_subparsers(title="tool", dest="tool")
    subparsers.required = True

    __version__ = "1.5.4"
    parser.add_argument('-v', '--version', action='version', version='%(prog)s {version}'.format(version=__version__))

    parser_nanopolish    = subparsers.add_parser('db-nanopolish')
    parser_megalodon     = subparsers.add_parser('db-megalodon')
    parser_custom        = subparsers.add_parser('db-custom')
    parser_guppy         = subparsers.add_parser('db-guppy')
    parser_substitution  = subparsers.add_parser('db-sub')
    parser_segmeth       = subparsers.add_parser('segmeth')
    parser_segplot       = subparsers.add_parser('segplot')
    parser_locus         = subparsers.add_parser('locus')
    parser_region        = subparsers.add_parser('region')
    parser_composite     = subparsers.add_parser('composite')
    parser_wgmeth        = subparsers.add_parser('wgmeth')
    parser_adjustcutoffs = subparsers.add_parser('adjustcutoffs')
    parser_scoredist     = subparsers.add_parser('scoredist')

    parser_segmeth.set_defaults(func=segmeth)
    parser_segplot.set_defaults(func=segplot)
    parser_locus.set_defaults(func=locus)
    parser_region.set_defaults(func=region)
    parser_composite.set_defaults(func=composite)
    parser_wgmeth.set_defaults(func=wgmeth)
    parser_nanopolish.set_defaults(func=db_nanopolish)
    parser_megalodon.set_defaults(func=db_megalodon)
    parser_custom.set_defaults(func=db_custom)
    parser_guppy.set_defaults(func=db_guppy)
    parser_substitution.set_defaults(func=db_sub)
    parser_adjustcutoffs.set_defaults(func=adjustcutoffs)
    parser_scoredist.set_defaults(func=scoredist)

    # options for methylation segments
    parser_segmeth.add_argument('-d', '--data', default=None, help='text file with .bam filename and corresponding methylation database per line(whitespace-delimited)')
    parser_segmeth.add_argument('-b', '--bams', default=None, help='one or more .bams (or bedMethyl input) with Mm and Ml tags for modification calls (see samtags spec)')
    parser_segmeth.add_argument('-i', '--intervals', required=True, help='.bed file')
    parser_segmeth.add_argument('-p', '--procs', default=1, help='multiprocessing')
    parser_segmeth.add_argument('-q', '--min_mapq', default=10, help='minimum mapping quality (mapq), default = 10')
    parser_segmeth.add_argument('-o', '--outfile', default=None, help='output file name (default: generated from input)')
    parser_segmeth.add_argument('-m', '--mods', default=None, help='mods, comma-delimited for >1 (default to all available mods)')
    parser_segmeth.add_argument('--bedmethyl', default=False, action='store_true', help='input is (tabix-indexed) bedMethyl')
    parser_segmeth.add_argument('--meth_thresh', default=0.8, help='modified base threshold (default=0.8)')
    parser_segmeth.add_argument('--can_thresh', default=0.8, help='canonical base threshold (default=0.8)')
    parser_segmeth.add_argument('--ctbam', default=None, help='specify which .bam(s) are C/T substitution data (can be comma-delimited)')
    parser_segmeth.add_argument('--ref', default=None, help='reference genome .fa (build .fai index with samtools faidx) (required for mod bams)')
    parser_segmeth.add_argument('--motif', default=None, help='expected modification motif (e.g. CG for 5mCpG required for mod bams)')
    parser_segmeth.add_argument('--max_read_density', default=None, help='filter reads with call density greater >= value, can be helpful in footprinting assays (default=None)')
    parser_segmeth.add_argument('--excl_ambig', action='store_true', default=False, help='do not consider reads that align entirely within segment')
    parser_segmeth.add_argument('--spanning_only', action='store_true', default=False, help='only consider reads that span segment')
    parser_segmeth.add_argument('--primary_only', action='store_true', default=False, help='ignore non-primary alignments')
    parser_segmeth.add_argument('--lowmeth_thresh', default=0.1, help='threshold for low-methylated read count column (default = 0.05)')
    parser_segmeth.add_argument('--highmeth_thresh', default=0.9, help='threshold for high-methylated read count column (default = 0.95)')
    parser_segmeth.add_argument('--dmr_minreads', default=8, help='minimum reads per group for DMR prediction (default=8)')
    parser_segmeth.add_argument('--dmr_minratio', default=0.3, help='minimum reads ratio for DMR prediction (default=0.3)')
    parser_segmeth.add_argument('--dmr_maxoverlap', default=0.0, help='max overlap between distributions (default = 0.0)')
    parser_segmeth.add_argument('--dmr_mindiff', default=0.4, help='minimum difference between means (default = 0.4)')
    parser_segmeth.add_argument('--dmr_minmotifs', default=20, help='minimum motif count for DMR prediction (default = 20)')
    parser_segmeth.add_argument('--phased', action='store_true', default=False, help='currently only considers two phases (diploid)')
    parser_segmeth.add_argument('--predict_dmr', action='store_true', default=False, help='enable DMR prediction in unphased data')

    # options for methylation strip / violin / ridge plots
    parser_segplot.add_argument('-s', '--segmeth', required=True, help='output from segmeth.py')
    parser_segplot.add_argument('-m', '--samples', default=None, help='samples, comma delimited')
    parser_segplot.add_argument('-d', '--mods', default=None, help='mods, comma delimited')
    parser_segplot.add_argument('-c', '--categories', default=None, help='categories, comma delimited, need to match seg_name column from input')
    parser_segplot.add_argument('-v', '--violin', default=False, action='store_true')
    parser_segplot.add_argument('-g', '--ridge', default=False, action='store_true')
    parser_segplot.add_argument('-n', '--mincalls', default=10, help='minimum number of calls to include site (methylated + unmethylated) (default=10)')
    parser_segplot.add_argument('-r', '--minreads', default=1, help='minimum reads in interval (default = 1)')
    parser_segplot.add_argument('-q', '--min_mapq', default=10, help='minimum mapping quality (mapq), default = 10')
    parser_segplot.add_argument('-a', '--group_by_annotation', default=False, action='store_true', help='group plots by annotation rather than by sample')
    parser_segplot.add_argument('-o', '--outfile', default=None, help='output file name (default: generated from input)')
    parser_segplot.add_argument('--metadata', default=None, help='sample metadata (tab-delimited with header, sample name as first column)')
    parser_segplot.add_argument('--usemeta', default=None, help='metadata to append to annotation (comma-delimited)')
    parser_segplot.add_argument('--width', default=None, help='figure width (default = automatic)')
    parser_segplot.add_argument('--height', default=6, help='figure height (default = 6)')
    parser_segplot.add_argument('--pointsize', default=1, help='point size for scatterplot (default = 1)')
    parser_segplot.add_argument('--min', default=-0.15, help='min (default = -0.15)')
    parser_segplot.add_argument('--max', default=1.15, help='max (default = 1.15)')
    parser_segplot.add_argument('--ylabel', default='pct methylation', help='set label for y-axis (default: pct methylation)')
    parser_segplot.add_argument('--tiltlabel', default=False, action='store_true')
    parser_segplot.add_argument('--vertlabel', default=False, action='store_true')
    parser_segplot.add_argument('--palette', default="tab10", help='palette (default = "tab10"), see https://seaborn.pydata.org/tutorial/color_palettes.html')
    parser_segplot.add_argument('--ridge_alpha', default=1.0, help='alpha (tranparency) for ridge plot fills (default = 1.0)')
    parser_segplot.add_argument('--ridge_spacing', default=-0.25, help='ridge plot spacing (generally negative, default = -0.25)')
    parser_segplot.add_argument('--ridge_smoothing', default=0.5, help='smoothing parameter for ridge plot, bigger is smoother (default=0.5)')
    parser_segplot.add_argument('--svg', default=False, action='store_true')

    # options for locus-specific plots
    parser_locus.add_argument('-d', '--data', default=None, help='text file with .bam filename and corresponding methylation database per line (whitespace-delimited)')
    parser_locus.add_argument('-b', '--bams', default=None, help='one or more .bams with MM and ML tags for modification calls (see samtags spec)')
    parser_locus.add_argument('-i', '--interval', required=True, help='chrom:start-end')
    parser_locus.add_argument('-g', '--gtf', default=None, help='genes or intervals to display in gtf format')
    parser_locus.add_argument('-C', '--csi', default=None, help='csi index for the gtf, optional')
    parser_locus.add_argument('-l', '--highlight', default=None, help='format: start-end, (can be chrom:start-end but chrom is ignored) can comma-delimit multiple highlights')
    parser_locus.add_argument('-m', '--mods', default=None, help='mods, comma-delimited for >1 (default to all available mods)')
    parser_locus.add_argument('-s', '--smoothwindowsize', default=None, help='size of window for smoothing (default=auto)')
    parser_locus.add_argument('-t', '--slidingwindowstep', default=1, help='step size for initial sliding window (default=1)')
    parser_locus.add_argument('-p', '--panelratios',  default=None, help='Alter panel ratios: needs to be 5 comma-seperated integers. Default: 1,5,1,3,3')
    parser_locus.add_argument('-q', '--min_mapq', default=10, help='minimum mapping quality (mapq), default = 10')
    parser_locus.add_argument('-r', '--ref', default=None, help='reference genome .fa (build .fai index with samtools faidx) (required for mod bams)')
    parser_locus.add_argument('-n', '--motif', default=None, help='expected modification motif (e.g. CG for 5mCpG required for mod bams)')
    parser_locus.add_argument('-c', '--plot_coverage', default=None, help='plot coverage from bam(s) (can be comma-delimited list)')
    parser_locus.add_argument('-o', '--outfile', default=None, help='output file name (default: generated from input)')
    parser_locus.add_argument('--samplebox', default=False, action='store_true', help='draw sample box with labels next to alignments')
    parser_locus.add_argument('--skip_align_plot', default=False, action='store_true', help='do not plot alignments')
    parser_locus.add_argument('--include_raw_plot', default=False, action='store_true', help='include plot raw signal')
    parser_locus.add_argument('--meth_thresh', default=0.8, help='modified base threshold (default=0.8)')
    parser_locus.add_argument('--can_thresh', default=0.8, help='canonical base threshold (default=0.8)')
    parser_locus.add_argument('--ctbam', default=None, help='specify which .bam(s) are C/T substitution data (can be comma-delimited)')
    parser_locus.add_argument('--logcover', default=False, action='store_true', help='apply log2(count+1) to coverage data (--plot_coverage)')
    parser_locus.add_argument('--coverprocs', default=1, help='processes to use for coverage function (default=1)')
    parser_locus.add_argument('--bed', default=None, help='.bed file for additional annotations (BED3+3: chrom, start, end, label, strand, color)')
    parser_locus.add_argument('--hidebedlabel', default=False, action='store_true', help='hide lables from .bed track')
    parser_locus.add_argument('--highlight_bed', default=None, help='BED3+1 format (chrom, start, end, optional_colour) where colour (optional) must be intelligible to matplotlib')
    parser_locus.add_argument('--highlight_subplot', default=False, action='store_true', help='make sub-plots for highlighted regions of smoothed plots grouped by colormap (incl --phasediff plot), requires --highlight and --colormap')
    parser_locus.add_argument('--variants', default=None, help='variants to highlight, bgzipped/tabix VCF')
    parser_locus.add_argument('--splitvar', default=None, help='split variant on variant with ID (uses ID field from --variants VCF)')
    parser_locus.add_argument('--variantpalette', default='Set1', help='colour palette for variant ticks (default = Set1)')
    parser_locus.add_argument('--variantsize', default=6, help='size of variant ticks (default = 6)')
    parser_locus.add_argument('--motifsize', default=2, help='mod motif size, only used with -b/--bams (default is 2 as "CG" is most common use case, e.g. set to 1 for 6mA)')
    parser_locus.add_argument('--allreads', default=False, action='store_true', help='show all alignments (secondary/supplementary alignments hidden by default)')
    parser_locus.add_argument('--phased', default=False, action='store_true', help='split samples into phases')
    parser_locus.add_argument('--phasediff', default=False, action='store_true', help='add absolute difference between phases as output')
    parser_locus.add_argument('--ignore_ps', default=False, action='store_true', help='do not use phase set (PS) when plotting phased data (HP only)')
    parser_locus.add_argument('--color_by_hp', default=False, action='store_true', help='color samples by HP value (req --phased)')
    parser_locus.add_argument('--color_by_phase', default=False, action='store_true', help='color samples by phase (req --phased)')
    parser_locus.add_argument('--colormap', default=None, help='map annotations to colours, can be file with mapping or "auto"')
    parser_locus.add_argument('--phase_labels', default=None, help='if --color_by_hp substitute HP tags for labels. Format HP:Label comma-delimited e.g.: 1:Father,2:Mother')
    parser_locus.add_argument('--include_unphased', default=False, action='store_true', help='include an "unphased" category if called with --phased')
    parser_locus.add_argument('--readmask', default=None, help='mask reads from being shown in interval(s) (start-end or chrom:start-end; chrom ignored). Can be comma-delimited.')
    parser_locus.add_argument('--readmarker', default='o', help='marker for (un)methylated glpyhs in read panel (matplotlib markers, default=o)')
    parser_locus.add_argument('--markeralpha', default=1.0, help='alpha (transparency) for (un)methylation marker (default=1.0)')
    parser_locus.add_argument('--readmarkersize', default=2.0, help='marker size for (un)methylated glpyhs in read panel (default=2.0)')
    parser_locus.add_argument('--readlinewidth', default=1.0, help='width for lines representing read alignments (default=1.0)')
    parser_locus.add_argument('--readlinealpha', default=0.5, help='alpha (transparency) for read mapping lines (default=0.4)')
    parser_locus.add_argument('--readopenmarkeredgecolor', default=None, help='edge color for open (unmethylated) markers in read plot (default = sample color)')
    parser_locus.add_argument('--slidingwindowsize', default=2, help='size of initial sliding window for coverage check (default=2)')
    parser_locus.add_argument('--smoothfunc', default='hanning', help='smoothing function, one of: flat,hanning,hamming,bartlett,blackman (default = hanning)')
    parser_locus.add_argument('--smoothalpha', default=1.0, help='alpha (transparency) value for smoothed plot (default = 1.0)')
    parser_locus.add_argument('--smoothlinewidth', default=4.0, help='smooth line width (default = 4.0)')
    parser_locus.add_argument('--shuffle', default=False, action='store_true', help='shuffle sample order for smoothed plot (may reduce visual bias due to sample order)')
    parser_locus.add_argument('--smoothed_csv', default=None, help='output values from smoothed plot to CSV format (can specify filename or "auto")')
    parser_locus.add_argument('--maskcutoff', default=1, help='read count masking cutoff (default=1)')
    parser_locus.add_argument('--maxmaskedfrac', default=1.0, help='skip smoothed plot if fraction of sample masked (--maskcutoff) > this value (default = 1.0)')
    parser_locus.add_argument('--nomask', default=False, action='store_true', help='skip drawing segment masks')
    parser_locus.add_argument('--mincalls', default=0, help='drop modspace positions if call count (meth+unmeth) < --mincalls (default=0)')
    parser_locus.add_argument('--max_read_density', default=None, help='filter reads with call density greater >= value, can be helpful in footprinting assays (default=None)')
    parser_locus.add_argument('--modspace', default=None, help='spacing between links in top panel (default=auto)')
    parser_locus.add_argument('--genes', default=None, help='genes of interest (comma delimited)')
    parser_locus.add_argument('--hidegenelabels', default=False, action='store_true', help='plot gene names')
    parser_locus.add_argument('--hidelegend', default=False, action='store_true', help='hide legends')
    parser_locus.add_argument('--exonheight', default=0.8, help='set exon height (default=0.8)')
    parser_locus.add_argument('--show_transcripts', default=False, action='store_true', help='plot all transcripts, use transcript_id/transcript_name attrs')
    parser_locus.add_argument('--ymin', default=-0.05, help='y-axis minimum for smoothed plot (default = -0.05)')
    parser_locus.add_argument('--ymax', default=1.05, help='y-axis maximum for smoothed plot (default = 1.05)')
    parser_locus.add_argument('--cover_ymin', default=0, help='y-axis minimum for coverage plot (default = 0)')
    parser_locus.add_argument('--nticks', default=10, help='tick count (default=10)')
    parser_locus.add_argument('--statname', default=None, help='label for raw statistic plot')
    parser_locus.add_argument('--samplepalette', default="tab10", help='palette for samples (default = "tab10"), see https://seaborn.pydata.org/tutorial/color_palettes.html')
    parser_locus.add_argument('--coverpalette', default="mako", help='colour palette name for coverage plot (default = "mako")')
    parser_locus.add_argument('--highlightpalette', default="Blues", help='colour palette name for highlights (default = "Blues")')
    parser_locus.add_argument('--genepalette', default="viridis", help='colour palette name for highlights (default = "viridis")')
    parser_locus.add_argument('--highlight_alpha', default=0.25, help='alpha for highlighting in panels (between 0 and 1, default = 0.25)')
    parser_locus.add_argument('--excl_ambig', action='store_true', default=False)
    parser_locus.add_argument('--primary_only', action='store_true', default=False, help='ignore non-primary alignments')
    parser_locus.add_argument('--unambig_highlight', action='store_true', default=False)
    parser_locus.add_argument('--width', default=16, help='image width (inches, default=16)')
    parser_locus.add_argument('--height', default=8, help='image width (inches, default=8)')
    parser_locus.add_argument('--notext', default=False, action='store_true', help='remove all text from figure')
    parser_locus.add_argument('--svg', action='store_true')

    # options for region plots
    parser_region.add_argument('-i', '--interval', required=True, help='chrom:start-end')
    parser_region.add_argument('-d', '--data', default=None, help='text file with .bam filename and corresponding methylation database per line(whitespace-delimited)')
    parser_region.add_argument('-b', '--bams', default=None, help='one or more .bams with MM and ML tags for modification calls (see samtags spec)')
    parser_region.add_argument('-n', '--motif', required=True, help='normalise window sizes to motif occurance')
    parser_region.add_argument('-r', '--ref', required=True, help='ref genome fasta, required if normalising windows with -n/--norm_motif')
    parser_region.add_argument('-g', '--gtf', default=None, help='genes or intervals to display in gtf format')
    parser_region.add_argument('-C', '--csi', default=None, help='csi index for the gtf, optional')
    parser_region.add_argument('-l', '--highlight', default=None, help='format: start-end, (can be chrom:start-end but chrom is ignored) can comma-delimit multiple highlights')
    parser_region.add_argument('-w', '--windows', default=None, help='set window count, default=auto')
    parser_region.add_argument('-p', '--procs', default=1, help='multiprocessing')
    parser_region.add_argument('-m', '--mods', default=None, help='mods to consider (comma-delimited, default = all available)')
    parser_region.add_argument('-s', '--smoothwindowsize', default=None, help='size of window for smoothing (default=auto)')
    parser_region.add_argument('-q', '--min_mapq', default=10, help='minimum mapping quality (mapq), default = 10')
    parser_region.add_argument('-c', '--plot_coverage', default=None, help='plot coverage from bam(s) (can be comma-delimited list)')
    parser_region.add_argument('-o', '--outfile', default=None, help='output file name (default: generated from input)')
    parser_region.add_argument('--meth_thresh', default=0.8, help='modified base threshold (default=0.8)')
    parser_region.add_argument('--can_thresh', default=0.8, help='canonical base threshold (default=0.8)')
    parser_region.add_argument('--bedmethyl', default=False, action='store_true', help='input is (tabix-indexed) bedMethyl')
    parser_region.add_argument('--ctbam', default=None, help='specify which .bam(s) are C/T substitution data (can be comma-delimited)')
    parser_region.add_argument('--bed', default=None, help='.bed file for additional annotations (BED3+3: chrom, start, end, label, strand, color)')
    parser_region.add_argument('--hidebedlabel', default=False, action='store_true', help='hide lables from .bed track')
    parser_region.add_argument('--logcover', default=False, action='store_true', help='apply log2(count+1) to coverage data (--plot_coverage)')
    parser_region.add_argument('--allreads', default=False, action='store_true', help='show all alignments (secondary/supplementary alignments hidden by default)')
    parser_region.add_argument('--highlight_bed', default=None, help='BED3+1 format (chrom, start, end, optional_colour) where colour (optional) must be intelligible to matplotlib')
    parser_region.add_argument('--motifsize', default=2, help='mod motif size, only used with -b/--bams (default is 2 as "CG" is most common use case, e.g. set to 1 for 6mA)')
    parser_region.add_argument('--maxuncovered', default=50.0, help='maximum percentage of uncovered windows tolerated (default = 50.0)')
    parser_region.add_argument('--modspace', default=None, help='increase to increase spacing between links in top panel (default=auto)')
    parser_region.add_argument('--readmask', default=None, help='mask reads from being shown in interval(s) (start-end or chrom:start-end; chrom ignored). Can be comma-delimited.')
    parser_region.add_argument('--min_window_calls', default=1, help='minimum reads per window to include in plot (default = 1)')
    parser_region.add_argument('--smoothfunc', default='hanning', help='smoothing function, one of: flat,hanning,hamming,bartlett,blackman (default = hanning)')
    parser_region.add_argument('--smoothalpha', default=1.0, help='alpha (transparency) value for smoothed plot (default = 1.0)')
    parser_region.add_argument('--smoothlinewidth', default=4.0, help='smooth line width (default = 4.0)')
    parser_region.add_argument('--shuffle', default=False, action='store_true', help='shuffle sample order for smoothed plot (may reduce visual bias due to sample order)')
    parser_region.add_argument('--segment_csv', default=None, help='output values from smoothed segment plot to specified filename in CSV format (default=None)')
    parser_region.add_argument('--eff', default=None, help='conversion efficiency (for e.g. bs-seq or em-seq), input as comma-delimited sample:eff e.g. MySample1.m:0.9,MySample2.m:0.8')
    parser_region.add_argument('--ymin', default=-0.05, help='y-axis minimum for smoothed plot (default = -0.05)')
    parser_region.add_argument('--ymax', default=1.05, help='y-axis maximum for smoothed plot (default = 1.05')
    parser_region.add_argument('--cover_ymin', default=0, help='y-axis minimum for coverage plot (default = 0)')
    parser_region.add_argument('--max_read_density', default=None, help='filter reads with call density greater >= value, can be helpful in footprinting assays (default=None)')
    parser_region.add_argument('--samplepalette', default="tab10", help='palette for samples (default = "tab10"), see https://seaborn.pydata.org/tutorial/color_palettes.html')
    parser_region.add_argument('--coverpalette', default="mako", help='colour palette name for coverage plot (default = "mako")')
    parser_region.add_argument('--highlightpalette', default="Blues", help='colour palette name for highlights (default = "Blues")')
    parser_region.add_argument('--genepalette', default="viridis", help='colour palette name for highlights (default = "viridis")')
    parser_region.add_argument('--gene_track_height', default=None, help='maximum number of gene track layers')
    parser_region.add_argument('--highlight_alpha', default=0.25, help='alpha for highlighting in panels (between 0 and 1, default = 0.25)')
    parser_region.add_argument('--highlight_centerline', default=None, help='change highlight to line (specify width)')
    parser_region.add_argument('--panelratios',  default=None, help='Alter panel ratios: needs to be 4 (or 5 if --plot_coverage) comma-seperated integers. Default: 1,5,3,3')
    parser_region.add_argument('--nticks', default=10, help='tick count (default=10)')
    parser_region.add_argument('--skip_align_plot', default=False, action='store_true', help='blank alignment plot, useful if unneeded or for runtime.')
    parser_region.add_argument('--force_align_plot', default=False, action='store_true', help='retain alignment plot even over regions > 5Mbp where it would be disabled automatically')
    parser_region.add_argument('--genes', default=None, help='genes of interest (comma delimited)')
    parser_region.add_argument('--labelgenes', default=False, action='store_true', help='plot gene names')
    parser_region.add_argument('--show_transcripts', default=False, action='store_true', help='plot all transcripts, use transcript_id/transcript_name attrs')
    parser_region.add_argument('--exonheight', default=0.8, help='set exon height (default=0.8)')
    parser_region.add_argument('--width', default=16, help='image width (inches, default=16)')
    parser_region.add_argument('--height', default=8, help='image width (inches, default=8)')
    parser_region.add_argument('--dmr_minreads', default=8, help='minimum reads per group for DMR prediction (default=8)')
    parser_region.add_argument('--dmr_minratio', default=0.3, help='minimum reads ratio for DMR prediction (default=0.3)')
    parser_region.add_argument('--dmr_maxoverlap', default=0.0, help='max overlap between distributions (default = 0.0)')
    parser_region.add_argument('--dmr_mindiff', default=0.4, help='minimum difference between means (default = 0.4)')
    parser_region.add_argument('--dmr_minmotifs', default=20, help='minimum motif count for DMR prediction (default = 20)')
    parser_region.add_argument('--write_dmrs', default=False, action='store_true', help='record differentially methylated windows to a file')
    parser_region.add_argument('--phased', action='store_true', default=False, help='currently only considers two phases (diploid)')
    parser_region.add_argument('--primary_only', action='store_true', default=False, help='ignore non-primary alignments')
    parser_region.add_argument('--color_by_hp', default=False, action='store_true', help='color samples by HP value (req --phased)')
    parser_region.add_argument('--colormap', default=None, help='map annotations to colours, can be file with mapping or "auto"')
    parser_region.add_argument('--scale_fullwidth', default=None, help='scale plot output relative to value (e.g. use length of chrom 1)')
    parser_region.add_argument('--svg', action='store_true', default=False)

    # options for composite plots
    parser_composite.add_argument('-d', '--data', default=None, help='text file with .bam filename and corresponding methylation database per line(whitespace-delimited)')
    parser_composite.add_argument('-b', '--bams', default=None, help='one or more .bams with MM and ML tags for modification calls (see samtags spec)')
    parser_composite.add_argument('-s', '--segdata', required=True, help='BED3+1: chrom, start, end, strand')
    parser_composite.add_argument('-r', '--ref', required=True, help='ref genome fasta')
    parser_composite.add_argument('-t', '--teref', required=True, help='TE ref fasta')
    parser_composite.add_argument('-p', '--procs', default=1, help='multiprocessing')
    parser_composite.add_argument('-c','--palette', default="tab10", help='palette for samples (default = "tab10"), see https://seaborn.pydata.org/tutorial/color_palettes.html')
    parser_composite.add_argument('-a', '--alpha', default=0.3, help='alpha (default: 0.3)')
    parser_composite.add_argument('-w', '--linewidth', default=1, help='line width (default: 1)')
    parser_composite.add_argument('-l', '--lenfrac', default=0.95, help='fraction of TE length that must align (default 0.95)')
    parser_composite.add_argument('-q', '--min_mapq', default=10, help='minimum mapping quality (mapq), default = 10')
    parser_composite.add_argument('-o', '--outfile', default=None, help='output file name (default: generated from input)')
    parser_composite.add_argument('--meth_thresh', default=0.8, help='modified base threshold (default=0.8)')
    parser_composite.add_argument('--can_thresh', default=0.8, help='canonical base threshold (default=0.8)')
    parser_composite.add_argument('--meanplot_ylabel', default='% methylation', help='set y-axis label on mean plot')
    parser_composite.add_argument('--meanplot_cutoff', default=None, help='override site coverage cutoff for mean plot (see output for automatic value)')
    parser_composite.add_argument('--mod', default=None, help='modification to plot (mod codes will be listed, default: infer from sample name')
    parser_composite.add_argument('--motif', default='CG', help='modified motif to highlight (default = CG)')
    parser_composite.add_argument('--blocks', default=None, help='blocks to highlight (txt file with start, end, name, hex colour)')
    parser_composite.add_argument('--start', default=None, help='start plotting at this base (default None)')
    parser_composite.add_argument('--end', default=None, help='end plotting at this base (default None)')
    parser_composite.add_argument('--mincalls', default=100, help='minimum call count to include elt (default = 100)')
    parser_composite.add_argument('--minelts', default=1, help='minimum output elements (default = 1)')
    parser_composite.add_argument('--maxelts', default=200, help='maximum output elements, if > max random.sample() (default = 200)')
    parser_composite.add_argument('--slidingwindowsize', default=10, help='size of sliding window for meth frac (default 10)')
    parser_composite.add_argument('--slidingwindowstep', default=1, help='step size for meth frac (default 1)')
    parser_composite.add_argument('--smoothwindowsize', default=8, help='size of window for smoothing (default 8)')
    parser_composite.add_argument('--smoothfunc', default='hanning', help='smoothing function, one of: flat,hanning,hamming,bartlett,blackman (default = hanning)')
    parser_composite.add_argument('--ymin', default=-0.05, help='y-axis minimum for smoothed plot')
    parser_composite.add_argument('--ymax', default=1.05, help='y-axis maximum for smoothed plot')
    parser_composite.add_argument('--max_read_density', default=None, help='filter reads with call density greater >= value, can be helpful in footprinting assays (default=None)')
    parser_composite.add_argument('--excl_ambig', action='store_true', default=False)
    parser_composite.add_argument('--primary_only', action='store_true', default=False, help='ignore non-primary alignments')
    parser_composite.add_argument('--phased', action='store_true', default=False, help='currently only considers two phases (diploid)')
    parser_composite.add_argument('--color_by_phase', default=False, action='store_true', help='color samples by HP value (req --phased)')
    parser_composite.add_argument('--output_table', default=False, action='store_true', help='output per-site data to table (.tsv)')
    parser_composite.add_argument('--svg', action='store_true', default=False)

    # options for whole genome output
    parser_wgmeth.add_argument('-b', '--bam', required=True, help='bam used for methylation calling')
    parser_wgmeth.add_argument('-d', '--methdb', default=None, help='methylation database')
    parser_wgmeth.add_argument('-s', '--binsize', default=1000000, help='bin size for parallelisation, default = 1000000')
    parser_wgmeth.add_argument('-f', '--fai', default=None, help='fasta index (.fai), default = --ref + .fai, required for .db files')
    parser_wgmeth.add_argument('-m', '--mod', default=None, help='output for specific mod (names vary, see output for hints)')
    parser_wgmeth.add_argument('-p', '--procs', default=1, help='multiprocessing')
    parser_wgmeth.add_argument('-c', '--chrom', default=None, help='limit analysis to one chromosome')
    parser_wgmeth.add_argument('-q', '--min_mapq', default=10, help='minimum mapping quality (mapq), default = 10')
    parser_wgmeth.add_argument('-r', '--ref', default=None, help='reference genome .fa (build .fai index with samtools faidx) (required for mod bams)')
    parser_wgmeth.add_argument('-o', '--outfile', default=None, help='output file name (default: generated from input)')
    parser_wgmeth.add_argument('-l', '--minlen', default=0, help='minimum chromosome length (default = 0)')
    parser_wgmeth.add_argument('--meth_thresh', default=0.8, help='modified base threshold (default=0.8)')
    parser_wgmeth.add_argument('--can_thresh', default=0.8, help='canonical base threshold (default=0.8)')
    parser_wgmeth.add_argument('--ctbam', default=None, help='specify which .bam(s) are C/T substitution data (can be comma-delimited)')
    parser_wgmeth.add_argument('--motif', default=None, help='expected modification motif (e.g. CG for 5mCpG required for mod bams)')
    parser_wgmeth.add_argument('--max_read_density', default=None, help='filter reads with call density greater >= value, can be helpful in footprinting assays (default=None)')
    parser_wgmeth.add_argument('--dss', default=False, action='store_true', help='output in DSS format (default = bedMethyl)')
    parser_wgmeth.add_argument('--phased', action='store_true', default=False, help='split output into phases (currently just 1,2)')
    parser_wgmeth.add_argument('--primary_only', action='store_true', default=False, help='ignore non-primary alignments')

    # options for custom db
    parser_custom.add_argument('-m', '--methdata', required=True, help='per-read methylation output table')
    parser_custom.add_argument('--header', default=False, action='store_true', help='input table has header')
    parser_custom.add_argument('--delimiter', default=None, help='column delimimter char (default = whitespace (i.e. tab or space)')
    parser_custom.add_argument('--readname', required=True, help='readname column number')
    parser_custom.add_argument('--chrom', required=True, help='chromosome column number')
    parser_custom.add_argument('--pos', required=True, help='genomic (i.e. on chromosome/contig) position column number, 0-based')
    parser_custom.add_argument('--strand', required=True, help='strand column number')
    parser_custom.add_argument('--modprob', required=True, help='column number for probability of modified base')
    parser_custom.add_argument('--canprob', default=None, help='column number for probability of canonical base (if not given, assume p=1-modprob)')
    parser_custom.add_argument('--modbasecol', default=None, help='column number for modified base/motif name (optional, can use --modbase instead)')
    parser_custom.add_argument('--modbase', default=None, help='specify modified base/motif name (overrides --modbasecol)')
    parser_custom.add_argument('-d', '--db', default=None, help='database name (default: auto-infer)')
    parser_custom.add_argument('--minmodprob', default=0.8, help='probability threshold for calling modified base (default = 0.8)')
    parser_custom.add_argument('--mincanprob', default=None, help='probability threshold for calling canonical base (default = minmodprob)')
    parser_custom.add_argument('-a', '--append', default=False, action='store_true', help='append to database')
    parser_custom.add_argument('--motifsize', default=2, help='mod motif size (default is 2 as "CG" is most common use case, e.g. set to 1 for 6mA)')

    # options for megalodon db
    parser_megalodon.add_argument('-m', '--methdata', required=True, help='megalodon per_read_text methylation output')
    parser_megalodon.add_argument('-d', '--db', default=None, help='database name (default: auto-infer)')
    parser_megalodon.add_argument('-p', '--minprob', default=0.8, help='probability threshold for calling modified or unmodified base (default = 0.8)')
    parser_megalodon.add_argument('-a', '--append', default=False, action='store_true', help='append to database')
    parser_megalodon.add_argument('--motifsize', default=2, help='mod motif size (default is 2 as "CG" is most common use case, e.g. set to 1 for 6mA)')

    # options for guppy db
    parser_guppy.add_argument('-s', '--samplename', required=True, help='name for sample')
    parser_guppy.add_argument('-f', '--fast5', required=True, help='fast5 with called bases')
    parser_guppy.add_argument('-p', '--procs', default=1, help='multiprocessing')
    parser_guppy.add_argument('-m', '--motif', required=True, help='motif e.g. G[A]TC or [C]G')
    parser_guppy.add_argument('-n', '--modname', required=True, help='mod name in guppy fast5 modified base alphabet (5mC, 6mA, etc)')
    parser_guppy.add_argument('-b', '--bam', required=True, help='.bam file containing alignments of reads from fast5')
    parser_guppy.add_argument('-r', '--ref', required=True, help='reference genome fasta (samtools faidx indexed)')
    parser_guppy.add_argument('--minprob', default=0.8, help='probability threshold for calling modified or unmodified base (default = 0.8)')
    parser_guppy.add_argument('-a', '--append', default=False, action='store_true', help='append to database')
    parser_guppy.add_argument('--include_unmatched', action='store_true', default=False, help='include sites where read base does not match genome base')
    parser_guppy.add_argument('--motifsize', default=2, help='mod motif size (default is 2 as "CG" is most common use case, e.g. set to 1 for 6mA)')
    parser_guppy.add_argument('--force', default=False, action='store_true')

    # options for nanopolish db
    parser_nanopolish.add_argument('-m', '--methdata', required=True, help='whole genome nanopolish methylation output, can be comma-delimited')
    parser_nanopolish.add_argument('-d', '--db', default=None, help='database name (default: auto-infer)')
    parser_nanopolish.add_argument('-t', '--thresh', default=2.5, help='llr threshold (default = 2.5; if using --scalegroup the suggested setting is 2.0)')
    parser_nanopolish.add_argument('-a', '--append', default=False, action='store_true', help='append to database')
    parser_nanopolish.add_argument('-s', '--scalegroup', default=False, action='store_true', help='scale threshold by number of CpGs in a group')
    parser_nanopolish.add_argument('-n', '--modname', default='CpG', help='modification type (tag if combining multiple mods, default = "CpG")')
    parser_nanopolish.add_argument('--motif', default='CG', help='mod motif (default = CG)')

    # options for base-substitution db
    parser_substitution.add_argument('-b', '--bam', required=True, help='bam file, requires MD tag')
    parser_substitution.add_argument('-d', '--db', required=True, help='database name (will append .db if necessary)')
    parser_substitution.add_argument('--append', default=False, action='store_true', help='append to database')

    # options for adjusting methylation / unmethylation cutoffs in methylartist db
    parser_adjustcutoffs.add_argument('-d', '--db', required=True, help='methylartist database')
    parser_adjustcutoffs.add_argument('--mod', required=True, help='modification to plot (will list for user if incorrect)')
    parser_adjustcutoffs.add_argument('-m', '--methylated', required=True, help='mark as methylated above cutoff value')
    parser_adjustcutoffs.add_argument('-u', '--unmethylated', required=True, help='mark as unmethylated below cutoff value')

    # options for score distribution exploration function
    parser_scoredist.add_argument('-d', '--db', default=None, help='methylartist database(s), can be comma-delimited')
    parser_scoredist.add_argument('-b', '--bam', default=None, help='one or more .bam files with MM and ML tags for modification calls (see samtags spec)')
    parser_scoredist.add_argument('-n', '--n', default=1000000, help='sample size (default = 1000000)')
    parser_scoredist.add_argument('-m', '--mod', required=True, help='modification to plot (will list for user if incorrect)')
    parser_scoredist.add_argument('-r', '--ref', default=None, help='reference genome fasta (samtools faidx indexed)')
    parser_scoredist.add_argument('-o', '--outfile', default=None, help='output file name (default: generated from input)')
    parser_scoredist.add_argument('--motif', default=None, help='modified motif to highlight (e.g. CG)')
    parser_scoredist.add_argument('--xmin', default=None)
    parser_scoredist.add_argument('--xmax', default=None)
    parser_scoredist.add_argument('--lw', default=2, help='line width (default = 2)')
    parser_scoredist.add_argument('--palette', default="tab10", help='palette for phases (default = "tab10"), see https://seaborn.pydata.org/tutorial/color_palettes.html')
    parser_scoredist.add_argument('--svg', action='store_true', default=False)

    args = parser.parse_args()
    return args

if __name__ == '__main__':
    main()
