Coverage for src/epublib/xml_element.py: 93%

122 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-09-18 16:07 -0300

1import operator 

2from abc import ABC, abstractmethod 

3from dataclasses import dataclass, field, fields 

4from datetime import datetime 

5from itertools import islice 

6from types import UnionType 

7from typing import ( 

8 ClassVar, 

9 Self, 

10 cast, 

11 get_args, 

12 get_origin, 

13 overload, 

14 override, 

15) 

16 

17import bs4 

18 

19from epublib.identifier import EPUBId 

20 

21sentinel_tag = bs4.BeautifulSoup("", "xml").new_tag("sentinel") 

22 

23 

24type ValueType = str | datetime | bool | list[str] | EPUBId 

25 

26 

27@dataclass(kw_only=True) 

28class XMLElement(ABC): 

29 """Abstract base class for an XML element.""" 

30 

31 name: str 

32 tag: bs4.Tag = field(default=sentinel_tag) 

33 

34 obj_to_tag: ClassVar[dict[str, str]] = {} 

35 exclude_from_tag: ClassVar[list[str]] = ["tag"] 

36 

37 @property 

38 @abstractmethod 

39 def tag_name(self) -> str: 

40 raise NotImplementedError 

41 

42 def __post_init__(self): 

43 if self.tag is sentinel_tag: 

44 self.tag = self.create_tag(bs4.BeautifulSoup("", "xml")) 

45 

46 def value_to_str(self, _attr: str, /, value: ValueType) -> str: 

47 if isinstance(value, datetime): 

48 return value.isoformat() 

49 

50 if isinstance(value, bool): 

51 return "yes" if value else "no" 

52 

53 if isinstance(value, list): 

54 return " ".join(str(el) for el in value) 

55 

56 return value 

57 

58 @staticmethod 

59 def _resolve_type[T: ValueType | UnionType | None](typ: type[T]): 

60 origin: type[T] = get_origin(typ) or typ 

61 

62 if origin is UnionType: 

63 args = cast(tuple[type[T], ...], get_args(typ)) 

64 origin = cast( 

65 type[T], 

66 operator.or_( 

67 *( 

68 cast( 

69 type[T], 

70 get_origin(arg) or arg, # type: ignore[reportGeneralTypeIssues] 

71 ) 

72 for arg in args 

73 ) 

74 ), 

75 ) 

76 

77 return origin 

78 

79 @classmethod 

80 def to_value[T: ValueType | UnionType | None]( 

81 cls, 

82 value: str | None, 

83 typ: type[T], 

84 ) -> T | None: 

85 if value is None: 

86 return None 

87 

88 typ = cls._resolve_type(typ) 

89 

90 if issubclass(list, typ): 

91 return value.split() # type: ignore[reportReturnType] 

92 

93 if issubclass(datetime, typ): 

94 return datetime.fromisoformat(value) # type: ignore[reportReturnType] 

95 

96 if issubclass(bool, typ): 

97 return value != "no" # type: ignore[reportReturnType] 

98 

99 if issubclass(EPUBId, typ): 

100 return EPUBId(value) # type: ignore[reportReturnType] 

101 

102 return str(value) # type: ignore[reportReturnType] 

103 

104 @override 

105 def __setattr__(self, name: str, value: ValueType | None) -> None: 

106 ret = super().__setattr__(name, value) 

107 if name != "tag": 

108 self.update_tag(name, value) 

109 return ret 

110 

111 def create_tag(self, soup: bs4.BeautifulSoup, **kwargs: str) -> bs4.Tag: 

112 tag = soup.new_tag(self.tag_name) 

113 

114 for fld in fields(self): 

115 val: ValueType | None = getattr(self, fld.name, None) 

116 if val is not None and fld.name not in self.exclude_from_tag: 

117 attr = self.obj_to_tag.get(fld.name, fld.name) 

118 tag[attr.replace("_", "-")] = self.value_to_str(fld.name, val) 

119 

120 for key, val in kwargs.items(): 

121 tag[key] = val 

122 

123 return tag 

124 

125 @classmethod 

126 def from_tag(cls, tag: bs4.Tag, **kwargs: str) -> Self: 

127 return cls( 

128 tag=tag, 

129 **kwargs, # type: ignore[reportUnknownArgumentType] 

130 **{ 

131 field.name: cls.to_value( 

132 tag.attrs.get( # type: ignore[reportArgumentType] 

133 cls.obj_to_tag.get(field.name, field.name) 

134 .replace("_", "-") 

135 .lower() 

136 ), 

137 field.type, # type: ignore[reportArgumentType] 

138 ) 

139 for field in fields(cls) 

140 if field.name not in cls.exclude_from_tag 

141 }, 

142 ) 

143 

144 def update_tag(self, field: str, value: ValueType | None): 

145 if field in self.exclude_from_tag: 

146 return 

147 

148 attr = self.obj_to_tag.get(field, field).replace("_", "-").lower() 

149 if value is None: 

150 del self.tag[attr] 

151 else: 

152 self.tag[attr] = self.value_to_str(field, value) 

153 

154 @override 

155 def __repr__(self): 

156 name_field_name = self.obj_to_tag.get("name", "name") 

157 return f"{self.__class__.__name__}({name_field_name}={self.name})" 

158 

159 

160class XMLParent[I: XMLElement](ABC): 

161 """Abstract base class for an XML element that contains other XML elements.""" 

162 

163 default_item_type: type[I] = XMLElement # type: ignore[reportAssignmentType] 

164 tag_name: str | None = None 

165 

166 def __init__( 

167 self, 

168 tag: bs4.Tag, 

169 ) -> None: 

170 self.tag: bs4.Tag = tag 

171 self._items: list[I] = self.create_items() 

172 

173 @abstractmethod 

174 def create_items(self) -> list[I]: 

175 raise NotImplementedError 

176 

177 @overload 

178 def get[J: XMLElement](self, name: str, cls: type[J]) -> J | None: ... 

179 

180 @overload 

181 def get(self, name: str, cls: type[I] | None = None) -> I | None: ... 

182 

183 def get(self, name: str, cls: type[I] | None = None): 

184 if cls is None: 

185 cls = self.default_item_type 

186 

187 return next( 

188 ( 

189 item 

190 for item in self._items 

191 if item.name == name and isinstance(item, cls) 

192 ), 

193 None, 

194 ) 

195 

196 def __getitem__(self, name: str): 

197 value = self.get(name) 

198 if value is None: 

199 raise KeyError(name) 

200 return value 

201 

202 def add_item(self, item: I) -> I: 

203 self._items.append(item) 

204 __ = self.tag.append(item.tag) 

205 

206 return item 

207 

208 def insert_item(self, position: int, item: I) -> I: 

209 self._items.insert(position, item) 

210 try: 

211 nth_child = cast( 

212 bs4.Tag, 

213 next(islice(self.tag.find_all(True, recursive=False), position, None)), 

214 ) 

215 __ = nth_child.insert_before(item.tag) 

216 except StopIteration: 

217 __ = self.tag.append(item.tag) 

218 

219 return item 

220 

221 def remove_item(self, item: I) -> None: 

222 self._items.remove(item) 

223 item.tag.decompose() 

224 

225 @property 

226 def items(self): 

227 return tuple(self._items) 

228 

229 @override 

230 def __repr__(self): 

231 return f"{self.__class__.__name__}({len(self.items)} items)"