Coverage for src/meshadmin/server/networks/services.py: 94%

232 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-05-07 19:26 +0200

1import copy 

2import hashlib 

3import ipaddress 

4import json 

5import os 

6from datetime import datetime 

7from pathlib import Path 

8from typing import Optional 

9 

10import structlog 

11import yaml 

12from django.contrib.auth import get_user_model 

13from django.utils.timezone import now 

14from jwcrypto.jwk import JWK 

15from jwcrypto.jwt import JWT 

16 

17from meshadmin.common.utils import create_ca, print_ca, sign_keys 

18from meshadmin.server import assets 

19from meshadmin.server.networks.models import ( 

20 CA, 

21 Group, 

22 Host, 

23 HostCert, 

24 HostConfig, 

25 Network, 

26 NetworkMembership, 

27 Rule, 

28 SigningCA, 

29 Template, 

30) 

31 

32User = get_user_model() 

33logger = structlog.get_logger(__name__) 

34 

35 

36def create_available_hosts_iterator(cidr, unavailable_ips): 

37 network = ipaddress.IPv4Network(cidr) 

38 hosts_iterator = ( 

39 host for host in network.hosts() if str(host) not in unavailable_ips 

40 ) 

41 return hosts_iterator 

42 

43 

44def network_available_hosts_iterator(network): 

45 reserved_ips = [ 

46 host.assigned_ip for host in Host.objects.filter(network=network).all() 

47 ] 

48 ipv4_iterator = create_available_hosts_iterator(network.cidr, reserved_ips) 

49 return ipv4_iterator 

50 

51 

52def create_network_ca(ca_name, network): 

53 cert, key = create_ca(ca_name) 

54 cert_print = print_ca(cert) 

55 ca = CA.objects.create( 

56 network=network, name=ca_name, cert=cert, key=key, cert_print=cert_print 

57 ) 

58 return ca 

59 

60 

61def create_network( 

62 network_name: str, network_cidr: str, user: User, update_interval: int = 5 

63): 

64 logger.info("creating network") 

65 network = Network.objects.create( 

66 name=network_name, cidr=network_cidr, update_interval=update_interval 

67 ) 

68 NetworkMembership.objects.create( 

69 network=network, user=user, role=NetworkMembership.Role.ADMIN 

70 ) 

71 

72 ca_name = "auto created initial ca" 

73 cert, key = create_ca(ca_name) 

74 json_data = print_ca(cert) 

75 ca = CA.objects.create( 

76 network=network, name=ca_name, cert=cert, key=key, cert_print=json_data 

77 ) 

78 

79 SigningCA.objects.create(network=network, ca=ca) 

80 logger.info("created network", network=str(network)) 

81 return network 

82 

83 

84def apply_group_config_overrides(config_data: dict, groups: list[Group]) -> dict: 

85 config = copy.deepcopy(config_data) 

86 overrides = [] 

87 for group in groups: 

88 for override in group.config_overrides.all(): 

89 overrides.append((group.name, override)) 

90 

91 # Sort overrides by group name to ensure consistent ordering 

92 overrides.sort(key=lambda x: x[0]) 

93 

94 for group_name, override in overrides: 

95 try: 

96 path_parts = override.key.split(".") 

97 current = config 

98 for part in path_parts[:-1]: 

99 if part not in current: 

100 current[part] = {} 

101 current = current[part] 

102 

103 value = override.value 

104 # Handle boolean fields 

105 if any( 

106 override.key.endswith(suffix) 

107 for suffix in [ 

108 ".serve_dns", 

109 ".punch", 

110 ".respond", 

111 ".message_metrics", 

112 ".lighthouse_metrics", 

113 ] 

114 ): 

115 value = value.lower() == "true" 

116 

117 # Handle numeric fields 

118 elif any( 

119 override.key.endswith(suffix) 

120 for suffix in [ 

121 ".port", 

122 ".delay", 

123 ".respond_delay", 

124 ".interval", 

125 ] 

126 ): 

127 value = int(value) 

128 

129 current[path_parts[-1]] = value 

130 logger.info( 

131 "applied group config override", 

132 group=group_name, 

133 path=override.key, 

134 value=value, 

135 ) 

