Coverage for /Users/ajo/work/jumpstarter/jumpstarter/packages/jumpstarter/jumpstarter/driver/base.py: 69%

135 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-26 17:27 +0200

1""" 

2Base classes for drivers and driver clients 

3""" 

4 

5from __future__ import annotations 

6 

7import logging 

8import os 

9from abc import ABCMeta, abstractmethod 

10from contextlib import asynccontextmanager 

11from dataclasses import field 

12from inspect import isasyncgenfunction, iscoroutinefunction 

13from itertools import chain 

14from typing import Any 

15from uuid import UUID, uuid4 

16 

17import aiohttp 

18from anyio import to_thread 

19from grpc import StatusCode 

20from jumpstarter_protocol import jumpstarter_pb2, jumpstarter_pb2_grpc, router_pb2_grpc 

21from pydantic import TypeAdapter 

22from pydantic.dataclasses import dataclass 

23 

24from .decorators import ( 

25 MARKER_DRIVERCALL, 

26 MARKER_MAGIC, 

27 MARKER_STREAMCALL, 

28 MARKER_STREAMING_DRIVERCALL, 

29) 

30from jumpstarter.common import Metadata 

31from jumpstarter.common.resources import ClientStreamResource, PresignedRequestResource, Resource, ResourceMetadata 

32from jumpstarter.common.serde import decode_value, encode_value 

33from jumpstarter.common.streams import ( 

34 DriverStreamRequest, 

35 ResourceStreamRequest, 

36) 

37from jumpstarter.config.env import JMP_DISABLE_COMPRESSION 

38from jumpstarter.streams.aiohttp import AiohttpStreamReaderStream 

39from jumpstarter.streams.common import create_memory_stream 

40from jumpstarter.streams.encoding import Compression, compress_stream 

41from jumpstarter.streams.metadata import MetadataStream 

42from jumpstarter.streams.progress import ProgressStream 

43 

44SUPPORTED_CONTENT_ENCODINGS = ( 

45 {} 

46 if os.environ.get(JMP_DISABLE_COMPRESSION) == "1" 

47 else { 

48 Compression.GZIP, 

49 Compression.XZ, 

50 Compression.BZ2, 

51 } 

52) 

53 

54 

55@dataclass(kw_only=True) 

56class Driver( 

57 Metadata, 

58 jumpstarter_pb2_grpc.ExporterServiceServicer, 

59 router_pb2_grpc.RouterServiceServicer, 

60 metaclass=ABCMeta, 

61): 

62 """Base class for drivers 

63 

64 Drivers should at the minimum implement the `client` method. 

65 

66 Regular or streaming driver calls can be marked with the `export` decorator. 

67 Raw stream constructors can be marked with the `exportstream` decorator. 

68 """ 

69 

70 children: dict[str, Driver] = field(default_factory=dict) 

71 

72 resources: dict[UUID, Any] = field(default_factory=dict, init=False) 

73 """Dict of client side resources""" 

74 

75 log_level: str = "INFO" 

76 logger: logging.Logger = field(init=False) 

77 

78 def __post_init__(self): 

79 if hasattr(super(), "__post_init__"): 

80 super().__post_init__() 

81 

82 self.logger = logging.getLogger(self.__class__.__name__) 

83 self.logger.setLevel(self.log_level) 

84 

85 def close(self): 

86 for child in self.children.values(): 

87 child.close() 

88 

89 def reset(self): 

90 for child in self.children.values(): 

91 child.reset() 

92 

93 @classmethod 

94 @abstractmethod 

95 def client(cls) -> str: 

96 """ 

97 Return full import path of the corresponding driver client class 

98 """ 

99 

100 def extra_labels(self) -> dict[str, str]: 

101 return {} 

102 

103 async def DriverCall(self, request, context): 

104 """ 

105 :meta private: 

106 """ 

107 try: 

108 method = await self.__lookup_drivercall(request.method, context, MARKER_DRIVERCALL) 

109 

110 args = [decode_value(arg) for arg in request.args] 

111 

112 if iscoroutinefunction(method): 

113 result = await method(*args) 

114 else: 

115 result = await to_thread.run_sync(method, *args) 

116 

117 return jumpstarter_pb2.DriverCallResponse( 

118 uuid=str(uuid4()), 

119 result=encode_value(result), 

120 ) 

121 except NotImplementedError as e: 

122 await context.abort(StatusCode.UNIMPLEMENTED, str(e)) 

123 except ValueError as e: 

124 await context.abort(StatusCode.INVALID_ARGUMENT, str(e)) 

125 except TimeoutError as e: 

126 await context.abort(StatusCode.DEADLINE_EXCEEDED, str(e)) 

127 except Exception as e: 

128 await context.abort(StatusCode.UNKNOWN, str(e)) 

129 

130 async def StreamingDriverCall(self, request, context): 

131 """ 

132 :meta private: 

133 """ 

134 try: 

135 method = await self.__lookup_drivercall(request.method, context, MARKER_STREAMING_DRIVERCALL) 

136 

137 args = [decode_value(arg) for arg in request.args] 

