Source code for betty.model.association

"""
Entity associations.
"""

from __future__ import annotations

import weakref
from abc import abstractmethod
from typing import Generic, cast, Any, Iterable, TypeVar, final

from basedtyping import Intersection
from typing_extensions import override

from betty.attr import MutableAttr
from betty.classtools import repr_instance
from betty.importlib import import_any
from betty.model import Entity
from betty.model.collections import EntityCollection, SingleTypeEntityCollection
from betty.typing import internal

_EntityT = TypeVar("_EntityT", bound=Entity)
_OwnerT = TypeVar("_OwnerT")
_AssociateT = TypeVar("_AssociateT")
_AssociationAttrValueT = TypeVar("_AssociationAttrValueT")
_AssociationAttrSetT = TypeVar("_AssociationAttrSetT")


[docs] class Association( Generic[_OwnerT, _AssociateT, _AssociationAttrValueT, _AssociationAttrSetT], MutableAttr[ Intersection[_OwnerT, Entity], _AssociationAttrValueT, _AssociationAttrSetT ], ): """ Define an association between two entity types. """
[docs] def __init__( self, owner_type_name: str, owner_attr_name: str, associate_type_name: str, ): super().__init__(owner_attr_name) self._owner_type_name = owner_type_name self._owner_attr_name = owner_attr_name self._associate_type_name = associate_type_name AssociationRegistry._register(self)
def __hash__(self) -> int: return hash( ( self._owner_type_name, self._owner_attr_name, self._associate_type_name, ) ) @override def __repr__(self) -> str: return repr_instance( self, owner_type=self._owner_type_name, owner_attr_name=self._owner_attr_name, associate_type=self._associate_type_name, ) @property def owner_type(self) -> type[_OwnerT]: """ The type of the owning entity that contains this association. """ return cast( type[_OwnerT], import_any(self._owner_type_name), ) @property def owner_attr_name(self) -> str: """ The name of the attribute on the owning entity that contains this association. """ return self._owner_attr_name @property def associate_type(self) -> type[_AssociateT]: """ The type of any associate entities. """ return cast( type[_AssociateT], import_any(self._associate_type_name), )
[docs] @abstractmethod def associate( self, owner: _OwnerT & Entity, associate: _AssociateT & Entity ) -> None: """ Associate two entities. """ pass
[docs] @abstractmethod def disassociate( self, owner: _OwnerT & Entity, associate: _AssociateT & Entity ) -> None: """ Disassociate two entities. """ pass
class _BidirectionalAssociation( Generic[_OwnerT, _AssociateT, _AssociationAttrValueT, _AssociationAttrSetT], Association[_OwnerT, _AssociateT, _AssociationAttrValueT, _AssociationAttrSetT], ): """ A bidirectional entity type association. """ def __init__( self, owner_type_name: str, owner_attr_name: str, associate_type_name: str, associate_attr_name: str, ): self._associate_attr_name = associate_attr_name super().__init__( owner_type_name, owner_attr_name, associate_type_name, ) def __hash__(self) -> int: return hash( ( self._owner_type_name, self._owner_attr_name, self._associate_type_name, self._associate_attr_name, ) ) @override def __repr__(self) -> str: return repr_instance( self, owner_type=self._owner_type_name, owner_attr_name=self._owner_attr_name, associate_type_name=self._associate_type_name, associate_attr_name=self._associate_attr_name, ) @property def associate_attr_name(self) -> str: """ The association's attribute name on the associate type. """ return self._associate_attr_name def inverse( self, ) -> _BidirectionalAssociation[_AssociateT, _OwnerT, Any, Any]: """ Get the inverse association. """ association = AssociationRegistry.get_association( self.associate_type, self.associate_attr_name ) assert isinstance(association, _BidirectionalAssociation) return association
[docs] @internal class ToOneAssociation( Generic[_OwnerT, _AssociateT], Association[ _OwnerT, _AssociateT, Intersection[_AssociateT, Entity] | None, Intersection[_AssociateT, Entity] | None, ], ): """ A unidirectional to-one entity type association. """
[docs] @override def new_attr(self, instance: _OwnerT & Entity) -> None: return None
[docs] @override def set_attr( self, instance: _OwnerT & Entity, value: Intersection[_AssociateT, Entity] | None, ) -> None: setattr(instance, self._attr_name, value)
[docs] @override def del_attr(self, instance: _OwnerT & Entity) -> None: self.set_attr(instance, None)
[docs] @override def associate( self, owner: _OwnerT & Entity, associate: _AssociateT & Entity ) -> None: self.set_attr(owner, associate)
[docs] @override def disassociate( self, owner: _OwnerT & Entity, associate: _AssociateT & Entity ) -> None: if associate == self.get_attr(owner): self.del_attr(owner)
[docs] @internal class ToManyAssociation( Generic[_OwnerT, _AssociateT], Association[ _OwnerT, _AssociateT, EntityCollection[_AssociateT], Iterable[Intersection[_AssociateT, Entity]], ], ): """ A to-many entity type association. """
[docs] @override def set_attr( self, instance: _OwnerT & Entity, value: Iterable[Intersection[_AssociateT, Entity]], ) -> None: """ Set the associates on the given owner. """ self.get_attr(instance).replace(*value)
[docs] @override def del_attr(self, instance: _OwnerT & Entity) -> None: self.get_attr(instance).clear()
[docs] @override def associate( self, owner: _OwnerT & Entity, associate: _AssociateT & Entity ) -> None: self.get_attr(owner).add(associate)
[docs] @override def disassociate( self, owner: _OwnerT & Entity, associate: _AssociateT & Entity ) -> None: self.get_attr(owner).remove(associate)
class _BidirectionalToOneAssociation( Generic[_OwnerT, _AssociateT], ToOneAssociation[_OwnerT, _AssociateT], _BidirectionalAssociation[ _OwnerT, _AssociateT, Intersection[_AssociateT, Entity] | None, Intersection[_AssociateT, Entity] | None, ], ): """ A bidirectional *-to-one entity type association. """ @override def set_attr( self, instance: _OwnerT & Entity & Entity, value: Intersection[_AssociateT, Entity] | None, ) -> None: previous_associate = self.get_attr(instance) if previous_associate == value: return super().set_attr(instance, value) if previous_associate is not None: self.inverse().disassociate(previous_associate, instance) if value is not None: self.inverse().associate(value, instance) class _BidirectionalToManyAssociation( Generic[_OwnerT, _AssociateT], ToManyAssociation[_OwnerT, _AssociateT], _BidirectionalAssociation[ _OwnerT, _AssociateT, EntityCollection[_AssociateT], Iterable[Intersection[_AssociateT, Entity]], ], ): """ A bidirectional *-to-many entity type association. """ @override def new_attr(self, instance: _OwnerT & Entity) -> EntityCollection[_AssociateT]: return _BidirectionalAssociateCollection( instance, self, )
[docs] @final class ToOne(Generic[_OwnerT, _AssociateT], ToOneAssociation[_OwnerT, _AssociateT]): """ A unidirectional to-one entity type association. """ pass
[docs] @final class OneToOne( Generic[_OwnerT, _AssociateT], _BidirectionalToOneAssociation[_OwnerT, _AssociateT], ): """ A bidirectional one-to-one entity type association. """ pass
[docs] @final class ManyToOne( Generic[_OwnerT, _AssociateT], _BidirectionalToOneAssociation[_OwnerT, _AssociateT], ): """ A bidirectional many-to-one entity type association. """ pass
[docs] @final class ToMany(Generic[_OwnerT, _AssociateT], ToManyAssociation[_OwnerT, _AssociateT]): """ A unidirectional to-many entity type association. """
[docs] @override def new_attr(self, instance: _OwnerT & Entity) -> EntityCollection[_AssociateT]: return SingleTypeEntityCollection[_AssociateT](self.associate_type)
[docs] @final class OneToMany( Generic[_OwnerT, _AssociateT], _BidirectionalToManyAssociation[_OwnerT, _AssociateT], ): """ A bidirectional one-to-many entity type association. """ pass
[docs] @final class ManyToMany( Generic[_OwnerT, _AssociateT], _BidirectionalToManyAssociation[_OwnerT, _AssociateT], ): """ A bidirectional many-to-many entity type association. """ pass
[docs] @final class AssociationRegistry: """ Inspect any known entity type associations. """ _associations = set[Association[Any, Any, Any, Any]]()
[docs] @classmethod def get_all_associations( cls, owner: type | object ) -> set[Association[Any, Any, Any, Any]]: """ Get all associations for an owner. """ owner_type = owner if isinstance(owner, type) else type(owner) return { association for association in cls._associations if association.owner_type in owner_type.__mro__ }
[docs] @classmethod def get_association( cls, owner: type[_OwnerT] | _OwnerT & Entity, owner_attr_name: str ) -> Association[_OwnerT, Any, Any, Any]: """ Get the association for a given owner and attribute name. """ for association in cls.get_all_associations(owner): if association.owner_attr_name == owner_attr_name: return association raise ValueError( f"No association exists for {owner if isinstance(owner, type) else owner.__class__}.{owner_attr_name}." )
[docs] @classmethod def get_associates( cls, owner: _EntityT, association: Association[_EntityT, _AssociateT, Any, Any], ) -> Iterable[_AssociateT]: """ Get the associates for a given owner and association. """ associates: _AssociateT | None | Iterable[_AssociateT] = association.get_attr( owner ) if isinstance(association, ToOneAssociation): if associates is None: return yield cast(_AssociateT, associates) return yield from cast(Iterable[_AssociateT], associates)
@classmethod def _register(cls, association: Association[Any, Any, Any, Any]) -> None: cls._associations.add(association)
class _BidirectionalAssociateCollection( Generic[_AssociateT, _OwnerT], SingleTypeEntityCollection[_AssociateT] ): __slots__ = "__owner", "_association" def __init__( self, owner: _OwnerT & Entity, association: _BidirectionalAssociation[_OwnerT, _AssociateT, Any, Any], ): super().__init__(association.associate_type) self._association = association self.__owner = weakref.ref(owner) @property def _owner(self) -> _OwnerT & Entity: owner = self.__owner() if owner is None: raise RuntimeError( "This associate collection's owner no longer exists in memory." ) return owner @override def _on_add(self, *entities: _AssociateT & Entity) -> None: super()._on_add(*entities) for associate in entities: self._association.inverse().associate(associate, self._owner) @override def _on_remove(self, *entities: _AssociateT & Entity) -> None: super()._on_remove(*entities) for associate in entities: self._association.inverse().disassociate(associate, self._owner)