#!python
import argparse
import glob
import json
import logging
import os
import shutil
import sys

import tqdm
from ase.data import chemical_symbols

EQUI_STRUCTURE_STRATEGY = "structures"
EQUI_ATOMS_STRATEGY = "atoms"

ALEXANDRIA_ELEMENTS = [
    "Ac",
    "Ag",
    "Al",
    "Ar",
    "As",
    "Au",
    "B",
    "Ba",
    "Be",
    "Bi",
    "Br",
    "C",
    "Ca",
    "Cd",
    "Ce",
    "Cl",
    "Co",
    "Cr",
    "Cs",
    "Cu",
    "Dy",
    "Er",
    "Eu",
    "F",
    "Fe",
    "Ga",
    "Gd",
    "Ge",
    "H",
    "He",
    "Hf",
    "Hg",
    "Ho",
    "I",
    "In",
    "Ir",
    "K",
    "Kr",
    "La",
    "Li",
    "Lu",
    "Mg",
    "Mn",
    "Mo",
    "N",
    "Na",
    "Nb",
    "Nd",
    "Ne",
    "Ni",
    "Np",
    "O",
    "Os",
    "P",
    "Pa",
    "Pb",
    "Pd",
    "Pm",
    "Pr",
    "Pt",
    "Pu",
    "Rb",
    "Re",
    "Rh",
    "Ru",
    "S",
    "Sb",
    "Sc",
    "Se",
    "Si",
    "Sm",
    "Sn",
    "Sr",
    "Ta",
    "Tb",
    "Tc",
    "Te",
    "Th",
    "Ti",
    "Tl",
    "Tm",
    "U",
    "V",
    "W",
    "Xe",
    "Y",
    "Yb",
    "Zn",
    "Zr",
]

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"


orders = ["n_neighbours", "n_atoms", "n_structures"]

import pandas as pd

from tensorpotential.data.databuilder import (
    GeometricalDataBuilder,
    ReferenceEnergyForcesStressesDataBuilder,
    split_batches_into_buckets,
)
from tensorpotential import constants as tc

import tensorflow as tf
from tensorpotential.data.process_df import (
    ASE_ATOMS,
    ENERGY_CORRECTED_COL,
    FORCES_COL,
    STRESS_COL,
)

import numpy as np


class NpEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.integer):
            return int(obj)
        if isinstance(obj, np.floating):
            return float(obj)
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        return super(NpEncoder, self).default(obj)


remap_type_to_tf_dtype = {
    np.int32: tf.int32,
    np.float64: tf.float64,
    np.float32: tf.float32,
    np.dtype("int32"): tf.int32,
    np.dtype("float64"): tf.float64,
    np.dtype("float32"): tf.float32,
}


def get_elements(df_iterrows):
    elements_set = set()
    for _, row in df_iterrows:
        atoms = row[ASE_ATOMS]
        elements_set.update(atoms.get_chemical_symbols())
    return sorted(elements_set)


def get_batch_dtypes(databuilders_list):
    batch_dtypes = {}
    for db in databuilders_list:
        batch_dtypes.update(
            {k: remap_type_to_tf_dtype[v] for k, v in db.get_batch_dtypes().items()}
        )
    return batch_dtypes


