Coverage for src/meshadmin/cli/utils.py: 23%

155 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-05-09 15:09 +0200

1import asyncio 

2import json 

3import os 

4import signal 

5from datetime import datetime, timedelta 

6from importlib.metadata import version 

7from pathlib import Path 

8 

9import httpx 

10import structlog 

11import typer 

12import yaml 

13from jwcrypto.jwk import JWK 

14from jwcrypto.jwt import JWT 

15from jwt import decode 

16 

17from meshadmin.cli.config import get_config 

18from meshadmin.common.utils import create_expiration_date, get_nebula_path 

19 

20logger = structlog.get_logger(__name__) 

21 

22 

23def get_access_token(): 

24 if os.getenv("MESHADMIN_TEST_MODE") == "true": 

25 return "test-token" 

26 

27 config = get_config() 

28 

29 client_id = os.environ.get("OIDC_CLIENT_ID") 

30 client_secret = os.environ.get("OIDC_CLIENT_SECRET") 

31 if client_id and client_secret: 

32 logger.info("using client credentials flow") 

33 res = httpx.post( 

34 config.keycloak_token_url, 

35 data={ 

36 "grant_type": "client_credentials", 

37 "client_id": client_id, 

38 "client_secret": client_secret, 

39 }, 

40 ) 

41 res.raise_for_status() 

42 return res.json()["access_token"] 

43 

44 if config.authentication_path.exists(): 

45 logger.info("using device flow") 

46 auth = json.loads(config.authentication_path.read_text()) 

47 access_token = auth["access_token"] 

48 

49 decoded_token = decode( 

50 access_token, options={"verify_signature": False, "verify_exp": False} 

51 ) 

52 

53 # is exp still 2/3 of the time 

54 if decoded_token["exp"] >= (datetime.now() + timedelta(seconds=10)).timestamp(): 

55 return access_token 

56 else: 

57 refresh_token = auth["refresh_token"] 

58 res = httpx.post( 

59 config.keycloak_token_url, 

60 data={ 

61 "grant_type": "refresh_token", 

62 "refresh_token": refresh_token, 

63 "client_id": config.keycloak_admin_client, 

64 }, 

65 ) 

66 res.raise_for_status() 

67 config.authentication_path.write_bytes(res.content) 

68 return res.json()["access_token"] 

69 

70 else: 

71 print("authentication failed") 

72 

73 

74def get_context_config(): 

75 config = get_config() 

76 if not config.contexts_file.exists(): 

77 print("No contexts found") 

78 raise typer.Exit(1) 

79 

80 with open(config.contexts_file) as f: 

81 contexts = yaml.safe_load(f) or {} 

82 

83 current = os.getenv("MESH_CONTEXT") 

84 if not current: 

85 active_contexts = [ 

86 name for name, data in contexts.items() if data.get("active") 

87 ] 

88 current = active_contexts[0] if active_contexts else None 

89 

90 if not current or current not in contexts: 

91 print("No active context. Please select a context with 'meshadmin context use'") 

92 raise typer.Exit(1) 

93 

94 context_data = contexts[current] 

95 network_dir = config.networks_dir / current 

96 

97 return { 

98 "name": current, 

99 "endpoint": context_data["endpoint"], 

100 "interface": context_data["interface"], 

101 "network_dir": network_dir, 

102 } 

103 

104 

105async def get_config_from_mesh(mesh_admin_endpoint, private_auth_key): 

106 jwt = JWT( 

107 header={"alg": "RS256", "kid": private_auth_key.thumbprint()}, 

108 claims={ 

109 "exp": create_expiration_date(10), 

110 "kid": private_auth_key.thumbprint(), 

111 }, 

112 ) 

113 jwt.make_signed_token(private_auth_key) 

114 token = jwt.serialize() 

115 

116 async with httpx.AsyncClient() as client: 

117 res = await client.get( 

118 f"{mesh_admin_endpoint}/api/v1/config", 

119 headers={ 

120 "Authorization": f"Bearer {token}", 

121 "X-Meshadmin-Version": version("meshadmin"), 

122 }, 

123 ) 

124 res.raise_for_status() 

125 config = res.text 

126 update_interval = int(res.headers.get("X-Update-Interval", "5")) 

127 upgrade_required = res.headers.get("X-Upgrade-Requested") == "true" 

128 return config, update_interval, upgrade_required 

129 

130 

131async def cleanup_ephemeral_hosts(mesh_admin_endpoint, private_auth_key): 

132 jwt_token = JWT( 

133 header={"alg": "RS256", "kid": private_auth_key.thumbprint()}, 

134 claims={ 

135 "exp": create_expiration_date(10), 

136 "kid": private_auth_key.thumbprint(), 

137 }, 

138 ) 

139 jwt_token.make_signed_token(private_auth_key) 

140 token = jwt_token.serialize() 

141 

142 async with httpx.AsyncClient() as client: 

143 res = await client.post( 

144 f"{mesh_admin_endpoint}/api/v1/cleanup-ephemeral", 

145 headers={"Authorization": f"Bearer {token}"}, 

146 ) 

147 res.raise_for_status() 

148 return res.json() 

149 

