Coverage for phml\embedded\__init__.py: 99%

183 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2023-04-12 14:26 -0500

1""" 

2Embedded has all the logic for processing python elements, attributes, and text blocks. 

3""" 

4from __future__ import annotations 

5 

6import ast 

7import re 

8import types 

9from functools import cached_property 

10from html import escape 

11from pathlib import Path 

12from shutil import get_terminal_size 

13from traceback import FrameSummary, extract_tb 

14from typing import Any, Iterator, TypedDict 

15 

16from phml.embedded.built_in import built_in_funcs, built_in_types 

17from phml.helpers import normalize_indent 

18from phml.nodes import Element, Literal 

19 

20ESCAPE_OPTIONS = { 

21 "quote": False, 

22} 

23 

24# Global cached imports 

25__IMPORTS__ = {} 

26__FROM_IMPORTS__ = {} 

27 

28 

29# PERF: Only allow assignments, methods, imports, and classes? 

30class EmbeddedTryCatch: 

31 """Context manager around embedded python execution. Will parse the traceback 

32 and the content being executed to create a detailed error message. The final 

33 error message is raised in a custom EmbeddedPythonException. 

34 """ 

35 

36 def __init__( 

37 self, 

38 path: str | Path | None = None, 

39 content: str | None = None, 

40 pos: tuple[int, int] | None = None, 

41 ) -> None: 

42 self._path = str(path or "<python>") 

43 self._content = content or "" 

44 self._pos = pos or (0, 0) 

45 

46 def __enter__(self): 

47 pass 

48 

49 def __exit__(self, _, exc_val, exc_tb): 

50 if exc_val is not None and not isinstance(exc_val, SystemExit): 

51 raise EmbeddedPythonException( 

52 self._path, 

53 self._content, 

54 self._pos, 

55 exc_val, 

56 exc_tb, 

57 ) from exc_val 

58 

59 

60class EmbeddedPythonException(Exception): 

61 def __init__(self, path, content, pos, exc_val, exc_tb) -> None: 

62 self.max_width, _ = get_terminal_size((20, 0)) 

63 self.msg = exc_val.msg if hasattr(exc_val, "msg") else str(exc_val) 

64 if isinstance(exc_val, SyntaxError): 

65 self.l_slice = (exc_val.lineno or 0, exc_val.end_lineno or 0) 

66 self.c_slice = (exc_val.offset or 0, exc_val.end_offset or 0) 

67 else: 

68 fs: FrameSummary = extract_tb(exc_tb)[-1] 

69 self.l_slice = (fs.lineno or 0, fs.end_lineno or 0) 

70 self.c_slice = (fs.colno or 0, fs.end_colno or 0) 

71 

72 self._content = content 

73 self._path = path 

74 self._pos = pos 

75 

76 def format_line(self, line, c_width, leading: str = " "): 

77 return f"{leading.ljust(c_width, ' ')}│{line}" 

78 

79 def generate_exception_lines(self, lines: list[str], width: int): 

80 max_width = self.max_width - width - 3 

81 result = [] 

82 for i, line in enumerate(lines): 

83 if len(line) > max_width: 

84 parts = [ 

85 line[j : j + max_width] for j in range(0, len(line), max_width) 

86 ] 

87 result.append(self.format_line(parts[0], width, str(i + 1))) 

88 for part in parts[1:]: 

89 result.append(self.format_line(part, width)) 

90 else: 

91 result.append(self.format_line(line, width, str(i + 1))) 

92 return result 

93 

94 def __str__(self) -> str: 

95 message = "" 

96 if self._path != "": 

97 pos = ( 

98 self._pos[0] + (self.l_slice[0] or 0), 

99 self.c_slice[0] or self._pos[1], 

100 ) 

101 if pos[0] > self._content.count("\n"): 

102 message = f"{self._path} Failed to execute phml embedded python" 

103 else: 

104 message = f"[{pos[0]+1}:{pos[1]}] {self._path} Failed to execute phml embedded python" 

105 if self._content != "": 

106 lines = self._content.split("\n") 

107 target_lines = lines[self.l_slice[0] - 1 : self.l_slice[1]] 

108 if len(target_lines) > 0: 

109 if self.l_slice[0] == self.l_slice[1]: 

