Skip to content
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
6d8a0d4
feat-72: ai usage quota
sebampuerocursor May 4, 2025
0b58fd8
feat-72: remove unneeded reference
sebampuerocursor May 4, 2025
531eacc
feat-72: include new schema version for new table
sebampuerocursor May 4, 2025
9a3f24c
feat-72: minor fixes
sebampuerocursor May 4, 2025
048a10a
feat-72: include tests for services
sebampuerocursor May 4, 2025
904542a
feat-72: include tests for new service call
sebampuerocursor May 4, 2025
76a7eb1
feat-72: run ruff formatter and linter
sebampuerocursor May 4, 2025
d9518c7
feat-72: generate openapi for new route
sebampuerocursor May 4, 2025
69d7323
Merge branch 'main' into feat-72-usage-quota-ai
sebampuerocursor May 8, 2025
1685399
feat-72: add state for loading ai usage quota
sebampuerocursor May 8, 2025
96bbc39
feat-72: only allow creation if quota is not 100%
sebampuerocursor May 9, 2025
ae98d08
run formatting
sebampuerocursor May 9, 2025
94c4aaf
feat: apply suggested changes
sebampuerocursor May 12, 2025
11ac736
tests: update tests for user services
sebampuerocursor May 13, 2025
2097f5b
fix more tests
sebampuerocursor May 14, 2025
be5539c
fix linting
sebampuerocursor May 15, 2025
fd69629
run client and regenerate openapi spec
sebampuerocursor May 15, 2025
9ab4443
update cheking function
sebampuerocursor May 15, 2025
21d9fc7
fix linting
sebampuerocursor May 15, 2025
9e91cf5
run formatter
sebampuerocursor May 15, 2025
72f2408
remove unneeded hook for ai prompt dialog, only one query is needed
sebampuerocursor May 16, 2025
8ef3ed4
change UI so that usage left is shown
sebampuerocursor May 16, 2025
e18423c
Merge remote-tracking branch 'origin/main' into feat-72-usage-quota-ai
sebampuerocursor May 22, 2025
ba9109c
feat: apply suggested changes
sebampuerocursor May 22, 2025
373e0b0
fix: refactor check_and_increment_ai_usage_quota
0010aor May 27, 2025
9db02b1
refactor: format backend files
0010aor May 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions backend/.env.example
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ POSTGRES_USER=postgres
POSTGRES_PASSWORD=changethis

# AI
AI_MAX_USAGE_QUOTA=30
AI_QUOTA_TIME_RANGE_DAYS=30 # time in days
AI_MODEL="dummy_model"
AI_API_KEY="dummy_api_key"

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""Add AI Usage Quota tables

Revision ID: d1ea38d75310
Revises: cb16ae472c1e
Create Date: 2025-05-04 09:59:20.325131

