#!/usr/bin/python3
#
# test-chroma - test ld-chroma-decoder using ld-chroma-encoder
# Copyright (C) 2022 Adam Sampson
#
# This file is part of ld-decode.
#
# test-chroma is free software: you can redistribute it and/or
# modify it under the terms of the GNU General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

# The overall idea here is that we generate a test video using ffmpeg, encode
# it into TBC form using ld-chroma-encoder, then decode it using
# ld-chroma-decoder, and use ffmpeg's psnr filter to compare the decoded
# version to the original.
#
# The PSNR shouldn't be too low for any individual decode (in practice this
# varies depending on the source), and should be nearly the same for the same
# decoder settings regardless of which output format is being used.
#
# If this test fails, rerun it with --png and look at the images.

# XXX Add options to specify which decoders etc. to test

import argparse
import os
import shutil
import statistics
import subprocess
import sys

build_dir = None

def resolve_tool(*candidates):
    for path in candidates:
        if os.path.exists(path):
            return path
    return candidates[0]

def safe_unlink(filename):
    """Remove a file if it exists; if not, do nothing."""

    try:
        os.unlink(filename)
    except FileNotFoundError:
        pass

def clean(args, suffixes):
    """Remove output files, if they exist."""

    for suffix in suffixes:
        safe_unlink(args.output + suffix)

FFMPEG_CMD = ['ffmpeg', '-loglevel', 'error']
RGB_FORMAT = ['-f', 'rawvideo', '-pix_fmt', 'rgb48']
YUV_FORMAT = ['-f', 'rawvideo', '-pix_fmt', 'yuv444p16']

def test_encode(args, source, sc_locked, png_suffix):
    """Generate a test video in .rgb/.yuv form, and encode it to .tbc."""

    if args.input_format == 'yuv':
        suffix = '.yuv'
        PIX_FORMAT = YUV_FORMAT
    else:
        suffix = '.rgb'
        PIX_FORMAT = RGB_FORMAT

    clean(args, [suffix, png_suffix, '.tbc', '.tbc.db'])

    if args.system == 'ntsc':
        FILTER_ARGS = ['-filter:v', 'pad=758:486:-1:-1']
        SIZE_ARGS =   ['-s', '758x486', '-r', 'ntsc']
    else:
        FILTER_ARGS = ['-filter:v', 'pad=928:576:-1:-1']
        SIZE_ARGS   = ['-s', '928x576', '-r', 'pal']

    # Convert the source video to .rgb/.yuv using ffmpeg
    converted_file = args.output + suffix
    subprocess.check_call(
        FFMPEG_CMD
        + source
        + FILTER_ARGS
        + SIZE_ARGS
        + PIX_FORMAT + [converted_file]
        )
    if args.png:
        # Convert .rgb/.yuv to PNG
        subprocess.check_call(
            FFMPEG_CMD
            + SIZE_ARGS
            + PIX_FORMAT
            + ['-i', converted_file]
            + ['-frames:v', '1', args.output + png_suffix]
            )

    # Encode the .rgb to .tbc, using ld-chroma-encoder
    tbc_file = args.output + '.tbc'
    cmd = [resolve_tool(
        os.path.join(build_dir, 'bin', 'ld-chroma-encoder'),
        os.path.join(build_dir, 'src', 'ld-chroma-decoder', 'encoder', 'ld-chroma-encoder'),
    )]
    cmd += ['--input-format', args.input_format]
    if sc_locked:
        cmd += ['--sc-locked']
    if args.system == 'ntsc':
        cmd += ['--system', 'ntsc']
    cmd += [converted_file, tbc_file]
    subprocess.check_call(cmd)

