#!/usr/bin/env python
import random
import numpy as np
import h5py
import re

import argparse
parser = argparse.ArgumentParser(
    description='Generate graph data for flagser.',
    usage='python h5toflagser.py [--groups L5_TTPC1,L4_*] [--random-edge-filtration] h5file.h5 output.flag')
parser.add_argument(
    '--groups', type=str, default='all', help='all, or a comma separated list of groups to extract, e.g. L5_TTPC1,L5_TTPC2 (wihtout spaces). You can use e.g. L5_* to match the whole layer 5.')
parser.add_argument('--random-edge-filtration', action='store_true',
                    help='add this argument if each edge should be assigned a random filtration value between 0 and 1.')
parser.add_argument('h5file', type=str, help='the input h5-file')
parser.add_argument('output', type=str, help='the output filename')
args = parser.parse_args()

f = h5py.File(args.h5file, "r")
connectivity_group = f.get('connectivity')
connectivity_group = connectivity_group if connectivity_group != None else f.get(
    'connectivty')
groups = connectivity_group.keys()

groups_to_extract = 'all' if args.groups == 'all' else [
    re.compile(group.strip()) for group in args.groups.split(',')]

of = open(args.output, 'w')
of.write('dim 0\n')

all_groups = args.groups == 'all'


def is_group_to_be_extracted(group):
    return any(group_regex.match(group) for group_regex in groups_to_extract)


population_offsets = {}
offset = 0
number_of_groups = 0
for group in groups:
    if not all_groups and not is_group_to_be_extracted(group):
        continue
    number_of_groups += 1
    population_offsets[group] = offset
    population_size = np.shape(connectivity_group.get(
        group + '/' + group + '/cMat'))[0]
    for i in range(population_size):
        of.write("0 ")
    offset += population_size

of.write('\ndim 1\n')
current = 0
edge_nr = 0
for group in groups:
    if not all_groups and not is_group_to_be_extracted(group):
        continue
    current += 1
    print('Writing group ' + group + ', which is number ' + str(current) +
          ' of ' + str(number_of_groups) + '.')
    for group2 in connectivity_group.get(group):
        if not all_groups and not is_group_to_be_extracted(group2):
            continue
        matrix = connectivity_group.get(group + '/' + group2 + '/cMat')
        edge_nr += np.sum(matrix)
        for (i, j) in np.transpose(np.nonzero(matrix)):
            of.write(str(population_offsets[group] + i) + ' ' +
                     str(population_offsets[group2] + j) +
                     ((' ' + "{0:.4f}".format(random.uniform(0, 1))) if args.random_edge_filtration else '') + '\n')

print('Successfully wrote a graph with ' + str(edge_nr) + ' edges.')
