diff --git a/backend/app/core/startup.py b/backend/app/core/startup.py new file mode 100644 index 0000000..2b2e756 --- /dev/null +++ b/backend/app/core/startup.py @@ -0,0 +1,52 @@ +import logging +from typing import Any + +from fastapi import HTTPException +from sqlalchemy.orm import Session + +from app.modules.admin.seed.service import seed_all +from app.modules.auth.schemas import UserCreate +from app.modules.auth.service import create_user_if_not_exists +from app.settings import Settings + +logger = logging.getLogger(__name__) + + +def setup_test_users(db: Session, users: list[dict[str, Any]], default_password: str): + """Create initial test users if they don't exist. Easily extendable.""" + for user_info in users: + # Use default password if not provided in user_info + data = user_info.copy() + if "password" not in data: + data["password"] = default_password + + logger.info(f"Ensuring test user exists: {data['email']}") + create_user_if_not_exists(db, UserCreate(**data)) + + +def auto_seed_data(db: Session): + """Seed the database with initial data if it's empty.""" + try: + seed_all(db, n_tags=7, n_fields=12, n_events=30) + logger.info("Auto-seeding completed successfully.") + except HTTPException as e: + if e.status_code == 405: + logger.info("Database already contains data. Skipping auto-seeding.") + else: + logger.exception(f"Auto-seeding failed with unexpected error: {e.detail}") + except Exception: + logger.exception("Auto-seeding failed") + + +def run_startup_tasks(db: Session, settings: Settings): + """Run all necessary startup tasks for development environment.""" + if settings.is_dev: + setup_test_users(db, settings.dev_users, settings.dev_users_password) + auto_seed_data(db) + elif settings.is_demo: + create_user_if_not_exists( + db, + UserCreate( + email=settings.demo_user_email, password=settings.demo_user_password + ), + ) diff --git a/backend/app/factory.py b/backend/app/factory.py index 58173ba..c97472a 100644 --- a/backend/app/factory.py +++ b/backend/app/factory.py @@ -10,8 +10,7 @@ from app.api.v1.routes import admin, auth, events, fields, generic, tags from app.core.handlers import http_exception_handler, validation_exception_handler -from app.modules.auth.schemas import UserCreate -from app.modules.auth.service import create_user_if_not_exists +from app.core.startup import run_startup_tasks from app.settings import Settings logger = logging.getLogger(__name__) @@ -24,11 +23,8 @@ def create_app( @asynccontextmanager async def lifespan(app: FastAPI): logger.info(f"Starting application in {settings.env} mode") - if settings.is_demo: - with SessionLocal() as db: - create_user_if_not_exists( - db, UserCreate(email="demo@evsy.dev", password="bestructured") - ) + with SessionLocal() as db: + run_startup_tasks(db, settings) yield logger.info("Shutting down application") engine.dispose() diff --git a/backend/app/modules/auth/token.py b/backend/app/modules/auth/token.py index 3377e64..2dc349d 100644 --- a/backend/app/modules/auth/token.py +++ b/backend/app/modules/auth/token.py @@ -17,9 +17,15 @@ def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str: settings = get_settings() to_encode = data.copy() - expire = datetime.now(UTC) + ( - expires_delta or timedelta(minutes=settings.access_token_expire_minutes) - ) + if expires_delta: + expire = datetime.now(UTC) + expires_delta + elif settings.is_dev: + # 100 years for dev mode + expire = datetime.now(UTC) + timedelta(days=365 * 100) + else: + expire = datetime.now(UTC) + timedelta( + minutes=settings.access_token_expire_minutes + ) to_encode.update({"exp": expire}) return jwt.encode(to_encode, settings.secret_key, algorithm=settings.jwt_algorithm) diff --git a/backend/app/settings.py b/backend/app/settings.py index 5cd5244..1179b7c 100644 --- a/backend/app/settings.py +++ b/backend/app/settings.py @@ -57,6 +57,15 @@ def __init__(self, _env_file: Optional[str] = None, **kwargs: Any): default=None, alias="GOOGLE_CLIENT_SECRET" ) + dev_users: list[dict[str, Any]] = Field( + default=[{"email": "user@example.com"}], + alias="DEV_USERS", + ) + dev_users_password: str = Field(default="12345678", alias="DEV_USERS_PASSWORD") + + demo_user_email: str = Field(default="demo@evsy.dev", alias="DEMO_USER_EMAIL") + demo_user_password: str = Field(default="bestructured", alias="DEMO_USER_PASSWORD") + model_config = SettingsConfigDict( env_file_encoding="utf-8", case_sensitive=False, diff --git a/backend/tests/test_startup.py b/backend/tests/test_startup.py new file mode 100644 index 0000000..e84fa1e --- /dev/null +++ b/backend/tests/test_startup.py @@ -0,0 +1,98 @@ +from datetime import UTC, datetime, timedelta +from unittest.mock import MagicMock, patch + +from fastapi import HTTPException +from jose import jwt + +from app.core.startup import auto_seed_data, run_startup_tasks, setup_test_users +from app.modules.auth.models import User +from app.modules.auth.token import create_access_token + + +def test_create_access_token_dev_long_expiry(): + """Test that tokens in dev mode have a very long expiry.""" + mock_settings = MagicMock() + mock_settings.is_dev = True + mock_settings.secret_key = "test_secret" + mock_settings.jwt_algorithm = "HS256" + + with patch("app.modules.auth.token.get_settings", return_value=mock_settings): + token = create_access_token({"sub": "user@example.com"}) + payload = jwt.decode(token, "test_secret", algorithms=["HS256"]) + + exp = payload["exp"] + expected_min_exp = (datetime.now(UTC) + timedelta(days=365 * 99)).timestamp() + assert exp > expected_min_exp + + +def test_create_access_token_prod_normal_expiry(): + """Test that tokens in prod mode have normal expiry.""" + mock_settings = MagicMock() + mock_settings.is_dev = False + mock_settings.access_token_expire_minutes = 60 + mock_settings.secret_key = "test_secret" + mock_settings.jwt_algorithm = "HS256" + + with patch("app.modules.auth.token.get_settings", return_value=mock_settings): + token = create_access_token({"sub": "user@example.com"}) + payload = jwt.decode(token, "test_secret", algorithms=["HS256"]) + + exp = payload["exp"] + # Should be roughly 60 minutes from now + expected_exp = (datetime.now(UTC) + timedelta(minutes=60)).timestamp() + assert abs(exp - expected_exp) < 10 # Allow 10s difference + + +def test_setup_test_users(db, test_settings): + """Test that test users are created if they don't exist.""" + # Ensure user doesn't exist in the current transaction + primary_dev_user = test_settings.dev_users[0] + user = db.query(User).filter(User.email == primary_dev_user["email"]).first() + if user: + db.delete(user) + db.flush() + + setup_test_users(db, test_settings.dev_users, test_settings.dev_users_password) + + user = db.query(User).filter(User.email == primary_dev_user["email"]).first() + assert user is not None + assert user.email == primary_dev_user["email"] + + +@patch("app.core.startup.seed_all") +def test_auto_seed_data_empty_db(mock_seed_all, db): + """Test that seeding is called when DB is empty.""" + mock_seed_all.return_value = None + auto_seed_data(db) + mock_seed_all.assert_called_once() + + +@patch("app.core.startup.seed_all") +def test_auto_seed_data_already_seeded(mock_seed_all, db): + """Test that seeding is skipped if DB already has data (simulated by HTTPException 405).""" + mock_seed_all.side_effect = HTTPException( + status_code=405, detail="Action is only allowed on empty database" + ) + + # This should not raise an exception, just log and return + auto_seed_data(db) + mock_seed_all.assert_called_once() + + +def test_run_startup_tasks_dev_calls_subtasks(db): + """Test that all dev startup tasks are triggered in dev mode.""" + mock_settings = MagicMock() + mock_settings.is_dev = True + mock_settings.is_demo = False + mock_settings.dev_users = [{"email": "user@example.com"}] + mock_settings.dev_users_password = "password" + + with ( + patch("app.core.startup.setup_test_users") as mock_setup_users, + patch("app.core.startup.auto_seed_data") as mock_auto_seed, + ): + run_startup_tasks(db, mock_settings) + mock_setup_users.assert_called_once_with( + db, mock_settings.dev_users, mock_settings.dev_users_password + ) + mock_auto_seed.assert_called_once_with(db)