def test_decode(args, decoder, phase_locked, output_format, png_suffix):
    """Decode a .tbc file, compare it with the original .rgb/.yuv, and return the
    median pSNR."""

    clean(args, ['.decoded', '.psnr', png_suffix])

    # Work out ffmpeg input format corresponding to ld-chroma-decoder output format
    if output_format == 'rgb':
        decoded_format = RGB_FORMAT
        if args.system == 'ntsc':
            decoded_format += ['-s', '758x486']
        else:
            decoded_format += ['-s', '928x576']
    elif output_format == 'yuv':
        if args.system == 'ntsc':
            decoded_format = ['-f', 'rawvideo', '-pix_fmt', 'yuv444p16', '-s', '758x486']
        else:
            decoded_format = ['-f', 'rawvideo', '-pix_fmt', 'yuv444p16', '-s', '928x576']
    else:
        # ffmpeg can read the Y4M header, but psnr fails if framerates mismatch
        decoded_format = ['-r', 'pal']

    # Decode the .tbc using ld-chroma-decoder
    tbc_file = args.output + '.tbc'
    decoded_file = args.output + '.decoded'
    cmd = [resolve_tool(
        os.path.join(build_dir, 'bin', 'ld-chroma-decoder'),
        os.path.join(build_dir, 'src', 'ld-chroma-decoder', 'ld-chroma-decoder'),
    ),
        '--quiet',
        '-f', decoder,
        '--chroma-nr', '0',
        '--luma-nr', '0',
        '--simple-pal',
        '--output-format', output_format,
        tbc_file, decoded_file,]
    if args.system == 'ntsc':
        cmd += ['--ffrl', '39', '--pad', '2']
        if phase_locked:
            cmd += ['--ntsc-phase-comp']
    subprocess.check_call(cmd)

    if args.png:
        # Convert decoded to PNG
        subprocess.check_call(
            FFMPEG_CMD
            + decoded_format + ['-i', decoded_file]
            + ['-frames:v', '1', args.output + png_suffix]
            )

    if args.input_format == 'yuv':
        suffix = '.yuv'
        PIX_FORMAT = YUV_FORMAT
        lavf_format = 'yuv444p16'
    else:
        suffix = '.rgb'
        PIX_FORMAT = RGB_FORMAT
        lavf_format = 'rgb48'

    if args.system == 'ntsc':
        SIZE_ARGS = ['-s', '758x486']
    else:
        SIZE_ARGS = ['-s', '928x576']

    # Compare the video from the pipe to the original .rgb/.yuv using ffmpeg
    ref_file = args.output + suffix
    psnr_file = args.output + '.psnr'
    subprocess.check_call(
        FFMPEG_CMD
        + decoded_format + ['-i', decoded_file]
        + PIX_FORMAT + SIZE_ARGS + ['-i', ref_file]
        + ['-lavfi', '[0:v] format=pix_fmts=%s, split [decoded];'
                     '[decoded][1:v] psnr=stats_file=%s' % (lavf_format, psnr_file),
           '-f', 'null', '-']
        )

    # Read the per-frame stats back from ffmpeg
    psnrs = []
    with open(psnr_file) as f:
        for line in f.readlines():
            for field in line.rstrip().split():
                parts = field.split(':', 1)
                if len(parts) == 2 and parts[0] == 'psnr_avg':
                    psnrs.append(float(parts[1]))
    return statistics.median(psnrs)

