#!/usr/bin/env python3

import argparse


parser = argparse.ArgumentParser()
parser.add_argument('path', type=str, help='path to input file')
parser.add_argument('output', type=str, help='path to output file')
parser.add_argument('operation', type=str, help='operation (e.g. set-entry-point, set-stack)')
parser.add_argument('arguments', type=str, nargs='*', help='arguments (e.g. cs, ip)')
args = parser.parse_args()

assert args.operation in (
    'set-entry-point',
    'set-stack',
    'set-relocations',
    'set-minimum-allocation',
    'set-maximum-allocation',
    'strip-tail',
    'add-text-size',
    'fix-text-size',
    'fill-last-page',
    'preallocate-minimum',
)

with open(args.path, 'rb') as f:
    bytes = bytearray(f.read())
assert len(bytes) >= 0x1c
assert bytes[:2] in b'MZM'
header_size = int.from_bytes(bytes[8:0x0a], 'little') * 0x10
assert len(bytes) >= header_size


if args.operation == 'set-entry-point':
    assert len(args.arguments) == 2
    cs, ip = map(int, args.arguments)
    bytes[0x14:0x16] = ip.to_bytes(2, 'little')
    bytes[0x16:0x18] = cs.to_bytes(2, 'little')
    with open(args.output, 'wb') as f:
        f.write(bytes)

elif args.operation == 'set-stack':
    assert len(args.arguments) == 2
    ss, sp = map(int, args.arguments)
    bytes[0x0e:0x10] = ss.to_bytes(2, 'little')
    bytes[0x10:0x12] = sp.to_bytes(2, 'little')
    with open(args.output, 'wb') as f:
        f.write(bytes)

elif args.operation == 'set-minimum-allocation':
    assert len(args.arguments) == 1
    allocation, = map(int, args.arguments)
    assert allocation & 0x0f == 0
    bytes[0x0a:0x0c] = (allocation >> 4).to_bytes(2, 'little')
    with open(args.output, 'wb') as f:
        f.write(bytes)

elif args.operation == 'set-maximum-allocation':
    assert len(args.arguments) == 1
    allocation, = map(int, args.arguments)
    assert allocation & 0x0f == 0
    bytes[0x0c:0x0e] = (allocation >> 4).to_bytes(2, 'little')
    with open(args.output, 'wb') as f:
        f.write(bytes)

