Coverage for src/dataknobs_fsm/functions/library/streaming.py: 0%

277 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-11-08 14:11 -0700

1"""Built-in streaming functions for FSM. 

2 

3This module provides streaming-related functions that can be referenced 

4in FSM configurations for processing large data sets efficiently. 

5""" 

6 

7import csv 

8import json 

9from pathlib import Path 

10from typing import Any, Dict, List, Union 

11 

12from dataknobs_fsm.functions.base import ITransformFunction, TransformError 

13from dataknobs_fsm.streaming.core import IStreamSource 

14 

15 

16class ChunkReader(ITransformFunction): 

17 """Read data in chunks from a source.""" 

18 

19 def __init__( 

20 self, 

21 source: Union[str, IStreamSource], 

22 chunk_size: int = 1000, 

23 format: str = "auto", # "auto", "json", "csv", "lines" 

24 ): 

25 """Initialize the chunk reader. 

26  

27 Args: 

28 source: Data source (file path or stream source). 

29 chunk_size: Number of records per chunk. 

30 format: Data format to expect. 

31 """ 

32 self.source = source 

33 self.chunk_size = chunk_size 

34 self.format = format 

35 

36 async def transform(self, data: Dict[str, Any]) -> Dict[str, Any]: 

37 """Transform data by reading next chunk from source. 

38  

39 Args: 

40 data: Input data (may contain chunk state). 

41  

42 Returns: 

43 Data with next chunk of records. 

44 """ 

45 # Get or initialize chunk state 

46 chunk_state = data.get("_chunk_state", {}) 

47 

48 if isinstance(self.source, str): 

49 # File source 

50 file_path = Path(self.source) 

51 if not file_path.exists(): 

52 raise TransformError(f"File not found: {self.source}") 

53 

54 # Determine format 

55 format = self.format 

56 if format == "auto": 

57 format = self._detect_format(file_path) 

58 

59 # Read chunk based on format 

60 if format == "json": 

61 chunk = await self._read_json_chunk(file_path, chunk_state) 

62 elif format == "csv": 

63 chunk = await self._read_csv_chunk(file_path, chunk_state) 

64 elif format == "lines": 

65 chunk = await self._read_lines_chunk(file_path, chunk_state) 

66 else: 

67 raise TransformError(f"Unsupported format: {format}") 

68 

69 else: 

70 # Stream source 

71 chunk = await self._read_stream_chunk(self.source, chunk_state) 

72 

73 return { 

74 **data, 

75 "chunk": chunk["records"], 

76 "has_more": chunk["has_more"], 

77 "_chunk_state": chunk["state"], 

78 } 

79 

80 def _detect_format(self, file_path: Path) -> str: 

81 """Detect file format from extension.""" 

82 suffix = file_path.suffix.lower() 

83 if suffix == ".json": 

84 return "json" 

85 elif suffix == ".csv": 

86 return "csv" 

87 else: 

88 return "lines" 

89 

90 def get_transform_description(self) -> str: 

91 """Get a description of the transformation.""" 

92 source_str = str(self.source) if isinstance(self.source, str) else "stream" 

93 return f"Read {self.chunk_size} records from {source_str} in {self.format} format" 

94 

95 async def _read_json_chunk( 

96 self, file_path: Path, state: Dict[str, Any] 

97 ) -> Dict[str, Any]: 

98 """Read chunk from JSON file.""" 

99 offset = state.get("offset", 0) 

100 

101 # For JSON, we need to load the entire file (or use streaming JSON parser) 

102 with open(file_path) as f: 

103 data = json.load(f) 

104 

105 if isinstance(data, list): 

106 chunk = data[offset:offset + self.chunk_size] 

107 has_more = offset + self.chunk_size < len(data) 

108 new_offset = offset + len(chunk) 

109 else: 

110 # Single object 

111 if offset == 0: 

112 chunk = [data] 

113 has_more = False 

114 new_offset = 1 

115 else: 

116 chunk = [] 

117 has_more = False 

118 new_offset = offset 

119 

120 return { 

121 "records": chunk, 

122 "has_more": has_more, 

123 "state": {"offset": new_offset}, 

124 } 

125 

126 async def _read_csv_chunk( 

127 self, file_path: Path, state: Dict[str, Any] 

128 ) -> Dict[str, Any]: 

129 """Read chunk from CSV file.""" 

130 offset = state.get("offset", 0) 

131 records = [] 

132 

133 with open(file_path) as f: 

134 reader = csv.DictReader(f) 

135 

136 # Skip to offset 

137 for _ in range(offset): 

138 try: 

139 next(reader) 

140 except StopIteration: 

141 break 

142 

143 # Read chunk 

144 for _ in range(self.chunk_size): 

