Coverage for tests/nodes/test_core.py: 100%
94 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-02 11:07 +0100
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-02 11:07 +0100
1"""Tests for the `certus.nodes.core` module."""
3import math
4from unittest import mock
6import hypothesis as hyp
7import hypothesis.strategies as st
8import pytest
10from certus.nodes import core
12from . import common
14ST_EMPTY_COMPOSITE_NODES = st.builds(core.Composite, children=st.just([]))
17@hyp.given(st.text(), common.ST_LOGPROBS, common.ST_STARTS)
18def test_token_init(value, logprob, start):
19 """Check a token is instantiated as expected."""
20 token = core.Token(value, logprob, start)
22 assert token.value == value
23 assert token.logprob == logprob
24 assert token.start == start
25 assert token._confidence is None
28@hyp.given(st.text(), common.ST_LOGPROBS, common.ST_STARTS)
29def test_token_confidence_one_time(value, logprob, start):
30 """
31 Check a token calculates its confidence only once.
33 We mock the clamp utility function here, telling it to pass through
34 the linear probability unchanged. Then we check the confidence is
35 only calculated once by accessing the property twice and checking
36 this mock is called once.
37 """
38 with mock.patch.object(core.utils, "clamp") as clamp:
39 clamp.side_effect = lambda p, _, __: p
40 token = core.Token(value, logprob, start)
41 c1 = token.confidence
42 c2 = token.confidence
44 assert c1 == c2 == math.exp(logprob)
46 clamp.assert_called_once_with(c1, 0.0, 1.0)
49@hyp.given(common.ST_COMPOSITE_NODES)
50def test_composite_node_init(composite):
51 """Check a composite is instantiated as expected."""
52 children = composite.children
53 assert isinstance(children, list)
54 assert all(isinstance(child, (core.Token, core.Composite)) for child in children)
56 assert composite._value is None
57 assert composite._logprob is None
58 assert composite._start is None
59 assert composite._confidence is None
60 assert composite._leaves is None
63@hyp.given(ST_EMPTY_COMPOSITE_NODES, common.st_token_lists())
64def test_composite_node_value_one_time(composite, leaves):
65 """
66 Check a composite calculates its value only once.
68 We mock the leaf gatherer so we can ensure it is only called once,
69 passing a set of leaf nodes.
70 """
71 with mock.patch.object(core, "gather_leaves", return_value=leaves) as gather_leaves:
72 v1 = composite.value
73 v2 = composite.value
75 assert v1 == v2 == "".join(leaf.value for leaf in leaves)
77 gather_leaves.assert_called_once_with(composite)
80@hyp.given(ST_EMPTY_COMPOSITE_NODES, common.st_token_lists())
81def test_composite_node_logprob_one_time(composite, leaves):
82 """
83 Check a composite calculates its log-probability only once.
85 We mock the leaf gatherer so we can ensure it is only called once,
86 passing a set of leaf nodes.
87 """
88 with mock.patch.object(core, "gather_leaves", return_value=leaves) as gather_leaves:
89 l1 = composite.logprob
90 l2 = composite.logprob
92 assert l1 <= 0
93 assert l1 == l2 == sum(leaf.logprob for leaf in leaves)
95 gather_leaves.assert_called_once_with(composite)
98@hyp.given(ST_EMPTY_COMPOSITE_NODES, common.st_token_lists())
99def test_composite_node_start_one_time(composite, leaves):
100 """
101 Check a composite calculates its start only once.
103 We mock the leaf gatherer so we can ensure it is only called once,
104 passing a set of leaf nodes.
105 """
106 with mock.patch.object(core, "gather_leaves", return_value=leaves) as gather_leaves:
107 s1 = composite.start
108 s2 = composite.start
110 assert s1 >= 0
111 assert s1 == s2 == min(leaf.start for leaf in leaves)
113 gather_leaves.assert_called_once_with(composite)
116@hyp.given(ST_EMPTY_COMPOSITE_NODES, common.st_token_lists())
117def test_composite_node_confidence_one_time(composite, leaves):
118 """
119 Check a composite calculates its confidence only once.
121 We mock the leaf gatherer so we can ensure it is only called once,
122 passing a set of leaf nodes. We also mock the clamp utility to pass
123 the probability through unchanged.
124 """
125 with (
126 mock.patch.object(core, "gather_leaves", return_value=leaves) as gather_leaves,
127 mock.patch.object(core.utils, "clamp", side_effect=lambda p, _, __: p) as clamp,
128 ):
129 c1 = composite.confidence
130 c2 = composite.confidence
132 assert 0 <= c1 <= 1
133 assert c1 == c2 == math.exp(sum(leaf.logprob for leaf in leaves) / len(leaves))
135 gather_leaves.assert_called_once_with(composite)
136 clamp.assert_called_once_with(c1, 0.0, 1.0)
139@hyp.given(ST_EMPTY_COMPOSITE_NODES, common.st_token_lists())
140def test_composite_node_leaves_one_time(composite, leaves):
141 """
142 Check a composite calculates its leaves only once.
144 We mock the leaf gatherer so we can ensure it is only called once,
145 passing a set of leaf nodes.
146 """
147 with mock.patch.object(core, "gather_leaves", return_value=leaves) as gather_leaves:
148 l1 = composite.leaves
149 l2 = composite.leaves
151 assert l1 == l2 == leaves
153 gather_leaves.assert_called_once_with(composite)
156@hyp.given(common.ST_COMPOSITE_NODES)
157def test_gather_leaves_composite_node(composite):
158 """Check gathering from a composite returns a list of tokens."""
159 leaves = core.gather_leaves(composite)
161 def _count_leaves(node_: core.Composite | core.Token) -> int:
162 if isinstance(node_, core.Token):
163 return 1
165 return sum(_count_leaves(child) for child in node_.children)
167 assert isinstance(leaves, list)
168 assert all(isinstance(leaf, core.Token) for leaf in leaves)
169 assert len(leaves) == _count_leaves(composite)
172@hyp.given(common.st_tokens())
173def test_gather_leaves_solo_token_node(token):
174 """Check gathering from a token returns itself in a list."""
175 assert core.gather_leaves(token) == [token]
178@hyp.given(st.builds(core.Composite, children=common.st_token_lists()))
179def test_gather_leaves_composite_all_father(composite):
180 """Check gathering from an all-father gives its children."""
181 assert core.gather_leaves(composite) == composite.children
184def test_gather_leaves_raises_for_other_node_type():
185 """Check an unknown node type throws an error."""
187 class NotNode:
188 pass
190 with pytest.raises(ValueError, match=r"Invalid node type:.*NotNode"):
191 _ = core.gather_leaves(NotNode()) # type: ignore[reportArgumentType]