138 

139 if isasyncgenfunction(method): 

140 async for result in method(*args): 

141 yield jumpstarter_pb2.StreamingDriverCallResponse( 

142 uuid=str(uuid4()), 

143 result=encode_value(result), 

144 ) 

145 else: 

146 for result in await to_thread.run_sync(method, *args): 

147 yield jumpstarter_pb2.StreamingDriverCallResponse( 

148 uuid=str(uuid4()), 

149 result=encode_value(result), 

150 ) 

151 except NotImplementedError as e: 

152 await context.abort(StatusCode.UNIMPLEMENTED, str(e)) 

153 except ValueError as e: 

154 await context.abort(StatusCode.INVALID_ARGUMENT, str(e)) 

155 except TimeoutError as e: 

156 await context.abort(StatusCode.DEADLINE_EXCEEDED, str(e)) 

157 except Exception as e: 

158 await context.abort(StatusCode.UNKNOWN, str(e)) 

159 

160 @asynccontextmanager 

161 async def Stream(self, request, context): 

162 """ 

163 :meta private: 

164 """ 

165 match request: 

166 case DriverStreamRequest(method=driver_method): 

167 method = await self.__lookup_drivercall(driver_method, context, MARKER_STREAMCALL) 

168 

169 async with method() as stream: 

170 yield stream 

171 

172 case ResourceStreamRequest(): 

173 remote, resource = create_memory_stream() 

174 

175 resource_uuid = uuid4() 

176 

177 self.resources[resource_uuid] = resource 

178 

179 async with MetadataStream( 

180 stream=remote, 

181 metadata=ResourceMetadata.model_construct( 

182 resource=ClientStreamResource( 

183 uuid=resource_uuid, x_jmp_content_encoding=request.x_jmp_content_encoding 

184 ), 

185 x_jmp_accept_encoding=request.x_jmp_content_encoding 

186 if request.x_jmp_content_encoding in SUPPORTED_CONTENT_ENCODINGS 

187 else None, 

188 ).model_dump(mode="json", round_trip=True), 

189 ) as stream: 

190 yield stream 

191 

192 def report(self, *, root=None, parent=None, name=None): 

193 """ 

194 Create DriverInstanceReport 

195 

196 :meta private: 

197 """ 

198 

199 if root is None: 

200 root = self 

201 

202 return jumpstarter_pb2.DriverInstanceReport( 

203 uuid=str(self.uuid), 

204 parent_uuid=str(parent.uuid) if parent else None, 

205 labels=self.labels 

206 | self.extra_labels() 

207 | ({"jumpstarter.dev/client": self.client()}) 

208 | ({"jumpstarter.dev/name": name} if name else {}), 

209 ) 

210 

211 def enumerate(self, *, root=None, parent=None, name=None): 

212 """ 

213 Get list of self and child devices 

214 

215 :meta private: 

216 """ 

217 if root is None: 

218 root = self 

219 

220 return [(self.uuid, parent, name, self)] + list( 

221 chain(*[child.enumerate(root=root, parent=self, name=cname) for (cname, child) in self.children.items()]) 

222 ) 

223 

224 @asynccontextmanager 

225 async def resource(self, handle: str, timeout: int = 300): 

226 handle = TypeAdapter(Resource).validate_python(handle) 

227 match handle: 

228 case ClientStreamResource(uuid=uuid, x_jmp_content_encoding=content_encoding): 

229 async with self.resources[uuid] as stream: 

230 try: 

231 yield compress_stream(stream, content_encoding) 

232 finally: 

233 del self.resources[uuid] 

234 case PresignedRequestResource(headers=headers, url=url, method=method): 

235 client_timeout = aiohttp.ClientTimeout(total=timeout) 

236 match method: 

237 case "GET": 

238 async with aiohttp.request( 

239 method, url, headers=headers, raise_for_status=True, timeout=client_timeout 

240 ) as resp: 

241 async with AiohttpStreamReaderStream(reader=resp.content) as stream: 

242 yield ProgressStream(stream=stream, logging=True) 

243 case "PUT": 

244 remote, stream = create_memory_stream() 

245 async with aiohttp.request( 

246 method, url, headers=headers, raise_for_status=True, data=remote, timeout=client_timeout 

247 ) as resp: 

248 async with stream: 

249 yield ProgressStream(stream=stream, logging=True) 

250 case _: 

251 # INVARIANT: method is always one of GET or PUT, see PresignedRequestResource 

252 raise ValueError("unreachable") 

253 

254 async def __lookup_drivercall(self, name, context, marker): 

255 """Lookup drivercall by method name 

256 

257 Methods are checked against magic markers 

258 to avoid accidentally calling non-exported 

259 methods 

260 """ 

261 method = getattr(self, name, None) 

262 

263 if method is None: 

264 await context.abort(StatusCode.NOT_FOUND, f"method {name} not found on driver") 

265 

266 if getattr(method, marker, None) != MARKER_MAGIC: 

267 await context.abort(StatusCode.NOT_FOUND, f"method {name} missing marker {marker}") 

268 

269 return method