145 try: 

146 records.append(next(reader)) 

147 except StopIteration: 

148 break 

149 

150 has_more = len(records) == self.chunk_size 

151 new_offset = offset + len(records) 

152 

153 return { 

154 "records": records, 

155 "has_more": has_more, 

156 "state": {"offset": new_offset}, 

157 } 

158 

159 async def _read_lines_chunk( 

160 self, file_path: Path, state: Dict[str, Any] 

161 ) -> Dict[str, Any]: 

162 """Read chunk of lines from file.""" 

163 offset = state.get("offset", 0) 

164 records = [] 

165 

166 with open(file_path) as f: 

167 # Skip to offset 

168 for _ in range(offset): 

169 if not f.readline(): 

170 break 

171 

172 # Read chunk 

173 for _ in range(self.chunk_size): 

174 line = f.readline() 

175 if not line: 

176 break 

177 records.append({"line": line.strip()}) 

178 

179 has_more = len(records) == self.chunk_size 

180 new_offset = offset + len(records) 

181 

182 return { 

183 "records": records, 

184 "has_more": has_more, 

185 "state": {"offset": new_offset}, 

186 } 

187 

188 async def _read_stream_chunk( 

189 self, source: IStreamSource, state: Dict[str, Any] 

190 ) -> Dict[str, Any]: 

191 """Read chunk from stream source.""" 

192 records = [] 

193 

194 async for record in source.read(self.chunk_size): 

195 records.append(record) 

196 

197 has_more = len(records) == self.chunk_size 

198 

199 return { 

200 "records": records, 

201 "has_more": has_more, 

202 "state": {"stream_position": source.position if hasattr(source, "position") else None}, 

203 } 

204 

205 

206class RecordParser(ITransformFunction): 

207 """Parse records from various formats.""" 

208 

209 def __init__( 

210 self, 

211 format: str, 

212 field: str = "raw", 

213 output_field: str = "parsed", 

214 options: Dict[str, Any] | None = None, 

215 ): 

216 """Initialize the record parser. 

217  

218 Args: 

219 format: Format to parse ("json", "csv", "xml", "yaml"). 

220 field: Field containing raw data to parse. 

221 output_field: Field to store parsed data. 

222 options: Format-specific parsing options. 

223 """ 

224 self.format = format 

225 self.field = field 

226 self.output_field = output_field 

227 self.options = options or {} 

228 

229 def transform(self, data: Dict[str, Any]) -> Dict[str, Any]: 

230 """Transform data by parsing records. 

231  

232 Args: 

233 data: Input data containing raw records. 

234  

235 Returns: 

236 Data with parsed records. 

237 """ 

238 raw_data = data.get(self.field) 

239 if raw_data is None: 

240 return data 

241 

242 try: 

243 if self.format == "json": 

244 parsed = self._parse_json(raw_data) 

245 elif self.format == "csv": 

246 parsed = self._parse_csv(raw_data) 

247 elif self.format == "yaml": 

248 parsed = self._parse_yaml(raw_data) 

249 elif self.format == "xml": 

250 parsed = self._parse_xml(raw_data) 

251 else: 

252 raise TransformError(f"Unsupported format: {self.format}") 

253 

254 return { 

255 **data, 

256 self.output_field: parsed, 

257 } 

258 

259 except Exception as e: 

260 raise TransformError(f"Failed to parse {self.format}: {e}") from e 

261 

262 def _parse_json(self, raw: Union[str, bytes]) -> Any: 

263 """Parse JSON data.""" 

264 if isinstance(raw, bytes): 

265 raw = raw.decode("utf-8") 

266 return json.loads(raw) 

267 

268 def _parse_csv(self, raw: Union[str, bytes]) -> List[Dict[str, Any]]: 

269 """Parse CSV data.""" 

270 if isinstance(raw, bytes): 

271 raw = raw.decode("utf-8") 

272 

273 import io 

274 reader = csv.DictReader(io.StringIO(raw), **self.options) 

275 return list(reader) 

276 

277 def _parse_yaml(self, raw: Union[str, bytes]) -> Any: 

278 """Parse YAML data.""" 

279 import yaml 

280 if isinstance(raw, bytes): 

281 raw = raw.decode("utf-8") 

282 return yaml.safe_load(raw) 

283 

284 def _parse_xml(self, raw: Union[str, bytes]) -> Dict[str, Any]: 

285 """Parse XML data.""" 

286 import xml.etree.ElementTree as ET 

287 if isinstance(raw, str): 

288 raw = raw.encode("utf-8") 

289 

290 root = ET.fromstring(raw) 

291 return self._xml_to_dict(root) 

292 

