Coverage for src/meshadmin/server/networks/services.py: 94%
232 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-07 19:26 +0200
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-07 19:26 +0200
1import copy
2import hashlib
3import ipaddress
4import json
5import os
6from datetime import datetime
7from pathlib import Path
8from typing import Optional
10import structlog
11import yaml
12from django.contrib.auth import get_user_model
13from django.utils.timezone import now
14from jwcrypto.jwk import JWK
15from jwcrypto.jwt import JWT
17from meshadmin.common.utils import create_ca, print_ca, sign_keys
18from meshadmin.server import assets
19from meshadmin.server.networks.models import (
20 CA,
21 Group,
22 Host,
23 HostCert,
24 HostConfig,
25 Network,
26 NetworkMembership,
27 Rule,
28 SigningCA,
29 Template,
30)
32User = get_user_model()
33logger = structlog.get_logger(__name__)
36def create_available_hosts_iterator(cidr, unavailable_ips):
37 network = ipaddress.IPv4Network(cidr)
38 hosts_iterator = (
39 host for host in network.hosts() if str(host) not in unavailable_ips
40 )
41 return hosts_iterator
44def network_available_hosts_iterator(network):
45 reserved_ips = [
46 host.assigned_ip for host in Host.objects.filter(network=network).all()
47 ]
48 ipv4_iterator = create_available_hosts_iterator(network.cidr, reserved_ips)
49 return ipv4_iterator
52def create_network_ca(ca_name, network):
53 cert, key = create_ca(ca_name)
54 cert_print = print_ca(cert)
55 ca = CA.objects.create(
56 network=network, name=ca_name, cert=cert, key=key, cert_print=cert_print
57 )
58 return ca
61def create_network(
62 network_name: str, network_cidr: str, user: User, update_interval: int = 5
63):
64 logger.info("creating network")
65 network = Network.objects.create(
66 name=network_name, cidr=network_cidr, update_interval=update_interval
67 )
68 NetworkMembership.objects.create(
69 network=network, user=user, role=NetworkMembership.Role.ADMIN
70 )
72 ca_name = "auto created initial ca"
73 cert, key = create_ca(ca_name)
74 json_data = print_ca(cert)
75 ca = CA.objects.create(
76 network=network, name=ca_name, cert=cert, key=key, cert_print=json_data
77 )
79 SigningCA.objects.create(network=network, ca=ca)
80 logger.info("created network", network=str(network))
81 return network
84def apply_group_config_overrides(config_data: dict, groups: list[Group]) -> dict:
85 config = copy.deepcopy(config_data)
86 overrides = []
87 for group in groups:
88 for override in group.config_overrides.all():
89 overrides.append((group.name, override))
91 # Sort overrides by group name to ensure consistent ordering
92 overrides.sort(key=lambda x: x[0])
94 for group_name, override in overrides:
95 try:
96 path_parts = override.key.split(".")
97 current = config
98 for part in path_parts[:-1]:
99 if part not in current:
100 current[part] = {}
101 current = current[part]
103 value = override.value
104 # Handle boolean fields
105 if any(
106 override.key.endswith(suffix)
107 for suffix in [
108 ".serve_dns",
109 ".punch",
110 ".respond",
111 ".message_metrics",
112 ".lighthouse_metrics",
113 ]
114 ):
115 value = value.lower() == "true"
117 # Handle numeric fields
118 elif any(
119 override.key.endswith(suffix)
120 for suffix in [
121 ".port",
122 ".delay",
123 ".respond_delay",
124 ".interval",
125 ]
126 ):
127 value = int(value)
129 current[path_parts[-1]] = value
130 logger.info(
131 "applied group config override",
132 group=group_name,
133 path=override.key,
134 value=value,
135 )
136 except Exception as e:
137 logger.error(
138 "failed to apply group config override",
139 group=group_name,
140 path=override.key,
141 error=str(e),
142 )
144 return config
147def generate_config_yaml(host_id: int, ignore_freeze: bool = False):
148 host = Host.objects.get(id=host_id)
149 if host.config_freeze and not ignore_freeze:
150 last_config = host.hostconfig_set.order_by("-created_at").first()
151 if last_config:
152 logger.info("using frozen config", host_id=host.id, host_name=host.name)
153 return last_config.config
155 logger.info("generating config", host_id=host.id, host_name=host.name)
157 network = host.network
158 ca = network.signingca.ca
160 assert host.public_key
161 # load config yaml
162 config_template = (assets.asset_path / "config.yml").read_text()
163 config_data = yaml.safe_load(config_template)
164 config_data["pki"]["ca"] = "".join([ca.cert for ca in network.ca_set.all()])
166 groups = frozenset([group.name for group in host.groups.all()])
167 assigned_ip = f"{host.assigned_ip}/24"
169 group_timestamps = frozenset([group.updated_at for group in host.groups.all()])
170 group_rules = frozenset(
171 [rule.updated_at for group in host.groups.all() for rule in group.rules.all()]
172 )
173 group_configs = frozenset(
174 [
175 config.updated_at
176 for group in host.groups.all()
177 for config in group.config_overrides.all()
178 ]
179 )
181 query_key = (
182 ca.cert,
183 host.public_key,
184 host.name,
185 assigned_ip,
186 hash(group_timestamps),
187 hash(group_rules),
188 hash(group_configs),
189 )
190 query_key_hash = hash(query_key)
191 host_cert = HostCert.objects.filter(host=host, ca=ca).first()
192 if host_cert is None or host_cert.hash != query_key_hash:
193 logger.info(
194 "generating new host certificate",
195 host_id=host.id,
196 host_name=host.name,
197 reason="initial" if host_cert is None else "hash_changed",
198 )
199 if host_cert:
200 host_cert.delete()
202 cert = sign_keys(
203 ca_key=ca.key,
204 ca_crt=ca.cert,
205 public_key=host.public_key,
206 name=host.name,
207 ip=assigned_ip,
208 groups=groups,
209 )
210 host_cert = HostCert.objects.create(
211 host=host, ca=ca, cert=cert, hash=query_key_hash
212 )
214 config_data["pki"]["cert"] = host_cert.cert
215 config_data["pki"]["key"] = "host.key"
217 lighthouses = Host.objects.filter(network=network, is_lighthouse=True).all()
218 if host.is_lighthouse:
219 config_data["lighthouse"]["am_lighthouse"] = True
220 config_data["lighthouse"]["hosts"] = []
221 else:
222 config_data["lighthouse"]["am_lighthouse"] = False
223 config_data["lighthouse"]["hosts"] = [
224 lighthouse.assigned_ip for lighthouse in lighthouses
225 ]
226 config_data["static_host_map"] = {
227 lighthouse.assigned_ip: [f"{lighthouse.public_ip_or_hostname}:4242"]
228 for lighthouse in lighthouses
229 }
230 config_data["relay"]["am_relay"] = host.is_relay
231 config_data["relay"]["use_relays"] = host.use_relay
233 config_data["tun"]["dev"] = host.interface
235 inbound_rules = []
236 outbound_rules = []
237 for group in host.groups.all():
238 for rule in group.rules.all():
239 rule_data = {}
241 if rule.local_cidr is not None:
242 rule_data["local_cidr"] = rule.local_cidr
244 if rule.cidr is not None:
245 rule_data["cidr"] = rule.cidr
247 groups = rule.groups.all()
248 if len(groups) > 0:
249 rule_data["groups"] = [group.name for group in groups]
251 if rule.group is not None:
252 rule_data["group"] = rule.group.name
254 if rule.proto is not None:
255 rule_data["proto"] = rule.proto
257 if rule.port is not None:
258 rule_data["port"] = rule.port
260 if rule.direction == Rule.Direction.INBOUND:
261 inbound_rules.append(rule_data)
262 else:
263 outbound_rules.append(rule_data)
265 config_data["firewall"]["inbound"] = inbound_rules
266 config_data["firewall"]["outbound"] = outbound_rules
268 config_data = apply_group_config_overrides(config_data, host.groups.all())
270 yaml_config = yaml.safe_dump(config_data, indent=4)
272 sha256 = hashlib.sha256(yaml_config.encode()).hexdigest()
274 if not HostConfig.objects.filter(sha256=sha256).exists():
275 logger.info("saving new config version", host_id=host.id, host_name=host.name)
276 HostConfig.objects.create(host=host, config=yaml_config, sha256=sha256)
277 else:
278 logger.debug("config unchanged", host_id=host.id, host_name=host.name)
280 return yaml_config
283def create_template(
284 name: str,
285 network_name: str,
286 is_lighthouse=False,
287 is_relay: bool = False,
288 use_relay: bool = True,
289 groups: list[str] = (),
290 reusable: bool = True,
291 usage_limit: int = None,
292 ephemeral_peers: bool = False,
293 expires_at: datetime = None,
294):
295 template = Template.objects.create(
296 name=name,
297 network=Network.objects.get(name=network_name),
298 is_lighthouse=is_lighthouse,
299 is_relay=is_relay,
300 use_relay=use_relay,
301 reusable=reusable,
302 usage_limit=usage_limit,
303 ephemeral_peers=ephemeral_peers,
304 expires_at=expires_at,
305 )
307 for group in groups:
308 try:
309 template.groups.add(
310 Group.objects.get(name=group, network__name=network_name)
311 )
312 except Group.DoesNotExist:
313 raise LookupError(f"Group does not exist in network {group}/{network_name}")
315 return template
318def get_server_signing_key():
319 key_path = Path("enrollment_signing.key")
320 if os.path.exists(key_path):
321 with open(key_path, "r") as f:
322 return JWK.from_json(f.read())
323 else:
324 key = JWK.generate(kty="EC", crv="P-256")
325 with open(key_path, "w") as f:
326 f.write(key.export_private())
327 os.chmod(key_path, 0o600)
328 return key
331def generate_enrollment_token(template: Template, ttl: int = None):
332 signing_key = get_server_signing_key()
333 claims = {
334 "jti": str(template.enrollment_key),
335 "iss": "meshadmin",
336 "sub": f"template:{template.id}",
337 "iat": int(now().timestamp()),
338 "template_id": template.id,
339 "network_id": template.network_id,
340 "is_lighthouse": template.is_lighthouse,
341 "is_relay": template.is_relay,
342 "use_relay": template.use_relay,
343 "reusable": template.reusable,
344 "usage_limit": template.usage_limit,
345 "ephemeral_peers": template.ephemeral_peers,
346 }
347 if ttl:
348 claims["exp"] = int(now().timestamp() + ttl)
349 elif template.expires_at:
350 claims["exp"] = int(template.expires_at.timestamp())
351 token = JWT(header={"alg": "ES256"}, claims=claims)
352 token.make_signed_token(signing_key)
353 return token.serialize()
356def verify_enrollment_token(token_string):
357 signing_key = get_server_signing_key()
358 try:
359 token = JWT(jwt=token_string)
360 token.validate(signing_key)
361 payload = json.loads(token.token.objects["payload"])
362 template_id = payload.get("template_id")
363 if not template_id:
364 raise ValueError("Invalid token: missing template_id")
365 return template_id
366 except Exception as e:
367 error_message = str(e)
368 if "Expired" in error_message:
369 raise ValueError("Enrollment token has expired")
370 elif "Invalid signature" in error_message:
371 raise ValueError("Invalid enrollment token signature")
372 else:
373 logger.error(
374 "invalid enrollment token", token=token_string, error=error_message
375 )
376 raise ValueError(f"Invalid enrollment token: {error_message}")
379def enrollment(
380 enrollment_key: str,
381 public_auth_key: str,
382 enroll_on_existence: bool,
383 public_ip: str,
384 preferred_hostname,
385 public_net_key,
386 interface: str = "nebula1",
387):
388 try:
389 template_id = verify_enrollment_token(enrollment_key)
390 template = Template.objects.get(id=template_id)
391 except ValueError as e:
392 logger.error("invalid enrollment token", error=str(e))
393 raise ValueError(f"Invalid enrollment token: {str(e)}")
394 except Template.DoesNotExist:
395 logger.error("template not found", template_id=template_id)
396 raise ValueError("Template not found")
398 # Check usage limit
399 if not template.reusable:
400 if template.usage_count >= 1:
401 logger.error(
402 "single-use enrollment key has already been used",
403 template_id=template.id,
404 )
405 raise ValueError("Single-use enrollment key has already been used")
406 elif template.usage_limit:
407 if template.usage_count >= template.usage_limit:
408 logger.error("enrollment key usage limit exceeded", template_id=template.id)
409 raise ValueError("Enrollment key usage limit exceeded")
411 # check if public key is already enrolled
412 thumbprint = JWK.from_json(public_auth_key).thumbprint()
413 host: Optional[Host] = Host.objects.filter(public_auth_kid=thumbprint).first()
415 # host already registered
416 if host:
417 if enroll_on_existence:
418 logger.info(
419 "host already exists, aborting enrollment",
420 host_id=host.id,
421 enroll_on_existence=enroll_on_existence,
422 )
423 raise ValueError("Host already enrolled")
425 else:
426 host.delete()
428 network = template.network
429 ipv4_iterator = network_available_hosts_iterator(network)
431 if template.is_lighthouse and not public_ip:
432 raise ValueError("Cannot enroll a lighthouse without public_ip")
434 jwk_public_auth = JWK.from_json(public_auth_key)
436 already_registered_hostnames = (
437 Host.objects.values("name")
438 .filter(network=network, name__startswith=preferred_hostname)
439 .all()
440 )
441 final_hostname = preferred_hostname
442 i = 1
443 existing_hostnames = set([host["name"] for host in already_registered_hostnames])
444 while final_hostname in existing_hostnames:
445 final_hostname = f"{preferred_hostname}-{i}"
446 i += 1
448 host = Host.objects.create(
449 network=network,
450 name=final_hostname,
451 assigned_ip=next(ipv4_iterator),
452 is_relay=template.is_relay,
453 is_lighthouse=template.is_lighthouse,
454 public_ip_or_hostname=public_ip,
455 public_key=public_net_key,
456 public_auth_key=public_auth_key,
457 public_auth_kid=jwk_public_auth.thumbprint(),
458 is_ephemeral=template.ephemeral_peers,
459 interface=interface,
460 )
462 # Increment usage count
463 template.usage_count += 1
464 template.save()
466 for group in template.groups.all():
467 host.groups.add(group)
469 host.save()
470 return host
473def create_group(network_pk: int, group_name: str, description: str = ""):
474 network = Network.objects.get(pk=network_pk)
475 return Group.objects.create(
476 network=network, name=group_name, description=description
477 )