Coverage for intelligence_toolkit/tests/unit/compare_case_groups/test_build_dataframes.py: 100%
121 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-16 13:41 -0300
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-16 13:41 -0300
1# Copyright (c) 2024 Microsoft Corporation. All rights reserved.
2# Licensed under the MIT license. See LICENSE file in the project.
3#
6import pandas as pd
7import polars as pl
8import pytest
10from intelligence_toolkit.compare_case_groups.build_dataframes import (
11 build_attribute_df,
12 build_grouped_df,
13 build_ranked_df,
14 filter_df,
15)
18class TestBuildRankedGroups:
19 @pytest.fixture()
20 def sample_data(self):
21 ldf = pl.DataFrame(
22 {
23 "Group": ["A", "A", "B", "B"],
24 "attribute_value": [1, 2, 3, 4],
25 "attribute_rank": [1, 2, 1, 2],
26 "group_rank": [1, 1, 1, 1],
27 "temporal column": ["2021-01", "2021-02", "2021-01", "2021-02"],
28 "temporal column_window_rank": [1, 2, 1, 2],
29 "temporal column_window_delta": [0, 1, 0, 1],
30 }
31 )
33 gdf = pl.DataFrame({"Group": ["A", "B"], "Global Rank": [1, 2]})
35 adf = pl.DataFrame(
36 {
37 "Group": ["A", "A", "B", "B"],
38 "attribute_value": [1, 2, 3, 4],
39 "attribute_rank": [1, 2, 1, 2],
40 "group_rank": [1, 1, 1, 1],
41 }
42 )
44 return ldf, gdf, adf
46 def test_build_ranked_df_temporal_columns(self, sample_data):
47 ldf, gdf, adf = sample_data
48 temporal = "temporal column"
49 groups = ["Group"]
51 result_df = build_ranked_df(ldf, gdf, adf, temporal, groups)
53 assert "temporal column_window" in result_df.columns
54 assert "temporal column_window_rank" in result_df.columns
55 assert "temporal column_window_delta" in result_df.columns
57 def test_build_ranked_df_temporal(self, sample_data):
58 ldf, gdf, adf = sample_data
59 temporal = "temporal column"
60 groups = ["Group"]
62 result_df = build_ranked_df(ldf, gdf, adf, temporal, groups)
64 expected_values = {
65 "Group": ["A", "A", "B", "B"],
66 "attribute_value": [1, 2, 3, 4],
67 "attribute_rank": [1, 2, 1, 2],
68 "group_rank": [1, 1, 1, 1],
69 "temporal column_window_rank": [1, 2, 1, 2],
70 "temporal column_window_delta": [0, 1, 0, 1],
71 }
73 for col, values in expected_values.items():
74 assert result_df[col].to_list() == values
76 def test_build_ranked_df_no_temporal_columns(self, sample_data):
77 ldf, gdf, adf = sample_data
78 temporal = ""
79 groups = ["Group"]
81 result_df = build_ranked_df(ldf, gdf, adf, temporal, groups)
83 assert "temporal column_window" not in result_df.columns
84 assert "temporal_column_window_rank" not in result_df.columns
85 assert "temporal_column_window_delta" not in result_df.columns
87 def test_build_ranked_df_no_temporal(self, sample_data):
88 ldf, gdf, adf = sample_data
89 temporal = ""
90 groups = ["Group"]
92 result_df = build_ranked_df(ldf, gdf, adf, temporal, groups)
94 expected_values = {
95 "Group": ["A", "A", "B", "B"],
96 "attribute_value": [1, 2, 3, 4],
97 "attribute_rank": [1, 2, 1, 2],
98 "group_rank": [1, 1, 1, 1],
99 }
101 for col, values in expected_values.items():
102 assert result_df[col].to_list() == values
104 def test_build_ranked_df_attribute_rank_type(self, sample_data):
105 ldf, gdf, adf = sample_data
106 temporal = "temporal column"
107 groups = ["Group"]
109 result_df = build_ranked_df(ldf, gdf, adf, temporal, groups)
111 assert result_df["attribute_rank"].dtype == pl.Int32
112 assert result_df["group_rank"].dtype == pl.Int32
114 def test_build_ranked_df_sorted(self, sample_data):
115 ldf, gdf, adf = sample_data
116 temporal = "temporal column"
117 groups = ["Group"]
119 result_df = build_ranked_df(ldf, gdf, adf, temporal, groups)
121 assert result_df.equals(result_df.sort(by=groups))
122 assert result_df.equals(result_df.sort(by=groups))
125class TestFilterDf:
126 @pytest.fixture()
127 def dataset(self) -> pl.DataFrame:
128 return pl.DataFrame(
129 {
130 "Group": ["A", "A", "B", "B", "C"],
131 "attribute_value": ["X", "B", "BCD", "ABC", "X"],
132 "attribute_rank": [1, 2, 1, 2, 1],
133 "group_rank": [1, 1, 1, 1, 1],
134 }
135 )
137 def test_filter_empty(self, dataset) -> None:
138 result = filter_df(dataset, [])
139 assert result.equals(dataset)
141 def test_filter_single(self, dataset) -> None:
142 result = filter_df(dataset, ["Group:A"])
143 expected = dataset.filter(pl.col("Group") == "A")
144 assert result.equals(expected)
146 def test_filter_multiple(self, dataset) -> None:
147 result = filter_df(dataset, ["Group:A", "attribute_value:X"])
148 expected = dataset.filter(
149 (pl.col("Group") == "A") & (pl.col("attribute_value") == "X")
150 )
151 assert result.equals(expected)
153 def test_filter_multiple_attr_inexistent(self, dataset) -> None:
154 result = filter_df(dataset, ["Group:A", "attribute_value:F"])
155 assert len(result) == 0
158class TestBuildAttributeDf:
159 @pytest.fixture()
160 def dataset_2(self) -> pl.DataFrame:
161 return pl.DataFrame(
162 {
163 "Group": ["A", "A", "B", "B"],
164 "Temporal": [1, 2, 1, 2],
165 "Aggregate1": [10, 20, None, 40],
166 "Aggregate2": [5, None, 15, 20],
167 }
168 )
170 @pytest.fixture()
171 def dataset_3(self) -> pl.DataFrame:
172 return pl.DataFrame(
173 {
174 "Group": ["A", "A", "B", "B"],
175 "Group2": ["AX", "AE", "BZ", "BY"],
176 "Aggregate1": [40, 20, 30, 50],
177 "Aggregate2": [10, 1, 15, 7],
178 }
179 )
181 @pytest.fixture()
182 def expected_dataset_1(self) -> pl.DataFrame:
183 return pl.DataFrame(
184 {
185 "Group": [
186 "A",
187 "A",
188 "A",
189 "A",
190 "A",
191 "A",
192 "A",
193 "A",
194 "B",
195 "B",
196 "B",
197 "B",
198 "B",
199 "B",
200 "B",
201 "B",
202 ],
203 "attribute_value": [
204 "Aggregate1:20",
205 "Aggregate1:30",
206 "Aggregate1:40",
207 "Aggregate1:50",
208 "Aggregate2:1",
209 "Aggregate2:10",
210 "Aggregate2:15",
211 "Aggregate2:7",
212 "Aggregate1:20",
213 "Aggregate1:30",
214 "Aggregate1:40",
215 "Aggregate1:50",
216 "Aggregate2:1",
217 "Aggregate2:10",
218 "Aggregate2:15",
219 "Aggregate2:7",
220 ],
221 "attribute_count": [1, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1],
222 "attribute_rank": [
223 1.0,
224 2.0,
225 1.0,
226 2.0,
227 1.0,
228 1.0,
229 2.0,
230 2.0,
231 2.0,
232 1.0,
233 2.0,
234 1.0,
235 2.0,
236 2.0,
237 1.0,
238 1.0,
239 ],
240 }
241 ).sort(by=["Group", "attribute_value"])
243 @pytest.fixture()
244 def expected_dataset_2(self) -> pl.DataFrame:
245 return pl.DataFrame(
246 {
247 "Group": ["A"] * 6 + ["B"] * 6,
248 "attribute_value": [
249 "Aggregate1:10",
250 "Aggregate1:20",
251 "Aggregate1:40",
252 "Aggregate2:15",
253 "Aggregate2:20",
254 "Aggregate2:5",
255 "Aggregate1:10",
256 "Aggregate1:20",
257 "Aggregate1:40",
258 "Aggregate2:15",
259 "Aggregate2:20",
260 "Aggregate2:5",
261 ],
262 "attribute_count": [1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0],
263 "attribute_rank": [1, 1, 2, 2, 2, 1, 2, 2, 1, 1, 1, 2],
264 }
265 ).sort(by=["Group", "attribute_value"])
267 @pytest.fixture()
268 def expected_dataset_3(self) -> pl.DataFrame:
269 return (
270 pl.DataFrame(
271 {
272 "Group": ["A"] * 4 + ["B"] * 4 + ["A"] * 12 + ["B"] * 12,
273 "Group2": [
274 "AE",
275 "AE",
276 "AX",
277 "AX",
278 "BY",
279 "BY",
280 "BZ",
281 "BZ",
282 "AE",
283 "AE",
284 "AE",
285 "AE",
286 "AE",
287 "AE",
288 "AX",
289 "AX",
290 "AX",
291 "AX",
292 "AX",
293 "AX",
294 "BY",
295 "BY",
296 "BY",
297 "BY",
298 "BY",
299 "BY",
300 "BZ",
301 "BZ",
302 "BZ",
303 "BZ",
304 "BZ",
305 "BZ",
306 ],
307 "attribute_value": [
308 "Aggregate1:20",
309 "Aggregate2:1",
310 "Aggregate1:40",
311 "Aggregate2:10",
312 "Aggregate1:50",
313 "Aggregate2:7",
314 "Aggregate1:30",
315 "Aggregate2:15",
316 "Aggregate1:40",
317 "Aggregate2:10",
318 "Aggregate1:50",
319 "Aggregate2:7",
320 "Aggregate1:30",
321 "Aggregate2:15",
322 "Aggregate1:20",
323 "Aggregate2:1",
324 "Aggregate1:50",
325 "Aggregate2:7",
326 "Aggregate1:30",
327 "Aggregate2:15",
328 "Aggregate1:20",
329 "Aggregate2:1",
330 "Aggregate1:40",
331 "Aggregate2:10",
332 "Aggregate1:30",
333 "Aggregate2:15",
334 "Aggregate1:20",
335 "Aggregate2:1",
336 "Aggregate1:40",
337 "Aggregate2:10",
338 "Aggregate1:50",
339 "Aggregate2:7",
340 ],
341 "attribute_count": [1] * 8 + [0] * 24,
342 "attribute_rank": [1.0] * 8 + [4.0] * 24,
343 }
344 )
345 .with_columns(
346 [
347 pl.col("attribute_rank").cast(pl.UInt32),
348 pl.col("attribute_count").cast(pl.UInt32),
349 ]
350 )
351 .sort(by=["Group", "Group2", "attribute_value"])
352 )
354 def test_build_attribute_df(self, expected_dataset_1) -> None:
355 df1 = pl.DataFrame(
356 {
357 "Group": ["A", "A", "B", "B"],
358 "Temporal": [1, 2, 1, 2],
359 "Aggregate1": [40, 20, 30, 50],
360 "Aggregate2": [10, 1, 15, 7],
361 }
362 )
364 result_df1 = build_attribute_df(
365 df1, ["Group"], ["Aggregate1", "Aggregate2"]
366 ).sort(by=["Group", "attribute_value"])
367 result_df1.equals(expected_dataset_1)
369 def test_with_missing_values(self, dataset_2, expected_dataset_2):
370 result_df2 = build_attribute_df(
371 dataset_2,
372 ["Group"],
373 ["Aggregate1", "Aggregate2"],
374 ).sort(by=["Group", "attribute_value"])
376 assert result_df2.equals(expected_dataset_2)
378 def test_with_additional_group(self, dataset_3, expected_dataset_3):
379 result_df3 = build_attribute_df(
380 dataset_3,
381 ["Group", "Group2"],
382 ["Aggregate1", "Aggregate2"],
383 ).sort(by=["Group", "Group2", "attribute_value"])
385 assert result_df3.equals(expected_dataset_3)
388class TestBuildGroupedDf:
389 @pytest.fixture()
390 def main_dataset(self) -> pl.DataFrame:
391 return pl.DataFrame(
392 {
393 "city": [
394 "Westview",
395 "Westview",
396 "Westview",
397 "Westview",
398 "Eastview",
399 "Southtview",
400 "Northview",
401 "Gotham",
402 "Anycity",
403 "Simcity",
404 "Anycity",
405 ],
406 "country": [
407 "ANY",
408 "ANY",
409 "ANY",
410 "ANY",
411 "ANY",
412 "ANY",
413 "NEW",
414 "OLD",
415 "ANY",
416 "KEY",
417 "ANY",
418 ],
419 }
420 )
422 def test_build_grouped_df(self, main_dataset):
423 result = build_grouped_df(main_dataset, ["city"])
424 expected_data = {
425 "city": [
426 "Anycity",
427 "Eastview",
428 "Gotham",
429 "Northview",
430 "Simcity",
431 "Southtview",
432 "Westview",
433 ],
434 "group_count": [2, 1, 1, 1, 1, 1, 4],
435 "group_rank": [2, 7, 7, 7, 7, 7, 1],
436 }
437 expected_df = pl.DataFrame(expected_data)
438 assert result.equals(expected_df)
440 def test_build_grouped_ints(self, main_dataset) -> None:
441 invalid_groups = ["city", 123]
443 with pytest.raises(ValueError, match="All elements in groups must be strings"):
444 build_grouped_df(main_dataset, invalid_groups)
446 def test_build_grouped_df_missing(self, main_dataset):
447 main_dataset = main_dataset.filter(pl.col("city") != "Gotham")
448 result = build_grouped_df(main_dataset, ["city"])
449 expected_data = {
450 "city": [
451 "Anycity",
452 "Eastview",
453 "Northview",
454 "Simcity",
455 "Southtview",
456 "Westview",
457 ],
458 "group_count": [2, 1, 1, 1, 1, 4],
459 "group_rank": [2, 6, 6, 6, 6, 1],
460 }
461 expected_df = pl.DataFrame(expected_data)
462 expected_df = expected_df.with_columns(
463 [
464 pl.col("group_count").cast(pl.UInt32),
465 pl.col("group_rank").cast(pl.UInt32),
466 ]
467 )
468 assert result.equals(expected_df)