summaryrefslogtreecommitdiff
path: root/app/models/profile.py
diff options
context:
space:
mode:
Diffstat (limited to 'app/models/profile.py')
-rw-r--r--app/models/profile.py38
1 files changed, 35 insertions, 3 deletions
diff --git a/app/models/profile.py b/app/models/profile.py
index b19d796..93ce6fd 100644
--- a/app/models/profile.py
+++ b/app/models/profile.py
@@ -1,14 +1,28 @@
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+from fastapi import Depends
from sqlalchemy import ForeignKey, Integer, String
-from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
+from sqlalchemy.ext.asyncio import AsyncSession
+from sqlalchemy.future import select
+from sqlalchemy.orm import Mapped, mapped_column, relationship
+
+from app.utils.db import Base, get_async_session
-Base = DeclarativeBase()
+if TYPE_CHECKING:
+ from .user import (
+ User,
+ )
class Profile(Base):
__tablename__ = "profiles"
id: Mapped[int] = mapped_column(Integer, primary_key=True)
- user_id: Mapped[int] = mapped_column(ForeignKey("users.id"), unique=True)
+ user_id: Mapped[int] = mapped_column(
+ ForeignKey("users.id"), unique=True, nullable=False
+ )
avatar_file: Mapped[str | None] = mapped_column(String(255), nullable=True)
banner_file: Mapped[str | None] = mapped_column(String(255), nullable=True)
@@ -19,3 +33,21 @@ class Profile(Base):
subscriptions_count: Mapped[int] = mapped_column(Integer, default=0, nullable=False)
followers_count: Mapped[int] = mapped_column(Integer, default=0, nullable=False)
following_count: Mapped[int] = mapped_column(Integer, default=0, nullable=False)
+
+ user: Mapped["User"] = relationship(
+ "User", back_populates="profile", uselist=False, lazy="selectin"
+ )
+
+ @classmethod
+ async def get_by_user_id(
+ cls, user_id: int, session: AsyncSession = Depends(get_async_session)
+ ):
+ result = await session.execute(select(cls).where(cls.user_id == user_id))
+ return result.scalars().first()
+
+ @classmethod
+ async def get_by_id(
+ cls, profile_id: int, session: AsyncSession = Depends(get_async_session)
+ ):
+ result = await session.execute(select(cls).where(cls.id == profile_id))
+ return result.scalars().first()