"""Species"""
{% set species_data = data | unique_rows("name") %}
{% set strain_data = data | unique_rows("strain") %}
{% raw -%}
from pydantic import BaseModel, Field, ConfigDict
from typing import Literal, Union, Optional
from typing_extensions import Annotated
from aind_data_schema_models.registries import Registry
{% endraw %}

class StrainModel(BaseModel):
    """Base model for a strain"""
    model_config = ConfigDict(frozen=True)
    name: str
    species: str
    registry: Optional[Registry] = Field(default=None)
    registry_identifier: Optional[str] = Field(default=None)

{% for _, row in strain_data.iterrows() %}
{% if row['strain'] != 'default' %}
class {{ row['strain'] | to_class_name_underscored }}(StrainModel):
    """Model {{row['strain']}}"""
    name: Literal["{{ row['strain'] }}"] = "{{ row['strain'] }}"
    species: Literal["{{ row['name'] }}"] = "{{ row['name'] }}"
    {%- set reg_abb = row['strain_registry_abbreviation'] %}
    registry: Optional[Registry] = {% if reg_abb == reg_abb %}Field(default=Registry.{{ reg_abb | upper }}){% else %}Field(default=None){% endif %}
    registry_identifier: Optional[str] = {% if reg_abb == reg_abb %}Field(default="{{ row['strain_registry_identifier'] }}"){% else %}Field(default=None){% endif %}
{% endif %}
{% endfor %}
class Strain:
    """Strain"""
{% for _, row in strain_data.iterrows() %}
{% if row['strain'] != 'default' %}
    {{ row['strain'] | to_class_name | upper }} = {{ row['strain'] | to_class_name_underscored }}()
{% endif %}
{%- endfor %}

    ALL = tuple(StrainModel.__subclasses__())

    ONE_OF = Annotated[Union[{% for _, row in strain_data.iterrows() %}{% if row['strain'] != 'default' %}{{ row['strain'] | to_class_name_underscored }}{{ ", " if not loop.last else "" }}{% endif %}{% endfor %}], Field(discriminator="name")]


class SpeciesModel(BaseModel):
    """Base model for species"""
    model_config = ConfigDict(frozen=True)
    name: str
    common_name: str
    registry: Registry
    registry_identifier: str

{% for _, row in species_data.iterrows() %}
class {{ row['name'] | to_class_name_underscored }}(SpeciesModel):
    """Model {{row['name']}}"""
    name: Literal["{{ row['name'] }}"] = "{{ row['name'] }}"
    common_name: Literal["{{ row['common_name'] }}"] = "{{ row['common_name'] }}"
    registry: Registry = Registry.{{ row['species_registry_abbreviation'] | upper }}
    registry_identifier: Literal["{{ row['species_registry_identifier'] }}"] = "{{ row['species_registry_identifier'] }}"

{% endfor %}
class Species:
    """Species"""
{% for _, row in species_data.iterrows() %}
    {{ row['common_name'] | to_class_name | upper }} = {{ row['name'] | to_class_name_underscored }}()
{%- endfor %}

    ALL = tuple(SpeciesModel.__subclasses__())

    ONE_OF = Annotated[Union[{% for _, row in species_data.iterrows() %}{{ row['name'] | to_class_name_underscored }}{{ ", " if not loop.last else "" }}{% endfor %}], Field(discriminator="name")]

    name_map = {m().name: m() for m in ALL}
    common_name_map = {m().common_name: m() for m in ALL}

    @classmethod
    def from_name(cls, name: str):
        """Get SpeciesModel from name"""
        return cls.name_map.get(name, None)

    @classmethod
    def from_common_name(cls, common_name: str):
        """Get SpeciesModel from common_name"""
        return cls.common_name_map.get(common_name, None)