elif args.operation == 'set-relocations':
    relocs_base = int.from_bytes(bytes[0x18:0x1a], 'little')
    if relocs_base + len(args.arguments)*4 > header_size:
        old_pages = (header_size + 0x1ff) // 0x200
        new_pages = (relocs_base + len(args.arguments)*4 + 0x1ff) // 0x200
        bytes = bytes[:header_size] + b'\x00'*(new_pages-old_pages)*0x200 + bytes[header_size:]
        header_size += (new_pages-old_pages)*0x200
        bytes[8:0x0a] = (header_size//0x10).to_bytes(2, 'little')
        bytes[4:6] = (int.from_bytes(bytes[4:6], 'little')+new_pages-old_pages).to_bytes(2, 'little')
        # FIXME
    for index, address in enumerate(map(int, args.arguments)):
        offset = address & 0xffff
        segment = (address & 0xffff0000) // 0x10
        assert 0 <= relocs_base+index*4+4 <= header_size
        bytes[relocs_base+index*4:relocs_base+index*4+4] = offset.to_bytes(2, 'little') + segment.to_bytes(2, 'little')
    bytes[6:8] = len(args.arguments).to_bytes(2, 'little')
    with open(args.output, 'wb') as f:
        f.write(bytes)

elif args.operation == 'strip-tail':
    last_page_size = int.from_bytes(bytes[2:4], 'little')
    assert 0 <= last_page_size <= 0x200
    all_pages = int.from_bytes(bytes[4:6], 'little')
    text_size = (all_pages - bool(last_page_size)) * 0x200 + last_page_size - header_size
    with open(args.output, 'wb') as f:
        f.write(bytes[:header_size+text_size])

elif args.operation == 'add-text-size':
    # You will need to call `fix-text-size` after that. No tail allowed.
    # If text size is too big, DOS may complain:
    # - https://github.com/microsoft/MS-DOS/blob/2d04cacc5322951f187bb17e017c12920ac8ebe2/v2.0/source/EXEC.ASM#L593 ;
    # - https://github.com/microsoft/MS-DOS/blob/2d04cacc5322951f187bb17e017c12920ac8ebe2/v4.0/src/DOS/EXEC.ASM#L458 .
    assert len(args.arguments) <= 1
    add_size, = map(int, args.arguments)
    last_page_size = int.from_bytes(bytes[2:4], 'little')
    assert 0 <= last_page_size <= 0x200
    all_pages = int.from_bytes(bytes[4:6], 'little')
    text_size = (all_pages - bool(last_page_size)) * 0x200 + last_page_size - header_size
    assert text_size + header_size == len(bytes)
    new_size = header_size + text_size + add_size
    all_pages, last_page_size = (new_size + 0x1ff) // 0x200, new_size % 0x200
    bytes[2:4] = last_page_size.to_bytes(2, 'little')
    bytes[4:6] = all_pages.to_bytes(2, 'little')
    with open(args.output, 'wb') as f:
        f.write(bytes)

elif args.operation == 'fix-text-size':
    assert len(args.arguments) <= 1
    fill = (args.arguments and args.arguments[0] or 'github.com/excitoon/reasm\\0').encode('raw_unicode_escape').decode('unicode_escape').encode()
    last_page_size = int.from_bytes(bytes[2:4], 'little')
    assert 0 <= last_page_size <= 0x200
    all_pages = int.from_bytes(bytes[4:6], 'little')
    text_size = (all_pages - bool(last_page_size)) * 0x200 + last_page_size - header_size
    assert text_size + header_size > len(bytes)
    add_bytes = text_size + header_size - len(bytes)
    new_bytes = b''
    while len(new_bytes) < add_bytes:
        new_bytes += fill
    new_bytes = new_bytes[:add_bytes]
    with open(args.output, 'wb') as f:
        f.write(bytes)
        f.write(new_bytes)

elif args.operation == 'fill-last-page':
    # Moves tail forward if applicable.
    # If the application expects last page to be fully loaded (including part of tail probably), it may fail.
    assert len(args.arguments) <= 1
    fill = (args.arguments and args.arguments[0] or 'github.com/excitoon/reasm\\0').encode('raw_unicode_escape').decode('unicode_escape').encode()
    last_page_size = int.from_bytes(bytes[2:4], 'little')
    assert 0 <= last_page_size <= 0x200
    all_pages = int.from_bytes(bytes[4:6], 'little')
    text_size = (all_pages - bool(last_page_size)) * 0x200 + last_page_size - header_size
    assert text_size + header_size <= len(bytes)
    bytes, tail = bytes[:header_size+text_size], bytes[header_size+text_size:]
    last_page_size, add_bytes = 0, (0x200 - last_page_size) % 0x200
    new_bytes = b''
    while len(new_bytes) < add_bytes:
        new_bytes += fill
    new_bytes = new_bytes[:add_bytes]
    bytes[2:4] = last_page_size.to_bytes(2, 'little')
    bytes[4:6] = all_pages.to_bytes(2, 'little')
    with open(args.output, 'wb') as f:
        f.write(bytes)
        f.write(new_bytes)
        f.write(tail)

elif args.operation == 'preallocate-minimum':
    # Assumes `fill-last-page`.
    assert len(args.arguments) <= 1
    fill = (args.arguments and args.arguments[0] or 'github.com/excitoon/reasm\\0').encode('raw_unicode_escape').decode('unicode_escape').encode()
    last_page_size = int.from_bytes(bytes[2:4], 'little')
    assert 0 <= last_page_size <= 0x200
    all_pages = int.from_bytes(bytes[4:6], 'little')
    text_size = (all_pages - bool(last_page_size)) * 0x200 + last_page_size - header_size
    minimum_allocation = int.from_bytes(bytes[0x0a:0x0c], 'little') * 0x10
    maximum_allocation = int.from_bytes(bytes[0x0c:0x0e], 'little') * 0x10
    assert text_size + header_size <= len(bytes)
    bytes, tail = bytes[:header_size+text_size], bytes[header_size+text_size:]
    if maximum_allocation == 0:
        # Special value:
        # - https://github.com/microsoft/MS-DOS/blob/2d04cacc5322951f187bb17e017c12920ac8ebe2/v2.0/source/EXEC.ASM#L330 ;
        # - https://github.com/microsoft/MS-DOS/blob/2d04cacc5322951f187bb17e017c12920ac8ebe2/v4.0/src/DOS/EXEC.ASM#L300 .
        maximum_allocation = minimum_allocation
    minimum_allocation, maximum_allocation, add_bytes = 0, maximum_allocation-minimum_allocation, (0x200 - last_page_size) % 0x200 + minimum_allocation
    new_size = header_size + text_size + add_bytes
    all_pages, last_page_size = (new_size + 0x1ff) // 0x200, new_size % 0x200
    new_bytes = b''
    while len(new_bytes) < add_bytes:
        new_bytes += fill
    new_bytes = new_bytes[:add_bytes]
    bytes[2:4] = last_page_size.to_bytes(2, 'little')
    bytes[4:6] = all_pages.to_bytes(2, 'little')
    bytes[0x0a:0x0c] = (minimum_allocation // 0x10).to_bytes(2, 'little')
    bytes[0x0c:0x0e] = (maximum_allocation // 0x10).to_bytes(2, 'little')
    with open(args.output, 'wb') as f:
        f.write(bytes)
        f.write(new_bytes)
        f.write(tail)
