import torch
import torch.nn as nn
import torch.nn.functional as F
from fedbiomed.common.training_plans import TorchTrainingPlan
from fedbiomed.common.datamanager import DataManager
from fedbiomed.common.constants import ProcessTypes
from fedbiomed.common.dataloader import PytorchDataLoader
from torchvision import datasets, transforms
from fedbiomed.common.dataset import CustomDataset
from torchvision import transforms
import pandas as pd
from PIL import Image
import os
import numpy as np
class CelebaTrainingPlan(TorchTrainingPlan):

    # Defines model
    def init_model(self):
        model = self.Net()
        return model

    # Here we define the custom dependencies that will be needed by our custom Dataloader
    def init_dependencies(self):
        deps = ["from fedbiomed.common.dataset import CustomDataset",
                "from torchvision import transforms",
                "import pandas as pd",
                "from PIL import Image",
                "import os",
                "import numpy as np"]
        return deps

    # Torch modules class
    class Net(nn.Module):

        def __init__(self):
            super().__init__()
            #convolution layers
            self.conv1 = nn.Conv2d(3, 32, 3, 1)
            self.conv2 = nn.Conv2d(32, 32, 3, 1)
            self.conv3 = nn.Conv2d(32, 32, 3, 1)
            self.conv4 = nn.Conv2d(32, 32, 3, 1)
            self.dropout1 = nn.Dropout(0.25)
            self.dropout2 = nn.Dropout(0.5)
            # classifier
            self.fc1 = nn.Linear(3168, 128)
            self.fc2 = nn.Linear(128, 2)

        def forward(self, x):
            x = self.conv1(x)
            x = F.max_pool2d(x, 2)
            x = F.relu(x)

            x = self.conv2(x)
            x = F.max_pool2d(x, 2)
            x = F.relu(x)

            x = self.conv3(x)
            x = F.max_pool2d(x, 2)
            x = F.relu(x)

            x = self.conv4(x)
            x = F.max_pool2d(x, 2)
            x = F.relu(x)

            x = self.dropout1(x)
            x = torch.flatten(x, 1)
            x = self.fc1(x)
            x = F.relu(x)

            x = self.dropout2(x)
            x = self.fc2(x)
            output = F.log_softmax(x, dim=1)
            return output


    class CelebaDataset(CustomDataset):
        """Custom Dataset for loading CelebA face images"""

        # we dont load the full data of the images, we retrieve the image with the get item.
        # in our case, each image is 218*178 * 3colors. there is 67533 images. this take at least 7G of ram
        # loading images when needed takes more time during training but it won't impact the ram usage as much as loading everything

        def read(self):
            self.csv_path = self.path + '/' + 'target.csv'
            self.img_dir = self.path + '/' + 'data'
    
            # Read the CSV file that contains labels for each image
            df = pd.read_csv(self.csv_path, sep="\t", index_col=0)
            self.img_names = df.index.values
            self.y = df['Smiling'].values
    
        def get_item(self, index):
            img = np.asarray(Image.open(os.path.join(self.img_dir, self.img_names[index])))
            img = transforms.ToTensor()(img)
            label = torch.tensor(self.y[index])
            return img, label
    
        def __len__(self):
            return self.y.shape[0]
            
    # The training_data creates the Dataloader to be used for training in the
    # general class Torchnn of fedbiomed
    def training_data(self):
        dataset = self.CelebaDataset()
        loader_arguments = { 'shuffle': True}
        return DataManager(dataset, **loader_arguments)

    # This function must return the loss to backward it
    def training_step(self, data, target):

        output = self.model().forward(data)
        loss   = torch.nn.functional.nll_loss(output, target)
        return loss