Coverage for tests/test_node.py: 100%
87 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-13 09:46 +0100
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-13 09:46 +0100
1"""Tests for the `certus.node` module."""
3import math
4from unittest import mock
6import hypothesis as hyp
7import hypothesis.strategies as st
8import pytest
10from certus import node
12ST_LOGPROBS = st.floats(max_value=0)
13ST_TOKEN_NODES = st.builds(node.TokenNode, logprob=ST_LOGPROBS)
15ST_COMPOSITE_NODES = st.recursive(
16 ST_TOKEN_NODES,
17 lambda children: st.builds(
18 node.CompositeNode, children=st.lists(children, min_size=1, max_size=3)
19 ),
20 max_leaves=10,
21).filter(lambda n: isinstance(n, node.CompositeNode))
23ST_EMPTY_COMPOSITE_NODES = st.builds(node.CompositeNode, children=st.just([]))
24ST_LEAF_LISTS = st.lists(ST_TOKEN_NODES, min_size=1)
27@hyp.given(st.text(), ST_LOGPROBS)
28def test_token_node_init(value, logprob):
29 """Check a node is instantiated as expected."""
30 token = node.TokenNode(value, logprob)
32 assert token.value == value
33 assert token.logprob == logprob
34 assert token._confidence is None
37@hyp.given(st.text(), ST_LOGPROBS)
38def test_token_node_confidence_one_time(value, logprob):
39 """
40 Check a node calculates its confidence only once.
42 We mock the clamp utility function here, telling it to pass through
43 the linear probability unchanged. Then we check the confidence is
44 only calculated once by accessing the property twice and checking
45 this mock is called once.
46 """
47 with mock.patch.object(node.utils, "clamp") as clamp:
48 clamp.side_effect = lambda p, _, __: p
49 token = node.TokenNode(value, logprob)
50 c1 = token.confidence
51 c2 = token.confidence
53 assert c1 == c2 == math.exp(logprob)
55 clamp.assert_called_once_with(c1, 0.0, 1.0)
58@hyp.given(ST_COMPOSITE_NODES)
59def test_composite_node_init(composite):
60 """Check a node is instantiated as expected."""
61 children = composite.children
62 assert isinstance(children, list)
63 assert all(isinstance(child, (node.TokenNode, node.CompositeNode)) for child in children)
65 assert composite._value is None
66 assert composite._logprob is None
67 assert composite._confidence is None
68 assert composite._leaves is None
71@hyp.given(ST_EMPTY_COMPOSITE_NODES, ST_LEAF_LISTS)
72def test_composite_node_value_one_time(composite, leaves):
73 """
74 Check a node calculates its value only once.
76 We mock the leaf gatherer so we can ensure it is only called once,
77 passing a set of leaf nodes.
78 """
79 with mock.patch.object(node, "gather_leaves", return_value=leaves) as gather_leaves:
80 v1 = composite.value
81 v2 = composite.value
83 assert v1 == v2 == " ".join(leaf.value for leaf in leaves)
85 gather_leaves.assert_called_once_with(composite)
88@hyp.given(ST_EMPTY_COMPOSITE_NODES, ST_LEAF_LISTS)
89def test_composite_node_logprob_one_time(composite, leaves):
90 """
91 Check a node calculates its log-probability only once.
93 We mock the leaf gatherer so we can ensure it is only called once,
94 passing a set of leaf nodes.
95 """
96 with mock.patch.object(node, "gather_leaves", return_value=leaves) as gather_leaves:
97 l1 = composite.logprob
98 l2 = composite.logprob
100 assert l1 <= 0
101 assert l1 == l2 == sum(leaf.logprob for leaf in leaves)
103 gather_leaves.assert_called_once_with(composite)
106@hyp.given(ST_EMPTY_COMPOSITE_NODES, ST_LEAF_LISTS)
107def test_composite_node_confidence_one_time(composite, leaves):
108 """
109 Check a node calculates its confidence only once.
111 We mock the leaf gatherer so we can ensure it is only called once,
112 passing a set of leaf nodes. We also mock the clamp utility to pass
113 the probability through unchanged.
114 """
115 with (
116 mock.patch.object(node, "gather_leaves", return_value=leaves) as gather_leaves,
117 mock.patch.object(node.utils, "clamp", side_effect=lambda p, _, __: p) as clamp,
118 ):
119 c1 = composite.confidence
120 c2 = composite.confidence
122 assert 0 <= c1 <= 1
123 assert c1 == c2 == math.exp(sum(leaf.logprob for leaf in leaves) / len(leaves))
125 gather_leaves.assert_called_once_with(composite)
126 clamp.assert_called_once_with(c1, 0.0, 1.0)
129@hyp.given(ST_EMPTY_COMPOSITE_NODES, ST_LEAF_LISTS)
130def test_composite_node_leaves_one_time(composite, leaves):
131 """
132 Check a node calculates its leaves only once.
134 We mock the leaf gatherer so we can ensure it is only called once,
135 passing a set of leaf nodes.
136 """
137 with mock.patch.object(node, "gather_leaves", return_value=leaves) as gather_leaves:
138 l1 = composite.leaves
139 l2 = composite.leaves
141 assert l1 == l2 == leaves
143 gather_leaves.assert_called_once_with(composite)
146@hyp.given(ST_COMPOSITE_NODES)
147def test_gather_leaves_composite_node(composite):
148 """Check gathering from a composite returns a list of tokens."""
149 leaves = node.gather_leaves(composite)
151 def _count_leaves(node_: node.CompositeNode | node.TokenNode) -> int:
152 if isinstance(node_, node.TokenNode):
153 return 1
155 return sum(_count_leaves(child) for child in node_.children)
157 assert isinstance(leaves, list)
158 assert all(isinstance(leaf, node.TokenNode) for leaf in leaves)
159 assert len(leaves) == _count_leaves(composite)
162@hyp.given(ST_TOKEN_NODES)
163def test_gather_leaves_solo_token_node(token):
164 """Check gathering from a token node returns itself in a list."""
165 assert node.gather_leaves(token) == [token]
168@hyp.given(st.builds(node.CompositeNode, children=ST_LEAF_LISTS))
169def test_gather_leaves_composite_all_father(composite):
170 """Check gathering from an all-father gives the leaves we pass."""
171 assert node.gather_leaves(composite) == composite.children
174def test_gather_leaves_raises_for_other_node_type():
175 """Check an unknown node type throws an error."""
177 class NotNode:
178 pass
180 with pytest.raises(ValueError, match=r"Invalid node type:.*NotNode"):
181 _ = node.gather_leaves(NotNode()) # type: ignore[reportArgumentType]