Coverage for intelligence_toolkit/generate_mock_data/data_generator.py: 51%
109 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-16 13:41 -0300
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-16 13:41 -0300
1# Copyright (c) 2024 Microsoft Corporation. All rights reserved.
2import random
3from json import loads
5import pandas as pd
7import intelligence_toolkit.AI.utils as utils
8import intelligence_toolkit.generate_mock_data.prompts as prompts
9import intelligence_toolkit.generate_mock_data.schema_builder as schema_builder
10from intelligence_toolkit.helpers.progress_batch_callback import ProgressBatchCallback
13async def generate_data(
14 ai_configuration,
15 generation_guidance,
16 data_schema,
17 num_records_overall,
18 records_per_batch,
19 duplicate_records_per_batch,
20 related_records_per_batch,
21 temperature,
22 df_update_callback,
23 callback_batch,
24 parallel_batches=5,
25):
26 record_arrays = extract_array_fields(data_schema)
27 primary_record_array = record_arrays[0]
28 generated_objects = []
29 first_object = generate_unseeded_data(
30 ai_configuration=ai_configuration,
31 generation_guidance=generation_guidance,
32 primary_record_array=primary_record_array,
33 total_records=records_per_batch,
34 data_schema=data_schema,
35 temperature=temperature,
36 )
37 first_object_json = loads(first_object)
38 try:
39 first_object_json = loads(first_object)
40 except Exception as e:
41 msg = f"AI did not return a valid JSON response. Please try again. {e}"
42 raise ValueError(msg) from e
43 generated_objects.append(first_object_json)
44 current_object_json = first_object_json.copy()
45 dfs = {}
46 for record_array in record_arrays:
47 df = extract_df(current_object_json, record_array)
48 dfs[".".join(record_array)] = df
49 if df_update_callback is not None:
50 df_update_callback(dfs)
52 num_records = records_per_batch
53 while num_records < num_records_overall:
54 remainder = num_records_overall - num_records
55 required = remainder / records_per_batch
56 if not required.is_integer():
57 required += 1
58 batches = min(parallel_batches, int(required))
59 sample_records = sample_from_record_array(
60 current_object_json, primary_record_array, batches
61 )
62 num_records += records_per_batch * parallel_batches
63 # Use each as seed for parallel gen
64 new_objects = await generate_seeded_data(
65 ai_configuration=ai_configuration,
66 sample_records=sample_records,
67 generation_guidance=generation_guidance,
68 primary_record_array=primary_record_array,
69 total_records=records_per_batch,
70 near_duplicate_records=duplicate_records_per_batch,
71 close_relation_records=related_records_per_batch,
72 data_schema=data_schema,
73 temperature=temperature,
74 callbacks=[callback_batch] if callback_batch is not None else None,
75 )
77 for new_object in new_objects:
78 try:
79 new_object_json = loads(new_object)
80 except Exception as e:
81 msg = f"AI did not return a valid JSON response. Please try again. {e}"
82 raise ValueError(msg) from e
84 generated_objects.append(new_object_json)
85 current_object_json, conflicts = merge_json_objects(
86 current_object_json, new_object_json
87 )
89 for record_array in record_arrays:
90 df = extract_df(current_object_json, record_array)
91 dfs[".".join(record_array)] = df
93 if df_update_callback is not None:
94 df_update_callback(dfs)
95 return current_object_json, dfs
98def generate_unseeded_data(
99 ai_configuration,
100 generation_guidance,
101 primary_record_array,
102 total_records,
103 data_schema,
104 temperature,
105):
106 answer_messages = utils.prepare_messages(
107 prompts.unseeded_data_generation_prompt,
108 {
109 "generation_guidance": generation_guidance,
110 "primary_record_array": primary_record_array,
111 "total_records": total_records,
112 },
113 )
114 answer_format = {
115 "type": "json_schema",
116 "json_schema": {"name": "answer_object", "strict": True, "schema": data_schema},
117 }
119 return utils.generate_text(
120 ai_configuration,
121 answer_messages,
122 response_format=answer_format,
123 temperature=temperature,
124 )
127async def generate_seeded_data(
128 ai_configuration,
129 sample_records,
130 generation_guidance,
131 primary_record_array,
132 total_records,
133 near_duplicate_records,
134 close_relation_records,
135 data_schema,
136 temperature,
137 callbacks: list[ProgressBatchCallback] | None = None,
138):
139 answer_format = {
140 "type": "json_schema",
141 "json_schema": {"name": "answer_object", "strict": True, "schema": data_schema},
142 }
143 mapped_messages = [
144 utils.prepare_messages(
145 prompts.seeded_data_generation_prompt,
146 {
147 "seed_record": sample_record,
148 "generation_guidance": generation_guidance,
149 "primary_record_array": primary_record_array,
150 "record_targets": "\n".join(
151 [
152 "Total records: " + str(total_records),
153 "Near duplicates of seed: " + str(near_duplicate_records),
154 "Close relations of seed: " + str(close_relation_records),
155 ]
156 ),
157 },
158 )
159 for sample_record in sample_records
160 ]
162 return await utils.map_generate_text(
163 ai_configuration,
164 mapped_messages,
165 response_format=answer_format,
166 temperature=temperature,
167 callbacks=callbacks,
168 )
171def select_random_records(num_records, category_to_count):
172 select = sum(category_to_count.values())
173 selected = random.sample(range(num_records), select)
174 # return category to ids
175 category_to_ids = {}
176 for category, count in category_to_count.items():
177 category_to_ids[category] = selected[:count]
178 selected = selected[count:]
179 return category_to_ids
182def extract_df(json_data, record_path):
183 # Extracts a DataFrame from a JSON object
184 return pd.json_normalize(data=json_data, record_path=record_path)
187def merge_json_objects(json_obj1, json_obj2):
188 merged_object = {}
189 conflicts = []
191 def merge_values(key, value1, value2):
192 if isinstance(value1, dict) and isinstance(value2, dict):
193 merged_value, sub_conflicts = merge_json_objects(value1, value2)
194 if sub_conflicts:
195 conflicts.extend([f"{key}.{sub_key}" for sub_key in sub_conflicts])
196 return merged_value
197 elif isinstance(value1, list) and isinstance(value2, list):
198 return value1 + value2
199 elif value1 != value2:
200 conflicts.append(key)
201 return value2
202 else:
203 return value1
205 all_keys = set(json_obj1.keys()).union(set(json_obj2.keys()))
207 for key in all_keys:
208 if key in json_obj1 and key in json_obj2:
209 merged_object[key] = merge_values(key, json_obj1[key], json_obj2[key])
210 elif key in json_obj1:
211 merged_object[key] = json_obj1[key]
212 else:
213 merged_object[key] = json_obj2[key]
215 return merged_object, conflicts
218def extract_array_fields(schema: dict) -> list[list[str]]:
219 # Extracts any array fields at any level of nesting, and returns a list of lists of field names navigating down the schema
220 array_fields = []
222 def extract_array_fields_recursive(schema, field_path):
223 if isinstance(schema, dict):
224 for field_name, field_value in schema.get("properties", {}).items():
225 if isinstance(field_value, dict):
226 if field_value.get("type") == "array":
227 array_fields.append(field_path + [field_name])
228 extract_array_fields_recursive(
229 field_value.get("items", {}), field_path + [field_name]
230 )
231 else:
232 extract_array_fields_recursive(
233 field_value, field_path + [field_name]
234 )
235 elif isinstance(schema, list):
236 for item in schema:
237 extract_array_fields_recursive(item, field_path)
239 extract_array_fields_recursive(schema, [])
240 return array_fields
243def sample_from_record_array(current_object, record_array, k):
244 records = schema_builder.get_subobject(current_object, record_array)
245 return random.sample(records, k) if len(records) > k else records