Coverage for intelligence_toolkit/extract_record_data/data_extractor.py: 0%

64 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-10-16 13:41 -0300

1# Copyright (c) 2024 Microsoft Corporation. All rights reserved. 

2from json import loads 

3 

4import pandas as pd 

5 

6import intelligence_toolkit.AI.utils as utils 

7import intelligence_toolkit.extract_record_data.prompts as prompts 

8from intelligence_toolkit.helpers.progress_batch_callback import ProgressBatchCallback 

9 

10 

11async def extract_record_data( 

12 ai_configuration, 

13 generation_guidance, 

14 record_arrays, 

15 data_schema, 

16 input_texts, 

17 df_update_callback, 

18 callback_batch, 

19): 

20 generated_objects = [] 

21 current_object_json = {} 

22 

23 new_objects = await _extract_data_parallel( 

24 ai_configuration=ai_configuration, 

25 input_texts=input_texts, 

26 generation_guidance=generation_guidance, 

27 data_schema=data_schema, 

28 callbacks=[callback_batch] if callback_batch is not None else None, 

29 ) 

30 

31 for new_object in new_objects: 

32 new_object_json = loads(new_object) 

33 generated_objects.append(new_object_json) 

34 current_object_json, conflicts = merge_json_objects( 

35 current_object_json, new_object_json 

36 ) 

37 dfs = {} 

38 for record_array in record_arrays: 

39 df = extract_df(current_object_json, record_array) 

40 dfs[".".join(record_array)] = df 

41 if df_update_callback is not None: 

42 df_update_callback(dfs) 

43 return current_object_json, dfs 

44 

45 

46async def _extract_data_parallel( 

47 ai_configuration, 

48 input_texts, 

49 generation_guidance, 

50 data_schema, 

51 callbacks: list[ProgressBatchCallback] | None = None, 

52): 

53 answer_format = { 

54 "type": "json_schema", 

55 "json_schema": {"name": "record_object", "strict": True, "schema": data_schema}, 

56 } 

57 mapped_messages = [ 

58 utils.prepare_messages( 

59 prompts.data_extraction_prompt, 

60 { 

61 "input_text": input_text, 

62 "generation_guidance": generation_guidance, 

63 }, 

64 ) 

65 for input_text in input_texts 

66 ] 

67 

68 return await utils.map_generate_text( 

69 ai_configuration, 

70 mapped_messages, 

71 response_format=answer_format, 

72 callbacks=callbacks, 

73 ) 

74 

75 

76def extract_df(json_data, record_path): 

77 # Extracts a DataFrame from a JSON object 

78 return pd.json_normalize(data=json_data, record_path=record_path) 

79 

80 

81def merge_json_objects(json_obj1, json_obj2): 

82 merged_object = {} 

83 conflicts = [] 

84 

85 def merge_values(key, value1, value2): 

86 if isinstance(value1, dict) and isinstance(value2, dict): 

87 merged_value, sub_conflicts = merge_json_objects(value1, value2) 

88 if sub_conflicts: 

89 conflicts.extend([f"{key}.{sub_key}" for sub_key in sub_conflicts]) 

90 return merged_value 

91 elif isinstance(value1, list) and isinstance(value2, list): 

92 return value1 + value2 

93 elif value1 != value2: 

94 conflicts.append(key) 

95 return value2 

96 else: 

97 return value1 

98 

99 all_keys = set(json_obj1.keys()).union(set(json_obj2.keys())) 

100 

101 for key in all_keys: 

102 if key in json_obj1 and key in json_obj2: 

103 merged_object[key] = merge_values(key, json_obj1[key], json_obj2[key]) 

104 elif key in json_obj1: 

105 merged_object[key] = json_obj1[key] 

106 else: 

107 merged_object[key] = json_obj2[key] 

108 

109 return merged_object, conflicts 

110 

111 

112def extract_array_fields(schema): 

113 # Extracts any array fields at any level of nesting, and returns a list of lists of field names navigating down the schema 

114 array_fields = [] 

115 

116 def extract_array_fields_recursive(schema, field_path): 

117 if isinstance(schema, dict): 

118 for field_name, field_value in schema.get("properties", {}).items(): 

119 if isinstance(field_value, dict): 

120 if field_value.get("type") == "array": 

121 array_fields.append(field_path + [field_name]) 

122 extract_array_fields_recursive( 

123 field_value.get("items", {}), field_path + [field_name] 

124 ) 

125 else: 

126 extract_array_fields_recursive( 

127 field_value, field_path + [field_name] 

128 ) 

129 elif isinstance(schema, list): 

130 for item in schema: 

131 extract_array_fields_recursive(item, field_path) 

132 

133 extract_array_fields_recursive(schema, []) 

134 return array_fields