#!python
# encoding: utf-8
"""
TempoDetector tempo estimation algorithm.

"""

from __future__ import absolute_import, division, print_function

import argparse
from functools import partial

import numpy as np

from madmom.audio import SignalProcessor
from madmom.features import (ActivationsProcessor, RNNBeatProcessor,
                             TempoEstimationProcessor)
from madmom.io import write_events, write_tempo
from madmom.processors import IOProcessor, io_arguments


def write_mirex(tempi, filename):
    """
    Write the most dominant tempi and the relative strength to a file (in
    MIREX format).

    Parameters
    ----------
    tempi : numpy array
        Array with the detected tempi (first column) and their strengths
        (second column).
    filename : str or file handle
        Output file.

    """
    # make the given tempi a 2d array
    tempi = np.array(tempi, ndmin=2)
    # default values
    t1 = t2 = strength = np.nan
    # only one tempo was detected
    if len(tempi) == 1:
        t1 = tempi[0][0]
        strength = 1.
        # generate a fake second tempo (safer to output a real value than NaN)
        # the boundary of 68 bpm is taken from Tzanetakis 2013 ICASSP paper
        if t1 < 68:
            t2 = t1 * 2.
        else:
            t2 = t1 / 2.
    # consider only the two strongest tempi and strengths
    elif len(tempi) > 1:
        t1, t2 = tempi[:2, 0]
        strength = tempi[0, 1] / sum(tempi[:2, 1])
    # for MIREX, the lower tempo must be given first
    if t1 > t2:
        t1, t2, strength = t2, t1, 1. - strength
    # format as a numpy array
    out = np.array([t1, t2, strength], ndmin=2)
    # write to output
    write_events(out, filename, fmt=['%.2f', '%.2f', '%.2f'])


def main():
    """TempoDetector"""

    # define parser
    p = argparse.ArgumentParser(
        formatter_class=argparse.RawDescriptionHelpFormatter, description='''
    The TempoDetector program detects the dominant tempi according to the
    algorithm described in:

    "Accurate Tempo Estimation based on Recurrent Neural Networks and
     Resonating Comb Filters"
    Sebastian Böck, Florian Krebs and Gerhard Widmer.
    Proceedings of the 16th International Society for Music Information
    Retrieval Conference (ISMIR), 2015.

    This program can be run in 'single' file mode to process a single audio
    file and write the detected tempi to STDOUT or the given output file.

      $ TempoDetector single INFILE [-o OUTFILE]

    If multiple audio files should be processed, the program can also be run
    in 'batch' mode to save the detected tempi to files with the given suffix.

      $ TempoDetector batch [-o OUTPUT_DIR] [-s OUTPUT_SUFFIX] FILES

    If no output directory is given, the program writes the files with the
    detected tempi to the same location as the audio files.

    The 'pickle' mode can be used to store the used parameters to be able to
    exactly reproduce experiments.

    ''')
    # version
    p.add_argument('--version', action='version', version='TempoDetector.2016')
    # input/output options
    io_arguments(p, output_suffix='.bpm.txt', online=True)
    ActivationsProcessor.add_arguments(p)
    # signal processing arguments
    SignalProcessor.add_arguments(p, norm=False, gain=0)
    # tempo arguments
    TempoEstimationProcessor.add_arguments(p, method='comb', min_bpm=40.,
                                           max_bpm=250., act_smooth=0.14,
                                           hist_smooth=9, hist_buffer=10.,
                                           alpha=0.79)
    # mirex stuff
    g = p.add_mutually_exclusive_group()
    g.add_argument('--mirex', dest='tempo_format',
                   action='store_const', const='mirex',
                   help='use the MIREX output format (lower tempo first)')
    g.add_argument('--all', dest='tempo_format',
                   action='store_const', const='all',
                   help='output all detected tempi in raw format')

    # parse arguments
    args = p.parse_args()

    # set immutable arguments
    args.fps = 100

    # print arguments
    if args.verbose:
        print(args)

    # input processor
    if args.load:
        # load the activations from file
        in_processor = ActivationsProcessor(mode='r', **vars(args))
    else:
        # use a RNN to predict the beats
        in_processor = RNNBeatProcessor(**vars(args))

    # output processor
    if args.save:
        # save the RNN beat activations to file
        out_processor = ActivationsProcessor(mode='w', **vars(args))
    else:
        # perform tempo estimation based on the beat activation function
        tempo_estimator = TempoEstimationProcessor(**vars(args))
        # output handler
        if args.tempo_format == 'mirex':
            # output in the MIREX format (i.e. slower tempo first)
            output = write_mirex
        elif args.tempo_format in ('raw', 'all'):
            # borrow event writer for outputting multiple values
            output = partial(write_events, fmt='%.2f\t%.3f')
        else:
            # normal output
            output = write_tempo
        # sequentially process them
        out_processor = [tempo_estimator, output]

    # create an IOProcessor
    processor = IOProcessor(in_processor, out_processor)

    # finally call the processing function (single/batch processing)
    args.func(processor, **vars(args))


if __name__ == '__main__':
    main()
