"""Protocols"""
{% raw -%}
from pydantic import Field, ConfigDict
from typing import Union
from typing_extensions import Annotated
from aind_data_schema_models.pid_names import BaseName
from aind_data_schema_models.registries import Registry
import re
{% endraw %}

class ProtocolModel(BaseName):
    """Base model for protocol"""
    model_config = ConfigDict(frozen=True)
    name: str
    version: int
    registry: Registry
    registry_identifier: str

{% for _, row in data.iterrows() %}
class {{ row['title'] | to_class_name_underscored }}_V{{ row['version'] }}(ProtocolModel):
    """Model {{ row['title'] }}"""
    name: str = "{{ row['title'] }}"
    version: int = {{ row['version'] }}
    registry: Registry = Registry.DOI
    registry_identifier: str = "{{ row['DOI'] }}"


{% endfor %}

class Protocols:
    """Protocols"""
{% for _, row in data.iterrows() %}
    {{ row['title'] | to_class_name | upper }}_V{{ row['version'] }} = {{ row['title'] | to_class_name_underscored }}_V{{ row['version'] }}()
{% endfor %}

    ALL = tuple(ProtocolModel.__subclasses__())

    ONE_OF = Annotated[Union[{% for _, row in data.iterrows() %}{{ row['title'] | to_class_name_underscored }}_V{{ row['version'] }}{{ ", " if not loop.last else "" }}{% endfor %}], Field(discriminator="title")]

    doi_map = {m().registry_identifier: m() for m in ALL if getattr(m(), "registry_identifier", None)}

    @classmethod
    def from_doi(cls, doi: str) -> ProtocolModel:
        """Return protocol model by DOI."""
        return cls.doi_map.get(doi, None)

    @classmethod
    def from_url(cls, url: str) -> ProtocolModel:
        """Return protocol model by DOI, stripping URL prefixes."""
        # Remove any leading protocol/domain up to the DOI
        doi = re.sub(r'^(https?://)?(dx\.)?doi\.org/', '', url)
        return cls.from_doi(doi)
