Coverage for src/dataknobs_fsm/functions/library/transformers.py: 0%

274 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-11-08 14:11 -0700

1"""Built-in transformer functions for FSM. 

2 

3This module provides commonly used transformation functions that can be 

4referenced in FSM configurations. 

5""" 

6 

7import copy 

8import json 

9import re 

10from datetime import datetime 

11from typing import Any, Callable, Dict, List, Union 

12 

13from dataknobs_fsm.functions.base import ITransformFunction, TransformError 

14 

15 

16class FieldMapper(ITransformFunction): 

17 """Map fields from source to target names.""" 

18 

19 def __init__( 

20 self, 

21 field_map: Dict[str, str], 

22 drop_unmapped: bool = False, 

23 copy_unmapped: bool = True, 

24 ): 

25 """Initialize the field mapper. 

26  

27 Args: 

28 field_map: Dictionary mapping source field names to target names. 

29 drop_unmapped: If True, drop fields not in the mapping. 

30 copy_unmapped: If True, copy unmapped fields as-is. 

31 """ 

32 self.field_map = field_map 

33 self.drop_unmapped = drop_unmapped 

34 self.copy_unmapped = copy_unmapped 

35 

36 def transform(self, data: Dict[str, Any]) -> Dict[str, Any]: 

37 """Transform data by mapping field names. 

38  

39 Args: 

40 data: Input data. 

41  

42 Returns: 

43 Transformed data with mapped field names. 

44 """ 

45 result = {} 

46 

47 # Map specified fields 

48 for source, target in self.field_map.items(): 

49 if source in data: 

50 # Handle nested field paths 

51 if "." in source: 

52 value = self._get_nested(data, source) 

53 else: 

54 value = data[source] 

55 

56 if "." in target: 

57 self._set_nested(result, target, value) 

58 else: 

59 result[target] = value 

60 

61 # Handle unmapped fields 

62 if not self.drop_unmapped and self.copy_unmapped: 

63 for key, value in data.items(): 

64 if key not in self.field_map and key not in result: 

65 result[key] = value 

66 

67 return result 

68 

69 def _get_nested(self, data: Dict, path: str) -> Any: 

70 """Get value from nested dictionary using dot notation.""" 

71 parts = path.split(".") 

72 value = data 

73 for part in parts: 

74 if isinstance(value, dict) and part in value: 

75 value = value[part] 

76 else: 

77 return None 

78 return value 

79 

80 def _set_nested(self, data: Dict, path: str, value: Any) -> None: 

81 """Set value in nested dictionary using dot notation.""" 

82 parts = path.split(".") 

83 current = data 

84 for part in parts[:-1]: 

85 if part not in current: 

86 current[part] = {} 

87 current = current[part] 

88 current[parts[-1]] = value 

89 

90 def get_transform_description(self) -> str: 

91 """Get a description of the transformation.""" 

92 mappings = list(self.field_map.items()) 

93 mapping_str = ", ".join(f"{s}->{t}" for s, t in mappings[:3]) 

94 if len(mappings) > 3: 

95 mapping_str += "..." 

96 return f"Map fields: {mapping_str}" 

97 

98 

99class ValueNormalizer(ITransformFunction): 

100 """Normalize values in data fields.""" 

101 

102 def __init__( 

103 self, 

104 normalizations: Dict[str, str], 

105 fields: List[str] | None = None, 

106 ): 

107 """Initialize the value normalizer. 

108  

109 Args: 

110 normalizations: Dictionary of normalization types: 

111 - "lowercase": Convert to lowercase 

112 - "uppercase": Convert to uppercase 

113 - "trim": Remove leading/trailing whitespace 

114 - "snake_case": Convert to snake_case 

115 - "camel_case": Convert to camelCase 

116 - "pascal_case": Convert to PascalCase 

117 - "remove_special": Remove special characters 

118 - "normalize_spaces": Replace multiple spaces with single space 

119 fields: List of fields to normalize. If None, apply to all string fields. 

120 """ 

121 self.normalizations = normalizations 

122 self.fields = fields 

123 

