Coverage for src/meshadmin/server/networks/services.py: 94%
236 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-09 15:09 +0200
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-09 15:09 +0200
1import copy
2import hashlib
3import ipaddress
4import json
5import os
6from datetime import datetime
7from pathlib import Path
8from typing import Optional
10import structlog
11import yaml
12from django.contrib.auth import get_user_model
13from django.utils.timezone import now
14from jwcrypto.jwk import JWK
15from jwcrypto.jwt import JWT
17from meshadmin.common.utils import create_ca, print_ca, sign_keys
18from meshadmin.server import assets
19from meshadmin.server.networks.models import (
20 CA,
21 Group,
22 Host,
23 HostCert,
24 HostConfig,
25 Network,
26 NetworkMembership,
27 Rule,
28 SigningCA,
29 Template,
30)
32User = get_user_model()
33logger = structlog.get_logger(__name__)
36def create_available_hosts_iterator(cidr, unavailable_ips):
37 network = ipaddress.IPv4Network(cidr)
38 hosts_iterator = (
39 host for host in network.hosts() if str(host) not in unavailable_ips
40 )
41 return hosts_iterator
44def network_available_hosts_iterator(network):
45 reserved_ips = [
46 host.assigned_ip for host in Host.objects.filter(network=network).all()
47 ]
48 ipv4_iterator = create_available_hosts_iterator(network.cidr, reserved_ips)
49 return ipv4_iterator
52def create_network_ca(ca_name, network):
53 cert, key = create_ca(ca_name)
54 cert_print = print_ca(cert)
55 ca = CA.objects.create(
56 network=network, name=ca_name, cert=cert, key=key, cert_print=cert_print
57 )
58 return ca
61def create_network(
62 network_name: str, network_cidr: str, user: User, update_interval: int = 5
63):
64 logger.info("creating network")
65 network = Network.objects.create(
66 name=network_name, cidr=network_cidr, update_interval=update_interval
67 )
68 NetworkMembership.objects.create(
69 network=network, user=user, role=NetworkMembership.Role.ADMIN
70 )
72 ca_name = "auto created initial ca"
73 cert, key = create_ca(ca_name)
74 json_data = print_ca(cert)
75 ca = CA.objects.create(
76 network=network, name=ca_name, cert=cert, key=key, cert_print=json_data
77 )
79 SigningCA.objects.create(network=network, ca=ca)
80 logger.info("created network", network=str(network))
81 return network
84def apply_group_config_overrides(config_data: dict, groups: list[Group]) -> dict:
85 config = copy.deepcopy(config_data)
86 overrides = []
87 for group in groups:
88 for override in group.config_overrides.all():
89 overrides.append((group.name, override))
91 # Sort overrides by group name to ensure consistent ordering
92 overrides.sort(key=lambda x: x[0])
94 for group_name, override in overrides:
95 try:
96 path_parts = override.key.split(".")
97 current = config
98 for part in path_parts[:-1]:
99 if part not in current:
100 current[part] = {}
101 current = current[part]
103 value = override.value
104 # Handle boolean fields
105 if any(
106 override.key.endswith(suffix)
107 for suffix in [
108 ".serve_dns",
109 ".punch",
110 ".respond",
111 ".message_metrics",
112 ".lighthouse_metrics",
113 ]
114 ):
115 value = value.lower() == "true"
117 # Handle numeric fields
118 elif any(
119 override.key.endswith(suffix)
120 for suffix in [
121 ".port",
122 ".delay",
123 ".respond_delay",
124 ".interval",
125 ]
126 ):
127 value = int(value)
129 current[path_parts[-1]] = value
130 logger.info(
131 "applied group config override",
132 group=group_name,
133 path=override.key,
134 value=value,
135 )
136 except Exception as e:
137 logger.error(
138 "failed to apply group config override",
139 group=group_name,
140 path=override.key,
141 error=str(e),
142 )
144 return config
147def generate_config_yaml(host_id: int, ignore_freeze: bool = False):
148 host = Host.objects.get(id=host_id)
149 if host.config_freeze and not ignore_freeze:
150 last_config = host.hostconfig_set.order_by("-created_at").first()
151 if last_config:
152 logger.info("using frozen config", host_id=host.id, host_name=host.name)
153 return last_config.config
155 logger.info("generating config", host_id=host.id, host_name=host.name)
157 network = host.network
158 ca = network.signingca.ca
160 assert host.public_key
161 # load config yaml
162 config_template = (assets.asset_path / "config.yml").read_text()
163 config_data = yaml.safe_load(config_template)
164 config_data["pki"]["ca"] = "".join([ca.cert for ca in network.ca_set.all()])
166 groups = frozenset([group.name for group in host.groups.all()])
167 assigned_ip = f"{host.assigned_ip}/24"
169 group_timestamps = frozenset([group.updated_at for group in host.groups.all()])
170 group_rules = frozenset(
171 [rule.updated_at for group in host.groups.all() for rule in group.rules.all()]
172 )
173 group_configs = frozenset(
174 [
175 config.updated_at
176 for group in host.groups.all()
177 for config in group.config_overrides.all()
178 ]
179 )
181 query_key = (
182 ca.cert,
183 host.public_key,
184 host.name,
185 assigned_ip,
186 hash(group_timestamps),
187 hash(group_rules),
188 hash(group_configs),
189 )
190 query_key_hash = hash(query_key)
191 host_cert = HostCert.objects.filter(host=host, ca=ca).first()
192 if host_cert is None or host_cert.hash != query_key_hash:
193 logger.info(
194 "generating new host certificate",
195 host_id=host.id,
196 host_name=host.name,
197 reason="initial" if host_cert is None else "hash_changed",
198 )
199 if host_cert:
200 host_cert.delete()
202 cert = sign_keys(
203 ca_key=ca.key,
204 ca_crt=ca.cert,
205 public_key=host.public_key,
206 name=host.name,
207 ip=assigned_ip,
208 groups=groups,
209 )
210 host_cert = HostCert.objects.create(
211 host=host, ca=ca, cert=cert, hash=query_key_hash
212 )
214 config_data["pki"]["cert"] = host_cert.cert
215 config_data["pki"]["key"] = "host.key"
217 lighthouses = Host.objects.filter(network=network, is_lighthouse=True).all()
218 if host.is_lighthouse:
219 config_data["lighthouse"]["am_lighthouse"] = True
220 config_data["lighthouse"]["hosts"] = []
221 else:
222 config_data["lighthouse"]["am_lighthouse"] = False
223 config_data["lighthouse"]["hosts"] = [
224 lighthouse.assigned_ip for lighthouse in lighthouses
225 ]
226 config_data["static_host_map"] = {
227 lighthouse.assigned_ip: [f"{lighthouse.public_ip_or_hostname}:4242"]
228 for lighthouse in lighthouses
229 }
230 config_data["relay"]["am_relay"] = host.is_relay
231 config_data["relay"]["use_relays"] = host.use_relay
232 if not host.is_relay:
233 relay_hosts = (
234 Host.objects.filter(network=network, is_relay=True)
235 .exclude(id=host.id)
236 .all()
237 )
238 if relay_hosts:
239 config_data["relay"]["relays"] = [
240 relay.assigned_ip for relay in relay_hosts
241 ]
243 config_data["tun"]["dev"] = host.interface
245 inbound_rules = []
246 outbound_rules = []
247 for group in host.groups.all():
248 for rule in group.rules.all():
249 rule_data = {}
251 if rule.local_cidr is not None:
252 rule_data["local_cidr"] = rule.local_cidr
254 if rule.cidr is not None:
255 rule_data["cidr"] = rule.cidr
257 groups = rule.groups.all()
258 if len(groups) > 0:
259 rule_data["groups"] = [group.name for group in groups]
261 if rule.group is not None:
262 rule_data["group"] = rule.group.name
264 if rule.proto is not None:
265 rule_data["proto"] = rule.proto
267 if rule.port is not None:
268 rule_data["port"] = rule.port
270 if rule.direction == Rule.Direction.INBOUND:
271 inbound_rules.append(rule_data)
272 else:
273 outbound_rules.append(rule_data)
275 config_data["firewall"]["inbound"] = inbound_rules
276 config_data["firewall"]["outbound"] = outbound_rules
278 config_data = apply_group_config_overrides(config_data, host.groups.all())
280 yaml_config = yaml.safe_dump(config_data, indent=4)
282 sha256 = hashlib.sha256(yaml_config.encode()).hexdigest()
284 if not HostConfig.objects.filter(sha256=sha256).exists():
285 logger.info("saving new config version", host_id=host.id, host_name=host.name)
286 HostConfig.objects.create(host=host, config=yaml_config, sha256=sha256)
287 else:
288 logger.debug("config unchanged", host_id=host.id, host_name=host.name)
290 return yaml_config
293def create_template(
294 name: str,
295 network_name: str,
296 is_lighthouse=False,
297 is_relay: bool = False,
298 use_relay: bool = True,
299 groups: list[str] = (),
300 reusable: bool = True,
301 usage_limit: int = None,
302 ephemeral_peers: bool = False,
303 expires_at: datetime = None,
304):
305 template = Template.objects.create(
306 name=name,
307 network=Network.objects.get(name=network_name),
308 is_lighthouse=is_lighthouse,
309 is_relay=is_relay,
310 use_relay=use_relay,
311 reusable=reusable,
312 usage_limit=usage_limit,
313 ephemeral_peers=ephemeral_peers,
314 expires_at=expires_at,
315 )
317 for group in groups:
318 try:
319 template.groups.add(
320 Group.objects.get(name=group, network__name=network_name)
321 )
322 except Group.DoesNotExist:
323 raise LookupError(f"Group does not exist in network {group}/{network_name}")
325 return template
328def get_server_signing_key():
329 key_path = Path("enrollment_signing.key")
330 if os.path.exists(key_path):
331 with open(key_path, "r") as f:
332 return JWK.from_json(f.read())
333 else:
334 key = JWK.generate(kty="EC", crv="P-256")
335 with open(key_path, "w") as f:
336 f.write(key.export_private())
337 os.chmod(key_path, 0o600)
338 return key
341def generate_enrollment_token(template: Template, ttl: int = None):
342 signing_key = get_server_signing_key()
343 claims = {
344 "jti": str(template.enrollment_key),
345 "iss": "meshadmin",
346 "sub": f"template:{template.id}",
347 "iat": int(now().timestamp()),
348 "template_id": template.id,
349 "network_id": template.network_id,
350 "is_lighthouse": template.is_lighthouse,
351 "is_relay": template.is_relay,
352 "use_relay": template.use_relay,
353 "reusable": template.reusable,
354 "usage_limit": template.usage_limit,
355 "ephemeral_peers": template.ephemeral_peers,
356 }
357 if ttl:
358 claims["exp"] = int(now().timestamp() + ttl)
359 elif template.expires_at:
360 claims["exp"] = int(template.expires_at.timestamp())
361 token = JWT(header={"alg": "ES256"}, claims=claims)
362 token.make_signed_token(signing_key)
363 return token.serialize()
366def verify_enrollment_token(token_string):
367 signing_key = get_server_signing_key()
368 try:
369 token = JWT(jwt=token_string)
370 token.validate(signing_key)
371 payload = json.loads(token.token.objects["payload"])
372 template_id = payload.get("template_id")
373 if not template_id:
374 raise ValueError("Invalid token: missing template_id")
375 return template_id
376 except Exception as e:
377 error_message = str(e)
378 if "Expired" in error_message:
379 raise ValueError("Enrollment token has expired")
380 elif "Invalid signature" in error_message:
381 raise ValueError("Invalid enrollment token signature")
382 else:
383 logger.error(
384 "invalid enrollment token", token=token_string, error=error_message
385 )
386 raise ValueError(f"Invalid enrollment token: {error_message}")
389def enrollment(
390 enrollment_key: str,
391 public_auth_key: str,
392 enroll_on_existence: bool,
393 public_ip: str,
394 preferred_hostname,
395 public_net_key,
396 interface: str = "nebula1",
397):
398 try:
399 template_id = verify_enrollment_token(enrollment_key)
400 template = Template.objects.get(id=template_id)
401 except ValueError as e:
402 logger.error("invalid enrollment token", error=str(e))
403 raise ValueError(f"Invalid enrollment token: {str(e)}")
404 except Template.DoesNotExist:
405 logger.error("template not found", template_id=template_id)
406 raise ValueError("Template not found")
408 # Check usage limit
409 if not template.reusable:
410 if template.usage_count >= 1:
411 logger.error(
412 "single-use enrollment key has already been used",
413 template_id=template.id,
414 )
415 raise ValueError("Single-use enrollment key has already been used")
416 elif template.usage_limit:
417 if template.usage_count >= template.usage_limit:
418 logger.error("enrollment key usage limit exceeded", template_id=template.id)
419 raise ValueError("Enrollment key usage limit exceeded")
421 # check if public key is already enrolled
422 thumbprint = JWK.from_json(public_auth_key).thumbprint()
423 host: Optional[Host] = Host.objects.filter(public_auth_kid=thumbprint).first()
425 # host already registered
426 if host:
427 if enroll_on_existence:
428 logger.info(
429 "host already exists, aborting enrollment",
430 host_id=host.id,
431 enroll_on_existence=enroll_on_existence,
432 )
433 raise ValueError("Host already enrolled")
435 else:
436 host.delete()
438 network = template.network
439 ipv4_iterator = network_available_hosts_iterator(network)
441 if template.is_lighthouse and not public_ip:
442 raise ValueError("Cannot enroll a lighthouse without public_ip")
444 jwk_public_auth = JWK.from_json(public_auth_key)
446 already_registered_hostnames = (
447 Host.objects.values("name")
448 .filter(network=network, name__startswith=preferred_hostname)
449 .all()
450 )
451 final_hostname = preferred_hostname
452 i = 1
453 existing_hostnames = set([host["name"] for host in already_registered_hostnames])
454 while final_hostname in existing_hostnames:
455 final_hostname = f"{preferred_hostname}-{i}"
456 i += 1
458 host = Host.objects.create(
459 network=network,
460 name=final_hostname,
461 assigned_ip=next(ipv4_iterator),
462 is_relay=template.is_relay,
463 is_lighthouse=template.is_lighthouse,
464 public_ip_or_hostname=public_ip,
465 public_key=public_net_key,
466 public_auth_key=public_auth_key,
467 public_auth_kid=jwk_public_auth.thumbprint(),
468 is_ephemeral=template.ephemeral_peers,
469 interface=interface,
470 )
472 # Increment usage count
473 template.usage_count += 1
474 template.save()
476 for group in template.groups.all():
477 host.groups.add(group)
479 host.save()
480 return host
483def create_group(network_pk: int, group_name: str, description: str = ""):
484 network = Network.objects.get(pk=network_pk)
485 return Group.objects.create(
486 network=network, name=group_name, description=description
487 )