150 

151async def start_nebula(network_dir: Path, mesh_admin_endpoint: str): 

152 config = get_config() 

153 await logger.ainfo("starting nebula") 

154 conf_path = network_dir / config.config_path 

155 assert conf_path.exists(), f"Config at {conf_path} does not exist" 

156 

157 private_auth_key_path = config.contexts_file.parent / config.private_key 

158 assert private_auth_key_path.exists(), ( 

159 f"private_key at {private_auth_key_path} does not exist" 

160 ) 

161 

162 async def start_process(): 

163 return await asyncio.create_subprocess_exec( 

164 get_nebula_path(), 

165 "-config", 

166 str(conf_path), 

167 cwd=network_dir, 

168 ) 

169 

170 proc = await start_process() 

171 

172 # Default update interval in seconds 

173 update_interval = 5 

174 

175 while True: 

176 await asyncio.sleep(update_interval) 

177 try: 

178 private_auth_key_path = config.contexts_file.parent / config.private_key 

179 private_auth_key = JWK.from_json(private_auth_key_path.read_text()) 

180 

181 # Check for config updates 

182 try: 

183 ( 

184 new_config, 

185 new_update_interval, 

186 upgrade_required, 

187 ) = await get_config_from_mesh(mesh_admin_endpoint, private_auth_key) 

188 if upgrade_required: 

189 await logger.ainfo("Starting meshadmin self-upgrade") 

190 await perform_self_upgrade() 

191 

192 if update_interval != new_update_interval: 

193 await logger.ainfo( 

194 "update interval changed", 

195 old_interval=update_interval, 

196 new_interval=new_update_interval, 

197 ) 

198 update_interval = new_update_interval 

199 

200 old_config = conf_path.read_text() 

201 if new_config != old_config: 

202 await logger.ainfo("config changed, reloading") 

203 conf_path.write_text(new_config) 

204 conf_path.chmod(0o600) 

205 

206 try: 

207 proc.send_signal(signal.SIGHUP) 

208 except ProcessLookupError: 

209 await logger.ainfo("process died, restarting") 

210 proc = await start_process() 

211 else: 

212 await logger.ainfo("config not changed") 

213 except httpx.HTTPStatusError as e: 

214 if e.response.status_code == 401: 

215 await logger.aerror( 

216 "Could not get config because of authentication error. Host may have been deleted.", 

217 error=str(e), 

218 response_text=e.response.text, 

219 ) 

220 print( 

221 "Error: Could not get config because of authentication error. Host may have been deleted." 

222 ) 

223 print(f"Server message: {e.response.text}") 

224 break 

225 else: 

226 await logger.aerror("error getting config", error=str(e)) 

227 

228 # Cleanup ephemeral hosts 

229 try: 

230 result = await cleanup_ephemeral_hosts( 

231 mesh_admin_endpoint, private_auth_key 

232 ) 

233 if result.get("removed_count", 0) > 0: 

234 await logger.ainfo( 

235 "removed stale ephemeral hosts", 

236 count=result["removed_count"], 

237 ) 

238 except httpx.HTTPStatusError as e: 

239 if e.response.status_code == 401: 

240 await logger.aerror( 

241 "Could not clean up ephemeral hosts because of authentication error. Host may have been deleted.", 

242 error=str(e), 

243 response_text=e.response.text, 

244 ) 

245 print( 

246 "Error: Could not clean up ephemeral hosts because of authentication error. Host may have been deleted." 

247 ) 

248 print(f"Server message: {e.response.text}") 

249 break 

250 else: 

251 await logger.aerror("error during cleanup operation", error=str(e)) 

252 

253 except Exception: 

254 await logger.aexception("could not refresh token") 

255 if proc.returncode is not None: 

256 await logger.ainfo("process died, restarting") 

257 proc = await start_process() 

258 

259 # Clean shutdown if we get here 

260 if proc.returncode is None: 

261 await logger.ainfo("shutting down nebula process") 

262 proc.terminate() 

263 try: 

264 await asyncio.wait_for(proc.wait(), timeout=5.0) 

265 except asyncio.TimeoutError: 

266 await logger.awarning("nebula process didn't terminate, killing it") 

267 proc.kill() 

268 

269 

270async def perform_self_upgrade(): 

271 try: 

272 process = await asyncio.create_subprocess_exec( 

273 "uv", 

274 "tool", 

275 "install", 

276 "--upgrade", 

277 "meshadmin", 

278 stdout=asyncio.subprocess.PIPE, 

279 stderr=asyncio.subprocess.PIPE, 

280 ) 

281 stdout, stderr = await process.communicate() 

282 if process.returncode == 0: 

283 await logger.ainfo( 

284 "meshadmin successfully upgraded", 

285 stdout=stdout.decode().strip() if stdout else None, 

286 ) 

287 return True 

288 else: 

289 await logger.aerror( 

290 "meshadmin upgrade failed", 

291 returncode=process.returncode, 

292 stderr=stderr.decode().strip() if stderr else None, 

293 ) 

294 return False 

295 except Exception as e: 

296 await logger.aerror("Error during meshadmin self-upgrade", error=str(e)) 

297 return False