Coverage for intelligence_toolkit/extract_record_data/data_extractor.py: 0%
64 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-16 13:41 -0300
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-16 13:41 -0300
1# Copyright (c) 2024 Microsoft Corporation. All rights reserved.
2from json import loads
4import pandas as pd
6import intelligence_toolkit.AI.utils as utils
7import intelligence_toolkit.extract_record_data.prompts as prompts
8from intelligence_toolkit.helpers.progress_batch_callback import ProgressBatchCallback
11async def extract_record_data(
12 ai_configuration,
13 generation_guidance,
14 record_arrays,
15 data_schema,
16 input_texts,
17 df_update_callback,
18 callback_batch,
19):
20 generated_objects = []
21 current_object_json = {}
23 new_objects = await _extract_data_parallel(
24 ai_configuration=ai_configuration,
25 input_texts=input_texts,
26 generation_guidance=generation_guidance,
27 data_schema=data_schema,
28 callbacks=[callback_batch] if callback_batch is not None else None,
29 )
31 for new_object in new_objects:
32 new_object_json = loads(new_object)
33 generated_objects.append(new_object_json)
34 current_object_json, conflicts = merge_json_objects(
35 current_object_json, new_object_json
36 )
37 dfs = {}
38 for record_array in record_arrays:
39 df = extract_df(current_object_json, record_array)
40 dfs[".".join(record_array)] = df
41 if df_update_callback is not None:
42 df_update_callback(dfs)
43 return current_object_json, dfs
46async def _extract_data_parallel(
47 ai_configuration,
48 input_texts,
49 generation_guidance,
50 data_schema,
51 callbacks: list[ProgressBatchCallback] | None = None,
52):
53 answer_format = {
54 "type": "json_schema",
55 "json_schema": {"name": "record_object", "strict": True, "schema": data_schema},
56 }
57 mapped_messages = [
58 utils.prepare_messages(
59 prompts.data_extraction_prompt,
60 {
61 "input_text": input_text,
62 "generation_guidance": generation_guidance,
63 },
64 )
65 for input_text in input_texts
66 ]
68 return await utils.map_generate_text(
69 ai_configuration,
70 mapped_messages,
71 response_format=answer_format,
72 callbacks=callbacks,
73 )
76def extract_df(json_data, record_path):
77 # Extracts a DataFrame from a JSON object
78 return pd.json_normalize(data=json_data, record_path=record_path)
81def merge_json_objects(json_obj1, json_obj2):
82 merged_object = {}
83 conflicts = []
85 def merge_values(key, value1, value2):
86 if isinstance(value1, dict) and isinstance(value2, dict):
87 merged_value, sub_conflicts = merge_json_objects(value1, value2)
88 if sub_conflicts:
89 conflicts.extend([f"{key}.{sub_key}" for sub_key in sub_conflicts])
90 return merged_value
91 elif isinstance(value1, list) and isinstance(value2, list):
92 return value1 + value2
93 elif value1 != value2:
94 conflicts.append(key)
95 return value2
96 else:
97 return value1
99 all_keys = set(json_obj1.keys()).union(set(json_obj2.keys()))
101 for key in all_keys:
102 if key in json_obj1 and key in json_obj2:
103 merged_object[key] = merge_values(key, json_obj1[key], json_obj2[key])
104 elif key in json_obj1:
105 merged_object[key] = json_obj1[key]
106 else:
107 merged_object[key] = json_obj2[key]
109 return merged_object, conflicts
112def extract_array_fields(schema):
113 # Extracts any array fields at any level of nesting, and returns a list of lists of field names navigating down the schema
114 array_fields = []
116 def extract_array_fields_recursive(schema, field_path):
117 if isinstance(schema, dict):
118 for field_name, field_value in schema.get("properties", {}).items():
119 if isinstance(field_value, dict):
120 if field_value.get("type") == "array":
121 array_fields.append(field_path + [field_name])
122 extract_array_fields_recursive(
123 field_value.get("items", {}), field_path + [field_name]
124 )
125 else:
126 extract_array_fields_recursive(
127 field_value, field_path + [field_name]
128 )
129 elif isinstance(schema, list):
130 for item in schema:
131 extract_array_fields_recursive(item, field_path)
133 extract_array_fields_recursive(schema, [])
134 return array_fields