Coverage for agentos/tools/jwt.py: 0%
109 statements
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-03 08:21 +0800
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-03 08:21 +0800
1"""
2JWT — JSON Web Token encode/decode/verify (HS256/RS256/ES256).
4Supports:
5 - HS256 (HMAC-SHA256), RS256 (RSA), ES256 (ECDSA) algorithms
6 - Encode with claims (iss, sub, aud, exp, iat, nbf, jti, custom)
7 - Decode with signature verification
8 - Decode without verification (for inspection)
9 - Token expiry checking
10 - Claim validation
11"""
13from __future__ import annotations
15import base64
16import hashlib
17import hmac
18import json
19import time
20from typing import Any, Dict, List, Optional, Tuple, Union
23# ============================================================================
24# JWTError
25# ============================================================================
27class JWTError(Exception):
28 pass
31class ExpiredTokenError(JWTError):
32 pass
35class InvalidTokenError(JWTError):
36 pass
39# ============================================================================
40# Helpers
41# ============================================================================
43def _b64url_encode(data: bytes) -> str:
44 return base64.urlsafe_b64encode(data).rstrip(b"=").decode("ascii")
47def _b64url_decode(s: str) -> bytes:
48 # Restore padding
49 padding = 4 - len(s) % 4
50 if padding != 4:
51 s += "=" * padding
52 return base64.urlsafe_b64decode(s)
55def _json_b64_decode(s: str) -> dict:
56 return json.loads(_b64url_decode(s))
59# ============================================================================
60# JWT
61# ============================================================================
63ALGORITHMS = frozenset({"HS256", "HS384", "HS512", "RS256", "RS384", "RS512", "ES256", "ES384", "ES512"})
66class JWT:
67 """JSON Web Token encoder/decoder.
69 Usage:
70 jwt = JWT(secret="my-secret") # for HS256
72 # Encode
73 token = jwt.encode({"sub": "user123", "role": "admin"}, ttl=3600)
75 # Decode & verify
76 payload = jwt.decode(token)
78 # Decode without verification (inspect only)
79 payload = jwt.decode(token, verify=False)
80 """
82 def __init__(
83 self,
84 secret: Optional[str] = None,
85 private_key: Optional[str] = None,
86 public_key: Optional[str] = None,
87 algorithm: str = "HS256",
88 ):
89 if algorithm not in ALGORITHMS:
90 raise ValueError(f"Unsupported algorithm: {algorithm}. Use one of {sorted(ALGORITHMS)}")
92 self._algorithm = algorithm
93 self._hash_func = {
94 "HS256": hashlib.sha256,
95 "HS384": hashlib.sha384,
96 "HS512": hashlib.sha512,
97 }
99 if algorithm.startswith("HS"):
100 if not secret:
101 raise ValueError(f"{algorithm} requires a secret")
102 self._secret = secret.encode("utf-8")
103 elif algorithm.startswith("RS") or algorithm.startswith("ES"):
104 if not private_key and not public_key:
105 raise ValueError(f"{algorithm} requires at least one key")
106 self._private_key = private_key
107 self._public_key = public_key
109 # ---------- Encode ----------
111 def encode(
112 self,
113 payload: dict,
114 ttl: Optional[int] = None,
115 headers_extra: Optional[dict] = None,
116 ) -> str:
117 """Encode a JWT token.
119 Args:
120 payload: Claims to include
121 ttl: Time-to-live in seconds (sets 'exp' claim)
122 headers_extra: Additional header parameters
123 """
124 header = {"alg": self._algorithm, "typ": "JWT"}
125 if headers_extra:
126 header.update(headers_extra)
128 claims = dict(payload)
129 now = int(time.time())
131 # Standard claims
132 if "iat" not in claims:
133 claims["iat"] = now
134 if ttl is not None and "exp" not in claims:
135 claims["exp"] = now + ttl
137 header_b64 = _b64url_encode(json.dumps(header).encode("utf-8"))
138 payload_b64 = _b64url_encode(json.dumps(claims).encode("utf-8"))
139 signing_input = f"{header_b64}.{payload_b64}"
141 signature = self._sign(signing_input)
142 return f"{signing_input}.{signature}"
144 # ---------- Decode ----------
146 def decode(
147 self,
148 token: str,
149 verify: bool = True,
150 audience: Optional[Union[str, List[str]]] = None,
151 issuer: Optional[str] = None,
152 ) -> dict:
153 """Decode and optionally verify a JWT token.
155 Args:
156 token: The JWT string
157 verify: Whether to verify the signature (default True)
158 audience: Expected audience (if present, validates 'aud' claim)
159 issuer: Expected issuer (if present, validates 'iss' claim)
160 """
161 parts = token.split(".")
162 if len(parts) != 3:
163 raise InvalidTokenError("JWT must have 3 parts (header.payload.signature)")
165 header_b64, payload_b64, signature_b64 = parts
167 # Decode header and payload (always safe)
168 header = _json_b64_decode(header_b64)
169 payload = _json_b64_decode(payload_b64)
171 # Verify algorithm
172 alg = header.get("alg")
173 if verify and alg != self._algorithm:
174 raise InvalidTokenError(f"Algorithm mismatch: expected {self._algorithm}, got {alg}")
176 # Verify signature
177 if verify:
178 signing_input = f"{header_b64}.{payload_b64}"
179 if not self._verify(signing_input, signature_b64):
180 raise InvalidTokenError("Invalid signature")
182 # Check expiry
183 exp = payload.get("exp")
184 if exp and int(exp) < time.time():
185 raise ExpiredTokenError(f"Token expired at {exp}")
187 # Check not-before
188 nbf = payload.get("nbf")
189 if nbf and int(nbf) > time.time():
190 raise InvalidTokenError(f"Token not valid before {nbf}")
192 # Check audience
193 if audience is not None:
194 aud = payload.get("aud")
195 if aud is None:
196 raise InvalidTokenError("Token missing 'aud' claim")
197 expected = [audience] if isinstance(audience, str) else audience
198 if isinstance(aud, str):
199 aud = [aud]
200 if not set(expected) & set(aud):
201 raise InvalidTokenError(f"Audience mismatch")
203 # Check issuer
204 if issuer is not None:
205 iss = payload.get("iss")
206 if iss != issuer:
207 raise InvalidTokenError(f"Issuer mismatch: expected {issuer}, got {iss}")
209 return payload
211 # ---------- Signature ----------
213 def _sign(self, data: str) -> str:
214 if self._algorithm.startswith("HS"):
215 d = hmac.new(self._secret, data.encode("utf-8"), self._hash_func[self._algorithm]).digest()
216 return _b64url_encode(d)
217 else:
218 raise NotImplementedError(f"Signing with {self._algorithm} requires cryptographic libraries (cryptography)")
220 def _verify(self, data: str, signature_b64: str) -> bool:
221 if self._algorithm.startswith("HS"):
222 expected = self._sign(data)
223 return hmac.compare_digest(expected, signature_b64)
224 else:
225 raise NotImplementedError(f"Verification with {self._algorithm} requires cryptographic libraries (cryptography)")
227 # ---------- Static helpers ----------
229 @staticmethod
230 def decode_unverified(token: str) -> dict:
231 """Decode JWT without verifying signature (inspect only)."""
232 parts = token.split(".")
233 if len(parts) != 3:
234 raise InvalidTokenError("JWT must have 3 parts")
235 return _json_b64_decode(parts[1])
237 @staticmethod
238 def get_header(token: str) -> dict:
239 """Extract JWT header without verification."""
240 parts = token.split(".")
241 if len(parts) != 3:
242 raise InvalidTokenError("JWT must have 3 parts")
243 return _json_b64_decode(parts[0])