293 def _xml_to_dict(self, element) -> Dict[str, Any]: 

294 """Convert XML element to dictionary.""" 

295 result = {} 

296 

297 # Add attributes 

298 if element.attrib: 

299 result["@attributes"] = element.attrib 

300 

301 # Add text content 

302 if element.text and element.text.strip(): 

303 result["text"] = element.text.strip() 

304 

305 # Add children 

306 for child in element: 

307 child_data = self._xml_to_dict(child) 

308 if child.tag in result: 

309 # Convert to list if multiple children with same tag 

310 if not isinstance(result[child.tag], list): 

311 result[child.tag] = [result[child.tag]] 

312 result[child.tag].append(child_data) 

313 else: 

314 result[child.tag] = child_data 

315 

316 return result 

317 

318 def get_transform_description(self) -> str: 

319 """Get a description of the transformation.""" 

320 return f"Parse {self.format} data from '{self.field}' to '{self.output_field}'" 

321 

322 

323class FileAppender(ITransformFunction): 

324 """Append data to a file.""" 

325 

326 def __init__( 

327 self, 

328 file_path: str, 

329 format: str = "json", # "json", "csv", "lines" 

330 field: str = "data", 

331 buffer_size: int = 100, 

332 create_if_missing: bool = True, 

333 ): 

334 """Initialize the file appender. 

335  

336 Args: 

337 file_path: Path to file to append to. 

338 format: Format to write data in. 

339 field: Field containing data to append. 

340 buffer_size: Number of records to buffer before writing. 

341 create_if_missing: Create file if it doesn't exist. 

342 """ 

343 self.file_path = Path(file_path) 

344 self.format = format 

345 self.field = field 

346 self.buffer_size = buffer_size 

347 self.create_if_missing = create_if_missing 

348 self._buffer: List[Any] = [] 

349 

350 async def transform(self, data: Dict[str, Any]) -> Dict[str, Any]: 

351 """Transform data by appending to file. 

352  

353 Args: 

354 data: Input data containing records to append. 

355  

356 Returns: 

357 Data with append status. 

358 """ 

359 records = data.get(self.field) 

360 if records is None: 

361 return data 

362 

363 # Add to buffer 

364 if isinstance(records, list): 

365 self._buffer.extend(records) 

366 else: 

367 self._buffer.append(records) 

368 

369 # Write if buffer is full 

370 written = 0 

371 if len(self._buffer) >= self.buffer_size: 

372 written = await self._write_buffer() 

373 

374 return { 

375 **data, 

376 "appended_count": written, 

377 "buffer_size": len(self._buffer), 

378 } 

379 

380 async def _write_buffer(self) -> int: 

381 """Write buffer to file.""" 

382 if not self._buffer: 

383 return 0 

384 

385 # Create file if needed 

386 if self.create_if_missing and not self.file_path.exists(): 

387 self.file_path.parent.mkdir(parents=True, exist_ok=True) 

388 self.file_path.touch() 

389 

390 count = len(self._buffer) 

391 

392 if self.format == "json": 

393 # Append to JSON array 

394 existing = [] 

395 if self.file_path.exists() and self.file_path.stat().st_size > 0: 

396 with open(self.file_path) as f: 

397 existing = json.load(f) 

398 

399 existing.extend(self._buffer) 

400 

401 with open(self.file_path, "w") as f: 

402 json.dump(existing, f, indent=2) 

403 

404 elif self.format == "csv": 

405 # Append to CSV 

406 import csv 

407 

408 file_exists = self.file_path.exists() and self.file_path.stat().st_size > 0 

409 

410 with open(self.file_path, "a", newline="") as f: 

411 if self._buffer and isinstance(self._buffer[0], dict): 

412 writer = csv.DictWriter(f, fieldnames=self._buffer[0].keys()) 

413 if not file_exists: 

414 writer.writeheader() 

415 writer.writerows(self._buffer) 

416 else: 

417 writer = csv.writer(f) 

418 writer.writerows(self._buffer) 

419 

420 elif self.format == "lines": 

421 # Append lines 

422 with open(self.file_path, "a") as f: 

423 for record in self._buffer: 

424 if isinstance(record, dict): 

425 f.write(json.dumps(record) + "\n") 

426 else: 

427 f.write(str(record) + "\n") 

428 

429 else: 

430 raise TransformError(f"Unsupported format: {self.format}") 

431 

432 self._buffer.clear() 

433 return count 

434 

435 async def flush(self) -> int: 

436 """Flush any remaining buffered data.""" 

437 return await self._write_buffer() 

438 

439 def get_transform_description(self) -> str: 

440 """Get a description of the transformation.""" 

441 return f"Append {self.format} data from '{self.field}' to {self.file_path}" 

