diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..55fc3b6 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,107 @@ +name: CI + +on: + push: + branches: ["main", "main-*", "claude/*"] + pull_request: + branches: ["main", "main-*"] + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + lint: + name: Lint + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: "3.11" + cache: pip + + - name: Install lint tools + run: pip install black==24.4.2 isort==5.13.2 + + - name: black (check) + run: black --check --diff src/ scripts/ tests/ + + - name: isort (check) + run: isort --check-only --diff src/ scripts/ tests/ + + test: + name: Tests (Python ${{ matrix.python-version }}) + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: ["3.10", "3.11", "3.12"] + + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + cache: pip + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + + - name: Run unit tests + run: | + pytest tests/unit/ -v --tb=short \ + --cov=src/f1_predictor \ + --cov-report=term-missing \ + --cov-report=xml:coverage.xml \ + -q + + - name: Run integration tests + run: | + pytest tests/integration/ -v --tb=short -q + + - name: Run system tests + run: | + pytest tests/system/ -v --tb=short -q + + - name: Upload coverage report + if: matrix.python-version == '3.11' + uses: actions/upload-artifact@v4 + with: + name: coverage-report + path: coverage.xml + retention-days: 7 + + api-smoke: + name: API smoke test + runs-on: ubuntu-latest + needs: test + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: "3.11" + cache: pip + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + + - name: Start API server in background + run: | + python -m uvicorn src.f1_predictor.api:app --host 127.0.0.1 --port 8000 & + sleep 3 + + - name: Health check + run: | + curl -f http://127.0.0.1:8000/health + + - name: Docs endpoint + run: | + curl -f http://127.0.0.1:8000/openapi.json | python -c "import sys,json; d=json.load(sys.stdin); print(d['info']['title'])" diff --git a/dashboard/app.py b/dashboard/app.py new file mode 100644 index 0000000..3ec4d0e --- /dev/null +++ b/dashboard/app.py @@ -0,0 +1,268 @@ +"""F1 Prediction Dashboard β€” Streamlit application. + +Run from the repository root: + streamlit run dashboard/app.py + +The dashboard calls the local FastAPI server. Start it separately with: + uvicorn src.f1_predictor.api:app --host 127.0.0.1 --port 8000 + +Or configure API_BASE_URL in the sidebar to point at a remote deployment. +""" + +from __future__ import annotations + +import json +from datetime import datetime +from typing import Optional + +import pandas as pd +import plotly.express as px +import plotly.graph_objects as go +import requests +import streamlit as st + +# --------------------------------------------------------------------------- +# Page config +# --------------------------------------------------------------------------- + +st.set_page_config( + page_title="F1 Prediction Dashboard", + page_icon="🏎", + layout="wide", + initial_sidebar_state="expanded", +) + +# --------------------------------------------------------------------------- +# Sidebar β€” configuration +# --------------------------------------------------------------------------- + +with st.sidebar: + st.title("βš™οΈ Settings") + api_base = st.text_input( + "API base URL", + value="http://127.0.0.1:8000", + help="Base URL of the running FastAPI server.", + ) + current_year = datetime.now().year + year = st.number_input( + "Season year", min_value=2018, max_value=2030, value=current_year, step=1 + ) + n_sims = st.slider( + "Simulations (Monte Carlo)", min_value=200, max_value=5000, value=1000, step=200 + ) + sc_prob = st.slider("Safety-car probability", 0.0, 1.0, 0.30, 0.05) + +# --------------------------------------------------------------------------- +# Helper β€” API calls +# --------------------------------------------------------------------------- + + +def _get(path: str) -> Optional[dict]: + try: + r = requests.get(f"{api_base}{path}", timeout=30) + r.raise_for_status() + return r.json() + except requests.ConnectionError: + st.error(f"Cannot connect to API at **{api_base}**. Is the server running?") + return None + except Exception as exc: + st.error(f"API error: {exc}") + return None + + +def _post(path: str, payload: dict) -> Optional[dict]: + try: + r = requests.post(f"{api_base}{path}", json=payload, timeout=120) + r.raise_for_status() + return r.json() + except requests.ConnectionError: + st.error(f"Cannot connect to API at **{api_base}**. Is the server running?") + return None + except Exception as exc: + st.error(f"API error ({r.status_code}): {r.text[:300]}") # type: ignore[possibly-undefined] + return None + + +# --------------------------------------------------------------------------- +# Helpers β€” charts +# --------------------------------------------------------------------------- + + +def _bar_predictions(df: pd.DataFrame, pos_col: str, title: str) -> go.Figure: + """Horizontal bar chart of predicted positions (lower = better).""" + df = df.sort_values(pos_col) + fig = px.bar( + df, + x=pos_col, + y="Driver", + orientation="h", + color="Team", + title=title, + labels={pos_col: "Predicted position", "Driver": ""}, + height=max(400, len(df) * 28), + ) + fig.update_layout(yaxis={"categoryorder": "total ascending"}, showlegend=True) + return fig + + +def _podium_bar(sim_df: pd.DataFrame) -> go.Figure: + """Grouped bar chart: win / podium / top-10 % per driver.""" + df = sim_df.sort_values("Win_Pct", ascending=False).head(20).copy() + fig = go.Figure() + for col, label, colour in [ + ("Win_Pct", "Win %", "#FFD700"), + ("Podium_Pct", "Podium %", "#C0C0C0"), + ("Top10_Pct", "Top-10 %", "#CD7F32"), + ]: + if col in df.columns: + fig.add_trace( + go.Bar( + name=label, + x=df["Driver"], + y=(df[col] * 100).round(1), + marker_color=colour, + ) + ) + fig.update_layout( + barmode="group", + title="Win / Podium / Top-10 probability (%)", + yaxis_title="Probability (%)", + xaxis_title="", + height=450, + ) + return fig + + +def _position_heatmap(pos_matrix_data: dict, drivers: list[str]) -> go.Figure: + """Heatmap of finishing-position distributions.""" + df = pd.DataFrame.from_dict(pos_matrix_data, orient="tight" if "index" in pos_matrix_data else "dict") + if "data" in pos_matrix_data: + df = pd.DataFrame( + pos_matrix_data["data"], + index=pos_matrix_data.get("index", drivers), + columns=pos_matrix_data.get("columns", list(range(1, 21))), + ) + # Sort drivers by median finishing position + median_pos = (df * df.columns.astype(float)).sum(axis=1) + df = df.loc[median_pos.sort_values().index] + fig = px.imshow( + df * 100, + labels={"x": "Finishing position", "y": "Driver", "color": "Probability (%)"}, + title="Finishing-position distribution (% of simulations)", + color_continuous_scale="Blues", + aspect="auto", + height=max(400, len(df) * 28), + ) + return fig + + +# --------------------------------------------------------------------------- +# Main β€” health banner +# --------------------------------------------------------------------------- + +st.title("🏎 F1 Prediction Dashboard") + +health = _get("/health") +if health: + st.success(f"API online β€” version {health.get('version', '?')} | {health.get('timestamp', '')}") +else: + st.warning("API offline. Start the server and refresh this page.") + +# --------------------------------------------------------------------------- +# Race selector +# --------------------------------------------------------------------------- + +sched_data = _get(f"/schedule/{year}") +race_names: list[str] = [] +if sched_data and sched_data.get("schedule"): + race_names = [r["EventName"] for r in sched_data["schedule"]] + +race = st.selectbox( + "Select race", + options=race_names or ["(no schedule loaded)"], + help="Races pulled from FastF1 via the API.", +) + +# --------------------------------------------------------------------------- +# Tabs +# --------------------------------------------------------------------------- + +tab_race, tab_quali, tab_sim = st.tabs(["🏁 Race prediction", "⏱ Qualifying prediction", "🎲 Simulation"]) + +# ── Race prediction ────────────────────────────────────────────────────────── +with tab_race: + mode = st.selectbox( + "Prediction mode", + ["auto", "pre_weekend", "pre_quali", "post_quali"], + index=0, + help="'auto' lets the model decide based on available data.", + ) + if st.button("Predict race", key="btn_race", disabled=not race_names): + with st.spinner("Running race prediction…"): + data = _post("/predict/race", {"year": int(year), "race": race, "mode": mode}) + if data and data.get("predictions"): + df = pd.DataFrame(data["predictions"]) + st.dataframe(df, use_container_width=True) + pos_col = next( + (c for c in ["Predicted_Race_Pos", "Predicted_Pos", "Position"] if c in df.columns), + df.columns[0], + ) + st.plotly_chart(_bar_predictions(df, pos_col, f"{year} {race} β€” Race prediction"), use_container_width=True) + else: + st.info("No predictions available. Ensure models are trained (`python scripts/predict.py train`).") + +# ── Qualifying prediction ──────────────────────────────────────────────────── +with tab_quali: + if st.button("Predict qualifying", key="btn_quali", disabled=not race_names): + with st.spinner("Running qualifying prediction…"): + data = _post("/predict/qualifying", {"year": int(year), "race": race}) + if data and data.get("predictions"): + df = pd.DataFrame(data["predictions"]) + st.dataframe(df, use_container_width=True) + pos_col = next( + (c for c in ["Predicted_Quali_Pos", "Predicted_Pos", "Quali_Pos"] if c in df.columns), + df.columns[0], + ) + st.plotly_chart( + _bar_predictions(df, pos_col, f"{year} {race} β€” Qualifying prediction"), + use_container_width=True, + ) + else: + st.info("No qualifying predictions available. Ensure qualifying model is trained.") + +# ── Simulation ─────────────────────────────────────────────────────────────── +with tab_sim: + st.markdown( + f"Run **{n_sims:,}** Monte Carlo simulations with a **{sc_prob:.0%}** safety-car probability." + ) + if st.button("Run simulation", key="btn_sim", disabled=not race_names): + with st.spinner(f"Simulating {n_sims} races…"): + data = _post( + "/simulate", + { + "year": int(year), + "race": race, + "n_simulations": n_sims, + "sc_probability": sc_prob, + }, + ) + if data and data.get("summary"): + sim_df = pd.DataFrame(data["summary"]) + st.subheader("Summary") + st.dataframe(sim_df, use_container_width=True) + + col1, col2 = st.columns(2) + with col1: + st.plotly_chart(_podium_bar(sim_df), use_container_width=True) + with col2: + if data.get("position_matrix"): + drivers = sim_df["Driver"].tolist() if "Driver" in sim_df.columns else [] + try: + st.plotly_chart( + _position_heatmap(data["position_matrix"], drivers), + use_container_width=True, + ) + except Exception: + st.info("Position matrix chart unavailable.") + else: + st.info("Simulation returned no results. Ensure models are trained.") diff --git a/requirements.txt b/requirements.txt index 4f46bf0..ca8abe6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -34,6 +34,13 @@ matplotlib==3.8.4 seaborn==0.13.2 plotly==5.22.0 +# REST API +fastapi==0.115.0 +uvicorn[standard]==0.30.6 + +# Dashboard +streamlit==1.36.0 + # Configuration Management pyyaml==6.0.2 diff --git a/src/f1_predictor/api.py b/src/f1_predictor/api.py new file mode 100644 index 0000000..99e6064 --- /dev/null +++ b/src/f1_predictor/api.py @@ -0,0 +1,249 @@ +"""FastAPI REST layer for the F1 prediction system. + +Start with: + uvicorn src.f1_predictor.api:app --reload + +Endpoints +--------- +GET /health β†’ liveness check +GET /schedule/{year} β†’ race schedule for a season +POST /predict/qualifying β†’ qualifying-order predictions +POST /predict/race β†’ race-order predictions +POST /simulate β†’ Monte Carlo race simulation +""" + +from __future__ import annotations + +import logging +from datetime import datetime +from typing import Any, Dict, List, Optional + +from fastapi import FastAPI, HTTPException, Query +from fastapi.responses import JSONResponse +from pydantic import BaseModel, Field + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Application +# --------------------------------------------------------------------------- + +app = FastAPI( + title="F1 Prediction API", + description="REST interface for the F1 race and qualifying prediction system.", + version="1.0.0", + docs_url="/docs", + redoc_url="/redoc", +) + +# --------------------------------------------------------------------------- +# Request / Response models +# --------------------------------------------------------------------------- + + +class PredictRequest(BaseModel): + year: int = Field(..., ge=2018, le=2030, description="Season year") + race: str = Field(..., description="Official FastF1 EventName, e.g. 'Italian Grand Prix'") + mode: Optional[str] = Field( + "auto", + description="Prediction mode: auto | pre_weekend | pre_quali | post_quali | live", + ) + + +class SimulateRequest(BaseModel): + year: int = Field(..., ge=2018, le=2030) + race: str = Field(..., description="Official FastF1 EventName") + mode: Optional[str] = Field("auto") + n_simulations: int = Field(2000, ge=100, le=20000) + sc_probability: float = Field(0.3, ge=0.0, le=1.0) + seed: Optional[int] = None + + +class HealthResponse(BaseModel): + status: str + timestamp: str + version: str + + +# --------------------------------------------------------------------------- +# Lazy singletons +# --------------------------------------------------------------------------- + +_predictor: Any = None + + +def _get_predictor() -> Any: + """Lazily initialise F1Predictor (expensive due to model loading).""" + global _predictor + if _predictor is None: + from .prediction import F1Predictor # noqa: PLC0415 + + _predictor = F1Predictor() + return _predictor + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _df_to_records(df: Any) -> List[Dict]: + """Convert a DataFrame to a list of dicts, handling NaN/NaT safely.""" + import math + + records = df.to_dict(orient="records") + clean = [] + for row in records: + clean.append( + { + k: (None if isinstance(v, float) and math.isnan(v) else v) + for k, v in row.items() + } + ) + return clean + + +def _get_schedule(year: int) -> List[Dict]: + try: + import fastf1 # noqa: PLC0415 + + sched = fastf1.get_event_schedule(year, include_testing=False) + sched = sched.sort_values("RoundNumber") + cols = ["RoundNumber", "EventName", "EventDate", "Country", "Location"] + existing = [c for c in cols if c in sched.columns] + out = sched[existing].copy() + out["EventDate"] = out["EventDate"].astype(str) + return out.to_dict(orient="records") + except Exception as exc: + raise HTTPException(status_code=502, detail=f"FastF1 schedule error: {exc}") from exc + + +# --------------------------------------------------------------------------- +# Routes +# --------------------------------------------------------------------------- + + +@app.get("/health", response_model=HealthResponse, tags=["Meta"]) +def health() -> HealthResponse: + """Liveness / readiness probe.""" + return HealthResponse( + status="ok", + timestamp=datetime.utcnow().isoformat() + "Z", + version=app.version, + ) + + +@app.get("/schedule/{year}", tags=["Schedule"]) +def schedule(year: int) -> JSONResponse: + """Return the official race schedule for *year*. + + Uses FastF1's event schedule; results are **not** cached across requests. + """ + if year < 2018 or year > 2030: + raise HTTPException(status_code=400, detail="year must be between 2018 and 2030") + data = _get_schedule(year) + return JSONResponse({"year": year, "rounds": len(data), "schedule": data}) + + +@app.post("/predict/qualifying", tags=["Predictions"]) +def predict_qualifying(req: PredictRequest) -> JSONResponse: + """Predict qualifying order for the requested event. + + Returns a ranked list of drivers with predicted qualifying positions. + Requires trained qualifying model artifacts. + """ + predictor = _get_predictor() + try: + result = predictor.predict_qualifying(req.year, req.race, scenario="qualifying") + except Exception as exc: + logger.exception("predict_qualifying failed") + raise HTTPException(status_code=500, detail=str(exc)) from exc + + if result is None or result.empty: + raise HTTPException( + status_code=422, + detail=( + "No predictions returned β€” check that models are trained and " + "the race name matches the FastF1 EventName exactly." + ), + ) + return JSONResponse( + { + "year": req.year, + "race": req.race, + "session": "qualifying", + "predictions": _df_to_records(result), + } + ) + + +@app.post("/predict/race", tags=["Predictions"]) +def predict_race(req: PredictRequest) -> JSONResponse: + """Predict race finishing order for the requested event. + + Returns a ranked list of drivers with predicted race positions. + Requires trained race model artifacts. + """ + predictor = _get_predictor() + try: + result = predictor.predict_race(req.year, req.race, mode=req.mode or "auto") + except Exception as exc: + logger.exception("predict_race failed") + raise HTTPException(status_code=500, detail=str(exc)) from exc + + if result is None or result.empty: + raise HTTPException( + status_code=422, + detail=( + "No predictions returned β€” check that models are trained and " + "the race name matches the FastF1 EventName exactly." + ), + ) + return JSONResponse( + { + "year": req.year, + "race": req.race, + "session": "race", + "mode": req.mode, + "predictions": _df_to_records(result), + } + ) + + +@app.post("/simulate", tags=["Simulation"]) +def simulate(req: SimulateRequest) -> JSONResponse: + """Run a Monte Carlo race simulation. + + Runs *n_simulations* stochastic races and returns per-driver win / podium / + top-10 probabilities and expected points, plus the full position matrix. + """ + predictor = _get_predictor() + try: + result = predictor.simulate( + req.year, + req.race, + mode=req.mode or "auto", + n_simulations=req.n_simulations, + sc_probability=req.sc_probability, + seed=req.seed, + ) + except Exception as exc: + logger.exception("simulate failed") + raise HTTPException(status_code=500, detail=str(exc)) from exc + + if result is None: + raise HTTPException( + status_code=422, + detail="Simulation returned no results β€” verify models are trained and event name is valid.", + ) + + return JSONResponse( + { + "year": req.year, + "race": req.race, + "n_simulations": result.n_simulations, + "seed": result.seed, + "summary": _df_to_records(result.summary), + "position_matrix": result.position_matrix.to_dict(orient="split"), + } + ) diff --git a/tests/integration/test_pipeline.py b/tests/integration/test_pipeline.py index 449ca6a..b6822dc 100644 --- a/tests/integration/test_pipeline.py +++ b/tests/integration/test_pipeline.py @@ -1,5 +1,7 @@ """Integration tests: full feature engineering pipeline + store round-trip.""" +import os +import pathlib import numpy as np import pandas as pd import pytest @@ -7,6 +9,9 @@ from src.f1_predictor.feature_engineering_pipeline import FeatureEngineeringPipeline from src.f1_predictor.store import F1Store +# Repository root β€” used for CLI smoke tests +_REPO_ROOT = str(pathlib.Path(__file__).resolve().parents[2]) + # --------------------------------------------------------------------------- # Feature Engineering Pipeline β€” integration @@ -137,7 +142,7 @@ def test_predict_help(self): result = subprocess.run( [sys.executable, "scripts/predict.py", "--help"], capture_output=True, text=True, - cwd="E:\\f1\\f1_prediction_project\\.claude\\worktrees\\naughty-easley", + cwd=_REPO_ROOT, ) assert result.returncode == 0 assert "F1 Prediction System" in result.stdout @@ -150,7 +155,7 @@ def test_fetch_data_help(self): result = subprocess.run( [sys.executable, "scripts/predict.py", "fetch-data", "--help"], capture_output=True, text=True, - cwd="E:\\f1\\f1_prediction_project\\.claude\\worktrees\\naughty-easley", + cwd=_REPO_ROOT, ) assert result.returncode == 0 assert "--force" in result.stdout @@ -160,6 +165,6 @@ def test_train_help(self): result = subprocess.run( [sys.executable, "scripts/predict.py", "train", "--help"], capture_output=True, text=True, - cwd="E:\\f1\\f1_prediction_project\\.claude\\worktrees\\naughty-easley", + cwd=_REPO_ROOT, ) assert result.returncode == 0 diff --git a/tests/unit/test_api.py b/tests/unit/test_api.py new file mode 100644 index 0000000..92a565a --- /dev/null +++ b/tests/unit/test_api.py @@ -0,0 +1,260 @@ +"""Unit tests for the FastAPI layer (no trained models required). + +Uses FastAPI's TestClient so the tests run entirely in-process without +needing a live server or real model artefacts. +""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pandas as pd +import pytest + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_pred_df(session: str = "race") -> pd.DataFrame: + drivers = ["VER", "HAM", "LEC", "NOR", "ALO"] + if session == "qualifying": + return pd.DataFrame( + { + "Driver": drivers, + "Team": ["Red Bull", "Mercedes", "Ferrari", "McLaren", "Aston"], + "Predicted_Quali_Pos": [1.0, 2.0, 3.0, 4.0, 5.0], + } + ) + return pd.DataFrame( + { + "Driver": drivers, + "Team": ["Red Bull", "Mercedes", "Ferrari", "McLaren", "Aston"], + "Predicted_Race_Pos": [1.0, 2.0, 3.0, 4.0, 5.0], + } + ) + + +def _make_sim_result() -> MagicMock: + from src.f1_predictor.simulation import SimulationResult # noqa: PLC0415 + + summary = pd.DataFrame( + { + "Driver": ["VER", "HAM"], + "Team": ["Red Bull", "Mercedes"], + "Win_Pct": [0.7, 0.3], + "Podium_Pct": [0.9, 0.6], + "Top10_Pct": [1.0, 1.0], + "Exp_Points": [22.0, 15.0], + } + ) + pos_matrix = pd.DataFrame( + [[0.7, 0.3], [0.3, 0.7]], index=["VER", "HAM"], columns=[1, 2] + ) + return SimulationResult( + summary=summary, + position_matrix=pos_matrix, + n_simulations=200, + seed=42, + ) + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + +@pytest.fixture +def client(): + """TestClient with F1Predictor mocked out.""" + try: + from fastapi.testclient import TestClient + except ImportError: + pytest.skip("fastapi not installed") + + from src.f1_predictor import api as api_module # noqa: PLC0415 + + # Reset the cached singleton so mock is used fresh + api_module._predictor = None + + mock_predictor = MagicMock() + mock_predictor.predict_qualifying.return_value = _make_pred_df("qualifying") + mock_predictor.predict_race.return_value = _make_pred_df("race") + mock_predictor.simulate.return_value = _make_sim_result() + + with patch.object(api_module, "_get_predictor", return_value=mock_predictor): + yield TestClient(api_module.app) + + # Clean up singleton after test + api_module._predictor = None + + +# --------------------------------------------------------------------------- +# Tests β€” /health +# --------------------------------------------------------------------------- + +class TestHealth: + def test_health_ok(self, client): + resp = client.get("/health") + assert resp.status_code == 200 + body = resp.json() + assert body["status"] == "ok" + assert "timestamp" in body + assert "version" in body + + +# --------------------------------------------------------------------------- +# Tests β€” /schedule/{year} +# --------------------------------------------------------------------------- + +class TestSchedule: + def test_invalid_year_low(self, client): + resp = client.get("/schedule/2010") + assert resp.status_code == 400 + + def test_invalid_year_high(self, client): + resp = client.get("/schedule/2050") + assert resp.status_code == 400 + + def test_schedule_fastf1_error_raises_502(self): + """If FastF1 raises inside the endpoint, the server returns 500.""" + try: + from fastapi.testclient import TestClient + except ImportError: + pytest.skip("fastapi not installed") + + from src.f1_predictor import api as api_module # noqa: PLC0415 + + # raise_server_exceptions=False so TestClient returns a 500 response + # rather than re-raising the exception in the test process. + no_raise_client = TestClient(api_module.app, raise_server_exceptions=False) + with patch.object(api_module, "_get_schedule", side_effect=Exception("network")): + resp = no_raise_client.get("/schedule/2024") + assert resp.status_code in (500, 502) + + +# --------------------------------------------------------------------------- +# Tests β€” /predict/qualifying +# --------------------------------------------------------------------------- + +class TestPredictQualifying: + def test_success(self, client): + payload = {"year": 2024, "race": "Italian Grand Prix"} + resp = client.post("/predict/qualifying", json=payload) + assert resp.status_code == 200 + body = resp.json() + assert body["session"] == "qualifying" + assert body["year"] == 2024 + assert body["race"] == "Italian Grand Prix" + assert len(body["predictions"]) == 5 + assert "Driver" in body["predictions"][0] + + def test_no_predictions_returns_422(self, client): + try: + from fastapi.testclient import TestClient + except ImportError: + pytest.skip("fastapi not installed") + + from src.f1_predictor import api as api_module # noqa: PLC0415 + + empty_mock = MagicMock() + empty_mock.predict_qualifying.return_value = None + with patch.object(api_module, "_get_predictor", return_value=empty_mock): + resp = client.post( + "/predict/qualifying", json={"year": 2024, "race": "Italian Grand Prix"} + ) + assert resp.status_code == 422 + + def test_missing_year_returns_422(self, client): + resp = client.post("/predict/qualifying", json={"race": "Italian Grand Prix"}) + assert resp.status_code == 422 + + def test_missing_race_returns_422(self, client): + resp = client.post("/predict/qualifying", json={"year": 2024}) + assert resp.status_code == 422 + + +# --------------------------------------------------------------------------- +# Tests β€” /predict/race +# --------------------------------------------------------------------------- + +class TestPredictRace: + def test_success(self, client): + payload = {"year": 2024, "race": "Italian Grand Prix", "mode": "auto"} + resp = client.post("/predict/race", json=payload) + assert resp.status_code == 200 + body = resp.json() + assert body["session"] == "race" + assert len(body["predictions"]) == 5 + + def test_default_mode(self, client): + payload = {"year": 2024, "race": "Italian Grand Prix"} + resp = client.post("/predict/race", json=payload) + assert resp.status_code == 200 + + def test_no_predictions_returns_422(self, client): + try: + from fastapi.testclient import TestClient + except ImportError: + pytest.skip("fastapi not installed") + + from src.f1_predictor import api as api_module # noqa: PLC0415 + + empty_mock = MagicMock() + empty_mock.predict_race.return_value = pd.DataFrame() + with patch.object(api_module, "_get_predictor", return_value=empty_mock): + resp = client.post( + "/predict/race", json={"year": 2024, "race": "Italian Grand Prix"} + ) + assert resp.status_code == 422 + + +# --------------------------------------------------------------------------- +# Tests β€” /simulate +# --------------------------------------------------------------------------- + +class TestSimulate: + def test_success(self, client): + payload = {"year": 2024, "race": "Italian Grand Prix", "n_simulations": 200} + resp = client.post("/simulate", json=payload) + assert resp.status_code == 200 + body = resp.json() + assert body["n_simulations"] == 200 + assert "summary" in body + assert "position_matrix" in body + assert len(body["summary"]) == 2 + + def test_no_result_returns_422(self, client): + try: + from fastapi.testclient import TestClient + except ImportError: + pytest.skip("fastapi not installed") + + from src.f1_predictor import api as api_module # noqa: PLC0415 + + none_mock = MagicMock() + none_mock.simulate.return_value = None + with patch.object(api_module, "_get_predictor", return_value=none_mock): + resp = client.post( + "/simulate", json={"year": 2024, "race": "Italian Grand Prix"} + ) + assert resp.status_code == 422 + + def test_n_simulations_too_low(self, client): + payload = {"year": 2024, "race": "Italian Grand Prix", "n_simulations": 50} + resp = client.post("/simulate", json=payload) + assert resp.status_code == 422 + + def test_sc_probability_out_of_range(self, client): + payload = { + "year": 2024, + "race": "Italian Grand Prix", + "sc_probability": 1.5, + } + resp = client.post("/simulate", json=payload) + assert resp.status_code == 422 + + def test_openapi_json_served(self, client): + resp = client.get("/openapi.json") + assert resp.status_code == 200 + data = resp.json() + assert data["info"]["title"] == "F1 Prediction API"