"""
from alembic import op
import sqlalchemy as sa
import sqlmodel.sql.sqltypes


# revision identifiers, used by Alembic.
revision = 'd1ea38d75310'
down_revision = 'cb16ae472c1e'
branch_labels = None
depends_on = None


def upgrade():
op.create_table(
'aiusagequota',
sa.Column('id', sa.UUID(), primary_key=True, nullable=False),
sa.Column('user_id', sa.UUID(), sa.ForeignKey('user.id', ondelete='CASCADE'), index=True, nullable=False),
sa.Column('usage_count', sa.Integer, default=0, nullable=False),
sa.Column('last_reset_time', sa.DateTime(timezone=True), nullable=False),
)


def downgrade():
op.drop_table('aiusagequota')
9 changes: 9 additions & 0 deletions backend/src/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ def SQLALCHEMY_DATABASE_URI(self) -> PostgresDsn:

AI_API_KEY: str | None = None
AI_MODEL: str | None = None
AI_MAX_USAGE_QUOTA: int | None = None
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

    AI_MAX_USAGE_QUOTA: int = 30
    AI_QUOTA_TIME_RANGE_DAYS: int = 1

so the project can run without requiring to define those variables in the .env

AI_QUOTA_TIME_RANGE_DAYS: int | None = None

COLLECTION_GENERATION_PROMPT: str | None = None
CARD_GENERATION_PROMPT: str | None = None
Expand Down Expand Up @@ -92,6 +94,13 @@ def _enforce_non_default_secrets(self) -> Self:
"FIRST_SUPERUSER_PASSWORD", self.FIRST_SUPERUSER_PASSWORD
)

if self.AI_MAX_USAGE_QUOTA is None or self.AI_MAX_USAGE_QUOTA <= 0:
raise ValueError("AI_MAX_USAGE_QUOTA must be set to a positive integer.")
if self.AI_QUOTA_TIME_RANGE_DAYS is None or self.AI_QUOTA_TIME_RANGE_DAYS <= 0:
raise ValueError(
"AI_QUOTA_TIME_RANGE_DAYS must be set to a positive integer."
)

return self


Expand Down
9 changes: 9 additions & 0 deletions backend/src/flashcards/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from src.ai_models.gemini import GeminiProviderDep
from src.ai_models.gemini.exceptions import AIGenerationError
from src.auth.services import CurrentUser, SessionDep
from src.users.services import check_and_increment_ai_usage_quota

from . import services
from .exceptions import EmptyCollectionError
Expand Down Expand Up @@ -52,6 +53,10 @@ async def create_collection(

if collection_in.prompt:
try:
if not check_and_increment_ai_usage_quota(session, current_user):
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think we should use if not await asyncio.to_thread here right?

raise HTTPException(
status_code=429, detail="Quota for AI usage is reached."
)
flashcard_collection = await services.generate_ai_collection(
provider, collection_in.prompt
)
Expand Down Expand Up @@ -145,6 +150,10 @@ async def create_card(
if not access_checked:
raise HTTPException(status_code=404, detail="Collection not found")
if card_in.prompt:
if not check_and_increment_ai_usage_quota(session, current_user):
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here?

raise HTTPException(
status_code=429, detail="Quota for AI usage is reached."
)
card_base = await services.generate_ai_flashcard(card_in.prompt, provider)
card_in.front = card_base.front
card_in.back = card_base.back
Expand Down
7 changes: 6 additions & 1 deletion backend/src/users/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from src.auth.services import CurrentUser, SessionDep
from src.core.config import settings
from src.users.schemas import UserCreate, UserPublic, UserRegister
from src.users.schemas import AIUsageQuota, UserCreate, UserPublic, UserRegister

from . import services

Expand Down Expand Up @@ -38,3 +38,8 @@ def register_user(session: SessionDep, user_in: UserRegister) -> Any:
user_create = UserCreate.model_validate(user_in)
user = services.create_user(session=session, user_create=user_create)
return user


@router.get("/users/me/ai-usage-quota", response_model=AIUsageQuota)
def get_my_ai_usage_quota(current_user: CurrentUser):
return services.get_ai_usage_quota_for_user(current_user)
19 changes: 18 additions & 1 deletion backend/src/users/models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import uuid
from datetime import datetime, timezone
from typing import TYPE_CHECKING

from sqlmodel import Field, Relationship
from sqlmodel import Field, Relationship, SQLModel

from src.users.schemas import UserBase

Expand All @@ -22,3 +23,19 @@ class User(UserBase, table=True):
cascade_delete=True,
sa_relationship_kwargs={"lazy": "selectin"},
)
ai_usage_quota: "AIUsageQuota" = Relationship(
back_populates="user",
sa_relationship_kwargs={"uselist": False, "lazy": "selectin"},
)


class AIUsageQuota(SQLModel, table=True):
id: uuid.UUID | None = Field(default_factory=uuid.uuid4, primary_key=True)
user_id: uuid.UUID = Field(
foreign_key="user.id", index=True, unique=True, ondelete="CASCADE"
)
usage_count: int = Field(default=0)
last_reset_time: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc)
)
user: "User" = Relationship(back_populates="ai_usage_quota")
7 changes: 7 additions & 0 deletions backend/src/users/schemas.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import uuid
from datetime import datetime

from pydantic import EmailStr
from sqlmodel import Field, SQLModel
Expand Down Expand Up @@ -26,3 +27,9 @@ class UserRegister(SQLModel):

class UserPublic(UserBase):
id: uuid.UUID


class AIUsageQuota(SQLModel):
usage_count: int
max_usage_allowed: int
reset_date: datetime
53 changes: 51 additions & 2 deletions backend/src/users/services.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import uuid
from datetime import datetime, timedelta, timezone
from typing import Any

from sqlmodel import Session, select
from sqlmodel import Session, select, update

from src.auth.services import get_password_hash
from src.core.config import settings
from src.users.models import AIUsageQuota as AIUsageQuotaModel
from src.users.models import User
from src.users.schemas import UserCreate, UserUpdate
from src.users.schemas import AIUsageQuota, UserCreate, UserUpdate


def create_user(*, session: Session, user_create: UserCreate) -> User:
Expand Down Expand Up @@ -42,3 +45,49 @@ def get_user_by_email(*, session: Session, email: str) -> User | None:
statement = select(User).where(User.email == email)
session_user = session.exec(statement).first()
return session_user


def get_ai_usage_quota_for_user(user: User) -> AIUsageQuota:
quota = user.ai_usage_quota
if not quota:
return AIUsageQuota(
usage_count=0,
max_usage_allowed=settings.AI_MAX_USAGE_QUOTA,
reset_date=(
datetime.now(timezone.utc)
+ timedelta(days=settings.AI_QUOTA_TIME_RANGE_DAYS)
),
)
return AIUsageQuota(
usage_count=quota.usage_count,
max_usage_allowed=settings.AI_MAX_USAGE_QUOTA,
reset_date=(
quota.last_reset_time + timedelta(days=settings.AI_QUOTA_TIME_RANGE_DAYS)
),
)


def check_and_increment_ai_usage_quota(session: Session, user: User) -> bool:
quota = user.ai_usage_quota
now = datetime.now(timezone.utc)
if not quota:
quota = AIUsageQuotaModel(user_id=user.id, usage_count=1, last_reset_time=now)
session.add(quota)
session.commit()
return True

if now - quota.last_reset_time >= timedelta(days=settings.AI_QUOTA_TIME_RANGE_DAYS):
quota.usage_count = 0
quota.last_reset_time = now
session.add(quota)
session.commit()
result = session.exec(
update(AIUsageQuotaModel)
.where(
(AIUsageQuotaModel.id == quota.id)
& (AIUsageQuotaModel.usage_count <= settings.AI_MAX_USAGE_QUOTA)
)
.values(usage_count=AIUsageQuotaModel.usage_count + 1)
)
session.commit()
return result.rowcount > 0
28 changes: 16 additions & 12 deletions backend/tests/flashcards/card/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,18 +394,22 @@ def test_create_card_with_prompt_ai(
with patch(
"src.flashcards.services.generate_ai_flashcard", new_callable=AsyncMock
) as mock_ai:
mock_ai.return_value = type("Card", (), ai_card)()
card_data = {"prompt": prompt, "front": "", "back": ""}
rsp = client.post(
f"{settings.API_V1_STR}/collections/{collection_id}/cards/",
json=card_data,
headers=normal_user_token_headers,
)
assert rsp.status_code == 200
content = rsp.json()
assert content["front"] == ai_card["front"]
assert content["back"] == ai_card["back"]
mock_ai.assert_called_once_with(prompt, ANY)
with patch(
"src.flashcards.api.check_and_increment_ai_usage_quota"
) as mock_quota_check:
mock_quota_check.return_value = True
mock_ai.return_value = type("Card", (), ai_card)()
card_data = {"prompt": prompt, "front": "", "back": ""}
rsp = client.post(
f"{settings.API_V1_STR}/collections/{collection_id}/cards/",
json=card_data,
headers=normal_user_token_headers,
)
assert rsp.status_code == 200
content = rsp.json()
assert content["front"] == ai_card["front"]
assert content["back"] == ai_card["back"]
mock_ai.assert_called_once_with(prompt, ANY)


def test_create_card_with_prompt_too_long(
Expand Down
72 changes: 40 additions & 32 deletions backend/tests/flashcards/collection/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,25 +83,29 @@ def test_create_collection_with_prompt(
with patch(
"src.flashcards.services.generate_ai_collection", new_callable=AsyncMock
) as mock_ai_generate:
mock_ai_generate.return_value = mock_collection

rsp = client.post(
f"{settings.API_V1_STR}/collections/",
json=collection_data.model_dump(),
headers=normal_user_token_headers,
)

assert rsp.status_code == 200
content = rsp.json()
assert content["name"] == collection_data.name
assert "id" in content
assert isinstance(content["id"], str)
assert len(content["cards"]) == len(mock_collection.cards)
for i, card in enumerate(mock_collection.cards):
assert content["cards"][i]["front"] == card.front
assert content["cards"][i]["back"] == card.back

mock_ai_generate.assert_called_once()
with patch(
"src.flashcards.api.check_and_increment_ai_usage_quota"
) as mock_quota_check:
mock_ai_generate.return_value = mock_collection
mock_quota_check.return_value = True

rsp = client.post(
f"{settings.API_V1_STR}/collections/",
json=collection_data.model_dump(),
headers=normal_user_token_headers,
)

assert rsp.status_code == 200
content = rsp.json()
assert content["name"] == collection_data.name
assert "id" in content
assert isinstance(content["id"], str)
assert len(content["cards"]) == len(mock_collection.cards)
for i, card in enumerate(mock_collection.cards):
assert content["cards"][i]["front"] == card.front
assert content["cards"][i]["back"] == card.back

mock_ai_generate.assert_called_once()


def test_create_collection_with_ai_generation_error(
Expand All @@ -114,19 +118,23 @@ def test_create_collection_with_ai_generation_error(
with patch(
"src.flashcards.services.generate_ai_collection", new_callable=AsyncMock
) as mock_ai_generate:
err_msg = "AI service is unavailable"
mock_ai_generate.side_effect = AIGenerationError(err_msg)

rsp = client.post(
f"{settings.API_V1_STR}/collections/",
json=collection_data.model_dump(),
headers=normal_user_token_headers,
)

assert rsp.status_code == 500
content = rsp.json()
assert "detail" in content
assert err_msg in content["detail"]
with patch(
"src.flashcards.api.check_and_increment_ai_usage_quota"
) as mock_quota_check:
err_msg = "AI service is unavailable"
mock_ai_generate.side_effect = AIGenerationError(err_msg)
mock_quota_check.return_value = True

rsp = client.post(
f"{settings.API_V1_STR}/collections/",
json=collection_data.model_dump(),
headers=normal_user_token_headers,
)

assert rsp.status_code == 500
content = rsp.json()
assert "detail" in content
assert err_msg in content["detail"]


def test_read_collection(
Expand Down
20 changes: 20 additions & 0 deletions backend/tests/users/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from typing import Any

import pytest
from sqlmodel import Session

from src.users.schemas import UserCreate
from src.users.services import create_user
from tests.utils.utils import random_email, random_lower_string


@pytest.fixture
def test_user(db: Session) -> dict[str, Any]:
email = random_email()
password = random_lower_string()
full_name = random_lower_string()

user_in = UserCreate(email=email, password=password, full_name=full_name)
user = create_user(session=db, user_create=user_in)

return user
Loading