Coverage for intelligence_toolkit/generate_mock_data/data_generator.py: 51%

109 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-10-16 13:41 -0300

1# Copyright (c) 2024 Microsoft Corporation. All rights reserved. 

2import random 

3from json import loads 

4 

5import pandas as pd 

6 

7import intelligence_toolkit.AI.utils as utils 

8import intelligence_toolkit.generate_mock_data.prompts as prompts 

9import intelligence_toolkit.generate_mock_data.schema_builder as schema_builder 

10from intelligence_toolkit.helpers.progress_batch_callback import ProgressBatchCallback 

11 

12 

13async def generate_data( 

14 ai_configuration, 

15 generation_guidance, 

16 data_schema, 

17 num_records_overall, 

18 records_per_batch, 

19 duplicate_records_per_batch, 

20 related_records_per_batch, 

21 temperature, 

22 df_update_callback, 

23 callback_batch, 

24 parallel_batches=5, 

25): 

26 record_arrays = extract_array_fields(data_schema) 

27 primary_record_array = record_arrays[0] 

28 generated_objects = [] 

29 first_object = generate_unseeded_data( 

30 ai_configuration=ai_configuration, 

31 generation_guidance=generation_guidance, 

32 primary_record_array=primary_record_array, 

33 total_records=records_per_batch, 

34 data_schema=data_schema, 

35 temperature=temperature, 

36 ) 

37 first_object_json = loads(first_object) 

38 try: 

39 first_object_json = loads(first_object) 

40 except Exception as e: 

41 msg = f"AI did not return a valid JSON response. Please try again. {e}" 

42 raise ValueError(msg) from e 

43 generated_objects.append(first_object_json) 

44 current_object_json = first_object_json.copy() 

45 dfs = {} 

46 for record_array in record_arrays: 

47 df = extract_df(current_object_json, record_array) 

48 dfs[".".join(record_array)] = df 

49 if df_update_callback is not None: 

50 df_update_callback(dfs) 

51 

52 num_records = records_per_batch 

53 while num_records < num_records_overall: 

54 remainder = num_records_overall - num_records 

55 required = remainder / records_per_batch 

56 if not required.is_integer(): 

57 required += 1 

58 batches = min(parallel_batches, int(required)) 

59 sample_records = sample_from_record_array( 

60 current_object_json, primary_record_array, batches 

61 ) 

62 num_records += records_per_batch * parallel_batches 

63 # Use each as seed for parallel gen 

64 new_objects = await generate_seeded_data( 

65 ai_configuration=ai_configuration, 

66 sample_records=sample_records, 

67 generation_guidance=generation_guidance, 

68 primary_record_array=primary_record_array, 

69 total_records=records_per_batch, 

70 near_duplicate_records=duplicate_records_per_batch, 

71 close_relation_records=related_records_per_batch, 

72 data_schema=data_schema, 

73 temperature=temperature, 

74 callbacks=[callback_batch] if callback_batch is not None else None, 

75 ) 

76 

77 for new_object in new_objects: 

78 try: 

79 new_object_json = loads(new_object) 

80 except Exception as e: 

81 msg = f"AI did not return a valid JSON response. Please try again. {e}" 

82 raise ValueError(msg) from e 

83 

84 generated_objects.append(new_object_json) 

85 current_object_json, conflicts = merge_json_objects( 

86 current_object_json, new_object_json 

87 ) 

88 

89 for record_array in record_arrays: 

90 df = extract_df(current_object_json, record_array) 

91 dfs[".".join(record_array)] = df 

92 

93 if df_update_callback is not None: 

94 df_update_callback(dfs) 

95 return current_object_json, dfs 

96 

97 

98def generate_unseeded_data( 

99 ai_configuration, 

100 generation_guidance, 

101 primary_record_array, 

102 total_records, 

103 data_schema, 

104 temperature, 

105): 

106 answer_messages = utils.prepare_messages( 

107 prompts.unseeded_data_generation_prompt, 

108 { 

109 "generation_guidance": generation_guidance, 

110 "primary_record_array": primary_record_array, 

111 "total_records": total_records, 

112 }, 

113 ) 

114 answer_format = { 

115 "type": "json_schema", 

116 "json_schema": {"name": "answer_object", "strict": True, "schema": data_schema}, 

117 } 

118 

