1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
|
import os
from datetime import datetime, timedelta, timezone
from typing import Any, TypedDict
import jwt
from jwt import PyJWTError
from app.utils.logger_cfg import logger
JWT_ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 60
REFRESH_TOKEN_EXPIRE_DAYS = 7
class JWTClaims(TypedDict):
sub: str
token_version: int
exp: int
def get_jwt_secret() -> str:
secret = os.getenv("JWT_SECRET")
if not isinstance(secret, str) or not secret:
raise RuntimeError("JWT_SECRET environment variable not set!")
return secret
def create_access_token(
data: dict[str, Any],
expires_delta: timedelta | None = None,
) -> str:
secret = get_jwt_secret()
payload = data.copy()
expire = datetime.now(timezone.utc) + (
expires_delta or timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
)
payload["exp"] = int(expire.timestamp())
return jwt.encode(payload, secret, algorithm=JWT_ALGORITHM)
def create_refresh_token(
data: dict[str, Any],
expires_delta: timedelta | None = None,
) -> str:
secret = get_jwt_secret()
payload = data.copy()
expire = datetime.now(timezone.utc) + (
expires_delta or timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS)
)
payload["exp"] = int(expire.timestamp())
return jwt.encode(payload, secret, algorithm=JWT_ALGORITHM)
def decode_token(token: str) -> JWTClaims:
secret = get_jwt_secret()
try:
logger.debug("Decoding JWT token...")
payload: dict[str, Any] = jwt.decode(
token,
secret,
algorithms=[JWT_ALGORITHM],
)
sub = payload.get("sub")
exp = payload.get("exp")
token_version = payload.get("token_version", 0)
if sub is None:
raise ValueError("Missing sub")
if exp is None:
raise ValueError("Missing exp")
return {
"sub": str(sub),
"token_version": int(token_version or 0),
"exp": int(exp),
}
except jwt.ExpiredSignatureError:
raise ValueError("Token expired")
except jwt.InvalidTokenError:
raise ValueError("Invalid token")
except PyJWTError:
raise ValueError("JWT decode error")
except Exception as e:
logger.exception(f"Unexpected JWT error: {e}")
raise RuntimeError("Unexpected JWT error")
|