124 def transform(self, data: Dict[str, Any]) -> Dict[str, Any]: 

125 """Transform data by normalizing values. 

126  

127 Args: 

128 data: Input data. 

129  

130 Returns: 

131 Transformed data with normalized values. 

132 """ 

133 result = copy.deepcopy(data) 

134 

135 # Determine which fields to process 

136 fields_to_process = self.fields if self.fields else list(result.keys()) 

137 

138 for field in fields_to_process: 

139 if field not in result: 

140 continue 

141 

142 value = result[field] 

143 if not isinstance(value, str): 

144 continue 

145 

146 # Apply normalizations for this field 

147 field_normalizations = self.normalizations.get( 

148 field, self.normalizations.get("*", []) 

149 ) 

150 

151 if isinstance(field_normalizations, str): 

152 field_normalizations = [field_normalizations] 

153 

154 for normalization in field_normalizations: 

155 value = self._apply_normalization(value, normalization) 

156 

157 result[field] = value 

158 

159 return result 

160 

161 def _apply_normalization(self, value: str, normalization: str) -> str: 

162 """Apply a single normalization to a value.""" 

163 if normalization == "lowercase": 

164 return value.lower() 

165 elif normalization == "uppercase": 

166 return value.upper() 

167 elif normalization == "trim": 

168 return value.strip() 

169 elif normalization == "snake_case": 

170 # Convert to snake_case 

171 s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', value) 

172 return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower() 

173 elif normalization == "camel_case": 

174 # Convert to camelCase 

175 parts = value.replace("-", "_").split("_") 

176 return parts[0].lower() + "".join(p.capitalize() for p in parts[1:]) 

177 elif normalization == "pascal_case": 

178 # Convert to PascalCase 

179 parts = value.replace("-", "_").split("_") 

180 return "".join(p.capitalize() for p in parts) 

181 elif normalization == "remove_special": 

182 return re.sub(r'[^a-zA-Z0-9\s]', '', value) 

183 elif normalization == "normalize_spaces": 

184 return re.sub(r'\s+', ' ', value).strip() 

185 else: 

186 return value 

187 

188 def get_transform_description(self) -> str: 

189 """Get a description of the transformation.""" 

190 fields = self.fields if self.fields else ["all fields"] 

191 norm_types = set() 

192 for val in self.normalizations.values(): 

193 if isinstance(val, list): 

194 norm_types.update(val) 

195 else: 

196 norm_types.add(val) 

197 return f"Normalize {', '.join(str(f) for f in fields[:3])} using {', '.join(list(norm_types)[:3])}" 

198 

199 

200class TypeConverter(ITransformFunction): 

201 """Convert field types in data.""" 

202 

203 def __init__( 

204 self, 

205 conversions: Dict[str, Union[str, type, Callable]], 

206 strict: bool = False, 

207 ): 

208 """Initialize the type converter. 

209  

210 Args: 

211 conversions: Dictionary mapping field names to target types. 

212 Can be type names (str, int, float, bool, list, dict), 

213 type objects, or callable converters. 

214 strict: If True, raise error on conversion failure. 

215 """ 

216 self.conversions = conversions 

217 self.strict = strict 

218 

219 def transform(self, data: Dict[str, Any]) -> Dict[str, Any]: 

220 """Transform data by converting field types. 

221  

222 Args: 

223 data: Input data. 

224  

225 Returns: 

226 Transformed data with converted types. 

227 """ 

228 result = copy.deepcopy(data) 

229 

230 for field, target_type in self.conversions.items(): 

231 if field not in result: 

232 continue 

233 

234 value = result[field] 

235 

236 try: 

237 result[field] = self._convert_value(value, target_type) 

238 except Exception as e: 

239 if self.strict: 

240 raise TransformError( 

241 f"Failed to convert field '{field}': {e}" 

242 ) from e 

243 # Keep original value if conversion fails and not strict 

244 

245 return result 

246 

247 def _convert_value(self, value: Any, target_type: Union[str, type, Callable]) -> Any: 

248 """Convert a single value to target type.""" 

249 if value is None: 

250 return None 

251 

