#!/usr/bin/env python3

import subprocess
import sys
import os
import re
import signal
import shlex
from typing import List, Iterable
import argparse

class Commit(object):

    parents : List[str]

    def __init__(self, git, rev):
        self.parents = list()
        buffer = git.cmd("git log -n1 --no-notes --pretty=raw".split() + [rev], quiet=True)
        (headers, message) = buffer.split("\n\n", 1)
        for header in headers.split("\n"):
            (key, value) = header.strip().split(' ', 1)
            if key == 'commit':
                self.sha = value
            if key == 'parent':
                self.parents.append(value)
            if key == 'tree':
                self.tree = value
            if key == 'author':
                self.author = value
            if key == 'committer':
                self.committer = value
        assert message.endswith('\n')
        message = message[:-1].split('\n')
        assert all(x.startswith('    ') for x in message)
        message = [x[4:] for x in message]
        message = '\n'.join(message) + '\n'
        self.message = message

FNULL = open(os.devnull, 'w')

class Git(object):

    def cmd(self, args, quiet=False, **kw):
        if not quiet:
            print("+", ' '.join(map(shlex.quote, args)))
        proc = subprocess.Popen(args, cwd=self.directory,
                                encoding='utf8',
                                stdin=FNULL, stdout=subprocess.PIPE, **kw)
        (out, err) = proc.communicate()
        if proc.wait() != 0:
            raise Exception("git failed: " + out)
        return out

    def cmd_test(self, args, **kw) -> bool:
        proc = subprocess.Popen(args, cwd=self.directory,
                                stdin=FNULL, **kw)
        code = proc.wait()
        if code not in [0,1]:
            raise Exception("git failed")
        return not code

    def __init__(self, directory=None):
        if directory is None:
            directory = '.'
        self.directory = directory
        self.directory = self.cmd("git rev-parse --show-toplevel".split(), quiet=True).strip()
        if self.directory == '':
            raise Exception("can't find top-level.  (bare repo?)")

    def is_clean(self) -> bool:
        return ('' == self.cmd("git diff-index --cached --name-only HEAD".split(), quiet=True) and
                '' == self.cmd("git diff-files --name-only".split(), quiet=True))

    # is a an ancestor of b
    def is_ancestor(self, a: str, b: str) -> bool:
        return self.cmd_test(['git', 'merge-base', '--is-ancestor', a, b])

    def rev_parse(self, commit: str) -> str:
        return self.cmd(['git', 'rev-parse', commit], quiet=True).strip()

    def checkout(self, branch: str) -> None:
        self.cmd(['git', 'checkout', branch], stderr=FNULL)

    def force_checkout(self, branch: str) -> None:
        self.cmd(['git', 'checkout', '-f', branch], stderr=FNULL)

    def detach(self) -> None:
        self.cmd(['git', 'checkout', self.rev_parse('HEAD')], stderr=FNULL)

    def merge(self, commit: str, message: str) -> bool:
        merged = self.cmd_test(['git', 'merge', '--no-edit', '-m', message, commit], stdout=FNULL)
        if not merged:
            self.cmd(['git', 'merge', '--abort'])
        return merged

    def current_branch(self) -> str:
        branch = self.cmd(['git', 'symbolic-ref', 'HEAD'], quiet=True).strip()
        if branch == 'refs/heads/' or not branch.startswith('refs/heads/'):
            raise Exception("not on a branch")
        return branch[len('refs/heads/'):]

    def trees_differ(self, a, b) -> bool:
        return not git.cmd_test(['git', 'diff', '--quiet', a, b])

    def unmerged(self) -> Iterable[str]:
        for line in self.cmd(['git', 'ls-files', '--unmerged']).splitlines():
            _, filename = line.split("\t", 1)
            yield filename


git = Git()


def remerge_multi_merges() -> None:
    if not git.is_clean():
        raise Exception("repo not clean")

    head = Commit(git, 'HEAD')

    merges : List[Commit] = [head]
    prev = None
    while True:
        commit = merges[-1]
        if len(commit.parents) != 2:
            break
        if prev is not None:
            if not git.is_ancestor(commit.parents[1], prev.parents[1]):
                break
        prev = commit
        merges.append(Commit(git, commit.parents[0]))
    merges = merges[:-1]

    current_branch = git.current_branch()

    def sigint_handler(signal, frame):
        git.force_checkout(current_branch)
        sys.exit(0)
    signal.signal(signal.SIGINT, sigint_handler)

    try:

        git.detach()

        while len(merges) >= 2:
            last = merges[-1]

            print("trying base:", git.cmd(['git', 'log', '-n1', '--oneline', last.parents[0]]).strip())
            git.checkout(last.parents[0])
            merged = git.merge(head.parents[1], head.message)

            if not merged:
                print("merge failed")
            elif git.trees_differ('HEAD', head.sha):
                print("merge differed:", \
                    git.cmd(['git', 'rev-parse', '--short', 'HEAD'], quiet=True).strip(),  \
                    git.cmd(['git', 'rev-parse', '--short', head.sha], quiet=True).strip())

            if merged and not git.trees_differ('HEAD', head.sha):
                git.cmd(['git', 'update-ref', '-m', 'remerge', 'refs/heads/' + current_branch, 'HEAD'])
                git.checkout(current_branch)
                print("remerged!")
                break

            merges.pop()

        else:
            print("nothing changed")
            git.checkout(current_branch)

    except:
        git.force_checkout(current_branch)
        raise

