diff options
Diffstat (limited to 'app')
| -rw-r--r-- | app/auth/dependencies.py | 48 | ||||
| -rw-r--r-- | app/auth/jwt.py | 134 | ||||
| -rw-r--r-- | app/main.py | 10 | ||||
| -rw-r--r-- | app/routes/__init__.py | 19 | ||||
| -rw-r--r-- | app/routes/auth/__init__.py | 0 | ||||
| -rw-r--r-- | app/routes/auth/auth.py (renamed from app/routes/auth.py) | 0 | ||||
| -rw-r--r-- | app/routes/auth/logout.py | 35 | ||||
| -rw-r--r-- | app/routes/auth/register.py (renamed from app/routes/register.py) | 0 | ||||
| -rw-r--r-- | app/routes/me.py | 51 | ||||
| -rw-r--r-- | app/routes/users/__init__.py | 0 | ||||
| -rw-r--r-- | app/routes/users/changeusername.py | 40 | ||||
| -rw-r--r-- | app/routes/users/me.py | 17 | ||||
| -rw-r--r-- | app/routes/users/security.py | 0 | ||||
| -rw-r--r-- | app/routes/users/user.py (renamed from app/routes/user.py) | 0 | ||||
| -rw-r--r-- | app/schemas/user.py | 5 | ||||
| -rw-r--r-- | app/utils/__init__.py | 17 | ||||
| -rw-r--r-- | app/utils/cors.py | 3 | ||||
| -rw-r--r-- | app/utils/create_tables.py | 7 | ||||
| -rw-r--r-- | app/utils/db.py | 49 | ||||
| -rw-r--r-- | app/utils/env.py | 5 |
20 files changed, 291 insertions, 149 deletions
diff --git a/app/auth/dependencies.py b/app/auth/dependencies.py new file mode 100644 index 0000000..f482a50 --- /dev/null +++ b/app/auth/dependencies.py @@ -0,0 +1,48 @@ +from fastapi import Depends, HTTPException, Request, status +from sqlalchemy.ext.asyncio import AsyncSession + +from app.auth.jwt import JWTClaims, decode_token +from app.models.user import User +from app.utils.db import get_async_session + + +async def get_optional_user( + request: Request, + session: AsyncSession = Depends(get_async_session), +) -> User | None: + token = request.cookies.get("access_token") + if not token: + return None + + try: + payload: JWTClaims = decode_token(token) + except Exception: + return None + + sub = payload["sub"] + token_version = payload["token_version"] + + if not sub.isdigit(): + return None + + user_id = int(sub) + + user = await User.get_user_by_id(user_id, session=session) + if not user: + return None + + if user.token_version != token_version: + return None + + return user + + +async def get_current_user( + user: User | None = Depends(get_optional_user), +) -> User: + if user is None: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Not authenticated", + ) + return user diff --git a/app/auth/jwt.py b/app/auth/jwt.py index cf8b732..ae795b5 100644 --- a/app/auth/jwt.py +++ b/app/auth/jwt.py @@ -1,71 +1,103 @@ import os -from datetime import datetime, timedelta -from typing import Dict, Optional +from datetime import datetime, timedelta, timezone +from typing import Any, TypedDict import jwt -from dotenv import load_dotenv +from jwt import PyJWTError from app.utils.logger_cfg import logger -load_dotenv() -JWT_SECRET = os.getenv("JWT_SECRET") -if not JWT_SECRET: - logger.critical("JWT_SECRET environment variable not set! Exiting.") - raise RuntimeError("JWT_SECRET environment variable not set!") - JWT_ALGORITHM = "HS256" ACCESS_TOKEN_EXPIRE_MINUTES = 60 -REFRESH_TOKEN_EXPIRE_DAYS = 30 +REFRESH_TOKEN_EXPIRE_DAYS = 7 -def create_access_token(data: Dict, expires_delta: Optional[timedelta] = None) -> str: - try: - to_encode = data.copy() - expire = datetime.utcnow() + ( - expires_delta or timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) - ) - to_encode.update({"exp": expire}) - safe_payload = {k: v for k, v in to_encode.items() if k != "password"} - logger.debug(f"Creating access token with payload: {safe_payload}") - return jwt.encode(to_encode, JWT_SECRET, algorithm=JWT_ALGORITHM) - except Exception as e: - logger.exception(f"Failed to create access token: {e}") - raise RuntimeError("Failed to create access token") +class JWTClaims(TypedDict): + sub: str + token_version: int + exp: int -def create_refresh_token(data: Dict, expires_delta: Optional[timedelta] = None) -> str: - try: - to_encode = data.copy() - expire = datetime.utcnow() + ( - expires_delta or timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS) - ) - to_encode.update({"exp": expire}) - logger.debug(f"Creating refresh token with payload: {to_encode}") - return jwt.encode(to_encode, JWT_SECRET, algorithm=JWT_ALGORITHM) - except Exception as e: - logger.exception(f"Failed to create refresh token: {e}") - raise RuntimeError("Failed to create refresh token") +def get_jwt_secret() -> str: + secret = os.getenv("JWT_SECRET") + + if not isinstance(secret, str) or not secret: + raise RuntimeError("JWT_SECRET environment variable not set!") + + return secret + + +def create_access_token( + data: dict[str, Any], + expires_delta: timedelta | None = None, +) -> str: + secret = get_jwt_secret() + + payload = data.copy() + + expire = datetime.now(timezone.utc) + ( + expires_delta or timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) + ) + + payload["exp"] = int(expire.timestamp()) + + return jwt.encode(payload, secret, algorithm=JWT_ALGORITHM) + +def create_refresh_token( + data: dict[str, Any], + expires_delta: timedelta | None = None, +) -> str: + secret = get_jwt_secret() + + payload = data.copy() + + expire = datetime.now(timezone.utc) + ( + expires_delta or timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS) + ) + + payload["exp"] = int(expire.timestamp()) + + return jwt.encode(payload, secret, algorithm=JWT_ALGORITHM) + + +def decode_token(token: str) -> JWTClaims: + secret = get_jwt_secret() -def decode_token(token: str) -> Dict: try: logger.debug("Decoding JWT token...") - payload = jwt.decode(token, JWT_SECRET, algorithms=[JWT_ALGORITHM]) - safe_payload = {k: v for k, v in payload.items() if k != "password"} - logger.info(f"JWT decoded successfully: {safe_payload}") - return payload + + payload: dict[str, Any] = jwt.decode( + token, + secret, + algorithms=[JWT_ALGORITHM], + ) + + sub = payload.get("sub") + exp = payload.get("exp") + token_version = payload.get("token_version", 0) + + if sub is None: + raise ValueError("Missing sub") + + if exp is None: + raise ValueError("Missing exp") + + return { + "sub": str(sub), + "token_version": int(token_version or 0), + "exp": int(exp), + } + except jwt.ExpiredSignatureError: - logger.warning("JWT token has expired") - raise ValueError("Token has expired") - except jwt.InvalidSignatureError: - logger.warning("JWT token signature invalid") - raise ValueError("Invalid token signature") - except jwt.DecodeError: - logger.warning("JWT token decode failed (possibly malformed)") - raise ValueError("Malformed token") + raise ValueError("Token expired") + except jwt.InvalidTokenError: - logger.warning("JWT token invalid for unknown reason") raise ValueError("Invalid token") + + except PyJWTError: + raise ValueError("JWT decode error") + except Exception as e: - logger.exception(f"Unexpected error decoding JWT: {e}") - raise RuntimeError("Unexpected error while decoding token") + logger.exception(f"Unexpected JWT error: {e}") + raise RuntimeError("Unexpected JWT error") diff --git a/app/main.py b/app/main.py index f1c556a..49e7a1f 100644 --- a/app/main.py +++ b/app/main.py @@ -1,15 +1,20 @@ import time from contextlib import asynccontextmanager +from pathlib import Path +from dotenv import load_dotenv from fastapi import FastAPI from sqlalchemy import text from app.routes import router as api_router from app.utils.cors import setup_cors from app.utils.create_tables import init_db -from app.utils.db import engine +from app.utils.db import get_engine, init_db_engine from app.utils.logger_cfg import logger +load_dotenv(Path(__file__).resolve().parent.parent / ".env") +init_db_engine() + app_start_time = time.perf_counter() logger.debug("App start timestamp recorded") @@ -18,6 +23,8 @@ logger.debug("App start timestamp recorded") async def lifespan(app: FastAPI): logger.info("Application startup initiated") + engine = get_engine() + try: async with engine.begin() as conn: logger.debug("Executing test query: SELECT 1") @@ -46,6 +53,7 @@ async def lifespan(app: FastAPI): logger.info("Application shutdown initiated") try: + engine = get_engine() await engine.dispose() logger.info("Database engine disposed successfully") except Exception: diff --git a/app/routes/__init__.py b/app/routes/__init__.py index a57869a..65d9c36 100644 --- a/app/routes/__init__.py +++ b/app/routes/__init__.py @@ -1,13 +1,18 @@ from fastapi import APIRouter -from .auth import router as auth_router -from .me import router as me_router -from .register import router as register_router -from .user import router as user_router +from app.routes.auth.auth import router as auth_router +from app.routes.auth.logout import router as logout_router +from app.routes.auth.register import router as register_router +from app.routes.users.changeusername import router as changeusername_router +from app.routes.users.me import router as me_router +from app.routes.users.user import router as user_router router = APIRouter() -router.include_router(register_router, prefix="/auth") -router.include_router(auth_router, prefix="/auth") -router.include_router(user_router) +router.include_router(changeusername_router) router.include_router(me_router) +router.include_router(user_router) + +router.include_router(auth_router, prefix="/auth") +router.include_router(register_router, prefix="/auth") +router.include_router(logout_router, prefix="/auth") diff --git a/app/routes/auth/__init__.py b/app/routes/auth/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/app/routes/auth/__init__.py diff --git a/app/routes/auth.py b/app/routes/auth/auth.py index 6e0d410..6e0d410 100644 --- a/app/routes/auth.py +++ b/app/routes/auth/auth.py diff --git a/app/routes/auth/logout.py b/app/routes/auth/logout.py new file mode 100644 index 0000000..a55ea9e --- /dev/null +++ b/app/routes/auth/logout.py @@ -0,0 +1,35 @@ +from fastapi import APIRouter, Depends, Response +from sqlalchemy.ext.asyncio import AsyncSession + +from app.auth.dependencies import get_current_user +from app.models.user import User +from app.utils.db import get_async_session +from app.utils.logger_cfg import logger + +router = APIRouter(tags=["auth"]) + + +COOKIE_KWARGS = { + "httponly": True, + "secure": False, + "samesite": "lax", + "path": "/", +} + + +@router.post("/logout") +async def logout( + response: Response, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(get_current_user), +): + response.delete_cookie("access_token", **COOKIE_KWARGS) + response.delete_cookie("refresh_token", **COOKIE_KWARGS) + + user.token_version += 1 + session.add(user) + await session.commit() + + logger.info("User logged out everywhere | user_id={}", user.id) + + return {"message": "Logged out successfully"} diff --git a/app/routes/register.py b/app/routes/auth/register.py index f0b36ed..f0b36ed 100644 --- a/app/routes/register.py +++ b/app/routes/auth/register.py diff --git a/app/routes/me.py b/app/routes/me.py deleted file mode 100644 index 6d28a80..0000000 --- a/app/routes/me.py +++ /dev/null @@ -1,51 +0,0 @@ -from fastapi import APIRouter, Depends, Request -from sqlalchemy.ext.asyncio import AsyncSession - -from app.auth.jwt import decode_token -from app.models.user import User -from app.utils.db import get_async_session - -router = APIRouter(tags=["auth"]) - - -async def get_current_user_from_cookie( - request: Request, - session: AsyncSession = Depends(get_async_session), -) -> dict: - token = request.cookies.get("access_token") - if not token: - return {"authenticated": False, "user": None} - - try: - payload = decode_token(token) - sub = payload.get("sub") - if sub is None: - return {"authenticated": False, "user": None} - user_id = int(sub) - except ValueError, TypeError: - return {"authenticated": False, "user": None} - - user = await User.get_user_by_id(user_id, session=session) - if not user or user.token_version != payload.get("token_version"): - return {"authenticated": False, "user": None} - - return { - "authenticated": True, - "user": { - "id": user.id, - "username": user.username, - "password": user.has_password, - "google_id": user.google_id, - "email": user.email, - "premium": user.premium, - "is_banned": user.is_banned, - "is_moderator": user.is_moderator, - }, - } - - -@router.get("/me") -async def read_current_user( - user_info: dict = Depends(get_current_user_from_cookie), -): - return user_info diff --git a/app/routes/users/__init__.py b/app/routes/users/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/app/routes/users/__init__.py diff --git a/app/routes/users/changeusername.py b/app/routes/users/changeusername.py new file mode 100644 index 0000000..66ba8da --- /dev/null +++ b/app/routes/users/changeusername.py @@ -0,0 +1,40 @@ +from fastapi import APIRouter, Depends, HTTPException +from pydantic import BaseModel +from sqlalchemy.ext.asyncio import AsyncSession + +from app.auth.dependencies import get_current_user +from app.models.user import User +from app.utils.db import get_async_session + +router = APIRouter() + + +class ChangeUsernameRequest(BaseModel): + username: str + + +@router.patch("/users/change-username") +async def change_username( + data: ChangeUsernameRequest, + user_info: dict = Depends(get_current_user), + session: AsyncSession = Depends(get_async_session), +): + if not user_info["authenticated"]: + raise HTTPException(status_code=401, detail="Not authenticated") + + user = user_info["user"] + + if len(data.username) < 3: + raise HTTPException(status_code=400, detail="Username too short") + + db_user = await session.get(User, user["id"]) + + if not db_user: + raise HTTPException(status_code=404, detail="User not found") + + db_user.username = data.username + + await session.commit() + await session.refresh(db_user) + + return {"success": True, "username": db_user.username} diff --git a/app/routes/users/me.py b/app/routes/users/me.py new file mode 100644 index 0000000..a54fbfe --- /dev/null +++ b/app/routes/users/me.py @@ -0,0 +1,17 @@ +from fastapi import APIRouter, Depends + +from app.auth.dependencies import get_optional_user +from app.models.user import User +from app.schemas.user import MeResponse, UserRead + +router = APIRouter(tags=["auth"]) + + +@router.get("/me", response_model=MeResponse) +async def me( + user: User | None = Depends(get_optional_user), +): + return MeResponse( + authenticated=user is not None, + user=UserRead.model_validate(user) if user else None, + ) diff --git a/app/routes/users/security.py b/app/routes/users/security.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/app/routes/users/security.py diff --git a/app/routes/user.py b/app/routes/users/user.py index 1eb096d..1eb096d 100644 --- a/app/routes/user.py +++ b/app/routes/users/user.py diff --git a/app/schemas/user.py b/app/schemas/user.py index d809196..5d60e2c 100644 --- a/app/schemas/user.py +++ b/app/schemas/user.py @@ -33,3 +33,8 @@ class UserRead(BaseModel): model_config = { "from_attributes": True, } + + +class MeResponse(BaseModel): + authenticated: bool + user: Optional[UserRead] = None diff --git a/app/utils/__init__.py b/app/utils/__init__.py index a05c4d1..e69de29 100644 --- a/app/utils/__init__.py +++ b/app/utils/__init__.py @@ -1,17 +0,0 @@ -from .cors import setup_cors -from .create_tables import init_db -from .db import Base, async_session, engine, get_async_session -from .hash_cfg import hash_password, verify_password -from .logger_cfg import logger - -__all__ = [ - "engine", - "async_session", - "get_async_session", - "Base", - "setup_cors", - "init_db", - "hash_password", - "verify_password", - "logger", -] diff --git a/app/utils/cors.py b/app/utils/cors.py index e7b54e8..71dadbb 100644 --- a/app/utils/cors.py +++ b/app/utils/cors.py @@ -1,10 +1,7 @@ import os -from dotenv import load_dotenv from fastapi.middleware.cors import CORSMiddleware -load_dotenv() - def setup_cors(app): diff --git a/app/utils/create_tables.py b/app/utils/create_tables.py index e438a03..0d2eaa2 100644 --- a/app/utils/create_tables.py +++ b/app/utils/create_tables.py @@ -1,7 +1,10 @@ from app.models.user import Base -from app.utils.db import engine +from app.utils.db import get_engine -async def init_db(): +async def init_db() -> None: + + engine = get_engine() + async with engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) diff --git a/app/utils/db.py b/app/utils/db.py index 5531998..d3c7318 100644 --- a/app/utils/db.py +++ b/app/utils/db.py @@ -1,7 +1,6 @@ import os -from typing import AsyncGenerator +from typing import AsyncGenerator, Optional -from dotenv import load_dotenv from sqlalchemy.ext.asyncio import ( AsyncEngine, AsyncSession, @@ -10,29 +9,45 @@ from sqlalchemy.ext.asyncio import ( ) from sqlalchemy.orm import DeclarativeBase -load_dotenv() +engine: Optional[AsyncEngine] = None +async_session: Optional[async_sessionmaker[AsyncSession]] = None -DATABASE_URL = os.getenv("DATABASE_URL") -if not DATABASE_URL: - raise ValueError("DATABASE_URL not found in .env") +class Base(DeclarativeBase): + pass -engine: AsyncEngine = create_async_engine( - DATABASE_URL, - echo=True, - pool_pre_ping=True, -) -async_session = async_sessionmaker( - bind=engine, - expire_on_commit=False, -) +def init_db_engine() -> None: + global engine, async_session + database_url = os.getenv("DATABASE_URL") -class Base(DeclarativeBase): - pass + if not database_url: + raise RuntimeError("DATABASE_URL not found in environment") + + engine = create_async_engine( + database_url, + echo=True, + pool_pre_ping=True, + ) + + async_session = async_sessionmaker( + bind=engine, + expire_on_commit=False, + ) + + +def get_engine() -> AsyncEngine: + if engine is None: + raise RuntimeError( + "DB engine not initialized. Call init_db_engine() first." + ) + return engine async def get_async_session() -> AsyncGenerator[AsyncSession, None]: + if async_session is None: + raise RuntimeError("DB not initialized. Call init_db_engine() first.") + async with async_session() as session: yield session diff --git a/app/utils/env.py b/app/utils/env.py index e69de29..76b61a7 100644 --- a/app/utils/env.py +++ b/app/utils/env.py @@ -0,0 +1,5 @@ +from pathlib import Path + +from dotenv import load_dotenv + +load_dotenv(Path(__file__).resolve().parents[2] / ".env") |
