Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 43 additions & 4 deletions src/routers/openml/tasktype.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import json
import logging
from collections.abc import Mapping
from typing import Annotated, Any, Literal, cast

from fastapi import APIRouter, Depends
Expand All @@ -11,6 +13,7 @@
from routers.dependencies import expdb_connection

router = APIRouter(prefix="/tasktype", tags=["tasks"])
logger = logging.getLogger(__name__)


def _normalize_task_type(task_type: Row[Any]) -> dict[str, str | None | list[Any]]:
Expand All @@ -26,6 +29,36 @@ def _normalize_task_type(task_type: Row[Any]) -> dict[str, str | None | list[Any
return ttype


def _extract_data_type_from_api_constraints(
api_constraints: Mapping[str, Any] | str | None,
input_name: str,
) -> str | None:
"""Extract string data_type from api_constraints safely."""
constraint: Mapping[str, Any] | None = None

if isinstance(api_constraints, str):
try:
loaded = json.loads(api_constraints)
except json.JSONDecodeError:
logger.warning(
"Failed to decode legacy api_constraints JSON for task_type_input '%s'; value=%r",
input_name,
api_constraints,
exc_info=True,
)
return None
if isinstance(loaded, Mapping):
constraint = loaded
elif isinstance(api_constraints, Mapping):
constraint = api_constraints

if not constraint:
return None

data_type = constraint.get("data_type")
return data_type if isinstance(data_type, str) else None


@router.get(path="/list")
async def list_task_types(
expdb: Annotated[AsyncConnection, Depends(expdb_connection)],
Expand Down Expand Up @@ -59,17 +92,23 @@ async def get_task_type(
creator.strip(' "') for creator in cast("str", contributors).split(",")
]
task_type["creation_date"] = task_type.pop("creationDate")

task_type_inputs = await get_input_for_task_type(task_type_id, expdb)
input_types = []
for task_type_input in task_type_inputs:
input_ = {}
if task_type_input.requirement == "required":
input_["requirement"] = task_type_input.requirement
input_["name"] = task_type_input.name
# api_constraints is for one input only in the test database (TODO: patch db)
if isinstance(task_type_input.api_constraints, str):
constraint = json.loads(task_type_input.api_constraints)
input_["data_type"] = constraint["data_type"]

data_type = _extract_data_type_from_api_constraints(
task_type_input.api_constraints,
task_type_input.name,
)
if data_type is not None:
input_["data_type"] = data_type

input_types.append(input_)

task_type["input"] = input_types
return {"task_type": task_type}