Coverage for intelligence_toolkit/detect_case_patterns/record_counter.py: 17%

46 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-10-16 13:41 -0300

1# Copyright (c) 2024 Microsoft Corporation. All rights reserved. 

2# Licensed under the MIT license. See LICENSE file in the project. 

3# 

4from collections import defaultdict 

5 

6import numpy as np 

7 

8from intelligence_toolkit.helpers.constants import ATTRIBUTE_VALUE_SEPARATOR 

9 

10 

11class RecordCounter: 

12 def __init__(self, df): 

13 self.counter = 0 

14 self.df = df 

15 self.periods = sorted(df["Period"].unique()) 

16 self.atts = sorted(df["Full Attribute"].unique()) 

17 att_to_ids_df = ( 

18 df[["Subject ID", "Full Attribute"]] 

19 .groupby("Full Attribute") 

20 .agg(list) 

21 .reset_index() 

22 ) 

23 self.att_to_ids = dict( 

24 zip( 

25 att_to_ids_df["Full Attribute"], 

26 [set(x) for x in att_to_ids_df["Subject ID"]], 

27 strict=False, 

28 ) 

29 ) 

30 # do same for Period 

31 time_to_ids_df = ( 

32 df[["Subject ID", "Period"]].groupby("Period").agg(list).reset_index() 

33 ) 

34 self.att_to_ids.update( 

35 dict( 

36 zip( 

37 time_to_ids_df["Period"], 

38 [set(x) for x in time_to_ids_df["Subject ID"]], 

39 strict=False, 

40 ) 

41 ) 

42 ) 

43 self.cache = {} 

44 

45 def count_records(self, atts): 

46 key = ";".join(sorted(atts)) 

47 if key in self.cache: 

48 return self.cache[key] 

49 

50 type_to_vals = defaultdict(list) 

51 for att in atts: 

52 type_to_vals[att.split(ATTRIBUTE_VALUE_SEPARATOR)[0]].append(att) 

53 ids = set() 

54 for ix, (_typ, vals) in enumerate(type_to_vals.items()): 

55 combined_atts = set() 

56 for val in vals: 

57 combined_atts.update(self.att_to_ids[val]) 

58 if ix == 0: 

59 ids.update(combined_atts) 

60 else: 

61 ids.intersection_update(combined_atts) 

62 count = len(ids) 

63 self.cache[key] = count 

64 return count 

65 

66 def compute_period_mean_sd_max(self, atts): 

67 counts = [] 

68 for p in self.periods: 

69 counts.append(self.count_records([p, *atts])) 

70 np_mean = np.mean(counts) if len(counts) > 0 else 0 

71 np_sd = np.std(counts) if len(counts) > 0 else 0 

72 np_max = np.max(counts) if len(counts) > 0 else 0 

73 return np_mean, np_sd, np_max 

74 

75 def create_time_series_rows(self, atts): 

76 rows = [] 

77 for p in self.periods: 

78 count = self.count_records([p, *atts]) 

79 rows.append([p, " & ".join(atts), count]) 

80 return rows