Coverage for /Users/ajo/work/jumpstarter/jumpstarter/packages/jumpstarter/jumpstarter/config/client.py: 52%

185 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-26 15:50 +0200

1from __future__ import annotations 

2 

3import asyncio 

4import os 

5from contextlib import asynccontextmanager, contextmanager 

6from datetime import timedelta 

7from functools import wraps 

8from pathlib import Path 

9from typing import Annotated, ClassVar, Literal, Optional, Self 

10 

11import grpc 

12import yaml 

13from anyio.from_thread import BlockingPortal, start_blocking_portal 

14from pydantic import BaseModel, ConfigDict, Field, ValidationError, field_validator, model_validator 

15from pydantic_settings import BaseSettings, NoDecode, SettingsConfigDict 

16 

17from .common import CONFIG_PATH, ObjectMeta 

18from .env import JMP_LEASE 

19from .grpc import call_credentials 

20from .tls import TLSConfigV1Alpha1 

21from jumpstarter.client.grpc import ClientService 

22from jumpstarter.common.exceptions import ConfigurationError, FileNotFoundError 

23from jumpstarter.common.grpc import aio_secure_channel, ssl_channel_credentials 

24 

25 

26def _blocking_compat(f): 

27 @wraps(f) 

28 def wrapper(*args, **kwargs): 

29 try: 

30 asyncio.get_running_loop() 

31 except RuntimeError: 

32 return asyncio.run(f(*args, **kwargs)) 

33 else: 

34 return f(*args, **kwargs) 

35 

36 return wrapper 

37 

38 

39class ClientConfigV1Alpha1Drivers(BaseSettings): 

40 model_config = SettingsConfigDict(env_prefix="JMP_DRIVERS_") 

41 

42 allow: Annotated[list[str], NoDecode] = Field(default_factory=list) 

43 unsafe: bool = Field(default=False) 

44 

45 @field_validator("allow", mode="before") 

46 @classmethod 

47 def decode_allow(cls, v: str | list[str]) -> list[str]: 

48 if not isinstance(v, list): 

49 return list(v.split(",")) 

50 else: 

51 return v 

52 

53 @model_validator(mode="after") 

54 def decode_unsafe(self) -> Self: 

55 if "UNSAFE" in self.allow: 

56 self.unsafe = True 

57 

58 return self 

59 

60 

61class ClientConfigV1Alpha1(BaseSettings): 

62 CLIENT_CONFIGS_PATH: ClassVar[Path] = CONFIG_PATH / "clients" 

63 

64 model_config = SettingsConfigDict(env_prefix="JMP_") 

65 

66 alias: str = Field(default="default") 

67 path: Path | None = Field(default=None) 

68 

69 apiVersion: Literal["jumpstarter.dev/v1alpha1"] = Field(default="jumpstarter.dev/v1alpha1") 

70 kind: Literal["ClientConfig"] = Field(default="ClientConfig") 

71 

72 metadata: ObjectMeta = Field(default_factory=ObjectMeta) 

73 

74 endpoint: str | None = Field(default=None) 

75 tls: TLSConfigV1Alpha1 = Field(default_factory=TLSConfigV1Alpha1) 

76 token: str | None = Field(default=None) 

77 grpcOptions: dict[str, str | int] | None = Field(default_factory=dict) 

78 

79 drivers: ClientConfigV1Alpha1Drivers = Field(default_factory=ClientConfigV1Alpha1Drivers) 

80 

81 async def channel(self): 

82 if self.endpoint is None or self.token is None: 

83 raise ConfigurationError("endpoint or token not set in client config") 

84 

85 credentials = grpc.composite_channel_credentials( 

86 await ssl_channel_credentials(self.endpoint, self.tls), 

87 call_credentials("Client", self.metadata, self.token), 

88 ) 

89 

90 return aio_secure_channel(self.endpoint, credentials, self.grpcOptions) 

91 

92 @contextmanager 

93 def lease( 

94 self, 

95 selector: str | None = None, 

96 lease_name: str | None = None, 

97 duration: timedelta = timedelta(minutes=30), 

98 ): 

99 with start_blocking_portal() as portal: 

