Coverage for agentos/tools/jwt.py: 0%

109 statements  

« prev     ^ index     » next       coverage.py v7.14.3, created at 2026-07-03 08:21 +0800

1""" 

2JWT — JSON Web Token encode/decode/verify (HS256/RS256/ES256). 

3 

4Supports: 

5 - HS256 (HMAC-SHA256), RS256 (RSA), ES256 (ECDSA) algorithms 

6 - Encode with claims (iss, sub, aud, exp, iat, nbf, jti, custom) 

7 - Decode with signature verification 

8 - Decode without verification (for inspection) 

9 - Token expiry checking 

10 - Claim validation 

11""" 

12 

13from __future__ import annotations 

14 

15import base64 

16import hashlib 

17import hmac 

18import json 

19import time 

20from typing import Any, Dict, List, Optional, Tuple, Union 

21 

22 

23# ============================================================================ 

24# JWTError 

25# ============================================================================ 

26 

27class JWTError(Exception): 

28 pass 

29 

30 

31class ExpiredTokenError(JWTError): 

32 pass 

33 

34 

35class InvalidTokenError(JWTError): 

36 pass 

37 

38 

39# ============================================================================ 

40# Helpers 

41# ============================================================================ 

42 

43def _b64url_encode(data: bytes) -> str: 

44 return base64.urlsafe_b64encode(data).rstrip(b"=").decode("ascii") 

45 

46 

47def _b64url_decode(s: str) -> bytes: 

48 # Restore padding 

49 padding = 4 - len(s) % 4 

50 if padding != 4: 

51 s += "=" * padding 

52 return base64.urlsafe_b64decode(s) 

53 

54 

55def _json_b64_decode(s: str) -> dict: 

56 return json.loads(_b64url_decode(s)) 

57 

58 

59# ============================================================================ 

60# JWT 

61# ============================================================================ 

62 

63ALGORITHMS = frozenset({"HS256", "HS384", "HS512", "RS256", "RS384", "RS512", "ES256", "ES384", "ES512"}) 

64 

65 

66class JWT: 

67 """JSON Web Token encoder/decoder. 

68 

69 Usage: 

70 jwt = JWT(secret="my-secret") # for HS256 

71 

72 # Encode 

73 token = jwt.encode({"sub": "user123", "role": "admin"}, ttl=3600) 

74 

75 # Decode & verify 

76 payload = jwt.decode(token) 

77 

78 # Decode without verification (inspect only) 

79 payload = jwt.decode(token, verify=False) 

80 """ 

81 

82 def __init__( 

83 self, 

84 secret: Optional[str] = None, 

85 private_key: Optional[str] = None, 

86 public_key: Optional[str] = None, 

87 algorithm: str = "HS256", 

88 ): 

89 if algorithm not in ALGORITHMS: 

90 raise ValueError(f"Unsupported algorithm: {algorithm}. Use one of {sorted(ALGORITHMS)}") 

91 

92 self._algorithm = algorithm 

93 self._hash_func = { 

94 "HS256": hashlib.sha256, 

95 "HS384": hashlib.sha384, 

96 "HS512": hashlib.sha512, 

97 } 

98 

99 if algorithm.startswith("HS"): 

100 if not secret: 

101 raise ValueError(f"{algorithm} requires a secret") 

102 self._secret = secret.encode("utf-8") 

103 elif algorithm.startswith("RS") or algorithm.startswith("ES"): 

104 if not private_key and not public_key: 

105 raise ValueError(f"{algorithm} requires at least one key") 

106 self._private_key = private_key 

107 self._public_key = public_key 

108 

109 # ---------- Encode ---------- 

110 

111 def encode( 

112 self, 

113 payload: dict, 

114 ttl: Optional[int] = None, 

115 headers_extra: Optional[dict] = None, 

116 ) -> str: 

117 """Encode a JWT token. 

118 

119 Args: 

120 payload: Claims to include 

121 ttl: Time-to-live in seconds (sets 'exp' claim) 

122 headers_extra: Additional header parameters 

123 """ 

124 header = {"alg": self._algorithm, "typ": "JWT"} 

125 if headers_extra: 

126 header.update(headers_extra) 

127 

128 claims = dict(payload) 

129 now = int(time.time()) 

130 

131 # Standard claims 

132 if "iat" not in claims: 

133 claims["iat"] = now 

134 if ttl is not None and "exp" not in claims: 

