summaryrefslogtreecommitdiff
path: root/app/auth/jwt.py
blob: ae795b5d62de10425e290751ffcf587816c4ccbf (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import os
from datetime import datetime, timedelta, timezone
from typing import Any, TypedDict

import jwt
from jwt import PyJWTError

from app.utils.logger_cfg import logger

JWT_ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 60
REFRESH_TOKEN_EXPIRE_DAYS = 7


class JWTClaims(TypedDict):
    sub: str
    token_version: int
    exp: int


def get_jwt_secret() -> str:
    secret = os.getenv("JWT_SECRET")

    if not isinstance(secret, str) or not secret:
        raise RuntimeError("JWT_SECRET environment variable not set!")

    return secret


def create_access_token(
    data: dict[str, Any],
    expires_delta: timedelta | None = None,
) -> str:
    secret = get_jwt_secret()

    payload = data.copy()

    expire = datetime.now(timezone.utc) + (
        expires_delta or timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
    )

    payload["exp"] = int(expire.timestamp())

    return jwt.encode(payload, secret, algorithm=JWT_ALGORITHM)


def create_refresh_token(
    data: dict[str, Any],
    expires_delta: timedelta | None = None,
) -> str:
    secret = get_jwt_secret()

    payload = data.copy()

    expire = datetime.now(timezone.utc) + (
        expires_delta or timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS)
    )

    payload["exp"] = int(expire.timestamp())

    return jwt.encode(payload, secret, algorithm=JWT_ALGORITHM)


def decode_token(token: str) -> JWTClaims:
    secret = get_jwt_secret()

    try:
        logger.debug("Decoding JWT token...")

        payload: dict[str, Any] = jwt.decode(
            token,
            secret,
            algorithms=[JWT_ALGORITHM],
        )

        sub = payload.get("sub")
        exp = payload.get("exp")
        token_version = payload.get("token_version", 0)

        if sub is None:
            raise ValueError("Missing sub")

        if exp is None:
            raise ValueError("Missing exp")

        return {
            "sub": str(sub),
            "token_version": int(token_version or 0),
            "exp": int(exp),
        }

    except jwt.ExpiredSignatureError:
        raise ValueError("Token expired")

    except jwt.InvalidTokenError:
        raise ValueError("Invalid token")

    except PyJWTError:
        raise ValueError("JWT decode error")

    except Exception as e:
        logger.exception(f"Unexpected JWT error: {e}")
        raise RuntimeError("Unexpected JWT error")