100 with portal.wrap_async_context_manager(self.lease_async(selector, lease_name, duration, portal)) as lease: 

101 yield lease 

102 

103 @_blocking_compat 

104 async def get_exporter( 

105 self, 

106 name: str, 

107 ): 

108 svc = ClientService(channel=await self.channel(), namespace=self.metadata.namespace) 

109 return await svc.GetExporter(name=name) 

110 

111 @_blocking_compat 

112 async def list_exporters( 

113 self, 

114 page_size: int | None = None, 

115 page_token: str | None = None, 

116 filter: str | None = None, 

117 ): 

118 svc = ClientService(channel=await self.channel(), namespace=self.metadata.namespace) 

119 return await svc.ListExporters(page_size=page_size, page_token=page_token, filter=filter) 

120 

121 @_blocking_compat 

122 async def create_lease( 

123 self, 

124 selector: str, 

125 duration: timedelta, 

126 ): 

127 svc = ClientService(channel=await self.channel(), namespace=self.metadata.namespace) 

128 return await svc.CreateLease( 

129 selector=selector, 

130 duration=duration, 

131 ) 

132 

133 @_blocking_compat 

134 async def delete_lease( 

135 self, 

136 name: str, 

137 ): 

138 svc = ClientService(channel=await self.channel(), namespace=self.metadata.namespace) 

139 await svc.DeleteLease( 

140 name=name, 

141 ) 

142 

143 @_blocking_compat 

144 async def list_leases( 

145 self, 

146 page_size: int | None = None, 

147 page_token: str | None = None, 

148 filter: str | None = None, 

149 ): 

150 svc = ClientService(channel=await self.channel(), namespace=self.metadata.namespace) 

151 return await svc.ListLeases( 

152 page_size=page_size, 

153 page_token=page_token, 

154 filter=filter, 

155 ) 

156 

157 @_blocking_compat 

158 async def update_lease( 

159 self, 

160 name, 

161 duration: timedelta, 

162 ): 

163 svc = ClientService(channel=await self.channel(), namespace=self.metadata.namespace) 

164 return await svc.UpdateLease(name=name, duration=duration) 

165 

166 @asynccontextmanager 

167 async def lease_async( 

168 self, 

169 selector: str, 

170 lease_name: str | None, 

171 duration: timedelta, 

172 portal: BlockingPortal, 

173 ): 

174 from jumpstarter.client import Lease 

175 

176 # if no lease_name provided, check if it is set in the environment 

177 lease_name = lease_name or os.environ.get(JMP_LEASE, "") 

178 # when no lease name is provided, release the lease on exit 

179 release_lease = lease_name == "" 

180 

181 async with Lease( 

182 channel=await self.channel(), 

183 namespace=self.metadata.namespace, 

184 name=lease_name, 

185 selector=selector, 

186 duration=duration, 

187 portal=portal, 

188 allow=self.drivers.allow, 

189 unsafe=self.drivers.unsafe, 

190 release=release_lease, 

191 tls_config=self.tls, 

192 grpc_options=self.grpcOptions, 

193 ) as lease: 

194 yield lease 

195 

196 @classmethod 

197 def from_file(cls, path: os.PathLike): 

198 with open(path) as f: 

199 v = cls.model_validate(yaml.safe_load(f)) 

200 v.alias = os.path.basename(path).split(".")[0] 

201 v.path = Path(path) 

202 return v 

203 

204 @classmethod 

205 def ensure_exists(cls): 

206 """Check if the clients config dir exists, otherwise create it.""" 

207 os.makedirs(cls.CLIENT_CONFIGS_PATH, exist_ok=True) 

208 

209 @classmethod 

210 def try_from_env(cls): 

211 try: 

212 return cls.from_env() 

213 except ValidationError: 

214 return None 

215 

216 @classmethod 

217 def from_env(cls): 

218 return cls() 

219 

220 @classmethod 

221 def _get_path(cls, alias: str) -> Path: 

222 """Get the regular path of a client config given an alias.""" 

223 return (cls.CLIENT_CONFIGS_PATH / alias).with_suffix(".yaml") 

