#!python
#
# Run a simulation of random walk on graph under specified conditions.
# Copyright (c) 2023, Hiroyuki Ohsaki.
# All rights reserved.
#

# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.

# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.

import random
import statistics
import sys

from perlcompat import die, warn, getopts
import randwalk
import tbdump

MAX_STEPS = 10000

def usage():
    die(f"""\
usage: {sys.argv[0]} [-s #] [-n #] [-k #]
  -s #      seed of random number generator
  -n #      the number of vertices (default: 100)
  -k #      average degree (default: 3)
  -t type   type of graph ({randwalk.GRAPH_TYPES})
  -a agent  type of agent ({randwalk.AGENT_TYPES})
  -A #      bias parameter alpha (default: 0)
""")

def header_str():
    return '# agent    \tN\tM\ttype\tcount\tabort\tC\t95%\tH\t95%\tE[H]\t95%'

def status_str(agent, g, count, naborts, covers, hittings, mean_hittings):
    name = agent.name()
    n = g.nvertices()
    m = g.nedges()
    type_ = g.get_graph_attribute('name')
    # Collect statistics.
    cover = agent.step
    covers.append(cover)
    # NOTE: hiting[v] records the hitting time at vertex V.
    hitting = agent.hitting[agent.target]
    hittings.append(hitting)
    mean_hitting = statistics.mean(agent.hitting.values())
    mean_hittings.append(mean_hitting)
    c_avg, c_conf = randwalk.mean_and_conf95(covers)
    h_avg, h_conf = randwalk.mean_and_conf95(hittings)
    mh_avg, mh_conf = randwalk.mean_and_conf95(mean_hittings)
    return f'{name:12}\t{n}\t{m}\t{type_[:7]}\t{count}\t{naborts}\t{c_avg:.0f} \t{c_conf:.0f} \t{h_avg:.0f} \t{h_conf:.0f} \t{mh_avg:.0f} \t{mh_conf:.0f} '

def simulate(agent_type, g, start_vertex=1, alpha=0, ntrials=100):
    covers = []
    hittings = []
    mean_hittings = []
    naborts = 0
    for count in range(1, ntrials + 1):
        agent = randwalk.create_agent(agent_type,
                                      graph=g,
                                      current=start_vertex,
                                      target=g.nvertices(),
                                      alpha=alpha)
        # Perform an instance of simulation.
        while agent.ncovered < g.nvertices():
            agent.advance()
            if agent.step > MAX_STEPS:
                naborts += 1
                break
        stat = status_str(agent, g, count, naborts, covers, hittings,
                          mean_hittings)
        print(stat + '\r', file=sys.stderr, end='')
        # Abort the experiment if it takes too long.
        if naborts >= 10:
            break
    # FIXME: Workaround when the stdout is redirected.
    if not sys.stdout.isatty():
        print('', file=sys.stderr)
    print(stat + ' ')

def main():
    opt = getopts('s:n:k:t:a:A:') or usage()
    seed = int(opt.s) if opt.s else 1
    n_nodes = int(opt.n) if opt.n else 100
    kavg = float(opt.k) if opt.k else 2.5
    graph_type = opt.t if opt.t else 'random,tree'
    agent_type = opt.a if opt.a else 'BiasedRW,NBRW,SARW'
    alpha = float(opt.A) if opt.A else 0.
    start_vertex = 1
    ntrials = 1000

    random.seed(seed)
    print(header_str())
    for graph in graph_type.split(','):
        g = randwalk.create_graph(graph, n_nodes, kavg)
        with open('foo', 'w') as f:
            f.write(g.export_dot())
        for agent in agent_type.split(','):
            simulate(agent, g, start_vertex, alpha, ntrials)

if __name__ == "__main__":
    main()
