#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File   : jac-hook-ipdb.py
# Author : Jiayuan Mao
# Email  : maojiayuan@gmail.com
# Date   : 12/26/2023
#
# This file is part of Jacinle.
# Distributed under terms of the MIT license.

import argparse


g_ipdb_hook_code = r"""### Maintained by <jac-hook-ipdb>
import sys

def _custom_exception_hook(type, value, tb):
    if hasattr(sys, 'ps1') or not sys.stderr.isatty():
        # we are in interactive mode or we don't have a tty-like
        # device, so we call the default hook
        sys.__excepthook__(type, value, tb)
    else:
        import traceback, ipdb
        # we are NOT in interactive mode, print the exception...
        traceback.print_exception(type, value, tb)
        # ...then start the debugger in post-mortem mode.
        ipdb.post_mortem(tb)


def hook_exception_ipdb():
    # Add a hook to ipdb when an exception is raised.
    if not hasattr(_custom_exception_hook, 'origin_hook'):
        _custom_exception_hook.origin_hook = sys.excepthook
        sys.excepthook = _custom_exception_hook


def unhook_exception_ipdb():
    # Remove the hook to ipdb when an exception is raised.
    assert hasattr(_custom_exception_hook, 'origin_hook')
    sys.excepthook = _custom_exception_hook.origin_hook


hook_exception_ipdb()

### End of <jac-hook-ipdb>

"""


parser = argparse.ArgumentParser()
parser.add_argument('filename', type=str)
parser.add_argument('--hook', action='store_true')
parser.add_argument('--unhook', action='store_true')
args = parser.parse_args()


def main():
    with open(args.filename, 'r') as f:
        content = f.read()

    if args.hook:
        if g_ipdb_hook_code in content:
            print('Already hooked.')
            return
        content = g_ipdb_hook_code + content
        print('Hooked: {}'.format(args.filename))
    elif args.unhook:
        if g_ipdb_hook_code not in content:
            print('Not hooked.')
            return
        content = content.replace(g_ipdb_hook_code, '')
        print('Unhooked: {}'.format(args.filename))

    with open(args.filename, 'w') as f:
        f.write(content)


if __name__ == '__main__':
    main()

