Coverage for src/certus/node.py: 100%

51 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-08-13 09:46 +0100

1"""Module for the token node class.""" 

2 

3import dataclasses 

4import math 

5import typing 

6 

7from . import utils 

8 

9NodeType = typing.Union["TokenNode", "CompositeNode"] 

10 

11 

12@dataclasses.dataclass 

13class TokenNode: 

14 """ 

15 Data model for a token leaf node. 

16 

17 Parameters 

18 ---------- 

19 value : str 

20 Value of the token in the output. 

21 logprob : float 

22 Log-probability of the token. 

23 

24 Attributes 

25 ---------- 

26 confidence : float 

27 Confidence (probability) of the token. 

28 """ 

29 

30 value: str 

31 logprob: float 

32 

33 def __post_init__(self): 

34 self._confidence: float | None = None 

35 

36 @property 

37 def confidence(self) -> float: 

38 """Set or return the linear probability of the token.""" 

39 if self._confidence is None: 

40 self._confidence = utils.clamp(math.exp(self.logprob), 0.0, 1.0) 

41 

42 return self._confidence 

43 

44 

45@dataclasses.dataclass 

46class CompositeNode: 

47 """ 

48 Data model for a node made up of other nodes. 

49 

50 Parameters 

51 ---------- 

52 children : list of TokenNode or CompositeNode 

53 Nodes contained within this composite. 

54 

55 Attributes 

56 ---------- 

57 confidence : float 

58 Confidence of the composite. Derived as the geometric mean of 

59 the log-probabilities of all downstream token (leaf) nodes. 

60 """ 

61 

62 children: list[NodeType] 

63 

64 def __post_init__(self): 

65 self._value: str | None = None 

66 self._logprob: float | None = None 

67 self._confidence: float | None = None 

68 self._leaves: list[TokenNode] | None = None 

69 

70 @property 

71 def value(self) -> str: 

72 """Set or return the concatenation of the composite's values.""" 

73 if self._value is None: 

74 self._value = " ".join(leaf.value for leaf in self.leaves) 

75 

76 return self._value 

77 

78 @property 

79 def logprob(self) -> float: 

80 """Set or return the sum of the log-probs of the composite.""" 

81 if self._logprob is None: 

82 self._logprob = sum(leaf.logprob for leaf in self.leaves) 

83 

84 return self._logprob 

85 

86 @property 

87 def confidence(self) -> float: 

88 """Set or return the confidence of the composite.""" 

89 if self._confidence is None: 

90 mean_logprob = self.logprob / len(self.leaves) 

91 self._confidence = utils.clamp(math.exp(mean_logprob), 0.0, 1.0) 

92 

93 return self._confidence 

94 

95 @property 

96 def leaves(self) -> list[TokenNode]: 

97 """Return the leaf nodes downstream of this composite node.""" 

98 if self._leaves is None: 

99 self._leaves = gather_leaves(self) 

100 

101 return self._leaves 

102 

103 

104def gather_leaves(node: TokenNode | CompositeNode) -> list[TokenNode]: 

105 """ 

106 Get the leaf nodes downstream of a node. 

107 

108 Parameters 

109 ---------- 

110 node : TokenNode or CompositeNode 

111 A leaf node or one in which to delve for more. 

112 

113 Returns 

114 ------- 

115 list of TokenNode 

116 Leaf nodes in the composite tree. 

117 """ 

118 if isinstance(node, CompositeNode): 

119 return [leaf for child in node.children for leaf in gather_leaves(child)] 

120 if isinstance(node, TokenNode): 

121 return [node] 

122 

123 raise ValueError(f"Invalid node type: {node}, {node.__class__}")