Coverage for src/castep_linter/fortran/fortran_node.py: 89%

75 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-11-23 18:07 +0000

1"""Module containing useful classes for parsing a fortran source tree from tree-sitter""" 

2from enum import Enum 

3from typing import List, Optional, Tuple 

4 

5from tree_sitter import Node 

6 

7 

8class WrongNodeError(Exception): 

9 """Exception thrown when an invalid node is passed to a typed function""" 

10 

11 

12class Fortran(Enum): 

13 """Represents raw fortran source code tree elements""" 

14 

15 COMMENT = "comment" 

16 SUBROUTINE = "subroutine" 

17 SUBROUTINE_STMT = "subroutine_statement" 

18 FUNCTION = "function" 

19 FUNCTION_STMT = "function_statement" 

20 NAME = "name" 

21 SIZE = "size" 

22 INTRINSIC_TYPE = "intrinsic_type" 

23 ASSIGNMENT_STMT = "assignment_statement" 

24 ARGUMENT_LIST = "argument_list" 

25 SUBROUTINE_CALL = "subroutine_call" 

26 IDENTIFIER = "identifier" 

27 VARIABLE_DECLARATION = "variable_declaration" 

28 RELATIONAL_EXPR = "relational_expression" 

29 IF_STMT = "if_statement" 

30 PAREN_EXPRESSION = "parenthesized_expression" 

31 KEYWORD_ARGUMENT = "keyword_argument" 

32 STRING_LITERAL = "string_literal" 

33 NUMBER_LITERAL = "number_literal" 

34 TYPE_QUALIFIER = "type_qualifier" 

35 CALL_EXPRESSION = "call_expression" 

36 

37 UNKNOWN = "unknown" 

38 

39 

40FortranLookup = {k.value: k for k in Fortran} 

41 

42 

43class FortranNode: 

44 """Wrapper for tree_sitter Node type to add extra functionality""" 

45 

46 def __init__(self, node: Node): 

47 self.node = node 

48 

49 self.type: Optional[str] 

50 

51 if self.node.is_named: 

52 self.type = self.node.type 

53 else: 

54 self.type = None 

55 

56 @property 

57 def ftype(self) -> Fortran: 

58 """Return the node type as member of the Fortran enum""" 

59 if self.type in FortranLookup: 

60 return FortranLookup[self.type] 

61 else: 

62 return Fortran.UNKNOWN 

63 

64 def is_type(self, ftype: Fortran) -> bool: 

65 """Checks if a fortran node is of the supplied type""" 

66 return self.ftype == ftype 

67 

68 @property 

69 def children(self) -> List["FortranNode"]: 

70 """Return all children of this node""" 

71 return [FortranNode(c) for c in self.node.children] 

72 

73 def next_named_sibling(self) -> Optional["FortranNode"]: 

74 """Return the next named sibling of the current node""" 

75 if self.node.next_named_sibling: 75 ↛ 78line 75 didn't jump to line 78, because the condition on line 75 was never false

76 return FortranNode(self.node.next_named_sibling) 

77 else: 

78 return None 

79 

80 def get(self, ftype: Fortran) -> "FortranNode": 

81 """Return the first child node with the requested type""" 

82 for c in self.node.named_children: 

83 if c.type == ftype.value: 

84 return FortranNode(c) 

85 

86 err = f'"{ftype}" not found in children of node {self.raw}' 

87 raise KeyError(err) 

88 

89 def get_children_by_name(self, ftype: Fortran) -> List["FortranNode"]: 

90 """Return all the children with the requested type""" 

91 return [FortranNode(c) for c in self.node.named_children if c.type == ftype.value] 

92 

93 def split(self) -> Tuple["FortranNode", "FortranNode"]: 

94 """Split a relational node with a left and right part into the two child nodes""" 

95 left = self.node.child_by_field_name("left") 

96 

97 if left is None: 97 ↛ 98line 97 didn't jump to line 98, because the condition on line 97 was never true

98 err = f"Unable to find left part of node pair: {self.raw}" 

99 raise KeyError(err) 

100 

101 right = self.node.child_by_field_name("right") 

102 

103 if right is None: 103 ↛ 104line 103 didn't jump to line 104, because the condition on line 103 was never true

104 err = f"Unable to find right part of node pair: {self.raw}" 

105 raise KeyError(err) 

106 

107 return FortranNode(left), FortranNode(right) 

108 

109 @property 

110 def raw(self) -> str: 

111 """Return a string of all the text in a node as unicode""" 

112 return self.node.text.decode() 

113 

114 def parse_string_literal(self) -> str: 

115 "Parse a string literal object to get the string" 

116 if not self.type == "string_literal": 116 ↛ 117line 116 didn't jump to line 117, because the condition on line 116 was never true

117 err = f"Tried to parse {self.raw} as string literal" 

118 raise WrongNodeError(err) 

119 return self.raw.strip("\"'")