136 except Exception as e: 

137 logger.error( 

138 "failed to apply group config override", 

139 group=group_name, 

140 path=override.key, 

141 error=str(e), 

142 ) 

143 

144 return config 

145 

146 

147def generate_config_yaml(host_id: int, ignore_freeze: bool = False): 

148 host = Host.objects.get(id=host_id) 

149 if host.config_freeze and not ignore_freeze: 

150 last_config = host.hostconfig_set.order_by("-created_at").first() 

151 if last_config: 

152 logger.info("using frozen config", host_id=host.id, host_name=host.name) 

153 return last_config.config 

154 

155 logger.info("generating config", host_id=host.id, host_name=host.name) 

156 

157 network = host.network 

158 ca = network.signingca.ca 

159 

160 assert host.public_key 

161 # load config yaml 

162 config_template = (assets.asset_path / "config.yml").read_text() 

163 config_data = yaml.safe_load(config_template) 

164 config_data["pki"]["ca"] = "".join([ca.cert for ca in network.ca_set.all()]) 

165 

166 groups = frozenset([group.name for group in host.groups.all()]) 

167 assigned_ip = f"{host.assigned_ip}/24" 

168 

169 group_timestamps = frozenset([group.updated_at for group in host.groups.all()]) 

170 group_rules = frozenset( 

171 [rule.updated_at for group in host.groups.all() for rule in group.rules.all()] 

172 ) 

173 group_configs = frozenset( 

174 [ 

175 config.updated_at 

176 for group in host.groups.all() 

177 for config in group.config_overrides.all() 

178 ] 

179 ) 

180 

181 query_key = ( 

182 ca.cert, 

183 host.public_key, 

184 host.name, 

185 assigned_ip, 

186 hash(group_timestamps), 

187 hash(group_rules), 

188 hash(group_configs), 

189 ) 

190 query_key_hash = hash(query_key) 

191 host_cert = HostCert.objects.filter(host=host, ca=ca).first() 

192 if host_cert is None or host_cert.hash != query_key_hash: 

193 logger.info( 

194 "generating new host certificate", 

195 host_id=host.id, 

196 host_name=host.name, 

197 reason="initial" if host_cert is None else "hash_changed", 

198 ) 

199 if host_cert: 

200 host_cert.delete() 

201 

202 cert = sign_keys( 

203 ca_key=ca.key, 

204 ca_crt=ca.cert, 

205 public_key=host.public_key, 

206 name=host.name, 

207 ip=assigned_ip, 

208 groups=groups, 

209 ) 

210 host_cert = HostCert.objects.create( 

211 host=host, ca=ca, cert=cert, hash=query_key_hash 

212 ) 

213 

214 config_data["pki"]["cert"] = host_cert.cert 

215 config_data["pki"]["key"] = "host.key" 

216 

217 lighthouses = Host.objects.filter(network=network, is_lighthouse=True).all() 

218 if host.is_lighthouse: 

219 config_data["lighthouse"]["am_lighthouse"] = True 

220 config_data["lighthouse"]["hosts"] = [] 

221 else: 

222 config_data["lighthouse"]["am_lighthouse"] = False 

223 config_data["lighthouse"]["hosts"] = [ 

224 lighthouse.assigned_ip for lighthouse in lighthouses 

225 ] 

226 config_data["static_host_map"] = { 

227 lighthouse.assigned_ip: [f"{lighthouse.public_ip_or_hostname}:4242"] 

228 for lighthouse in lighthouses 

229 } 

230 config_data["relay"]["am_relay"] = host.is_relay 

231 config_data["relay"]["use_relays"] = host.use_relay 

232 

233 config_data["tun"]["dev"] = host.interface 

234 

235 inbound_rules = [] 

236 outbound_rules = [] 

237 for group in host.groups.all(): 

238 for rule in group.rules.all(): 

239 rule_data = {} 

240 

241 if rule.local_cidr is not None: 

242 rule_data["local_cidr"] = rule.local_cidr 

243 

244 if rule.cidr is not None: 

245 rule_data["cidr"] = rule.cidr 

246 

247 groups = rule.groups.all() 

