Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1import ast 

2import inspect 

3import textwrap 

4import tokenize 

5import warnings 

6from bisect import bisect_right 

7from typing import Iterable 

8from typing import Iterator 

9from typing import List 

10from typing import Optional 

11from typing import Tuple 

12from typing import Union 

13 

14from _pytest.compat import overload 

15 

16 

17class Source: 

18 """An immutable object holding a source code fragment. 

19 

20 When using Source(...), the source lines are deindented. 

21 """ 

22 

23 def __init__(self, obj: object = None) -> None: 

24 if not obj: 

25 self.lines = [] # type: List[str] 

26 elif isinstance(obj, Source): 

27 self.lines = obj.lines 

28 elif isinstance(obj, (tuple, list)): 

29 self.lines = deindent(x.rstrip("\n") for x in obj) 

30 elif isinstance(obj, str): 

31 self.lines = deindent(obj.split("\n")) 

32 else: 

33 rawcode = getrawcode(obj) 

34 src = inspect.getsource(rawcode) 

35 self.lines = deindent(src.split("\n")) 

36 

37 def __eq__(self, other: object) -> bool: 

38 if not isinstance(other, Source): 

39 return NotImplemented 

40 return self.lines == other.lines 

41 

42 # Ignore type because of https://github.com/python/mypy/issues/4266. 

43 __hash__ = None # type: ignore 

44 

45 @overload 

46 def __getitem__(self, key: int) -> str: 

47 raise NotImplementedError() 

48 

49 @overload # noqa: F811 

50 def __getitem__(self, key: slice) -> "Source": # noqa: F811 

51 raise NotImplementedError() 

52 

53 def __getitem__(self, key: Union[int, slice]) -> Union[str, "Source"]: # noqa: F811 

54 if isinstance(key, int): 

55 return self.lines[key] 

56 else: 

57 if key.step not in (None, 1): 

58 raise IndexError("cannot slice a Source with a step") 

59 newsource = Source() 

60 newsource.lines = self.lines[key.start : key.stop] 

61 return newsource 

62 

63 def __iter__(self) -> Iterator[str]: 

64 return iter(self.lines) 

65 

66 def __len__(self) -> int: 

67 return len(self.lines) 

68 

69 def strip(self) -> "Source": 

70 """ return new source object with trailing 

71 and leading blank lines removed. 

72 """ 

73 start, end = 0, len(self) 

74 while start < end and not self.lines[start].strip(): 

75 start += 1 

76 while end > start and not self.lines[end - 1].strip(): 

77 end -= 1 

78 source = Source() 

79 source.lines[:] = self.lines[start:end] 

80 return source 

81 

82 def indent(self, indent: str = " " * 4) -> "Source": 

83 """ return a copy of the source object with 

84 all lines indented by the given indent-string. 

85 """ 

86 newsource = Source() 

87 newsource.lines = [(indent + line) for line in self.lines] 

88 return newsource 

89 

90 def getstatement(self, lineno: int) -> "Source": 

91 """ return Source statement which contains the 

92 given linenumber (counted from 0). 

93 """ 

94 start, end = self.getstatementrange(lineno) 

95 return self[start:end] 

96 

97 def getstatementrange(self, lineno: int) -> Tuple[int, int]: 

98 """ return (start, end) tuple which spans the minimal 

99 statement region which containing the given lineno. 

100 """ 

101 if not (0 <= lineno < len(self)): 

102 raise IndexError("lineno out of range") 

103 ast, start, end = getstatementrange_ast(lineno, self) 

104 return start, end 

105 

106 def deindent(self) -> "Source": 

107 """return a new source object deindented.""" 

108 newsource = Source() 

109 newsource.lines[:] = deindent(self.lines) 

110 return newsource 

111 

112 def __str__(self) -> str: 

113 return "\n".join(self.lines) 

114 

115 

116# 

117# helper functions 

118# 

119 

120 

121def findsource(obj) -> Tuple[Optional[Source], int]: 

122 try: 

123 sourcelines, lineno = inspect.findsource(obj) 

124 except Exception: 

125 return None, -1 

126 source = Source() 

127 source.lines = [line.rstrip() for line in sourcelines] 

128 return source, lineno 

129 

130 

131def getrawcode(obj, trycall: bool = True): 

132 """ return code object for given function. """ 

133 try: 

134 return obj.__code__ 

135 except AttributeError: 

136 obj = getattr(obj, "f_code", obj) 

137 obj = getattr(obj, "__code__", obj) 

138 if trycall and not hasattr(obj, "co_firstlineno"): 

139 if hasattr(obj, "__call__") and not inspect.isclass(obj): 

140 x = getrawcode(obj.__call__, trycall=False) 

141 if hasattr(x, "co_firstlineno"): 

142 return x 

143 return obj 

144 

145 

146def deindent(lines: Iterable[str]) -> List[str]: 

147 return textwrap.dedent("\n".join(lines)).splitlines() 

148 

149 

150def get_statement_startend2(lineno: int, node: ast.AST) -> Tuple[int, Optional[int]]: 

151 # flatten all statements and except handlers into one lineno-list 

152 # AST's line numbers start indexing at 1 

153 values = [] # type: List[int] 

154 for x in ast.walk(node): 

155 if isinstance(x, (ast.stmt, ast.ExceptHandler)): 

156 values.append(x.lineno - 1) 

157 for name in ("finalbody", "orelse"): 

158 val = getattr(x, name, None) # type: Optional[List[ast.stmt]] 

159 if val: 

160 # treat the finally/orelse part as its own statement 

161 values.append(val[0].lineno - 1 - 1) 

162 values.sort() 

163 insert_index = bisect_right(values, lineno) 

164 start = values[insert_index - 1] 

165 if insert_index >= len(values): 

166 end = None 

167 else: 

168 end = values[insert_index] 

169 return start, end 

170 

171 

172def getstatementrange_ast( 

173 lineno: int, 

174 source: Source, 

175 assertion: bool = False, 

176 astnode: Optional[ast.AST] = None, 

177) -> Tuple[ast.AST, int, int]: 

178 if astnode is None: 

179 content = str(source) 

180 # See #4260: 

181 # don't produce duplicate warnings when compiling source to find ast 

182 with warnings.catch_warnings(): 

183 warnings.simplefilter("ignore") 

184 astnode = ast.parse(content, "source", "exec") 

185 

186 start, end = get_statement_startend2(lineno, astnode) 

187 # we need to correct the end: 

188 # - ast-parsing strips comments 

189 # - there might be empty lines 

190 # - we might have lesser indented code blocks at the end 

191 if end is None: 

192 end = len(source.lines) 

193 

194 if end > start + 1: 

195 # make sure we don't span differently indented code blocks 

196 # by using the BlockFinder helper used which inspect.getsource() uses itself 

197 block_finder = inspect.BlockFinder() 

198 # if we start with an indented line, put blockfinder to "started" mode 

199 block_finder.started = source.lines[start][0].isspace() 

200 it = ((x + "\n") for x in source.lines[start:end]) 

201 try: 

202 for tok in tokenize.generate_tokens(lambda: next(it)): 

203 block_finder.tokeneater(*tok) 

204 except (inspect.EndOfBlock, IndentationError): 

205 end = block_finder.last + start 

206 except Exception: 

207 pass 

208 

209 # the end might still point to a comment or empty line, correct it 

210 while end: 

211 line = source.lines[end - 1].lstrip() 

212 if line.startswith("#") or not line: 

213 end -= 1 

214 else: 

215 break 

216 return astnode, start, end