Coverage for /Users/ajo/work/jumpstarter/jumpstarter/packages/jumpstarter/jumpstarter/client/core.py: 64%
94 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-26 17:27 +0200
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-26 17:27 +0200
1"""
2Base classes for drivers and driver clients
3"""
5import logging
6from contextlib import asynccontextmanager
7from dataclasses import dataclass, field
8from typing import Any
10from anyio import create_task_group
11from google.protobuf import empty_pb2
12from grpc import StatusCode
13from grpc.aio import AioRpcError
14from jumpstarter_protocol import jumpstarter_pb2, jumpstarter_pb2_grpc, router_pb2_grpc
16from jumpstarter.common import Metadata
17from jumpstarter.common.exceptions import JumpstarterException
18from jumpstarter.common.resources import ResourceMetadata
19from jumpstarter.common.serde import decode_value, encode_value
20from jumpstarter.common.streams import (
21 DriverStreamRequest,
22 ResourceStreamRequest,
23 StreamRequestMetadata,
24)
25from jumpstarter.streams.common import forward_stream
26from jumpstarter.streams.encoding import compress_stream
27from jumpstarter.streams.metadata import MetadataStream, MetadataStreamAttributes
28from jumpstarter.streams.progress import ProgressStream
29from jumpstarter.streams.router import RouterStream
32class DriverError(JumpstarterException):
33 """
34 Raised when a driver call returns an error
35 """
38class DriverMethodNotImplemented(DriverError, NotImplementedError):
39 """
40 Raised when a driver method is not implemented
41 """
44class DriverInvalidArgument(DriverError, ValueError):
45 """
46 Raised when a driver method is called with invalid arguments
47 """
50@dataclass(kw_only=True)
51class AsyncDriverClient(
52 Metadata,
53 jumpstarter_pb2_grpc.ExporterServiceStub,
54 router_pb2_grpc.RouterServiceStub,
55):
56 """
57 Async driver client base class
59 Backing implementation of blocking driver client.
60 """
62 stub: Any
64 log_level: str = "INFO"
65 logger: logging.Logger = field(init=False)
67 def __post_init__(self):
68 if hasattr(super(), "__post_init__"):
69 super().__post_init__()
70 self.logger = logging.getLogger(self.__class__.__name__)
71 self.logger.setLevel(self.log_level)
73 # add default handler
74 if not self.logger.handlers:
75 handler = logging.StreamHandler()
76 handler.setFormatter(logging.Formatter("%(name)s - %(levelname)s - %(message)s"))
77 self.logger.addHandler(handler)
79 async def call_async(self, method, *args):
80 """Make DriverCall by method name and arguments"""
82 request = jumpstarter_pb2.DriverCallRequest(
83 uuid=str(self.uuid),
84 method=method,
85 args=[encode_value(arg) for arg in args],
86 )
88 try:
89 response = await self.stub.DriverCall(request)
90 except AioRpcError as e:
91 match e.code():
92 case StatusCode.UNIMPLEMENTED:
93 raise DriverMethodNotImplemented(e.details()) from None
94 case StatusCode.INVALID_ARGUMENT:
95 raise DriverInvalidArgument(e.details()) from None
96 case StatusCode.UNKNOWN:
97 raise DriverError(e.details()) from None
98 case _:
99 raise DriverError(e.details()) from e
101 return decode_value(response.result)
103 async def streamingcall_async(self, method, *args):
104 """Make StreamingDriverCall by method name and arguments"""
106 request = jumpstarter_pb2.StreamingDriverCallRequest(
107 uuid=str(self.uuid),
108 method=method,
109 args=[encode_value(arg) for arg in args],
110 )
112 try:
113 async for response in self.stub.StreamingDriverCall(request):
114 yield decode_value(response.result)
115 except AioRpcError as e:
116 match e.code():
117 case StatusCode.UNIMPLEMENTED:
118 raise DriverMethodNotImplemented(e.details()) from None
119 case StatusCode.INVALID_ARGUMENT:
120 raise DriverInvalidArgument(e.details()) from None
121 case StatusCode.UNKNOWN:
122 raise DriverError(e.details()) from None
123 case _:
124 raise DriverError(e.details()) from e
126 @asynccontextmanager
127 async def stream_async(self, method):
128 context = self.stub.Stream(
129 metadata=StreamRequestMetadata.model_construct(request=DriverStreamRequest(uuid=self.uuid, method=method))
130 .model_dump(mode="json", round_trip=True)
131 .items(),
132 )
133 metadata = dict(list(await context.initial_metadata()))
134 async with MetadataStream(stream=RouterStream(context=context), metadata=metadata) as stream:
135 yield stream
137 @asynccontextmanager
138 async def resource_async(
139 self,
140 stream,
141 content_encoding: str | None = None,
142 ):
143 context = self.stub.Stream(
144 metadata=StreamRequestMetadata.model_construct(
145 request=ResourceStreamRequest(uuid=self.uuid, x_jmp_content_encoding=content_encoding)
146 )
147 .model_dump(mode="json", round_trip=True)
148 .items(),
149 )
150 metadata = dict(list(await context.initial_metadata()))
151 async with MetadataStream(stream=RouterStream(context=context), metadata=metadata) as rstream:
152 metadata = ResourceMetadata(**rstream.extra(MetadataStreamAttributes.metadata))
153 if metadata.x_jmp_accept_encoding is None:
154 stream = compress_stream(stream, content_encoding)
156 async with forward_stream(ProgressStream(stream=stream), rstream):
157 yield metadata.resource.model_dump(mode="json")
159 def __log(self, level: int, msg: str):
160 self.logger.log(level, msg)
162 @asynccontextmanager
163 async def log_stream_async(self):
164 async def log_stream():
165 async for response in self.stub.LogStream(empty_pb2.Empty()):
166 self.__log(logging.getLevelName(response.severity), response.message)
168 async with create_task_group() as tg:
169 tg.start_soon(log_stream)
170 try:
171 yield
172 finally:
173 tg.cancel_scope.cancel()