Coverage for src/dataknobs_fsm/functions/library/streaming.py: 0%
277 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-11-08 14:11 -0700
« prev ^ index » next coverage.py v7.11.0, created at 2025-11-08 14:11 -0700
1"""Built-in streaming functions for FSM.
3This module provides streaming-related functions that can be referenced
4in FSM configurations for processing large data sets efficiently.
5"""
7import csv
8import json
9from pathlib import Path
10from typing import Any, Dict, List, Union
12from dataknobs_fsm.functions.base import ITransformFunction, TransformError
13from dataknobs_fsm.streaming.core import IStreamSource
16class ChunkReader(ITransformFunction):
17 """Read data in chunks from a source."""
19 def __init__(
20 self,
21 source: Union[str, IStreamSource],
22 chunk_size: int = 1000,
23 format: str = "auto", # "auto", "json", "csv", "lines"
24 ):
25 """Initialize the chunk reader.
27 Args:
28 source: Data source (file path or stream source).
29 chunk_size: Number of records per chunk.
30 format: Data format to expect.
31 """
32 self.source = source
33 self.chunk_size = chunk_size
34 self.format = format
36 async def transform(self, data: Dict[str, Any]) -> Dict[str, Any]:
37 """Transform data by reading next chunk from source.
39 Args:
40 data: Input data (may contain chunk state).
42 Returns:
43 Data with next chunk of records.
44 """
45 # Get or initialize chunk state
46 chunk_state = data.get("_chunk_state", {})
48 if isinstance(self.source, str):
49 # File source
50 file_path = Path(self.source)
51 if not file_path.exists():
52 raise TransformError(f"File not found: {self.source}")
54 # Determine format
55 format = self.format
56 if format == "auto":
57 format = self._detect_format(file_path)
59 # Read chunk based on format
60 if format == "json":
61 chunk = await self._read_json_chunk(file_path, chunk_state)
62 elif format == "csv":
63 chunk = await self._read_csv_chunk(file_path, chunk_state)
64 elif format == "lines":
65 chunk = await self._read_lines_chunk(file_path, chunk_state)
66 else:
67 raise TransformError(f"Unsupported format: {format}")
69 else:
70 # Stream source
71 chunk = await self._read_stream_chunk(self.source, chunk_state)
73 return {
74 **data,
75 "chunk": chunk["records"],
76 "has_more": chunk["has_more"],
77 "_chunk_state": chunk["state"],
78 }
80 def _detect_format(self, file_path: Path) -> str:
81 """Detect file format from extension."""
82 suffix = file_path.suffix.lower()
83 if suffix == ".json":
84 return "json"
85 elif suffix == ".csv":
86 return "csv"
87 else:
88 return "lines"
90 def get_transform_description(self) -> str:
91 """Get a description of the transformation."""
92 source_str = str(self.source) if isinstance(self.source, str) else "stream"
93 return f"Read {self.chunk_size} records from {source_str} in {self.format} format"
95 async def _read_json_chunk(
96 self, file_path: Path, state: Dict[str, Any]
97 ) -> Dict[str, Any]:
98 """Read chunk from JSON file."""
99 offset = state.get("offset", 0)
101 # For JSON, we need to load the entire file (or use streaming JSON parser)
102 with open(file_path) as f:
103 data = json.load(f)
105 if isinstance(data, list):
106 chunk = data[offset:offset + self.chunk_size]
107 has_more = offset + self.chunk_size < len(data)
108 new_offset = offset + len(chunk)
109 else:
110 # Single object
111 if offset == 0:
112 chunk = [data]
113 has_more = False
114 new_offset = 1
115 else:
116 chunk = []
117 has_more = False
118 new_offset = offset
120 return {
121 "records": chunk,
122 "has_more": has_more,
123 "state": {"offset": new_offset},
124 }
126 async def _read_csv_chunk(
127 self, file_path: Path, state: Dict[str, Any]
128 ) -> Dict[str, Any]:
129 """Read chunk from CSV file."""
130 offset = state.get("offset", 0)
131 records = []
133 with open(file_path) as f:
134 reader = csv.DictReader(f)
136 # Skip to offset
137 for _ in range(offset):
138 try:
139 next(reader)
140 except StopIteration:
141 break
143 # Read chunk
144 for _ in range(self.chunk_size):
145 try:
146 records.append(next(reader))
147 except StopIteration:
148 break
150 has_more = len(records) == self.chunk_size
151 new_offset = offset + len(records)
153 return {
154 "records": records,
155 "has_more": has_more,
156 "state": {"offset": new_offset},
157 }
159 async def _read_lines_chunk(
160 self, file_path: Path, state: Dict[str, Any]
161 ) -> Dict[str, Any]:
162 """Read chunk of lines from file."""
163 offset = state.get("offset", 0)
164 records = []
166 with open(file_path) as f:
167 # Skip to offset
168 for _ in range(offset):
169 if not f.readline():
170 break
172 # Read chunk
173 for _ in range(self.chunk_size):
174 line = f.readline()
175 if not line:
176 break
177 records.append({"line": line.strip()})
179 has_more = len(records) == self.chunk_size
180 new_offset = offset + len(records)
182 return {
183 "records": records,
184 "has_more": has_more,
185 "state": {"offset": new_offset},
186 }
188 async def _read_stream_chunk(
189 self, source: IStreamSource, state: Dict[str, Any]
190 ) -> Dict[str, Any]:
191 """Read chunk from stream source."""
192 records = []
194 async for record in source.read(self.chunk_size):
195 records.append(record)
197 has_more = len(records) == self.chunk_size
199 return {
200 "records": records,
201 "has_more": has_more,
202 "state": {"stream_position": source.position if hasattr(source, "position") else None},
203 }
206class RecordParser(ITransformFunction):
207 """Parse records from various formats."""
209 def __init__(
210 self,
211 format: str,
212 field: str = "raw",
213 output_field: str = "parsed",
214 options: Dict[str, Any] | None = None,
215 ):
216 """Initialize the record parser.
218 Args:
219 format: Format to parse ("json", "csv", "xml", "yaml").
220 field: Field containing raw data to parse.
221 output_field: Field to store parsed data.
222 options: Format-specific parsing options.
223 """
224 self.format = format
225 self.field = field
226 self.output_field = output_field
227 self.options = options or {}
229 def transform(self, data: Dict[str, Any]) -> Dict[str, Any]:
230 """Transform data by parsing records.
232 Args:
233 data: Input data containing raw records.
235 Returns:
236 Data with parsed records.
237 """
238 raw_data = data.get(self.field)
239 if raw_data is None:
240 return data
242 try:
243 if self.format == "json":
244 parsed = self._parse_json(raw_data)
245 elif self.format == "csv":
246 parsed = self._parse_csv(raw_data)
247 elif self.format == "yaml":
248 parsed = self._parse_yaml(raw_data)
249 elif self.format == "xml":
250 parsed = self._parse_xml(raw_data)
251 else:
252 raise TransformError(f"Unsupported format: {self.format}")
254 return {
255 **data,
256 self.output_field: parsed,
257 }
259 except Exception as e:
260 raise TransformError(f"Failed to parse {self.format}: {e}") from e
262 def _parse_json(self, raw: Union[str, bytes]) -> Any:
263 """Parse JSON data."""
264 if isinstance(raw, bytes):
265 raw = raw.decode("utf-8")
266 return json.loads(raw)
268 def _parse_csv(self, raw: Union[str, bytes]) -> List[Dict[str, Any]]:
269 """Parse CSV data."""
270 if isinstance(raw, bytes):
271 raw = raw.decode("utf-8")
273 import io
274 reader = csv.DictReader(io.StringIO(raw), **self.options)
275 return list(reader)
277 def _parse_yaml(self, raw: Union[str, bytes]) -> Any:
278 """Parse YAML data."""
279 import yaml
280 if isinstance(raw, bytes):
281 raw = raw.decode("utf-8")
282 return yaml.safe_load(raw)
284 def _parse_xml(self, raw: Union[str, bytes]) -> Dict[str, Any]:
285 """Parse XML data."""
286 import xml.etree.ElementTree as ET
287 if isinstance(raw, str):
288 raw = raw.encode("utf-8")
290 root = ET.fromstring(raw)
291 return self._xml_to_dict(root)
293 def _xml_to_dict(self, element) -> Dict[str, Any]:
294 """Convert XML element to dictionary."""
295 result = {}
297 # Add attributes
298 if element.attrib:
299 result["@attributes"] = element.attrib
301 # Add text content
302 if element.text and element.text.strip():
303 result["text"] = element.text.strip()
305 # Add children
306 for child in element:
307 child_data = self._xml_to_dict(child)
308 if child.tag in result:
309 # Convert to list if multiple children with same tag
310 if not isinstance(result[child.tag], list):
311 result[child.tag] = [result[child.tag]]
312 result[child.tag].append(child_data)
313 else:
314 result[child.tag] = child_data
316 return result
318 def get_transform_description(self) -> str:
319 """Get a description of the transformation."""
320 return f"Parse {self.format} data from '{self.field}' to '{self.output_field}'"
323class FileAppender(ITransformFunction):
324 """Append data to a file."""
326 def __init__(
327 self,
328 file_path: str,
329 format: str = "json", # "json", "csv", "lines"
330 field: str = "data",
331 buffer_size: int = 100,
332 create_if_missing: bool = True,
333 ):
334 """Initialize the file appender.
336 Args:
337 file_path: Path to file to append to.
338 format: Format to write data in.
339 field: Field containing data to append.
340 buffer_size: Number of records to buffer before writing.
341 create_if_missing: Create file if it doesn't exist.
342 """
343 self.file_path = Path(file_path)
344 self.format = format
345 self.field = field
346 self.buffer_size = buffer_size
347 self.create_if_missing = create_if_missing
348 self._buffer: List[Any] = []
350 async def transform(self, data: Dict[str, Any]) -> Dict[str, Any]:
351 """Transform data by appending to file.
353 Args:
354 data: Input data containing records to append.
356 Returns:
357 Data with append status.
358 """
359 records = data.get(self.field)
360 if records is None:
361 return data
363 # Add to buffer
364 if isinstance(records, list):
365 self._buffer.extend(records)
366 else:
367 self._buffer.append(records)
369 # Write if buffer is full
370 written = 0
371 if len(self._buffer) >= self.buffer_size:
372 written = await self._write_buffer()
374 return {
375 **data,
376 "appended_count": written,
377 "buffer_size": len(self._buffer),
378 }
380 async def _write_buffer(self) -> int:
381 """Write buffer to file."""
382 if not self._buffer:
383 return 0
385 # Create file if needed
386 if self.create_if_missing and not self.file_path.exists():
387 self.file_path.parent.mkdir(parents=True, exist_ok=True)
388 self.file_path.touch()
390 count = len(self._buffer)
392 if self.format == "json":
393 # Append to JSON array
394 existing = []
395 if self.file_path.exists() and self.file_path.stat().st_size > 0:
396 with open(self.file_path) as f:
397 existing = json.load(f)
399 existing.extend(self._buffer)
401 with open(self.file_path, "w") as f:
402 json.dump(existing, f, indent=2)
404 elif self.format == "csv":
405 # Append to CSV
406 import csv
408 file_exists = self.file_path.exists() and self.file_path.stat().st_size > 0
410 with open(self.file_path, "a", newline="") as f:
411 if self._buffer and isinstance(self._buffer[0], dict):
412 writer = csv.DictWriter(f, fieldnames=self._buffer[0].keys())
413 if not file_exists:
414 writer.writeheader()
415 writer.writerows(self._buffer)
416 else:
417 writer = csv.writer(f)
418 writer.writerows(self._buffer)
420 elif self.format == "lines":
421 # Append lines
422 with open(self.file_path, "a") as f:
423 for record in self._buffer:
424 if isinstance(record, dict):
425 f.write(json.dumps(record) + "\n")
426 else:
427 f.write(str(record) + "\n")
429 else:
430 raise TransformError(f"Unsupported format: {self.format}")
432 self._buffer.clear()
433 return count
435 async def flush(self) -> int:
436 """Flush any remaining buffered data."""
437 return await self._write_buffer()
439 def get_transform_description(self) -> str:
440 """Get a description of the transformation."""
441 return f"Append {self.format} data from '{self.field}' to {self.file_path}"
444class StreamAggregator(ITransformFunction):
445 """Aggregate streaming data using various functions."""
447 def __init__(
448 self,
449 aggregations: Dict[str, Dict[str, Any]],
450 group_by: List[str] | None = None,
451 window_size: int | None = None,
452 ):
453 """Initialize the stream aggregator.
455 Args:
456 aggregations: Dictionary of aggregation specifications.
457 Keys are output field names, values are:
458 {"function": "sum|avg|min|max|count", "field": "source_field"}
459 group_by: Fields to group by before aggregating.
460 window_size: Number of records in sliding window.
461 """
462 self.aggregations = aggregations
463 self.group_by = group_by
464 self.window_size = window_size
465 self._window: List[Dict[str, Any]] = []
466 self._groups: Dict[tuple, List[Dict[str, Any]]] = {}
468 def transform(self, data: Dict[str, Any]) -> Dict[str, Any]:
469 """Transform data by aggregating stream.
471 Args:
472 data: Input data (single record or batch).
474 Returns:
475 Data with aggregation results.
476 """
477 # Add to window/groups
478 records = data.get("records", [data])
480 if self.group_by:
481 # Group-based aggregation
482 for record in records:
483 key = tuple(record.get(field) for field in self.group_by)
484 if key not in self._groups:
485 self._groups[key] = []
486 self._groups[key].append(record)
488 # Apply window size per group
489 if self.window_size and len(self._groups[key]) > self.window_size:
490 self._groups[key] = self._groups[key][-self.window_size:]
492 # Compute aggregations per group
493 results = []
494 for key, group_records in self._groups.items():
495 result = dict(zip(self.group_by, key, strict=False))
496 for output_field, agg_spec in self.aggregations.items():
497 result[output_field] = self._compute_aggregation(group_records, agg_spec)
498 results.append(result)
500 return {**data, "aggregations": results}
502 else:
503 # Global aggregation
504 self._window.extend(records)
506 # Apply window size
507 if self.window_size and len(self._window) > self.window_size:
508 self._window = self._window[-self.window_size:]
510 # Compute aggregations
511 result = {}
512 for output_field, agg_spec in self.aggregations.items():
513 result[output_field] = self._compute_aggregation(self._window, agg_spec)
515 return {**data, "aggregation": result}
517 def _compute_aggregation(
518 self, records: List[Dict[str, Any]], spec: Dict[str, Any]
519 ) -> Any:
520 """Compute a single aggregation."""
521 func = spec["function"]
522 field = spec.get("field")
524 if func == "count":
525 return len(records)
527 if not field:
528 raise TransformError(f"Field required for {func} aggregation")
530 values: List[Any] = [r.get(field) for r in records if r.get(field) is not None]
532 if not values:
533 return None
535 if func == "sum":
536 return sum(values) # type: ignore
537 elif func == "avg":
538 return sum(values) / len(values) # type: ignore
539 elif func == "min":
540 return min(values) # type: ignore
541 elif func == "max":
542 return max(values) # type: ignore
543 else:
544 raise TransformError(f"Unknown aggregation function: {func}")
546 def get_transform_description(self) -> str:
547 """Get a description of the transformation."""
548 agg_list = list(self.aggregations.keys())[:3]
549 agg_str = ", ".join(agg_list)
550 if len(self.aggregations) > 3:
551 agg_str += "..."
552 group_str = f" grouped by {', '.join(self.group_by)}" if self.group_by else ""
553 return f"Aggregate {agg_str}{group_str}"
556# Convenience functions for creating streaming functions
557def read_chunks(source: str, size: int = 1000, **kwargs) -> ChunkReader:
558 """Create a ChunkReader."""
559 return ChunkReader(source, size, **kwargs)
562def parse(format: str, **kwargs) -> RecordParser:
563 """Create a RecordParser."""
564 return RecordParser(format, **kwargs)
567def append_to_file(path: str, **kwargs) -> FileAppender:
568 """Create a FileAppender."""
569 return FileAppender(path, **kwargs)
572def aggregate(**aggregations: Dict[str, Any]) -> StreamAggregator:
573 """Create a StreamAggregator."""
574 return StreamAggregator(aggregations)