Coverage for tests/nodes/test_core.py: 100%

94 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-09-02 11:07 +0100

1"""Tests for the `certus.nodes.core` module.""" 

2 

3import math 

4from unittest import mock 

5 

6import hypothesis as hyp 

7import hypothesis.strategies as st 

8import pytest 

9 

10from certus.nodes import core 

11 

12from . import common 

13 

14ST_EMPTY_COMPOSITE_NODES = st.builds(core.Composite, children=st.just([])) 

15 

16 

17@hyp.given(st.text(), common.ST_LOGPROBS, common.ST_STARTS) 

18def test_token_init(value, logprob, start): 

19 """Check a token is instantiated as expected.""" 

20 token = core.Token(value, logprob, start) 

21 

22 assert token.value == value 

23 assert token.logprob == logprob 

24 assert token.start == start 

25 assert token._confidence is None 

26 

27 

28@hyp.given(st.text(), common.ST_LOGPROBS, common.ST_STARTS) 

29def test_token_confidence_one_time(value, logprob, start): 

30 """ 

31 Check a token calculates its confidence only once. 

32 

33 We mock the clamp utility function here, telling it to pass through 

34 the linear probability unchanged. Then we check the confidence is 

35 only calculated once by accessing the property twice and checking 

36 this mock is called once. 

37 """ 

38 with mock.patch.object(core.utils, "clamp") as clamp: 

39 clamp.side_effect = lambda p, _, __: p 

40 token = core.Token(value, logprob, start) 

41 c1 = token.confidence 

42 c2 = token.confidence 

43 

44 assert c1 == c2 == math.exp(logprob) 

45 

46 clamp.assert_called_once_with(c1, 0.0, 1.0) 

47 

48 

49@hyp.given(common.ST_COMPOSITE_NODES) 

50def test_composite_node_init(composite): 

51 """Check a composite is instantiated as expected.""" 

52 children = composite.children 

53 assert isinstance(children, list) 

54 assert all(isinstance(child, (core.Token, core.Composite)) for child in children) 

55 

56 assert composite._value is None 

57 assert composite._logprob is None 

58 assert composite._start is None 

59 assert composite._confidence is None 

60 assert composite._leaves is None 

61 

62 

63@hyp.given(ST_EMPTY_COMPOSITE_NODES, common.st_token_lists()) 

64def test_composite_node_value_one_time(composite, leaves): 

65 """ 

66 Check a composite calculates its value only once. 

67 

68 We mock the leaf gatherer so we can ensure it is only called once, 

69 passing a set of leaf nodes. 

70 """ 

71 with mock.patch.object(core, "gather_leaves", return_value=leaves) as gather_leaves: 

72 v1 = composite.value 

73 v2 = composite.value 

74 

75 assert v1 == v2 == "".join(leaf.value for leaf in leaves) 

76 

77 gather_leaves.assert_called_once_with(composite) 

78 

79 

80@hyp.given(ST_EMPTY_COMPOSITE_NODES, common.st_token_lists()) 

81def test_composite_node_logprob_one_time(composite, leaves): 

82 """ 

83 Check a composite calculates its log-probability only once. 

84 

85 We mock the leaf gatherer so we can ensure it is only called once, 

86 passing a set of leaf nodes. 

87 """ 

88 with mock.patch.object(core, "gather_leaves", return_value=leaves) as gather_leaves: 

89 l1 = composite.logprob 

90 l2 = composite.logprob 

91 

92 assert l1 <= 0 

93 assert l1 == l2 == sum(leaf.logprob for leaf in leaves) 

94 

95 gather_leaves.assert_called_once_with(composite) 

96 

97 

98@hyp.given(ST_EMPTY_COMPOSITE_NODES, common.st_token_lists()) 

99def test_composite_node_start_one_time(composite, leaves): 

100 """ 

101 Check a composite calculates its start only once. 

102 

103 We mock the leaf gatherer so we can ensure it is only called once, 

104 passing a set of leaf nodes. 

105 """ 

106 with mock.patch.object(core, "gather_leaves", return_value=leaves) as gather_leaves: 

107 s1 = composite.start 

108 s2 = composite.start 

109 

110 assert s1 >= 0 

111 assert s1 == s2 == min(leaf.start for leaf in leaves) 

112 

113 gather_leaves.assert_called_once_with(composite) 

114 

115 

116@hyp.given(ST_EMPTY_COMPOSITE_NODES, common.st_token_lists()) 

117def test_composite_node_confidence_one_time(composite, leaves): 

118 """ 

119 Check a composite calculates its confidence only once. 

120 

121 We mock the leaf gatherer so we can ensure it is only called once, 

122 passing a set of leaf nodes. We also mock the clamp utility to pass 

123 the probability through unchanged. 

124 """ 

125 with ( 

126 mock.patch.object(core, "gather_leaves", return_value=leaves) as gather_leaves, 

127 mock.patch.object(core.utils, "clamp", side_effect=lambda p, _, __: p) as clamp, 

128 ): 

129 c1 = composite.confidence 

130 c2 = composite.confidence 

131 

132 assert 0 <= c1 <= 1 

133 assert c1 == c2 == math.exp(sum(leaf.logprob for leaf in leaves) / len(leaves)) 

134 

135 gather_leaves.assert_called_once_with(composite) 

136 clamp.assert_called_once_with(c1, 0.0, 1.0) 

137 

138 

139@hyp.given(ST_EMPTY_COMPOSITE_NODES, common.st_token_lists()) 

140def test_composite_node_leaves_one_time(composite, leaves): 

141 """ 

142 Check a composite calculates its leaves only once. 

143 

144 We mock the leaf gatherer so we can ensure it is only called once, 

145 passing a set of leaf nodes. 

146 """ 

147 with mock.patch.object(core, "gather_leaves", return_value=leaves) as gather_leaves: 

148 l1 = composite.leaves 

149 l2 = composite.leaves 

150 

151 assert l1 == l2 == leaves 

152 

153 gather_leaves.assert_called_once_with(composite) 

154 

155 

156@hyp.given(common.ST_COMPOSITE_NODES) 

157def test_gather_leaves_composite_node(composite): 

158 """Check gathering from a composite returns a list of tokens.""" 

159 leaves = core.gather_leaves(composite) 

160 

161 def _count_leaves(node_: core.Composite | core.Token) -> int: 

162 if isinstance(node_, core.Token): 

163 return 1 

164 

165 return sum(_count_leaves(child) for child in node_.children) 

166 

167 assert isinstance(leaves, list) 

168 assert all(isinstance(leaf, core.Token) for leaf in leaves) 

169 assert len(leaves) == _count_leaves(composite) 

170 

171 

172@hyp.given(common.st_tokens()) 

173def test_gather_leaves_solo_token_node(token): 

174 """Check gathering from a token returns itself in a list.""" 

175 assert core.gather_leaves(token) == [token] 

176 

177 

178@hyp.given(st.builds(core.Composite, children=common.st_token_lists())) 

179def test_gather_leaves_composite_all_father(composite): 

180 """Check gathering from an all-father gives its children.""" 

181 assert core.gather_leaves(composite) == composite.children 

182 

183 

184def test_gather_leaves_raises_for_other_node_type(): 

185 """Check an unknown node type throws an error.""" 

186 

187 class NotNode: 

188 pass 

189 

190 with pytest.raises(ValueError, match=r"Invalid node type:.*NotNode"): 

191 _ = core.gather_leaves(NotNode()) # type: ignore[reportArgumentType]