110 target_lines[0] = ( 

111 target_lines[0][: self.c_slice[0]] 

112 + "\x1b[31m" 

113 + target_lines[0][self.c_slice[0] : self.c_slice[1]] 

114 + "\x1b[0m" 

115 + target_lines[0][self.c_slice[1] :] 

116 ) 

117 else: 

118 target_lines[0] = ( 

119 target_lines[0][: self.c_slice[0] + 1] 

120 + "\x1b[31m" 

121 + target_lines[0][self.c_slice[0] + 1 :] 

122 + "\x1b[0m" 

123 ) 

124 for i, line in enumerate(target_lines[1:-1]): 

125 target_lines[i + 1] = "\x1b[31m" + line + "\x1b[0m" 

126 target_lines[-1] = ( 

127 "\x1b[31m" 

128 + target_lines[-1][: self.c_slice[-1] + 1] 

129 + "\x1b[0m" 

130 + target_lines[-1][self.c_slice[-1] + 1 :] 

131 ) 

132 

133 lines = [ 

134 *lines[: self.l_slice[0] - 1], 

135 *target_lines, 

136 *lines[self.l_slice[1] :], 

137 ] 

138 

139 w_fmt = len(f"{len(lines)}") 

140 content = "\n".join( 

141 self.generate_exception_lines(lines, w_fmt), 

142 ) 

143 line_width = self.max_width - w_fmt - 2 

144 

145 exception = f"{self.msg}" 

146 if len(target_lines) > 0: 

147 exception += f" at <{self.l_slice[0]}:{self.c_slice[0]}-{self.l_slice[1]}:{self.c_slice[1]}>" 

148 ls = [ 

149 exception[i : i + line_width] 

150 for i in range(0, len(exception), line_width) 

151 ] 

152 exception_line = self.format_line(ls[0], w_fmt, "#") 

153 for l in ls[1:]: 

154 exception_line += "\n" + self.format_line(l, w_fmt) 

155 

156 message += ( 

157 f"\n{'─'.ljust(w_fmt, '─')}┬─{'─'*(line_width)}\n" 

158 + exception_line 

159 + "\n" 

160 + f"{'═'.ljust(w_fmt, '═')}╪═{'═'*(line_width)}\n" 

161 + f"{content}" 

162 ) 

163 

164 return message 

165 

166 

167def parse_import_values(_import: str) -> list[str | tuple[str, str]]: 

168 values = [] 

169 for value in re.finditer(r"(?:([^,\s]+) as (.+)|([^,\s]+))(?=\s*,)?", _import): 

170 if value.group(1) is not None: 

171 values.append((value.group(1), value.group(2))) 

172 elif value.groups(3) is not None: 

173 values.append(value.group(3)) 

174 return values 

175 

176 

177class ImportStruct(TypedDict): 

178 key: str 

179 values: str | list[str] 

180 

181 

182class Module: 

183 """Object used to access the gobal imports. Readonly data.""" 

184 

185 def __init__(self, module: str, *, imports: list[str] | None = None) -> None: 

186 self.objects = imports or [] 

187 if imports is not None and len(imports) > 0: 

188 if module not in __FROM_IMPORTS__: 

189 raise ValueError(f"Unkown module {module!r}") 

190 try: 

191 imports = { 

192 _import: __FROM_IMPORTS__[module][_import] for _import in imports 

193 } 

194 except KeyError as kerr: 

195 back_frame = kerr.__traceback__.tb_frame.f_back 

196 back_tb = types.TracebackType( 

197 tb_next=None, 

198 tb_frame=back_frame, 

199 tb_lasti=back_frame.f_lasti, 

200 tb_lineno=back_frame.f_lineno, 

201 ) 

202 FrameSummary("", 2, "") 

203 raise ValueError( 

204 f"{', '.join(kerr.args)!r} {'arg' if len(kerr.args) > 1 else 'is'} not found in cached imported module {module!r}", 

205 ).with_traceback(back_tb) 

206 

207 globals().update(imports) 

208 locals().update(imports) 

209 self.module = module 

210 else: 

211 if module not in __IMPORTS__: 

212 raise ValueError(f"Unkown module {module!r}") 

213 

214 imports = {module: __IMPORTS__[module]} 

215 locals().update(imports) 

216 globals().update(imports) 

217 self.module = module 

218 

219 def collect(self) -> Any: 

220 """Collect the imports and return the single import or a tuple of multiple imports.""" 

