summaryrefslogtreecommitdiff
path: root/app/utils/db.py
blob: d3c7318704b485d596911125cc8e9da3041909cf (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import os
from typing import AsyncGenerator, Optional

from sqlalchemy.ext.asyncio import (
    AsyncEngine,
    AsyncSession,
    async_sessionmaker,
    create_async_engine,
)
from sqlalchemy.orm import DeclarativeBase

engine: Optional[AsyncEngine] = None
async_session: Optional[async_sessionmaker[AsyncSession]] = None


class Base(DeclarativeBase):
    pass


def init_db_engine() -> None:
    global engine, async_session

    database_url = os.getenv("DATABASE_URL")

    if not database_url:
        raise RuntimeError("DATABASE_URL not found in environment")

    engine = create_async_engine(
        database_url,
        echo=True,
        pool_pre_ping=True,
    )

    async_session = async_sessionmaker(
        bind=engine,
        expire_on_commit=False,
    )


def get_engine() -> AsyncEngine:
    if engine is None:
        raise RuntimeError(
            "DB engine not initialized. Call init_db_engine() first."
        )
    return engine


async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
    if async_session is None:
        raise RuntimeError("DB not initialized. Call init_db_engine() first.")

    async with async_session() as session:
        yield session