224 

225 @classmethod 

226 def load(cls, alias: str) -> Self: 

227 """Load a client config by alias.""" 

228 path = cls._get_path(alias) 

229 if path.exists() is False: 

230 raise FileNotFoundError(f"Client config '{path}' does not exist.") 

231 return cls.from_file(path) 

232 

233 @classmethod 

234 def save(cls, config: Self, path: Optional[os.PathLike] = None) -> Path: 

235 """Saves a client config as YAML.""" 

236 # Ensure the clients dir exists 

237 if path is None: 

238 cls.ensure_exists() 

239 # Set the config path before saving 

240 config.path = cls._get_path(config.alias) 

241 else: 

242 config.path = Path(path) 

243 with config.path.open(mode="w") as f: 

244 yaml.safe_dump(config.model_dump(mode="json", exclude={"path", "alias"}), f, sort_keys=False) 

245 return config.path 

246 

247 @classmethod 

248 def dump_yaml(cls, config: Self) -> str: 

249 return yaml.safe_dump(config.model_dump(mode="json", exclude={"path", "alias"}), sort_keys=False) 

250 

251 @classmethod 

252 def exists(cls, alias: str) -> bool: 

253 """Check if a client config exists by alias.""" 

254 return cls._get_path(alias).exists() 

255 

256 @classmethod 

257 def list(cls) -> ClientConfigListV1Alpha1: 

258 """List the available client configs.""" 

259 from .user import UserConfigV1Alpha1 

260 

261 if cls.CLIENT_CONFIGS_PATH.exists() is False: 

262 # Return an empty list if the dir does not exist 

263 return ClientConfigListV1Alpha1( 

264 current_config=None, 

265 items=[], 

266 ) 

267 

268 results = os.listdir(cls.CLIENT_CONFIGS_PATH) 

269 # Only accept YAML files in the list 

270 files = filter(lambda x: x.endswith(".yaml"), results) 

271 

272 def make_config(file: str): 

273 path = cls.CLIENT_CONFIGS_PATH / file 

274 return cls.from_file(path) 

275 

276 current_config = None 

277 if UserConfigV1Alpha1.exists(): 

278 current_client = UserConfigV1Alpha1.load().config.current_client 

279 current_config = current_client.alias if current_client is not None else None 

280 

281 return ClientConfigListV1Alpha1( 

282 current_config=current_config, 

283 items=list(map(make_config, files)), 

284 ) 

285 

286 @classmethod 

287 def delete(cls, alias: str) -> Path: 

288 """Delete a client config by alias.""" 

289 path = cls._get_path(alias) 

290 if path.exists() is False: 

291 raise FileNotFoundError(f"Client config '{path}' does not exist.") 

292 path.unlink() 

293 return path 

294 

295 

296class ClientConfigListV1Alpha1(BaseModel): 

297 api_version: Literal["jumpstarter.dev/v1alpha1"] = Field(alias="apiVersion", default="jumpstarter.dev/v1alpha1") 

298 current_config: Optional[str] = Field(alias="currentConfig") 

299 items: list[ClientConfigV1Alpha1] 

300 kind: Literal["ClientConfigList"] = Field(default="ClientConfigList") 

301 

302 def dump_json(self): 

303 return self.model_dump_json(indent=4, by_alias=True) 

304 

305 def dump_yaml(self): 

306 return yaml.safe_dump(self.model_dump(mode="json", by_alias=True), indent=2) 

307 

308 model_config = ConfigDict(arbitrary_types_allowed=True, populate_by_name=True) 

309 

310 @classmethod 

311 def rich_add_columns(cls, table): 

312 table.add_column("CURRENT") 

313 table.add_column("ALIAS") 

314 table.add_column("ENDPOINT") 

315 table.add_column("PATH") 

316 

317 def rich_add_rows(self, table): 

318 for client in self.items: 

319 table.add_row( 

320 "*" if self.current_config == client.alias else "", 

321 client.alias, 

322 client.endpoint, 

323 str(client.path), 

324 ) 

325 

326 def rich_add_names(self, names): 

327 for client in self.items: 

328 names.append(client.alias)