Coverage for src/meshadmin/cli/utils.py: 26%

131 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-04-22 07:26 +0200

1import asyncio 

2import json 

3import os 

4import signal 

5from datetime import datetime, timedelta 

6from importlib.metadata import version 

7from pathlib import Path 

8 

9import httpx 

10import structlog 

11import typer 

12import yaml 

13from jwcrypto.jwk import JWK 

14from jwcrypto.jwt import JWT 

15from jwt import decode 

16 

17from meshadmin.cli.config import get_config 

18from meshadmin.common.utils import create_expiration_date, get_nebula_path 

19 

20logger = structlog.get_logger(__name__) 

21 

22 

23def get_access_token(): 

24 if os.getenv("MESHADMIN_TEST_MODE") == "true": 

25 return "test-token" 

26 

27 config = get_config() 

28 if config.authentication_path.exists(): 

29 auth = json.loads(config.authentication_path.read_text()) 

30 access_token = auth["access_token"] 

31 

32 decoded_token = decode( 

33 access_token, options={"verify_signature": False, "verify_exp": False} 

34 ) 

35 

36 # is exp still 2/3 of the time 

37 if decoded_token["exp"] >= (datetime.now() + timedelta(seconds=10)).timestamp(): 

38 return access_token 

39 else: 

40 refresh_token = auth["refresh_token"] 

41 res = httpx.post( 

42 config.keycloak_token_url, 

43 data={ 

44 "grant_type": "refresh_token", 

45 "refresh_token": refresh_token, 

46 "client_id": config.keycloak_admin_client, 

47 }, 

48 ) 

49 res.raise_for_status() 

50 config.authentication_path.write_bytes(res.content) 

51 return res.json()["access_token"] 

52 

53 else: 

54 print("authentication failed") 

55 

56 

57def get_context_config(): 

58 config = get_config() 

59 if not config.contexts_file.exists(): 

60 print("No contexts found") 

61 raise typer.Exit(1) 

62 

63 with open(config.contexts_file) as f: 

64 contexts = yaml.safe_load(f) or {} 

65 

66 current = os.getenv("MESH_CONTEXT") 

67 if not current: 

68 active_contexts = [ 

69 name for name, data in contexts.items() if data.get("active") 

70 ] 

71 current = active_contexts[0] if active_contexts else None 

72 

73 if not current or current not in contexts: 

74 print("No active context. Please select a context with 'meshadmin context use'") 

75 raise typer.Exit(1) 

76 

77 context_data = contexts[current] 

78 network_dir = config.networks_dir / current 

79 

80 return { 

81 "name": current, 

82 "endpoint": context_data["endpoint"], 

83 "interface": context_data["interface"], 

84 "network_dir": network_dir, 

85 } 

86 

87 

88async def get_config_from_mesh(mesh_admin_endpoint, private_auth_key): 

89 jwt = JWT( 

90 header={"alg": "RS256", "kid": private_auth_key.thumbprint()}, 

91 claims={ 

92 "exp": create_expiration_date(10), 

93 "kid": private_auth_key.thumbprint(), 

94 }, 

95 ) 

96 jwt.make_signed_token(private_auth_key) 

97 token = jwt.serialize() 

98 

99 async with httpx.AsyncClient() as client: 

100 res = await client.get( 

101 f"{mesh_admin_endpoint}/api/v1/config", 

102 headers={ 

103 "Authorization": f"Bearer {token}", 

104 "X-Meshadmin-Version": version("meshadmin"), 

105 }, 

106 ) 

107 res.raise_for_status() 

108 config = res.text 

109 update_interval = int(res.headers.get("X-Update-Interval", "5")) 

110 return config, update_interval 

111 

112 

113async def cleanup_ephemeral_hosts(mesh_admin_endpoint, private_auth_key): 

114 jwt_token = JWT( 

115 header={"alg": "RS256", "kid": private_auth_key.thumbprint()}, 

116 claims={ 

117 "exp": create_expiration_date(10), 

118 "kid": private_auth_key.thumbprint(), 

119 }, 

120 ) 

121 jwt_token.make_signed_token(private_auth_key) 

122 token = jwt_token.serialize() 

123 

124 async with httpx.AsyncClient() as client: 