221 if len(self.objects) > 0: 

222 if len(self.objects) == 1: 

223 return __FROM_IMPORTS__[self.module][self.objects[0]] 

224 return tuple( 

225 [__FROM_IMPORTS__[self.module][object] for object in self.objects] 

226 ) 

227 return __IMPORTS__[self.module] 

228 

229 

230class EmbeddedImport: 

231 """Data representation of an import.""" 

232 

233 module: str 

234 """Package where the import(s) are from.""" 

235 

236 objects: list[str|tuple[str, str]] 

237 """The imported objects.""" 

238 

239 def __init__( 

240 self, module: str, values: str | list[str] | None = None 

241 ) -> None: 

242 self.module = module 

243 

244 if isinstance(values, list): 

245 self.objects = values 

246 else: 

247 self.objects = parse_import_values(values or "") 

248 

249 self.data 

250 

251 def _parse_from_import(self): 

252 if self.module in __FROM_IMPORTS__: 

253 values = list( 

254 filter( 

255 lambda v: (v if isinstance(v, str) else v[0]) 

256 not in __FROM_IMPORTS__[self.module], 

257 self.objects, 

258 ) 

259 ) 

260 else: 

261 values = self.objects 

262 

263 if len(values) > 0: 

264 local_env = {} 

265 exec_val = compile(str(self), "_embedded_import_", "exec") 

266 exec(exec_val, {}, local_env) 

267 

268 if self.module not in __FROM_IMPORTS__: 

269 __FROM_IMPORTS__[self.module] = {} 

270 __FROM_IMPORTS__[self.module].update(local_env) 

271 

272 keys = [key if isinstance(key, str) else key[1] for key in self.objects] 

273 return {key: __FROM_IMPORTS__[self.module][key] for key in keys} 

274 

275 def _parse_import(self): 

276 if self.module not in __IMPORTS__: 

277 local_env = {} 

278 exec_val = compile(str(self), "_embedded_import_", "exec") 

279 exec(exec_val, {}, local_env) 

280 __IMPORTS__.update(local_env) 

281 

282 return {self.module: __IMPORTS__[self.module]} 

283 

284 def __iter__(self) -> Iterator[tuple[str, Any]]: 

285 if len(self.objects) > 0: 

286 if self.module not in __FROM_IMPORTS__: 

287 raise KeyError(f"{self.module} is not a known exposed module") 

288 yield from __FROM_IMPORTS__[self.module].items() 

289 else: 

290 if self.module not in __IMPORTS__: 

291 raise KeyError(f"{self.module} is not a known exposed module") 

292 yield self.module, __IMPORTS__[self.module] 

293 

294 @cached_property 

295 def data(self) -> dict[str, Any]: 

296 """The actual imports stored by a name to value mapping.""" 

297 if len(self.objects) > 0: 

298 return self._parse_from_import() 

299 return self._parse_import() 

300 

301 def __getitem__(self, key: str) -> Any: 

302 self.data[key] 

303 

304 def __repr__(self) -> str: 

305 if len(self.objects) > 0: 

306 return f"FROM({self.module}).IMPORT({', '.join(self.objects)})" 

307 return f"IMPORT({self.module})" 

308 

309 def __str__(self) -> str: 

310 if len(self.objects) > 0: 

311 return f"from {self.module} import {', '.join(obj if isinstance(obj, str) else f'{obj[0]} as {obj[1]}' for obj in self.objects)}" 

312 return f"import {self.module}" 

313 

314 

315class Embedded: 

316 """Logic for parsing and storing locals and imports of dynamic python code.""" 

317 

318 context: dict[str, Any] 

319 """Variables and locals found in the python code block.""" 

320 

321 imports: list[EmbeddedImport] 

322 """Imports needed for the python in this scope. Imports are stored in the module globally 

323 to reduce duplicate imports. 

324 """ 

325 

326 def __init__(self, content: str | Element, path: str | None = None) -> None: 

327 self._path = path or "<python>" 

328 self._pos = (0, 0) 

329 if isinstance(content, Element): 

330 if len(content) > 1 or ( 

331 len(content) == 1 and not Literal.is_text(content[0]) 

332 ): 

333 # TODO: Custom error 

334 raise ValueError( 

335 "Expected python elements to contain one text node or nothing", 

336 ) 

