Coverage for src/meshadmin/cli/utils.py: 24%

147 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-05-06 11:46 +0200

1import asyncio 

2import json 

3import os 

4import signal 

5from datetime import datetime, timedelta 

6from importlib.metadata import version 

7from pathlib import Path 

8 

9import httpx 

10import structlog 

11import typer 

12import yaml 

13from jwcrypto.jwk import JWK 

14from jwcrypto.jwt import JWT 

15from jwt import decode 

16 

17from meshadmin.cli.config import get_config 

18from meshadmin.common.utils import create_expiration_date, get_nebula_path 

19 

20logger = structlog.get_logger(__name__) 

21 

22 

23def get_access_token(): 

24 if os.getenv("MESHADMIN_TEST_MODE") == "true": 

25 return "test-token" 

26 

27 config = get_config() 

28 if config.authentication_path.exists(): 

29 auth = json.loads(config.authentication_path.read_text()) 

30 access_token = auth["access_token"] 

31 

32 decoded_token = decode( 

33 access_token, options={"verify_signature": False, "verify_exp": False} 

34 ) 

35 

36 # is exp still 2/3 of the time 

37 if decoded_token["exp"] >= (datetime.now() + timedelta(seconds=10)).timestamp(): 

38 return access_token 

39 else: 

40 refresh_token = auth["refresh_token"] 

41 res = httpx.post( 

42 config.keycloak_token_url, 

43 data={ 

44 "grant_type": "refresh_token", 

45 "refresh_token": refresh_token, 

46 "client_id": config.keycloak_admin_client, 

47 }, 

48 ) 

49 res.raise_for_status() 

50 config.authentication_path.write_bytes(res.content) 

51 return res.json()["access_token"] 

52 

53 else: 

54 print("authentication failed") 

55 

56 

57def get_context_config(): 

58 config = get_config() 

59 if not config.contexts_file.exists(): 

60 print("No contexts found") 

61 raise typer.Exit(1) 

62 

63 with open(config.contexts_file) as f: 

64 contexts = yaml.safe_load(f) or {} 

65 

66 current = os.getenv("MESH_CONTEXT") 

67 if not current: 

68 active_contexts = [ 

69 name for name, data in contexts.items() if data.get("active") 

70 ] 

71 current = active_contexts[0] if active_contexts else None 

72 

73 if not current or current not in contexts: 

74 print("No active context. Please select a context with 'meshadmin context use'") 

75 raise typer.Exit(1) 

76 

77 context_data = contexts[current] 

78 network_dir = config.networks_dir / current 

79 

80 return { 

81 "name": current, 

82 "endpoint": context_data["endpoint"], 

83 "interface": context_data["interface"], 

84 "network_dir": network_dir, 

85 } 

86 

87 

88async def get_config_from_mesh(mesh_admin_endpoint, private_auth_key): 

89 jwt = JWT( 

90 header={"alg": "RS256", "kid": private_auth_key.thumbprint()}, 

91 claims={ 

92 "exp": create_expiration_date(10), 

93 "kid": private_auth_key.thumbprint(), 

94 }, 

95 ) 

96 jwt.make_signed_token(private_auth_key) 

97 token = jwt.serialize() 

98 

99 async with httpx.AsyncClient() as client: 

100 res = await client.get( 

101 f"{mesh_admin_endpoint}/api/v1/config", 

102 headers={ 

103 "Authorization": f"Bearer {token}", 

104 "X-Meshadmin-Version": version("meshadmin"), 

105 }, 

106 ) 

107 res.raise_for_status() 

108 config = res.text 

109 update_interval = int(res.headers.get("X-Update-Interval", "5")) 

110 upgrade_required = res.headers.get("X-Upgrade-Requested") == "true" 

111 return config, update_interval, upgrade_required 

112 

113 

114async def cleanup_ephemeral_hosts(mesh_admin_endpoint, private_auth_key): 

115 jwt_token = JWT( 

116 header={"alg": "RS256", "kid": private_auth_key.thumbprint()}, 

117 claims={ 

118 "exp": create_expiration_date(10), 

119 "kid": private_auth_key.thumbprint(), 

120 }, 

121 ) 

122 jwt_token.make_signed_token(private_auth_key) 

123 token = jwt_token.serialize() 

124 

125 async with httpx.AsyncClient() as client: 

126 res = await client.post( 

127 f"{mesh_admin_endpoint}/api/v1/cleanup-ephemeral", 

128 headers={"Authorization": f"Bearer {token}"}, 

129 ) 

130 res.raise_for_status() 

131 return res.json() 

132 

133 

134async def start_nebula(network_dir: Path, mesh_admin_endpoint: str): 

135 config = get_config() 

136 await logger.ainfo("starting nebula") 

137 conf_path = network_dir / config.config_path 

138 assert conf_path.exists(), f"Config at {conf_path} does not exist" 

