#!python

import argparse
import badbyte
from badbyte.utils.colors import RED, GREEN, RST
from badbyte.utils.functions import unhexify
from badbyte.utils.functions import analyze, generate_characters

print(f"""

{RED}  _               _{GREEN} _           _       {RST}
{RED} | |__   __ _  __| {GREEN}| |__  _   _| |_ ___ {RST}
{RED} | '_ \ / _` |/ _` {GREEN}| '_ \| | | | __/ _ \{RST}
{RED} | |_) | (_| | (_| {GREEN}| |_) | |_| | ||  __/{RST}
{RED} |_.__/ \__,_|\__,_{GREEN}|_.__/ \__, |\__\___|{RST}
                        {GREEN}   |___/             {RST}
                        by C3l1n v{badbyte.__version__}
""")

parser = argparse.ArgumentParser()
parser.add_argument("mode", help="g[enerate] | p[arse]")
parser.add_argument("-c", "--custom", action='store_true',
                    help="use custom payload in parse (usefull when you try to search for subsequent bad cahrs).")
parser.add_argument("--pre", type=str, default="BAD_START", help="payload prefix - default BAD_START")
parser.add_argument("--post", type=str, default="BAD_STOP", help="payload postfix - default BAD_STOP")
parser.add_argument("--bad", type=str, default="", help="Banned characters as hexstring i.e. '3D 0D 0A'")

args = parser.parse_args()


if __name__ == "__main__":
    bad = []
    if args.bad != "":
        bad = [b for b in unhexify(args.bad)]
    payload_raw = bytes([x for x in range(255, 0, -1) if x not in bad])
    PREFIX = args.pre.encode("latin1")
    POSTFIX = args.post.encode("latin1")
    payload = generate_characters(PREFIX, POSTFIX, bad)
    if args.mode[0] == "g":
        print(f"-> Known bad characters: {RED}{bytes(bad)} {RST}")
        print(payload)
    elif args.mode[0] == "p":
        print(f"{GREEN}Please paste hexdump (every non whitespace is treated as hexstring). "
              f"Press CTRL+D to end.{RST}")
        try:
            hexdump = ""
            while True:
                hexdump += input()
        except EOFError:
            pass
        if args.custom:
            payload = input(f"{GREEN}Enter payload as python bytestring {RST}: ")
            if payload[0:2] == "b'" or payload[0:2] == 'b"':
                payload = payload[2:-1]
            payload = payload.encode('utf-8').decode("unicode_escape").encode("latin1")
        analyze(hexdump, PREFIX, POSTFIX, payload)
    else:
        parser.print_help()