def pick_cherries(git: Git, cherries: List[str], contiguous: bool, keep: bool) -> None:
    if not cherries:
            return

    base = git.rev_parse("HEAD")

    if contiguous:
        # Since the list of cherries is contiguous, they can be picked using
        # rebase.   This will automatically remove any commits in the series
        # that were cherry-picked down into branches
        git.checkout(cherries[-1])
        try:
            git.cmd(['git', 'rebase', f"{cherries[0]}^", '--onto', base])
            return
        except:
            if not keep:
                git.cmd(['git', 'rebase', '--abort'])
            raise


    # Try using cherry-pick.   This can fail if commits from this series
    # have been cherry picked down into some of the branches that got
    # merged
    try:
        git.cmd(['git', 'cherry-pick', '--allow-empty', *cherries])
        return
    except:
        if not keep:
            git.cmd(['git', 'cherry-pick', '--abort'])
        raise


def remerge_integration_branch(keep: bool) -> None:
    with open('.branches', 'r') as f:
        branches_content = f.read()

    def i():
        for line in branches_content.splitlines():
            if re.match(r'^\s*(\#|$)', line):
                continue
            for branch in shlex.split(line):
                yield branch
    branches = list(i())

    if ".branches" in git.cmd("git diff-files --name-only".split(), quiet=True).splitlines():
        git.cmd(['git', 'commit', '-m', 'branches', '.branches'])

    if not git.is_clean():
        raise Exception("repo not clean")

    current_branch = git.current_branch()

    branches_commits = [line.strip() for line in
        git.cmd(["git", "log", "--format=%H", "--", ".branches"], quiet=True).splitlines()]
    first_branches_commit = None
    if branches_commits:
        first_branches_commit = branches_commits[-1]

    branches_shas = {git.rev_parse(b) for b in branches}

    # find first merge, and collect a list of cherries
    commit = Commit(git, 'HEAD')
    cherries : List[str] = list()
    summary : str = "" # a string like " cccccc ".  "c" is a cherry, a space is a skip
    while True:
        if len(commit.parents) > 1:
            break # found merge
        files = git.cmd(
            ["git", "diff", "--name-only", commit.parents[0], commit.sha], quiet=True)
        if files.strip() == ".branches":
            summary += " "
        else:
            cherries.append(commit.sha)
            summary += "c"
        if commit.sha == first_branches_commit:
            break # don't look back past the first commit to touch .branches
        if commit.sha in branches_shas:
            break # don't look back past the specified branches
        commit = Commit(git, commit.parents[0])

    cherries_are_contiguous = bool(re.match(r'^\s*c*\s*$', summary))
    cherries.reverse()

    git.detach()

    try:
        for octopus in (True, False):
            try:
                git.cmd(['git', 'reset', '--hard', branches[0]])

                with open('.branches', 'w') as f:
                    f.write(branches_content)
                git.cmd(['git', 'add', '.branches'])
                git.cmd(['git', 'commit', '-m', '.branches'])

                if len(branches) == 1:
                    break

                if not octopus:
                    for branch in branches[1:]:
                        try:
                            git.cmd(['git', 'merge', branch])
                        except:
                            if set(git.unmerged()) == {".branches"}:
                                git.cmd(['git', 'checkout', '--ours', '.branches'])
                                git.cmd(['git', 'add', '.branches'])
                                git.cmd(['git', '-c', 'core.editor=true', 'merge', '--continue'])
                            else:
                                raise
                else:
                    git.cmd(['git', 'merge'] + branches[1:])
            except:
                if octopus == False:
                    raise
                else:
                    continue
            else:
                break

        pick_cherries(git, cherries, cherries_are_contiguous, keep)

    except:
        if not keep:
            git.force_checkout(current_branch)
        raise

    git.cmd(['git', 'update-ref', '-m', 'remerge', 'refs/heads/' + current_branch, 'HEAD'])
    git.cmd(['git', 'checkout', current_branch])


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--keep", action="store_true")
    args = parser.parse_args()

    if os.path.exists('.branches'):
        print("resetting integration branch")
        remerge_integration_branch(args.keep)
    else:
        print("attempting to coalesce redundant merges")
        remerge_multi_merges()


if __name__ == "__main__":
    main()
