Coverage for src/meshadmin/server/networks/services.py: 94%

230 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-05-06 11:34 +0200

1import copy 

2import hashlib 

3import ipaddress 

4import json 

5import os 

6from datetime import datetime 

7from pathlib import Path 

8from typing import Optional 

9 

10import structlog 

11import yaml 

12from django.contrib.auth import get_user_model 

13from django.utils.timezone import now 

14from jwcrypto.jwk import JWK 

15from jwcrypto.jwt import JWT 

16 

17from meshadmin.common.utils import create_ca, print_ca, sign_keys 

18from meshadmin.server import assets 

19from meshadmin.server.networks.models import ( 

20 CA, 

21 Group, 

22 Host, 

23 HostCert, 

24 HostConfig, 

25 Network, 

26 NetworkMembership, 

27 Rule, 

28 SigningCA, 

29 Template, 

30) 

31 

32User = get_user_model() 

33logger = structlog.get_logger(__name__) 

34 

35 

36def create_available_hosts_iterator(cidr, unavailable_ips): 

37 network = ipaddress.IPv4Network(cidr) 

38 hosts_iterator = ( 

39 host for host in network.hosts() if str(host) not in unavailable_ips 

40 ) 

41 return hosts_iterator 

42 

43 

44def network_available_hosts_iterator(network): 

45 reserved_ips = [ 

46 host.assigned_ip for host in Host.objects.filter(network=network).all() 

47 ] 

48 ipv4_iterator = create_available_hosts_iterator(network.cidr, reserved_ips) 

49 return ipv4_iterator 

50 

51 

52def create_network_ca(ca_name, network): 

53 cert, key = create_ca(ca_name) 

54 cert_print = print_ca(cert) 

55 ca = CA.objects.create( 

56 network=network, name=ca_name, cert=cert, key=key, cert_print=cert_print 

57 ) 

58 return ca 

59 

60 

61def create_network( 

62 network_name: str, network_cidr: str, user: User, update_interval: int = 5 

63): 

64 logger.info("creating network") 

65 network = Network.objects.create( 

66 name=network_name, cidr=network_cidr, update_interval=update_interval 

67 ) 

68 NetworkMembership.objects.create( 

69 network=network, user=user, role=NetworkMembership.Role.ADMIN 

70 ) 

71 

72 ca_name = "auto created initial ca" 

73 cert, key = create_ca(ca_name) 

74 json_data = print_ca(cert) 

75 ca = CA.objects.create( 

76 network=network, name=ca_name, cert=cert, key=key, cert_print=json_data 

77 ) 

78 

79 SigningCA.objects.create(network=network, ca=ca) 

80 logger.info("created network", network=str(network)) 

81 return network 

82 

83 

84def apply_group_config_overrides(config_data: dict, groups: list[Group]) -> dict: 

85 config = copy.deepcopy(config_data) 

86 overrides = [] 

87 for group in groups: 

88 for override in group.config_overrides.all(): 

89 overrides.append((group.name, override)) 

90 

91 # Sort overrides by group name to ensure consistent ordering 

92 overrides.sort(key=lambda x: x[0]) 

93 

94 for group_name, override in overrides: 

95 try: 

96 path_parts = override.key.split(".") 

97 current = config 

98 for part in path_parts[:-1]: 

99 if part not in current: 

100 current[part] = {} 

101 current = current[part] 

102 

103 value = override.value 

104 if any( 

105 override.key.endswith(suffix) 

106 for suffix in [ 

107 ".serve_dns", 

108 ] 

109 ): 

110 value = value.lower() == "true" 

111 elif any( 

112 override.key.endswith(suffix) 

113 for suffix in [ 

114 ".port", 

115 ] 

116 ): 

117 value = int(value) 

118 

119 current[path_parts[-1]] = value 

120 logger.info( 

121 "applied group config override", 

122 group=group_name, 

123 path=override.key, 

124 value=value, 

125 ) 

126 except Exception as e: 

127 logger.error( 

128 "failed to apply group config override", 

129 group=group_name, 

130 path=override.key, 

131 error=str(e), 

132 ) 