139 

140 private_auth_key_path = config.contexts_file.parent / config.private_key 

141 assert private_auth_key_path.exists(), ( 

142 f"private_key at {private_auth_key_path} does not exist" 

143 ) 

144 

145 async def start_process(): 

146 return await asyncio.create_subprocess_exec( 

147 get_nebula_path(), 

148 "-config", 

149 str(conf_path), 

150 cwd=network_dir, 

151 ) 

152 

153 proc = await start_process() 

154 

155 # Default update interval in seconds 

156 update_interval = 5 

157 

158 while True: 

159 await asyncio.sleep(update_interval) 

160 try: 

161 private_auth_key_path = config.contexts_file.parent / config.private_key 

162 private_auth_key = JWK.from_json(private_auth_key_path.read_text()) 

163 

164 # Check for config updates 

165 try: 

166 ( 

167 new_config, 

168 new_update_interval, 

169 upgrade_required, 

170 ) = await get_config_from_mesh(mesh_admin_endpoint, private_auth_key) 

171 if upgrade_required: 

172 await logger.ainfo("Starting meshadmin self-upgrade") 

173 await perform_self_upgrade() 

174 

175 if update_interval != new_update_interval: 

176 await logger.ainfo( 

177 "update interval changed", 

178 old_interval=update_interval, 

179 new_interval=new_update_interval, 

180 ) 

181 update_interval = new_update_interval 

182 

183 old_config = conf_path.read_text() 

184 if new_config != old_config: 

185 await logger.ainfo("config changed, reloading") 

186 conf_path.write_text(new_config) 

187 conf_path.chmod(0o600) 

188 

189 try: 

190 proc.send_signal(signal.SIGHUP) 

191 except ProcessLookupError: 

192 await logger.ainfo("process died, restarting") 

193 proc = await start_process() 

194 else: 

195 await logger.ainfo("config not changed") 

196 except httpx.HTTPStatusError as e: 

197 if e.response.status_code == 401: 

198 await logger.aerror( 

199 "Could not get config because of authentication error. Host may have been deleted.", 

200 error=str(e), 

201 response_text=e.response.text, 

202 ) 

203 print( 

204 "Error: Could not get config because of authentication error. Host may have been deleted." 

205 ) 

206 print(f"Server message: {e.response.text}") 

207 break 

208 else: 

209 await logger.aerror("error getting config", error=str(e)) 

210 

211 # Cleanup ephemeral hosts 

212 try: 

213 result = await cleanup_ephemeral_hosts( 

214 mesh_admin_endpoint, private_auth_key 

215 ) 

216 if result.get("removed_count", 0) > 0: 

217 await logger.ainfo( 

218 "removed stale ephemeral hosts", 

219 count=result["removed_count"], 

220 ) 

221 except httpx.HTTPStatusError as e: 

222 if e.response.status_code == 401: 

223 await logger.aerror( 

224 "Could not clean up ephemeral hosts because of authentication error. Host may have been deleted.", 

225 error=str(e), 

226 response_text=e.response.text, 

227 ) 

228 print( 

229 "Error: Could not clean up ephemeral hosts because of authentication error. Host may have been deleted." 

230 ) 

231 print(f"Server message: {e.response.text}") 

232 break 

233 else: 

234 await logger.aerror("error during cleanup operation", error=str(e)) 

235 

236 except Exception: 

237 await logger.aexception("could not refresh token") 

238 if proc.returncode is not None: 

239 await logger.ainfo("process died, restarting") 

240 proc = await start_process() 

241 

242 # Clean shutdown if we get here 

243 if proc.returncode is None: 

244 await logger.ainfo("shutting down nebula process") 

245 proc.terminate() 

246 try: 

247 await asyncio.wait_for(proc.wait(), timeout=5.0) 

248 except asyncio.TimeoutError: 

249 await logger.awarning("nebula process didn't terminate, killing it") 

250 proc.kill() 

251 

252 

253async def perform_self_upgrade(): 

254 try: 

255 process = await asyncio.create_subprocess_exec( 

256 "uv", 

257 "tool", 

258 "install", 

259 "--upgrade", 

260 "meshadmin", 

261 stdout=asyncio.subprocess.PIPE, 

262 stderr=asyncio.subprocess.PIPE, 

263 ) 

264 stdout, stderr = await process.communicate() 

265 if process.returncode == 0: 

266 await logger.ainfo( 

267 "meshadmin successfully upgraded", 

268 stdout=stdout.decode().strip() if stdout else None, 

269 ) 

270 return True 

271 else: 

272 await logger.aerror( 

273 "meshadmin upgrade failed", 

274 returncode=process.returncode, 

275 stderr=stderr.decode().strip() if stderr else None, 

276 ) 

277 return False 

278 except Exception as e: 

279 await logger.aerror("Error during meshadmin self-upgrade", error=str(e)) 

280 return False