252 # Handle callable converters 

253 if callable(target_type) and not isinstance(target_type, type): 

254 return target_type(value) 

255 

256 # Handle type names 

257 if isinstance(target_type, str): 

258 target_type = { 

259 "str": str, 

260 "int": int, 

261 "float": float, 

262 "bool": bool, 

263 "list": list, 

264 "dict": dict, 

265 "datetime": datetime.fromisoformat, 

266 "json": json.loads, 

267 }.get(target_type, str) 

268 

269 # Special handling for bool conversion 

270 if target_type == bool and isinstance(value, str): 

271 return value.lower() in ["true", "yes", "1", "on"] 

272 

273 # Special handling for datetime 

274 if target_type == datetime.fromisoformat and isinstance(value, str): 

275 return datetime.fromisoformat(value) 

276 

277 # Standard type conversion 

278 return target_type(value) # type: ignore 

279 

280 def get_transform_description(self) -> str: 

281 """Get a description of the transformation.""" 

282 conversions = list(self.conversions.items()) 

283 conv_str = ", ".join(f"{k}:{v}" for k, v in conversions[:3]) 

284 if len(conversions) > 3: 

285 conv_str += "..." 

286 return f"Convert types: {conv_str}" 

287 

288 

289class DataEnricher(ITransformFunction): 

290 """Enrich data with additional fields.""" 

291 

292 def __init__( 

293 self, 

294 enrichments: Dict[str, Any], 

295 overwrite: bool = False, 

296 ): 

297 """Initialize the data enricher. 

298  

299 Args: 

300 enrichments: Dictionary of fields to add/update. 

301 Values can be static or callables. 

302 overwrite: If True, overwrite existing fields. 

303 """ 

304 self.enrichments = enrichments 

305 self.overwrite = overwrite 

306 

307 def transform(self, data: Dict[str, Any]) -> Dict[str, Any]: 

308 """Transform data by adding enrichment fields. 

309  

310 Args: 

311 data: Input data. 

312  

313 Returns: 

314 Transformed data with enrichments. 

315 """ 

316 result = copy.deepcopy(data) 

317 

318 for field, value in self.enrichments.items(): 

319 # Skip if field exists and not overwriting 

320 if field in result and not self.overwrite: 

321 continue 

322 

323 # Evaluate value if callable 

324 if callable(value): 

325 try: 

326 result[field] = value(data) 

327 except Exception as e: 

328 raise TransformError( 

329 f"Failed to compute enrichment for '{field}': {e}" 

330 ) from e 

331 else: 

332 result[field] = value 

333 

334 return result 

335 

336 def get_transform_description(self) -> str: 

337 """Get a description of the transformation.""" 

338 fields = list(self.enrichments.keys()) 

339 return f"Enrich data with fields: {', '.join(fields[:3])}{'...' if len(fields) > 3 else ''}" 

340 

341 

342class FieldFilter(ITransformFunction): 

343 """Filter fields from data.""" 

344 

345 def __init__( 

346 self, 

347 include: List[str] | None = None, 

348 exclude: List[str] | None = None, 

349 ): 

350 """Initialize the field filter. 

351  

352 Args: 

353 include: List of fields to include (whitelist). 

354 exclude: List of fields to exclude (blacklist). 

355 """ 

356 if include and exclude: 

357 raise ValueError("Cannot specify both include and exclude") 

358 

359 self.include = include 

360 self.exclude = exclude 

361 

362 def transform(self, data: Dict[str, Any]) -> Dict[str, Any]: 

363 """Transform data by filtering fields. 

364 

365 Args: 

366 data: Input data. 

367 

368 Returns: 

369 Transformed data with filtered fields. 

370 """ 

371 if self.include: 

372 # Include only specified fields 

373 return {k: v for k, v in data.items() if k in self.include} 

374 elif self.exclude: 

375 # Exclude specified fields 

376 return {k: v for k, v in data.items() if k not in self.exclude} 

377 else: 

378 # No filtering 

379 return data.copy() 

380 

381 def get_transform_description(self) -> str: 

382 """Get a description of the transformation.""" 