337 if content.position is not None: 

338 start = content.position.start 

339 self._pos = (start.line, start.column) 

340 content = content[0].content 

341 content = normalize_indent(content) 

342 self.imports = [] 

343 self.context = {} 

344 if len(content) > 0: 

345 with EmbeddedTryCatch(path, content, self._pos): 

346 self.parse_data(content) 

347 

348 def __add__(self, _o) -> Embedded: 

349 self.imports.extend(_o.imports) 

350 self.context.update(_o.context) 

351 return self 

352 

353 def __contains__(self, key: str) -> bool: 

354 return key in self.context 

355 

356 def __getitem__(self, key: str) -> Any: 

357 if key in self.context: 

358 return self.context[key] 

359 elif key in self.imports: 

360 return __IMPORTS__[key] 

361 

362 raise KeyError(f"Key is not in Embedded context or imports: {key}") 

363 

364 def split_contexts(self, content: str) -> tuple[list[str], list[EmbeddedImport]]: 

365 re_context = re.compile(r"class.+|def.+") 

366 re_import = re.compile( 

367 r"from (?P<key>.+) import (?P<values>.+)|import (?P<value>.+)", 

368 ) 

369 

370 imports = [] 

371 blocks = [] 

372 current = [] 

373 

374 lines = content.split("\n") 

375 i = 0 

376 while i < len(lines): 

377 imp_match = re_import.match(lines[i]) 

378 if imp_match is not None: 

379 data = imp_match.groupdict() 

380 imports.append( 

381 EmbeddedImport(data["key"] or data["value"], data["values"]) 

382 ) 

383 elif re_context.match(lines[i]) is not None: 

384 blocks.append("\n".join(current)) 

385 current = [lines[i]] 

386 i += 1 

387 while i < len(lines) and lines[i].startswith(" "): 

388 current.append(lines[i]) 

389 i += 1 

390 blocks.append("\n".join(current)) 

391 current = [] 

392 else: 

393 current.append(lines[i]) 

394 if i < len(lines): 

395 i += 1 

396 

397 if len(current) > 0: 

398 blocks.append("\n".join(current)) 

399 

400 return blocks, imports 

401 

402 def parse_data(self, content: str): 

403 blocks, self.imports = self.split_contexts(content) 

404 

405 local_env = {} 

406 global_env = {key: value for _import in self.imports for key, value in _import} 

407 context = {**global_env} 

408 

409 for block in blocks: 

410 exec_val = compile(block, self._path, "exec") 

411 exec(exec_val, global_env, local_env) 

412 context.update(local_env) 

413 # update global env with found locals so they can be used inside methods and classes 

414 global_env.update(local_env) 

415 

416 self.context = context 

417 

418 

419def _validate_kwargs(code: ast.Module, kwargs: dict[str, Any]): 

420 exclude_list = [*built_in_funcs, *built_in_types] 

421 for var in ( 

422 name.id 

423 for name in ast.walk(code) 

424 if isinstance( 

425 name, 

426 ast.Name, 

427 ) # Get all variables/names used. This can be methods or values 

428 and name.id not in exclude_list 

429 ): 

430 if var not in kwargs: 

431 kwargs[var] = None 

432 

433 

434def update_ast_node_pos(dest, source): 

435 """Assign lineno, end_lineno, col_offset, and end_col_offset 

436 from a source python ast node to a destination python ast node. 

437 """ 

438 dest.lineno = source.lineno 

439 dest.end_lineno = source.end_lineno 

440 dest.col_offset = source.col_offset 

441 dest.end_col_offset = source.end_col_offset 

442 

443 

444RESULT = "_phml_embedded_result_" 

445 

446 

447def exec_embedded(code: str, _path: str | None = None, **context: Any) -> Any: 

448 """Execute embedded python and return the extracted value. This is the last 

449 assignment in the embedded python. The embedded python must have the last line as a value 

450 or an assignment. 

451 

452 Note: 

453 No local or global variables will be retained from the embedded python code. 

454 

455 Args: 

456 code (str): The embedded python code. 

457 **context (Any): The additional context to provide to the embedded python. 

458 

459 Returns: 

460 Any: The value of the last assignment or value defined 

461 """ 

462 from phml.utilities import blank 

463 

464 context = { 

465 "blank": blank, 

466 **context, 

467 } 

468 

