import torch
import torch.nn as nn
import torch.nn.functional as F
from fedbiomed.common.training_plans import TorchTrainingPlan
from fedbiomed.common.datamanager import DataManager
from fedbiomed.common.constants import ProcessTypes
from fedbiomed.common.dataloader import PytorchDataLoader
from torchvision import datasets, transforms
from fedbiomed.common.dataset import MnistDataset
from torchvision import transforms
from torch.optim import Adam
class MyTrainingPlan(TorchTrainingPlan):
    class Net(nn.Module):
        def __init__(self, model_args):
            super().__init__()
            self.conv1 = nn.Conv2d(1, 32, 3, 1)
            self.conv2 = nn.Conv2d(32, 64, 3, 1)
            self.dropout1 = nn.Dropout(0.25)
            self.dropout2 = nn.Dropout(0.5)
            self.fc1 = nn.Linear(9216, 128)
            self.fc2 = nn.Linear(128, 10)

        def forward(self, x):
            x = self.conv1(x)
            x = F.relu(x)
            x = self.conv2(x)
            x = F.relu(x)
            x = F.max_pool2d(x, 2)
            x = self.dropout1(x)
            x = torch.flatten(x, 1)
            x = self.fc1(x)
            x = F.relu(x)
            x = self.dropout2(x)
            x = self.fc2(x)
            output = F.log_softmax(x, dim=1)
            return output

    def init_model(self, model_args):
        return self.Net(model_args = model_args)

    def init_optimizer(self, optimizer_args):
        return Adam(self.model().parameters(), lr = optimizer_args["lr"])

    def init_dependencies(self):
        return ["from fedbiomed.common.dataset import MnistDataset",
                "from torchvision import transforms",
                "from torch.optim import Adam"]

    def training_data(self):
        transform = transforms.Normalize((0.1307,), (0.3081,))
        dataset1 = MnistDataset(transform=transform)
        loader_arguments = {'shuffle': True}
        return DataManager(dataset1, **loader_arguments)

    def training_step(self, data, target):
        output = self.model().forward(data)
        loss   = torch.nn.functional.nll_loss(output, target)
        return loss