Coverage for src/certus/nodes/core.py: 100%

58 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-09-04 09:44 +0100

1"""Core node models.""" 

2 

3import dataclasses 

4import math 

5import typing 

6 

7from certus import utils 

8 

9NodeType = typing.Union["Composite", "Token"] 

10 

11 

12@dataclasses.dataclass 

13class Token: 

14 """ 

15 Data model for a token leaf node. 

16 

17 Parameters 

18 ---------- 

19 value : str 

20 Value of the token in the output. 

21 logprob : float 

22 Log-probability of the token. 

23 start : int 

24 Position of the first token character in the response. 

25 

26 Attributes 

27 ---------- 

28 confidence : float 

29 Confidence (probability) of the token. 

30 """ 

31 

32 value: str 

33 logprob: float 

34 start: int 

35 

36 def __post_init__(self): 

37 self._confidence: float | None = None 

38 

39 @property 

40 def confidence(self) -> float: 

41 """Set or return the linear probability of the token.""" 

42 if self._confidence is None: 

43 self._confidence = utils.clamp(math.exp(self.logprob), 0.0, 1.0) 

44 

45 return self._confidence 

46 

47 

48@dataclasses.dataclass 

49class Composite: 

50 """ 

51 Data model for a node made up of other nodes. 

52 

53 Parameters 

54 ---------- 

55 children : list of Token or Composite 

56 Nodes contained within this composite. 

57 

58 Attributes 

59 ---------- 

60 leaves : list of Token 

61 All leaf nodes downstream from the composite. 

62 value : float 

63 Value of the composite. Taken as the concatenation of the 

64 composite's leaf nodes' values separated by spaces. 

65 logprob : float 

66 Log-probability of the composite. Taken as the sum of the 

67 log-probability for each leaf node of the composite. 

68 start : int 

69 Position of the first character in the composite. Taken as the 

70 minimum of the starts for each leaf node in the composite. 

71 confidence : float 

72 Confidence of the composite. Derived as the geometric mean of 

73 the log-probabilities of all downstream token (leaf) nodes. 

74 """ 

75 

76 children: typing.Sequence[NodeType] = dataclasses.field(default_factory=list) 

77 

78 def __post_init__(self): 

79 self._leaves: list[Token] | None = None 

80 self._value: str | None = None 

81 self._logprob: float | None = None 

82 self._start: int | None = None 

83 self._confidence: float | None = None 

84 

85 @property 

86 def value(self) -> str: 

87 """Set or return the concatenation of the composite's values.""" 

88 if self._value is None: 

89 self._value = "".join(leaf.value for leaf in self.leaves) 

90 

91 return self._value 

92 

93 @property 

94 def logprob(self) -> float: 

95 """Set or return the sum of the log-probs of the composite.""" 

96 if self._logprob is None: 

97 self._logprob = sum(leaf.logprob for leaf in self.leaves) 

98 

99 return self._logprob 

100 

101 @property 

102 def start(self) -> int: 

103 """Set or return the earliest start in the composite.""" 

104 if self._start is None: 

105 self._start = min(leaf.start for leaf in self.leaves) 

106 

107 return self._start 

108 

109 @property 

110 def confidence(self) -> float: 

111 """Set or return the confidence of the composite.""" 

112 if self._confidence is None: 

113 mean_logprob = self.logprob / len(self.leaves) if self.leaves else float("-inf") 

114 self._confidence = utils.clamp(math.exp(mean_logprob), 0.0, 1.0) 

115 

116 return self._confidence 

117 

118 @property 

119 def leaves(self) -> list[Token]: 

120 """Return the leaf nodes downstream of this composite node.""" 

121 if self._leaves is None: 

122 self._leaves = gather_leaves(self) 

123 

124 return self._leaves 

125 

126 

127def gather_leaves(node: Token | Composite) -> list[Token]: 

128 """ 

129 Get the leaf nodes downstream of a node. 

130 

131 Parameters 

132 ---------- 

133 node : Token or Composite 

134 A leaf node or one in which to delve for more. 

135 

136 Returns 

137 ------- 

138 list of Token 

139 Leaf nodes in the composite tree. 

140 """ 

141 if isinstance(node, Composite): 

142 return [leaf for child in node.children for leaf in gather_leaves(child)] 

143 if isinstance(node, Token): 

144 return [node] 

145 

146 raise ValueError(f"Invalid node type: {node}, {node.__class__}")