Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions backend/.env.example
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ POSTGRES_USER=postgres
POSTGRES_PASSWORD=changethis

# AI
AI_MAX_USAGE_QUOTA=30
AI_QUOTA_TIME_RANGE_DAYS=30 # time in days
AI_MODEL="dummy_model"
AI_API_KEY="dummy_api_key"

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""Add AI Usage Quota tables

Revision ID: d1ea38d75310
Revises: cb16ae472c1e
Create Date: 2025-05-04 09:59:20.325131

"""
from alembic import op
import sqlalchemy as sa
import sqlmodel.sql.sqltypes


# revision identifiers, used by Alembic.
revision = 'd1ea38d75310'
down_revision = 'cb16ae472c1e'
branch_labels = None
depends_on = None


def upgrade():
op.create_table(
'aiusagequota',
sa.Column('id', sa.UUID(), primary_key=True, nullable=False),
sa.Column('user_id', sa.UUID(), sa.ForeignKey('user.id', ondelete='CASCADE'), index=True, nullable=False),
sa.Column('usage_count', sa.Integer, default=0, nullable=False),
sa.Column('last_reset_time', sa.DateTime(timezone=True), nullable=False),
)


def downgrade():
op.drop_table('aiusagequota')
9 changes: 9 additions & 0 deletions backend/src/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ def SQLALCHEMY_DATABASE_URI(self) -> PostgresDsn:

AI_API_KEY: str | None = None
AI_MODEL: str | None = None
AI_MAX_USAGE_QUOTA: int = 30
AI_QUOTA_TIME_RANGE_DAYS: int = 1

COLLECTION_GENERATION_PROMPT: str | None = None
CARD_GENERATION_PROMPT: str | None = None
Expand Down Expand Up @@ -92,6 +94,13 @@ def _enforce_non_default_secrets(self) -> Self:
"FIRST_SUPERUSER_PASSWORD", self.FIRST_SUPERUSER_PASSWORD
)

if self.AI_MAX_USAGE_QUOTA is None or self.AI_MAX_USAGE_QUOTA <= 0:
raise ValueError("AI_MAX_USAGE_QUOTA must be set to a positive integer.")
if self.AI_QUOTA_TIME_RANGE_DAYS is None or self.AI_QUOTA_TIME_RANGE_DAYS <= 0:
raise ValueError(
"AI_QUOTA_TIME_RANGE_DAYS must be set to a positive integer."
)

return self


Expand Down
13 changes: 13 additions & 0 deletions backend/src/flashcards/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from src.ai_models.gemini import GeminiProviderDep
from src.ai_models.gemini.exceptions import AIGenerationError
from src.auth.services import CurrentUser, SessionDep
from src.users.services import check_and_increment_ai_usage_quota

from . import services
from .exceptions import EmptyCollectionError
Expand Down Expand Up @@ -52,6 +53,12 @@ async def create_collection(

if collection_in.prompt:
try:
if not await asyncio.to_thread(
lambda: check_and_increment_ai_usage_quota(session, current_user)
):
raise HTTPException(
status_code=429, detail="Quota for AI usage is reached."
)
flashcard_collection = await services.generate_ai_collection(
provider, collection_in.prompt
)
Expand Down Expand Up @@ -145,6 +152,12 @@ async def create_card(
if not access_checked:
raise HTTPException(status_code=404, detail="Collection not found")
if card_in.prompt:
if not await asyncio.to_thread(
lambda: check_and_increment_ai_usage_quota(session, current_user)
):
raise HTTPException(
status_code=429, detail="Quota for AI usage is reached."
)
card_base = await services.generate_ai_flashcard(card_in.prompt, provider)
card_in.front = card_base.front
card_in.back = card_base.back
Expand Down
7 changes: 6 additions & 1 deletion backend/src/users/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from src.auth.services import CurrentUser, SessionDep
from src.core.config import settings
from src.users.schemas import UserCreate, UserPublic, UserRegister
from src.users.schemas import AIUsageQuota, UserCreate, UserPublic, UserRegister

from . import services

Expand Down Expand Up @@ -38,3 +38,8 @@ def register_user(session: SessionDep, user_in: UserRegister) -> Any:
user_create = UserCreate.model_validate(user_in)
user = services.create_user(session=session, user_create=user_create)
return user


@router.get("/users/me/ai-usage-quota", response_model=AIUsageQuota)
def get_my_ai_usage_quota(current_user: CurrentUser):
return services.get_ai_usage_quota_for_user(current_user)
19 changes: 18 additions & 1 deletion backend/src/users/models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import uuid
from datetime import datetime, timezone
from typing import TYPE_CHECKING

from sqlmodel import Field, Relationship
from sqlmodel import Field, Relationship, SQLModel

from src.users.schemas import UserBase

Expand All @@ -22,3 +23,19 @@ class User(UserBase, table=True):
cascade_delete=True,
sa_relationship_kwargs={"lazy": "selectin"},
)
ai_usage_quota: "AIUsageQuota" = Relationship(
back_populates="user",
sa_relationship_kwargs={"uselist": False, "lazy": "selectin"},
)


class AIUsageQuota(SQLModel, table=True):
id: uuid.UUID | None = Field(default_factory=uuid.uuid4, primary_key=True)
user_id: uuid.UUID = Field(
foreign_key="user.id", index=True, unique=True, ondelete="CASCADE"
)
usage_count: int = Field(default=0)
last_reset_time: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc)
)
user: "User" = Relationship(back_populates="ai_usage_quota")
7 changes: 7 additions & 0 deletions backend/src/users/schemas.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import uuid
from datetime import datetime

from pydantic import EmailStr
from sqlmodel import Field, SQLModel
Expand Down Expand Up @@ -26,3 +27,9 @@ class UserRegister(SQLModel):

class UserPublic(UserBase):
id: uuid.UUID


class AIUsageQuota(SQLModel):
usage_count: int
max_usage_allowed: int
reset_date: datetime
72 changes: 70 additions & 2 deletions backend/src/users/services.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import uuid
from datetime import datetime, timedelta, timezone
from typing import Any

from sqlmodel import Session, select
from sqlalchemy.exc import IntegrityError
from sqlmodel import Session, select, update

from src.auth.services import get_password_hash
from src.core.config import settings
from src.users.models import AIUsageQuota as AIUsageQuotaModel
from src.users.models import User
from src.users.schemas import UserCreate, UserUpdate
from src.users.schemas import AIUsageQuota, UserCreate, UserUpdate


def create_user(*, session: Session, user_create: UserCreate) -> User:
Expand Down Expand Up @@ -42,3 +46,67 @@ def get_user_by_email(*, session: Session, email: str) -> User | None:
statement = select(User).where(User.email == email)
session_user = session.exec(statement).first()
return session_user


def get_ai_usage_quota_for_user(user: User) -> AIUsageQuota:
quota = user.ai_usage_quota
if not quota:
return AIUsageQuota(
usage_count=0,
max_usage_allowed=settings.AI_MAX_USAGE_QUOTA,
reset_date=(
datetime.now(timezone.utc)
+ timedelta(days=settings.AI_QUOTA_TIME_RANGE_DAYS)
),
)
return AIUsageQuota(
usage_count=quota.usage_count,
max_usage_allowed=settings.AI_MAX_USAGE_QUOTA,
reset_date=(
quota.last_reset_time + timedelta(days=settings.AI_QUOTA_TIME_RANGE_DAYS)
),
)


def check_and_increment_ai_usage_quota(session: Session, user: User) -> bool:
now = datetime.now(timezone.utc)
reset_threshold = now - timedelta(days=settings.AI_QUOTA_TIME_RANGE_DAYS)

if not user.ai_usage_quota:
try:
quota = AIUsageQuotaModel(
user_id=user.id, usage_count=1, last_reset_time=now
)
session.add(quota)
session.commit()
return True
except IntegrityError:
session.rollback()

session.refresh(user)

result_reset = session.exec(
update(AIUsageQuotaModel)
.where(
(AIUsageQuotaModel.user_id == user.id)
& (AIUsageQuotaModel.last_reset_time <= reset_threshold)
)
.values(usage_count=1, last_reset_time=now)
)

if result_reset.rowcount > 0:
session.commit()
return True

result_increment = session.exec(
update(AIUsageQuotaModel)
.where(
(AIUsageQuotaModel.user_id == user.id)
& (AIUsageQuotaModel.last_reset_time > reset_threshold)
& (AIUsageQuotaModel.usage_count < settings.AI_MAX_USAGE_QUOTA)
)
.values(usage_count=AIUsageQuotaModel.usage_count + 1)
)

session.commit()
return result_increment.rowcount > 0
28 changes: 16 additions & 12 deletions backend/tests/flashcards/card/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,18 +394,22 @@ def test_create_card_with_prompt_ai(
with patch(
"src.flashcards.services.generate_ai_flashcard", new_callable=AsyncMock
) as mock_ai:
mock_ai.return_value = type("Card", (), ai_card)()
card_data = {"prompt": prompt, "front": "", "back": ""}
rsp = client.post(
f"{settings.API_V1_STR}/collections/{collection_id}/cards/",
json=card_data,
headers=normal_user_token_headers,
)
assert rsp.status_code == 200
content = rsp.json()
assert content["front"] == ai_card["front"]
assert content["back"] == ai_card["back"]
mock_ai.assert_called_once_with(prompt, ANY)
with patch(
"src.flashcards.api.check_and_increment_ai_usage_quota"
) as mock_quota_check:
mock_quota_check.return_value = True
mock_ai.return_value = type("Card", (), ai_card)()
card_data = {"prompt": prompt, "front": "", "back": ""}
rsp = client.post(
f"{settings.API_V1_STR}/collections/{collection_id}/cards/",
json=card_data,
headers=normal_user_token_headers,
)
assert rsp.status_code == 200
content = rsp.json()
assert content["front"] == ai_card["front"]
assert content["back"] == ai_card["back"]
mock_ai.assert_called_once_with(prompt, ANY)


def test_create_card_with_prompt_too_long(
Expand Down
72 changes: 40 additions & 32 deletions backend/tests/flashcards/collection/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,25 +83,29 @@ def test_create_collection_with_prompt(
with patch(
"src.flashcards.services.generate_ai_collection", new_callable=AsyncMock
) as mock_ai_generate:
mock_ai_generate.return_value = mock_collection

rsp = client.post(
f"{settings.API_V1_STR}/collections/",
json=collection_data.model_dump(),
headers=normal_user_token_headers,
)

assert rsp.status_code == 200
content = rsp.json()
assert content["name"] == collection_data.name
assert "id" in content
assert isinstance(content["id"], str)
assert len(content["cards"]) == len(mock_collection.cards)
for i, card in enumerate(mock_collection.cards):
assert content["cards"][i]["front"] == card.front
assert content["cards"][i]["back"] == card.back

mock_ai_generate.assert_called_once()
with patch(
"src.flashcards.api.check_and_increment_ai_usage_quota"
) as mock_quota_check:
mock_ai_generate.return_value = mock_collection
mock_quota_check.return_value = True

rsp = client.post(
f"{settings.API_V1_STR}/collections/",
json=collection_data.model_dump(),
headers=normal_user_token_headers,
)

assert rsp.status_code == 200
content = rsp.json()
assert content["name"] == collection_data.name
assert "id" in content
assert isinstance(content["id"], str)
assert len(content["cards"]) == len(mock_collection.cards)
for i, card in enumerate(mock_collection.cards):
assert content["cards"][i]["front"] == card.front
assert content["cards"][i]["back"] == card.back

mock_ai_generate.assert_called_once()


def test_create_collection_with_ai_generation_error(
Expand All @@ -114,19 +118,23 @@ def test_create_collection_with_ai_generation_error(
with patch(
"src.flashcards.services.generate_ai_collection", new_callable=AsyncMock
) as mock_ai_generate:
err_msg = "AI service is unavailable"
mock_ai_generate.side_effect = AIGenerationError(err_msg)

rsp = client.post(
f"{settings.API_V1_STR}/collections/",
json=collection_data.model_dump(),
headers=normal_user_token_headers,
)

assert rsp.status_code == 500
content = rsp.json()
assert "detail" in content
assert err_msg in content["detail"]
with patch(
"src.flashcards.api.check_and_increment_ai_usage_quota"
) as mock_quota_check:
err_msg = "AI service is unavailable"
mock_ai_generate.side_effect = AIGenerationError(err_msg)
mock_quota_check.return_value = True

rsp = client.post(
f"{settings.API_V1_STR}/collections/",
json=collection_data.model_dump(),
headers=normal_user_token_headers,
)

assert rsp.status_code == 500
content = rsp.json()
assert "detail" in content
assert err_msg in content["detail"]


def test_read_collection(
Expand Down
20 changes: 20 additions & 0 deletions backend/tests/users/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from typing import Any

import pytest
from sqlmodel import Session

from src.users.schemas import UserCreate
from src.users.services import create_user
from tests.utils.utils import random_email, random_lower_string


@pytest.fixture
def test_user(db: Session) -> dict[str, Any]:
email = random_email()
password = random_lower_string()
full_name = random_lower_string()

user_in = UserCreate(email=email, password=password, full_name=full_name)
user = create_user(session=db, user_create=user_in)

return user
Loading