119 return utils.generate_text( 

120 ai_configuration, 

121 answer_messages, 

122 response_format=answer_format, 

123 temperature=temperature, 

124 ) 

125 

126 

127async def generate_seeded_data( 

128 ai_configuration, 

129 sample_records, 

130 generation_guidance, 

131 primary_record_array, 

132 total_records, 

133 near_duplicate_records, 

134 close_relation_records, 

135 data_schema, 

136 temperature, 

137 callbacks: list[ProgressBatchCallback] | None = None, 

138): 

139 answer_format = { 

140 "type": "json_schema", 

141 "json_schema": {"name": "answer_object", "strict": True, "schema": data_schema}, 

142 } 

143 mapped_messages = [ 

144 utils.prepare_messages( 

145 prompts.seeded_data_generation_prompt, 

146 { 

147 "seed_record": sample_record, 

148 "generation_guidance": generation_guidance, 

149 "primary_record_array": primary_record_array, 

150 "record_targets": "\n".join( 

151 [ 

152 "Total records: " + str(total_records), 

153 "Near duplicates of seed: " + str(near_duplicate_records), 

154 "Close relations of seed: " + str(close_relation_records), 

155 ] 

156 ), 

157 }, 

158 ) 

159 for sample_record in sample_records 

160 ] 

161 

162 return await utils.map_generate_text( 

163 ai_configuration, 

164 mapped_messages, 

165 response_format=answer_format, 

166 temperature=temperature, 

167 callbacks=callbacks, 

168 ) 

169 

170 

171def select_random_records(num_records, category_to_count): 

172 select = sum(category_to_count.values()) 

173 selected = random.sample(range(num_records), select) 

174 # return category to ids 

175 category_to_ids = {} 

176 for category, count in category_to_count.items(): 

177 category_to_ids[category] = selected[:count] 

178 selected = selected[count:] 

179 return category_to_ids 

180 

181 

182def extract_df(json_data, record_path): 

183 # Extracts a DataFrame from a JSON object 

184 return pd.json_normalize(data=json_data, record_path=record_path) 

185 

186 

187def merge_json_objects(json_obj1, json_obj2): 

188 merged_object = {} 

189 conflicts = [] 

190 

191 def merge_values(key, value1, value2): 

192 if isinstance(value1, dict) and isinstance(value2, dict): 

193 merged_value, sub_conflicts = merge_json_objects(value1, value2) 

194 if sub_conflicts: 

195 conflicts.extend([f"{key}.{sub_key}" for sub_key in sub_conflicts]) 

196 return merged_value 

197 elif isinstance(value1, list) and isinstance(value2, list): 

198 return value1 + value2 

199 elif value1 != value2: 

200 conflicts.append(key) 

201 return value2 

202 else: 

203 return value1 

204 

205 all_keys = set(json_obj1.keys()).union(set(json_obj2.keys())) 

206 

207 for key in all_keys: 

208 if key in json_obj1 and key in json_obj2: 

209 merged_object[key] = merge_values(key, json_obj1[key], json_obj2[key]) 

210 elif key in json_obj1: 

211 merged_object[key] = json_obj1[key] 

212 else: 

213 merged_object[key] = json_obj2[key] 

214 

215 return merged_object, conflicts 

216 

217 

218def extract_array_fields(schema: dict) -> list[list[str]]: 

219 # Extracts any array fields at any level of nesting, and returns a list of lists of field names navigating down the schema 

220 array_fields = [] 

221 

222 def extract_array_fields_recursive(schema, field_path): 

223 if isinstance(schema, dict): 

224 for field_name, field_value in schema.get("properties", {}).items(): 

225 if isinstance(field_value, dict): 

226 if field_value.get("type") == "array": 

227 array_fields.append(field_path + [field_name]) 

228 extract_array_fields_recursive( 

229 field_value.get("items", {}), field_path + [field_name] 

230 ) 

231 else: 

232 extract_array_fields_recursive( 

233 field_value, field_path + [field_name] 

234 ) 

235 elif isinstance(schema, list): 

236 for item in schema: 

237 extract_array_fields_recursive(item, field_path) 

238 

239 extract_array_fields_recursive(schema, []) 

240 return array_fields 

241 

242 

243def sample_from_record_array(current_object, record_array, k): 

244 records = schema_builder.get_subobject(current_object, record_array) 

245 return random.sample(records, k) if len(records) > k else records