469 # last line must be an assignment or the value to be used 

470 with EmbeddedTryCatch(_path, code): 

471 code = normalize_indent(code) 

472 AST = ast.parse(code) 

473 _validate_kwargs(AST, context) 

474 

475 last = AST.body[-1] 

476 returns = [ret for ret in AST.body if isinstance(ret, ast.Return)] 

477 

478 if len(returns) > 0: 

479 last = returns[0] 

480 idx = AST.body.index(last) 

481 

482 n_expr = ast.Name(id=RESULT, ctx=ast.Store()) 

483 n_assign = ast.Assign(targets=[n_expr], value=last.value) 

484 

485 update_ast_node_pos(dest=n_expr, source=last) 

486 update_ast_node_pos(dest=n_assign, source=last) 

487 

488 AST.body = [*AST.body[:idx], n_assign] 

489 elif isinstance(last, ast.Expr): 

490 n_expr = ast.Name(id=RESULT, ctx=ast.Store()) 

491 n_assign = ast.Assign(targets=[n_expr], value=last.value) 

492 

493 update_ast_node_pos(dest=n_expr, source=last) 

494 update_ast_node_pos(dest=n_assign, source=last) 

495 

496 AST.body[-1] = n_assign 

497 elif isinstance(last, ast.Assign): 

498 n_expr = ast.Name(id=RESULT, ctx=ast.Store()) 

499 update_ast_node_pos(dest=n_expr, source=last) 

500 last.targets.append(n_expr) 

501 

502 ccode = compile(AST, "_phml_embedded_", "exec") 

503 local_env = {} 

504 exec(ccode, {**context}, local_env) 

505 

506 if isinstance(local_env[RESULT], str): 

507 return escape(local_env[RESULT], **ESCAPE_OPTIONS) 

508 return local_env[RESULT] 

509 

510 

511def exec_embedded_blocks(code: str, _path: str = "", **context: dict[str, Any]): 

512 """Execute embedded python inside `{{}}` blocks. The resulting values are subsituted 

513 in for the found blocks. 

514 

515 Note: 

516 No local or global variables will be retained from the embedded python code. 

517 

518 Args: 

519 code (str): The embedded python code. 

520 **context (Any): The additional context to provide to the embedded python. 

521 

522 Returns: 

523 str: The value of the passed in string with the python blocks replaced. 

524 """ 

525 

526 result = [""] 

527 data = [] 

528 next_block = re.search(r"\{\{", code) 

529 while next_block is not None: 

530 start = next_block.start() 

531 if start > 0: 

532 result[-1] += code[:start] 

533 code = code[start + 2 :] 

534 

535 balance = 2 

536 index = 0 

537 while balance > 0 and index < len(code): 

538 if code[index] == "}": 

539 balance -= 1 

540 elif code[index] == "{": 

541 balance += 1 

542 index += 1 

543 

544 result.append("") 

545 data.append( 

546 str( 

547 exec_embedded( 

548 code[: index - 2].strip(), 

549 _path + f" block #{len(data)+1}", 

550 **context, 

551 ), 

552 ), 

553 ) 

554 code = code[index:] 

555 next_block = re.search(r"(?<!\\)\{\{", code) 

556 

557 if len(code) > 0: 

558 result[-1] += code 

559 

560 if len(data) != len(result) - 1: 

561 raise ValueError( 

562 f"Not enough data to replace inline python blocks: expected {len(result) - 1} but there was {len(data)}" 

563 ) 

564 

565 def merge(dest: list, source: list) -> list: 

566 """Merge source into dest. For every item in source place each item between items of dest. 

567 If there is more items in source the spaces between items in dest then the extra items in source 

568 are ignored. 

569 

570 Example: 

571 dest = [1, 2, 3] 

572 source = ["red", "blue", "green"] 

573 merge(dest, source) == [1, "red", 2, "blue", 3] 

574 

575 or 

576 

577 dest = [1, 2, 3] 

578 source = ["red"] 

579 merge(dest, source) == [1, "red", 2, 3] 

580 """ 

581 combination = [] 

582 for f_item, s_item in zip(dest, source): 

583 combination.extend([f_item, s_item]) 

584 

585 idx = len(combination) // 2 

586 if idx < len(dest): 

587 combination.extend(dest[idx:]) 

588 return combination 

589 

590 return "".join(merge(result, data))