Coverage for src/meshadmin/server/networks/services.py: 94%
230 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-25 08:49 +0200
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-25 08:49 +0200
1import copy
2import hashlib
3import ipaddress
4import json
5import os
6from datetime import datetime
7from pathlib import Path
8from typing import Optional
10import structlog
11import yaml
12from django.contrib.auth import get_user_model
13from django.utils.timezone import now
14from jwcrypto.jwk import JWK
15from jwcrypto.jwt import JWT
17from meshadmin.common.utils import create_ca, print_ca, sign_keys
18from meshadmin.server import assets
19from meshadmin.server.networks.models import (
20 CA,
21 Group,
22 Host,
23 HostCert,
24 HostConfig,
25 Network,
26 NetworkMembership,
27 Rule,
28 SigningCA,
29 Template,
30)
32User = get_user_model()
33logger = structlog.get_logger(__name__)
36def create_available_hosts_iterator(cidr, unavailable_ips):
37 network = ipaddress.IPv4Network(cidr)
38 hosts_iterator = (
39 host for host in network.hosts() if str(host) not in unavailable_ips
40 )
41 return hosts_iterator
44def network_available_hosts_iterator(network):
45 reserved_ips = [
46 host.assigned_ip for host in Host.objects.filter(network=network).all()
47 ]
48 ipv4_iterator = create_available_hosts_iterator(network.cidr, reserved_ips)
49 return ipv4_iterator
52def create_network_ca(ca_name, network):
53 cert, key = create_ca(ca_name)
54 cert_print = print_ca(cert)
55 ca = CA.objects.create(
56 network=network, name=ca_name, cert=cert, key=key, cert_print=cert_print
57 )
58 return ca
61def create_network(
62 network_name: str, network_cidr: str, user: User, update_interval: int = 5
63):
64 logger.info("creating network")
65 network = Network.objects.create(
66 name=network_name, cidr=network_cidr, update_interval=update_interval
67 )
68 NetworkMembership.objects.create(
69 network=network, user=user, role=NetworkMembership.Role.ADMIN
70 )
72 ca_name = "auto created initial ca"
73 cert, key = create_ca(ca_name)
74 json_data = print_ca(cert)
75 ca = CA.objects.create(
76 network=network, name=ca_name, cert=cert, key=key, cert_print=json_data
77 )
79 SigningCA.objects.create(network=network, ca=ca)
80 logger.info("created network", network=str(network))
81 return network
84def apply_group_config_overrides(config_data: dict, groups: list[Group]) -> dict:
85 config = copy.deepcopy(config_data)
86 overrides = []
87 for group in groups:
88 for override in group.config_overrides.all():
89 overrides.append((group.name, override))
91 # Sort overrides by group name to ensure consistent ordering
92 overrides.sort(key=lambda x: x[0])
94 for group_name, override in overrides:
95 try:
96 path_parts = override.key.split(".")
97 current = config
98 for part in path_parts[:-1]:
99 if part not in current:
100 current[part] = {}
101 current = current[part]
103 value = override.value
104 if any(
105 override.key.endswith(suffix)
106 for suffix in [
107 ".serve_dns",
108 ]
109 ):
110 value = value.lower() == "true"
111 elif any(
112 override.key.endswith(suffix)
113 for suffix in [
114 ".port",
115 ]
116 ):
117 value = int(value)
119 current[path_parts[-1]] = value
120 logger.info(
121 "applied group config override",
122 group=group_name,
123 path=override.key,
124 value=value,
125 )
126 except Exception as e:
127 logger.error(
128 "failed to apply group config override",
129 group=group_name,
130 path=override.key,
131 error=str(e),
132 )
134 return config
137def generate_config_yaml(host_id: int, ignore_freeze: bool = False):
138 host = Host.objects.get(id=host_id)
139 if host.config_freeze and not ignore_freeze:
140 last_config = host.hostconfig_set.order_by("-created_at").first()
141 if last_config:
142 logger.info("using frozen config", host_id=host.id, host_name=host.name)
143 return last_config.config
145 logger.info("generating config", host_id=host.id, host_name=host.name)
147 network = host.network
148 ca = network.signingca.ca
150 assert host.public_key
151 # load config yaml
152 config_template = (assets.asset_path / "config.yml").read_text()
153 config_data = yaml.safe_load(config_template)
154 config_data["pki"]["ca"] = "".join([ca.cert for ca in network.ca_set.all()])
156 groups = frozenset([group.name for group in host.groups.all()])
157 assigned_ip = f"{host.assigned_ip}/24"
159 group_timestamps = frozenset([group.updated_at for group in host.groups.all()])
160 group_rules = frozenset(
161 [rule.updated_at for group in host.groups.all() for rule in group.rules.all()]
162 )
163 group_configs = frozenset(
164 [
165 config.updated_at
166 for group in host.groups.all()
167 for config in group.config_overrides.all()
168 ]
169 )
171 query_key = (
172 ca.cert,
173 host.public_key,
174 host.name,
175 assigned_ip,
176 hash(group_timestamps),
177 hash(group_rules),
178 hash(group_configs),
179 )
180 query_key_hash = hash(query_key)
181 host_cert = HostCert.objects.filter(host=host, ca=ca).first()
182 if host_cert is None or host_cert.hash != query_key_hash:
183 logger.info(
184 "generating new host certificate",
185 host_id=host.id,
186 host_name=host.name,
187 reason="initial" if host_cert is None else "hash_changed",
188 )
189 if host_cert:
190 host_cert.delete()
192 cert = sign_keys(
193 ca_key=ca.key,
194 ca_crt=ca.cert,
195 public_key=host.public_key,
196 name=host.name,
197 ip=assigned_ip,
198 groups=groups,
199 )
200 host_cert = HostCert.objects.create(
201 host=host, ca=ca, cert=cert, hash=query_key_hash
202 )
204 config_data["pki"]["cert"] = host_cert.cert
205 config_data["pki"]["key"] = "host.key"
207 lighthouses = Host.objects.filter(network=network, is_lighthouse=True).all()
208 if host.is_lighthouse:
209 config_data["lighthouse"]["am_lighthouse"] = True
210 config_data["lighthouse"]["hosts"] = []
211 else:
212 config_data["lighthouse"]["am_lighthouse"] = False
213 config_data["lighthouse"]["hosts"] = [
214 lighthouse.assigned_ip for lighthouse in lighthouses
215 ]
216 config_data["static_host_map"] = {
217 lighthouse.assigned_ip: [f"{lighthouse.public_ip_or_hostname}:4242"]
218 for lighthouse in lighthouses
219 }
220 config_data["relay"]["am_relay"] = host.is_relay
221 config_data["relay"]["use_relays"] = host.use_relay
223 config_data["tun"]["dev"] = host.interface
225 inbound_rules = []
226 outbound_rules = []
227 for group in host.groups.all():
228 for rule in group.rules.all():
229 rule_data = {}
231 if rule.local_cidr is not None:
232 rule_data["local_cidr"] = rule.local_cidr
234 if rule.cidr is not None:
235 rule_data["cidr"] = rule.cidr
237 groups = rule.groups.all()
238 if len(groups) > 0:
239 rule_data["groups"] = [group.name for group in groups]
241 if rule.group is not None:
242 rule_data["group"] = rule.group.name
244 if rule.proto is not None:
245 rule_data["proto"] = rule.proto
247 if rule.port is not None:
248 rule_data["port"] = rule.port
250 if rule.direction == Rule.Direction.INBOUND:
251 inbound_rules.append(rule_data)
252 else:
253 outbound_rules.append(rule_data)
255 config_data["firewall"]["inbound"] = inbound_rules
256 config_data["firewall"]["outbound"] = outbound_rules
258 config_data = apply_group_config_overrides(config_data, host.groups.all())
260 yaml_config = yaml.safe_dump(config_data, indent=4)
262 sha256 = hashlib.sha256(yaml_config.encode()).hexdigest()
264 if not HostConfig.objects.filter(sha256=sha256).exists():
265 logger.info("saving new config version", host_id=host.id, host_name=host.name)
266 HostConfig.objects.create(host=host, config=yaml_config, sha256=sha256)
267 else:
268 logger.debug("config unchanged", host_id=host.id, host_name=host.name)
270 return yaml_config
273def create_template(
274 name: str,
275 network_name: str,
276 is_lighthouse=False,
277 is_relay: bool = False,
278 use_relay: bool = True,
279 groups: list[str] = (),
280 reusable: bool = True,
281 usage_limit: int = None,
282 ephemeral_peers: bool = False,
283 expires_at: datetime = None,
284):
285 template = Template.objects.create(
286 name=name,
287 network=Network.objects.get(name=network_name),
288 is_lighthouse=is_lighthouse,
289 is_relay=is_relay,
290 use_relay=use_relay,
291 reusable=reusable,
292 usage_limit=usage_limit,
293 ephemeral_peers=ephemeral_peers,
294 expires_at=expires_at,
295 )
297 for group in groups:
298 try:
299 template.groups.add(
300 Group.objects.get(name=group, network__name=network_name)
301 )
302 except Group.DoesNotExist:
303 raise LookupError(f"Group does not exist in network {group}/{network_name}")
305 return template
308def get_server_signing_key():
309 key_path = Path("enrollment_signing.key")
310 if os.path.exists(key_path):
311 with open(key_path, "r") as f:
312 return JWK.from_json(f.read())
313 else:
314 key = JWK.generate(kty="EC", crv="P-256")
315 with open(key_path, "w") as f:
316 f.write(key.export_private())
317 os.chmod(key_path, 0o600)
318 return key
321def generate_enrollment_token(template: Template):
322 signing_key = get_server_signing_key()
323 claims = {
324 "jti": str(template.enrollment_key),
325 "iss": "meshadmin",
326 "sub": f"template:{template.id}",
327 "iat": int(now().timestamp()),
328 "template_id": template.id,
329 "network_id": template.network_id,
330 "is_lighthouse": template.is_lighthouse,
331 "is_relay": template.is_relay,
332 "use_relay": template.use_relay,
333 "reusable": template.reusable,
334 "usage_limit": template.usage_limit,
335 "ephemeral_peers": template.ephemeral_peers,
336 }
337 if template.expires_at:
338 claims["exp"] = int(template.expires_at.timestamp())
339 token = JWT(header={"alg": "ES256"}, claims=claims)
340 token.make_signed_token(signing_key)
341 return token.serialize()
344def verify_enrollment_token(token_string):
345 signing_key = get_server_signing_key()
346 try:
347 token = JWT(jwt=token_string)
348 token.validate(signing_key)
349 payload = json.loads(token.token.objects["payload"])
350 template_id = payload.get("template_id")
351 if not template_id:
352 raise ValueError("Invalid token: missing template_id")
353 return template_id
354 except Exception as e:
355 error_message = str(e)
356 if "Expired" in error_message:
357 raise ValueError("Enrollment token has expired")
358 elif "Invalid signature" in error_message:
359 raise ValueError("Invalid enrollment token signature")
360 else:
361 logger.error(
362 "invalid enrollment token", token=token_string, error=error_message
363 )
364 raise ValueError(f"Invalid enrollment token: {error_message}")
367def enrollment(
368 enrollment_key: str,
369 public_auth_key: str,
370 enroll_on_existence: bool,
371 public_ip: str,
372 preferred_hostname,
373 public_net_key,
374 interface: str = "nebula1",
375):
376 try:
377 template_id = verify_enrollment_token(enrollment_key)
378 template = Template.objects.get(id=template_id)
379 except ValueError as e:
380 logger.error("invalid enrollment token", error=str(e))
381 raise ValueError(f"Invalid enrollment token: {str(e)}")
382 except Template.DoesNotExist:
383 logger.error("template not found", template_id=template_id)
384 raise ValueError("Template not found")
386 # Check usage limit
387 if not template.reusable:
388 if template.usage_count >= 1:
389 logger.error(
390 "single-use enrollment key has already been used",
391 template_id=template.id,
392 )
393 raise ValueError("Single-use enrollment key has already been used")
394 elif template.usage_limit:
395 if template.usage_count >= template.usage_limit:
396 logger.error("enrollment key usage limit exceeded", template_id=template.id)
397 raise ValueError("Enrollment key usage limit exceeded")
399 # check if public key is already enrolled
400 thumbprint = JWK.from_json(public_auth_key).thumbprint()
401 host: Optional[Host] = Host.objects.filter(public_auth_kid=thumbprint).first()
403 # host already registered
404 if host:
405 if enroll_on_existence:
406 logger.info(
407 "host already exists, aborting enrollment",
408 host_id=host.id,
409 enroll_on_existence=enroll_on_existence,
410 )
411 raise ValueError("Host already enrolled")
413 else:
414 host.delete()
416 network = template.network
417 ipv4_iterator = network_available_hosts_iterator(network)
419 if template.is_lighthouse and not public_ip:
420 raise ValueError("Cannot enroll a lighthouse without public_ip")
422 jwk_public_auth = JWK.from_json(public_auth_key)
424 already_registered_hostnames = (
425 Host.objects.values("name")
426 .filter(network=network, name__startswith=preferred_hostname)
427 .all()
428 )
429 final_hostname = preferred_hostname
430 i = 1
431 existing_hostnames = set([host["name"] for host in already_registered_hostnames])
432 while final_hostname in existing_hostnames:
433 final_hostname = f"{preferred_hostname}-{i}"
434 i += 1
436 host = Host.objects.create(
437 network=network,
438 name=final_hostname,
439 assigned_ip=next(ipv4_iterator),
440 is_relay=template.is_relay,
441 is_lighthouse=template.is_lighthouse,
442 public_ip_or_hostname=public_ip,
443 public_key=public_net_key,
444 public_auth_key=public_auth_key,
445 public_auth_kid=jwk_public_auth.thumbprint(),
446 is_ephemeral=template.ephemeral_peers,
447 interface=interface,
448 )
450 # Increment usage count
451 template.usage_count += 1
452 template.save()
454 for group in template.groups.all():
455 host.groups.add(group)
457 host.save()
458 return host
461def create_group(network_pk: int, group_name: str, description: str = ""):
462 network = Network.objects.get(pk=network_pk)
463 return Group.objects.create(
464 network=network, name=group_name, description=description
465 )