Coverage for crateweb/nlp_classification/tests/models_tests.py: 100%

82 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-08-27 10:34 -0500

1import re 

2from unittest import mock 

3 

4from django.test import TestCase 

5 

6from crate_anon.crateweb.core.constants import ( 

7 DJANGO_DEFAULT_CONNECTION, 

8) 

9from crate_anon.crateweb.nlp_classification.models import ( 

10 Option, 

11 Question, 

12 Task, 

13) 

14from crate_anon.crateweb.nlp_classification.tests.factories import ( 

15 ColumnFactory, 

16 SampleSpecFactory, 

17 SourceRecordFactory, 

18 TableDefinitionFactory, 

19) 

20from crate_anon.nlp_manager.constants import ( 

21 FN_SRCFIELD, 

22 FN_SRCPKFIELD, 

23 FN_SRCPKVAL, 

24 FN_SRCTABLE, 

25) 

26from crate_anon.nlp_manager.regex_parser import ( 

27 FN_CONTENT, 

28 FN_END, 

29 FN_START, 

30) 

31 

32 

33class TaskTests(TestCase): 

34 databases = {DJANGO_DEFAULT_CONNECTION} 

35 

36 def test_str_is_name(self) -> None: 

37 task = Task(name="Test") 

38 self.assertEqual(str(task), "Test") 

39 

40 

41class QuestionTests(TestCase): 

42 databases = {DJANGO_DEFAULT_CONNECTION} 

43 

44 def test_str_is_title(self) -> None: 

45 question = Question(title="Test") 

46 self.assertEqual(str(question), "Test") 

47 

48 

49class OptionTests(TestCase): 

50 databases = {DJANGO_DEFAULT_CONNECTION} 

51 

52 def test_str_is_description(self) -> None: 

53 choice = Option(description="Test") 

54 self.assertEqual(str(choice), "Test") 

55 

56 

57class SampleSpecTests(TestCase): 

58 databases = {DJANGO_DEFAULT_CONNECTION} 

59 

60 def test_str_shows_sample_spec_source(self) -> None: 

61 table_definition = TableDefinitionFactory( 

62 table_name="note", 

63 pk_column_name="id", 

64 db_connection_name="test", 

65 ) 

66 column = ColumnFactory(table_definition=table_definition, name="note") 

67 sample_spec = SampleSpecFactory( 

68 source_column=column, search_term="CRP", size=100, seed=12345 

69 ) 

70 self.assertEqual( 

71 str(sample_spec), 

72 ( 

73 "100 records from 'test.note.note' " 

74 "with seed 12345 and search term 'CRP'" 

75 ), 

76 ) 

77 

78 

79class SourceRecordTests(TestCase): 

80 def test_nlp_dict_fetched(self) -> None: 

81 nlp_table_definition = TableDefinitionFactory( 

82 table_name="nlp_table", pk_column_name="id" 

83 ) 

84 ColumnFactory(table_definition=nlp_table_definition, name="extra") 

85 source_record = SourceRecordFactory( 

86 nlp_table_definition=nlp_table_definition, nlp_pk_value="12345" 

87 ) 

88 

89 # Not realistic 

90 fake_source_record = {"fake": "source_record"} 

91 

92 mock_fetch = mock.Mock(return_value=fake_source_record) 

93 with mock.patch.multiple( 

94 "crate_anon.crateweb.nlp_classification.models.DatabaseConnection", 

95 fetchone_as_dict=mock_fetch, 

96 ): 

97 self.assertEqual(source_record.nlp_dict, fake_source_record) 

98 

99 expected_column_names = [ 

100 FN_SRCFIELD, 

101 FN_SRCTABLE, 

102 FN_SRCPKFIELD, 

103 FN_SRCPKVAL, 

104 FN_CONTENT, 

105 FN_START, 

106 FN_END, 

107 "extra", 

108 ] 

109 

110 mock_fetch.assert_called_with( 

111 expected_column_names, 

112 "nlp_table", 

113 where="id = %s", 

114 params=["12345"], 

115 ) 

