Coverage for smart_pipeline / models.py: 100%
26 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-01 18:35 +0200
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-01 18:35 +0200
1import inspect
2from dataclasses import dataclass, is_dataclass
3from typing import Callable, Optional, List, get_type_hints
5@dataclass(eq=False)
6class Step:
7 """
8 Represents a single node in the computation graph.
9 eq=False ensures hashability is based on object identity.
10 """
11 fn: Callable
12 manual_outputs: Optional[List[str]] = None
14 @property
15 def name(self) -> str:
16 return self.fn.__name__
18 def get_signature(self) -> inspect.Signature:
19 """
20 Robustly retrieves the signature of the underlying function,
21 peeling off any decorators (like @cached) to find the real inputs.
22 """
23 original_fn = inspect.unwrap(self.fn)
24 return inspect.signature(original_fn)
26 def resolve_output_names(self) -> List[str]:
27 """Determines variable names this step produces."""
28 if self.manual_outputs:
29 return self.manual_outputs
31 # FIX: Use get_type_hints to correctly resolve string annotations
32 # (common with 'from __future__ import annotations' or forward refs)
33 try:
34 original_fn = inspect.unwrap(self.fn)
35 hints = get_type_hints(original_fn)
36 ann = hints.get('return')
37 except Exception:
38 # Fallback to standard inspection if get_type_hints fails
39 # (e.g., closures without global context)
40 sig = self.get_signature()
41 ann = sig.return_annotation
43 # If the function returns a Dataclass, use field names
44 if isinstance(ann, type) and is_dataclass(ann):
45 return list(ann.__dataclass_fields__.keys())
47 # Default: use function name
48 return [self.name]