def sizeof_fmt(file_name_or_size, suffix="B"):
    if isinstance(file_name_or_size, str):
        file_name_or_size = os.path.getsize(file_name_or_size)
    for unit in ["", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"]:
        if abs(file_name_or_size) < 1024.0:
            return "%3.1f%s%s" % (file_name_or_size, unit, suffix)
        file_name_or_size /= 1024.0
    return "%.1f%s%s" % (file_name_or_size, "Yi", suffix)


# Function to yield dict[str->numpy array], that defines each structure
def databuilder_generator(dataframes_rows, databuilders_list):
    for row_id, row in dataframes_rows:
        res_dict = {}
        for db in databuilders_list:
            res_dict.update(
                db.extract_from_row(row, structure_id=np.array(row_id, dtype=np.int32))
            )
        yield res_dict


def infer_tensorspec(dataframes_rows, databuilders_list):
    batch = next(databuilder_generator(dataframes_rows, databuilders_list))
    tf_spec = {}
    for k, v in batch.items():
        if isinstance(v, np.ndarray):
            tf_shape = list(v.shape)
            if len(tf_shape) > 0:
                tf_shape[0] = None
            tf_spec[k] = tf.TensorSpec(
                shape=tf_shape, dtype=remap_type_to_tf_dtype[v.dtype], name=k
            )

    return tf_spec


def get_window_stat(batch_list):
    stat_dict = {"n_structures": 0, "n_atoms": 0, "n_neighbours": 0}

    for b in batch_list:
        stat_dict["n_structures"] += b[tc.N_STRUCTURES_BATCH_REAL].numpy()
        stat_dict["n_atoms"] += b[tc.N_ATOMS_BATCH_REAL].numpy()
        stat_dict["n_neighbours"] += b[tc.N_NEIGHBORS_REAL].numpy()

    return stat_dict


def compute_batches_df(dataset, batch_size, strategy=EQUI_STRUCTURE_STRATEGY):
    pbar = tqdm.tqdm(dataset, mininterval=2)
    iterator = iter(pbar)

    stat_list = []
    while True:
        try:
            batch_list = get_batch_list(iterator, batch_size, strategy=strategy)
        except StopIteration:
            break

        stat_list.append(get_window_stat(batch_list))
    structures_df = pd.DataFrame(stat_list)
    structures_df["bid"] = np.arange(len(structures_df))
    batches_df = structures_df.groupby("bid").agg(
        {"n_atoms": "sum", "n_neighbours": "sum", "n_structures": "sum"}
    )

    batches_df = batches_df.sort_values(
        ["n_neighbours", "n_atoms", "n_structures"], ascending=False
    ).reset_index(drop=True)

    return batches_df


def get_batch_list(iterator, batch_size, strategy=EQUI_STRUCTURE_STRATEGY, **kwargs):
    if strategy == EQUI_STRUCTURE_STRATEGY:
        return get_equi_structures_batch_list(iterator, batch_size, **kwargs)
    elif strategy == EQUI_ATOMS_STRATEGY:
        return get_equi_batch_list(iterator, batch_size, key="n_atoms", **kwargs)
    else:
        # TODO: balance batch size by number of atoms/neighbours
        raise NotImplementedError(f"Strategy {strategy} not implemented")


def get_equi_structures_batch_list(iterator, batch_size, **kwargs):
    """return batch_list from the iterator stream,
    always returning batch_size samples (or less for remainder)

    """
    batch_list = []
    try:
        for _ in range(batch_size):
            batch_list.append(next(iterator))
    except StopIteration as e:
        if not batch_list:
            raise e
    return batch_list


def get_equi_batch_list(iterator, batch_size, key="n_atoms", **kwargs):
    """return batch_list from the iterator stream,
    trying to have AT LEAST "batch_size" number of 'key' in a batch
    """
    batch_list = []
    try:
        num_atoms = 0
        while num_atoms < batch_size:
            b = next(iterator)
            batch_list.append(b)
            num_atoms += b["batch_tot_nat_real"]
    except StopIteration as e:
        if not batch_list:
            raise e
    return batch_list


# @tf.function(jit_compile=False)
def get_batch_stat_tf(batch):
    stat_dict = {
        "n_structures": batch["batch_total_num_structures"],
        "n_atoms": batch["batch_tot_nat_real"],
        "n_neighbours": batch["n_neigh"],
    }
    return stat_dict


# @tf.function  # (jit_compile=False)
def get_bound_for_batch(batch, padding_bounds):
    stats = get_batch_stat_tf(batch)
    tf_stats = [stats[c] for c in orders]
    cond = np.all(padding_bounds >= tf_stats, axis=1)
    ind = np.where(cond)
    ind = ind[0][0]
    bound = padding_bounds[ind]
    return bound


def build_parser():
    parser = argparse.ArgumentParser(
        prog="grace_preprocess",
        description="Precompute dataset and save into TF.Dataset format",
    )

    parser.add_argument("input", help="input pkl.gz file", type=str, nargs="+")

    parser.add_argument(
        "-o", "--output", help="output file name", type=str, default="tf_dataset"
    )

    parser.add_argument(
        "--sharded-input",
        action="store_true",
        default=False,
        help="Flag to show that input files are sharded",
    )
    parser.add_argument(
        "-e",
        "--elements",
        type=str,
        default=None,
        help="List of elements. Possible presets: `ALL` (except last 23 elements), `Alexandria` or `MP`",
    )

    parser.add_argument("-b", "--batch_size", type=int, default=8)

    parser.add_argument("-bu", "--max-n-buckets", type=int, default=5)

    parser.add_argument("-c", "--cutoff", type=float, default=5)

    parser.add_argument("--compression", type=str, default="GZIP")

    parser.add_argument("--energy-col", type=str, default=ENERGY_CORRECTED_COL)
    parser.add_argument("--forces-col", type=str, default=FORCES_COL)
    parser.add_argument("--stress-col", type=str, default=STRESS_COL)
    parser.add_argument("--is-fit-stress", action="store_true", default=False)
    parser.add_argument(
        "-s",
        "--strategy",
        type=str,
        default=EQUI_STRUCTURE_STRATEGY,
        help=f"Strategy to batch splitting. Possible values: {EQUI_STRUCTURE_STRATEGY} (default), {EQUI_ATOMS_STRATEGY}",
    )

    parser.add_argument("--task-id", type=int, default=0, help="ZERO based ID of task")

    parser.add_argument(
        "--total-task-num", type=int, default=1, help="Total number of tasks"
    )

    parser.add_argument(
        "--rerun", action="store_true", default=False, help="Enforce to rerun process"
    )

    parser.add_argument(
        "--stage-1",
        action="store_true",
        default=False,
        help="Run stage 1, precompute samples (non-batched)",
    )

    parser.add_argument(
        "--stage-2",
        action="store_true",
        default=False,
        help="Run stage 2, compute padding bounds",
    )

    parser.add_argument(
        "--stage-3",
        action="store_true",
        default=False,
        help="Run stage 3, padding batches",
    )

    parser.add_argument(
        "--stage-4",
        action="store_true",
        default=False,
        help="Run stage 4, compute statistics",
    )

    return parser


def dataframe_iterrows_generator(
    file_paths,
    shard_by_file=False,
    shard_by_row=False,
    task_index=0,
    total_num_tasks=1,
    verbose=False,
):
    if shard_by_file:
        file_paths = file_paths[task_index::total_num_tasks]
    file_iterator = tqdm.tqdm(file_paths, mininterval=2) if verbose else file_paths
    for file_path in file_iterator:
        try:
            df = pd.read_pickle(file_path, compression="gzip")
            if verbose:
                logging.info(f"Loaded {file_path}, total shape: {df.shape}")
            if shard_by_row:
                df = df.iloc[task_index::total_num_tasks]
            row_iterator = df.iterrows()
            if verbose:
                row_iterator = tqdm.tqdm(row_iterator, mininterval=2, total=len(df))
            for ind, row in row_iterator:
                yield ind, row
        except Exception as e:
            logging.error(f"Error reading file: {file_path}. Error: {e}")


def join_and_pad_generator(
    dataset,
    databuilders_list,
    padding_bounds,
    batch_size,
    strategy=EQUI_STRUCTURE_STRATEGY,
    verbose=False,
    **kwargs,
):
    pbar = tqdm.tqdm(dataset, mininterval=2) if verbose else dataset
    iterator = iter(pbar)
    while True:
        try:
            batch_list = get_batch_list(
                iterator, batch_size, strategy=strategy, **kwargs
            )
        except StopIteration:
            break

        # Apply join_to_batch function
        #  Note! number of keys in result dict will be larger
        batch = {}
        for db in databuilders_list:
            batch.update(db.join_to_batch(batch_list))
        batch["n_neigh"] = len(batch["bond_vector"])

        bound = get_bound_for_batch(batch, padding_bounds)

        batch_max_pad_dict = {
            tc.PAD_MAX_N_NEIGHBORS: bound[0],
            tc.PAD_MAX_N_ATOMS: bound[1],
            tc.PAD_MAX_N_STRUCTURES: bound[2],
        }

        for data_builder in databuilders_list:
            data_builder.pad_batch(batch, batch_max_pad_dict)

        yield batch


def infer_batch_signature(batch, batch_dtypes):
    tf_spec = {}
    for k, v in batch.items():
        if isinstance(v, np.ndarray):
            tf_shape = list(v.shape)
            if len(tf_shape) > 0:
                tf_shape[0] = None
            tf_spec[k] = tf.TensorSpec(shape=tf_shape, dtype=batch_dtypes[k], name=k)
        else:
            tf_spec[k] = tf.TensorSpec(shape=None, dtype=batch_dtypes[k], name=k)
    return tf_spec


def save_batches_df(batches_df, batches_df_fname):
    os.makedirs(os.path.dirname(batches_df_fname), exist_ok=True)
    batches_df.to_pickle(batches_df_fname)


def load_and_reduce_stats(output_path, total_num_tasks, max_n_buckets):
    all_batches_df = []
    for task_index in range(total_num_tasks):
        batches_df_fname = os.path.join(
            output_path,
            "stage2",
            f"batches_df-{task_index + 1}-of-{total_num_tasks}.pkl.gz",
        )
        if not os.path.isfile(batches_df_fname):
            raise RuntimeError(
                f"Stage 2 is not completed, no such file {batches_df_fname}"
            )

        batches_df = pd.read_pickle(batches_df_fname)
        all_batches_df.append(batches_df)
    batches_df = pd.concat(all_batches_df, axis=0)

    buckets_list = split_batches_into_buckets(batches_df, max_n_buckets)

    def compute_split_loss(buckets_list):
        return sum((subdf.max(axis=0) - subdf).sum() for subdf in buckets_list)

    total_samples = batches_df[["n_atoms", "n_neighbours", "n_structures"]].sum()
    relative_split_loss = compute_split_loss(buckets_list) / total_samples
    logging.info(f"Relative split loss: {relative_split_loss.to_dict()}")

    padding_bounds = []
    for bucket in buckets_list:
        max_nstruct = bucket["n_structures"].max()
        max_nat = bucket["n_atoms"].max()
        max_nneigh = bucket["n_neighbours"].max()

        # pad at least one atom and structure if it has to pad nneigh
        is_pad_nneigh = np.any(bucket["n_neighbours"] != max_nneigh)
        is_pad_atoms = np.any(bucket["n_atoms"] != max_nat) or is_pad_nneigh
        is_pad_struct = is_pad_atoms

        if is_pad_atoms:
            max_nat += 1
        if is_pad_struct:
            max_nstruct += 1
        padding_bounds.append((max_nneigh, max_nat, max_nstruct))

    padding_bounds = sorted(padding_bounds)  # neigh, atoms, structures
    padding_bounds = np.array(padding_bounds).astype(np.int32).tolist()
    return padding_bounds


def load_params(output_path):
    # load elements_map
    fname = os.path.join(output_path, "params.json")
    if os.path.isfile(fname):
        try:
            with open(fname, "r") as f:
                params = json.load(f)
                elements_map = params["elements_map"]
                cutoff = params["cutoff"]
            return elements_map, cutoff
        except Exception as e:
            logging.error(f"Can't parse {fname}: {e}")


def save_params(output_path, elements_map, cutoff):
    os.makedirs(output_path, exist_ok=True)
    with open(os.path.join(output_path, "params.json"), "w") as f:
        json.dump({"elements_map": elements_map, "cutoff": cutoff}, f)


def get_databuilders(
    elements_map,
    cutoff,
    args_parse,
    cutoff_dict=None,
):
    geom_db = GeometricalDataBuilder(
        elements_map=elements_map,
        cutoff=cutoff,
        cutoff_dict=cutoff_dict,  # TODO: provide parameter
        is_fit_stress=args_parse.is_fit_stress,
    )
    ref_db = ReferenceEnergyForcesStressesDataBuilder(
        is_fit_stress=args_parse.is_fit_stress,
        energy_col=args_parse.energy_col,
        forces_col=args_parse.forces_col,
        stress_col=args_parse.stress_col,
        # stress_units=stress_units,
    )
    databuilders_list = [geom_db, ref_db]
    return databuilders_list


def main(args):
    # TODO: apply reference energy (esa_dict)
    # TODO: is_fit_stress, stress_units?
    # TODO: if  shift -> compute esa_dict from lstsq (optional), scale (necessary!)
    # TODO: Extract elements and element_map
    # TODO: apply weighting
    # TODO: user_cutoff_dict = args_yaml.get(tc.INPUT_CUTOFF_DICT)

    parser = build_parser()
    args_parse = parser.parse_args(args)
    input_fnames = args_parse.input
    sharded_input = args_parse.sharded_input
    elements = args_parse.elements
    cutoff = args_parse.cutoff
    batch_size = args_parse.batch_size
    strategy = args_parse.strategy
    max_n_buckets = args_parse.max_n_buckets
    compression = args_parse.compression
    rerun = args_parse.rerun

    total_num_tasks = args_parse.total_task_num
    task_index = args_parse.task_id

    stage_1 = args_parse.stage_1
    stage_2 = args_parse.stage_2
    stage_3 = args_parse.stage_3
    stage_4 = args_parse.stage_4

    LOG_FMT = f"%(asctime)s %(levelname).1s [Worker {task_index+1}/{total_num_tasks}]- %(message)s"
    logging.basicConfig(level=logging.INFO, format=LOG_FMT, datefmt="%Y/%m/%d %H:%M:%S")

    stages_flags = [stage_1, stage_2, stage_3, stage_4]

    if sum(stages_flags) == 0:
        logging.error(
            "Error: EITHER --stage-1 or --stage-2 or --stage-3 options must be provided"
        )
        sys.exit(1)
    elif sum(stages_flags) > 1:
        logging.error(
            "Error: ONLY --stage-1 or --stage-2 or --stage-3 options must be provided"
        )
        sys.exit(1)

    logging.info(f"Input filename: {input_fnames}")
    output_path = args_parse.output

    logging.info(f"Output path: {output_path}")
    logging.info(f"Cutoff: {cutoff}")
    logging.info(f"Batch size: {batch_size}")
    logging.info(f"Buckets: {max_n_buckets}")
    logging.info(f"Compression: {compression}")
    logging.info(f"Elements: {elements}")
    stage1_path = os.path.join(
        output_path, "stage1", f"shard_{task_index + 1}-of-{total_num_tasks}"
    )
    stage1_path_temp = stage1_path + ".tmp"
    batches_df_fname = os.path.join(
        output_path,
        "stage2",
        f"batches_df-{task_index + 1}-of-{total_num_tasks}.pkl.gz",
    )
    stage3_path = os.path.join(
        output_path, "stage3", f"shard_{task_index + 1}-of-{total_num_tasks}"
    )
    stage3_path_temp = stage3_path + ".tmp"

    params = load_params(output_path)
    elements_map = None
    if params is not None:
        logging.info(f"Params loaded: {params}")
        elements_map, cutoff = params

        # TODO: provide parameters
        databuilders_list = get_databuilders(elements_map, cutoff, args_parse)
        batch_dtypes = get_batch_dtypes(databuilders_list)
        batch_dtypes["n_neigh"] = tf.int32

    ############################################################################################################
    # Stage 1: initial dataset
    if stage_1:
        ############################################################################################################
        if not os.path.isdir(stage1_path) or rerun:

            if sharded_input:
                input_fnames = input_fnames[task_index::total_num_tasks]
                logging.info(
                    f"Sharded {len(input_fnames)} input file(s): {input_fnames}"
                )
            else:
                logging.info(f"Total {len(input_fnames)} input file(s): {input_fnames}")

            df_rows = lambda verbose: dataframe_iterrows_generator(
                input_fnames,
                shard_by_row=not sharded_input,
                task_index=task_index,
                total_num_tasks=total_num_tasks,
                verbose=verbose,
            )

            if elements is None:
                if elements_map is None:
                    logging.warning(
                        "Elements are not provided, so they will be extracted from ALL dataframes. "
                        + "It may lead to unnecessary slowdown. Better use --elements option"
                    )
                    logging.info("Extracting elements")
                    elements = get_elements(df_iterrows=df_rows(True))
                    logging.info(f"Extracted elements: {elements}")
                else:
                    elements = sorted(elements_map.keys())
                    logging.info(f"Loaded elements: {elements}")
            else:
                if elements == "ALL":
                    elements = chemical_symbols[1:-23]
                elif elements in ["MP", "Alexandria"]:
                    elements = ALEXANDRIA_ELEMENTS
                else:
                    elements = sorted(elements.split())
                logging.info(f"Provided elements: {elements}")

            elements_map = {e: i for i, e in enumerate(elements)}
            logging.info(f"Elements map: {elements_map}")
            save_params(output_path, elements_map, cutoff)

            databuilders_list = get_databuilders(elements_map, cutoff, args_parse)

            logging.info(f"Stage 1: Computing neigh-list dataset")

            logging.info(f"Cleaning temp folder {stage1_path_temp}")
            shutil.rmtree(stage1_path_temp, ignore_errors=True)

            output_signature = infer_tensorspec(df_rows(False), databuilders_list)

            dataset = tf.data.Dataset.from_generator(
                lambda: databuilder_generator(df_rows(True), databuilders_list),
                output_signature=output_signature,
            )

            logging.info(f"Processing and saving to {stage1_path_temp}")
            dataset.save(
                stage1_path_temp,
                compression=compression,
            )
            logging.info(f"Renaming {stage1_path_temp} to {stage1_path}")
            if os.path.isdir(stage1_path):
                shutil.rmtree(stage1_path)
            shutil.move(stage1_path_temp, stage1_path)
            logging.info(f"Stage 1 complete")

        else:
            logging.info(f"Stage 1 is already finished, folder {stage1_path} exists")

    ############################################################################################################
    # Stage 2
    elif stage_2:
        ############################################################################################################

        if os.path.isfile(batches_df_fname) and not rerun:
            logging.info(f"Stage 2: {batches_df_fname} already exists, stopping")
        else:
            logging.info(f"Stage 2: Computing batches statistics")
            logging.info(f"Loading {stage1_path}, batch size={batch_size}")
            dataset = tf.data.Dataset.load(stage1_path, compression=compression)
            batches_df = compute_batches_df(
                dataset, batch_size=batch_size, strategy=strategy
            )
            save_batches_df(batches_df, batches_df_fname)
            logging.info(f"Batches statistics saved to {batches_df_fname}")
    ############################################################################################################
    # Stage 3
    elif stage_3:
        ############################################################################################################
        if os.path.isdir(stage3_path) and not rerun:
            logging.info(f"Folder {stage3_path} already exists, skipping")
        else:
            logging.info(f"Reducing padding bounds in {output_path}")
            padding_bounds = load_and_reduce_stats(
                output_path, total_num_tasks, max_n_buckets
            )
            os.makedirs(os.path.join(output_path, "stage3"), exist_ok=True)
            with open(
                os.path.join(output_path, "stage3", "reduced_padding_bounds.json"), "w"
            ) as f:
                json.dump(padding_bounds, f)

            logging.info(f"Loaded padding_bounds={padding_bounds}")
            padding_bounds = np.array(padding_bounds)

            logging.info(f"Stage 3: Padding dataset and saving to {stage3_path}")
            dataset = tf.data.Dataset.load(stage1_path, compression=compression)

            logging.info(f"Cleaning temp folder {stage3_path_temp}")
            shutil.rmtree(stage3_path_temp, ignore_errors=True)

            logging.info(f"Processing and saving to temp folder {stage3_path_temp}")

            batch = next(
                iter(
                    join_and_pad_generator(
                        dataset,
                        databuilders_list,
                        padding_bounds,
                        batch_size,
                        strategy=strategy,
                    )
                )
            )
            batch_signature = infer_batch_signature(batch, batch_dtypes)

            padded_dataset = tf.data.Dataset.from_generator(
                lambda: join_and_pad_generator(
                    dataset,
                    databuilders_list,
                    padding_bounds,
                    batch_size,
                    verbose=True,
                    strategy=strategy,
                ),
                output_signature=batch_signature,
            )

            padded_dataset.save(
                stage3_path_temp,
                compression=compression,
            )
            logging.info(f"Renaming {stage3_path_temp} to {stage3_path}")
            if os.path.isdir(stage3_path):
                shutil.rmtree(stage3_path)
            shutil.move(stage3_path_temp, stage3_path)
            logging.info(f"Saved to {stage3_path}")

        logging.info("Done")
    elif stage_4:
        if task_index == 0:
            stats_fname = os.path.join(output_path, "stage3", "stats.json")
            elements_map, cutoff = load_params(output_path)

            datasets_fnames = glob.glob(
                os.path.join(output_path, "stage3", "shard_*-of-*")
            )
            logging.info(f"Number of found shards: {len(datasets_fnames)}")
            datasets_list = [
                tf.data.Dataset.load(filepath, compression=compression)
                for filepath in datasets_fnames
            ]

            dataset = tf.data.Dataset.from_tensor_slices(datasets_list).interleave(
                lambda x: x
            )

            b_count = 0
            sum_sqr_forces = 0
            tot_nat = 0
            tot_nneigh = 0
            tot_nstruct = 0
            for b in tqdm.tqdm(dataset):
                b_count += 1
                nat_real = b[tc.N_ATOMS_BATCH_REAL].numpy()
                n_neigh_real = b[tc.N_NEIGHBORS_REAL].numpy()
                n_struct = b[tc.N_STRUCTURES_BATCH_REAL].numpy()

                cur_forces = b[tc.DATA_REFERENCE_FORCES].numpy()
                cur_forces = cur_forces[:nat_real]

                sum_sqr_forces += np.sum(cur_forces**2)
                tot_nat += nat_real
                tot_nneigh += n_neigh_real
                tot_nstruct += n_struct

            ave_n_neigh = tot_nneigh / tot_nat
            # currently avg per-component. Maybe must be per-vector ?, i.e. sqrt(3) times arger ?
            scale = np.sqrt(sum_sqr_forces / tot_nat / 3)

            stats = {
                "element_map": elements_map,
                "scale": scale,
                "avg_n_neigh": ave_n_neigh,
                "total_num_of_neighs": tot_nneigh,
                "total_num_of_atoms": tot_nat,
                "sum_sqr_forces": sum_sqr_forces,
                "total_num_structures": tot_nstruct,
                "total_num_of_batches": b_count,
                "cutoff": cutoff,
                "batch_size": batch_size,
            }
            logging.info(f"Aggregated stats: {stats}")

            with open(stats_fname, "w") as f:
                json.dump(stats, f, cls=NpEncoder)


if __name__ == "__main__":
    main(sys.argv[1:])
