Coverage for /Users/ajo/work/jumpstarter/jumpstarter/packages/jumpstarter/jumpstarter/driver/base.py: 69%
135 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-26 17:27 +0200
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-26 17:27 +0200
1"""
2Base classes for drivers and driver clients
3"""
5from __future__ import annotations
7import logging
8import os
9from abc import ABCMeta, abstractmethod
10from contextlib import asynccontextmanager
11from dataclasses import field
12from inspect import isasyncgenfunction, iscoroutinefunction
13from itertools import chain
14from typing import Any
15from uuid import UUID, uuid4
17import aiohttp
18from anyio import to_thread
19from grpc import StatusCode
20from jumpstarter_protocol import jumpstarter_pb2, jumpstarter_pb2_grpc, router_pb2_grpc
21from pydantic import TypeAdapter
22from pydantic.dataclasses import dataclass
24from .decorators import (
25 MARKER_DRIVERCALL,
26 MARKER_MAGIC,
27 MARKER_STREAMCALL,
28 MARKER_STREAMING_DRIVERCALL,
29)
30from jumpstarter.common import Metadata
31from jumpstarter.common.resources import ClientStreamResource, PresignedRequestResource, Resource, ResourceMetadata
32from jumpstarter.common.serde import decode_value, encode_value
33from jumpstarter.common.streams import (
34 DriverStreamRequest,
35 ResourceStreamRequest,
36)
37from jumpstarter.config.env import JMP_DISABLE_COMPRESSION
38from jumpstarter.streams.aiohttp import AiohttpStreamReaderStream
39from jumpstarter.streams.common import create_memory_stream
40from jumpstarter.streams.encoding import Compression, compress_stream
41from jumpstarter.streams.metadata import MetadataStream
42from jumpstarter.streams.progress import ProgressStream
44SUPPORTED_CONTENT_ENCODINGS = (
45 {}
46 if os.environ.get(JMP_DISABLE_COMPRESSION) == "1"
47 else {
48 Compression.GZIP,
49 Compression.XZ,
50 Compression.BZ2,
51 }
52)
55@dataclass(kw_only=True)
56class Driver(
57 Metadata,
58 jumpstarter_pb2_grpc.ExporterServiceServicer,
59 router_pb2_grpc.RouterServiceServicer,
60 metaclass=ABCMeta,
61):
62 """Base class for drivers
64 Drivers should at the minimum implement the `client` method.
66 Regular or streaming driver calls can be marked with the `export` decorator.
67 Raw stream constructors can be marked with the `exportstream` decorator.
68 """
70 children: dict[str, Driver] = field(default_factory=dict)
72 resources: dict[UUID, Any] = field(default_factory=dict, init=False)
73 """Dict of client side resources"""
75 log_level: str = "INFO"
76 logger: logging.Logger = field(init=False)
78 def __post_init__(self):
79 if hasattr(super(), "__post_init__"):
80 super().__post_init__()
82 self.logger = logging.getLogger(self.__class__.__name__)
83 self.logger.setLevel(self.log_level)
85 def close(self):
86 for child in self.children.values():
87 child.close()
89 def reset(self):
90 for child in self.children.values():
91 child.reset()
93 @classmethod
94 @abstractmethod
95 def client(cls) -> str:
96 """
97 Return full import path of the corresponding driver client class
98 """
100 def extra_labels(self) -> dict[str, str]:
101 return {}
103 async def DriverCall(self, request, context):
104 """
105 :meta private:
106 """
107 try:
108 method = await self.__lookup_drivercall(request.method, context, MARKER_DRIVERCALL)
110 args = [decode_value(arg) for arg in request.args]
112 if iscoroutinefunction(method):
113 result = await method(*args)
114 else:
115 result = await to_thread.run_sync(method, *args)
117 return jumpstarter_pb2.DriverCallResponse(
118 uuid=str(uuid4()),
119 result=encode_value(result),
120 )
121 except NotImplementedError as e:
122 await context.abort(StatusCode.UNIMPLEMENTED, str(e))
123 except ValueError as e:
124 await context.abort(StatusCode.INVALID_ARGUMENT, str(e))
125 except TimeoutError as e:
126 await context.abort(StatusCode.DEADLINE_EXCEEDED, str(e))
127 except Exception as e:
128 await context.abort(StatusCode.UNKNOWN, str(e))
130 async def StreamingDriverCall(self, request, context):
131 """
132 :meta private:
133 """
134 try:
135 method = await self.__lookup_drivercall(request.method, context, MARKER_STREAMING_DRIVERCALL)
137 args = [decode_value(arg) for arg in request.args]
139 if isasyncgenfunction(method):
140 async for result in method(*args):
141 yield jumpstarter_pb2.StreamingDriverCallResponse(
142 uuid=str(uuid4()),
143 result=encode_value(result),
144 )
145 else:
146 for result in await to_thread.run_sync(method, *args):
147 yield jumpstarter_pb2.StreamingDriverCallResponse(
148 uuid=str(uuid4()),
149 result=encode_value(result),
150 )
151 except NotImplementedError as e:
152 await context.abort(StatusCode.UNIMPLEMENTED, str(e))
153 except ValueError as e:
154 await context.abort(StatusCode.INVALID_ARGUMENT, str(e))
155 except TimeoutError as e:
156 await context.abort(StatusCode.DEADLINE_EXCEEDED, str(e))
157 except Exception as e:
158 await context.abort(StatusCode.UNKNOWN, str(e))
160 @asynccontextmanager
161 async def Stream(self, request, context):
162 """
163 :meta private:
164 """
165 match request:
166 case DriverStreamRequest(method=driver_method):
167 method = await self.__lookup_drivercall(driver_method, context, MARKER_STREAMCALL)
169 async with method() as stream:
170 yield stream
172 case ResourceStreamRequest():
173 remote, resource = create_memory_stream()
175 resource_uuid = uuid4()
177 self.resources[resource_uuid] = resource
179 async with MetadataStream(
180 stream=remote,
181 metadata=ResourceMetadata.model_construct(
182 resource=ClientStreamResource(
183 uuid=resource_uuid, x_jmp_content_encoding=request.x_jmp_content_encoding
184 ),
185 x_jmp_accept_encoding=request.x_jmp_content_encoding
186 if request.x_jmp_content_encoding in SUPPORTED_CONTENT_ENCODINGS
187 else None,
188 ).model_dump(mode="json", round_trip=True),
189 ) as stream:
190 yield stream
192 def report(self, *, root=None, parent=None, name=None):
193 """
194 Create DriverInstanceReport
196 :meta private:
197 """
199 if root is None:
200 root = self
202 return jumpstarter_pb2.DriverInstanceReport(
203 uuid=str(self.uuid),
204 parent_uuid=str(parent.uuid) if parent else None,
205 labels=self.labels
206 | self.extra_labels()
207 | ({"jumpstarter.dev/client": self.client()})
208 | ({"jumpstarter.dev/name": name} if name else {}),
209 )
211 def enumerate(self, *, root=None, parent=None, name=None):
212 """
213 Get list of self and child devices
215 :meta private:
216 """
217 if root is None:
218 root = self
220 return [(self.uuid, parent, name, self)] + list(
221 chain(*[child.enumerate(root=root, parent=self, name=cname) for (cname, child) in self.children.items()])
222 )
224 @asynccontextmanager
225 async def resource(self, handle: str, timeout: int = 300):
226 handle = TypeAdapter(Resource).validate_python(handle)
227 match handle:
228 case ClientStreamResource(uuid=uuid, x_jmp_content_encoding=content_encoding):
229 async with self.resources[uuid] as stream:
230 try:
231 yield compress_stream(stream, content_encoding)
232 finally:
233 del self.resources[uuid]
234 case PresignedRequestResource(headers=headers, url=url, method=method):
235 client_timeout = aiohttp.ClientTimeout(total=timeout)
236 match method:
237 case "GET":
238 async with aiohttp.request(
239 method, url, headers=headers, raise_for_status=True, timeout=client_timeout
240 ) as resp:
241 async with AiohttpStreamReaderStream(reader=resp.content) as stream:
242 yield ProgressStream(stream=stream, logging=True)
243 case "PUT":
244 remote, stream = create_memory_stream()
245 async with aiohttp.request(
246 method, url, headers=headers, raise_for_status=True, data=remote, timeout=client_timeout
247 ) as resp:
248 async with stream:
249 yield ProgressStream(stream=stream, logging=True)
250 case _:
251 # INVARIANT: method is always one of GET or PUT, see PresignedRequestResource
252 raise ValueError("unreachable")
254 async def __lookup_drivercall(self, name, context, marker):
255 """Lookup drivercall by method name
257 Methods are checked against magic markers
258 to avoid accidentally calling non-exported
259 methods
260 """
261 method = getattr(self, name, None)
263 if method is None:
264 await context.abort(StatusCode.NOT_FOUND, f"method {name} not found on driver")
266 if getattr(method, marker, None) != MARKER_MAGIC:
267 await context.abort(StatusCode.NOT_FOUND, f"method {name} missing marker {marker}")
269 return method