133 

134 return config 

135 

136 

137def generate_config_yaml(host_id: int, ignore_freeze: bool = False): 

138 host = Host.objects.get(id=host_id) 

139 if host.config_freeze and not ignore_freeze: 

140 last_config = host.hostconfig_set.order_by("-created_at").first() 

141 if last_config: 

142 logger.info("using frozen config", host_id=host.id, host_name=host.name) 

143 return last_config.config 

144 

145 logger.info("generating config", host_id=host.id, host_name=host.name) 

146 

147 network = host.network 

148 ca = network.signingca.ca 

149 

150 assert host.public_key 

151 # load config yaml 

152 config_template = (assets.asset_path / "config.yml").read_text() 

153 config_data = yaml.safe_load(config_template) 

154 config_data["pki"]["ca"] = "".join([ca.cert for ca in network.ca_set.all()]) 

155 

156 groups = frozenset([group.name for group in host.groups.all()]) 

157 assigned_ip = f"{host.assigned_ip}/24" 

158 

159 group_timestamps = frozenset([group.updated_at for group in host.groups.all()]) 

160 group_rules = frozenset( 

161 [rule.updated_at for group in host.groups.all() for rule in group.rules.all()] 

162 ) 

163 group_configs = frozenset( 

164 [ 

165 config.updated_at 

166 for group in host.groups.all() 

167 for config in group.config_overrides.all() 

168 ] 

169 ) 

170 

171 query_key = ( 

172 ca.cert, 

173 host.public_key, 

174 host.name, 

175 assigned_ip, 

176 hash(group_timestamps), 

177 hash(group_rules), 

178 hash(group_configs), 

179 ) 

180 query_key_hash = hash(query_key) 

181 host_cert = HostCert.objects.filter(host=host, ca=ca).first() 

182 if host_cert is None or host_cert.hash != query_key_hash: 

183 logger.info( 

184 "generating new host certificate", 

185 host_id=host.id, 

186 host_name=host.name, 

187 reason="initial" if host_cert is None else "hash_changed", 

188 ) 

189 if host_cert: 

190 host_cert.delete() 

191 

192 cert = sign_keys( 

193 ca_key=ca.key, 

194 ca_crt=ca.cert, 

195 public_key=host.public_key, 

196 name=host.name, 

197 ip=assigned_ip, 

198 groups=groups, 

199 ) 

200 host_cert = HostCert.objects.create( 

201 host=host, ca=ca, cert=cert, hash=query_key_hash 

202 ) 

203 

204 config_data["pki"]["cert"] = host_cert.cert 

205 config_data["pki"]["key"] = "host.key" 

206 

207 lighthouses = Host.objects.filter(network=network, is_lighthouse=True).all() 

208 if host.is_lighthouse: 

209 config_data["lighthouse"]["am_lighthouse"] = True 

210 config_data["lighthouse"]["hosts"] = [] 

211 else: 

212 config_data["lighthouse"]["am_lighthouse"] = False 

213 config_data["lighthouse"]["hosts"] = [ 

214 lighthouse.assigned_ip for lighthouse in lighthouses 

215 ] 

216 config_data["static_host_map"] = { 

217 lighthouse.assigned_ip: [f"{lighthouse.public_ip_or_hostname}:4242"] 

218 for lighthouse in lighthouses 

219 } 

220 config_data["relay"]["am_relay"] = host.is_relay 

221 config_data["relay"]["use_relays"] = host.use_relay 

222 

223 config_data["tun"]["dev"] = host.interface 

224 

225 inbound_rules = [] 

226 outbound_rules = [] 

227 for group in host.groups.all(): 

228 for rule in group.rules.all(): 

229 rule_data = {} 

230 

231 if rule.local_cidr is not None: 

232 rule_data["local_cidr"] = rule.local_cidr 

233 

234 if rule.cidr is not None: 

235 rule_data["cidr"] = rule.cidr 

236 

237 groups = rule.groups.all() 

238 if len(groups) > 0: 

