summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--app/auth/dependencies.py48
-rw-r--r--app/auth/jwt.py134
-rw-r--r--app/main.py10
-rw-r--r--app/routes/__init__.py19
-rw-r--r--app/routes/auth/__init__.py0
-rw-r--r--app/routes/auth/auth.py (renamed from app/routes/auth.py)0
-rw-r--r--app/routes/auth/logout.py35
-rw-r--r--app/routes/auth/register.py (renamed from app/routes/register.py)0
-rw-r--r--app/routes/me.py51
-rw-r--r--app/routes/users/__init__.py0
-rw-r--r--app/routes/users/changeusername.py40
-rw-r--r--app/routes/users/me.py17
-rw-r--r--app/routes/users/security.py0
-rw-r--r--app/routes/users/user.py (renamed from app/routes/user.py)0
-rw-r--r--app/schemas/user.py5
-rw-r--r--app/utils/__init__.py17
-rw-r--r--app/utils/cors.py3
-rw-r--r--app/utils/create_tables.py7
-rw-r--r--app/utils/db.py49
-rw-r--r--app/utils/env.py5
20 files changed, 291 insertions, 149 deletions
diff --git a/app/auth/dependencies.py b/app/auth/dependencies.py
new file mode 100644
index 0000000..f482a50
--- /dev/null
+++ b/app/auth/dependencies.py
@@ -0,0 +1,48 @@
+from fastapi import Depends, HTTPException, Request, status
+from sqlalchemy.ext.asyncio import AsyncSession
+
+from app.auth.jwt import JWTClaims, decode_token
+from app.models.user import User
+from app.utils.db import get_async_session
+
+
+async def get_optional_user(
+ request: Request,
+ session: AsyncSession = Depends(get_async_session),
+) -> User | None:
+ token = request.cookies.get("access_token")
+ if not token:
+ return None
+
+ try:
+ payload: JWTClaims = decode_token(token)
+ except Exception:
+ return None
+
+ sub = payload["sub"]
+ token_version = payload["token_version"]
+
+ if not sub.isdigit():
+ return None
+
+ user_id = int(sub)
+
+ user = await User.get_user_by_id(user_id, session=session)
+ if not user:
+ return None
+
+ if user.token_version != token_version:
+ return None
+
+ return user
+
+
+async def get_current_user(
+ user: User | None = Depends(get_optional_user),
+) -> User:
+ if user is None:
+ raise HTTPException(
+ status_code=status.HTTP_401_UNAUTHORIZED,
+ detail="Not authenticated",
+ )
+ return user
diff --git a/app/auth/jwt.py b/app/auth/jwt.py
index cf8b732..ae795b5 100644
--- a/app/auth/jwt.py
+++ b/app/auth/jwt.py
@@ -1,71 +1,103 @@
import os
-from datetime import datetime, timedelta
-from typing import Dict, Optional
+from datetime import datetime, timedelta, timezone
+from typing import Any, TypedDict
import jwt
-from dotenv import load_dotenv
+from jwt import PyJWTError
from app.utils.logger_cfg import logger
-load_dotenv()
-JWT_SECRET = os.getenv("JWT_SECRET")
-if not JWT_SECRET:
- logger.critical("JWT_SECRET environment variable not set! Exiting.")
- raise RuntimeError("JWT_SECRET environment variable not set!")
-
JWT_ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 60
-REFRESH_TOKEN_EXPIRE_DAYS = 30
+REFRESH_TOKEN_EXPIRE_DAYS = 7
-def create_access_token(data: Dict, expires_delta: Optional[timedelta] = None) -> str:
- try:
- to_encode = data.copy()
- expire = datetime.utcnow() + (
- expires_delta or timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
- )
- to_encode.update({"exp": expire})
- safe_payload = {k: v for k, v in to_encode.items() if k != "password"}
- logger.debug(f"Creating access token with payload: {safe_payload}")
- return jwt.encode(to_encode, JWT_SECRET, algorithm=JWT_ALGORITHM)
- except Exception as e:
- logger.exception(f"Failed to create access token: {e}")
- raise RuntimeError("Failed to create access token")
+class JWTClaims(TypedDict):
+ sub: str
+ token_version: int
+ exp: int
-def create_refresh_token(data: Dict, expires_delta: Optional[timedelta] = None) -> str:
- try:
- to_encode = data.copy()
- expire = datetime.utcnow() + (
- expires_delta or timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS)
- )
- to_encode.update({"exp": expire})
- logger.debug(f"Creating refresh token with payload: {to_encode}")
- return jwt.encode(to_encode, JWT_SECRET, algorithm=JWT_ALGORITHM)
- except Exception as e:
- logger.exception(f"Failed to create refresh token: {e}")
- raise RuntimeError("Failed to create refresh token")
+def get_jwt_secret() -> str:
+ secret = os.getenv("JWT_SECRET")
+
+ if not isinstance(secret, str) or not secret:
+ raise RuntimeError("JWT_SECRET environment variable not set!")
+
+ return secret
+
+
+def create_access_token(
+ data: dict[str, Any],
+ expires_delta: timedelta | None = None,
+) -> str:
+ secret = get_jwt_secret()
+
+ payload = data.copy()
+
+ expire = datetime.now(timezone.utc) + (
+ expires_delta or timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
+ )
+
+ payload["exp"] = int(expire.timestamp())
+
+ return jwt.encode(payload, secret, algorithm=JWT_ALGORITHM)
+
+def create_refresh_token(
+ data: dict[str, Any],
+ expires_delta: timedelta | None = None,
+) -> str:
+ secret = get_jwt_secret()
+
+ payload = data.copy()
+
+ expire = datetime.now(timezone.utc) + (
+ expires_delta or timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS)
+ )
+
+ payload["exp"] = int(expire.timestamp())
+
+ return jwt.encode(payload, secret, algorithm=JWT_ALGORITHM)
+
+
+def decode_token(token: str) -> JWTClaims:
+ secret = get_jwt_secret()
-def decode_token(token: str) -> Dict:
try:
logger.debug("Decoding JWT token...")
- payload = jwt.decode(token, JWT_SECRET, algorithms=[JWT_ALGORITHM])
- safe_payload = {k: v for k, v in payload.items() if k != "password"}
- logger.info(f"JWT decoded successfully: {safe_payload}")
- return payload
+
+ payload: dict[str, Any] = jwt.decode(
+ token,
+ secret,
+ algorithms=[JWT_ALGORITHM],
+ )
+
+ sub = payload.get("sub")
+ exp = payload.get("exp")
+ token_version = payload.get("token_version", 0)
+
+ if sub is None:
+ raise ValueError("Missing sub")
+
+ if exp is None:
+ raise ValueError("Missing exp")
+
+ return {
+ "sub": str(sub),
+ "token_version": int(token_version or 0),
+ "exp": int(exp),
+ }
+
except jwt.ExpiredSignatureError:
- logger.warning("JWT token has expired")
- raise ValueError("Token has expired")
- except jwt.InvalidSignatureError:
- logger.warning("JWT token signature invalid")
- raise ValueError("Invalid token signature")
- except jwt.DecodeError:
- logger.warning("JWT token decode failed (possibly malformed)")
- raise ValueError("Malformed token")
+ raise ValueError("Token expired")
+
except jwt.InvalidTokenError:
- logger.warning("JWT token invalid for unknown reason")
raise ValueError("Invalid token")
+
+ except PyJWTError:
+ raise ValueError("JWT decode error")
+
except Exception as e:
- logger.exception(f"Unexpected error decoding JWT: {e}")
- raise RuntimeError("Unexpected error while decoding token")
+ logger.exception(f"Unexpected JWT error: {e}")
+ raise RuntimeError("Unexpected JWT error")
diff --git a/app/main.py b/app/main.py
index f1c556a..49e7a1f 100644
--- a/app/main.py
+++ b/app/main.py
@@ -1,15 +1,20 @@
import time
from contextlib import asynccontextmanager
+from pathlib import Path
+from dotenv import load_dotenv
from fastapi import FastAPI
from sqlalchemy import text
from app.routes import router as api_router
from app.utils.cors import setup_cors
from app.utils.create_tables import init_db
-from app.utils.db import engine
+from app.utils.db import get_engine, init_db_engine
from app.utils.logger_cfg import logger
+load_dotenv(Path(__file__).resolve().parent.parent / ".env")
+init_db_engine()
+
app_start_time = time.perf_counter()
logger.debug("App start timestamp recorded")
@@ -18,6 +23,8 @@ logger.debug("App start timestamp recorded")
async def lifespan(app: FastAPI):
logger.info("Application startup initiated")
+ engine = get_engine()
+
try:
async with engine.begin() as conn:
logger.debug("Executing test query: SELECT 1")
@@ -46,6 +53,7 @@ async def lifespan(app: FastAPI):
logger.info("Application shutdown initiated")
try:
+ engine = get_engine()
await engine.dispose()
logger.info("Database engine disposed successfully")
except Exception:
diff --git a/app/routes/__init__.py b/app/routes/__init__.py
index a57869a..65d9c36 100644
--- a/app/routes/__init__.py
+++ b/app/routes/__init__.py
@@ -1,13 +1,18 @@
from fastapi import APIRouter
-from .auth import router as auth_router
-from .me import router as me_router
-from .register import router as register_router
-from .user import router as user_router
+from app.routes.auth.auth import router as auth_router
+from app.routes.auth.logout import router as logout_router
+from app.routes.auth.register import router as register_router
+from app.routes.users.changeusername import router as changeusername_router
+from app.routes.users.me import router as me_router
+from app.routes.users.user import router as user_router
router = APIRouter()
-router.include_router(register_router, prefix="/auth")
-router.include_router(auth_router, prefix="/auth")
-router.include_router(user_router)
+router.include_router(changeusername_router)
router.include_router(me_router)
+router.include_router(user_router)
+
+router.include_router(auth_router, prefix="/auth")
+router.include_router(register_router, prefix="/auth")
+router.include_router(logout_router, prefix="/auth")
diff --git a/app/routes/auth/__init__.py b/app/routes/auth/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/app/routes/auth/__init__.py
diff --git a/app/routes/auth.py b/app/routes/auth/auth.py
index 6e0d410..6e0d410 100644
--- a/app/routes/auth.py
+++ b/app/routes/auth/auth.py
diff --git a/app/routes/auth/logout.py b/app/routes/auth/logout.py
new file mode 100644
index 0000000..a55ea9e
--- /dev/null
+++ b/app/routes/auth/logout.py
@@ -0,0 +1,35 @@
+from fastapi import APIRouter, Depends, Response
+from sqlalchemy.ext.asyncio import AsyncSession
+
+from app.auth.dependencies import get_current_user
+from app.models.user import User
+from app.utils.db import get_async_session
+from app.utils.logger_cfg import logger
+
+router = APIRouter(tags=["auth"])
+
+
+COOKIE_KWARGS = {
+ "httponly": True,
+ "secure": False,
+ "samesite": "lax",
+ "path": "/",
+}
+
+
+@router.post("/logout")
+async def logout(
+ response: Response,
+ session: AsyncSession = Depends(get_async_session),
+ user: User = Depends(get_current_user),
+):
+ response.delete_cookie("access_token", **COOKIE_KWARGS)
+ response.delete_cookie("refresh_token", **COOKIE_KWARGS)
+
+ user.token_version += 1
+ session.add(user)
+ await session.commit()
+
+ logger.info("User logged out everywhere | user_id={}", user.id)
+
+ return {"message": "Logged out successfully"}
diff --git a/app/routes/register.py b/app/routes/auth/register.py
index f0b36ed..f0b36ed 100644
--- a/app/routes/register.py
+++ b/app/routes/auth/register.py
diff --git a/app/routes/me.py b/app/routes/me.py
deleted file mode 100644
index 6d28a80..0000000
--- a/app/routes/me.py
+++ /dev/null
@@ -1,51 +0,0 @@
-from fastapi import APIRouter, Depends, Request
-from sqlalchemy.ext.asyncio import AsyncSession
-
-from app.auth.jwt import decode_token
-from app.models.user import User
-from app.utils.db import get_async_session
-
-router = APIRouter(tags=["auth"])
-
-
-async def get_current_user_from_cookie(
- request: Request,
- session: AsyncSession = Depends(get_async_session),
-) -> dict:
- token = request.cookies.get("access_token")
- if not token:
- return {"authenticated": False, "user": None}
-
- try:
- payload = decode_token(token)
- sub = payload.get("sub")
- if sub is None:
- return {"authenticated": False, "user": None}
- user_id = int(sub)
- except ValueError, TypeError:
- return {"authenticated": False, "user": None}
-
- user = await User.get_user_by_id(user_id, session=session)
- if not user or user.token_version != payload.get("token_version"):
- return {"authenticated": False, "user": None}
-
- return {
- "authenticated": True,
- "user": {
- "id": user.id,
- "username": user.username,
- "password": user.has_password,
- "google_id": user.google_id,
- "email": user.email,
- "premium": user.premium,
- "is_banned": user.is_banned,
- "is_moderator": user.is_moderator,
- },
- }
-
-
-@router.get("/me")
-async def read_current_user(
- user_info: dict = Depends(get_current_user_from_cookie),
-):
- return user_info
diff --git a/app/routes/users/__init__.py b/app/routes/users/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/app/routes/users/__init__.py
diff --git a/app/routes/users/changeusername.py b/app/routes/users/changeusername.py
new file mode 100644
index 0000000..66ba8da
--- /dev/null
+++ b/app/routes/users/changeusername.py
@@ -0,0 +1,40 @@
+from fastapi import APIRouter, Depends, HTTPException
+from pydantic import BaseModel
+from sqlalchemy.ext.asyncio import AsyncSession
+
+from app.auth.dependencies import get_current_user
+from app.models.user import User
+from app.utils.db import get_async_session
+
+router = APIRouter()
+
+
+class ChangeUsernameRequest(BaseModel):
+ username: str
+
+
+@router.patch("/users/change-username")
+async def change_username(
+ data: ChangeUsernameRequest,
+ user_info: dict = Depends(get_current_user),
+ session: AsyncSession = Depends(get_async_session),
+):
+ if not user_info["authenticated"]:
+ raise HTTPException(status_code=401, detail="Not authenticated")
+
+ user = user_info["user"]
+
+ if len(data.username) < 3:
+ raise HTTPException(status_code=400, detail="Username too short")
+
+ db_user = await session.get(User, user["id"])
+
+ if not db_user:
+ raise HTTPException(status_code=404, detail="User not found")
+
+ db_user.username = data.username
+
+ await session.commit()
+ await session.refresh(db_user)
+
+ return {"success": True, "username": db_user.username}
diff --git a/app/routes/users/me.py b/app/routes/users/me.py
new file mode 100644
index 0000000..a54fbfe
--- /dev/null
+++ b/app/routes/users/me.py
@@ -0,0 +1,17 @@
+from fastapi import APIRouter, Depends
+
+from app.auth.dependencies import get_optional_user
+from app.models.user import User
+from app.schemas.user import MeResponse, UserRead
+
+router = APIRouter(tags=["auth"])
+
+
+@router.get("/me", response_model=MeResponse)
+async def me(
+ user: User | None = Depends(get_optional_user),
+):
+ return MeResponse(
+ authenticated=user is not None,
+ user=UserRead.model_validate(user) if user else None,
+ )
diff --git a/app/routes/users/security.py b/app/routes/users/security.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/app/routes/users/security.py
diff --git a/app/routes/user.py b/app/routes/users/user.py
index 1eb096d..1eb096d 100644
--- a/app/routes/user.py
+++ b/app/routes/users/user.py
diff --git a/app/schemas/user.py b/app/schemas/user.py
index d809196..5d60e2c 100644
--- a/app/schemas/user.py
+++ b/app/schemas/user.py
@@ -33,3 +33,8 @@ class UserRead(BaseModel):
model_config = {
"from_attributes": True,
}
+
+
+class MeResponse(BaseModel):
+ authenticated: bool
+ user: Optional[UserRead] = None
diff --git a/app/utils/__init__.py b/app/utils/__init__.py
index a05c4d1..e69de29 100644
--- a/app/utils/__init__.py
+++ b/app/utils/__init__.py
@@ -1,17 +0,0 @@
-from .cors import setup_cors
-from .create_tables import init_db
-from .db import Base, async_session, engine, get_async_session
-from .hash_cfg import hash_password, verify_password
-from .logger_cfg import logger
-
-__all__ = [
- "engine",
- "async_session",
- "get_async_session",
- "Base",
- "setup_cors",
- "init_db",
- "hash_password",
- "verify_password",
- "logger",
-]
diff --git a/app/utils/cors.py b/app/utils/cors.py
index e7b54e8..71dadbb 100644
--- a/app/utils/cors.py
+++ b/app/utils/cors.py
@@ -1,10 +1,7 @@
import os
-from dotenv import load_dotenv
from fastapi.middleware.cors import CORSMiddleware
-load_dotenv()
-
def setup_cors(app):
diff --git a/app/utils/create_tables.py b/app/utils/create_tables.py
index e438a03..0d2eaa2 100644
--- a/app/utils/create_tables.py
+++ b/app/utils/create_tables.py
@@ -1,7 +1,10 @@
from app.models.user import Base
-from app.utils.db import engine
+from app.utils.db import get_engine
-async def init_db():
+async def init_db() -> None:
+
+ engine = get_engine()
+
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
diff --git a/app/utils/db.py b/app/utils/db.py
index 5531998..d3c7318 100644
--- a/app/utils/db.py
+++ b/app/utils/db.py
@@ -1,7 +1,6 @@
import os
-from typing import AsyncGenerator
+from typing import AsyncGenerator, Optional
-from dotenv import load_dotenv
from sqlalchemy.ext.asyncio import (
AsyncEngine,
AsyncSession,
@@ -10,29 +9,45 @@ from sqlalchemy.ext.asyncio import (
)
from sqlalchemy.orm import DeclarativeBase
-load_dotenv()
+engine: Optional[AsyncEngine] = None
+async_session: Optional[async_sessionmaker[AsyncSession]] = None
-DATABASE_URL = os.getenv("DATABASE_URL")
-if not DATABASE_URL:
- raise ValueError("DATABASE_URL not found in .env")
+class Base(DeclarativeBase):
+ pass
-engine: AsyncEngine = create_async_engine(
- DATABASE_URL,
- echo=True,
- pool_pre_ping=True,
-)
-async_session = async_sessionmaker(
- bind=engine,
- expire_on_commit=False,
-)
+def init_db_engine() -> None:
+ global engine, async_session
+ database_url = os.getenv("DATABASE_URL")
-class Base(DeclarativeBase):
- pass
+ if not database_url:
+ raise RuntimeError("DATABASE_URL not found in environment")
+
+ engine = create_async_engine(
+ database_url,
+ echo=True,
+ pool_pre_ping=True,
+ )
+
+ async_session = async_sessionmaker(
+ bind=engine,
+ expire_on_commit=False,
+ )
+
+
+def get_engine() -> AsyncEngine:
+ if engine is None:
+ raise RuntimeError(
+ "DB engine not initialized. Call init_db_engine() first."
+ )
+ return engine
async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
+ if async_session is None:
+ raise RuntimeError("DB not initialized. Call init_db_engine() first.")
+
async with async_session() as session:
yield session
diff --git a/app/utils/env.py b/app/utils/env.py
index e69de29..76b61a7 100644
--- a/app/utils/env.py
+++ b/app/utils/env.py
@@ -0,0 +1,5 @@
+from pathlib import Path
+
+from dotenv import load_dotenv
+
+load_dotenv(Path(__file__).resolve().parents[2] / ".env")