def main():
    parser = argparse.ArgumentParser(description='Test ld-chroma-decoder using ld-chroma-encoder')
    group = parser.add_argument_group("Encoding and decoding")
    group.add_argument('input', metavar='input', nargs='?', default=None,
                       help='input video file (default colour bars)')
    group.add_argument('output', metavar='output', nargs='?', default='testout/test',
                       help='base name for output files (default testout/test)')
    group.add_argument('--input-format', choices=['rgb', 'yuv'], default='rgb',
                       help='input format is RGB48 or YUV444P16')
    group.add_argument('--build', metavar='DIR',
                       help='build tree to test (default same as this script)')
    group.add_argument('--system', choices=['pal', 'ntsc'], default='pal',
                       help='select color system (default pal)')
    group.add_argument('--png', action='store_true',
                       help='output PNG files for first frame of input and output videos')
    group = parser.add_argument_group("Sanity checks")
    group.add_argument('--expect-psnr', metavar='DB', type=float, default=15.0,
                       help='expect median PSNR of at least (default 15)')
    group.add_argument('--expect-psnr-range', metavar='DB', type=float, default=1,
                       help='expect PSNRs for different formats to be within (default 1)')
    args = parser.parse_args()

    # Find the top-level source directory
    prog_path = os.path.realpath(sys.argv[0])
    global build_dir
    build_dir = os.path.dirname(os.path.dirname(prog_path))
    if args.build:
        build_dir = args.build

    if shutil.which('ffmpeg') is None:
        print('SKIP: ffmpeg not found in PATH')
        sys.exit(77)

    encoder_path = os.path.join(build_dir, 'bin', 'ld-chroma-encoder')
    decoder_path = os.path.join(build_dir, 'bin', 'ld-chroma-decoder')
    missing = [path for path in (encoder_path, decoder_path) if not os.path.exists(path)]
    if missing:
        print('SKIP: required tool(s) not built:', ', '.join(missing))
        sys.exit(77)

    # Remove display environment variables, as the decoding tools shouldn't
    # depend on having a display
    for var in ('DISPLAY', 'WAYLAND_DISPLAY'):
        if var in os.environ:
            del os.environ[var]

    # Ensure the directory containing output files exists
    output_dir = os.path.dirname(args.output)
    if output_dir != '':
        os.makedirs(output_dir, exist_ok=True)

    # Convert input spec into ffmpeg options
    if args.input:
        source = ['-i', args.input]
    else:
        # Generate colour bars
        if args.system == 'ntsc':
            source = ['-f', 'lavfi', '-i', 'smptebars=duration=0.5:size=758x486:rate=ntsc']
        else:
            source = ['-f', 'lavfi', '-i', 'pal75bars=duration=0.5:size=922x576:rate=pal']

    print('Running encode-decode tests with source', source[-1])
    columns = '%-18s %-15s %-8s %8s'

    if args.system == 'ntsc':
        decoder_modes = ('ntsc1d', 'ntsc2d', 'ntsc3d')
        print('\n' + columns % ('Phase-Comp (Dec)', 'Decoder', 'Format', 'PSNR (dB)'))
    else:
        decoder_modes = ('pal2d', 'transform2d', 'transform3d')
        print('\n' + columns % ('SC-locked (Enc)', 'Decoder', 'Format', 'PSNR (dB)'))

    failed = False

    # For each combination of parameters...
    for sc_locked in (False, True):
        # Encode
        try:
            test_encode(args, source, sc_locked, '.input.png')
        except subprocess.CalledProcessError as e:
            print('Encoding failed:', e)
            failed = True
            continue

        for decoder in decoder_modes:
            format_psnrs = []
            for output_format in ('rgb', 'yuv', 'y4m'):
                if args.system == 'ntsc':
                    sc_locked_str = 'pc' if sc_locked else 'll'
                else:
                    sc_locked_str = 'sc' if sc_locked else 'll'
                png_suffix = '.output-%s-%s-%s.png' % (sc_locked_str, decoder, output_format)
                # Decode and compare
                try:
                    psnr = test_decode(args, decoder, sc_locked, output_format, png_suffix)
                except subprocess.CalledProcessError as e:
                    print('Decoding failed:', e)
                    failed = True
                    continue
                print(columns % (sc_locked, decoder, output_format, '%.2f' % psnr))
                format_psnrs.append(psnr)

                # Check PSNR for this case
                if psnr < args.expect_psnr:
                    print('FAIL: PSNR too low (expect %s dB)' % args.expect_psnr)
                    failed = True

            # Check PSNR for this group of formats
            psnr_range = max(format_psnrs) - min(format_psnrs)
            if psnr_range > args.expect_psnr_range:
                print('FAIL: PSNR range for different formats too high (expect %s dB)' % args.expect_psnr_range)
                failed = True

    if failed:
        print('\nTest failed')
        sys.exit(1)
    else:
        print('\nTest completed successfullly')
        sys.exit(0)

if __name__ == '__main__':
    main()