239 rule_data["groups"] = [group.name for group in groups] 

240 

241 if rule.group is not None: 

242 rule_data["group"] = rule.group.name 

243 

244 if rule.proto is not None: 

245 rule_data["proto"] = rule.proto 

246 

247 if rule.port is not None: 

248 rule_data["port"] = rule.port 

249 

250 if rule.direction == Rule.Direction.INBOUND: 

251 inbound_rules.append(rule_data) 

252 else: 

253 outbound_rules.append(rule_data) 

254 

255 config_data["firewall"]["inbound"] = inbound_rules 

256 config_data["firewall"]["outbound"] = outbound_rules 

257 

258 config_data = apply_group_config_overrides(config_data, host.groups.all()) 

259 

260 yaml_config = yaml.safe_dump(config_data, indent=4) 

261 

262 sha256 = hashlib.sha256(yaml_config.encode()).hexdigest() 

263 

264 if not HostConfig.objects.filter(sha256=sha256).exists(): 

265 logger.info("saving new config version", host_id=host.id, host_name=host.name) 

266 HostConfig.objects.create(host=host, config=yaml_config, sha256=sha256) 

267 else: 

268 logger.debug("config unchanged", host_id=host.id, host_name=host.name) 

269 

270 return yaml_config 

271 

272 

273def create_template( 

274 name: str, 

275 network_name: str, 

276 is_lighthouse=False, 

277 is_relay: bool = False, 

278 use_relay: bool = True, 

279 groups: list[str] = (), 

280 reusable: bool = True, 

281 usage_limit: int = None, 

282 ephemeral_peers: bool = False, 

283 expires_at: datetime = None, 

284): 

285 template = Template.objects.create( 

286 name=name, 

287 network=Network.objects.get(name=network_name), 

288 is_lighthouse=is_lighthouse, 

289 is_relay=is_relay, 

290 use_relay=use_relay, 

291 reusable=reusable, 

292 usage_limit=usage_limit, 

293 ephemeral_peers=ephemeral_peers, 

294 expires_at=expires_at, 

295 ) 

296 

297 for group in groups: 

298 try: 

299 template.groups.add( 

300 Group.objects.get(name=group, network__name=network_name) 

301 ) 

302 except Group.DoesNotExist: 

303 raise LookupError(f"Group does not exist in network {group}/{network_name}") 

304 

305 return template 

306 

307 

308def get_server_signing_key(): 

309 key_path = Path("enrollment_signing.key") 

310 if os.path.exists(key_path): 

311 with open(key_path, "r") as f: 

312 return JWK.from_json(f.read()) 

313 else: 

314 key = JWK.generate(kty="EC", crv="P-256") 

315 with open(key_path, "w") as f: 

316 f.write(key.export_private()) 

317 os.chmod(key_path, 0o600) 

318 return key 

319 

320 

321def generate_enrollment_token(template: Template): 

322 signing_key = get_server_signing_key() 

323 claims = { 

324 "jti": str(template.enrollment_key), 

325 "iss": "meshadmin", 

326 "sub": f"template:{template.id}", 

327 "iat": int(now().timestamp()), 

328 "template_id": template.id, 

329 "network_id": template.network_id, 

330 "is_lighthouse": template.is_lighthouse, 

331 "is_relay": template.is_relay, 

332 "use_relay": template.use_relay, 

333 "reusable": template.reusable, 

334 "usage_limit": template.usage_limit, 

335 "ephemeral_peers": template.ephemeral_peers, 

336 } 

337 if template.expires_at: 

338 claims["exp"] = int(template.expires_at.timestamp()) 

339 token = JWT(header={"alg": "ES256"}, claims=claims) 

340 token.make_signed_token(signing_key) 

341 return token.serialize() 

342 

343 

344def verify_enrollment_token(token_string): 

345 signing_key = get_server_signing_key() 

346 try: 

347 token = JWT(jwt=token_string) 

348 token.validate(signing_key) 

349 payload = json.loads(token.token.objects["payload"]) 

350 template_id = payload.get("template_id") 

351 if not template_id: 

