import torch
import torch.nn as nn
import torch.nn.functional as F
from fedbiomed.common.training_plans import TorchTrainingPlan
from fedbiomed.common.datamanager import DataManager
from fedbiomed.common.constants import ProcessTypes
from fedbiomed.common.dataloader import PytorchDataLoader
from torchvision import datasets, transforms
from fedbiomed.common.dataset import TabularDataset
import pandas as pd
class MyTrainingPlan(TorchTrainingPlan):

    # Model
    def init_model(self):
        model_args = self.model_args()
        model = self.Net(model_args)
        return model

    # Dependencies
    def init_dependencies(self):
        deps = ["from fedbiomed.common.dataset import TabularDataset",
                "import pandas as pd"]
        return deps

    # network
    class Net(nn.Module):
        def __init__(self, model_args):
            super().__init__()
            self.in_features = model_args['in_features']
            self.out_features = model_args['out_features']
            self.fc1 = nn.Linear(self.in_features, 5)
            self.fc2 = nn.Linear(5, self.out_features)

        def forward(self, x):
            x = self.fc1(x)
            x = F.relu(x)
            x = self.fc2(x)
            return x

    def training_step(self, data, target):
        output = self.model().forward(data).float()
        criterion = torch.nn.MSELoss()
        loss = torch.sqrt(criterion(output, target.unsqueeze(1)))
        return loss

    def training_data(self):
    # The training_data creates the Dataloader to be used for training in the general class TorchTrainingPlan of fedbiomed
        dataset = TabularDataset(
            input_columns=['year','transmission','mileage','tax','mpg','engineSize'], 
            target_columns=['price'],
            transform=lambda xs: torch.as_tensor(xs, dtype=torch.float32),
            target_transform=lambda xs: torch.as_tensor(xs, dtype=torch.float32)
        )
        return DataManager(dataset=dataset)