135 claims["exp"] = now + ttl 

136 

137 header_b64 = _b64url_encode(json.dumps(header).encode("utf-8")) 

138 payload_b64 = _b64url_encode(json.dumps(claims).encode("utf-8")) 

139 signing_input = f"{header_b64}.{payload_b64}" 

140 

141 signature = self._sign(signing_input) 

142 return f"{signing_input}.{signature}" 

143 

144 # ---------- Decode ---------- 

145 

146 def decode( 

147 self, 

148 token: str, 

149 verify: bool = True, 

150 audience: Optional[Union[str, List[str]]] = None, 

151 issuer: Optional[str] = None, 

152 ) -> dict: 

153 """Decode and optionally verify a JWT token. 

154 

155 Args: 

156 token: The JWT string 

157 verify: Whether to verify the signature (default True) 

158 audience: Expected audience (if present, validates 'aud' claim) 

159 issuer: Expected issuer (if present, validates 'iss' claim) 

160 """ 

161 parts = token.split(".") 

162 if len(parts) != 3: 

163 raise InvalidTokenError("JWT must have 3 parts (header.payload.signature)") 

164 

165 header_b64, payload_b64, signature_b64 = parts 

166 

167 # Decode header and payload (always safe) 

168 header = _json_b64_decode(header_b64) 

169 payload = _json_b64_decode(payload_b64) 

170 

171 # Verify algorithm 

172 alg = header.get("alg") 

173 if verify and alg != self._algorithm: 

174 raise InvalidTokenError(f"Algorithm mismatch: expected {self._algorithm}, got {alg}") 

175 

176 # Verify signature 

177 if verify: 

178 signing_input = f"{header_b64}.{payload_b64}" 

179 if not self._verify(signing_input, signature_b64): 

180 raise InvalidTokenError("Invalid signature") 

181 

182 # Check expiry 

183 exp = payload.get("exp") 

184 if exp and int(exp) < time.time(): 

185 raise ExpiredTokenError(f"Token expired at {exp}") 

186 

187 # Check not-before 

188 nbf = payload.get("nbf") 

189 if nbf and int(nbf) > time.time(): 

190 raise InvalidTokenError(f"Token not valid before {nbf}") 

191 

192 # Check audience 

193 if audience is not None: 

194 aud = payload.get("aud") 

195 if aud is None: 

196 raise InvalidTokenError("Token missing 'aud' claim") 

197 expected = [audience] if isinstance(audience, str) else audience 

198 if isinstance(aud, str): 

199 aud = [aud] 

200 if not set(expected) & set(aud): 

201 raise InvalidTokenError(f"Audience mismatch") 

202 

203 # Check issuer 

204 if issuer is not None: 

205 iss = payload.get("iss") 

206 if iss != issuer: 

207 raise InvalidTokenError(f"Issuer mismatch: expected {issuer}, got {iss}") 

208 

209 return payload 

210 

211 # ---------- Signature ---------- 

212 

213 def _sign(self, data: str) -> str: 

214 if self._algorithm.startswith("HS"): 

215 d = hmac.new(self._secret, data.encode("utf-8"), self._hash_func[self._algorithm]).digest() 

216 return _b64url_encode(d) 

217 else: 

218 raise NotImplementedError(f"Signing with {self._algorithm} requires cryptographic libraries (cryptography)") 

219 

220 def _verify(self, data: str, signature_b64: str) -> bool: 

221 if self._algorithm.startswith("HS"): 

222 expected = self._sign(data) 

223 return hmac.compare_digest(expected, signature_b64) 

224 else: 

225 raise NotImplementedError(f"Verification with {self._algorithm} requires cryptographic libraries (cryptography)") 

226 

227 # ---------- Static helpers ---------- 

228 

229 @staticmethod 

230 def decode_unverified(token: str) -> dict: 

231 """Decode JWT without verifying signature (inspect only).""" 

232 parts = token.split(".") 

233 if len(parts) != 3: 

234 raise InvalidTokenError("JWT must have 3 parts") 

235 return _json_b64_decode(parts[1]) 

236 

237 @staticmethod 

238 def get_header(token: str) -> dict: 

239 """Extract JWT header without verification.""" 

240 parts = token.split(".") 

241 if len(parts) != 3: 

242 raise InvalidTokenError("JWT must have 3 parts") 

243 return _json_b64_decode(parts[0])