Source code for betty.project.extension

"""Provide Betty's extension API."""

from __future__ import annotations

from abc import abstractmethod
from collections import defaultdict
from collections.abc import MutableMapping
from typing import TypeVar, Iterable, TYPE_CHECKING, Generic, Self, Sequence

from typing_extensions import override

from betty.config import Configurable, Configuration
from betty.core import CoreComponent
from betty.locale.localizable import Localizable, _, call
from betty.plugin import Plugin, PluginRepository, PluginIdentifier, PluginIdToTypeMap
from betty.plugin.entry_point import EntryPointPluginRepository
from betty.project.factory import ProjectDependentFactory
from betty.typing import internal
from betty.requirement import AllRequirements

if TYPE_CHECKING:
    from betty.event_dispatcher import EventHandlerRegistry
    from betty.requirement import Requirement
    from betty.project import Project
    from pathlib import Path

_ConfigurationT = TypeVar("_ConfigurationT", bound=Configuration)


[docs] class ExtensionError(Exception): """ A generic extension API error. """ pass # pragma: no cover
[docs] class CyclicDependencyError(ExtensionError, RuntimeError): """ Raised when extensions define a cyclic dependency, e.g. two extensions depend on each other. """
[docs] def __init__(self, extension_types: Iterable[type[Extension]]): extension_names = ", ".join( [extension.plugin_id() for extension in extension_types] ) super().__init__( f"The following extensions have cyclic dependencies: {extension_names}" )
[docs] class Extension(Plugin, CoreComponent, ProjectDependentFactory): """ Integrate optional functionality with Betty :py:class:`betty.project.Project`s. Read more about :doc:`/development/plugin/extension`. To test your own subclasses, use :py:class:`betty.test_utils.project.extension.ExtensionTestBase`. """
[docs] def __init__(self, project: Project): assert type(self) is not Extension super().__init__() self._project = project
[docs] @override @classmethod async def new_for_project(cls, project: Project) -> Self: return cls(project)
[docs] def register_event_handlers(self, registry: EventHandlerRegistry) -> None: """ Register event handlers with the project. """ pass
@property def project(self) -> Project: """ The project this extension runs within. """ return self._project
[docs] @classmethod def depends_on(cls) -> set[PluginIdentifier[Extension]]: """ The extensions this one depends on, and comes after. """ return set()
[docs] @classmethod def comes_after(cls) -> set[PluginIdentifier[Extension]]: """ The extensions that this one comes after. The other extensions may or may not be enabled. """ return set()
[docs] @classmethod def comes_before(cls) -> set[PluginIdentifier[Extension]]: """ The extensions that this one comes before. The other extensions may or may not be enabled. """ return set()
[docs] @classmethod async def requirement(cls) -> Requirement: """ Define the requirement for this extension to be enabled. This defaults to the extension's dependencies. """ return await Dependencies.new(cls) # type: ignore[no-any-return]
[docs] @classmethod def assets_directory_path(cls) -> Path | None: """ Return the path on disk where the extension's assets are located. This may be anywhere in your Python package. """ return None
_ExtensionT = TypeVar("_ExtensionT", bound=Extension) EXTENSION_REPOSITORY: PluginRepository[Extension] = EntryPointPluginRepository( "betty.extension" ) """ The project extension plugin repository. Read more about :doc:`/development/plugin/extension`. """
[docs] class Theme(Extension): """ An extension that is a front-end theme. """ pass # pragma: no cover
[docs] class ConfigurableExtension( Extension, Generic[_ConfigurationT], Configurable[_ConfigurationT] ): """ A configurable extension. """
[docs] def __init__(self, project: Project): super().__init__(project) self._configuration = self.default_configuration()
[docs] @classmethod @abstractmethod def default_configuration(cls) -> _ConfigurationT: """ Get this extension's default configuration. """ pass
ExtensionTypeGraph = MutableMapping[type[Extension], set[type[Extension]]]
[docs] async def build_extension_type_graph( extension_types: Iterable[type[Extension]], ) -> ExtensionTypeGraph: """ Build a dependency graph of the given extension types. """ extension_types_graph: ExtensionTypeGraph = defaultdict(set) # Add dependencies to the extension graph. for extension_type in extension_types: await _extend_extension_type_graph(extension_types_graph, extension_type) # Now all dependencies have been collected, extend the graph with optional extension orders. for extension_type in extension_types: for before_identifier in extension_type.comes_before(): before = ( await EXTENSION_REPOSITORY.get(before_identifier) if isinstance(before_identifier, str) else before_identifier ) if before in extension_types_graph: extension_types_graph[before].add(extension_type) for after_identifier in extension_type.comes_after(): after = ( await EXTENSION_REPOSITORY.get(after_identifier) if isinstance(after_identifier, str) else after_identifier ) if after in extension_types_graph: extension_types_graph[extension_type].add(after) return extension_types_graph
async def _extend_extension_type_graph( graph: ExtensionTypeGraph, extension_type: type[Extension] ) -> None: dependencies = [ await EXTENSION_REPOSITORY.get(dependency_identifier) if isinstance(dependency_identifier, str) else dependency_identifier for dependency_identifier in extension_type.depends_on() ] # Ensure each extension type appears in the graph, even if they're isolated. graph.setdefault(extension_type, set()) for dependency in dependencies: seen_dependency = dependency in graph graph[extension_type].add(dependency) if not seen_dependency: await _extend_extension_type_graph(graph, dependency)
[docs] class Dependencies(AllRequirements): """ Check a dependent's dependency requirements. """
[docs] @internal def __init__( self, dependent: type[Extension], extension_id_to_type_map: PluginIdToTypeMap[Extension], dependency_requirements: Sequence[Requirement], ): super().__init__(*dependency_requirements) self._dependent = dependent self._extension_id_to_type_map = extension_id_to_type_map
[docs] @classmethod async def new(cls, dependent: type[Extension]) -> Self: """ Create a new instance. """ try: dependency_requirements = [ await ( await EXTENSION_REPOSITORY.get(dependency_identifier) if isinstance(dependency_identifier, str) else dependency_identifier ).requirement() for dependency_identifier in dependent.depends_on() ] except RecursionError: raise CyclicDependencyError([dependent]) from None else: return cls( dependent, await EXTENSION_REPOSITORY.map(), dependency_requirements )
[docs] @override def summary(self) -> Localizable: return _("{dependent_label} requires {dependency_labels}.").format( dependent_label=self._dependent.plugin_label(), dependency_labels=call( lambda localizer: ", ".join( self._extension_id_to_type_map[dependency_identifier] .plugin_label() .localize(localizer) for dependency_identifier in self._dependent.depends_on() ), ), )