#!python

import logging
import os
import sys

LOG_FMT = "%(asctime)s %(levelname).1s - %(message)s"
logging.basicConfig(level=logging.INFO, format=LOG_FMT, datefmt="%Y/%m/%d %H:%M:%S")
log = logging.getLogger()

strategy = None
strategy_desc = ""
if "TF_CONFIG" in os.environ:
    tf_config = os.environ["TF_CONFIG"]
    log.info(f"TF_CONFIG detected: {tf_config}")
    log.info("Switching to MultiWorker distributed strategy")

    import tensorflow as tf

    strategy = tf.distribute.MultiWorkerMirroredStrategy()
    strategy_desc = "Multi Host / Multi GPU"

# for debugging data-parallel
if "NUM_VIRTUAL_DEVICES" in os.environ:
    # Simulate multiple CPUs with virtual devices
    N_VIRTUAL_DEVICES = int(os.environ["NUM_VIRTUAL_DEVICES"])
    log.info(f"Setting {N_VIRTUAL_DEVICES} virtual devices")
    os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
    import tensorflow as tf

    physical_devices = tf.config.list_physical_devices("CPU")
    tf.config.set_logical_device_configuration(
        physical_devices[0],
        [
            tf.config.LogicalDeviceConfiguration()  # memory_limit=3024
            for _ in range(N_VIRTUAL_DEVICES)
        ],
    )
    log.info("Available devices:")
    devices = tf.config.list_logical_devices()
    for i, device in enumerate(devices):
        log.info(f"{i}) {device}")

from tensorpotential.cli.gracemaker import main

if __name__ == "__main__":
    main(sys.argv[1:], strategy=strategy, strategy_desc=strategy_desc)
