import os from datetime import datetime, timedelta, timezone from typing import Any, TypedDict import jwt from jwt import PyJWTError from app.utils.logger_cfg import logger JWT_ALGORITHM = "HS256" ACCESS_TOKEN_EXPIRE_MINUTES = 60 REFRESH_TOKEN_EXPIRE_DAYS = 7 class JWTClaims(TypedDict): sub: str token_version: int exp: int def get_jwt_secret() -> str: secret = os.getenv("JWT_SECRET") if not isinstance(secret, str) or not secret: raise RuntimeError("JWT_SECRET environment variable not set!") return secret def create_access_token( data: dict[str, Any], expires_delta: timedelta | None = None, ) -> str: secret = get_jwt_secret() payload = data.copy() expire = datetime.now(timezone.utc) + ( expires_delta or timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) ) payload["exp"] = int(expire.timestamp()) return jwt.encode(payload, secret, algorithm=JWT_ALGORITHM) def create_refresh_token( data: dict[str, Any], expires_delta: timedelta | None = None, ) -> str: secret = get_jwt_secret() payload = data.copy() expire = datetime.now(timezone.utc) + ( expires_delta or timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS) ) payload["exp"] = int(expire.timestamp()) return jwt.encode(payload, secret, algorithm=JWT_ALGORITHM) def decode_token(token: str) -> JWTClaims: secret = get_jwt_secret() try: logger.debug("Decoding JWT token...") payload: dict[str, Any] = jwt.decode( token, secret, algorithms=[JWT_ALGORITHM], ) sub = payload.get("sub") exp = payload.get("exp") token_version = payload.get("token_version", 0) if sub is None: raise ValueError("Missing sub") if exp is None: raise ValueError("Missing exp") return { "sub": str(sub), "token_version": int(token_version or 0), "exp": int(exp), } except jwt.ExpiredSignatureError: raise ValueError("Token expired") except jwt.InvalidTokenError: raise ValueError("Invalid token") except PyJWTError: raise ValueError("JWT decode error") except Exception as e: logger.exception(f"Unexpected JWT error: {e}") raise RuntimeError("Unexpected JWT error")