#!python
# encoding: utf-8
"""
PianoTranscriptor (piano) note transcription algorithm.

"""

from __future__ import absolute_import, division, print_function

import argparse

import numpy as np

from madmom.audio.filters import midi2hz
from madmom.audio.signal import SignalProcessor
from madmom.features import ActivationsProcessor
from madmom.features.notes import (ADSRNoteTrackingProcessor,
                                   CNNPianoNoteProcessor)
from madmom.io import write_midi, write_notes
from madmom.processors import IOProcessor, io_arguments


def write_mirex(notes, filename):
    """
    Write the notes to a file in MIREX format (onset, offset, f0)

    Parameters
    ----------
    notes : numpy array, shape (num_notes, 2)
        Notes, row format 'onset_time' 'note_number' ['duration' ['velocity']].
    filename : str or file handle
        File to write the notes to.

    """
    # note format: onset, offset, f0
    onsets = notes[:, 0]
    offsets = notes[:, 0] + notes[:, 2]
    freqs = midi2hz(notes[:, 1])
    notes = np.c_[onsets, offsets, freqs]
    # write the notes
    write_notes(notes, filename, fmt=['%.2f', '%.2f', '%.2f'])


def main():
    """PianoTranscriptor"""

    # define parser
    p = argparse.ArgumentParser(
        formatter_class=argparse.RawDescriptionHelpFormatter, description='''
    The PianoTranscriptor program detects all (piano) notes in an audio
    file according to the algorithm described in:

    "Deep Polyphonic ADSR Piano Note Transcription",
    Rainer Kelz, Sebastian Böck and Gerhard Widmer
    Proceedings of the 44th International Conference on Acoustics, Speech and
    Signal Processing (ICASSP), 2019.

    This program can be run in 'single' file mode to process a single audio
    file and write the detected notes to STDOUT or the given output file.

      $ PianoTranscriptor single INFILE [-o OUTFILE]

    If multiple audio files should be processed, the program can also be run
    in 'batch' mode to save the detected notes to files with the given suffix.

      $ PianoTranscriptor batch [-o OUTPUT_DIR] [-s OUTPUT_SUFFIX] FILES

    If no output directory is given, the program writes the files with the
    detected notes to the same location as the audio files.

    The 'pickle' mode can be used to store the used parameters to be able to
    exactly reproduce experiments.

    ''')
    # version
    p.add_argument('--version', action='version',
                   version='PianoTranscriptor.2019')
    # input/output arguments
    io_arguments(p, output_suffix='.notes.txt')
    ActivationsProcessor.add_arguments(p)
    # signal processing arguments
    SignalProcessor.add_arguments(p, norm=False, gain=0, start=True, stop=True)
    # output arguments
    p.add_argument('--midi', dest='output_format', action='store_const',
                   const='midi', help='save as MIDI')
    p.add_argument('--mirex', dest='output_format', action='store_const',
                   const='mirex', help='use the MIREX output format')

    # parse arguments
    args = p.parse_args()

    # set immutable defaults
    args.fps = 50
    args.pitch_offset = 21

    # set the suffix for midi files
    if args.output_format == 'midi':
        args.output_suffix = '.mid'

    # print arguments
    if args.verbose:
        print(args)

    # input processor
    if args.load:
        # load the activations from file
        in_processor = ActivationsProcessor(mode='r', **vars(args))
    else:
        # use a RNN to predict the notes
        in_processor = CNNPianoNoteProcessor(**vars(args))

    # output processor
    if args.save:
        # save the CNN note activations to file
        out_processor = ActivationsProcessor(mode='w', **vars(args))
    else:
        # perform peak picking on the activation function
        peak_picking = ADSRNoteTrackingProcessor(**vars(args))
        # output everything in the right format
        if args.output_format is None:
            output = write_notes
        elif args.output_format == 'midi':
            output = write_midi
        elif args.output_format == 'mirex':
            output = write_mirex
        else:
            raise ValueError('unknown output format: %s' % args.output_format)
        out_processor = [peak_picking, output]

    # create an IOProcessor
    processor = IOProcessor(in_processor, out_processor)

    # and call the processing function
    args.func(processor, **vars(args))


if __name__ == '__main__':
    main()