116 

117 def test_source_text_fetched(self) -> None: 

118 test_pk_value = "12345" 

119 

120 source_table_definition = TableDefinitionFactory( 

121 table_name="source_table", 

122 pk_column_name="source_pk_field", 

123 ) 

124 

125 source_record = SourceRecordFactory( 

126 source_column=ColumnFactory( 

127 table_definition=source_table_definition, name="source_field" 

128 ), 

129 source_pk_value=test_pk_value, 

130 ) 

131 fake_source_source_record = {"source_field": "source text"} 

132 mock_fetch_from_source = mock.Mock( 

133 return_value=fake_source_source_record 

134 ) 

135 

136 source_connection = mock.Mock(fetchone_as_dict=mock_fetch_from_source) 

137 with mock.patch.multiple( 

138 source_record, 

139 get_source_database_connection=mock.Mock( 

140 return_value=source_connection 

141 ), 

142 ): 

143 self.assertEqual(source_record.source_text, "source text") 

144 

145 expected_column_names = ["source_field"] 

146 

147 mock_fetch_from_source.assert_called_with( 

148 expected_column_names, 

149 "source_table", 

150 where="source_pk_field = %s", 

151 params=[test_pk_value], 

152 ) 

153 

154 def test_str_reports_table_definition_info(self) -> None: 

155 test_pk_value = "12345" 

156 

157 source_table_definition = TableDefinitionFactory( 

158 table_name="source_table", 

159 pk_column_name="id", 

160 db_connection_name="test", 

161 ) 

162 source_column = ColumnFactory(table_definition=source_table_definition) 

163 

164 source_record = SourceRecordFactory( 

165 source_column=source_column, 

166 source_pk_value=test_pk_value, 

167 ) 

168 

169 self.assertEqual( 

170 str(source_record), f"Item test.source_table.id={test_pk_value}" 

171 ) 

172 

173 def test_source_text_before_match(self) -> None: 

174 source_record = SourceRecordFactory() 

175 fake_source_text = "before match after" 

176 match = re.search("match", fake_source_text) 

177 

178 fake_nlp_dict = {FN_START: match.start()} 

179 

180 with mock.patch.multiple( 

181 source_record, 

182 _source_text=fake_source_text, 

183 _nlp_dict=fake_nlp_dict, 

184 ): 

185 self.assertEqual(source_record.before, "before ") 

186 

187 def test_source_text_after_match(self) -> None: 

188 source_record = SourceRecordFactory() 

189 fake_source_text = "before match after" 

190 match = re.search("match", fake_source_text) 

191 

192 fake_nlp_dict = {FN_END: match.end()} 

193 

194 with mock.patch.multiple( 

195 source_record, 

196 _source_text=fake_source_text, 

197 _nlp_dict=fake_nlp_dict, 

198 ): 

199 self.assertEqual(source_record.after, " after") 

200 

201 def test_match_text_from_source_record_content(self) -> None: 

202 source_record = SourceRecordFactory() 

203 

204 fake_nlp_dict = {FN_CONTENT: "match"} 

205 

206 with mock.patch.multiple( 

207 source_record, 

208 _nlp_dict=fake_nlp_dict, 

209 ): 

210 self.assertEqual(source_record.match, "match") 

211 

212 def test_extra_fields_copied_from_nlp_dict(self) -> None: 

213 source_record = SourceRecordFactory() 

214 

215 # Not a complete real world example 

216 fake_nlp_dict = { 

217 FN_CONTENT: "CRP was <13 mg/dl", 

218 "value_text": "13", 

219 "units": "mg/dl", 

220 } 

221 

222 with mock.patch.multiple( 

223 source_record, 

224 _nlp_dict=fake_nlp_dict, 

225 _extra_nlp_column_names=["value_text", "units"], 

226 ): 

227 self.assertEqual( 

228 source_record.extra_nlp_fields, 

229 {"value_text": "13", "units": "mg/dl"}, 

230 )