Coverage for src/certus/nodes/core.py: 100%
58 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-04 09:44 +0100
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-04 09:44 +0100
1"""Core node models."""
3import dataclasses
4import math
5import typing
7from certus import utils
9NodeType = typing.Union["Composite", "Token"]
12@dataclasses.dataclass
13class Token:
14 """
15 Data model for a token leaf node.
17 Parameters
18 ----------
19 value : str
20 Value of the token in the output.
21 logprob : float
22 Log-probability of the token.
23 start : int
24 Position of the first token character in the response.
26 Attributes
27 ----------
28 confidence : float
29 Confidence (probability) of the token.
30 """
32 value: str
33 logprob: float
34 start: int
36 def __post_init__(self):
37 self._confidence: float | None = None
39 @property
40 def confidence(self) -> float:
41 """Set or return the linear probability of the token."""
42 if self._confidence is None:
43 self._confidence = utils.clamp(math.exp(self.logprob), 0.0, 1.0)
45 return self._confidence
48@dataclasses.dataclass
49class Composite:
50 """
51 Data model for a node made up of other nodes.
53 Parameters
54 ----------
55 children : list of Token or Composite
56 Nodes contained within this composite.
58 Attributes
59 ----------
60 leaves : list of Token
61 All leaf nodes downstream from the composite.
62 value : float
63 Value of the composite. Taken as the concatenation of the
64 composite's leaf nodes' values separated by spaces.
65 logprob : float
66 Log-probability of the composite. Taken as the sum of the
67 log-probability for each leaf node of the composite.
68 start : int
69 Position of the first character in the composite. Taken as the
70 minimum of the starts for each leaf node in the composite.
71 confidence : float
72 Confidence of the composite. Derived as the geometric mean of
73 the log-probabilities of all downstream token (leaf) nodes.
74 """
76 children: typing.Sequence[NodeType] = dataclasses.field(default_factory=list)
78 def __post_init__(self):
79 self._leaves: list[Token] | None = None
80 self._value: str | None = None
81 self._logprob: float | None = None
82 self._start: int | None = None
83 self._confidence: float | None = None
85 @property
86 def value(self) -> str:
87 """Set or return the concatenation of the composite's values."""
88 if self._value is None:
89 self._value = "".join(leaf.value for leaf in self.leaves)
91 return self._value
93 @property
94 def logprob(self) -> float:
95 """Set or return the sum of the log-probs of the composite."""
96 if self._logprob is None:
97 self._logprob = sum(leaf.logprob for leaf in self.leaves)
99 return self._logprob
101 @property
102 def start(self) -> int:
103 """Set or return the earliest start in the composite."""
104 if self._start is None:
105 self._start = min(leaf.start for leaf in self.leaves)
107 return self._start
109 @property
110 def confidence(self) -> float:
111 """Set or return the confidence of the composite."""
112 if self._confidence is None:
113 mean_logprob = self.logprob / len(self.leaves) if self.leaves else float("-inf")
114 self._confidence = utils.clamp(math.exp(mean_logprob), 0.0, 1.0)
116 return self._confidence
118 @property
119 def leaves(self) -> list[Token]:
120 """Return the leaf nodes downstream of this composite node."""
121 if self._leaves is None:
122 self._leaves = gather_leaves(self)
124 return self._leaves
127def gather_leaves(node: Token | Composite) -> list[Token]:
128 """
129 Get the leaf nodes downstream of a node.
131 Parameters
132 ----------
133 node : Token or Composite
134 A leaf node or one in which to delve for more.
136 Returns
137 -------
138 list of Token
139 Leaf nodes in the composite tree.
140 """
141 if isinstance(node, Composite):
142 return [leaf for child in node.children for leaf in gather_leaves(child)]
143 if isinstance(node, Token):
144 return [node]
146 raise ValueError(f"Invalid node type: {node}, {node.__class__}")