248 if len(groups) > 0: 

249 rule_data["groups"] = [group.name for group in groups] 

250 

251 if rule.group is not None: 

252 rule_data["group"] = rule.group.name 

253 

254 if rule.proto is not None: 

255 rule_data["proto"] = rule.proto 

256 

257 if rule.port is not None: 

258 rule_data["port"] = rule.port 

259 

260 if rule.direction == Rule.Direction.INBOUND: 

261 inbound_rules.append(rule_data) 

262 else: 

263 outbound_rules.append(rule_data) 

264 

265 config_data["firewall"]["inbound"] = inbound_rules 

266 config_data["firewall"]["outbound"] = outbound_rules 

267 

268 config_data = apply_group_config_overrides(config_data, host.groups.all()) 

269 

270 yaml_config = yaml.safe_dump(config_data, indent=4) 

271 

272 sha256 = hashlib.sha256(yaml_config.encode()).hexdigest() 

273 

274 if not HostConfig.objects.filter(sha256=sha256).exists(): 

275 logger.info("saving new config version", host_id=host.id, host_name=host.name) 

276 HostConfig.objects.create(host=host, config=yaml_config, sha256=sha256) 

277 else: 

278 logger.debug("config unchanged", host_id=host.id, host_name=host.name) 

279 

280 return yaml_config 

281 

282 

283def create_template( 

284 name: str, 

285 network_name: str, 

286 is_lighthouse=False, 

287 is_relay: bool = False, 

288 use_relay: bool = True, 

289 groups: list[str] = (), 

290 reusable: bool = True, 

291 usage_limit: int = None, 

292 ephemeral_peers: bool = False, 

293 expires_at: datetime = None, 

294): 

295 template = Template.objects.create( 

296 name=name, 

297 network=Network.objects.get(name=network_name), 

298 is_lighthouse=is_lighthouse, 

299 is_relay=is_relay, 

300 use_relay=use_relay, 

301 reusable=reusable, 

302 usage_limit=usage_limit, 

303 ephemeral_peers=ephemeral_peers, 

304 expires_at=expires_at, 

305 ) 

306 

307 for group in groups: 

308 try: 

309 template.groups.add( 

310 Group.objects.get(name=group, network__name=network_name) 

311 ) 

312 except Group.DoesNotExist: 

313 raise LookupError(f"Group does not exist in network {group}/{network_name}") 

314 

315 return template 

316 

317 

318def get_server_signing_key(): 

319 key_path = Path("enrollment_signing.key") 

320 if os.path.exists(key_path): 

321 with open(key_path, "r") as f: 

322 return JWK.from_json(f.read()) 

323 else: 

324 key = JWK.generate(kty="EC", crv="P-256") 

325 with open(key_path, "w") as f: 

326 f.write(key.export_private()) 

327 os.chmod(key_path, 0o600) 

328 return key 

329 

330 

331def generate_enrollment_token(template: Template, ttl: int = None): 

332 signing_key = get_server_signing_key() 

333 claims = { 

334 "jti": str(template.enrollment_key), 

335 "iss": "meshadmin", 

336 "sub": f"template:{template.id}", 

337 "iat": int(now().timestamp()), 

338 "template_id": template.id, 

339 "network_id": template.network_id, 

340 "is_lighthouse": template.is_lighthouse, 

341 "is_relay": template.is_relay, 

342 "use_relay": template.use_relay, 

343 "reusable": template.reusable, 

344 "usage_limit": template.usage_limit, 

345 "ephemeral_peers": template.ephemeral_peers, 

346 } 

347 if ttl: 

348 claims["exp"] = int(now().timestamp() + ttl) 

349 elif template.expires_at: 

350 claims["exp"] = int(template.expires_at.timestamp()) 

351 token = JWT(header={"alg": "ES256"}, claims=claims) 

352 token.make_signed_token(signing_key) 

353 return token.serialize() 

354 

355 

356def verify_enrollment_token(token_string): 

357 signing_key = get_server_signing_key() 

358 try: 

359 token = JWT(jwt=token_string) 

360 token.validate(signing_key) 

361 payload = json.loads(token.token.objects["payload"]) 

