Coverage for src/epublib/xml_element.py: 100%
373 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-07 13:19 -0300
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-07 13:19 -0300
1import dataclasses
2import enum
3import inspect
4import typing
5from abc import ABC
6from collections.abc import Generator, Iterable, Sequence
7from dataclasses import dataclass
8from datetime import datetime
9from functools import cache
10from pathlib import Path
11from typing import (
12 Annotated,
13 Any,
14 Callable,
15 ClassVar,
16 ForwardRef,
17 Protocol,
18 Self,
19 SupportsIndex,
20 cast,
21 get_args,
22 overload,
23 override,
24 runtime_checkable,
25)
27import bs4
29from epublib.exceptions import EPUBError
30from epublib.identifier import EPUBId
31from epublib.util import (
32 attr_to_str,
33 datetime_to_str,
34 get_absolute_href,
35 get_actual_tag_position,
36 get_relative_href,
37 new_id_in_tag,
38 parse_int,
39 remove_optional_type,
40 split_fragment,
41 strip_fragment,
42 strip_type_parameters,
43)
45type AttributeValue = str | datetime | bool | list[str] | EPUBId | int
48_sentinel_tag = bs4.BeautifulSoup("", "xml").new_tag("sentinel")
51class SyncType(enum.Enum):
52 ATTR = enum.auto() # Sync with tag attribute
53 STRING = enum.auto() # sync with tag string
54 NAME = enum.auto() # sync with tag name
57@dataclass
58class XMLAttribute:
59 """
60 Represents the relation between the attribute of a XML tag and its
61 representation in an object.
62 """
64 init_name: str | None = None
65 name: str = dataclasses.field(init=False, repr=False)
66 sync: SyncType = SyncType.ATTR
67 get: str | Callable[[bs4.Tag], bs4.Tag | None] | None = None
68 create: str | Callable[[bs4.BeautifulSoup, bs4.Tag], bs4.Tag] | None = None
69 prefix: str = ""
70 typ: type[AttributeValue] = dataclasses.field(init=False, repr=False)
71 init: bool = dataclasses.field(init=False, repr=False)
73 def __post_init__(self):
74 pass
76 def get_tag(self, tag: bs4.Tag) -> bs4.Tag | None:
77 if self.get is None:
78 return tag
80 if isinstance(self.get, str):
81 return tag.select_one(f"& > {self.get}")
83 return self.get(tag)
85 def create_tag(self, soup: bs4.BeautifulSoup, tag: bs4.Tag) -> bs4.Tag:
86 if self.create is None:
87 return tag
89 if isinstance(self.create, str):
90 new_tag = soup.new_tag(self.create)
91 __ = tag.insert(0, new_tag)
92 return new_tag
94 return self.create(soup, tag)
97@runtime_checkable
98class _XMLAttributeMetadataProtocol(Protocol):
99 __metadata__: tuple[XMLAttribute, ...]
100 __origin__: type[AttributeValue]
103@dataclass
104class BaseElement[S: bs4.BeautifulSoup = bs4.BeautifulSoup](ABC):
105 soup: S = dataclasses.field(repr=False)
106 tag: bs4.Tag = dataclasses.field(default=_sentinel_tag, repr=False)
108 def __post_init__(self):
109 pass
112@dataclass(kw_only=True)
113class XMLElement[S: bs4.BeautifulSoup = bs4.BeautifulSoup](
114 BaseElement[S],
115 ABC,
116):
117 """
118 Abstract base class for an XML element. Responsible for syncing object
119 and tag, and exposing important tag attributes as convenient
120 instance attributes
121 """
123 tag_name: ClassVar[str]
124 __cached_attributes: ClassVar[dict[type[Self], dict[str, XMLAttribute]]] = {}
126 def __post_init__(self):
127 if self.tag is _sentinel_tag:
128 self.create_tag()
130 super().__post_init__()
132 def get_tag_name(self) -> str:
133 try:
134 return self.tag_name
135 except AttributeError as error:
136 raise NotImplementedError(
137 f"{self.__class__.__name__} must define a class variable "
138 "`tag_name` with the name of the XML tag, or override the "
139 "`get_tag_name` method."
140 ) from error
142 @classmethod
143 def _get_attributes(cls) -> dict[str, XMLAttribute]:
144 """
145 Infer XML attributes from dataclass fields
146 """
147 if cls.__cached_attributes.get(cls):
148 return cls.__cached_attributes[cls]
150 attributes: dict[str, XMLAttribute] = {}
151 for field in dataclasses.fields(cls):
152 if (
153 typing.get_origin(field.type) is Annotated
154 and isinstance(field.type, _XMLAttributeMetadataProtocol)
155 and field.type.__metadata__
156 ):
157 attribute = field.type.__metadata__[0]
158 attribute.name = (
159 field.name.replace("_", "-")
160 if attribute.init_name is None
161 else attribute.init_name
162 )
163 attribute.typ = strip_type_parameters(field.type.__origin__)
164 attribute.init = field.init
165 attributes[field.name] = attribute
167 cls.__cached_attributes[cls] = attributes
168 return attributes
170 @override
171 def __setattr__(self, name: str, value: AttributeValue | None) -> None:
172 ret = super().__setattr__(name, value)
173 self.update_tag(name, value)
174 return ret
176 def create_tag(self):
177 self.tag: bs4.Tag = self.soup.new_tag(self.get_tag_name())
178 for name in self._get_attributes().keys():
179 self.update_tag(name, cast(AttributeValue | None, getattr(self, name)))
181 def update_tag(self, name: str, value: AttributeValue | None):
182 if self.tag is _sentinel_tag:
183 return
185 attribute = self._get_attributes().get(name)
186 if attribute is None:
187 return
189 value = self.attribute_to_str(name, value) if value is not None else None
191 tag = attribute.get_tag(self.tag)
192 if tag is None and value is not None:
193 tag = attribute.create_tag(self.soup, self.tag)
195 if tag is None:
196 return
198 match attribute.sync:
199 case SyncType.ATTR:
200 if value is None:
201 del tag[attribute.name]
202 else:
203 tag[attribute.name] = value
204 case SyncType.STRING:
205 if value is None and tag is not self.tag:
206 tag.decompose()
207 else:
208 tag.string = "" if value is None else value
209 case SyncType.NAME:
210 if not value:
211 raise EPUBError(
212 f"{self.__class__.__name__}.{name} cannot be empty or None"
213 )
215 self.tag.name = value
216 if attribute.prefix:
217 self.tag.prefix = attribute.prefix
219 @classmethod
220 def _read_from_tag(
221 cls,
222 tag: bs4.Tag,
223 attribute: XMLAttribute,
224 ) -> str | None:
225 tag_or_none = attribute.get_tag(tag)
227 if tag_or_none is None:
228 return None
230 tag = tag_or_none
232 match attribute.sync:
233 case SyncType.ATTR:
234 return attr_to_str(tag.get(attribute.name))
235 case SyncType.STRING:
236 return tag.get_text()
237 case SyncType.NAME:
238 return tag.name
240 @classmethod
241 def from_tag(
242 cls,
243 soup: S,
244 tag: bs4.Tag,
245 **kwargs: AttributeValue,
246 ) -> Self:
247 attributes = cls._get_attributes()
248 tag_kwargs = {
249 name: cls.str_to_attribute(
250 cls._read_from_tag(tag, attribute),
251 attribute.typ,
252 )
253 for name, attribute in attributes.items()
254 if attribute.init
255 }
257 instance = cls(
258 soup=soup,
259 tag=tag,
260 **tag_kwargs,
261 **kwargs,
262 )
264 return instance
266 def attribute_to_str(
267 self,
268 name: str, # type: ignore[reportUnusedParameter]
269 value: AttributeValue,
270 ) -> str:
271 """
272 Convert an attribute of this object to a string suitable for
273 XML serialization.
274 """
275 if isinstance(value, datetime):
276 return datetime_to_str(value)
278 if isinstance(value, bool):
279 return "yes" if value else "no"
281 if isinstance(value, int):
282 return str(value)
284 if isinstance(value, list):
285 return " ".join(str(el) for el in value)
287 return value
289 @classmethod
290 def str_to_attribute(
291 cls,
292 value: str | None,
293 typ: type[AttributeValue],
294 ) -> AttributeValue | None:
295 """
296 Convert a string from an XML attribute to an attribute of this
297 object.
298 """
299 if value is None:
300 return None
302 typ = remove_optional_type(typ)
303 if issubclass(typ, list):
304 return value.split()
306 if issubclass(typ, datetime):
307 return datetime.fromisoformat(value)
309 if issubclass(typ, bool):
310 return value != "no"
312 if issubclass(typ, int):
313 return parse_int(value)
315 if issubclass(typ, EPUBId):
316 return EPUBId(value)
318 return str(value)
321@dataclass(kw_only=True)
322class HrefElement[S: bs4.BeautifulSoup = bs4.BeautifulSoup](XMLElement[S], ABC):
323 """
324 XMLElement with a reference to a file. This class handles the logic
325 of syncing the 'href' (relative filename) and 'filename' (absolute
326 filename).
327 """
329 filename: str
330 href: Annotated[str, XMLAttribute()] = ""
331 own_filename: str
333 @property
334 def pk(self) -> str:
335 return self.filename
337 def href_to_filename(self, href: str) -> str:
338 return get_absolute_href(self.own_filename, href)
340 def filename_to_href(self, filename: str) -> str:
341 return get_relative_href(self.own_filename, filename)
343 def __post_init__(self):
344 if not self.href and self.filename:
345 self.href = self.filename_to_href(self.filename)
346 elif not self.filename and self.href:
347 self.filename = self.href_to_filename(self.href)
349 super().__post_init__()
351 @override
352 def __setattr__(self, name: str, value: AttributeValue | None) -> None:
353 super().__setattr__(name, value)
354 if hasattr(self, "own_filename"):
355 if name == "filename":
356 if not value:
357 super().__setattr__("href", value)
358 elif isinstance(value, str | Path):
359 super().__setattr__("href", self.filename_to_href(value))
361 elif name == "href":
362 if not value:
363 super().__setattr__("filename", value)
364 elif isinstance(value, str | Path):
365 super().__setattr__("filename", self.href_to_filename(value))
367 @classmethod
368 @override
369 def from_tag( # type: ignore[reportIncompatibleMethodOverride]
370 cls,
371 soup: S,
372 tag: bs4.Tag,
373 own_filename: str,
374 **kwargs: AttributeValue,
375 ) -> Self:
376 return super().from_tag(
377 soup,
378 tag,
379 filename="",
380 own_filename=own_filename,
381 **kwargs,
382 )
385# When generic constraints to generics become supported, we should use this:
386# XMLChildProtocol[S: bs4.BeautifulSoup = bs4.BeautifulSoup](Protocol)
387#
388# And then:
389# class XMLParent[S: bs4.BeautifulSoup = bs4.BeautifulSoup, I: XMLChildProtocol[S]](...)
392class XMLChildProtocol(Protocol):
393 tag: bs4.Tag
395 @property
396 def pk(self) -> str: ...
398 @classmethod
399 def from_tag(
400 cls,
401 soup: Any, # type: ignore[reportAny]
402 tag: bs4.Tag,
403 **kwargs: Any, # type: ignore[reportAny]
404 ) -> Self: ...
407@dataclass(kw_only=True)
408class XMLParent[I: XMLChildProtocol, S: bs4.BeautifulSoup = bs4.BeautifulSoup](
409 BaseElement[S],
410 ABC,
411):
412 """Abstract base class for an XML element that contains other XML elements."""
414 def __post_init__(self):
415 super().__post_init__()
416 self._items: list[I] = list(self.parse_items())
418 @classmethod
419 @cache
420 def _child_class(cls) -> type[I]:
421 try:
422 parent_base = next(
423 c
424 for c in cast(tuple[type[Any], ...], cls.__orig_bases__) # type: ignore[reportAttributeAccessIssue]
425 if issubclass(typing.get_origin(c) or c, XMLParent)
426 )
427 typ = get_args(parent_base)[0] # type: ignore[reportAttributeAccessIssue]
428 if isinstance(typ, ForwardRef) and typ.__forward_arg__ == cls.__name__:
429 return cast(type[I], cls)
430 assert inspect.isclass(typ)
431 return typ
432 except (AttributeError, IndexError, AssertionError, StopIteration):
433 raise NotImplementedError(
434 f"Cannot determine child class for {cls.__name__}. Specify "
435 "the generic type of override _child_class."
436 )
438 def get_child_tags(self) -> Iterable[bs4.Tag]:
439 parent_tag = self.parent_tag
440 if parent_tag is None:
441 return []
443 child_tag_name = getattr(self._child_class(), "tag_name", True)
445 return parent_tag.find_all(child_tag_name, recursive=False)
447 def _get_common_dataclass_attrs(
448 self,
449 exclude: Sequence[str] = (),
450 exclude_tag: bool = False,
451 exlcude_soup: bool = False,
452 include_self_as_parent: bool = True,
453 ) -> dict[str, AttributeValue]:
454 child_class = self._child_class()
455 child_field_names = {
456 field.name
457 for field in dataclasses.fields(
458 child_class, # type: ignore[reportArgumentType]
459 )
460 }
462 kwargs = {
463 field.name: getattr(self, field.name)
464 for field in dataclasses.fields(self)
465 if field.name in child_field_names
466 and field.name not in exclude
467 and (not exclude_tag or field.name != "tag")
468 and (not exlcude_soup or field.name != "soup")
469 }
471 if include_self_as_parent and "parent" in child_field_names:
472 kwargs["parent"] = self
474 return kwargs
476 def parse_items(self) -> Sequence[I]:
477 """Parse child items from self.tag and return their representations in a list."""
478 # This generic implementation will get all tag children of the
479 # parent element, and call the _child_class().from_tag method.
480 # If there are any dataclass attributes on the child class that
481 # have the same name as in this own class, they will be passed
482 # to the from_tag method.
484 child_class = self._child_class()
485 kwargs = self._get_common_dataclass_attrs(exclude_tag=True)
487 return [
488 child_class.from_tag(tag=tag, **kwargs) for tag in self.get_child_tags()
489 ]
491 @overload
492 def get[J: XMLChildProtocol](self, pk: str, cls: type[J]) -> J | None: ...
493 @overload
494 def get(self, pk: str, cls: type[I] | None = None) -> I | None: ...
496 def get(self, pk: str, cls: type[I] | None = None):
497 return next(
498 (
499 item
500 for item in self._items
501 if item.pk == pk and (cls is None or isinstance(item, cls))
502 ),
503 None,
504 )
506 def __getitem__(self, pk: str | SupportsIndex):
507 if isinstance(pk, SupportsIndex):
508 return self._items[pk]
510 value = self.get(pk)
511 if value is None:
512 raise KeyError(pk)
513 return value
515 def create_parent_tag(self) -> bs4.Tag:
516 return self.tag
518 @property
519 def parent_tag(self) -> bs4.Tag | None:
520 return self.tag
522 # When generic constraints to generics become supported, we should use this:
523 # def add_item[T: I](self, item: T) -> T:
524 # def insert_item[T: I](self, position: int, item: T) -> T:
525 def add_item(self, item: I) -> I:
526 return self.insert_item(len(self._items), item)
528 def insert_item(self, position: int | None, item: I) -> I:
529 parent_tag = self.parent_tag
530 if not parent_tag:
531 parent_tag = self.create_parent_tag()
533 assert item.tag is not self.tag
535 if position is None:
536 self._items.append(item)
537 __ = parent_tag.append(item.tag)
538 else:
539 self._items.insert(position, item)
540 child_tag_name: str | None = getattr(self._child_class(), "tag_name", None)
541 actual_position = get_actual_tag_position(
542 parent_tag,
543 position,
544 child_tag_name,
545 )
546 __ = parent_tag.insert(actual_position, item.tag)
548 return item
550 def create_child(self, **kwargs: AttributeValue | None) -> I:
551 common = self._get_common_dataclass_attrs(exclude_tag=True)
553 return self._child_class()(
554 **common,
555 **kwargs,
556 )
558 def insert(self, position: int | None, **kwargs: AttributeValue | None) -> I:
559 item = self.create_child(**kwargs)
560 return self.insert_item(position, item)
562 def add(self, **kwargs: AttributeValue | None) -> I:
563 return self.insert(None, **kwargs)
565 def remove(self, pk: str) -> None:
566 item = self.get(pk)
567 if item:
568 return self.remove_item(item)
570 def remove_item(self, item: I) -> None:
571 self._items.remove(item)
572 item.tag.decompose()
574 @property
575 def items(self):
576 return tuple(self._items)
578 def get_new_id(self, base: str | EPUBId) -> EPUBId:
579 return new_id_in_tag(EPUBId.to_valid(base), self.soup)
581 @override
582 def __repr__(self):
583 return f"{self.__class__.__name__}({len(self.items)} items)"
586class HrefChildProtocol(XMLChildProtocol, Protocol):
587 href: str
588 filename: str
591@dataclass(kw_only=True)
592class ParentOfHref[
593 I: HrefChildProtocol,
594 S: bs4.BeautifulSoup = bs4.BeautifulSoup,
595](
596 XMLParent[I, S],
597 ABC,
598):
599 """
600 An XML element that contains other XML elements that have hrefs.
601 """
603 own_filename: str
605 @overload
606 def get[J: HrefChildProtocol](
607 self,
608 filename: str | Path,
609 cls: type[J],
610 ignore_fragment: bool = False,
611 ) -> J | None: ...
612 @overload
613 def get(
614 self,
615 filename: str | Path,
616 cls: type[I] | None = None,
617 ignore_fragment: bool = False,
618 ) -> I | None: ...
620 @override
621 def get( # type: ignore[reportIncompatibleMethodOverride]
622 self,
623 filename: str | Path,
624 cls: type[I] | None = None,
625 ignore_fragment: bool = False,
626 ):
627 filename = strip_fragment(str(filename)) if ignore_fragment else str(filename)
629 return next(
630 (
631 item
632 for item in self._items
633 if (strip_fragment(item.filename) if ignore_fragment else item.filename)
634 == filename
635 and (cls is None or isinstance(item, cls))
636 ),
637 None,
638 )
640 @override
641 def remove( # type: ignore[reportIncompatibleMethodOverride]
642 self,
643 filename: str | Path,
644 ignore_fragment: bool = True,
645 ) -> None:
646 item = self.get(filename, ignore_fragment=ignore_fragment)
647 if item:
648 self.remove_item(item)
650 def remove_all(self, filename: str | Path) -> None:
651 while self.get(filename, ignore_fragment=True):
652 self.remove(filename, ignore_fragment=True)
654 @override
655 def _get_common_dataclass_attrs(
656 self,
657 exclude: Sequence[str] = (),
658 exclude_tag: bool = False,
659 exlcude_soup: bool = False,
660 include_self_as_parent: bool = True,
661 ) -> dict[str, AttributeValue]:
662 child_class = self._child_class()
664 return super()._get_common_dataclass_attrs(
665 exclude=(
666 "filename",
667 *exclude,
668 *child_class._get_attributes().keys(), # type: ignore[reportPrivateUsage]
669 ),
670 exclude_tag=exclude_tag,
671 exlcude_soup=exlcude_soup,
672 include_self_as_parent=include_self_as_parent,
673 )
676class ParentProtocol(Protocol):
677 @property
678 def items(self) -> Sequence[XMLChildProtocol]: ...
680 def insert_item( # type: ignore[reportAny]
681 self,
682 position: int,
683 item: Any, # type: ignore[reportAny]
684 ) -> Any: ... # type: ignore[reportAny]
686 def remove_item(self, item: Any) -> None: ... # type: ignore[reportAny]
689class RecursiveChildProtocol(XMLChildProtocol, Protocol):
690 def max_depth(self, base: int = 1) -> int: ...
693class RecursiveParent[
694 I: RecursiveChildProtocol,
695 S: bs4.BeautifulSoup = bs4.BeautifulSoup,
696](XMLParent[I, S], ABC):
697 def max_depth(self, base: int = 1) -> int:
698 if not self.items:
699 return base
701 return max(item.max_depth(base + 1) for item in self.items)
704class RecursiveHrefChildProtocol(
705 RecursiveChildProtocol,
706 HrefChildProtocol,
707 Protocol,
708):
709 def items_referencing(
710 self,
711 filename: str,
712 ignore_fragment: bool = False,
713 ) -> Generator[XMLChildProtocol]: ...
715 @classmethod
716 def _get_attributes(cls) -> dict[str, XMLAttribute]: ...
718 @property
719 def parent(self) -> ParentProtocol | None: ...
720 @property
721 def items(self) -> Sequence[Self]: ...
722 @property
723 def nodes(self) -> Generator[Self]: ...
724 def remove_nodes(self, filename: str, ignore_fragments: bool = True) -> None: ...
727class HrefRoot[
728 I: RecursiveHrefChildProtocol,
729 S: bs4.BeautifulSoup = bs4.BeautifulSoup,
730](
731 RecursiveParent[I, S],
732 ParentOfHref[I, S],
733 ABC,
734):
735 """Root of a tree of HrefElements."""
737 def items_referencing(
738 self,
739 filename: str,
740 ignore_fragment: bool = False,
741 ) -> Generator[Self | I]:
742 for item in self.items:
743 yield from (
744 cast(I, it) for it in item.items_referencing(filename, ignore_fragment)
745 )
747 @property
748 def nodes(self) -> Generator[I | Self]:
749 for item in self.items:
750 yield from item.nodes
752 def remove_nodes(self, filename: Path | str, ignore_fragments: bool = True) -> None:
753 filename = strip_fragment(str(filename)) if ignore_fragments else str(filename)
755 index = 0
756 while index < len(self.items):
757 item = self.items[index]
758 item.remove_nodes(filename, ignore_fragments)
759 item_filename = (
760 strip_fragment(item.filename) if ignore_fragments else item.filename
761 )
763 if item_filename == filename:
764 for child in item.items:
765 __ = self.insert_item(index, child)
766 index += 1
768 self.remove_item(item)
769 index -= 1
771 index += 1
774@dataclass(kw_only=True)
775class HrefRecursiveElement[
776 I: RecursiveHrefChildProtocol,
777 S: bs4.BeautifulSoup = bs4.BeautifulSoup,
778](
779 HrefRoot[I, S],
780 HrefElement[S],
781 ABC,
782):
783 """Node of a tree of HrefElements."""
785 parent: ParentProtocol | None = None
787 @property
788 @override
789 def nodes(self) -> Generator[I | Self]:
790 yield self
791 for item in self.items:
792 yield from item.nodes
794 @override
795 def items_referencing(
796 self,
797 filename: str,
798 ignore_fragment: bool = False,
799 ) -> Generator[Self | I]:
800 my_base, my_fragment = split_fragment(self.filename)
801 base, fragment = split_fragment(filename)
802 if my_base == base and (
803 ignore_fragment or fragment is None or my_fragment == fragment
804 ):
805 yield self
807 yield from super().items_referencing(filename, ignore_fragment)
809 @override
810 def _get_common_dataclass_attrs(
811 self,
812 exclude: Sequence[str] = (),
813 exclude_tag: bool = False,
814 exlcude_soup: bool = False,
815 include_self_as_parent: bool = True,
816 ) -> dict[str, AttributeValue]:
817 return super()._get_common_dataclass_attrs(
818 (
819 *exclude,
820 *self._get_attributes().keys(),
821 ),
822 exclude_tag,
823 exlcude_soup,
824 include_self_as_parent,
825 )
827 def add_item_after_self(self, item: I) -> I:
828 if self.parent is None:
829 raise EPUBError(f"{self} has no parent")
831 if hasattr(item, "parent"):
832 item.parent = self.parent # type: ignore[reportAttributeAccessIssue]
834 try:
835 index = self.parent.items.index(self)
836 except ValueError as error:
837 raise EPUBError(f"{self} not found in parent's items") from error
839 self.parent.insert_item(index + 1, item)
840 return item
842 def add_after_self(self, **kwargs: AttributeValue | None) -> I:
843 return self.add_item_after_self(self.create_child(**kwargs))