summaryrefslogtreecommitdiff
path: root/app/auth/jwt.py
blob: 2d2aac5937bfa12a72dc61b1c3e22bd5d5c6b2bc (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import os
from datetime import datetime, timedelta
from typing import Dict, Optional

import jwt
from dotenv import load_dotenv

from app.utils.logger_cfg import logger

load_dotenv()
JWT_SECRET = os.getenv("JWT_SECRET")
if not JWT_SECRET:
    logger.critical("JWT_SECRET environment variable not set! Exiting.")
    raise RuntimeError("JWT_SECRET environment variable not set!")

JWT_ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 60
REFRESH_TOKEN_EXPIRE_DAYS = 7


def create_access_token(data: Dict, expires_delta: Optional[timedelta] = None) -> str:
    try:
        to_encode = data.copy()
        expire = datetime.utcnow() + (
            expires_delta or timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
        )
        to_encode.update({"exp": expire})
        safe_payload = {k: v for k, v in to_encode.items() if k != "password"}
        logger.debug(f"Creating access token with payload: {safe_payload}")
        return jwt.encode(to_encode, JWT_SECRET, algorithm=JWT_ALGORITHM)
    except Exception as e:
        logger.exception(f"Failed to create access token: {e}")
        raise RuntimeError("Failed to create access token")


def create_refresh_token(data: Dict, expires_delta: Optional[timedelta] = None) -> str:
    try:
        to_encode = data.copy()
        expire = datetime.utcnow() + (
            expires_delta or timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS)
        )
        to_encode.update({"exp": expire})
        logger.debug(f"Creating refresh token with payload: {to_encode}")
        return jwt.encode(to_encode, JWT_SECRET, algorithm=JWT_ALGORITHM)
    except Exception as e:
        logger.exception(f"Failed to create refresh token: {e}")
        raise RuntimeError("Failed to create refresh token")


def decode_token(token: str) -> Dict:
    try:
        logger.debug("Decoding JWT token...")
        payload = jwt.decode(token, JWT_SECRET, algorithms=[JWT_ALGORITHM])
        safe_payload = {k: v for k, v in payload.items() if k != "password"}
        logger.info(f"JWT decoded successfully: {safe_payload}")
        return payload
    except jwt.ExpiredSignatureError:
        logger.warning("JWT token has expired")
        raise ValueError("Token has expired")
    except jwt.InvalidSignatureError:
        logger.warning("JWT token signature invalid")
        raise ValueError("Invalid token signature")
    except jwt.DecodeError:
        logger.warning("JWT token decode failed (possibly malformed)")
        raise ValueError("Malformed token")
    except jwt.InvalidTokenError:
        logger.warning("JWT token invalid for unknown reason")
        raise ValueError("Invalid token")
    except Exception as e:
        logger.exception(f"Unexpected error decoding JWT: {e}")
        raise RuntimeError("Unexpected error while decoding token")