#!python
import argparse
import logging
import os
import tempfile
from typing import Optional

import torch

import cerebras.pytorch as cstorch
from cerebras.appliance import logger
from cerebras.appliance.errors import ApplianceVersionError

INPUT_DIM = 10

logging.basicConfig(level=logging.INFO)


class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = torch.nn.Linear(INPUT_DIM, 2)

    def forward(self, x):
        return self.fc1(x)


def input_fn(
    batch_size,
    input_dtype=torch.float16,
    target_dtype=torch.int32,
    sample_count=1,
):
    return torch.utils.data.DataLoader(
        cstorch.utils.data.SyntheticDataset(
            [
                torch.ones(INPUT_DIM, dtype=input_dtype),
                torch.tensor(0, dtype=target_dtype),
            ],
            num_samples=sample_count,
        ),
        batch_size=batch_size,
    )


def check_install(
    mgmt_address: Optional[str] = None,
    mgmt_namespace: Optional[str] = None
):
    artifact_dir = tempfile.TemporaryDirectory()
    backend = cstorch.backend(
        "CSX",
        artifact_dir=artifact_dir.name,
        compile_only=True,
        validate_only=False,
        cluster_config=cstorch.distributed.ClusterConfig(
            mgmt_address=mgmt_address,
            mgmt_namespace=mgmt_namespace
        ),
    )
    model = cstorch.compile(Model(), backend=backend)

    loss_fn = torch.nn.CrossEntropyLoss()
    optimizer = cstorch.optim.SGD(model.parameters(), lr=0.001)

    batch_size = 1
    dataloader = cstorch.utils.data.DataLoader(input_fn, batch_size,)
    executor = cstorch.utils.data.DataExecutor(
        dataloader,
        num_steps=1,
    )

    @cstorch.trace
    def training_step(batch):
        inputs, targets = batch
        outputs = model(inputs)
        loss = loss_fn(outputs, targets)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        return loss

    for batch in executor:
        loss = training_step(batch)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        "CLI utility to confirm Cerebras Installation"
    )
    parser.add_argument(
        "-m",
        "--mgmt_address",
        default=None,
        type=str,
        help="Address of Cerebras Appliance",
    )
    parser.add_argument(
        "-n",
        "--mgmt_namespace",
        default=None,
        type=str,
        help="Address of Cerebras Appliance",
    )
    args = parser.parse_args()
    try:
        check_install(args.mgmt_address, args.mgmt_namespace)
    except ApplianceVersionError as exc:
        logger.error(f"Cerebras Component Mismatch Check Installation:\n{exc}")
        os._exit(5)  # Somewhat arbitrary just for testing
    logging.info("Cerebras Components Verified")
