Coverage for src/certus/node.py: 100%
51 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-13 09:46 +0100
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-13 09:46 +0100
1"""Module for the token node class."""
3import dataclasses
4import math
5import typing
7from . import utils
9NodeType = typing.Union["TokenNode", "CompositeNode"]
12@dataclasses.dataclass
13class TokenNode:
14 """
15 Data model for a token leaf node.
17 Parameters
18 ----------
19 value : str
20 Value of the token in the output.
21 logprob : float
22 Log-probability of the token.
24 Attributes
25 ----------
26 confidence : float
27 Confidence (probability) of the token.
28 """
30 value: str
31 logprob: float
33 def __post_init__(self):
34 self._confidence: float | None = None
36 @property
37 def confidence(self) -> float:
38 """Set or return the linear probability of the token."""
39 if self._confidence is None:
40 self._confidence = utils.clamp(math.exp(self.logprob), 0.0, 1.0)
42 return self._confidence
45@dataclasses.dataclass
46class CompositeNode:
47 """
48 Data model for a node made up of other nodes.
50 Parameters
51 ----------
52 children : list of TokenNode or CompositeNode
53 Nodes contained within this composite.
55 Attributes
56 ----------
57 confidence : float
58 Confidence of the composite. Derived as the geometric mean of
59 the log-probabilities of all downstream token (leaf) nodes.
60 """
62 children: list[NodeType]
64 def __post_init__(self):
65 self._value: str | None = None
66 self._logprob: float | None = None
67 self._confidence: float | None = None
68 self._leaves: list[TokenNode] | None = None
70 @property
71 def value(self) -> str:
72 """Set or return the concatenation of the composite's values."""
73 if self._value is None:
74 self._value = " ".join(leaf.value for leaf in self.leaves)
76 return self._value
78 @property
79 def logprob(self) -> float:
80 """Set or return the sum of the log-probs of the composite."""
81 if self._logprob is None:
82 self._logprob = sum(leaf.logprob for leaf in self.leaves)
84 return self._logprob
86 @property
87 def confidence(self) -> float:
88 """Set or return the confidence of the composite."""
89 if self._confidence is None:
90 mean_logprob = self.logprob / len(self.leaves)
91 self._confidence = utils.clamp(math.exp(mean_logprob), 0.0, 1.0)
93 return self._confidence
95 @property
96 def leaves(self) -> list[TokenNode]:
97 """Return the leaf nodes downstream of this composite node."""
98 if self._leaves is None:
99 self._leaves = gather_leaves(self)
101 return self._leaves
104def gather_leaves(node: TokenNode | CompositeNode) -> list[TokenNode]:
105 """
106 Get the leaf nodes downstream of a node.
108 Parameters
109 ----------
110 node : TokenNode or CompositeNode
111 A leaf node or one in which to delve for more.
113 Returns
114 -------
115 list of TokenNode
116 Leaf nodes in the composite tree.
117 """
118 if isinstance(node, CompositeNode):
119 return [leaf for child in node.children for leaf in gather_leaves(child)]
120 if isinstance(node, TokenNode):
121 return [node]
123 raise ValueError(f"Invalid node type: {node}, {node.__class__}")