442 

443 

444class StreamAggregator(ITransformFunction): 

445 """Aggregate streaming data using various functions.""" 

446 

447 def __init__( 

448 self, 

449 aggregations: Dict[str, Dict[str, Any]], 

450 group_by: List[str] | None = None, 

451 window_size: int | None = None, 

452 ): 

453 """Initialize the stream aggregator. 

454  

455 Args: 

456 aggregations: Dictionary of aggregation specifications. 

457 Keys are output field names, values are: 

458 {"function": "sum|avg|min|max|count", "field": "source_field"} 

459 group_by: Fields to group by before aggregating. 

460 window_size: Number of records in sliding window. 

461 """ 

462 self.aggregations = aggregations 

463 self.group_by = group_by 

464 self.window_size = window_size 

465 self._window: List[Dict[str, Any]] = [] 

466 self._groups: Dict[tuple, List[Dict[str, Any]]] = {} 

467 

468 def transform(self, data: Dict[str, Any]) -> Dict[str, Any]: 

469 """Transform data by aggregating stream. 

470  

471 Args: 

472 data: Input data (single record or batch). 

473  

474 Returns: 

475 Data with aggregation results. 

476 """ 

477 # Add to window/groups 

478 records = data.get("records", [data]) 

479 

480 if self.group_by: 

481 # Group-based aggregation 

482 for record in records: 

483 key = tuple(record.get(field) for field in self.group_by) 

484 if key not in self._groups: 

485 self._groups[key] = [] 

486 self._groups[key].append(record) 

487 

488 # Apply window size per group 

489 if self.window_size and len(self._groups[key]) > self.window_size: 

490 self._groups[key] = self._groups[key][-self.window_size:] 

491 

492 # Compute aggregations per group 

493 results = [] 

494 for key, group_records in self._groups.items(): 

495 result = dict(zip(self.group_by, key, strict=False)) 

496 for output_field, agg_spec in self.aggregations.items(): 

497 result[output_field] = self._compute_aggregation(group_records, agg_spec) 

498 results.append(result) 

499 

500 return {**data, "aggregations": results} 

501 

502 else: 

503 # Global aggregation 

504 self._window.extend(records) 

505 

506 # Apply window size 

507 if self.window_size and len(self._window) > self.window_size: 

508 self._window = self._window[-self.window_size:] 

509 

510 # Compute aggregations 

511 result = {} 

512 for output_field, agg_spec in self.aggregations.items(): 

513 result[output_field] = self._compute_aggregation(self._window, agg_spec) 

514 

515 return {**data, "aggregation": result} 

516 

517 def _compute_aggregation( 

518 self, records: List[Dict[str, Any]], spec: Dict[str, Any] 

519 ) -> Any: 

520 """Compute a single aggregation.""" 

521 func = spec["function"] 

522 field = spec.get("field") 

523 

524 if func == "count": 

525 return len(records) 

526 

527 if not field: 

528 raise TransformError(f"Field required for {func} aggregation") 

529 

530 values: List[Any] = [r.get(field) for r in records if r.get(field) is not None] 

531 

532 if not values: 

533 return None 

534 

535 if func == "sum": 

536 return sum(values) # type: ignore 

537 elif func == "avg": 

538 return sum(values) / len(values) # type: ignore 

539 elif func == "min": 

540 return min(values) # type: ignore 

541 elif func == "max": 

542 return max(values) # type: ignore 

543 else: 

544 raise TransformError(f"Unknown aggregation function: {func}") 

545 

546 def get_transform_description(self) -> str: 

547 """Get a description of the transformation.""" 

548 agg_list = list(self.aggregations.keys())[:3] 

549 agg_str = ", ".join(agg_list) 

550 if len(self.aggregations) > 3: 

551 agg_str += "..." 

552 group_str = f" grouped by {', '.join(self.group_by)}" if self.group_by else "" 

553 return f"Aggregate {agg_str}{group_str}" 

554 

555 

556# Convenience functions for creating streaming functions 

557def read_chunks(source: str, size: int = 1000, **kwargs) -> ChunkReader: 

558 """Create a ChunkReader.""" 

559 return ChunkReader(source, size, **kwargs) 

560 

561 

562def parse(format: str, **kwargs) -> RecordParser: 

563 """Create a RecordParser.""" 

564 return RecordParser(format, **kwargs) 

565 

566 

567def append_to_file(path: str, **kwargs) -> FileAppender: 

568 """Create a FileAppender.""" 

569 return FileAppender(path, **kwargs) 

570 

571 

572def aggregate(**aggregations: Dict[str, Any]) -> StreamAggregator: 

573 """Create a StreamAggregator.""" 

574 return StreamAggregator(aggregations)