383 if self.include: 

384 fields = ', '.join(self.include[:3]) 

385 if len(self.include) > 3: 

386 fields += "..." 

387 return f"Include only fields: {fields}" 

388 elif self.exclude: 

389 fields = ', '.join(self.exclude[:3]) 

390 if len(self.exclude) > 3: 

391 fields += "..." 

392 return f"Exclude fields: {fields}" 

393 else: 

394 return "No field filtering" 

395 

396 

397class ValueReplacer(ITransformFunction): 

398 """Replace specific values in data fields.""" 

399 

400 def __init__( 

401 self, 

402 replacements: Dict[str, Dict[Any, Any]], 

403 default_replacements: Dict[Any, Any] | None = None, 

404 ): 

405 """Initialize the value replacer. 

406  

407 Args: 

408 replacements: Dictionary mapping field names to replacement mappings. 

409 default_replacements: Default replacements for all fields. 

410 """ 

411 self.replacements = replacements 

412 self.default_replacements = default_replacements or {} 

413 

414 def transform(self, data: Dict[str, Any]) -> Dict[str, Any]: 

415 """Transform data by replacing values. 

416 

417 Args: 

418 data: Input data. 

419 

420 Returns: 

421 Transformed data with replaced values. 

422 """ 

423 result = copy.deepcopy(data) 

424 

425 for field, value in result.items(): 

426 # Get replacements for this field 

427 field_replacements = self.replacements.get(field, self.default_replacements) 

428 

429 if value in field_replacements: 

430 result[field] = field_replacements[value] 

431 

432 return result 

433 

434 def get_transform_description(self) -> str: 

435 """Get a description of the transformation.""" 

436 fields = list(self.replacements.keys())[:3] 

437 field_str = ', '.join(fields) 

438 if len(self.replacements) > 3: 

439 field_str += "..." 

440 return f"Replace values in fields: {field_str if fields else 'all fields'}" 

441 

442 

443class ArrayFlattener(ITransformFunction): 

444 """Flatten nested arrays in data.""" 

445 

446 def __init__( 

447 self, 

448 fields: List[str], 

449 depth: int = 1, 

450 ): 

451 """Initialize the array flattener. 

452  

453 Args: 

454 fields: List of fields containing arrays to flatten. 

455 depth: Number of levels to flatten (0 = fully flatten). 

456 """ 

457 self.fields = fields 

458 self.depth = depth 

459 

460 def transform(self, data: Dict[str, Any]) -> Dict[str, Any]: 

461 """Transform data by flattening arrays. 

462  

463 Args: 

464 data: Input data. 

465  

466 Returns: 

467 Transformed data with flattened arrays. 

468 """ 

469 result = copy.deepcopy(data) 

470 

471 for field in self.fields: 

472 if field not in result: 

473 continue 

474 

475 value = result[field] 

476 if isinstance(value, list): 

477 result[field] = self._flatten(value, self.depth) 

478 

479 return result 

480 

481 def _flatten(self, arr: List, depth: int) -> List: 

482 """Recursively flatten an array.""" 

483 if depth == 0: 

484 # Fully flatten 

485 result = [] 

486 for item in arr: 

487 if isinstance(item, list): 

488 result.extend(self._flatten(item, 0)) 

489 else: 

490 result.append(item) 

491 return result 

492 else: 

493 # Flatten to specified depth 

494 result = [] 

495 for item in arr: 

496 if isinstance(item, list) and depth > 1: 

497 result.extend(self._flatten(item, depth - 1)) 

498 elif isinstance(item, list): 

499 result.extend(item) 

500 else: 

501 result.append(item) 

502 return result 

503 

504 def get_transform_description(self) -> str: 

505 """Get a description of the transformation.""" 

506 fields = ', '.join(self.fields[:3]) 

507 if len(self.fields) > 3: 

508 fields += "..." 

509 depth_str = "fully" if self.depth == 0 else f"to depth {self.depth}" 

510 return f"Flatten arrays in {fields} {depth_str}" 

511 

512 

513class DataSplitter(ITransformFunction): 

