diff options
| author | l3wdfut4pwr <l3wdfut4pwr@gmail.com> | 2026-04-27 13:45:09 +0300 |
|---|---|---|
| committer | l3wdfut4pwr <l3wdfut4pwr@gmail.com> | 2026-04-27 13:45:09 +0300 |
| commit | 4848a9e9394b283022085a6305d00f94b11cd703 (patch) | |
| tree | d7ba45885f110e8ded4af20bc98b9f88f75b1f4a /app/auth | |
| parent | f1842be3bfabe7850d33662da2da377676144c48 (diff) | |
add username change and logout
Diffstat (limited to 'app/auth')
| -rw-r--r-- | app/auth/dependencies.py | 48 | ||||
| -rw-r--r-- | app/auth/jwt.py | 134 |
2 files changed, 131 insertions, 51 deletions
diff --git a/app/auth/dependencies.py b/app/auth/dependencies.py new file mode 100644 index 0000000..f482a50 --- /dev/null +++ b/app/auth/dependencies.py @@ -0,0 +1,48 @@ +from fastapi import Depends, HTTPException, Request, status +from sqlalchemy.ext.asyncio import AsyncSession + +from app.auth.jwt import JWTClaims, decode_token +from app.models.user import User +from app.utils.db import get_async_session + + +async def get_optional_user( + request: Request, + session: AsyncSession = Depends(get_async_session), +) -> User | None: + token = request.cookies.get("access_token") + if not token: + return None + + try: + payload: JWTClaims = decode_token(token) + except Exception: + return None + + sub = payload["sub"] + token_version = payload["token_version"] + + if not sub.isdigit(): + return None + + user_id = int(sub) + + user = await User.get_user_by_id(user_id, session=session) + if not user: + return None + + if user.token_version != token_version: + return None + + return user + + +async def get_current_user( + user: User | None = Depends(get_optional_user), +) -> User: + if user is None: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Not authenticated", + ) + return user diff --git a/app/auth/jwt.py b/app/auth/jwt.py index cf8b732..ae795b5 100644 --- a/app/auth/jwt.py +++ b/app/auth/jwt.py @@ -1,71 +1,103 @@ import os -from datetime import datetime, timedelta -from typing import Dict, Optional +from datetime import datetime, timedelta, timezone +from typing import Any, TypedDict import jwt -from dotenv import load_dotenv +from jwt import PyJWTError from app.utils.logger_cfg import logger -load_dotenv() -JWT_SECRET = os.getenv("JWT_SECRET") -if not JWT_SECRET: - logger.critical("JWT_SECRET environment variable not set! Exiting.") - raise RuntimeError("JWT_SECRET environment variable not set!") - JWT_ALGORITHM = "HS256" ACCESS_TOKEN_EXPIRE_MINUTES = 60 -REFRESH_TOKEN_EXPIRE_DAYS = 30 +REFRESH_TOKEN_EXPIRE_DAYS = 7 -def create_access_token(data: Dict, expires_delta: Optional[timedelta] = None) -> str: - try: - to_encode = data.copy() - expire = datetime.utcnow() + ( - expires_delta or timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) - ) - to_encode.update({"exp": expire}) - safe_payload = {k: v for k, v in to_encode.items() if k != "password"} - logger.debug(f"Creating access token with payload: {safe_payload}") - return jwt.encode(to_encode, JWT_SECRET, algorithm=JWT_ALGORITHM) - except Exception as e: - logger.exception(f"Failed to create access token: {e}") - raise RuntimeError("Failed to create access token") +class JWTClaims(TypedDict): + sub: str + token_version: int + exp: int -def create_refresh_token(data: Dict, expires_delta: Optional[timedelta] = None) -> str: - try: - to_encode = data.copy() - expire = datetime.utcnow() + ( - expires_delta or timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS) - ) - to_encode.update({"exp": expire}) - logger.debug(f"Creating refresh token with payload: {to_encode}") - return jwt.encode(to_encode, JWT_SECRET, algorithm=JWT_ALGORITHM) - except Exception as e: - logger.exception(f"Failed to create refresh token: {e}") - raise RuntimeError("Failed to create refresh token") +def get_jwt_secret() -> str: + secret = os.getenv("JWT_SECRET") + + if not isinstance(secret, str) or not secret: + raise RuntimeError("JWT_SECRET environment variable not set!") + + return secret + + +def create_access_token( + data: dict[str, Any], + expires_delta: timedelta | None = None, +) -> str: + secret = get_jwt_secret() + + payload = data.copy() + + expire = datetime.now(timezone.utc) + ( + expires_delta or timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) + ) + + payload["exp"] = int(expire.timestamp()) + + return jwt.encode(payload, secret, algorithm=JWT_ALGORITHM) + +def create_refresh_token( + data: dict[str, Any], + expires_delta: timedelta | None = None, +) -> str: + secret = get_jwt_secret() + + payload = data.copy() + + expire = datetime.now(timezone.utc) + ( + expires_delta or timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS) + ) + + payload["exp"] = int(expire.timestamp()) + + return jwt.encode(payload, secret, algorithm=JWT_ALGORITHM) + + +def decode_token(token: str) -> JWTClaims: + secret = get_jwt_secret() -def decode_token(token: str) -> Dict: try: logger.debug("Decoding JWT token...") - payload = jwt.decode(token, JWT_SECRET, algorithms=[JWT_ALGORITHM]) - safe_payload = {k: v for k, v in payload.items() if k != "password"} - logger.info(f"JWT decoded successfully: {safe_payload}") - return payload + + payload: dict[str, Any] = jwt.decode( + token, + secret, + algorithms=[JWT_ALGORITHM], + ) + + sub = payload.get("sub") + exp = payload.get("exp") + token_version = payload.get("token_version", 0) + + if sub is None: + raise ValueError("Missing sub") + + if exp is None: + raise ValueError("Missing exp") + + return { + "sub": str(sub), + "token_version": int(token_version or 0), + "exp": int(exp), + } + except jwt.ExpiredSignatureError: - logger.warning("JWT token has expired") - raise ValueError("Token has expired") - except jwt.InvalidSignatureError: - logger.warning("JWT token signature invalid") - raise ValueError("Invalid token signature") - except jwt.DecodeError: - logger.warning("JWT token decode failed (possibly malformed)") - raise ValueError("Malformed token") + raise ValueError("Token expired") + except jwt.InvalidTokenError: - logger.warning("JWT token invalid for unknown reason") raise ValueError("Invalid token") + + except PyJWTError: + raise ValueError("JWT decode error") + except Exception as e: - logger.exception(f"Unexpected error decoding JWT: {e}") - raise RuntimeError("Unexpected error while decoding token") + logger.exception(f"Unexpected JWT error: {e}") + raise RuntimeError("Unexpected JWT error") |
