Coverage for /Users/ajo/work/jumpstarter/jumpstarter/packages/jumpstarter/jumpstarter/config/client.py: 52%
185 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-26 15:50 +0200
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-26 15:50 +0200
1from __future__ import annotations
3import asyncio
4import os
5from contextlib import asynccontextmanager, contextmanager
6from datetime import timedelta
7from functools import wraps
8from pathlib import Path
9from typing import Annotated, ClassVar, Literal, Optional, Self
11import grpc
12import yaml
13from anyio.from_thread import BlockingPortal, start_blocking_portal
14from pydantic import BaseModel, ConfigDict, Field, ValidationError, field_validator, model_validator
15from pydantic_settings import BaseSettings, NoDecode, SettingsConfigDict
17from .common import CONFIG_PATH, ObjectMeta
18from .env import JMP_LEASE
19from .grpc import call_credentials
20from .tls import TLSConfigV1Alpha1
21from jumpstarter.client.grpc import ClientService
22from jumpstarter.common.exceptions import ConfigurationError, FileNotFoundError
23from jumpstarter.common.grpc import aio_secure_channel, ssl_channel_credentials
26def _blocking_compat(f):
27 @wraps(f)
28 def wrapper(*args, **kwargs):
29 try:
30 asyncio.get_running_loop()
31 except RuntimeError:
32 return asyncio.run(f(*args, **kwargs))
33 else:
34 return f(*args, **kwargs)
36 return wrapper
39class ClientConfigV1Alpha1Drivers(BaseSettings):
40 model_config = SettingsConfigDict(env_prefix="JMP_DRIVERS_")
42 allow: Annotated[list[str], NoDecode] = Field(default_factory=list)
43 unsafe: bool = Field(default=False)
45 @field_validator("allow", mode="before")
46 @classmethod
47 def decode_allow(cls, v: str | list[str]) -> list[str]:
48 if not isinstance(v, list):
49 return list(v.split(","))
50 else:
51 return v
53 @model_validator(mode="after")
54 def decode_unsafe(self) -> Self:
55 if "UNSAFE" in self.allow:
56 self.unsafe = True
58 return self
61class ClientConfigV1Alpha1(BaseSettings):
62 CLIENT_CONFIGS_PATH: ClassVar[Path] = CONFIG_PATH / "clients"
64 model_config = SettingsConfigDict(env_prefix="JMP_")
66 alias: str = Field(default="default")
67 path: Path | None = Field(default=None)
69 apiVersion: Literal["jumpstarter.dev/v1alpha1"] = Field(default="jumpstarter.dev/v1alpha1")
70 kind: Literal["ClientConfig"] = Field(default="ClientConfig")
72 metadata: ObjectMeta = Field(default_factory=ObjectMeta)
74 endpoint: str | None = Field(default=None)
75 tls: TLSConfigV1Alpha1 = Field(default_factory=TLSConfigV1Alpha1)
76 token: str | None = Field(default=None)
77 grpcOptions: dict[str, str | int] | None = Field(default_factory=dict)
79 drivers: ClientConfigV1Alpha1Drivers = Field(default_factory=ClientConfigV1Alpha1Drivers)
81 async def channel(self):
82 if self.endpoint is None or self.token is None:
83 raise ConfigurationError("endpoint or token not set in client config")
85 credentials = grpc.composite_channel_credentials(
86 await ssl_channel_credentials(self.endpoint, self.tls),
87 call_credentials("Client", self.metadata, self.token),
88 )
90 return aio_secure_channel(self.endpoint, credentials, self.grpcOptions)
92 @contextmanager
93 def lease(
94 self,
95 selector: str | None = None,
96 lease_name: str | None = None,
97 duration: timedelta = timedelta(minutes=30),
98 ):
99 with start_blocking_portal() as portal:
100 with portal.wrap_async_context_manager(self.lease_async(selector, lease_name, duration, portal)) as lease:
101 yield lease
103 @_blocking_compat
104 async def get_exporter(
105 self,
106 name: str,
107 ):
108 svc = ClientService(channel=await self.channel(), namespace=self.metadata.namespace)
109 return await svc.GetExporter(name=name)
111 @_blocking_compat
112 async def list_exporters(
113 self,
114 page_size: int | None = None,
115 page_token: str | None = None,
116 filter: str | None = None,
117 ):
118 svc = ClientService(channel=await self.channel(), namespace=self.metadata.namespace)
119 return await svc.ListExporters(page_size=page_size, page_token=page_token, filter=filter)
121 @_blocking_compat
122 async def create_lease(
123 self,
124 selector: str,
125 duration: timedelta,
126 ):
127 svc = ClientService(channel=await self.channel(), namespace=self.metadata.namespace)
128 return await svc.CreateLease(
129 selector=selector,
130 duration=duration,
131 )
133 @_blocking_compat
134 async def delete_lease(
135 self,
136 name: str,
137 ):
138 svc = ClientService(channel=await self.channel(), namespace=self.metadata.namespace)
139 await svc.DeleteLease(
140 name=name,
141 )
143 @_blocking_compat
144 async def list_leases(
145 self,
146 page_size: int | None = None,
147 page_token: str | None = None,
148 filter: str | None = None,
149 ):
150 svc = ClientService(channel=await self.channel(), namespace=self.metadata.namespace)
151 return await svc.ListLeases(
152 page_size=page_size,
153 page_token=page_token,
154 filter=filter,
155 )
157 @_blocking_compat
158 async def update_lease(
159 self,
160 name,
161 duration: timedelta,
162 ):
163 svc = ClientService(channel=await self.channel(), namespace=self.metadata.namespace)
164 return await svc.UpdateLease(name=name, duration=duration)
166 @asynccontextmanager
167 async def lease_async(
168 self,
169 selector: str,
170 lease_name: str | None,
171 duration: timedelta,
172 portal: BlockingPortal,
173 ):
174 from jumpstarter.client import Lease
176 # if no lease_name provided, check if it is set in the environment
177 lease_name = lease_name or os.environ.get(JMP_LEASE, "")
178 # when no lease name is provided, release the lease on exit
179 release_lease = lease_name == ""
181 async with Lease(
182 channel=await self.channel(),
183 namespace=self.metadata.namespace,
184 name=lease_name,
185 selector=selector,
186 duration=duration,
187 portal=portal,
188 allow=self.drivers.allow,
189 unsafe=self.drivers.unsafe,
190 release=release_lease,
191 tls_config=self.tls,
192 grpc_options=self.grpcOptions,
193 ) as lease:
194 yield lease
196 @classmethod
197 def from_file(cls, path: os.PathLike):
198 with open(path) as f:
199 v = cls.model_validate(yaml.safe_load(f))
200 v.alias = os.path.basename(path).split(".")[0]
201 v.path = Path(path)
202 return v
204 @classmethod
205 def ensure_exists(cls):
206 """Check if the clients config dir exists, otherwise create it."""
207 os.makedirs(cls.CLIENT_CONFIGS_PATH, exist_ok=True)
209 @classmethod
210 def try_from_env(cls):
211 try:
212 return cls.from_env()
213 except ValidationError:
214 return None
216 @classmethod
217 def from_env(cls):
218 return cls()
220 @classmethod
221 def _get_path(cls, alias: str) -> Path:
222 """Get the regular path of a client config given an alias."""
223 return (cls.CLIENT_CONFIGS_PATH / alias).with_suffix(".yaml")
225 @classmethod
226 def load(cls, alias: str) -> Self:
227 """Load a client config by alias."""
228 path = cls._get_path(alias)
229 if path.exists() is False:
230 raise FileNotFoundError(f"Client config '{path}' does not exist.")
231 return cls.from_file(path)
233 @classmethod
234 def save(cls, config: Self, path: Optional[os.PathLike] = None) -> Path:
235 """Saves a client config as YAML."""
236 # Ensure the clients dir exists
237 if path is None:
238 cls.ensure_exists()
239 # Set the config path before saving
240 config.path = cls._get_path(config.alias)
241 else:
242 config.path = Path(path)
243 with config.path.open(mode="w") as f:
244 yaml.safe_dump(config.model_dump(mode="json", exclude={"path", "alias"}), f, sort_keys=False)
245 return config.path
247 @classmethod
248 def dump_yaml(cls, config: Self) -> str:
249 return yaml.safe_dump(config.model_dump(mode="json", exclude={"path", "alias"}), sort_keys=False)
251 @classmethod
252 def exists(cls, alias: str) -> bool:
253 """Check if a client config exists by alias."""
254 return cls._get_path(alias).exists()
256 @classmethod
257 def list(cls) -> ClientConfigListV1Alpha1:
258 """List the available client configs."""
259 from .user import UserConfigV1Alpha1
261 if cls.CLIENT_CONFIGS_PATH.exists() is False:
262 # Return an empty list if the dir does not exist
263 return ClientConfigListV1Alpha1(
264 current_config=None,
265 items=[],
266 )
268 results = os.listdir(cls.CLIENT_CONFIGS_PATH)
269 # Only accept YAML files in the list
270 files = filter(lambda x: x.endswith(".yaml"), results)
272 def make_config(file: str):
273 path = cls.CLIENT_CONFIGS_PATH / file
274 return cls.from_file(path)
276 current_config = None
277 if UserConfigV1Alpha1.exists():
278 current_client = UserConfigV1Alpha1.load().config.current_client
279 current_config = current_client.alias if current_client is not None else None
281 return ClientConfigListV1Alpha1(
282 current_config=current_config,
283 items=list(map(make_config, files)),
284 )
286 @classmethod
287 def delete(cls, alias: str) -> Path:
288 """Delete a client config by alias."""
289 path = cls._get_path(alias)
290 if path.exists() is False:
291 raise FileNotFoundError(f"Client config '{path}' does not exist.")
292 path.unlink()
293 return path
296class ClientConfigListV1Alpha1(BaseModel):
297 api_version: Literal["jumpstarter.dev/v1alpha1"] = Field(alias="apiVersion", default="jumpstarter.dev/v1alpha1")
298 current_config: Optional[str] = Field(alias="currentConfig")
299 items: list[ClientConfigV1Alpha1]
300 kind: Literal["ClientConfigList"] = Field(default="ClientConfigList")
302 def dump_json(self):
303 return self.model_dump_json(indent=4, by_alias=True)
305 def dump_yaml(self):
306 return yaml.safe_dump(self.model_dump(mode="json", by_alias=True), indent=2)
308 model_config = ConfigDict(arbitrary_types_allowed=True, populate_by_name=True)
310 @classmethod
311 def rich_add_columns(cls, table):
312 table.add_column("CURRENT")
313 table.add_column("ALIAS")
314 table.add_column("ENDPOINT")
315 table.add_column("PATH")
317 def rich_add_rows(self, table):
318 for client in self.items:
319 table.add_row(
320 "*" if self.current_config == client.alias else "",
321 client.alias,
322 client.endpoint,
323 str(client.path),
324 )
326 def rich_add_names(self, names):
327 for client in self.items:
328 names.append(client.alias)