/* 
* Copyright (C) 2020-2026 MEmilio
*
* Authors: Martin Siggel, Daniel Abele, Martin J. Kuehn, Jan Kleinert
*
* Contact: Martin J. Kuehn <Martin.Kuehn@DLR.de>
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

//Includes from pymio
#include "pybind_util.h"
#include "compartments/flow_simulation.h"
#include "epidemiology/age_group.h"
#include "utils/parameter_set.h"
#include "utils/custom_index_array.h"
#include "utils/index.h"
#include "mobility/graph_simulation.h"
#include "mobility/metapopulation_mobility_instant.h"
#include "io/mobility_io.h"
#include "io/result_io.h"

//Includes from MEmilio
#include "ode_seir/model.h"
#include "ode_seir/infection_state.h"
#include "memilio/data/analyze_result.h"
#include "memilio/compartments/parameter_studies.h"
#include "memilio/data/analyze_result.h"

#include "pybind11/pybind11.h"


namespace py = pybind11;

namespace pymio{
//specialization of pretty_name
template <>
std::string pretty_name<mio::oseir::InfectionState>()
{
   return "InfectionState";
}

} // namespace pymio


PYBIND11_MAKE_OPAQUE(std::vector<mio::Graph<mio::SimulationNode<mio::Simulation<>>, mio::MobilityEdge<double>>);
PYBIND11_MAKE_OPAQUE(std::vector<mio::Graph<mio::SimulationNode<mio::FlowSimualtion<>>, mio::MobilityEdge<double>>);

PYBIND11_MODULE(_simulation_test_oseir, m)
{
    pymio::iterable_enum<mio::oseir::InfectionState>(m, "InfectionState")
	    .value("Susceptible", mio::oseir::InfectionState::Susceptible)
	    .value("Exposed", mio::oseir::InfectionState::Exposed)
	    .value("Infected", mio::oseir::InfectionState::Infected)
	    .value("Recovered", mio::oseir::InfectionState::Recovered);



    pymio::bind_CustomIndexArray<mio::UncertainValue, mio::AgeGroup>(m, "AgeGroupArray");

    pymio::bind_ParameterSet<mio::oseir::ParametersBase<double>, pymio::EnablePickling::Required>(m, "ParametersBase");

    py::class_<mio::oseir::Parameters<double>, pymio::EnablePickling::Required, mio::oseir::ParametersBase<double>>(m, "Parameters")
		.def(py::init<mio::AgeGroup>())
		.def("check_constraints", &mio::oseir::Parameters<double>::check_constraints)
		.def("apply_constraints", &mio::oseir::Parameters<double>::apply_constraints);

   
    pymio::bind_Population(m, "Populations", mio::Tag<mio::oseir::Model<double>::Populations>{});
   
    pymio::bind_CompartmentalModel<FlowModel<FP, mio::oseir::InfectionState, mio::Populations<FP, AgeGroup, InfectionState>, Parameters<FP>, mio::oseir::Flows>, CompartmentalModel<FP, Comp, Pop, Params>>(m, "ModelBase");
    py::class_<mio::oseir::Model<double>, pymio::EnablePickling::Required, FlowModel<FP, mio::oseir::InfectionState, mio::Populations<FP, AgeGroup, InfectionState>, Parameters<FP>, mio::oseir::Flows>, CompartmentalModel<FP, Comp, Pop, Params>>(m, "Model");
    
        .def(py::init<int>(), py::arg("num_agegroups"));


    pymio::bind_Simulation<mio::Simulation<double,mio::oseir::Model<double>>>(m, "Simulation");

	m.def(
		"simulate", &mio::simulate<double,mio::oseir::Model<double>>,
		py::arg("t0"), py::arg("tmax"), py::arg("dt"), py::arg("model"), py::arg("integrator") = py::none(),
		"Simulate a mio::oseir:: from t0 to tmax."
		);

	pymio::bind_ModelNode<mio::oseir::Model<double>>(m, "ModelNode");
	pymio::bind_SimulationNode<mio::oseir::Simulation<>>(m, "SimulationNode");
	pymio::bind_ModelGraph<mio::oseir::Model<double>>(m, "ModelGraph");
	pymio::bind_MobilityGraph<mio::oseir::Simulation<>>(m, "MobilityGraph");
	pymio::bind_GraphSimulation<mio::Graph<mio::SimulationNode<mio::Simulation<>>, mio::MobilityEdge<double>>>(m, "MobilitySimulation");

	pymio::bind_Flow_Simulation<mio::Simulation<double, mio::FlowSimulation<double,mio::oseir::Model<double>>>>(m, "FlowSimulation");

	m.def(
		"simulate_flows", &mio::flow_simulation<mio::oseir::Model<double>>,
		py::arg("t0"), py::arg("tmax"), py::arg("dt"), py::arg("model"), py::arg("integrator") = py::none(),
		"Simulate a mio::oseir:: with flows from t0 to tmax."
		);

	pymio::bind_SimulationNode<mio::oseir::FlowSimulation<>>(m, "SimulationNode");
	pymio::bind_MobilityGraph<mio::oseir::FlowSimulation<>>(m, "MobilityGraph");
	pymio::bind_GraphSimulation<mio::Graph<mio::SimulationNode<mio::FlowSimulation<>>, mio::MobilityEdge<double>>>(m, "MobilitySimulation");
	
    
    m.def("interpolate_simulation_result",
        static_cast<mio::TimeSeries<double> (*)(const mio::TimeSeries<double>&, const double)>(
            &mio::interpolate_simulation_result),
        py::arg("ts"), py::arg("abs_tol") = 1e-14);

    m.def("interpolate_simulation_result",
        static_cast<mio::TimeSeries<double> (*)(const mio::TimeSeries<double>&, const std::vector<double>&)>(
            &mio::interpolate_simulation_result),
        py::arg("ts"), py::arg("interpolation_times"));

    m.def("interpolate_ensemble_results", &mio::interpolate_ensemble_results<mio::TimeSeries<double>>);
    

    m.attr("__version__") = "dev";
}
