Coverage for src/incant/_codegen.py: 99%

144 statements  

« prev     ^ index     » next       coverage.py v7.0.1, created at 2023-06-27 02:25 +0200

1import linecache 

2 

3from inspect import Signature, iscoroutinefunction 

4from typing import Any, Callable, Counter, Dict, List, Literal, Optional, Union 

5 

6from attr import define 

7 

8from ._compat import signature 

9 

10 

11@define 

12class ParameterDep: 

13 arg_name: str 

14 type: Any 

15 default: Any = Signature.empty 

16 

17 

18CtxManagerKind = Literal["sync", "async"] 

19 

20 

21@define 

22class Invocation: 

23 """Produce an invocation (and possibly a local var) in a generated function.""" 

24 

25 factory: Callable 

26 args: List[Union[Callable, ParameterDep]] 

27 is_forced: bool = False 

28 is_ctx_manager: Optional[CtxManagerKind] = None 

29 

30 

31def compile_invoke( 

32 fn: Callable, 

33 fn_args: List[Callable], 

34 fn_factory_args: List[str], 

35 outer_args: List[ParameterDep], 

36 invocations: List[Invocation], 

37 is_async: bool = False, 

38) -> Callable: 

39 """Generate the invocation wrapper for `fn`. 

40 

41 :param fn_factory_args: Used names to avoid for local variables. 

42 :param outer_args: Arguments that the generated function needs to retain. 

43 

44 """ 

45 # Some arguments need to be taken from outside. 

46 # Some arguments need to be calculated from factories. 

47 sig = signature(fn) 

48 fn_name = f"invoke_{fn.__name__}" if fn.__name__ != "<lambda>" else "invoke_lambda" 

49 globs: Dict[str, Any] = {} 

50 taken_local_vars = set() 

51 arg_lines = [] 

52 

53 for dep in outer_args: 

54 if dep.type is not Signature.empty: 

55 # Some types, like new unions (`int|str`), do not have names. 

56 if (type_name := getattr(dep.type, "__name__", None)) and ( 

57 type_name not in globs or globs[type_name] is dep.type 

58 ): 

59 arg_type_snippet = f": {type_name}" 

60 globs[type_name] = dep.type 

61 else: 

62 arg_type_snippet = f": _incant_arg_{dep.arg_name}" 

63 globs[f"_incant_arg_{dep.arg_name}"] = dep.type 

64 else: 

65 arg_type_snippet = "" 

66 if dep.default is not Signature.empty: 

67 arg_default = f"_incant_default_{dep.arg_name}" 

68 arg_type_snippet = f"{arg_type_snippet} = {arg_default}" 

69 globs[arg_default] = dep.default 

70 

71 arg_lines.append(f"{dep.arg_name}{arg_type_snippet}") 

72 taken_local_vars.add(dep.arg_name) 

73 outer_arg_names = {o.arg_name for o in outer_args} 

74 

75 lines = [] 

76 

77 ret_type = "" 

78 if sig.return_annotation is not Signature.empty: 

79 tn = getattr(sig.return_annotation, "__name__", None) 

80 if tn is None: 

81 tn = "None" 

82 elif tn in globs and globs[tn] is not sig.return_annotation: 

83 tn = "_incant_return_type" 

84 globs[tn] = sig.return_annotation 

85 ret_type = f" -> {tn}" 

86 if is_async: 

87 lines.append(f"async def {fn_name}({', '.join(arg_lines)}){ret_type}:") 

88 else: 

89 lines.append(f"def {fn_name}({', '.join(arg_lines)}){ret_type}:") 

90 

91 local_vars_ix_by_factory = { 

92 local_var.factory: ix for ix, local_var in enumerate(invocations) 

93 } 

94 inline_exprs_by_factory: Dict[Callable, str] = {} 

95 ind = 0 # Indentation level 

96 

97 local_counter = 0 

98 

99 # The results of some invocations are used only once. 

100 # In that case, we can forgo the use of a local variable. 

101 # We call these invocations `inlineable`. 

102 # An invocation is inlineable if: 

103 # * it is not a context manager 

104 # * it appears only once in the args attribute of the invocation chain. 

105 factory_fns = Counter(fn_args) 

106 for invoc in invocations: 

107 if invoc.is_ctx_manager: 

108 continue 

109 factory_fns.update( 

110 Counter(a for a in invoc.args if not isinstance(a, ParameterDep)) 

111 ) 

112 inlineable = {fn for fn, cnt in factory_fns.items() if cnt == 1} 

113 

114 for i, invoc in enumerate(invocations): 

115 inv_fn_name = invoc.factory.__name__ 

116 if ( 

117 inv_fn_name not in taken_local_vars 

118 and inv_fn_name not in globs 

119 and inv_fn_name != "<lambda>" 

120 ): 

121 global_fn_name = inv_fn_name 

122 else: 

123 global_fn_name = f"_incant_local_factory_{i}" 

124 

125 globs[global_fn_name] = invoc.factory 

126 

127 local_arg_lines = [] 

128 for local_arg in invoc.args: 

129 if isinstance(local_arg, ParameterDep): 

130 local_arg_lines.append(local_arg.arg_name) 

131 else: 

