#!python
# encoding: utf-8
"""
TCNBeatTracker beat tracking algorithm.

"""

from __future__ import absolute_import, division, print_function

import argparse

from madmom.audio import SignalProcessor
from madmom.features import ActivationsProcessor
from madmom.features.beats import DBNBeatTrackingProcessor, TCNBeatProcessor
from madmom.io import write_beats
from madmom.processors import IOProcessor, io_arguments


def main():
    """TCNBeatTracker"""

    # define parser
    p = argparse.ArgumentParser(
        formatter_class=argparse.RawDescriptionHelpFormatter, description='''
    The TCNBeatTracker program detects all beats in an audio file according to
    the method described in:

    "Multi-Task learning of tempo and beat: learning one to improve the other"
    Sebastian Böck, Matthew Davies and Peter Knees.
    Proc. of the 20th International Society for Music Information Retrieval
    Conference (ISMIR), 2019.

    Which is a multi-task version of the original method described in:

    "Temporal convolutional networks for musical audio beat tracking"
    Matthew Davies and Sebastian Böck.
    Proc. of the 27th European Signal Processing Conference (EUSIPCO), 2019.

    This program can be run in 'single' file mode to process a single audio
    file and write the detected beats to STDOUT or the given output file.

      $ TCNBeatTracker single INFILE [-o OUTFILE]

    If multiple audio files should be processed, the program can also be run
    in 'batch' mode to save the detected beats to files with the given suffix.

      $ TCNBeatTracker batch [-o OUTPUT_DIR] [-s OUTPUT_SUFFIX] FILES

    If no output directory is given, the program writes the files with the
    detected beats to the same location as the audio files.

    The 'pickle' mode can be used to store the used parameters to be able to
    exactly reproduce experiments.

    ''')
    # version
    p.add_argument('--version', action='version', version='TCNBeatTracker')
    # input/output options
    io_arguments(p, output_suffix='.beats.txt')
    ActivationsProcessor.add_arguments(p)
    # signal processing arguments
    SignalProcessor.add_arguments(p, norm=False, gain=0)
    # peak picking arguments
    DBNBeatTrackingProcessor.add_arguments(p)

    # parse arguments
    args = p.parse_args()

    # set immutable arguments
    args.fps = 100

    # print arguments
    if args.verbose:
        print(args)

    # input processor
    if args.load:
        # load the activations from file
        in_processor = ActivationsProcessor(mode='r', **vars(args))
    else:
        # use a TCN to predict the beats
        in_processor = TCNBeatProcessor(**vars(args))

    # output processor
    if args.save:
        # save the TCN beat activations to file
        out_processor = ActivationsProcessor(mode='w', **vars(args))
    else:
        # track the beats with a DBN and output them
        beat_processor = DBNBeatTrackingProcessor(**vars(args))
        out_processor = [beat_processor, write_beats]

    # create an IOProcessor
    processor = IOProcessor(in_processor, out_processor)

    # and call the processing function
    args.func(processor, **vars(args))


if __name__ == '__main__':
    main()
