Coverage for src/dataknobs_fsm/functions/library/transformers.py: 0%
274 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-11-08 14:11 -0700
« prev ^ index » next coverage.py v7.11.0, created at 2025-11-08 14:11 -0700
1"""Built-in transformer functions for FSM.
3This module provides commonly used transformation functions that can be
4referenced in FSM configurations.
5"""
7import copy
8import json
9import re
10from datetime import datetime
11from typing import Any, Callable, Dict, List, Union
13from dataknobs_fsm.functions.base import ITransformFunction, TransformError
16class FieldMapper(ITransformFunction):
17 """Map fields from source to target names."""
19 def __init__(
20 self,
21 field_map: Dict[str, str],
22 drop_unmapped: bool = False,
23 copy_unmapped: bool = True,
24 ):
25 """Initialize the field mapper.
27 Args:
28 field_map: Dictionary mapping source field names to target names.
29 drop_unmapped: If True, drop fields not in the mapping.
30 copy_unmapped: If True, copy unmapped fields as-is.
31 """
32 self.field_map = field_map
33 self.drop_unmapped = drop_unmapped
34 self.copy_unmapped = copy_unmapped
36 def transform(self, data: Dict[str, Any]) -> Dict[str, Any]:
37 """Transform data by mapping field names.
39 Args:
40 data: Input data.
42 Returns:
43 Transformed data with mapped field names.
44 """
45 result = {}
47 # Map specified fields
48 for source, target in self.field_map.items():
49 if source in data:
50 # Handle nested field paths
51 if "." in source:
52 value = self._get_nested(data, source)
53 else:
54 value = data[source]
56 if "." in target:
57 self._set_nested(result, target, value)
58 else:
59 result[target] = value
61 # Handle unmapped fields
62 if not self.drop_unmapped and self.copy_unmapped:
63 for key, value in data.items():
64 if key not in self.field_map and key not in result:
65 result[key] = value
67 return result
69 def _get_nested(self, data: Dict, path: str) -> Any:
70 """Get value from nested dictionary using dot notation."""
71 parts = path.split(".")
72 value = data
73 for part in parts:
74 if isinstance(value, dict) and part in value:
75 value = value[part]
76 else:
77 return None
78 return value
80 def _set_nested(self, data: Dict, path: str, value: Any) -> None:
81 """Set value in nested dictionary using dot notation."""
82 parts = path.split(".")
83 current = data
84 for part in parts[:-1]:
85 if part not in current:
86 current[part] = {}
87 current = current[part]
88 current[parts[-1]] = value
90 def get_transform_description(self) -> str:
91 """Get a description of the transformation."""
92 mappings = list(self.field_map.items())
93 mapping_str = ", ".join(f"{s}->{t}" for s, t in mappings[:3])
94 if len(mappings) > 3:
95 mapping_str += "..."
96 return f"Map fields: {mapping_str}"
99class ValueNormalizer(ITransformFunction):
100 """Normalize values in data fields."""
102 def __init__(
103 self,
104 normalizations: Dict[str, str],
105 fields: List[str] | None = None,
106 ):
107 """Initialize the value normalizer.
109 Args:
110 normalizations: Dictionary of normalization types:
111 - "lowercase": Convert to lowercase
112 - "uppercase": Convert to uppercase
113 - "trim": Remove leading/trailing whitespace
114 - "snake_case": Convert to snake_case
115 - "camel_case": Convert to camelCase
116 - "pascal_case": Convert to PascalCase
117 - "remove_special": Remove special characters
118 - "normalize_spaces": Replace multiple spaces with single space
119 fields: List of fields to normalize. If None, apply to all string fields.
120 """
121 self.normalizations = normalizations
122 self.fields = fields
124 def transform(self, data: Dict[str, Any]) -> Dict[str, Any]:
125 """Transform data by normalizing values.
127 Args:
128 data: Input data.
130 Returns:
131 Transformed data with normalized values.
132 """
133 result = copy.deepcopy(data)
135 # Determine which fields to process
136 fields_to_process = self.fields if self.fields else list(result.keys())
138 for field in fields_to_process:
139 if field not in result:
140 continue
142 value = result[field]
143 if not isinstance(value, str):
144 continue
146 # Apply normalizations for this field
147 field_normalizations = self.normalizations.get(
148 field, self.normalizations.get("*", [])
149 )
151 if isinstance(field_normalizations, str):
152 field_normalizations = [field_normalizations]
154 for normalization in field_normalizations:
155 value = self._apply_normalization(value, normalization)
157 result[field] = value
159 return result
161 def _apply_normalization(self, value: str, normalization: str) -> str:
162 """Apply a single normalization to a value."""
163 if normalization == "lowercase":
164 return value.lower()
165 elif normalization == "uppercase":
166 return value.upper()
167 elif normalization == "trim":
168 return value.strip()
169 elif normalization == "snake_case":
170 # Convert to snake_case
171 s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', value)
172 return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower()
173 elif normalization == "camel_case":
174 # Convert to camelCase
175 parts = value.replace("-", "_").split("_")
176 return parts[0].lower() + "".join(p.capitalize() for p in parts[1:])
177 elif normalization == "pascal_case":
178 # Convert to PascalCase
179 parts = value.replace("-", "_").split("_")
180 return "".join(p.capitalize() for p in parts)
181 elif normalization == "remove_special":
182 return re.sub(r'[^a-zA-Z0-9\s]', '', value)
183 elif normalization == "normalize_spaces":
184 return re.sub(r'\s+', ' ', value).strip()
185 else:
186 return value
188 def get_transform_description(self) -> str:
189 """Get a description of the transformation."""
190 fields = self.fields if self.fields else ["all fields"]
191 norm_types = set()
192 for val in self.normalizations.values():
193 if isinstance(val, list):
194 norm_types.update(val)
195 else:
196 norm_types.add(val)
197 return f"Normalize {', '.join(str(f) for f in fields[:3])} using {', '.join(list(norm_types)[:3])}"
200class TypeConverter(ITransformFunction):
201 """Convert field types in data."""
203 def __init__(
204 self,
205 conversions: Dict[str, Union[str, type, Callable]],
206 strict: bool = False,
207 ):
208 """Initialize the type converter.
210 Args:
211 conversions: Dictionary mapping field names to target types.
212 Can be type names (str, int, float, bool, list, dict),
213 type objects, or callable converters.
214 strict: If True, raise error on conversion failure.
215 """
216 self.conversions = conversions
217 self.strict = strict
219 def transform(self, data: Dict[str, Any]) -> Dict[str, Any]:
220 """Transform data by converting field types.
222 Args:
223 data: Input data.
225 Returns:
226 Transformed data with converted types.
227 """
228 result = copy.deepcopy(data)
230 for field, target_type in self.conversions.items():
231 if field not in result:
232 continue
234 value = result[field]
236 try:
237 result[field] = self._convert_value(value, target_type)
238 except Exception as e:
239 if self.strict:
240 raise TransformError(
241 f"Failed to convert field '{field}': {e}"
242 ) from e
243 # Keep original value if conversion fails and not strict
245 return result
247 def _convert_value(self, value: Any, target_type: Union[str, type, Callable]) -> Any:
248 """Convert a single value to target type."""
249 if value is None:
250 return None
252 # Handle callable converters
253 if callable(target_type) and not isinstance(target_type, type):
254 return target_type(value)
256 # Handle type names
257 if isinstance(target_type, str):
258 target_type = {
259 "str": str,
260 "int": int,
261 "float": float,
262 "bool": bool,
263 "list": list,
264 "dict": dict,
265 "datetime": datetime.fromisoformat,
266 "json": json.loads,
267 }.get(target_type, str)
269 # Special handling for bool conversion
270 if target_type == bool and isinstance(value, str):
271 return value.lower() in ["true", "yes", "1", "on"]
273 # Special handling for datetime
274 if target_type == datetime.fromisoformat and isinstance(value, str):
275 return datetime.fromisoformat(value)
277 # Standard type conversion
278 return target_type(value) # type: ignore
280 def get_transform_description(self) -> str:
281 """Get a description of the transformation."""
282 conversions = list(self.conversions.items())
283 conv_str = ", ".join(f"{k}:{v}" for k, v in conversions[:3])
284 if len(conversions) > 3:
285 conv_str += "..."
286 return f"Convert types: {conv_str}"
289class DataEnricher(ITransformFunction):
290 """Enrich data with additional fields."""
292 def __init__(
293 self,
294 enrichments: Dict[str, Any],
295 overwrite: bool = False,
296 ):
297 """Initialize the data enricher.
299 Args:
300 enrichments: Dictionary of fields to add/update.
301 Values can be static or callables.
302 overwrite: If True, overwrite existing fields.
303 """
304 self.enrichments = enrichments
305 self.overwrite = overwrite
307 def transform(self, data: Dict[str, Any]) -> Dict[str, Any]:
308 """Transform data by adding enrichment fields.
310 Args:
311 data: Input data.
313 Returns:
314 Transformed data with enrichments.
315 """
316 result = copy.deepcopy(data)
318 for field, value in self.enrichments.items():
319 # Skip if field exists and not overwriting
320 if field in result and not self.overwrite:
321 continue
323 # Evaluate value if callable
324 if callable(value):
325 try:
326 result[field] = value(data)
327 except Exception as e:
328 raise TransformError(
329 f"Failed to compute enrichment for '{field}': {e}"
330 ) from e
331 else:
332 result[field] = value
334 return result
336 def get_transform_description(self) -> str:
337 """Get a description of the transformation."""
338 fields = list(self.enrichments.keys())
339 return f"Enrich data with fields: {', '.join(fields[:3])}{'...' if len(fields) > 3 else ''}"
342class FieldFilter(ITransformFunction):
343 """Filter fields from data."""
345 def __init__(
346 self,
347 include: List[str] | None = None,
348 exclude: List[str] | None = None,
349 ):
350 """Initialize the field filter.
352 Args:
353 include: List of fields to include (whitelist).
354 exclude: List of fields to exclude (blacklist).
355 """
356 if include and exclude:
357 raise ValueError("Cannot specify both include and exclude")
359 self.include = include
360 self.exclude = exclude
362 def transform(self, data: Dict[str, Any]) -> Dict[str, Any]:
363 """Transform data by filtering fields.
365 Args:
366 data: Input data.
368 Returns:
369 Transformed data with filtered fields.
370 """
371 if self.include:
372 # Include only specified fields
373 return {k: v for k, v in data.items() if k in self.include}
374 elif self.exclude:
375 # Exclude specified fields
376 return {k: v for k, v in data.items() if k not in self.exclude}
377 else:
378 # No filtering
379 return data.copy()
381 def get_transform_description(self) -> str:
382 """Get a description of the transformation."""
383 if self.include:
384 fields = ', '.join(self.include[:3])
385 if len(self.include) > 3:
386 fields += "..."
387 return f"Include only fields: {fields}"
388 elif self.exclude:
389 fields = ', '.join(self.exclude[:3])
390 if len(self.exclude) > 3:
391 fields += "..."
392 return f"Exclude fields: {fields}"
393 else:
394 return "No field filtering"
397class ValueReplacer(ITransformFunction):
398 """Replace specific values in data fields."""
400 def __init__(
401 self,
402 replacements: Dict[str, Dict[Any, Any]],
403 default_replacements: Dict[Any, Any] | None = None,
404 ):
405 """Initialize the value replacer.
407 Args:
408 replacements: Dictionary mapping field names to replacement mappings.
409 default_replacements: Default replacements for all fields.
410 """
411 self.replacements = replacements
412 self.default_replacements = default_replacements or {}
414 def transform(self, data: Dict[str, Any]) -> Dict[str, Any]:
415 """Transform data by replacing values.
417 Args:
418 data: Input data.
420 Returns:
421 Transformed data with replaced values.
422 """
423 result = copy.deepcopy(data)
425 for field, value in result.items():
426 # Get replacements for this field
427 field_replacements = self.replacements.get(field, self.default_replacements)
429 if value in field_replacements:
430 result[field] = field_replacements[value]
432 return result
434 def get_transform_description(self) -> str:
435 """Get a description of the transformation."""
436 fields = list(self.replacements.keys())[:3]
437 field_str = ', '.join(fields)
438 if len(self.replacements) > 3:
439 field_str += "..."
440 return f"Replace values in fields: {field_str if fields else 'all fields'}"
443class ArrayFlattener(ITransformFunction):
444 """Flatten nested arrays in data."""
446 def __init__(
447 self,
448 fields: List[str],
449 depth: int = 1,
450 ):
451 """Initialize the array flattener.
453 Args:
454 fields: List of fields containing arrays to flatten.
455 depth: Number of levels to flatten (0 = fully flatten).
456 """
457 self.fields = fields
458 self.depth = depth
460 def transform(self, data: Dict[str, Any]) -> Dict[str, Any]:
461 """Transform data by flattening arrays.
463 Args:
464 data: Input data.
466 Returns:
467 Transformed data with flattened arrays.
468 """
469 result = copy.deepcopy(data)
471 for field in self.fields:
472 if field not in result:
473 continue
475 value = result[field]
476 if isinstance(value, list):
477 result[field] = self._flatten(value, self.depth)
479 return result
481 def _flatten(self, arr: List, depth: int) -> List:
482 """Recursively flatten an array."""
483 if depth == 0:
484 # Fully flatten
485 result = []
486 for item in arr:
487 if isinstance(item, list):
488 result.extend(self._flatten(item, 0))
489 else:
490 result.append(item)
491 return result
492 else:
493 # Flatten to specified depth
494 result = []
495 for item in arr:
496 if isinstance(item, list) and depth > 1:
497 result.extend(self._flatten(item, depth - 1))
498 elif isinstance(item, list):
499 result.extend(item)
500 else:
501 result.append(item)
502 return result
504 def get_transform_description(self) -> str:
505 """Get a description of the transformation."""
506 fields = ', '.join(self.fields[:3])
507 if len(self.fields) > 3:
508 fields += "..."
509 depth_str = "fully" if self.depth == 0 else f"to depth {self.depth}"
510 return f"Flatten arrays in {fields} {depth_str}"
513class DataSplitter(ITransformFunction):
514 """Split data into multiple records based on a field."""
516 def __init__(
517 self,
518 split_field: str,
519 output_field: str = "records",
520 ):
521 """Initialize the data splitter.
523 Args:
524 split_field: Field containing array to split on.
525 output_field: Name of output field containing split records.
526 """
527 self.split_field = split_field
528 self.output_field = output_field
530 def transform(self, data: Dict[str, Any]) -> Dict[str, Any]:
531 """Transform data by splitting into multiple records.
533 Args:
534 data: Input data.
536 Returns:
537 Transformed data with split records.
538 """
539 if self.split_field not in data:
540 raise TransformError(f"Split field '{self.split_field}' not found")
542 split_values = data[self.split_field]
543 if not isinstance(split_values, list):
544 raise TransformError("Split field must be a list")
546 # Create a record for each value
547 records = []
548 base_data = {k: v for k, v in data.items() if k != self.split_field}
550 for value in split_values:
551 record = copy.deepcopy(base_data)
552 record[self.split_field] = value
553 records.append(record)
555 return {self.output_field: records}
557 def get_transform_description(self) -> str:
558 """Get a description of the transformation."""
559 return f"Split data on field '{self.split_field}' into '{self.output_field}'"
562class ChainTransformer(ITransformFunction):
563 """Chain multiple transformers together."""
565 def __init__(self, transformers: List[ITransformFunction]):
566 """Initialize the chain transformer.
568 Args:
569 transformers: List of transformers to apply in sequence.
570 """
571 self.transformers = transformers
573 def transform(self, data: Dict[str, Any]) -> Dict[str, Any]:
574 """Apply all transformers in sequence.
576 Args:
577 data: Input data.
579 Returns:
580 Transformed data after all transformers.
581 """
582 result = data
583 for transformer in self.transformers:
584 result = transformer.transform(result)
585 return result
587 def get_transform_description(self) -> str:
588 """Get a description of the transformation."""
589 count = len(self.transformers)
590 return f"Chain {count} transformer{'s' if count != 1 else ''} in sequence"
593# Convenience functions for creating transformers
594def map_fields(mapping: Dict[str, str], **kwargs) -> FieldMapper:
595 """Create a FieldMapper."""
596 return FieldMapper(mapping, **kwargs)
599def normalize(**normalizations: str) -> ValueNormalizer:
600 """Create a ValueNormalizer."""
601 return ValueNormalizer(normalizations)
604def convert_types(**conversions: Union[str, type, Callable]) -> TypeConverter:
605 """Create a TypeConverter."""
606 return TypeConverter(conversions)
609def enrich(**enrichments: Any) -> DataEnricher:
610 """Create a DataEnricher."""
611 return DataEnricher(enrichments)
614def filter_fields(include: List[str] | None = None, exclude: List[str] | None = None) -> FieldFilter:
615 """Create a FieldFilter."""
616 return FieldFilter(include, exclude)
619def replace_values(**replacements: Dict[Any, Any]) -> ValueReplacer:
620 """Create a ValueReplacer."""
621 return ValueReplacer(replacements)
624def flatten(*fields: str, depth: int = 1) -> ArrayFlattener:
625 """Create an ArrayFlattener."""
626 return ArrayFlattener(list(fields), depth)
629def split_on(field: str, output: str = "records") -> DataSplitter:
630 """Create a DataSplitter."""
631 return DataSplitter(field, output)
634def chain(*transformers: ITransformFunction) -> ChainTransformer:
635 """Create a ChainTransformer."""
636 return ChainTransformer(list(transformers))