Coverage for /Users/ajo/work/jumpstarter/jumpstarter/packages/jumpstarter/jumpstarter/client/core.py: 64%

94 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-26 17:27 +0200

1""" 

2Base classes for drivers and driver clients 

3""" 

4 

5import logging 

6from contextlib import asynccontextmanager 

7from dataclasses import dataclass, field 

8from typing import Any 

9 

10from anyio import create_task_group 

11from google.protobuf import empty_pb2 

12from grpc import StatusCode 

13from grpc.aio import AioRpcError 

14from jumpstarter_protocol import jumpstarter_pb2, jumpstarter_pb2_grpc, router_pb2_grpc 

15 

16from jumpstarter.common import Metadata 

17from jumpstarter.common.exceptions import JumpstarterException 

18from jumpstarter.common.resources import ResourceMetadata 

19from jumpstarter.common.serde import decode_value, encode_value 

20from jumpstarter.common.streams import ( 

21 DriverStreamRequest, 

22 ResourceStreamRequest, 

23 StreamRequestMetadata, 

24) 

25from jumpstarter.streams.common import forward_stream 

26from jumpstarter.streams.encoding import compress_stream 

27from jumpstarter.streams.metadata import MetadataStream, MetadataStreamAttributes 

28from jumpstarter.streams.progress import ProgressStream 

29from jumpstarter.streams.router import RouterStream 

30 

31 

32class DriverError(JumpstarterException): 

33 """ 

34 Raised when a driver call returns an error 

35 """ 

36 

37 

38class DriverMethodNotImplemented(DriverError, NotImplementedError): 

39 """ 

40 Raised when a driver method is not implemented 

41 """ 

42 

43 

44class DriverInvalidArgument(DriverError, ValueError): 

45 """ 

46 Raised when a driver method is called with invalid arguments 

47 """ 

48 

49 

50@dataclass(kw_only=True) 

51class AsyncDriverClient( 

52 Metadata, 

53 jumpstarter_pb2_grpc.ExporterServiceStub, 

54 router_pb2_grpc.RouterServiceStub, 

55): 

56 """ 

57 Async driver client base class 

58 

59 Backing implementation of blocking driver client. 

60 """ 

61 

62 stub: Any 

63 

64 log_level: str = "INFO" 

65 logger: logging.Logger = field(init=False) 

66 

67 def __post_init__(self): 

68 if hasattr(super(), "__post_init__"): 

69 super().__post_init__() 

70 self.logger = logging.getLogger(self.__class__.__name__) 

71 self.logger.setLevel(self.log_level) 

72 

73 # add default handler 

74 if not self.logger.handlers: 

75 handler = logging.StreamHandler() 

76 handler.setFormatter(logging.Formatter("%(name)s - %(levelname)s - %(message)s")) 

77 self.logger.addHandler(handler) 

78 

79 async def call_async(self, method, *args): 

80 """Make DriverCall by method name and arguments""" 

81 

82 request = jumpstarter_pb2.DriverCallRequest( 

83 uuid=str(self.uuid), 

84 method=method, 

85 args=[encode_value(arg) for arg in args], 

86 ) 

87 

88 try: 

89 response = await self.stub.DriverCall(request) 

90 except AioRpcError as e: 

91 match e.code(): 

92 case StatusCode.UNIMPLEMENTED: 

93 raise DriverMethodNotImplemented(e.details()) from None 

94 case StatusCode.INVALID_ARGUMENT: 

95 raise DriverInvalidArgument(e.details()) from None 

96 case StatusCode.UNKNOWN: 

97 raise DriverError(e.details()) from None 

98 case _: 

99 raise DriverError(e.details()) from e 

100 

101 return decode_value(response.result) 

102 

103 async def streamingcall_async(self, method, *args): 

104 """Make StreamingDriverCall by method name and arguments""" 

105 

106 request = jumpstarter_pb2.StreamingDriverCallRequest( 

107 uuid=str(self.uuid), 

108 method=method, 

109 args=[encode_value(arg) for arg in args], 

110 ) 

111 

112 try: 

113 async for response in self.stub.StreamingDriverCall(request): 

114 yield decode_value(response.result) 

115 except AioRpcError as e: 

116 match e.code(): 

117 case StatusCode.UNIMPLEMENTED: 

118 raise DriverMethodNotImplemented(e.details()) from None 

119 case StatusCode.INVALID_ARGUMENT: 

120 raise DriverInvalidArgument(e.details()) from None 

121 case StatusCode.UNKNOWN: 

122 raise DriverError(e.details()) from None 

123 case _: 

124 raise DriverError(e.details()) from e 

125 

126 @asynccontextmanager 

127 async def stream_async(self, method): 

128 context = self.stub.Stream( 

129 metadata=StreamRequestMetadata.model_construct(request=DriverStreamRequest(uuid=self.uuid, method=method)) 

130 .model_dump(mode="json", round_trip=True) 

131 .items(), 

132 ) 

133 metadata = dict(list(await context.initial_metadata())) 

134 async with MetadataStream(stream=RouterStream(context=context), metadata=metadata) as stream: 

135 yield stream 

136 

137 @asynccontextmanager 

138 async def resource_async( 

139 self, 

140 stream, 

141 content_encoding: str | None = None, 

142 ): 

143 context = self.stub.Stream( 

144 metadata=StreamRequestMetadata.model_construct( 

145 request=ResourceStreamRequest(uuid=self.uuid, x_jmp_content_encoding=content_encoding) 

146 ) 

147 .model_dump(mode="json", round_trip=True) 

148 .items(), 

149 ) 

150 metadata = dict(list(await context.initial_metadata())) 

151 async with MetadataStream(stream=RouterStream(context=context), metadata=metadata) as rstream: 

152 metadata = ResourceMetadata(**rstream.extra(MetadataStreamAttributes.metadata)) 

153 if metadata.x_jmp_accept_encoding is None: 

154 stream = compress_stream(stream, content_encoding) 

155 

156 async with forward_stream(ProgressStream(stream=stream), rstream): 

157 yield metadata.resource.model_dump(mode="json") 

158 

159 def __log(self, level: int, msg: str): 

160 self.logger.log(level, msg) 

161 

162 @asynccontextmanager 

163 async def log_stream_async(self): 

164 async def log_stream(): 

165 async for response in self.stub.LogStream(empty_pb2.Empty()): 

166 self.__log(logging.getLevelName(response.severity), response.message) 

167 

168 async with create_task_group() as tg: 

169 tg.start_soon(log_stream) 

170 try: 

171 yield 

172 finally: 

173 tg.cancel_scope.cancel()