Coverage for src/epublib/xml_element.py: 93%
122 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-18 16:07 -0300
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-18 16:07 -0300
1import operator
2from abc import ABC, abstractmethod
3from dataclasses import dataclass, field, fields
4from datetime import datetime
5from itertools import islice
6from types import UnionType
7from typing import (
8 ClassVar,
9 Self,
10 cast,
11 get_args,
12 get_origin,
13 overload,
14 override,
15)
17import bs4
19from epublib.identifier import EPUBId
21sentinel_tag = bs4.BeautifulSoup("", "xml").new_tag("sentinel")
24type ValueType = str | datetime | bool | list[str] | EPUBId
27@dataclass(kw_only=True)
28class XMLElement(ABC):
29 """Abstract base class for an XML element."""
31 name: str
32 tag: bs4.Tag = field(default=sentinel_tag)
34 obj_to_tag: ClassVar[dict[str, str]] = {}
35 exclude_from_tag: ClassVar[list[str]] = ["tag"]
37 @property
38 @abstractmethod
39 def tag_name(self) -> str:
40 raise NotImplementedError
42 def __post_init__(self):
43 if self.tag is sentinel_tag:
44 self.tag = self.create_tag(bs4.BeautifulSoup("", "xml"))
46 def value_to_str(self, _attr: str, /, value: ValueType) -> str:
47 if isinstance(value, datetime):
48 return value.isoformat()
50 if isinstance(value, bool):
51 return "yes" if value else "no"
53 if isinstance(value, list):
54 return " ".join(str(el) for el in value)
56 return value
58 @staticmethod
59 def _resolve_type[T: ValueType | UnionType | None](typ: type[T]):
60 origin: type[T] = get_origin(typ) or typ
62 if origin is UnionType:
63 args = cast(tuple[type[T], ...], get_args(typ))
64 origin = cast(
65 type[T],
66 operator.or_(
67 *(
68 cast(
69 type[T],
70 get_origin(arg) or arg, # type: ignore[reportGeneralTypeIssues]
71 )
72 for arg in args
73 )
74 ),
75 )
77 return origin
79 @classmethod
80 def to_value[T: ValueType | UnionType | None](
81 cls,
82 value: str | None,
83 typ: type[T],
84 ) -> T | None:
85 if value is None:
86 return None
88 typ = cls._resolve_type(typ)
90 if issubclass(list, typ):
91 return value.split() # type: ignore[reportReturnType]
93 if issubclass(datetime, typ):
94 return datetime.fromisoformat(value) # type: ignore[reportReturnType]
96 if issubclass(bool, typ):
97 return value != "no" # type: ignore[reportReturnType]
99 if issubclass(EPUBId, typ):
100 return EPUBId(value) # type: ignore[reportReturnType]
102 return str(value) # type: ignore[reportReturnType]
104 @override
105 def __setattr__(self, name: str, value: ValueType | None) -> None:
106 ret = super().__setattr__(name, value)
107 if name != "tag":
108 self.update_tag(name, value)
109 return ret
111 def create_tag(self, soup: bs4.BeautifulSoup, **kwargs: str) -> bs4.Tag:
112 tag = soup.new_tag(self.tag_name)
114 for fld in fields(self):
115 val: ValueType | None = getattr(self, fld.name, None)
116 if val is not None and fld.name not in self.exclude_from_tag:
117 attr = self.obj_to_tag.get(fld.name, fld.name)
118 tag[attr.replace("_", "-")] = self.value_to_str(fld.name, val)
120 for key, val in kwargs.items():
121 tag[key] = val
123 return tag
125 @classmethod
126 def from_tag(cls, tag: bs4.Tag, **kwargs: str) -> Self:
127 return cls(
128 tag=tag,
129 **kwargs, # type: ignore[reportUnknownArgumentType]
130 **{
131 field.name: cls.to_value(
132 tag.attrs.get( # type: ignore[reportArgumentType]
133 cls.obj_to_tag.get(field.name, field.name)
134 .replace("_", "-")
135 .lower()
136 ),
137 field.type, # type: ignore[reportArgumentType]
138 )
139 for field in fields(cls)
140 if field.name not in cls.exclude_from_tag
141 },
142 )
144 def update_tag(self, field: str, value: ValueType | None):
145 if field in self.exclude_from_tag:
146 return
148 attr = self.obj_to_tag.get(field, field).replace("_", "-").lower()
149 if value is None:
150 del self.tag[attr]
151 else:
152 self.tag[attr] = self.value_to_str(field, value)
154 @override
155 def __repr__(self):
156 name_field_name = self.obj_to_tag.get("name", "name")
157 return f"{self.__class__.__name__}({name_field_name}={self.name})"
160class XMLParent[I: XMLElement](ABC):
161 """Abstract base class for an XML element that contains other XML elements."""
163 default_item_type: type[I] = XMLElement # type: ignore[reportAssignmentType]
164 tag_name: str | None = None
166 def __init__(
167 self,
168 tag: bs4.Tag,
169 ) -> None:
170 self.tag: bs4.Tag = tag
171 self._items: list[I] = self.create_items()
173 @abstractmethod
174 def create_items(self) -> list[I]:
175 raise NotImplementedError
177 @overload
178 def get[J: XMLElement](self, name: str, cls: type[J]) -> J | None: ...
180 @overload
181 def get(self, name: str, cls: type[I] | None = None) -> I | None: ...
183 def get(self, name: str, cls: type[I] | None = None):
184 if cls is None:
185 cls = self.default_item_type
187 return next(
188 (
189 item
190 for item in self._items
191 if item.name == name and isinstance(item, cls)
192 ),
193 None,
194 )
196 def __getitem__(self, name: str):
197 value = self.get(name)
198 if value is None:
199 raise KeyError(name)
200 return value
202 def add_item(self, item: I) -> I:
203 self._items.append(item)
204 __ = self.tag.append(item.tag)
206 return item
208 def insert_item(self, position: int, item: I) -> I:
209 self._items.insert(position, item)
210 try:
211 nth_child = cast(
212 bs4.Tag,
213 next(islice(self.tag.find_all(True, recursive=False), position, None)),
214 )
215 __ = nth_child.insert_before(item.tag)
216 except StopIteration:
217 __ = self.tag.append(item.tag)
219 return item
221 def remove_item(self, item: I) -> None:
222 self._items.remove(item)
223 item.tag.decompose()
225 @property
226 def items(self):
227 return tuple(self._items)
229 @override
230 def __repr__(self):
231 return f"{self.__class__.__name__}({len(self.items)} items)"