362 template_id = payload.get("template_id") 

363 if not template_id: 

364 raise ValueError("Invalid token: missing template_id") 

365 return template_id 

366 except Exception as e: 

367 error_message = str(e) 

368 if "Expired" in error_message: 

369 raise ValueError("Enrollment token has expired") 

370 elif "Invalid signature" in error_message: 

371 raise ValueError("Invalid enrollment token signature") 

372 else: 

373 logger.error( 

374 "invalid enrollment token", token=token_string, error=error_message 

375 ) 

376 raise ValueError(f"Invalid enrollment token: {error_message}") 

377 

378 

379def enrollment( 

380 enrollment_key: str, 

381 public_auth_key: str, 

382 enroll_on_existence: bool, 

383 public_ip: str, 

384 preferred_hostname, 

385 public_net_key, 

386 interface: str = "nebula1", 

387): 

388 try: 

389 template_id = verify_enrollment_token(enrollment_key) 

390 template = Template.objects.get(id=template_id) 

391 except ValueError as e: 

392 logger.error("invalid enrollment token", error=str(e)) 

393 raise ValueError(f"Invalid enrollment token: {str(e)}") 

394 except Template.DoesNotExist: 

395 logger.error("template not found", template_id=template_id) 

396 raise ValueError("Template not found") 

397 

398 # Check usage limit 

399 if not template.reusable: 

400 if template.usage_count >= 1: 

401 logger.error( 

402 "single-use enrollment key has already been used", 

403 template_id=template.id, 

404 ) 

405 raise ValueError("Single-use enrollment key has already been used") 

406 elif template.usage_limit: 

407 if template.usage_count >= template.usage_limit: 

408 logger.error("enrollment key usage limit exceeded", template_id=template.id) 

409 raise ValueError("Enrollment key usage limit exceeded") 

410 

411 # check if public key is already enrolled 

412 thumbprint = JWK.from_json(public_auth_key).thumbprint() 

413 host: Optional[Host] = Host.objects.filter(public_auth_kid=thumbprint).first() 

414 

415 # host already registered 

416 if host: 

417 if enroll_on_existence: 

418 logger.info( 

419 "host already exists, aborting enrollment", 

420 host_id=host.id, 

421 enroll_on_existence=enroll_on_existence, 

422 ) 

423 raise ValueError("Host already enrolled") 

424 

425 else: 

426 host.delete() 

427 

428 network = template.network 

429 ipv4_iterator = network_available_hosts_iterator(network) 

430 

431 if template.is_lighthouse and not public_ip: 

432 raise ValueError("Cannot enroll a lighthouse without public_ip") 

433 

434 jwk_public_auth = JWK.from_json(public_auth_key) 

435 

436 already_registered_hostnames = ( 

437 Host.objects.values("name") 

438 .filter(network=network, name__startswith=preferred_hostname) 

439 .all() 

440 ) 

441 final_hostname = preferred_hostname 

442 i = 1 

443 existing_hostnames = set([host["name"] for host in already_registered_hostnames]) 

444 while final_hostname in existing_hostnames: 

445 final_hostname = f"{preferred_hostname}-{i}" 

446 i += 1 

447 

448 host = Host.objects.create( 

449 network=network, 

450 name=final_hostname, 

451 assigned_ip=next(ipv4_iterator), 

452 is_relay=template.is_relay, 

453 is_lighthouse=template.is_lighthouse, 

454 public_ip_or_hostname=public_ip, 

455 public_key=public_net_key, 

456 public_auth_key=public_auth_key, 

457 public_auth_kid=jwk_public_auth.thumbprint(), 

458 is_ephemeral=template.ephemeral_peers, 

459 interface=interface, 

460 ) 

461 

462 # Increment usage count 

463 template.usage_count += 1 

464 template.save() 

465 

466 for group in template.groups.all(): 

467 host.groups.add(group) 

468 

469 host.save() 

470 return host 

471 

472 

473def create_group(network_pk: int, group_name: str, description: str = ""): 

474 network = Network.objects.get(pk=network_pk) 

475 return Group.objects.create( 

476 network=network, name=group_name, description=description 

477 )