Coverage for tests/test_node.py: 100%

87 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-08-13 09:46 +0100

1"""Tests for the `certus.node` module.""" 

2 

3import math 

4from unittest import mock 

5 

6import hypothesis as hyp 

7import hypothesis.strategies as st 

8import pytest 

9 

10from certus import node 

11 

12ST_LOGPROBS = st.floats(max_value=0) 

13ST_TOKEN_NODES = st.builds(node.TokenNode, logprob=ST_LOGPROBS) 

14 

15ST_COMPOSITE_NODES = st.recursive( 

16 ST_TOKEN_NODES, 

17 lambda children: st.builds( 

18 node.CompositeNode, children=st.lists(children, min_size=1, max_size=3) 

19 ), 

20 max_leaves=10, 

21).filter(lambda n: isinstance(n, node.CompositeNode)) 

22 

23ST_EMPTY_COMPOSITE_NODES = st.builds(node.CompositeNode, children=st.just([])) 

24ST_LEAF_LISTS = st.lists(ST_TOKEN_NODES, min_size=1) 

25 

26 

27@hyp.given(st.text(), ST_LOGPROBS) 

28def test_token_node_init(value, logprob): 

29 """Check a node is instantiated as expected.""" 

30 token = node.TokenNode(value, logprob) 

31 

32 assert token.value == value 

33 assert token.logprob == logprob 

34 assert token._confidence is None 

35 

36 

37@hyp.given(st.text(), ST_LOGPROBS) 

38def test_token_node_confidence_one_time(value, logprob): 

39 """ 

40 Check a node calculates its confidence only once. 

41 

42 We mock the clamp utility function here, telling it to pass through 

43 the linear probability unchanged. Then we check the confidence is 

44 only calculated once by accessing the property twice and checking 

45 this mock is called once. 

46 """ 

47 with mock.patch.object(node.utils, "clamp") as clamp: 

48 clamp.side_effect = lambda p, _, __: p 

49 token = node.TokenNode(value, logprob) 

50 c1 = token.confidence 

51 c2 = token.confidence 

52 

53 assert c1 == c2 == math.exp(logprob) 

54 

55 clamp.assert_called_once_with(c1, 0.0, 1.0) 

56 

57 

58@hyp.given(ST_COMPOSITE_NODES) 

59def test_composite_node_init(composite): 

60 """Check a node is instantiated as expected.""" 

61 children = composite.children 

62 assert isinstance(children, list) 

63 assert all(isinstance(child, (node.TokenNode, node.CompositeNode)) for child in children) 

64 

65 assert composite._value is None 

66 assert composite._logprob is None 

67 assert composite._confidence is None 

68 assert composite._leaves is None 

69 

70 

71@hyp.given(ST_EMPTY_COMPOSITE_NODES, ST_LEAF_LISTS) 

72def test_composite_node_value_one_time(composite, leaves): 

73 """ 

74 Check a node calculates its value only once. 

75 

76 We mock the leaf gatherer so we can ensure it is only called once, 

77 passing a set of leaf nodes. 

78 """ 

79 with mock.patch.object(node, "gather_leaves", return_value=leaves) as gather_leaves: 

80 v1 = composite.value 

81 v2 = composite.value 

82 

83 assert v1 == v2 == " ".join(leaf.value for leaf in leaves) 

84 

85 gather_leaves.assert_called_once_with(composite) 

86 

87 

88@hyp.given(ST_EMPTY_COMPOSITE_NODES, ST_LEAF_LISTS) 

89def test_composite_node_logprob_one_time(composite, leaves): 

90 """ 

91 Check a node calculates its log-probability only once. 

92 

93 We mock the leaf gatherer so we can ensure it is only called once, 

94 passing a set of leaf nodes. 

95 """ 

96 with mock.patch.object(node, "gather_leaves", return_value=leaves) as gather_leaves: 

97 l1 = composite.logprob 

98 l2 = composite.logprob 

99 

100 assert l1 <= 0 

101 assert l1 == l2 == sum(leaf.logprob for leaf in leaves) 

102 

103 gather_leaves.assert_called_once_with(composite) 

104 

105 

106@hyp.given(ST_EMPTY_COMPOSITE_NODES, ST_LEAF_LISTS) 

107def test_composite_node_confidence_one_time(composite, leaves): 

108 """ 

109 Check a node calculates its confidence only once. 

110 

111 We mock the leaf gatherer so we can ensure it is only called once, 

112 passing a set of leaf nodes. We also mock the clamp utility to pass 

113 the probability through unchanged. 

114 """ 

115 with ( 

116 mock.patch.object(node, "gather_leaves", return_value=leaves) as gather_leaves, 

117 mock.patch.object(node.utils, "clamp", side_effect=lambda p, _, __: p) as clamp, 

118 ): 

119 c1 = composite.confidence 

120 c2 = composite.confidence 

121 

122 assert 0 <= c1 <= 1 

123 assert c1 == c2 == math.exp(sum(leaf.logprob for leaf in leaves) / len(leaves)) 

124 

125 gather_leaves.assert_called_once_with(composite) 

126 clamp.assert_called_once_with(c1, 0.0, 1.0) 

127 

128 

129@hyp.given(ST_EMPTY_COMPOSITE_NODES, ST_LEAF_LISTS) 

130def test_composite_node_leaves_one_time(composite, leaves): 

131 """ 

132 Check a node calculates its leaves only once. 

133 

134 We mock the leaf gatherer so we can ensure it is only called once, 

135 passing a set of leaf nodes. 

136 """ 

137 with mock.patch.object(node, "gather_leaves", return_value=leaves) as gather_leaves: 

138 l1 = composite.leaves 

139 l2 = composite.leaves 

140 

141 assert l1 == l2 == leaves 

142 

143 gather_leaves.assert_called_once_with(composite) 

144 

145 

146@hyp.given(ST_COMPOSITE_NODES) 

147def test_gather_leaves_composite_node(composite): 

148 """Check gathering from a composite returns a list of tokens.""" 

149 leaves = node.gather_leaves(composite) 

150 

151 def _count_leaves(node_: node.CompositeNode | node.TokenNode) -> int: 

152 if isinstance(node_, node.TokenNode): 

153 return 1 

154 

155 return sum(_count_leaves(child) for child in node_.children) 

156 

157 assert isinstance(leaves, list) 

158 assert all(isinstance(leaf, node.TokenNode) for leaf in leaves) 

159 assert len(leaves) == _count_leaves(composite) 

160 

161 

162@hyp.given(ST_TOKEN_NODES) 

163def test_gather_leaves_solo_token_node(token): 

164 """Check gathering from a token node returns itself in a list.""" 

165 assert node.gather_leaves(token) == [token] 

166 

167 

168@hyp.given(st.builds(node.CompositeNode, children=ST_LEAF_LISTS)) 

169def test_gather_leaves_composite_all_father(composite): 

170 """Check gathering from an all-father gives the leaves we pass.""" 

171 assert node.gather_leaves(composite) == composite.children 

172 

173 

174def test_gather_leaves_raises_for_other_node_type(): 

175 """Check an unknown node type throws an error.""" 

176 

177 class NotNode: 

178 pass 

179 

180 with pytest.raises(ValueError, match=r"Invalid node type:.*NotNode"): 

181 _ = node.gather_leaves(NotNode()) # type: ignore[reportArgumentType]