Coverage for crateweb/nlp_classification/tests/models_tests.py: 100%
82 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-08-27 10:34 -0500
« prev ^ index » next coverage.py v7.8.0, created at 2025-08-27 10:34 -0500
1import re
2from unittest import mock
4from django.test import TestCase
6from crate_anon.crateweb.core.constants import (
7 DJANGO_DEFAULT_CONNECTION,
8)
9from crate_anon.crateweb.nlp_classification.models import (
10 Option,
11 Question,
12 Task,
13)
14from crate_anon.crateweb.nlp_classification.tests.factories import (
15 ColumnFactory,
16 SampleSpecFactory,
17 SourceRecordFactory,
18 TableDefinitionFactory,
19)
20from crate_anon.nlp_manager.constants import (
21 FN_SRCFIELD,
22 FN_SRCPKFIELD,
23 FN_SRCPKVAL,
24 FN_SRCTABLE,
25)
26from crate_anon.nlp_manager.regex_parser import (
27 FN_CONTENT,
28 FN_END,
29 FN_START,
30)
33class TaskTests(TestCase):
34 databases = {DJANGO_DEFAULT_CONNECTION}
36 def test_str_is_name(self) -> None:
37 task = Task(name="Test")
38 self.assertEqual(str(task), "Test")
41class QuestionTests(TestCase):
42 databases = {DJANGO_DEFAULT_CONNECTION}
44 def test_str_is_title(self) -> None:
45 question = Question(title="Test")
46 self.assertEqual(str(question), "Test")
49class OptionTests(TestCase):
50 databases = {DJANGO_DEFAULT_CONNECTION}
52 def test_str_is_description(self) -> None:
53 choice = Option(description="Test")
54 self.assertEqual(str(choice), "Test")
57class SampleSpecTests(TestCase):
58 databases = {DJANGO_DEFAULT_CONNECTION}
60 def test_str_shows_sample_spec_source(self) -> None:
61 table_definition = TableDefinitionFactory(
62 table_name="note",
63 pk_column_name="id",
64 db_connection_name="test",
65 )
66 column = ColumnFactory(table_definition=table_definition, name="note")
67 sample_spec = SampleSpecFactory(
68 source_column=column, search_term="CRP", size=100, seed=12345
69 )
70 self.assertEqual(
71 str(sample_spec),
72 (
73 "100 records from 'test.note.note' "
74 "with seed 12345 and search term 'CRP'"
75 ),
76 )
79class SourceRecordTests(TestCase):
80 def test_nlp_dict_fetched(self) -> None:
81 nlp_table_definition = TableDefinitionFactory(
82 table_name="nlp_table", pk_column_name="id"
83 )
84 ColumnFactory(table_definition=nlp_table_definition, name="extra")
85 source_record = SourceRecordFactory(
86 nlp_table_definition=nlp_table_definition, nlp_pk_value="12345"
87 )
89 # Not realistic
90 fake_source_record = {"fake": "source_record"}
92 mock_fetch = mock.Mock(return_value=fake_source_record)
93 with mock.patch.multiple(
94 "crate_anon.crateweb.nlp_classification.models.DatabaseConnection",
95 fetchone_as_dict=mock_fetch,
96 ):
97 self.assertEqual(source_record.nlp_dict, fake_source_record)
99 expected_column_names = [
100 FN_SRCFIELD,
101 FN_SRCTABLE,
102 FN_SRCPKFIELD,
103 FN_SRCPKVAL,
104 FN_CONTENT,
105 FN_START,
106 FN_END,
107 "extra",
108 ]
110 mock_fetch.assert_called_with(
111 expected_column_names,
112 "nlp_table",
113 where="id = %s",
114 params=["12345"],
115 )
117 def test_source_text_fetched(self) -> None:
118 test_pk_value = "12345"
120 source_table_definition = TableDefinitionFactory(
121 table_name="source_table",
122 pk_column_name="source_pk_field",
123 )
125 source_record = SourceRecordFactory(
126 source_column=ColumnFactory(
127 table_definition=source_table_definition, name="source_field"
128 ),
129 source_pk_value=test_pk_value,
130 )
131 fake_source_source_record = {"source_field": "source text"}
132 mock_fetch_from_source = mock.Mock(
133 return_value=fake_source_source_record
134 )
136 source_connection = mock.Mock(fetchone_as_dict=mock_fetch_from_source)
137 with mock.patch.multiple(
138 source_record,
139 get_source_database_connection=mock.Mock(
140 return_value=source_connection
141 ),
142 ):
143 self.assertEqual(source_record.source_text, "source text")
145 expected_column_names = ["source_field"]
147 mock_fetch_from_source.assert_called_with(
148 expected_column_names,
149 "source_table",
150 where="source_pk_field = %s",
151 params=[test_pk_value],
152 )
154 def test_str_reports_table_definition_info(self) -> None:
155 test_pk_value = "12345"
157 source_table_definition = TableDefinitionFactory(
158 table_name="source_table",
159 pk_column_name="id",
160 db_connection_name="test",
161 )
162 source_column = ColumnFactory(table_definition=source_table_definition)
164 source_record = SourceRecordFactory(
165 source_column=source_column,
166 source_pk_value=test_pk_value,
167 )
169 self.assertEqual(
170 str(source_record), f"Item test.source_table.id={test_pk_value}"
171 )
173 def test_source_text_before_match(self) -> None:
174 source_record = SourceRecordFactory()
175 fake_source_text = "before match after"
176 match = re.search("match", fake_source_text)
178 fake_nlp_dict = {FN_START: match.start()}
180 with mock.patch.multiple(
181 source_record,
182 _source_text=fake_source_text,
183 _nlp_dict=fake_nlp_dict,
184 ):
185 self.assertEqual(source_record.before, "before ")
187 def test_source_text_after_match(self) -> None:
188 source_record = SourceRecordFactory()
189 fake_source_text = "before match after"
190 match = re.search("match", fake_source_text)
192 fake_nlp_dict = {FN_END: match.end()}
194 with mock.patch.multiple(
195 source_record,
196 _source_text=fake_source_text,
197 _nlp_dict=fake_nlp_dict,
198 ):
199 self.assertEqual(source_record.after, " after")
201 def test_match_text_from_source_record_content(self) -> None:
202 source_record = SourceRecordFactory()
204 fake_nlp_dict = {FN_CONTENT: "match"}
206 with mock.patch.multiple(
207 source_record,
208 _nlp_dict=fake_nlp_dict,
209 ):
210 self.assertEqual(source_record.match, "match")
212 def test_extra_fields_copied_from_nlp_dict(self) -> None:
213 source_record = SourceRecordFactory()
215 # Not a complete real world example
216 fake_nlp_dict = {
217 FN_CONTENT: "CRP was <13 mg/dl",
218 "value_text": "13",
219 "units": "mg/dl",
220 }
222 with mock.patch.multiple(
223 source_record,
224 _nlp_dict=fake_nlp_dict,
225 _extra_nlp_column_names=["value_text", "units"],
226 ):
227 self.assertEqual(
228 source_record.extra_nlp_fields,
229 {"value_text": "13", "units": "mg/dl"},
230 )