Coverage for phml\utilities\transform\transform.py: 100%

69 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2023-04-12 14:26 -0500

1"""phml.utilities.transform.transform 

2 

3Utility methods that revolve around transforming or manipulating the ast. 

4""" 

5 

6from functools import wraps 

7from typing import Callable 

8 

9from phml.nodes import Element, Literal, Node, Parent 

10from phml.utilities.misc import heading_rank 

11from phml.utilities.travel.travel import walk 

12from phml.utilities.validate.check import Test, check 

13 

14__all__ = [ 

15 "filter_nodes", 

16 "remove_nodes", 

17 "map_nodes", 

18 "find_and_replace", 

19 "shift_heading", 

20 "replace_node", 

21 "modify_children", 

22] 

23 

24 

25def filter_nodes( 

26 tree: Parent, 

27 condition: Test, 

28 strict: bool = True, 

29): 

30 """Take a given tree and filter the nodes with the condition. 

31 Only nodes passing the condition stay. If the parent node fails, 

32 all children are moved up in scope. Depth first 

33 

34 Same as remove_nodes but keeps the nodes that match. 

35 

36 Args: 

37 tree (Parent): The tree node to filter. 

38 condition (Test): The condition to apply to each node. 

39 

40 Returns: 

41 Parent: The given tree after being filtered. 

42 """ 

43 

44 def filter_children(node): 

45 children = [] 

46 for child in node: 

47 if isinstance(child, Parent): 

48 child = filter_children(child) 

49 if not check(child, condition, strict=strict): 

50 children.extend(child.children or []) 

51 else: 

52 children.append(child) 

53 elif check(child, condition, strict=strict): 

54 children.append(child) 

55 

56 if node.children is not None: 

57 node[:] = children 

58 return node 

59 

60 filter_children(tree) 

61 

62 

63def remove_nodes( 

64 tree: Parent, 

65 condition: Test, 

66 strict: bool = True, 

67): 

68 """Take a given tree and remove the nodes that match the condition. 

69 If a parent node is removed so is all the children. 

70 

71 Same as filter_nodes except removes nodes that match. 

72 

73 Args: 

74 tree (Parent): The parent node to start recursively removing from. 

75 condition (Test): The condition to apply to each node. 

76 """ 

77 

78 def filter_children(node): 

79 if node.children is not None: 

80 node.children = [n for n in node if not check(n, condition, strict=strict)] 

81 for child in node: 

82 if isinstance(child, Parent): 

83 filter_children(child) 

84 

85 filter_children(tree) 

86 

87 

88def map_nodes(tree: Parent, transform: Callable[[Node], Node]): 

89 """Takes a tree and a callable that returns a node and maps each node. 

90 

91 Signature for the transform function should be as follows: 

92 

93 1. Takes a single argument that is the node. 

94 2. Returns any type of node that is assigned to the original node. 

95 

96 ```python 

97 def to_links(node): 

98 return Element("a", {}, node.parent, children=node.children) 

99 if node.type == "element" 

100 else node 

101 ``` 

102 

103 Args: 

104 tree (Parent): Tree to transform. 

105 transform (Callable): The Callable that returns a node that is assigned 

106 to each node. 

107 """ 

108 

109 def recursive_map(node: Parent): 

110 for child in node: 

111 idx = node.index(child) 

112 node[idx] = transform(child) 

113 if isinstance(node[idx], Element): 

114 recursive_map(node[idx]) 

115 

116 recursive_map(tree) 

117 

118 

119def replace_node( 

120 start: Parent, 

121 condition: Test, 

122 replacement: Node | list[Node] | None, 

123 all_nodes: bool = False, 

124 strict: bool = True, 

125): 

126 """Search for a specific node in the tree and replace it with either 

127 a node or list of nodes. If replacement is None the found node is just removed. 

128 

129 Args: 

130 start (Parent): The starting point. 

131 condition (test): Test condition to find the correct node. 

132 replacement (Node | list[Node] | None): What to replace the node with. 

133 """ 

134 

135 # Convert iterator to static list to avoid errors while editing tree 

136 for node in list(walk(start)): 

137 if node != start and check(node, condition, strict=strict): 

138 parent = node.parent 

139 if parent is not None: 

140 idx = parent.index(node) 

141 if replacement is not None: 

142 if isinstance(replacement, list): 

143 parent[idx : idx + 1] = replacement 

144 else: 

145 parent[idx] = replacement 

146 else: 

147 del node.parent[idx] 

148 

149 if not all_nodes: 

150 break 

151 

152 

153def find_and_replace(start: Parent, *replacements: tuple[str, str | Callable]): 

154 """Takes a node and replaces text in Literal.Text 

155 nodes with matching replacements. 

156 

157 First value in each replacement tuple is the regex to match and 

158 the second value is what to replace it with. This can either be 

159 a string or a callable that returns a string or a new node. If 

160 a new node is returned then the text element will be split. 

161 """ 

162 from re import finditer # pylint: disable=import-outside-toplevel 

163 

164 for node in walk(start): 

165 if Literal.is_text(node): 

166 for replacement in replacements: 

167 if isinstance(replacement[1], str): 

168 for match in finditer(replacement[0], node.content): 

169 node.content = ( 

170 node.content[: match.start()] 

171 + replacement[1] 

172 + node.content[match.end() :] 

173 ) 

174 

175 

176def shift_heading(node: Element, amount: int): 

177 """Shift the heading by the amount specified. 

178 

179 value is clamped between 1 and 6. 

180 """ 

181 

182 rank = heading_rank(node) 

183 rank += amount 

184 

185 node.tag = f"h{min(6, max(1, rank))}" 

186 

187 

188def modify_children(func: Callable[[Node, int, Parent], Node]): 

189 """Function wrapper that when called, and passed a Parent node, 

190 will apply the wrapped function to each child. 

191 

192 The following args are passed to the wrapped method: 

193 child (Node): A child of the parent node. 

194 index (int): The index of the child in the parent node. 

195 parent (Parent): The starting parent node. 

196 

197 The wrapped method is expected to return a new or modified node. 

198 """ 

199 

200 @wraps(func) 

201 def inner(start: Parent): 

202 for idx, child in enumerate(start): 

203 start[idx] = func(child, idx, start) 

204 

205 return inner