125 res = await client.post( 

126 f"{mesh_admin_endpoint}/api/v1/cleanup-ephemeral", 

127 headers={"Authorization": f"Bearer {token}"}, 

128 ) 

129 res.raise_for_status() 

130 return res.json() 

131 

132 

133async def start_nebula(network_dir: Path, mesh_admin_endpoint: str): 

134 config = get_config() 

135 await logger.ainfo("starting nebula") 

136 conf_path = network_dir / config.config_path 

137 assert conf_path.exists(), f"Config at {conf_path} does not exist" 

138 

139 private_auth_key_path = config.contexts_file.parent / config.private_key 

140 assert private_auth_key_path.exists(), ( 

141 f"private_key at {private_auth_key_path} does not exist" 

142 ) 

143 

144 async def start_process(): 

145 return await asyncio.create_subprocess_exec( 

146 get_nebula_path(), 

147 "-config", 

148 str(conf_path), 

149 cwd=network_dir, 

150 ) 

151 

152 proc = await start_process() 

153 

154 # Default update interval in seconds 

155 update_interval = 5 

156 

157 while True: 

158 await asyncio.sleep(update_interval) 

159 try: 

160 private_auth_key_path = config.contexts_file.parent / config.private_key 

161 private_auth_key = JWK.from_json(private_auth_key_path.read_text()) 

162 

163 # Check for config updates 

164 try: 

165 new_config, new_update_interval = await get_config_from_mesh( 

166 mesh_admin_endpoint, private_auth_key 

167 ) 

168 

169 if update_interval != new_update_interval: 

170 await logger.ainfo( 

171 "update interval changed", 

172 old_interval=update_interval, 

173 new_interval=new_update_interval, 

174 ) 

175 update_interval = new_update_interval 

176 

177 old_config = conf_path.read_text() 

178 if new_config != old_config: 

179 await logger.ainfo("config changed, reloading") 

180 conf_path.write_text(new_config) 

181 conf_path.chmod(0o600) 

182 

183 try: 

184 proc.send_signal(signal.SIGHUP) 

185 except ProcessLookupError: 

186 await logger.ainfo("process died, restarting") 

187 proc = await start_process() 

188 else: 

189 await logger.ainfo("config not changed") 

190 except httpx.HTTPStatusError as e: 

191 if e.response.status_code == 401: 

192 await logger.aerror( 

193 "Could not get config because of authentication error. Host may have been deleted.", 

194 error=str(e), 

195 response_text=e.response.text, 

196 ) 

197 print( 

198 "Error: Could not get config because of authentication error. Host may have been deleted." 

199 ) 

200 print(f"Server message: {e.response.text}") 

201 break 

202 else: 

203 await logger.aerror("error getting config", error=str(e)) 

204 

205 # Cleanup ephemeral hosts 

206 try: 

207 result = await cleanup_ephemeral_hosts( 

208 mesh_admin_endpoint, private_auth_key 

209 ) 

210 if result.get("removed_count", 0) > 0: 

211 await logger.ainfo( 

212 "removed stale ephemeral hosts", 

213 count=result["removed_count"], 

214 ) 

215 except httpx.HTTPStatusError as e: 

216 if e.response.status_code == 401: 

217 await logger.aerror( 

218 "Could not clean up ephemeral hosts because of authentication error. Host may have been deleted.", 

219 error=str(e), 

220 response_text=e.response.text, 

221 ) 

222 print( 

223 "Error: Could not clean up ephemeral hosts because of authentication error. Host may have been deleted." 

224 ) 

225 print(f"Server message: {e.response.text}") 

226 break 

227 else: 

228 await logger.aerror("error during cleanup operation", error=str(e)) 

229 

230 except Exception: 

231 await logger.aexception("could not refresh token") 

232 if proc.returncode is not None: 

233 await logger.ainfo("process died, restarting") 

234 proc = await start_process() 

235 

236 # Clean shutdown if we get here 

237 if proc.returncode is None: 

238 await logger.ainfo("shutting down nebula process") 

239 proc.terminate() 

240 try: 

241 await asyncio.wait_for(proc.wait(), timeout=5.0) 

242 except asyncio.TimeoutError: 

243 await logger.awarning("nebula process didn't terminate, killing it") 

244 proc.kill()