diff --git a/src/routers/openml/tasktype.py b/src/routers/openml/tasktype.py index 5355e45..44228fa 100644 --- a/src/routers/openml/tasktype.py +++ b/src/routers/openml/tasktype.py @@ -1,4 +1,6 @@ import json +import logging +from collections.abc import Mapping from typing import Annotated, Any, Literal, cast from fastapi import APIRouter, Depends @@ -11,6 +13,7 @@ from routers.dependencies import expdb_connection router = APIRouter(prefix="/tasktype", tags=["tasks"]) +logger = logging.getLogger(__name__) def _normalize_task_type(task_type: Row[Any]) -> dict[str, str | None | list[Any]]: @@ -26,6 +29,36 @@ def _normalize_task_type(task_type: Row[Any]) -> dict[str, str | None | list[Any return ttype +def _extract_data_type_from_api_constraints( + api_constraints: Mapping[str, Any] | str | None, + input_name: str, +) -> str | None: + """Extract string data_type from api_constraints safely.""" + constraint: Mapping[str, Any] | None = None + + if isinstance(api_constraints, str): + try: + loaded = json.loads(api_constraints) + except json.JSONDecodeError: + logger.warning( + "Failed to decode legacy api_constraints JSON for task_type_input '%s'; value=%r", + input_name, + api_constraints, + exc_info=True, + ) + return None + if isinstance(loaded, Mapping): + constraint = loaded + elif isinstance(api_constraints, Mapping): + constraint = api_constraints + + if not constraint: + return None + + data_type = constraint.get("data_type") + return data_type if isinstance(data_type, str) else None + + @router.get(path="/list") async def list_task_types( expdb: Annotated[AsyncConnection, Depends(expdb_connection)], @@ -59,6 +92,7 @@ async def get_task_type( creator.strip(' "') for creator in cast("str", contributors).split(",") ] task_type["creation_date"] = task_type.pop("creationDate") + task_type_inputs = await get_input_for_task_type(task_type_id, expdb) input_types = [] for task_type_input in task_type_inputs: @@ -66,10 +100,15 @@ async def get_task_type( if task_type_input.requirement == "required": input_["requirement"] = task_type_input.requirement input_["name"] = task_type_input.name - # api_constraints is for one input only in the test database (TODO: patch db) - if isinstance(task_type_input.api_constraints, str): - constraint = json.loads(task_type_input.api_constraints) - input_["data_type"] = constraint["data_type"] + + data_type = _extract_data_type_from_api_constraints( + task_type_input.api_constraints, + task_type_input.name, + ) + if data_type is not None: + input_["data_type"] = data_type + input_types.append(input_) + task_type["input"] = input_types return {"task_type": task_type}