Coverage for src/castep_linter/fortran/fortran_node.py: 89%
75 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-11-23 18:07 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2023-11-23 18:07 +0000
1"""Module containing useful classes for parsing a fortran source tree from tree-sitter"""
2from enum import Enum
3from typing import List, Optional, Tuple
5from tree_sitter import Node
8class WrongNodeError(Exception):
9 """Exception thrown when an invalid node is passed to a typed function"""
12class Fortran(Enum):
13 """Represents raw fortran source code tree elements"""
15 COMMENT = "comment"
16 SUBROUTINE = "subroutine"
17 SUBROUTINE_STMT = "subroutine_statement"
18 FUNCTION = "function"
19 FUNCTION_STMT = "function_statement"
20 NAME = "name"
21 SIZE = "size"
22 INTRINSIC_TYPE = "intrinsic_type"
23 ASSIGNMENT_STMT = "assignment_statement"
24 ARGUMENT_LIST = "argument_list"
25 SUBROUTINE_CALL = "subroutine_call"
26 IDENTIFIER = "identifier"
27 VARIABLE_DECLARATION = "variable_declaration"
28 RELATIONAL_EXPR = "relational_expression"
29 IF_STMT = "if_statement"
30 PAREN_EXPRESSION = "parenthesized_expression"
31 KEYWORD_ARGUMENT = "keyword_argument"
32 STRING_LITERAL = "string_literal"
33 NUMBER_LITERAL = "number_literal"
34 TYPE_QUALIFIER = "type_qualifier"
35 CALL_EXPRESSION = "call_expression"
37 UNKNOWN = "unknown"
40FortranLookup = {k.value: k for k in Fortran}
43class FortranNode:
44 """Wrapper for tree_sitter Node type to add extra functionality"""
46 def __init__(self, node: Node):
47 self.node = node
49 self.type: Optional[str]
51 if self.node.is_named:
52 self.type = self.node.type
53 else:
54 self.type = None
56 @property
57 def ftype(self) -> Fortran:
58 """Return the node type as member of the Fortran enum"""
59 if self.type in FortranLookup:
60 return FortranLookup[self.type]
61 else:
62 return Fortran.UNKNOWN
64 def is_type(self, ftype: Fortran) -> bool:
65 """Checks if a fortran node is of the supplied type"""
66 return self.ftype == ftype
68 @property
69 def children(self) -> List["FortranNode"]:
70 """Return all children of this node"""
71 return [FortranNode(c) for c in self.node.children]
73 def next_named_sibling(self) -> Optional["FortranNode"]:
74 """Return the next named sibling of the current node"""
75 if self.node.next_named_sibling: 75 ↛ 78line 75 didn't jump to line 78, because the condition on line 75 was never false
76 return FortranNode(self.node.next_named_sibling)
77 else:
78 return None
80 def get(self, ftype: Fortran) -> "FortranNode":
81 """Return the first child node with the requested type"""
82 for c in self.node.named_children:
83 if c.type == ftype.value:
84 return FortranNode(c)
86 err = f'"{ftype}" not found in children of node {self.raw}'
87 raise KeyError(err)
89 def get_children_by_name(self, ftype: Fortran) -> List["FortranNode"]:
90 """Return all the children with the requested type"""
91 return [FortranNode(c) for c in self.node.named_children if c.type == ftype.value]
93 def split(self) -> Tuple["FortranNode", "FortranNode"]:
94 """Split a relational node with a left and right part into the two child nodes"""
95 left = self.node.child_by_field_name("left")
97 if left is None: 97 ↛ 98line 97 didn't jump to line 98, because the condition on line 97 was never true
98 err = f"Unable to find left part of node pair: {self.raw}"
99 raise KeyError(err)
101 right = self.node.child_by_field_name("right")
103 if right is None: 103 ↛ 104line 103 didn't jump to line 104, because the condition on line 103 was never true
104 err = f"Unable to find right part of node pair: {self.raw}"
105 raise KeyError(err)
107 return FortranNode(left), FortranNode(right)
109 @property
110 def raw(self) -> str:
111 """Return a string of all the text in a node as unicode"""
112 return self.node.text.decode()
114 def parse_string_literal(self) -> str:
115 "Parse a string literal object to get the string"
116 if not self.type == "string_literal": 116 ↛ 117line 116 didn't jump to line 117, because the condition on line 116 was never true
117 err = f"Tried to parse {self.raw} as string literal"
118 raise WrongNodeError(err)
119 return self.raw.strip("\"'")