Coverage for tests/test_fix_nonlocal.py: 100%

53 statements  

« prev     ^ index     » next       coverage.py v7.5.3, created at 2024-06-14 07:07 +0200

1import ast 

2import sys 

3 

4from inline_snapshot import snapshot 

5 

6from pysource_codegen._codegen import fix_nonlocal 

7from pysource_codegen._codegen import unparse 

8from pysource_codegen._utils import ast_dump 

9 

10known_errors = snapshot( 

11 [ 

12 "no binding for nonlocal 'x' found", 

13 "name 'x' is parameter and nonlocal", 

14 "name 'x' is used prior to nonlocal declaration", 

15 "name 'x' is assigned to before nonlocal declaration", 

16 "name 'x' is parameter and global", 

17 "name 'x' is assigned to before global declaration", 

18 "name 'x' is used prior to global declaration", 

19 "annotated name 'x' can't be global", 

20 "name 'x' is nonlocal and global", 

21 "annotated name 'x' can't be nonlocal", 

22 "nonlocal binding not allowed for type parameter 'x'", 

23 "annotated name 'name_3' can't be global", 

24 "name 'name_0' is used prior to global declaration", 

25 ] 

26) 

27 

28 

29def check_code(src, snapshot_value): 

30 try: 

31 compile(src, "<string>", "exec") 

32 except SyntaxError as error: 

33 error_str = str(error) 

34 assert error_str.split(" (")[0] in known_errors 

35 else: 

36 assert False, "error expected" 

37 

38 tree = ast.parse(src) 

39 

40 print("original tree:") 

41 print(ast_dump(tree)) 

42 print("original src:") 

43 print(src) 

44 print("error:", str(error_str)) 

45 

46 tree = fix_nonlocal(tree) 

47 new_src = unparse(tree).strip() + "\n" 

48 

49 print() 

50 print("transformed tree:") 

51 print(ast_dump(tree)) 

52 print("transformed src:") 

53 print(new_src) 

54 

55 compile(new_src, "<string>", "exec") 

56 

57 assert new_src == snapshot_value 

58 

59 

60def test_global_0(): 

61 check_code( 

62 """ 

63def a(x): 

64 global x 

65 """, 

66 snapshot( 

67 """\ 

68def a(x): 

69 pass 

70""" 

71 ), 

72 ) 

73 

74 

75def test_global_1(): 

76 check_code( 

77 """ 

78def a(): 

79 x = 0 

80 global x 

81 """, 

82 snapshot( 

83 """\ 

84def a(): 

85 x = 0 

86 pass 

87""" 

88 ), 

89 ) 

90 

91 

92def test_global_2(): 

93 check_code( 

94 """ 

95def a(): 

96 print(x) 

97 global x 

98 """, 

99 snapshot( 

100 """\ 

101def a(): 

102 print(x) 

103 pass 

104""" 

105 ), 

106 ) 

107 

108 

109def test_global_3(): 

110 check_code( 

111 """ 

112def a(): 

113 x:int 

114 global x 

115 """, 

116 snapshot( 

117 """\ 

118def a(): 

119 x: int 

120 pass 

121""" 

122 ), 

123 ) 

124 

125 

126def test_global_4(): 

127 check_code( 

128 """ 

129 

130def a(): 

131 x=5 

132 def b(): 

133 nonlocal x 

134 global x 

135 """, 

136 snapshot( 

137 """\ 

138def a(): 

139 x = 5 

140 

141 def b(): 

142 nonlocal x 

143 pass 

144""" 

145 ), 

146 ) 

147 

148 

149def test_global_5(): 

150 check_code( 

151 """ 

152def name_4(): 

153 global name_3 

154 name_3: int 

155 """, 

156 snapshot( 

157 """\ 

158def name_4(): 

159 global name_3 

160 pass 

161""" 

162 ), 

163 ) 

164 

165 

166def test_nonlocal_0(): 

167 check_code( 

168 """ 

169def b(): 

170 def a(): 

171 nonlocal x 

172 """, 

173 snapshot( 

174 """\ 

175def b(): 

176 

177 def a(): 

178 pass 

179""" 

180 ), 

181 ) 

182 

183 

184def test_nonlocal_1(): 

185 check_code( 

186 """ 

187def b(): 

188 x=0 

189 def a(x): 

190 nonlocal x 

191 """, 

192 snapshot( 

193 """\ 

194def b(): 

195 x = 0 

196 

197 def a(x): 

198 pass 

199""" 

200 ), 

201 ) 

202 

203 

204def test_nonlocal_2(): 

205 check_code( 

206 """ 

207def b(): 

208 x=0 

209 def a(): 

210 print(x) 

211 nonlocal x 

212 """, 

213 snapshot( 

214 """\ 

215def b(): 

216 x = 0 

217 

218 def a(): 

219 print(x) 

220 pass 

221""" 

222 ), 

223 ) 

224 

225 

226def test_nonlocal_3(): 

227 check_code( 

228 """ 

229def b(): 

230 x=0 

231 def a(): 

232 x=5 

233 nonlocal x 

234 """, 

235 snapshot( 

236 """\ 

237def b(): 

238 x = 0 

239 

240 def a(): 

241 x = 5 

242 pass 

243""" 

244 ), 

245 ) 

246 

247 

248def test_nonlocal_4(): 

249 check_code( 

250 """ 

251def b(): 

252 x=0 

253 def a(): 

254 x:int 

255 nonlocal x 

256 """, 

257 snapshot( 

258 """\ 

259def b(): 

260 x = 0 

261 

262 def a(): 

263 x: int 

264 pass 

265""" 

266 ), 

267 ) 

268 

269 

270def test_nonlocal_5(): 

271 if sys.version_info >= (3, 12): 

272 check_code( 

273 """ 

274def b(): 

275 x=0 

276 def a[x:int](): 

277 nonlocal x 

278 """, 

279 snapshot( 

280 """\ 

281def b(): 

282 x = 0 

283 

284 def a[x: int](): 

285 pass 

286""" 

287 ), 

288 )