514 """Split data into multiple records based on a field.""" 

515 

516 def __init__( 

517 self, 

518 split_field: str, 

519 output_field: str = "records", 

520 ): 

521 """Initialize the data splitter. 

522  

523 Args: 

524 split_field: Field containing array to split on. 

525 output_field: Name of output field containing split records. 

526 """ 

527 self.split_field = split_field 

528 self.output_field = output_field 

529 

530 def transform(self, data: Dict[str, Any]) -> Dict[str, Any]: 

531 """Transform data by splitting into multiple records. 

532  

533 Args: 

534 data: Input data. 

535  

536 Returns: 

537 Transformed data with split records. 

538 """ 

539 if self.split_field not in data: 

540 raise TransformError(f"Split field '{self.split_field}' not found") 

541 

542 split_values = data[self.split_field] 

543 if not isinstance(split_values, list): 

544 raise TransformError("Split field must be a list") 

545 

546 # Create a record for each value 

547 records = [] 

548 base_data = {k: v for k, v in data.items() if k != self.split_field} 

549 

550 for value in split_values: 

551 record = copy.deepcopy(base_data) 

552 record[self.split_field] = value 

553 records.append(record) 

554 

555 return {self.output_field: records} 

556 

557 def get_transform_description(self) -> str: 

558 """Get a description of the transformation.""" 

559 return f"Split data on field '{self.split_field}' into '{self.output_field}'" 

560 

561 

562class ChainTransformer(ITransformFunction): 

563 """Chain multiple transformers together.""" 

564 

565 def __init__(self, transformers: List[ITransformFunction]): 

566 """Initialize the chain transformer. 

567  

568 Args: 

569 transformers: List of transformers to apply in sequence. 

570 """ 

571 self.transformers = transformers 

572 

573 def transform(self, data: Dict[str, Any]) -> Dict[str, Any]: 

574 """Apply all transformers in sequence. 

575 

576 Args: 

577 data: Input data. 

578 

579 Returns: 

580 Transformed data after all transformers. 

581 """ 

582 result = data 

583 for transformer in self.transformers: 

584 result = transformer.transform(result) 

585 return result 

586 

587 def get_transform_description(self) -> str: 

588 """Get a description of the transformation.""" 

589 count = len(self.transformers) 

590 return f"Chain {count} transformer{'s' if count != 1 else ''} in sequence" 

591 

592 

593# Convenience functions for creating transformers 

594def map_fields(mapping: Dict[str, str], **kwargs) -> FieldMapper: 

595 """Create a FieldMapper.""" 

596 return FieldMapper(mapping, **kwargs) 

597 

598 

599def normalize(**normalizations: str) -> ValueNormalizer: 

600 """Create a ValueNormalizer.""" 

601 return ValueNormalizer(normalizations) 

602 

603 

604def convert_types(**conversions: Union[str, type, Callable]) -> TypeConverter: 

605 """Create a TypeConverter.""" 

606 return TypeConverter(conversions) 

607 

608 

609def enrich(**enrichments: Any) -> DataEnricher: 

610 """Create a DataEnricher.""" 

611 return DataEnricher(enrichments) 

612 

613 

614def filter_fields(include: List[str] | None = None, exclude: List[str] | None = None) -> FieldFilter: 

615 """Create a FieldFilter.""" 

616 return FieldFilter(include, exclude) 

617 

618 

619def replace_values(**replacements: Dict[Any, Any]) -> ValueReplacer: 

620 """Create a ValueReplacer.""" 

621 return ValueReplacer(replacements) 

622 

623 

624def flatten(*fields: str, depth: int = 1) -> ArrayFlattener: 

625 """Create an ArrayFlattener.""" 

626 return ArrayFlattener(list(fields), depth) 

627 

628 

629def split_on(field: str, output: str = "records") -> DataSplitter: 

630 """Create a DataSplitter.""" 

631 return DataSplitter(field, output) 

632 

633 

634def chain(*transformers: ITransformFunction) -> ChainTransformer: 

635 """Create a ChainTransformer.""" 

636 return ChainTransformer(list(transformers))