Coverage for intelligence_toolkit/detect_case_patterns/record_counter.py: 17%
46 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-16 13:41 -0300
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-16 13:41 -0300
1# Copyright (c) 2024 Microsoft Corporation. All rights reserved.
2# Licensed under the MIT license. See LICENSE file in the project.
3#
4from collections import defaultdict
6import numpy as np
8from intelligence_toolkit.helpers.constants import ATTRIBUTE_VALUE_SEPARATOR
11class RecordCounter:
12 def __init__(self, df):
13 self.counter = 0
14 self.df = df
15 self.periods = sorted(df["Period"].unique())
16 self.atts = sorted(df["Full Attribute"].unique())
17 att_to_ids_df = (
18 df[["Subject ID", "Full Attribute"]]
19 .groupby("Full Attribute")
20 .agg(list)
21 .reset_index()
22 )
23 self.att_to_ids = dict(
24 zip(
25 att_to_ids_df["Full Attribute"],
26 [set(x) for x in att_to_ids_df["Subject ID"]],
27 strict=False,
28 )
29 )
30 # do same for Period
31 time_to_ids_df = (
32 df[["Subject ID", "Period"]].groupby("Period").agg(list).reset_index()
33 )
34 self.att_to_ids.update(
35 dict(
36 zip(
37 time_to_ids_df["Period"],
38 [set(x) for x in time_to_ids_df["Subject ID"]],
39 strict=False,
40 )
41 )
42 )
43 self.cache = {}
45 def count_records(self, atts):
46 key = ";".join(sorted(atts))
47 if key in self.cache:
48 return self.cache[key]
50 type_to_vals = defaultdict(list)
51 for att in atts:
52 type_to_vals[att.split(ATTRIBUTE_VALUE_SEPARATOR)[0]].append(att)
53 ids = set()
54 for ix, (_typ, vals) in enumerate(type_to_vals.items()):
55 combined_atts = set()
56 for val in vals:
57 combined_atts.update(self.att_to_ids[val])
58 if ix == 0:
59 ids.update(combined_atts)
60 else:
61 ids.intersection_update(combined_atts)
62 count = len(ids)
63 self.cache[key] = count
64 return count
66 def compute_period_mean_sd_max(self, atts):
67 counts = []
68 for p in self.periods:
69 counts.append(self.count_records([p, *atts]))
70 np_mean = np.mean(counts) if len(counts) > 0 else 0
71 np_sd = np.std(counts) if len(counts) > 0 else 0
72 np_max = np.max(counts) if len(counts) > 0 else 0
73 return np_mean, np_sd, np_max
75 def create_time_series_rows(self, atts):
76 rows = []
77 for p in self.periods:
78 count = self.count_records([p, *atts])
79 rows.append([p, " & ".join(atts), count])
80 return rows