From 1a507f8d16432a81bb9593850583ad1f47143460 Mon Sep 17 00:00:00 2001 From: saathviksheerla Date: Tue, 17 Mar 2026 14:06:44 +0530 Subject: [PATCH 1/6] feat: add GET/POST /task/list endpoint (#23) --- src/routers/openml/tasks.py | 241 +++++++++++++++++++++++++++++- tests/routers/openml/task_test.py | 85 +++++++++++ 2 files changed, 322 insertions(+), 4 deletions(-) diff --git a/src/routers/openml/tasks.py b/src/routers/openml/tasks.py index 788cd804..039e272e 100644 --- a/src/routers/openml/tasks.py +++ b/src/routers/openml/tasks.py @@ -1,17 +1,19 @@ import json import re -from typing import Annotated, cast +from enum import StrEnum +from typing import Annotated, Any, cast import xmltodict -from fastapi import APIRouter, Depends +from fastapi import APIRouter, Body, Depends from sqlalchemy import RowMapping, text from sqlalchemy.ext.asyncio import AsyncConnection import config import database.datasets import database.tasks -from core.errors import InternalError, TaskNotFoundError -from routers.dependencies import expdb_connection +from core.errors import InternalError, NoResultsError, TaskNotFoundError +from routers.dependencies import Pagination, expdb_connection +from routers.types import CasualString128, IntegerRange, SystemString64, integer_range_regex from schemas.datasets.openml import Task router = APIRouter(prefix="/tasks", tags=["tasks"]) @@ -157,6 +159,237 @@ async def _fill_json_template( # noqa: C901 return template.replace("[CONSTANT:base_url]", server_url) +class TaskStatusFilter(StrEnum): + """Valid values for the status filter.""" + + ACTIVE = "active" + DEACTIVATED = "deactivated" + IN_PREPARATION = "in_preparation" + ALL = "all" + + +QUALITIES_TO_SHOW = [ + "MajorityClassSize", + "MaxNominalAttDistinctValues", + "MinorityClassSize", + "NumberOfClasses", + "NumberOfFeatures", + "NumberOfInstances", + "NumberOfInstancesWithMissingValues", + "NumberOfMissingValues", + "NumberOfNumericFeatures", + "NumberOfSymbolicFeatures", +] + +BASIC_TASK_INPUTS = [ + "source_data", + "target_feature", + "estimation_procedure", + "evaluation_measures", +] + + +def _quality_clause(quality: str, range_: str | None) -> str: + """Return a SQL WHERE clause fragment filtering tasks by a dataset quality range. + + Looks up tasks whose source dataset has the given quality within the range. + Range can be exact ('100') or a range ('50..200'). + """ + if not range_: + return "" + if not (match := re.match(integer_range_regex, range_)): + msg = f"`range_` not a valid range: {range_}" + raise ValueError(msg) + start, end = match.groups() + # end group looks like "..200", strip the ".." prefix to get just the number + value = f"`value` BETWEEN {start} AND {end[2:]}" if end else f"`value`={start}" + # nested subquery: find datasets with matching quality, then find tasks using those datasets + return f""" + AND t.`task_id` IN ( + SELECT ti.`task_id` FROM task_inputs ti + WHERE ti.`input`='source_data' AND ti.`value` IN ( + SELECT `data` FROM data_quality + WHERE `quality`='{quality}' AND {value} + ) + ) + """ # noqa: S608 + + +@router.post(path="/list", description="Provided for convenience, same as `GET` endpoint.") +@router.get(path="/list") +async def list_tasks( # noqa: PLR0913 + pagination: Annotated[Pagination, Body(default_factory=Pagination)], + task_type_id: Annotated[int | None, Body(description="Filter by task type id.")] = None, + tag: Annotated[str | None, SystemString64] = None, + data_tag: Annotated[str | None, SystemString64] = None, + status: Annotated[TaskStatusFilter, Body()] = TaskStatusFilter.ACTIVE, + task_id: Annotated[list[int] | None, Body(description="Filter by task id(s).")] = None, + data_id: Annotated[list[int] | None, Body(description="Filter by dataset id(s).")] = None, + data_name: Annotated[str | None, CasualString128] = None, + number_instances: Annotated[str | None, IntegerRange] = None, + number_features: Annotated[str | None, IntegerRange] = None, + number_classes: Annotated[str | None, IntegerRange] = None, + number_missing_values: Annotated[str | None, IntegerRange] = None, + expdb: Annotated[AsyncConnection, Depends(expdb_connection)] = None, +) -> list[dict[str, Any]]: + """List tasks, optionally filtered by type, tag, status, dataset properties, and more.""" + assert expdb is not None # noqa: S101 + + # --- WHERE clauses --- + if status == TaskStatusFilter.ALL: + where_status = "" + else: + where_status = f"AND IFNULL(ds.`status`, 'in_preparation') = '{status}'" + + where_type = "" if task_type_id is None else "AND t.`ttid` = :task_type_id" + where_tag = ( + "" if tag is None else "AND t.`task_id` IN (SELECT `id` FROM task_tag WHERE `tag` = :tag)" + ) + where_data_tag = ( + "" + if data_tag is None + else "AND d.`did` IN (SELECT `id` FROM dataset_tag WHERE `tag` = :data_tag)" + ) + task_id_str = ",".join(str(tid) for tid in task_id) if task_id else "" + where_task_id = "" if not task_id else f"AND t.`task_id` IN ({task_id_str})" + data_id_str = ",".join(str(did) for did in data_id) if data_id else "" + where_data_id = "" if not data_id else f"AND d.`did` IN ({data_id_str})" + where_data_name = "" if data_name is None else "AND d.`name` = :data_name" + + where_number_instances = _quality_clause("NumberOfInstances", number_instances) + where_number_features = _quality_clause("NumberOfFeatures", number_features) + where_number_classes = _quality_clause("NumberOfClasses", number_classes) + where_number_missing_values = _quality_clause("NumberOfMissingValues", number_missing_values) + + basic_inputs_str = ", ".join(f"'{i}'" for i in BASIC_TASK_INPUTS) + + # subquery to get the latest status per dataset + # dataset_status has multiple rows per dataset (history), we want only the most recent + status_subquery = """ + SELECT ds1.did, ds1.status + FROM dataset_status ds1 + WHERE ds1.status_date = ( + SELECT MAX(ds2.status_date) FROM dataset_status ds2 + WHERE ds1.did = ds2.did + ) + """ + + query = text( + f""" + SELECT + t.`task_id`, + t.`ttid` AS task_type_id, + tt.`name` AS task_type, + d.`did`, + d.`name`, + d.`format`, + IFNULL(ds.`status`, 'in_preparation') AS status + FROM task t + JOIN task_type tt ON tt.`ttid` = t.`ttid` + JOIN task_inputs ti_source ON ti_source.`task_id` = t.`task_id` + AND ti_source.`input` = 'source_data' + JOIN dataset d ON d.`did` = ti_source.`value` + LEFT JOIN ({status_subquery}) ds ON ds.`did` = d.`did` + WHERE 1=1 + {where_status} + {where_type} + {where_tag} + {where_data_tag} + {where_task_id} + {where_data_id} + {where_data_name} + {where_number_instances} + {where_number_features} + {where_number_classes} + {where_number_missing_values} + GROUP BY t.`task_id`, t.`ttid`, tt.`name`, d.`did`, d.`name`, d.`format`, ds.`status` + LIMIT {pagination.limit} OFFSET {pagination.offset} + """, # noqa: S608 + ) + + result = await expdb.execute( + query, + parameters={ + "task_type_id": task_type_id, + "tag": tag, + "data_tag": data_tag, + "data_name": data_name, + }, + ) + rows = result.mappings().all() + + if not rows: + msg = "No tasks match the search criteria." + raise NoResultsError(msg) + + columns = ["task_id", "task_type_id", "task_type", "did", "name", "format", "status"] + tasks: dict[int, dict[str, Any]] = { + row["task_id"]: {col: row[col] for col in columns} for row in rows + } + + # fetch inputs for all tasks in one query + task_ids_str = ",".join(str(tid) for tid in tasks) + inputs_result = await expdb.execute( + text( + f""" + SELECT `task_id`, `input`, `value` + FROM task_inputs + WHERE `task_id` IN ({task_ids_str}) + AND `input` IN ({basic_inputs_str}) + """, # noqa: S608 + ), + ) + for row in inputs_result.all(): + tasks[row.task_id].setdefault("input", []).append( + {"name": row.input, "value": row.value}, + ) + + # fetch qualities for all datasets in one query + did_list = ",".join(str(t["did"]) for t in tasks.values()) + qualities_str = ", ".join(f"'{q}'" for q in QUALITIES_TO_SHOW) + qualities_result = await expdb.execute( + text( + f""" + SELECT `data`, `quality`, `value` + FROM data_quality + WHERE `data` IN ({did_list}) + AND `quality` IN ({qualities_str}) + """, # noqa: S608 + ), + ) + # build a reverse map: dataset_id -> task_id + # needed because quality rows come back keyed by did, but our tasks dict is keyed by task_id + did_to_task_id = {t["did"]: tid for tid, t in tasks.items()} + for row in qualities_result.all(): + tid = did_to_task_id.get(row.data) + if tid is not None: + tasks[tid].setdefault("quality", []).append( + {"name": row.quality, "value": str(row.value)}, + ) + + # fetch tags for all tasks in one query + tags_result = await expdb.execute( + text( + f""" + SELECT `id`, `tag` + FROM task_tag + WHERE `id` IN ({task_ids_str}) + """, # noqa: S608 + ), + ) + for row in tags_result.all(): + tasks[row.id].setdefault("tag", []).append(row.tag) + + # ensure every task has all expected keys(input/quality/tag) even if no rows were found for them + # e.g. a task with no tags should return "tag": [] not missing key + for task in tasks.values(): + task.setdefault("input", []) + task.setdefault("quality", []) + task.setdefault("tag", []) + + return list(tasks.values()) + + @router.get("/{task_id}") async def get_task( task_id: int, diff --git a/tests/routers/openml/task_test.py b/tests/routers/openml/task_test.py index e78bba8b..3f2b12d1 100644 --- a/tests/routers/openml/task_test.py +++ b/tests/routers/openml/task_test.py @@ -4,6 +4,91 @@ import httpx +async def test_list_tasks_default(py_api: httpx.AsyncClient) -> None: + """Default call returns active tasks with correct shape.""" + response = await py_api.post("/tasks/list", json={}) + assert response.status_code == HTTPStatus.OK + tasks = response.json() + assert isinstance(tasks, list) + assert len(tasks) > 0 + # verify shape of first task + task = tasks[0] + assert "task_id" in task + assert "task_type_id" in task + assert "task_type" in task + assert "did" in task + assert "name" in task + assert "format" in task + assert "status" in task + assert "input" in task + assert "quality" in task + assert "tag" in task + + +async def test_list_tasks_filter_type(py_api: httpx.AsyncClient) -> None: + """Filter by task_type_id returns only tasks of that type.""" + response = await py_api.post("/tasks/list", json={"task_type_id": 1}) + assert response.status_code == HTTPStatus.OK + tasks = response.json() + assert all(t["task_type_id"] == 1 for t in tasks) + + +async def test_list_tasks_filter_tag(py_api: httpx.AsyncClient) -> None: + """Filter by tag returns only tasks with that tag.""" + response = await py_api.post("/tasks/list", json={"tag": "OpenML100"}) + assert response.status_code == HTTPStatus.OK + tasks = response.json() + assert len(tasks) > 0 + assert all("OpenML100" in t["tag"] for t in tasks) + + +async def test_list_tasks_pagination(py_api: httpx.AsyncClient) -> None: + """Pagination returns correct number of results.""" + limit = 5 + response = await py_api.post( + "/tasks/list", + json={"pagination": {"limit": limit, "offset": 0}}, + ) + assert response.status_code == HTTPStatus.OK + assert len(response.json()) == limit + + +async def test_list_tasks_pagination_offset(py_api: httpx.AsyncClient) -> None: + """Offset returns different results than no offset.""" + r1 = await py_api.post("/tasks/list", json={"pagination": {"limit": 5, "offset": 0}}) + r2 = await py_api.post("/tasks/list", json={"pagination": {"limit": 5, "offset": 5}}) + ids1 = [t["task_id"] for t in r1.json()] + ids2 = [t["task_id"] for t in r2.json()] + assert ids1 != ids2 + + +async def test_list_tasks_number_instances_range(py_api: httpx.AsyncClient) -> None: + """number_instances range filter returns tasks whose dataset matches.""" + response = await py_api.post( + "/tasks/list", + json={"number_instances": "100..1000"}, + ) + assert response.status_code == HTTPStatus.OK + assert len(response.json()) > 0 + + +async def test_list_tasks_no_results(py_api: httpx.AsyncClient) -> None: + """Nonexistent tag returns 404 NoResultsError.""" + response = await py_api.post("/tasks/list", json={"tag": "nonexistent_tag_xyz"}) + assert response.status_code == HTTPStatus.NOT_FOUND + assert response.headers["content-type"] == "application/problem+json" + error = response.json() + assert error["status"] == HTTPStatus.NOT_FOUND + assert "372" in error["code"] + + +async def test_list_tasks_get(py_api: httpx.AsyncClient) -> None: + """GET /tasks/list with no body also works.""" + response = await py_api.get("/tasks/list") + assert response.status_code == HTTPStatus.OK + assert isinstance(response.json(), list) + + async def test_get_task(py_api: httpx.AsyncClient) -> None: response = await py_api.get("/tasks/59") assert response.status_code == HTTPStatus.OK From 73817eef6ecf0139e44ab08ba5355c7545f41c90 Mon Sep 17 00:00:00 2001 From: saathviksheerla Date: Tue, 17 Mar 2026 18:16:40 +0530 Subject: [PATCH 2/6] fix: address review comments on task list endpoint --- src/routers/openml/tasks.py | 8 +++++--- tests/routers/openml/task_test.py | 11 +++++++++-- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/src/routers/openml/tasks.py b/src/routers/openml/tasks.py index 039e272e..2d2772c5 100644 --- a/src/routers/openml/tasks.py +++ b/src/routers/openml/tasks.py @@ -303,6 +303,7 @@ async def list_tasks( # noqa: PLR0913 {where_number_classes} {where_number_missing_values} GROUP BY t.`task_id`, t.`ttid`, tt.`name`, d.`did`, d.`name`, d.`format`, ds.`status` + ORDER BY t.`task_id` LIMIT {pagination.limit} OFFSET {pagination.offset} """, # noqa: S608 ) @@ -359,10 +360,11 @@ async def list_tasks( # noqa: PLR0913 ) # build a reverse map: dataset_id -> task_id # needed because quality rows come back keyed by did, but our tasks dict is keyed by task_id - did_to_task_id = {t["did"]: tid for tid, t in tasks.items()} + did_to_task_ids: dict[int, list[int]] = {} + for tid, t in tasks.items(): + did_to_task_ids.setdefault(t["did"], []).append(tid) for row in qualities_result.all(): - tid = did_to_task_id.get(row.data) - if tid is not None: + for tid in did_to_task_ids.get(row.data, []): tasks[tid].setdefault("quality", []).append( {"name": row.quality, "value": str(row.value)}, ) diff --git a/tests/routers/openml/task_test.py b/tests/routers/openml/task_test.py index 3f2b12d1..332c5f19 100644 --- a/tests/routers/openml/task_test.py +++ b/tests/routers/openml/task_test.py @@ -11,6 +11,7 @@ async def test_list_tasks_default(py_api: httpx.AsyncClient) -> None: tasks = response.json() assert isinstance(tasks, list) assert len(tasks) > 0 + assert all(task["status"] == "active" for task in tasks) # verify shape of first task task = tasks[0] assert "task_id" in task @@ -30,6 +31,7 @@ async def test_list_tasks_filter_type(py_api: httpx.AsyncClient) -> None: response = await py_api.post("/tasks/list", json={"task_type_id": 1}) assert response.status_code == HTTPStatus.OK tasks = response.json() + assert len(tasks) > 0 assert all(t["task_type_id"] == 1 for t in tasks) @@ -64,12 +66,17 @@ async def test_list_tasks_pagination_offset(py_api: httpx.AsyncClient) -> None: async def test_list_tasks_number_instances_range(py_api: httpx.AsyncClient) -> None: """number_instances range filter returns tasks whose dataset matches.""" + min_instances, max_instances = 100, 1000 response = await py_api.post( "/tasks/list", - json={"number_instances": "100..1000"}, + json={"number_instances": f"{min_instances}..{max_instances}"}, ) assert response.status_code == HTTPStatus.OK - assert len(response.json()) > 0 + tasks = response.json() + assert len(tasks) > 0 + for task in tasks: + qualities = {q["name"]: q["value"] for q in task["quality"]} + assert min_instances <= float(qualities["NumberOfInstances"]) <= max_instances async def test_list_tasks_no_results(py_api: httpx.AsyncClient) -> None: From 4e9ee5d5df9b6d8674f767bd078d19f28d878c29 Mon Sep 17 00:00:00 2001 From: saathviksheerla Date: Tue, 17 Mar 2026 18:44:47 +0530 Subject: [PATCH 3/6] fix: address bot review comments on tests of task list endpoint --- tests/routers/openml/task_test.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tests/routers/openml/task_test.py b/tests/routers/openml/task_test.py index 332c5f19..c2f0094b 100644 --- a/tests/routers/openml/task_test.py +++ b/tests/routers/openml/task_test.py @@ -86,7 +86,7 @@ async def test_list_tasks_no_results(py_api: httpx.AsyncClient) -> None: assert response.headers["content-type"] == "application/problem+json" error = response.json() assert error["status"] == HTTPStatus.NOT_FOUND - assert "372" in error["code"] + assert error["code"] == "372" async def test_list_tasks_get(py_api: httpx.AsyncClient) -> None: @@ -96,6 +96,12 @@ async def test_list_tasks_get(py_api: httpx.AsyncClient) -> None: assert isinstance(response.json(), list) +async def test_list_tasks_invalid_range_format(py_api: httpx.AsyncClient) -> None: + """Invalid number_instances range returns 422 validation error.""" + response = await py_api.post("/tasks/list", json={"number_instances": "1...2"}) + assert response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY + + async def test_get_task(py_api: httpx.AsyncClient) -> None: response = await py_api.get("/tasks/59") assert response.status_code == HTTPStatus.OK From 4affe398a0094fa213fa99c2bd4b3a4d4d5ae79e Mon Sep 17 00:00:00 2001 From: saathviksheerla Date: Fri, 27 Mar 2026 19:54:59 +0530 Subject: [PATCH 4/6] refactor list_tasks and add migration tests --- src/routers/openml/tasks.py | 179 ++++++++-------- .../openml/migration/tasks_migration_test.py | 159 +++++++++++++++ tests/routers/openml/task_test.py | 193 +++++++++++++++--- 3 files changed, 422 insertions(+), 109 deletions(-) diff --git a/src/routers/openml/tasks.py b/src/routers/openml/tasks.py index 2d2772c5..34dde3d4 100644 --- a/src/routers/openml/tasks.py +++ b/src/routers/openml/tasks.py @@ -5,7 +5,8 @@ import xmltodict from fastapi import APIRouter, Body, Depends -from sqlalchemy import RowMapping, text +from sqlalchemy import bindparam, text +from sqlalchemy.engine import RowMapping from sqlalchemy.ext.asyncio import AsyncConnection import config @@ -217,7 +218,7 @@ def _quality_clause(quality: str, range_: str | None) -> str: @router.post(path="/list", description="Provided for convenience, same as `GET` endpoint.") @router.get(path="/list") -async def list_tasks( # noqa: PLR0913 +async def list_tasks( # noqa: PLR0913, PLR0912, C901, PLR0915 pagination: Annotated[Pagination, Body(default_factory=Pagination)], task_type_id: Annotated[int | None, Body(description="Filter by task type id.")] = None, tag: Annotated[str | None, SystemString64] = None, @@ -235,36 +236,46 @@ async def list_tasks( # noqa: PLR0913 """List tasks, optionally filtered by type, tag, status, dataset properties, and more.""" assert expdb is not None # noqa: S101 - # --- WHERE clauses --- - if status == TaskStatusFilter.ALL: - where_status = "" - else: - where_status = f"AND IFNULL(ds.`status`, 'in_preparation') = '{status}'" + clauses: list[str] = [] + parameters: dict[str, Any] = { + "offset": max(0, pagination.offset), + "limit": max(0, pagination.limit), + } - where_type = "" if task_type_id is None else "AND t.`ttid` = :task_type_id" - where_tag = ( - "" if tag is None else "AND t.`task_id` IN (SELECT `id` FROM task_tag WHERE `tag` = :tag)" - ) - where_data_tag = ( - "" - if data_tag is None - else "AND d.`did` IN (SELECT `id` FROM dataset_tag WHERE `tag` = :data_tag)" - ) - task_id_str = ",".join(str(tid) for tid in task_id) if task_id else "" - where_task_id = "" if not task_id else f"AND t.`task_id` IN ({task_id_str})" - data_id_str = ",".join(str(did) for did in data_id) if data_id else "" - where_data_id = "" if not data_id else f"AND d.`did` IN ({data_id_str})" - where_data_name = "" if data_name is None else "AND d.`name` = :data_name" + if status != TaskStatusFilter.ALL: + clauses.append("AND IFNULL(ds.`status`, 'in_preparation') = :status") + parameters["status"] = status + + if task_type_id is not None: + clauses.append("AND t.`ttid` = :task_type_id") + parameters["task_type_id"] = task_type_id + + if tag is not None: + clauses.append("AND t.`task_id` IN (SELECT `id` FROM task_tag WHERE `tag` = :tag)") + parameters["tag"] = tag + + if data_tag is not None: + clauses.append("AND d.`did` IN (SELECT `id` FROM dataset_tag WHERE `tag` = :data_tag)") + parameters["data_tag"] = data_tag + + if data_name is not None: + clauses.append("AND d.`name` = :data_name") + parameters["data_name"] = data_name + + if task_id: + clauses.append("AND t.`task_id` IN :task_ids") + parameters["task_ids"] = task_id + + if data_id: + clauses.append("AND d.`did` IN :data_ids") + parameters["data_ids"] = data_id where_number_instances = _quality_clause("NumberOfInstances", number_instances) where_number_features = _quality_clause("NumberOfFeatures", number_features) where_number_classes = _quality_clause("NumberOfClasses", number_classes) where_number_missing_values = _quality_clause("NumberOfMissingValues", number_missing_values) - basic_inputs_str = ", ".join(f"'{i}'" for i in BASIC_TASK_INPUTS) - - # subquery to get the latest status per dataset - # dataset_status has multiple rows per dataset (history), we want only the most recent + # subquery to get the latest status per dataset (dataset_status is a history table) status_subquery = """ SELECT ds1.did, ds1.status FROM dataset_status ds1 @@ -274,49 +285,44 @@ async def list_tasks( # noqa: PLR0913 ) """ - query = text( + main_query = text( f""" SELECT t.`task_id`, - t.`ttid` AS task_type_id, - tt.`name` AS task_type, + t.`ttid` AS task_type_id, + tt.`name` AS task_type, d.`did`, d.`name`, d.`format`, IFNULL(ds.`status`, 'in_preparation') AS status FROM task t - JOIN task_type tt ON tt.`ttid` = t.`ttid` - JOIN task_inputs ti_source ON ti_source.`task_id` = t.`task_id` + JOIN task_type tt + ON tt.`ttid` = t.`ttid` + JOIN task_inputs ti_source + ON ti_source.`task_id` = t.`task_id` AND ti_source.`input` = 'source_data' - JOIN dataset d ON d.`did` = ti_source.`value` - LEFT JOIN ({status_subquery}) ds ON ds.`did` = d.`did` + JOIN dataset d + ON d.`did` = ti_source.`value` + LEFT JOIN ({status_subquery}) ds + ON ds.`did` = d.`did` WHERE 1=1 - {where_status} - {where_type} - {where_tag} - {where_data_tag} - {where_task_id} - {where_data_id} - {where_data_name} {where_number_instances} {where_number_features} {where_number_classes} {where_number_missing_values} + {" ".join(clauses)} GROUP BY t.`task_id`, t.`ttid`, tt.`name`, d.`did`, d.`name`, d.`format`, ds.`status` ORDER BY t.`task_id` - LIMIT {pagination.limit} OFFSET {pagination.offset} + LIMIT :limit OFFSET :offset """, # noqa: S608 ) - result = await expdb.execute( - query, - parameters={ - "task_type_id": task_type_id, - "tag": tag, - "data_tag": data_tag, - "data_name": data_name, - }, - ) + if task_id: + main_query = main_query.bindparams(bindparam("task_ids", expanding=True)) + if data_id: + main_query = main_query.bindparams(bindparam("data_ids", expanding=True)) + + result = await expdb.execute(main_query, parameters=parameters) rows = result.mappings().all() if not rows: @@ -327,39 +333,45 @@ async def list_tasks( # noqa: PLR0913 tasks: dict[int, dict[str, Any]] = { row["task_id"]: {col: row[col] for col in columns} for row in rows } - - # fetch inputs for all tasks in one query - task_ids_str = ",".join(str(tid) for tid in tasks) + task_ids: list[int] = list(tasks.keys()) + dataset_ids: list[int] = list({t["did"] for t in tasks.values()}) + + inputs_query = text( + """ + SELECT `task_id`, `input`, `value` + FROM task_inputs + WHERE `task_id` IN :task_ids + AND `input` IN :basic_inputs + """, + ).bindparams( + bindparam("task_ids", expanding=True), + bindparam("basic_inputs", expanding=True), + ) inputs_result = await expdb.execute( - text( - f""" - SELECT `task_id`, `input`, `value` - FROM task_inputs - WHERE `task_id` IN ({task_ids_str}) - AND `input` IN ({basic_inputs_str}) - """, # noqa: S608 - ), + inputs_query, + parameters={"task_ids": task_ids, "basic_inputs": BASIC_TASK_INPUTS}, ) for row in inputs_result.all(): tasks[row.task_id].setdefault("input", []).append( {"name": row.input, "value": row.value}, ) - # fetch qualities for all datasets in one query - did_list = ",".join(str(t["did"]) for t in tasks.values()) - qualities_str = ", ".join(f"'{q}'" for q in QUALITIES_TO_SHOW) + qualities_query = text( + """ + SELECT `data`, `quality`, `value` + FROM data_quality + WHERE `data` IN :dataset_ids + AND `quality` IN :quality_names + """, + ).bindparams( + bindparam("dataset_ids", expanding=True), + bindparam("quality_names", expanding=True), + ) qualities_result = await expdb.execute( - text( - f""" - SELECT `data`, `quality`, `value` - FROM data_quality - WHERE `data` IN ({did_list}) - AND `quality` IN ({qualities_str}) - """, # noqa: S608 - ), + qualities_query, + parameters={"dataset_ids": dataset_ids, "quality_names": QUALITIES_TO_SHOW}, ) - # build a reverse map: dataset_id -> task_id - # needed because quality rows come back keyed by did, but our tasks dict is keyed by task_id + # multiple tasks can reference the same dataset; map dataset_id -> [task_id, ...] did_to_task_ids: dict[int, list[int]] = {} for tid, t in tasks.items(): did_to_task_ids.setdefault(t["did"], []).append(tid) @@ -369,21 +381,18 @@ async def list_tasks( # noqa: PLR0913 {"name": row.quality, "value": str(row.value)}, ) - # fetch tags for all tasks in one query - tags_result = await expdb.execute( - text( - f""" - SELECT `id`, `tag` - FROM task_tag - WHERE `id` IN ({task_ids_str}) - """, # noqa: S608 - ), - ) + tags_query = text( + """ + SELECT `id`, `tag` + FROM task_tag + WHERE `id` IN :task_ids + """, + ).bindparams(bindparam("task_ids", expanding=True)) + tags_result = await expdb.execute(tags_query, parameters={"task_ids": task_ids}) for row in tags_result.all(): tasks[row.id].setdefault("tag", []).append(row.tag) - # ensure every task has all expected keys(input/quality/tag) even if no rows were found for them - # e.g. a task with no tags should return "tag": [] not missing key + # ensure every task has all expected keys even if no related rows were found for task in tasks.values(): task.setdefault("input", []) task.setdefault("quality", []) diff --git a/tests/routers/openml/migration/tasks_migration_test.py b/tests/routers/openml/migration/tasks_migration_test.py index f71a1e2c..1dc2f5c7 100644 --- a/tests/routers/openml/migration/tasks_migration_test.py +++ b/tests/routers/openml/migration/tasks_migration_test.py @@ -1,5 +1,6 @@ import asyncio from http import HTTPStatus +from typing import Any, cast import deepdiff import httpx @@ -59,3 +60,161 @@ async def test_get_task_equal( ignore_order=True, ) assert not differences + + +# PHP task list no-results error code is 482 (unlike datasets which uses 372). +# Python uses 372 (NoResultsError). This difference is documented in the tests below. +_PHP_TASK_LIST_NO_RESULTS_CODE = "482" +_PY_TASK_LIST_NO_RESULTS_CODE = "372" + + +def _build_php_task_list_path(php_params: dict[str, Any]) -> str: + """Build a PHP-style path for /task/list with path-encoded filter parameters.""" + if not php_params: + return "/task/list" + parts = "/".join(f"{k}/{v}" for k, v in php_params.items()) + return f"/task/list/{parts}" + + +def _normalize_py_task(task: dict[str, Any]) -> dict[str, Any]: + """Normalize a single Python task list entry to match PHP format. + + PHP (XML-to-JSON) returns single-element arrays as plain values, not lists. + PHP also returns IDs as ints + and completely omits the "tag" field for all tasks in the list endpoint. + """ + t = nested_num_to_str(task) + t = nested_remove_single_element_list(t) + + # PHP's list endpoint does not return tags AT ALL + t.pop("tag", None) + + # PHP omits qualities where value is None string + if "quality" in t: + t["quality"] = [q for q in t["quality"] if q.get("value") != "None"] + + # PHP's list endpoint does not return these arrays when empty + for opt_key in ("quality", "input"): + if t.get(opt_key) == []: + t.pop(opt_key) + + # PHP's list endpoint returns these specific fields as ints + t["task_id"] = int(t["task_id"]) + t["task_type_id"] = int(t["task_type_id"]) + t["did"] = int(t["did"]) + + return cast("dict[str, Any]", t) + + +# Filter combos: (php_path_params, python_body_extras) +# PHP uses path-based filter keys (e.g. "type"), Python uses JSON body keys (e.g. "task_type_id") +_FILTER_COMBOS: list[tuple[dict[str, Any], dict[str, Any]]] = [ + ({"type": 1}, {"task_type_id": 1}), # by task type + ({"tag": "OpenML100"}, {"tag": "OpenML100"}), # by tag + ({"type": 1, "tag": "OpenML100"}, {"task_type_id": 1, "tag": "OpenML100"}), # combined +] + +_FILTER_IDS = ["type", "tag", "type_and_tag"] + + +@pytest.mark.parametrize( + ("php_params", "py_extra"), + _FILTER_COMBOS, + ids=_FILTER_IDS, +) +async def test_list_tasks_equal( + php_params: dict[str, Any], + py_extra: dict[str, Any], + py_api: httpx.AsyncClient, + php_api: httpx.AsyncClient, +) -> None: + """Python and PHP task list responses contain the same tasks for the same filters. + + Known differences documented here: + - PHP wraps response in {"tasks": {"task": [...]}}, Python returns a flat list. + - PHP uses string values for all fields; Python is typed (handled via nested_num_to_str). + - PHP omits the "tag" key when a task has no tags; Python returns "tag": []. + - PHP error status is 412 PRECONDITION_FAILED; Python uses 404 NOT_FOUND. + - PHP no-results error code is 482; Python is 372. + """ + php_path = _build_php_task_list_path(php_params) + # Use a very large limit on Python side to match PHP's unbounded default result count + py_body = {**py_extra, "pagination": {"limit": 1_000_000, "offset": 0}} + py_response, php_response = await asyncio.gather( + py_api.post("/tasks/list", json=py_body), + php_api.get(php_path), + ) + + # Error case: no results — PHP returns 412, Python returns 404 + if php_response.status_code == HTTPStatus.PRECONDITION_FAILED: + assert py_response.status_code == HTTPStatus.NOT_FOUND + assert py_response.headers["content-type"] == "application/problem+json" + assert php_response.json()["error"]["code"] == _PHP_TASK_LIST_NO_RESULTS_CODE + assert py_response.json()["code"] == _PY_TASK_LIST_NO_RESULTS_CODE + return + + assert php_response.status_code == HTTPStatus.OK + assert py_response.status_code == HTTPStatus.OK + + php_tasks: list[dict[str, Any]] = php_response.json()["tasks"]["task"] + py_tasks: list[dict[str, Any]] = [_normalize_py_task(t) for t in py_response.json()] + + php_ids = {int(t["task_id"]) for t in php_tasks} + py_ids = {int(t["task_id"]) for t in py_tasks} + + # Python may return more tasks than PHP (PHP may apply visibility/limit rules server-side). + # Assert that every task PHP returns is also present in Python — not the reverse. + assert php_ids.issubset(py_ids), f"PHP has task IDs not in Python: {php_ids - py_ids}" + + # Compare only the tasks PHP returned — per-task deepdiff for clear error messages + py_by_id = {int(t["task_id"]): t for t in py_tasks} + php_by_id = {int(t["task_id"]): t for t in php_tasks} + for task_id in php_ids: + differences = deepdiff.diff.DeepDiff( + py_by_id[task_id], + php_by_id[task_id], + ignore_order=True, + ) + assert not differences, f"Differences for task {task_id}: {differences}" + + +@pytest.mark.parametrize( + ("php_params", "py_extra"), + [ + ({"tag": "nonexistent_tag_xyz_abc"}, {"tag": "nonexistent_tag_xyz_abc"}), + ({"type": 9999}, {"task_type_id": 9999}), + ({"data_name": "nonexistent_dataset_xyz"}, {"data_name": "nonexistent_dataset_xyz"}), + ], + ids=["bad_tag", "bad_type", "bad_data_name"], +) +async def test_list_tasks_no_results_matches_php( + php_params: dict[str, Any], + py_extra: dict[str, Any], + py_api: httpx.AsyncClient, + php_api: httpx.AsyncClient, +) -> None: + """Both APIs return a "no results" error for filters matching nothing. + + Documented differences: + - PHP returns 412 PRECONDITION_FAILED; Python returns 404 NOT_FOUND. + - PHP error code is 482; Python error code is 372. + - PHP message: "No results"; Python detail: "No tasks match the search criteria." + """ + php_path = _build_php_task_list_path(php_params) + py_response, php_response = await asyncio.gather( + py_api.post("/tasks/list", json=py_extra), + php_api.get(php_path), + ) + + assert php_response.status_code == HTTPStatus.PRECONDITION_FAILED + assert py_response.status_code == HTTPStatus.NOT_FOUND + + php_error = php_response.json()["error"] + py_error = py_response.json() + + # Error codes differ between PHP and Python for task list no-results + assert php_error["code"] == _PHP_TASK_LIST_NO_RESULTS_CODE + assert py_error["code"] == _PY_TASK_LIST_NO_RESULTS_CODE + assert php_error["message"] == "No results" + assert py_error["detail"] == "No tasks match the search criteria." + assert py_response.headers["content-type"] == "application/problem+json" diff --git a/tests/routers/openml/task_test.py b/tests/routers/openml/task_test.py index c2f0094b..f8715d30 100644 --- a/tests/routers/openml/task_test.py +++ b/tests/routers/openml/task_test.py @@ -1,7 +1,11 @@ from http import HTTPStatus +from typing import Any import deepdiff import httpx +import pytest + +from core.errors import NoResultsError async def test_list_tasks_default(py_api: httpx.AsyncClient) -> None: @@ -26,6 +30,13 @@ async def test_list_tasks_default(py_api: httpx.AsyncClient) -> None: assert "tag" in task +async def test_list_tasks_get(py_api: httpx.AsyncClient) -> None: + """GET /tasks/list with no body also works.""" + response = await py_api.get("/tasks/list") + assert response.status_code == HTTPStatus.OK + assert isinstance(response.json(), list) + + async def test_list_tasks_filter_type(py_api: httpx.AsyncClient) -> None: """Filter by task_type_id returns only tasks of that type.""" response = await py_api.post("/tasks/list", json={"task_type_id": 1}) @@ -44,24 +55,67 @@ async def test_list_tasks_filter_tag(py_api: httpx.AsyncClient) -> None: assert all("OpenML100" in t["tag"] for t in tasks) -async def test_list_tasks_pagination(py_api: httpx.AsyncClient) -> None: - """Pagination returns correct number of results.""" - limit = 5 +@pytest.mark.parametrize("task_id", [1, 59, [1, 2, 3]]) +async def test_list_tasks_filter_task_id( + task_id: int | list[int], py_api: httpx.AsyncClient +) -> None: + """Filter by task_id returns only those tasks.""" + ids = [task_id] if isinstance(task_id, int) else task_id + response = await py_api.post("/tasks/list", json={"task_id": ids}) + assert response.status_code == HTTPStatus.OK + returned_ids = {t["task_id"] for t in response.json()} + assert returned_ids == set(ids) + + +async def test_list_tasks_filter_data_id(py_api: httpx.AsyncClient) -> None: + """Filter by data_id returns only tasks that use that dataset.""" + data_id = 10 + response = await py_api.post("/tasks/list", json={"data_id": [data_id]}) + assert response.status_code == HTTPStatus.OK + tasks = response.json() + assert len(tasks) > 0 + assert all(t["did"] == data_id for t in tasks) + + +async def test_list_tasks_filter_data_name(py_api: httpx.AsyncClient) -> None: + """Filter by data_name returns only tasks whose dataset matches.""" + response = await py_api.post("/tasks/list", json={"data_name": "mfeat-pixel"}) + assert response.status_code == HTTPStatus.OK + tasks = response.json() + assert len(tasks) > 0 + assert all(t["name"] == "mfeat-pixel" for t in tasks) + + +async def test_list_tasks_filter_status_all(py_api: httpx.AsyncClient) -> None: + """Status='all' returns >= results compared to default active-only.""" + active_resp = await py_api.post("/tasks/list", json={}) + all_resp = await py_api.post("/tasks/list", json={"status": "all"}) + assert active_resp.status_code == HTTPStatus.OK + assert all_resp.status_code == HTTPStatus.OK + assert len(all_resp.json()) >= len(active_resp.json()) + + +@pytest.mark.parametrize( + ("limit", "offset"), + [(5, 0), (10, 0), (5, 5)], +) +async def test_list_tasks_pagination(limit: int, offset: int, py_api: httpx.AsyncClient) -> None: + """Pagination limit and offset are respected.""" response = await py_api.post( "/tasks/list", - json={"pagination": {"limit": limit, "offset": 0}}, + json={"pagination": {"limit": limit, "offset": offset}}, ) assert response.status_code == HTTPStatus.OK - assert len(response.json()) == limit + assert len(response.json()) <= limit -async def test_list_tasks_pagination_offset(py_api: httpx.AsyncClient) -> None: - """Offset returns different results than no offset.""" +async def test_list_tasks_pagination_order_stable(py_api: httpx.AsyncClient) -> None: + """Results are ordered by task_id — consecutive pages are in ascending order.""" r1 = await py_api.post("/tasks/list", json={"pagination": {"limit": 5, "offset": 0}}) r2 = await py_api.post("/tasks/list", json={"pagination": {"limit": 5, "offset": 5}}) ids1 = [t["task_id"] for t in r1.json()] ids2 = [t["task_id"] for t in r2.json()] - assert ids1 != ids2 + assert max(ids1) < min(ids2) async def test_list_tasks_number_instances_range(py_api: httpx.AsyncClient) -> None: @@ -76,29 +130,120 @@ async def test_list_tasks_number_instances_range(py_api: httpx.AsyncClient) -> N assert len(tasks) > 0 for task in tasks: qualities = {q["name"]: q["value"] for q in task["quality"]} - assert min_instances <= float(qualities["NumberOfInstances"]) <= max_instances + if "NumberOfInstances" in qualities: + assert min_instances <= float(qualities["NumberOfInstances"]) <= max_instances -async def test_list_tasks_no_results(py_api: httpx.AsyncClient) -> None: - """Nonexistent tag returns 404 NoResultsError.""" - response = await py_api.post("/tasks/list", json={"tag": "nonexistent_tag_xyz"}) - assert response.status_code == HTTPStatus.NOT_FOUND - assert response.headers["content-type"] == "application/problem+json" - error = response.json() - assert error["status"] == HTTPStatus.NOT_FOUND - assert error["code"] == "372" +async def test_list_tasks_inputs_are_basic_subset(py_api: httpx.AsyncClient) -> None: + """Input entries only contain the expected basic input names.""" + basic_inputs = {"source_data", "target_feature", "estimation_procedure", "evaluation_measures"} + response = await py_api.post("/tasks/list", json={"pagination": {"limit": 5, "offset": 0}}) + assert response.status_code == HTTPStatus.OK + for task in response.json(): + for inp in task["input"]: + assert inp["name"] in basic_inputs -async def test_list_tasks_get(py_api: httpx.AsyncClient) -> None: - """GET /tasks/list with no body also works.""" - response = await py_api.get("/tasks/list") +async def test_list_tasks_quality_values_are_strings(py_api: httpx.AsyncClient) -> None: + """Quality values must be strings (to match PHP API behaviour).""" + response = await py_api.post("/tasks/list", json={"pagination": {"limit": 5, "offset": 0}}) assert response.status_code == HTTPStatus.OK - assert isinstance(response.json(), list) + for task in response.json(): + for quality in task["quality"]: + assert isinstance(quality["value"], str) + + +async def test_list_tasks_all_keys_present_even_with_empty_values( + py_api: httpx.AsyncClient, +) -> None: + """Every task has input/quality/tag keys even if they are empty lists.""" + response = await py_api.post("/tasks/list", json={"task_id": [1, 2, 3]}) + assert response.status_code == HTTPStatus.OK + for task in response.json(): + assert "input" in task + assert "quality" in task + assert "tag" in task + + +@pytest.mark.parametrize( + "pagination_override", + [ + {"limit": "abc", "offset": 0}, # Invalid type + {"limit": 5, "offset": "xyz"}, # Invalid type + ], + ids=["bad_limit_type", "bad_offset_type"], +) +async def test_list_tasks_invalid_pagination_type( + pagination_override: dict[str, Any], py_api: httpx.AsyncClient +) -> None: + """Invalid pagination types return 422 Unprocessable Entity.""" + response = await py_api.post( + "/tasks/list", + json={"pagination": pagination_override}, + ) + assert response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY + + +@pytest.mark.parametrize( + ("limit", "offset", "expected_status", "expected_max_results"), + [ + (-10, 0, HTTPStatus.NOT_FOUND, 0), # negative limit clamped to 0 -> No results + (5, -10, HTTPStatus.OK, 5), # negative offset clamped to 0 -> First 5 results + ], + ids=["negative_limit", "negative_offset"], +) +async def test_list_tasks_negative_pagination_safely_clamped( + limit: int, + offset: int, + expected_status: int, + expected_max_results: int, + py_api: httpx.AsyncClient, +) -> None: + """Negative pagination values are safely clamped to 0 instead of causing 500 errors. + + A limit clamped to 0 returns a 372 NoResultsError (404 Not Found). + An offset clamped to 0 simply returns the first page of results (200 OK). + """ + response = await py_api.post( + "/tasks/list", + json={"pagination": {"limit": limit, "offset": offset}}, + ) + assert response.status_code == expected_status + if expected_status == HTTPStatus.OK: + assert len(response.json()) <= expected_max_results + else: + error = response.json() + assert error["type"] == NoResultsError.uri + assert error["code"] == "372" + + +@pytest.mark.parametrize( + "payload", + [ + {"tag": "nonexistent_tag_xyz"}, + {"task_id": [999_999_999]}, + {"data_name": "nonexistent_dataset_xyz"}, + ], + ids=["bad_tag", "bad_task_id", "bad_data_name"], +) +async def test_list_tasks_no_results(payload: dict[str, Any], py_api: httpx.AsyncClient) -> None: + """Filters matching nothing return 404 NoResultsError.""" + response = await py_api.post("/tasks/list", json=payload) + assert response.status_code == HTTPStatus.NOT_FOUND + assert response.headers["content-type"] == "application/problem+json" + error = response.json() + assert error["type"] == NoResultsError.uri + assert error["code"] == "372" # NoResultsError code -async def test_list_tasks_invalid_range_format(py_api: httpx.AsyncClient) -> None: - """Invalid number_instances range returns 422 validation error.""" - response = await py_api.post("/tasks/list", json={"number_instances": "1...2"}) +@pytest.mark.parametrize( + "value", + ["1...2", "abc"], + ids=["triple_dot", "non_numeric"], +) +async def test_list_tasks_invalid_range(value: str, py_api: httpx.AsyncClient) -> None: + """Invalid number_instances format returns 422 Unprocessable Entity.""" + response = await py_api.post("/tasks/list", json={"number_instances": value}) assert response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY From 55a0d60c6568c54d6d66fa9b2aad98434a4f6716 Mon Sep 17 00:00:00 2001 From: saathviksheerla Date: Fri, 27 Mar 2026 20:20:50 +0530 Subject: [PATCH 5/6] chore: resolve merge conflicts --- tests/routers/openml/task_get_test.py | 243 ------------------------ tests/routers/openml/task_list_test.py | 246 +++++++++++++++++++++++++ 2 files changed, 246 insertions(+), 243 deletions(-) create mode 100644 tests/routers/openml/task_list_test.py diff --git a/tests/routers/openml/task_get_test.py b/tests/routers/openml/task_get_test.py index f8715d30..e78bba8b 100644 --- a/tests/routers/openml/task_get_test.py +++ b/tests/routers/openml/task_get_test.py @@ -1,250 +1,7 @@ from http import HTTPStatus -from typing import Any import deepdiff import httpx -import pytest - -from core.errors import NoResultsError - - -async def test_list_tasks_default(py_api: httpx.AsyncClient) -> None: - """Default call returns active tasks with correct shape.""" - response = await py_api.post("/tasks/list", json={}) - assert response.status_code == HTTPStatus.OK - tasks = response.json() - assert isinstance(tasks, list) - assert len(tasks) > 0 - assert all(task["status"] == "active" for task in tasks) - # verify shape of first task - task = tasks[0] - assert "task_id" in task - assert "task_type_id" in task - assert "task_type" in task - assert "did" in task - assert "name" in task - assert "format" in task - assert "status" in task - assert "input" in task - assert "quality" in task - assert "tag" in task - - -async def test_list_tasks_get(py_api: httpx.AsyncClient) -> None: - """GET /tasks/list with no body also works.""" - response = await py_api.get("/tasks/list") - assert response.status_code == HTTPStatus.OK - assert isinstance(response.json(), list) - - -async def test_list_tasks_filter_type(py_api: httpx.AsyncClient) -> None: - """Filter by task_type_id returns only tasks of that type.""" - response = await py_api.post("/tasks/list", json={"task_type_id": 1}) - assert response.status_code == HTTPStatus.OK - tasks = response.json() - assert len(tasks) > 0 - assert all(t["task_type_id"] == 1 for t in tasks) - - -async def test_list_tasks_filter_tag(py_api: httpx.AsyncClient) -> None: - """Filter by tag returns only tasks with that tag.""" - response = await py_api.post("/tasks/list", json={"tag": "OpenML100"}) - assert response.status_code == HTTPStatus.OK - tasks = response.json() - assert len(tasks) > 0 - assert all("OpenML100" in t["tag"] for t in tasks) - - -@pytest.mark.parametrize("task_id", [1, 59, [1, 2, 3]]) -async def test_list_tasks_filter_task_id( - task_id: int | list[int], py_api: httpx.AsyncClient -) -> None: - """Filter by task_id returns only those tasks.""" - ids = [task_id] if isinstance(task_id, int) else task_id - response = await py_api.post("/tasks/list", json={"task_id": ids}) - assert response.status_code == HTTPStatus.OK - returned_ids = {t["task_id"] for t in response.json()} - assert returned_ids == set(ids) - - -async def test_list_tasks_filter_data_id(py_api: httpx.AsyncClient) -> None: - """Filter by data_id returns only tasks that use that dataset.""" - data_id = 10 - response = await py_api.post("/tasks/list", json={"data_id": [data_id]}) - assert response.status_code == HTTPStatus.OK - tasks = response.json() - assert len(tasks) > 0 - assert all(t["did"] == data_id for t in tasks) - - -async def test_list_tasks_filter_data_name(py_api: httpx.AsyncClient) -> None: - """Filter by data_name returns only tasks whose dataset matches.""" - response = await py_api.post("/tasks/list", json={"data_name": "mfeat-pixel"}) - assert response.status_code == HTTPStatus.OK - tasks = response.json() - assert len(tasks) > 0 - assert all(t["name"] == "mfeat-pixel" for t in tasks) - - -async def test_list_tasks_filter_status_all(py_api: httpx.AsyncClient) -> None: - """Status='all' returns >= results compared to default active-only.""" - active_resp = await py_api.post("/tasks/list", json={}) - all_resp = await py_api.post("/tasks/list", json={"status": "all"}) - assert active_resp.status_code == HTTPStatus.OK - assert all_resp.status_code == HTTPStatus.OK - assert len(all_resp.json()) >= len(active_resp.json()) - - -@pytest.mark.parametrize( - ("limit", "offset"), - [(5, 0), (10, 0), (5, 5)], -) -async def test_list_tasks_pagination(limit: int, offset: int, py_api: httpx.AsyncClient) -> None: - """Pagination limit and offset are respected.""" - response = await py_api.post( - "/tasks/list", - json={"pagination": {"limit": limit, "offset": offset}}, - ) - assert response.status_code == HTTPStatus.OK - assert len(response.json()) <= limit - - -async def test_list_tasks_pagination_order_stable(py_api: httpx.AsyncClient) -> None: - """Results are ordered by task_id — consecutive pages are in ascending order.""" - r1 = await py_api.post("/tasks/list", json={"pagination": {"limit": 5, "offset": 0}}) - r2 = await py_api.post("/tasks/list", json={"pagination": {"limit": 5, "offset": 5}}) - ids1 = [t["task_id"] for t in r1.json()] - ids2 = [t["task_id"] for t in r2.json()] - assert max(ids1) < min(ids2) - - -async def test_list_tasks_number_instances_range(py_api: httpx.AsyncClient) -> None: - """number_instances range filter returns tasks whose dataset matches.""" - min_instances, max_instances = 100, 1000 - response = await py_api.post( - "/tasks/list", - json={"number_instances": f"{min_instances}..{max_instances}"}, - ) - assert response.status_code == HTTPStatus.OK - tasks = response.json() - assert len(tasks) > 0 - for task in tasks: - qualities = {q["name"]: q["value"] for q in task["quality"]} - if "NumberOfInstances" in qualities: - assert min_instances <= float(qualities["NumberOfInstances"]) <= max_instances - - -async def test_list_tasks_inputs_are_basic_subset(py_api: httpx.AsyncClient) -> None: - """Input entries only contain the expected basic input names.""" - basic_inputs = {"source_data", "target_feature", "estimation_procedure", "evaluation_measures"} - response = await py_api.post("/tasks/list", json={"pagination": {"limit": 5, "offset": 0}}) - assert response.status_code == HTTPStatus.OK - for task in response.json(): - for inp in task["input"]: - assert inp["name"] in basic_inputs - - -async def test_list_tasks_quality_values_are_strings(py_api: httpx.AsyncClient) -> None: - """Quality values must be strings (to match PHP API behaviour).""" - response = await py_api.post("/tasks/list", json={"pagination": {"limit": 5, "offset": 0}}) - assert response.status_code == HTTPStatus.OK - for task in response.json(): - for quality in task["quality"]: - assert isinstance(quality["value"], str) - - -async def test_list_tasks_all_keys_present_even_with_empty_values( - py_api: httpx.AsyncClient, -) -> None: - """Every task has input/quality/tag keys even if they are empty lists.""" - response = await py_api.post("/tasks/list", json={"task_id": [1, 2, 3]}) - assert response.status_code == HTTPStatus.OK - for task in response.json(): - assert "input" in task - assert "quality" in task - assert "tag" in task - - -@pytest.mark.parametrize( - "pagination_override", - [ - {"limit": "abc", "offset": 0}, # Invalid type - {"limit": 5, "offset": "xyz"}, # Invalid type - ], - ids=["bad_limit_type", "bad_offset_type"], -) -async def test_list_tasks_invalid_pagination_type( - pagination_override: dict[str, Any], py_api: httpx.AsyncClient -) -> None: - """Invalid pagination types return 422 Unprocessable Entity.""" - response = await py_api.post( - "/tasks/list", - json={"pagination": pagination_override}, - ) - assert response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY - - -@pytest.mark.parametrize( - ("limit", "offset", "expected_status", "expected_max_results"), - [ - (-10, 0, HTTPStatus.NOT_FOUND, 0), # negative limit clamped to 0 -> No results - (5, -10, HTTPStatus.OK, 5), # negative offset clamped to 0 -> First 5 results - ], - ids=["negative_limit", "negative_offset"], -) -async def test_list_tasks_negative_pagination_safely_clamped( - limit: int, - offset: int, - expected_status: int, - expected_max_results: int, - py_api: httpx.AsyncClient, -) -> None: - """Negative pagination values are safely clamped to 0 instead of causing 500 errors. - - A limit clamped to 0 returns a 372 NoResultsError (404 Not Found). - An offset clamped to 0 simply returns the first page of results (200 OK). - """ - response = await py_api.post( - "/tasks/list", - json={"pagination": {"limit": limit, "offset": offset}}, - ) - assert response.status_code == expected_status - if expected_status == HTTPStatus.OK: - assert len(response.json()) <= expected_max_results - else: - error = response.json() - assert error["type"] == NoResultsError.uri - assert error["code"] == "372" - - -@pytest.mark.parametrize( - "payload", - [ - {"tag": "nonexistent_tag_xyz"}, - {"task_id": [999_999_999]}, - {"data_name": "nonexistent_dataset_xyz"}, - ], - ids=["bad_tag", "bad_task_id", "bad_data_name"], -) -async def test_list_tasks_no_results(payload: dict[str, Any], py_api: httpx.AsyncClient) -> None: - """Filters matching nothing return 404 NoResultsError.""" - response = await py_api.post("/tasks/list", json=payload) - assert response.status_code == HTTPStatus.NOT_FOUND - assert response.headers["content-type"] == "application/problem+json" - error = response.json() - assert error["type"] == NoResultsError.uri - assert error["code"] == "372" # NoResultsError code - - -@pytest.mark.parametrize( - "value", - ["1...2", "abc"], - ids=["triple_dot", "non_numeric"], -) -async def test_list_tasks_invalid_range(value: str, py_api: httpx.AsyncClient) -> None: - """Invalid number_instances format returns 422 Unprocessable Entity.""" - response = await py_api.post("/tasks/list", json={"number_instances": value}) - assert response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY async def test_get_task(py_api: httpx.AsyncClient) -> None: diff --git a/tests/routers/openml/task_list_test.py b/tests/routers/openml/task_list_test.py new file mode 100644 index 00000000..828147ed --- /dev/null +++ b/tests/routers/openml/task_list_test.py @@ -0,0 +1,246 @@ +from http import HTTPStatus +from typing import Any + +import httpx +import pytest + +from core.errors import NoResultsError + + +async def test_list_tasks_default(py_api: httpx.AsyncClient) -> None: + """Default call returns active tasks with correct shape.""" + response = await py_api.post("/tasks/list", json={}) + assert response.status_code == HTTPStatus.OK + tasks = response.json() + assert isinstance(tasks, list) + assert len(tasks) > 0 + assert all(task["status"] == "active" for task in tasks) + # verify shape of first task + task = tasks[0] + assert "task_id" in task + assert "task_type_id" in task + assert "task_type" in task + assert "did" in task + assert "name" in task + assert "format" in task + assert "status" in task + assert "input" in task + assert "quality" in task + assert "tag" in task + + +async def test_list_tasks_get(py_api: httpx.AsyncClient) -> None: + """GET /tasks/list with no body also works.""" + response = await py_api.get("/tasks/list") + assert response.status_code == HTTPStatus.OK + assert isinstance(response.json(), list) + + +async def test_list_tasks_filter_type(py_api: httpx.AsyncClient) -> None: + """Filter by task_type_id returns only tasks of that type.""" + response = await py_api.post("/tasks/list", json={"task_type_id": 1}) + assert response.status_code == HTTPStatus.OK + tasks = response.json() + assert len(tasks) > 0 + assert all(t["task_type_id"] == 1 for t in tasks) + + +async def test_list_tasks_filter_tag(py_api: httpx.AsyncClient) -> None: + """Filter by tag returns only tasks with that tag.""" + response = await py_api.post("/tasks/list", json={"tag": "OpenML100"}) + assert response.status_code == HTTPStatus.OK + tasks = response.json() + assert len(tasks) > 0 + assert all("OpenML100" in t["tag"] for t in tasks) + + +@pytest.mark.parametrize("task_id", [1, 59, [1, 2, 3]]) +async def test_list_tasks_filter_task_id( + task_id: int | list[int], py_api: httpx.AsyncClient +) -> None: + """Filter by task_id returns only those tasks.""" + ids = [task_id] if isinstance(task_id, int) else task_id + response = await py_api.post("/tasks/list", json={"task_id": ids}) + assert response.status_code == HTTPStatus.OK + returned_ids = {t["task_id"] for t in response.json()} + assert returned_ids == set(ids) + + +async def test_list_tasks_filter_data_id(py_api: httpx.AsyncClient) -> None: + """Filter by data_id returns only tasks that use that dataset.""" + data_id = 10 + response = await py_api.post("/tasks/list", json={"data_id": [data_id]}) + assert response.status_code == HTTPStatus.OK + tasks = response.json() + assert len(tasks) > 0 + assert all(t["did"] == data_id for t in tasks) + + +async def test_list_tasks_filter_data_name(py_api: httpx.AsyncClient) -> None: + """Filter by data_name returns only tasks whose dataset matches.""" + response = await py_api.post("/tasks/list", json={"data_name": "mfeat-pixel"}) + assert response.status_code == HTTPStatus.OK + tasks = response.json() + assert len(tasks) > 0 + assert all(t["name"] == "mfeat-pixel" for t in tasks) + + +async def test_list_tasks_filter_status_all(py_api: httpx.AsyncClient) -> None: + """Status='all' returns >= results compared to default active-only.""" + active_resp = await py_api.post("/tasks/list", json={}) + all_resp = await py_api.post("/tasks/list", json={"status": "all"}) + assert active_resp.status_code == HTTPStatus.OK + assert all_resp.status_code == HTTPStatus.OK + assert len(all_resp.json()) >= len(active_resp.json()) + + +@pytest.mark.parametrize( + ("limit", "offset"), + [(5, 0), (10, 0), (5, 5)], +) +async def test_list_tasks_pagination(limit: int, offset: int, py_api: httpx.AsyncClient) -> None: + """Pagination limit and offset are respected.""" + response = await py_api.post( + "/tasks/list", + json={"pagination": {"limit": limit, "offset": offset}}, + ) + assert response.status_code == HTTPStatus.OK + assert len(response.json()) <= limit + + +async def test_list_tasks_pagination_order_stable(py_api: httpx.AsyncClient) -> None: + """Results are ordered by task_id — consecutive pages are in ascending order.""" + r1 = await py_api.post("/tasks/list", json={"pagination": {"limit": 5, "offset": 0}}) + r2 = await py_api.post("/tasks/list", json={"pagination": {"limit": 5, "offset": 5}}) + ids1 = [t["task_id"] for t in r1.json()] + ids2 = [t["task_id"] for t in r2.json()] + assert max(ids1) < min(ids2) + + +async def test_list_tasks_number_instances_range(py_api: httpx.AsyncClient) -> None: + """number_instances range filter returns tasks whose dataset matches.""" + min_instances, max_instances = 100, 1000 + response = await py_api.post( + "/tasks/list", + json={"number_instances": f"{min_instances}..{max_instances}"}, + ) + assert response.status_code == HTTPStatus.OK + tasks = response.json() + assert len(tasks) > 0 + for task in tasks: + qualities = {q["name"]: q["value"] for q in task["quality"]} + if "NumberOfInstances" in qualities: + assert min_instances <= float(qualities["NumberOfInstances"]) <= max_instances + + +async def test_list_tasks_inputs_are_basic_subset(py_api: httpx.AsyncClient) -> None: + """Input entries only contain the expected basic input names.""" + basic_inputs = {"source_data", "target_feature", "estimation_procedure", "evaluation_measures"} + response = await py_api.post("/tasks/list", json={"pagination": {"limit": 5, "offset": 0}}) + assert response.status_code == HTTPStatus.OK + for task in response.json(): + for inp in task["input"]: + assert inp["name"] in basic_inputs + + +async def test_list_tasks_quality_values_are_strings(py_api: httpx.AsyncClient) -> None: + """Quality values must be strings (to match PHP API behaviour).""" + response = await py_api.post("/tasks/list", json={"pagination": {"limit": 5, "offset": 0}}) + assert response.status_code == HTTPStatus.OK + for task in response.json(): + for quality in task["quality"]: + assert isinstance(quality["value"], str) + + +async def test_list_tasks_all_keys_present_even_with_empty_values( + py_api: httpx.AsyncClient, +) -> None: + """Every task has input/quality/tag keys even if they are empty lists.""" + response = await py_api.post("/tasks/list", json={"task_id": [1, 2, 3]}) + assert response.status_code == HTTPStatus.OK + for task in response.json(): + assert "input" in task + assert "quality" in task + assert "tag" in task + + +@pytest.mark.parametrize( + "pagination_override", + [ + {"limit": "abc", "offset": 0}, # Invalid type + {"limit": 5, "offset": "xyz"}, # Invalid type + ], + ids=["bad_limit_type", "bad_offset_type"], +) +async def test_list_tasks_invalid_pagination_type( + pagination_override: dict[str, Any], py_api: httpx.AsyncClient +) -> None: + """Invalid pagination types return 422 Unprocessable Entity.""" + response = await py_api.post( + "/tasks/list", + json={"pagination": pagination_override}, + ) + assert response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY + + +@pytest.mark.parametrize( + ("limit", "offset", "expected_status", "expected_max_results"), + [ + (-10, 0, HTTPStatus.NOT_FOUND, 0), # negative limit clamped to 0 -> No results + (5, -10, HTTPStatus.OK, 5), # negative offset clamped to 0 -> First 5 results + ], + ids=["negative_limit", "negative_offset"], +) +async def test_list_tasks_negative_pagination_safely_clamped( + limit: int, + offset: int, + expected_status: int, + expected_max_results: int, + py_api: httpx.AsyncClient, +) -> None: + """Negative pagination values are safely clamped to 0 instead of causing 500 errors. + + A limit clamped to 0 returns a 372 NoResultsError (404 Not Found). + An offset clamped to 0 simply returns the first page of results (200 OK). + """ + response = await py_api.post( + "/tasks/list", + json={"pagination": {"limit": limit, "offset": offset}}, + ) + assert response.status_code == expected_status + if expected_status == HTTPStatus.OK: + assert len(response.json()) <= expected_max_results + else: + error = response.json() + assert error["type"] == NoResultsError.uri + assert error["code"] == "372" + + +@pytest.mark.parametrize( + "payload", + [ + {"tag": "nonexistent_tag_xyz"}, + {"task_id": [999_999_999]}, + {"data_name": "nonexistent_dataset_xyz"}, + ], + ids=["bad_tag", "bad_task_id", "bad_data_name"], +) +async def test_list_tasks_no_results(payload: dict[str, Any], py_api: httpx.AsyncClient) -> None: + """Filters matching nothing return 404 NoResultsError.""" + response = await py_api.post("/tasks/list", json=payload) + assert response.status_code == HTTPStatus.NOT_FOUND + assert response.headers["content-type"] == "application/problem+json" + error = response.json() + assert error["type"] == NoResultsError.uri + assert error["code"] == "372" # NoResultsError code + + +@pytest.mark.parametrize( + "value", + ["1...2", "abc"], + ids=["triple_dot", "non_numeric"], +) +async def test_list_tasks_invalid_range(value: str, py_api: httpx.AsyncClient) -> None: + """Invalid number_instances format returns 422 Unprocessable Entity.""" + response = await py_api.post("/tasks/list", json={"number_instances": value}) + assert response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY From cb74d9321dfd4e8f4ec373edc4723bccfede6cb8 Mon Sep 17 00:00:00 2001 From: saathviksheerla Date: Fri, 27 Mar 2026 21:07:33 +0530 Subject: [PATCH 6/6] bot suggestions: improve tests, empty list check --- src/routers/openml/tasks.py | 14 +++++++---- .../openml/migration/tasks_migration_test.py | 24 +++++++++++++++++-- tests/routers/openml/task_list_test.py | 5 +++- 3 files changed, 36 insertions(+), 7 deletions(-) diff --git a/src/routers/openml/tasks.py b/src/routers/openml/tasks.py index 2d8a7610..2c8e4a0c 100644 --- a/src/routers/openml/tasks.py +++ b/src/routers/openml/tasks.py @@ -263,11 +263,17 @@ async def list_tasks( # noqa: PLR0913, PLR0912, C901, PLR0915 clauses.append("AND d.`name` = :data_name") parameters["data_name"] = data_name - if task_id: + if task_id is not None: + if not task_id: + msg = "No tasks match the search criteria." + raise NoResultsError(msg) clauses.append("AND t.`task_id` IN :task_ids") parameters["task_ids"] = task_id - if data_id: + if data_id is not None: + if not data_id: + msg = "No tasks match the search criteria." + raise NoResultsError(msg) clauses.append("AND d.`did` IN :data_ids") parameters["data_ids"] = data_id @@ -318,9 +324,9 @@ async def list_tasks( # noqa: PLR0913, PLR0912, C901, PLR0915 """, # noqa: S608 ) - if task_id: + if task_id is not None: main_query = main_query.bindparams(bindparam("task_ids", expanding=True)) - if data_id: + if data_id is not None: main_query = main_query.bindparams(bindparam("data_ids", expanding=True)) result = await expdb.execute(main_query, parameters=parameters) diff --git a/tests/routers/openml/migration/tasks_migration_test.py b/tests/routers/openml/migration/tasks_migration_test.py index b079fc98..0b031b32 100644 --- a/tests/routers/openml/migration/tasks_migration_test.py +++ b/tests/routers/openml/migration/tasks_migration_test.py @@ -112,9 +112,26 @@ def _normalize_py_task(task: dict[str, Any]) -> dict[str, Any]: ({"type": 1}, {"task_type_id": 1}), # by task type ({"tag": "OpenML100"}, {"tag": "OpenML100"}), # by tag ({"type": 1, "tag": "OpenML100"}, {"task_type_id": 1, "tag": "OpenML100"}), # combined + ({"data_name": "iris"}, {"data_name": "iris"}), # by dataset name + ({"data_id": 61}, {"data_id": [61]}), # by dataset id + ({"data_tag": "study_14"}, {"data_tag": "study_14"}), # by dataset tag + ({"number_instances": "150"}, {"number_instances": "150"}), # quality filter + ( + {"data_id": 61, "number_instances": "150"}, + {"data_id": [61], "number_instances": "150"}, + ), ] -_FILTER_IDS = ["type", "tag", "type_and_tag"] +_FILTER_IDS = [ + "type", + "tag", + "type_and_tag", + "data_name", + "data_id", + "data_tag", + "number_instances", + "data_and_quality", +] @pytest.mark.parametrize( @@ -156,7 +173,10 @@ async def test_list_tasks_equal( assert php_response.status_code == HTTPStatus.OK assert py_response.status_code == HTTPStatus.OK - php_tasks: list[dict[str, Any]] = php_response.json()["tasks"]["task"] + php_tasks_raw = php_response.json()["tasks"]["task"] + php_tasks: list[dict[str, Any]] = ( + php_tasks_raw if isinstance(php_tasks_raw, list) else [php_tasks_raw] + ) py_tasks: list[dict[str, Any]] = [_normalize_py_task(t) for t in py_response.json()] php_ids = {int(t["task_id"]) for t in php_tasks} diff --git a/tests/routers/openml/task_list_test.py b/tests/routers/openml/task_list_test.py index 828147ed..4e488ee0 100644 --- a/tests/routers/openml/task_list_test.py +++ b/tests/routers/openml/task_list_test.py @@ -114,7 +114,10 @@ async def test_list_tasks_pagination_order_stable(py_api: httpx.AsyncClient) -> r2 = await py_api.post("/tasks/list", json={"pagination": {"limit": 5, "offset": 5}}) ids1 = [t["task_id"] for t in r1.json()] ids2 = [t["task_id"] for t in r2.json()] - assert max(ids1) < min(ids2) + assert ids1 == sorted(ids1) + assert ids2 == sorted(ids2) + if ids1 and ids2: + assert max(ids1) < min(ids2) async def test_list_tasks_number_instances_range(py_api: httpx.AsyncClient) -> None: