Coverage for session_buddy / di / container.py: 53.42%

57 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-04 00:43 -0800

1from __future__ import annotations 

2 

3import inspect 

4import typing as t 

5 

6from oneiric.core.resolution import Candidate, Resolver 

7 

8T = t.TypeVar("T") 

9 

10 

11class Inject[T]: 

12 """Typing helper to mirror DI-injected parameters.""" 

13 

14 

15class ServiceContainer: 

16 def __init__(self) -> None: 

17 self._resolver = Resolver() 

18 self._instances: dict[str, t.Any] = {} 

19 

20 def set(self, key: object, instance: t.Any) -> None: 

21 name = self._key_name(key) 

22 self._instances[name] = instance 

23 self._resolver.register( 

24 Candidate( 

25 domain="service", 

26 key=name, 

27 provider="instance", 

28 factory=lambda: instance, 

29 ) 

30 ) 

31 

32 def register_factory( 

33 self, 

34 key: object, 

35 factory: t.Callable[[], t.Any], 

36 *, 

37 provider: str | None = None, 

38 ) -> None: 

39 name = self._key_name(key) 

40 self._resolver.register( 

41 Candidate(domain="service", key=name, provider=provider, factory=factory) 

42 ) 

43 

44 def get_sync(self, key: object) -> t.Any: 

45 name = self._key_name(key) 

46 if name in self._instances: 

47 return self._instances[name] 

48 candidate = self._resolver.resolve("service", name) 

49 if not candidate: 49 ↛ 52line 49 didn't jump to line 52 because the condition on line 49 was always true

50 msg = f"Service not registered: {name}" 

51 raise KeyError(msg) 

52 instance = candidate.factory() 

53 if inspect.isawaitable(instance): 

54 msg = f"Async factory registered for sync get: {name}" 

55 raise RuntimeError(msg) 

56 self._instances[name] = instance 

57 return instance 

58 

59 def get(self, key: object) -> t.Any: 

60 return self.get_sync(key) 

61 

62 async def get_async(self, key: object) -> t.Any: 

63 name = self._key_name(key) 

64 if name in self._instances: 

65 return self._instances[name] 

66 candidate = self._resolver.resolve("service", name) 

67 if not candidate: 

68 msg = f"Service not registered: {name}" 

69 raise KeyError(msg) 

70 instance = candidate.factory() 

71 if inspect.isawaitable(instance): 

72 instance = await instance 

73 self._instances[name] = instance 

74 return instance 

75 

76 def reset(self) -> None: 

77 self._instances.clear() 

78 self._resolver = Resolver() 

79 

80 def _key_name(self, key: object) -> str: 

81 if isinstance(key, str): 81 ↛ 82line 81 didn't jump to line 82 because the condition on line 81 was never true

82 return key 

83 if hasattr(key, "__module__") and hasattr(key, "__qualname__"): 83 ↛ 85line 83 didn't jump to line 85 because the condition on line 83 was always true

84 return f"{key.__module__}.{key.__qualname__}" 

85 return str(key) 

86 

87 

88depends = ServiceContainer() 

89 

90__all__ = ["Inject", "ServiceContainer", "depends"]