132 if local_arg in inline_exprs_by_factory: 

133 local_arg_lines.append(inline_exprs_by_factory[local_arg]) 

134 else: 

135 local_arg_lines.append( 

136 f"_incant_local_{local_vars_ix_by_factory[local_arg]}" 

137 ) 

138 

139 if invoc.factory in inlineable and not invoc.is_ctx_manager: 

140 aw = "await " if iscoroutinefunction(invoc.factory) else "" 

141 inline_exprs_by_factory[ 

142 invoc.factory 

143 ] = f"{aw}{global_fn_name}({', '.join(local_arg_lines)})" 

144 

145 else: 

146 local_name = f"_incant_local_{local_counter}" 

147 

148 if invoc.is_ctx_manager is not None: 

149 aw = "async " if invoc.is_ctx_manager == "async" else "" 

150 if not invoc.is_forced: 

151 lines.append( 

152 f" {' ' * ind}{aw}with {global_fn_name}({', '.join(local_arg_lines)}) as {local_name}:" 

153 ) 

154 local_counter += 1 

155 else: 

156 lines.append( 

157 f" {' ' * ind}{aw}with {global_fn_name}({', '.join(local_arg_lines)}):" 

158 ) 

159 ind += 2 

160 else: 

161 aw = "await " if iscoroutinefunction(invoc.factory) else "" 

162 if not invoc.is_forced: 

163 lines.append( 

164 f" {' ' * ind}{local_name} = {aw}{global_fn_name}({', '.join(local_arg_lines)})" 

165 ) 

166 local_counter += 1 

167 else: 

168 lines.append( 

169 f" {' ' * ind}{aw}{global_fn_name}({', '.join(local_arg_lines)})" 

170 ) 

171 

172 incant_arg_lines = [] 

173 cnt = 0 

174 for name in sig.parameters: 

175 if name not in fn_factory_args and name in outer_arg_names: 

176 incant_arg_lines.append(name) 

177 else: 

178 # We need to fish out the local name for this fn arg. 

179 factory = fn_args[cnt] 

180 if factory in inline_exprs_by_factory: 

181 incant_arg_lines.append(inline_exprs_by_factory[factory]) 

182 else: 

183 local_var_ix = local_vars_ix_by_factory[factory] 

184 incant_arg_lines.append(f"_incant_local_{local_var_ix}") 

185 cnt += 1 

186 

187 aw = "await " if iscoroutinefunction(fn) else "" 

188 orig_name = fn.__name__ 

189 if ( 

190 orig_name != "<lambda>" 

191 and orig_name not in globs 

192 and orig_name not in outer_arg_names 

193 and orig_name not in taken_local_vars 

194 ): 

195 inner_name = fn.__name__ 

196 else: 

197 inner_name = "_incant_inner_fn" 

198 globs[inner_name] = fn 

199 lines.append(f" {' ' * ind}return {aw}{inner_name}({', '.join(incant_arg_lines)})") 

200 

201 script = "\n".join(lines) 

202 

203 fname = _generate_unique_filename(fn.__name__, "invoke", lines) 

204 eval(compile(script, fname, "exec"), globs) 

205 

206 return globs[fn_name] 

207 

208 

209def compile_incant_wrapper( 

210 fn: Callable, incant_plan: List[Union[int, str]], num_pos_args: int, num_kwargs: int 

211): 

212 fn_name = f"incant_{fn.__name__}" if fn.__name__ != "<lambda>" else "incant_lambda" 

213 globs = {"_incant_inner_fn": fn} 

214 arg_lines = [] 

215 if num_pos_args: 

216 arg_lines.append("*args") 

217 

218 kwargs = [arg for arg in incant_plan if isinstance(arg, str)] 

219 arg_lines.extend(kwargs) 

220 if num_kwargs > len(kwargs): 

221 arg_lines.append("**kwargs") 

222 

223 lines = [] 

224 lines.append(f"def {fn_name}({', '.join(arg_lines)}):") 

225 lines.append(" return _incant_inner_fn(") 

226 for arg in incant_plan: 

227 if isinstance(arg, int): 

228 lines.append(f" args[{arg}],") 

229 else: 

230 lines.append(f" {arg},") 

231 lines.append(" )") 

232 

233 script = "\n".join(lines) 

234 

235 fname = _generate_unique_filename(fn.__name__, "incant", lines) 

236 eval(compile(script, fname, "exec"), globs) 

237 

238 return globs[fn_name] 

239 

240 

241def _generate_unique_filename(func_name: str, func_type: str, source: List[str]) -> str: 

242 """ 

243 Create a "filename" suitable for a function being generated. 

244 """ 

245 extra = "" 

246 count = 1 

247 

248 while True: 

249 unique_filename = f"<incant generated {func_type} of {func_name}{extra}>" 

250 # To handle concurrency we essentially "reserve" our spot in 

251 # the linecache with a dummy line. The caller can then 

252 # set this value correctly. 

253 cache_line = (len(source), None, source, unique_filename) 

254 if linecache.cache.setdefault(unique_filename, cache_line) == cache_line: 

255 return unique_filename 

256 

257 # Looks like this spot is taken. Try again. 

258 count += 1 

259 extra = f"-{count}"