352 raise ValueError("Invalid token: missing template_id") 

353 return template_id 

354 except Exception as e: 

355 error_message = str(e) 

356 if "Expired" in error_message: 

357 raise ValueError("Enrollment token has expired") 

358 elif "Invalid signature" in error_message: 

359 raise ValueError("Invalid enrollment token signature") 

360 else: 

361 logger.error( 

362 "invalid enrollment token", token=token_string, error=error_message 

363 ) 

364 raise ValueError(f"Invalid enrollment token: {error_message}") 

365 

366 

367def enrollment( 

368 enrollment_key: str, 

369 public_auth_key: str, 

370 enroll_on_existence: bool, 

371 public_ip: str, 

372 preferred_hostname, 

373 public_net_key, 

374 interface: str = "nebula1", 

375): 

376 try: 

377 template_id = verify_enrollment_token(enrollment_key) 

378 template = Template.objects.get(id=template_id) 

379 except ValueError as e: 

380 logger.error("invalid enrollment token", error=str(e)) 

381 raise ValueError(f"Invalid enrollment token: {str(e)}") 

382 except Template.DoesNotExist: 

383 logger.error("template not found", template_id=template_id) 

384 raise ValueError("Template not found") 

385 

386 # Check usage limit 

387 if not template.reusable: 

388 if template.usage_count >= 1: 

389 logger.error( 

390 "single-use enrollment key has already been used", 

391 template_id=template.id, 

392 ) 

393 raise ValueError("Single-use enrollment key has already been used") 

394 elif template.usage_limit: 

395 if template.usage_count >= template.usage_limit: 

396 logger.error("enrollment key usage limit exceeded", template_id=template.id) 

397 raise ValueError("Enrollment key usage limit exceeded") 

398 

399 # check if public key is already enrolled 

400 thumbprint = JWK.from_json(public_auth_key).thumbprint() 

401 host: Optional[Host] = Host.objects.filter(public_auth_kid=thumbprint).first() 

402 

403 # host already registered 

404 if host: 

405 if enroll_on_existence: 

406 logger.info( 

407 "host already exists, aborting enrollment", 

408 host_id=host.id, 

409 enroll_on_existence=enroll_on_existence, 

410 ) 

411 raise ValueError("Host already enrolled") 

412 

413 else: 

414 host.delete() 

415 

416 network = template.network 

417 ipv4_iterator = network_available_hosts_iterator(network) 

418 

419 if template.is_lighthouse and not public_ip: 

420 raise ValueError("Cannot enroll a lighthouse without public_ip") 

421 

422 jwk_public_auth = JWK.from_json(public_auth_key) 

423 

424 already_registered_hostnames = ( 

425 Host.objects.values("name") 

426 .filter(network=network, name__startswith=preferred_hostname) 

427 .all() 

428 ) 

429 final_hostname = preferred_hostname 

430 i = 1 

431 existing_hostnames = set([host["name"] for host in already_registered_hostnames]) 

432 while final_hostname in existing_hostnames: 

433 final_hostname = f"{preferred_hostname}-{i}" 

434 i += 1 

435 

436 host = Host.objects.create( 

437 network=network, 

438 name=final_hostname, 

439 assigned_ip=next(ipv4_iterator), 

440 is_relay=template.is_relay, 

441 is_lighthouse=template.is_lighthouse, 

442 public_ip_or_hostname=public_ip, 

443 public_key=public_net_key, 

444 public_auth_key=public_auth_key, 

445 public_auth_kid=jwk_public_auth.thumbprint(), 

446 is_ephemeral=template.ephemeral_peers, 

447 interface=interface, 

448 ) 

449 

450 # Increment usage count 

451 template.usage_count += 1 

452 template.save() 

453 

454 for group in template.groups.all(): 

455 host.groups.add(group) 

456 

457 host.save() 

458 return host 

459 

460 

461def create_group(network_pk: int, group_name: str, description: str = ""): 

462 network = Network.objects.get(pk=network_pk) 

463 return Group.objects.create( 

464 network=network, name=group_name, description=description 

465 )