Coverage for smartmdao / models.py: 100%

26 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-02 20:01 +0200

1import inspect 

2from dataclasses import dataclass, is_dataclass 

3from typing import Callable, Optional, List, get_type_hints 

4 

5@dataclass(eq=False) 

6class Step: 

7 """ 

8 Represents a single node in the computation graph. 

9 eq=False ensures hashability is based on object identity. 

10 """ 

11 fn: Callable 

12 manual_outputs: Optional[List[str]] = None 

13 

14 @property 

15 def name(self) -> str: 

16 return self.fn.__name__ 

17 

18 def get_signature(self) -> inspect.Signature: 

19 """ 

20 Robustly retrieves the signature of the underlying function, 

21 peeling off any decorators (like @cached) to find the real inputs. 

22 """ 

23 original_fn = inspect.unwrap(self.fn) 

24 return inspect.signature(original_fn) 

25 

26 def resolve_output_names(self) -> List[str]: 

27 """Determines variable names this step produces.""" 

28 if self.manual_outputs: 

29 return self.manual_outputs 

30 

31 # FIX: Use get_type_hints to correctly resolve string annotations  

32 # (common with 'from __future__ import annotations' or forward refs) 

33 try: 

34 original_fn = inspect.unwrap(self.fn) 

35 hints = get_type_hints(original_fn) 

36 ann = hints.get('return') 

37 except Exception: 

38 # Fallback to standard inspection if get_type_hints fails  

39 # (e.g., closures without global context) 

40 sig = self.get_signature() 

41 ann = sig.return_annotation 

42 

43 # If the function returns a Dataclass, use field names 

44 if isinstance(ann, type) and is_dataclass(ann): 

45 return list(ann.__dataclass_fields__.keys()) 

46 

47 # Default: use function name 

48 return [self.name]