-
-
Notifications
You must be signed in to change notification settings - Fork 50
[FEAT] Implement POST /run upload and server-side evaluation pipeline in Python #260
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
5f02cf3
363b9f2
0bd41fd
91e727b
9dd9847
e630804
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,118 @@ | ||
| from __future__ import annotations | ||
|
|
||
| import math | ||
|
|
||
| # --------------------------------------------------------------------------- | ||
| # Individual metrics | ||
| # --------------------------------------------------------------------------- | ||
|
|
||
|
|
||
| def accuracy(y_true: list[str | int], y_pred: list[str | int]) -> float: | ||
| """Fraction of predictions that exactly match the ground truth.""" | ||
| if len(y_true) != len(y_pred): | ||
| msg = f"Length mismatch: {len(y_true)} vs {len(y_pred)}" | ||
| raise ValueError(msg) | ||
| if not y_true: | ||
| return 0.0 | ||
| correct = sum(t == p for t, p in zip(y_true, y_pred, strict=True)) | ||
| return correct / len(y_true) | ||
|
|
||
|
|
||
| def rmse(y_true: list[float], y_pred: list[float]) -> float: | ||
| """Root Mean Squared Error.""" | ||
| if len(y_true) != len(y_pred): | ||
| msg = f"Length mismatch: {len(y_true)} vs {len(y_pred)}" | ||
| raise ValueError(msg) | ||
| if not y_true: | ||
| return 0.0 | ||
| mse = sum((t - p) ** 2 for t, p in zip(y_true, y_pred, strict=True)) / len(y_true) | ||
| return math.sqrt(mse) | ||
|
|
||
|
|
||
| def mean_absolute_error(y_true: list[float], y_pred: list[float]) -> float: | ||
| """Mean Absolute Error.""" | ||
| if len(y_true) != len(y_pred): | ||
| msg = f"Length mismatch: {len(y_true)} vs {len(y_pred)}" | ||
| raise ValueError(msg) | ||
| if not y_true: | ||
| return 0.0 | ||
| return sum(abs(t - p) for t, p in zip(y_true, y_pred, strict=True)) / len(y_true) | ||
|
|
||
|
|
||
| def auc(y_true: list[int], y_score: list[float]) -> float: | ||
| """Binary ROC AUC via the Wilcoxon-Mann-Whitney U statistic. | ||
|
|
||
| Mathematically equivalent to the area under the ROC curve. | ||
| Counts concordant pairs: for each (positive, negative) pair, score 1 if | ||
| y_score[pos] > y_score[neg], 0.5 if tied, 0 otherwise, then normalise. | ||
|
|
||
| y_true: list of 0/1 ground-truth labels. | ||
| y_score: list of predicted probabilities for the positive class (label=1). | ||
| """ | ||
| if len(y_true) != len(y_score): | ||
| msg = f"Length mismatch: {len(y_true)} vs {len(y_score)}" | ||
| raise ValueError(msg) | ||
| if not y_true: | ||
| return 0.0 | ||
|
|
||
| n_pos = sum(y_true) | ||
| n_neg = len(y_true) - n_pos | ||
| if n_pos == 0 or n_neg == 0: | ||
| return 0.0 | ||
|
|
||
| pos_scores = [s for t, s in zip(y_true, y_score, strict=True) if t == 1] | ||
| neg_scores = [s for t, s in zip(y_true, y_score, strict=True) if t == 0] | ||
|
|
||
| concordant = 0.0 | ||
| for ps in pos_scores: | ||
| for ns in neg_scores: | ||
| if ps > ns: | ||
| concordant += 1.0 | ||
| elif ps == ns: | ||
| concordant += 0.5 | ||
|
|
||
| return concordant / (n_pos * n_neg) | ||
|
coderabbitai[bot] marked this conversation as resolved.
Outdated
|
||
|
|
||
|
|
||
| # --------------------------------------------------------------------------- | ||
| # Dispatcher | ||
| # --------------------------------------------------------------------------- | ||
|
|
||
| #: Task type IDs from the OpenML schema | ||
| TASK_TYPE_SUPERVISED_CLASSIFICATION = 1 | ||
| TASK_TYPE_SUPERVISED_REGRESSION = 2 | ||
|
|
||
|
|
||
| def compute_metrics( | ||
| task_type_id: int, | ||
| y_true: list[str | int | float], | ||
| y_pred: list[str | int | float], | ||
| y_score: list[float] | None = None, | ||
| ) -> dict[str, float]: | ||
| """Compute all applicable metrics for the given task type. | ||
|
|
||
| Returns a dict of {measure_name: value} using the same names found in | ||
| the OpenML `math_function` table (e.g. 'predictive_accuracy', 'area_under_roc_curve'). | ||
| """ | ||
| results: dict[str, float] = {} | ||
|
|
||
| if task_type_id == TASK_TYPE_SUPERVISED_CLASSIFICATION: | ||
| str_true = [str(v) for v in y_true] | ||
| str_pred = [str(v) for v in y_pred] | ||
| results["predictive_accuracy"] = accuracy(str_true, str_pred) | ||
|
|
||
| # AUC only when binary and scores are provided | ||
| unique_labels = set(str_true) | ||
| if y_score is not None and len(unique_labels) == 2: # noqa: PLR2004 | ||
| # Map the positive class (lexicographically larger, matching OpenML convention) | ||
| pos_label = max(unique_labels) | ||
| int_true = [1 if str(v) == pos_label else 0 for v in y_true] | ||
| results["area_under_roc_curve"] = auc(int_true, y_score) | ||
|
|
||
| elif task_type_id == TASK_TYPE_SUPERVISED_REGRESSION: | ||
| float_true = [float(v) for v in y_true] | ||
| float_pred = [float(v) for v in y_pred] | ||
| results["root_mean_squared_error"] = rmse(float_true, float_pred) | ||
| results["mean_absolute_error"] = mean_absolute_error(float_true, float_pred) | ||
|
|
||
| return results | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,101 @@ | ||
| from __future__ import annotations | ||
|
|
||
| import random | ||
| import re | ||
|
|
||
| SplitEntry = dict[str, int | str] | ||
|
|
||
|
|
||
| def generate_splits( | ||
| n_samples: int, | ||
| n_folds: int, | ||
| n_repeats: int, | ||
| *, | ||
| seed: int = 0, | ||
| ) -> list[SplitEntry]: | ||
| """Generate cross-validation splits deterministically. | ||
|
|
||
| Returns a flat list of dicts with keys: | ||
| repeat, fold, rowid, type ('TRAIN' or 'TEST') | ||
| """ | ||
| entries: list[SplitEntry] = [] | ||
| rng = random.Random(seed) # noqa: S311 | ||
|
|
||
| for repeat in range(n_repeats): | ||
| indices = list(range(n_samples)) | ||
| rng.shuffle(indices) | ||
|
|
||
| for fold in range(n_folds): | ||
| for row_pos, rowid in enumerate(indices): | ||
| split_type = "TEST" if row_pos % n_folds == fold else "TRAIN" | ||
| entries.append( | ||
|
coderabbitai[bot] marked this conversation as resolved.
|
||
| { | ||
| "repeat": repeat, | ||
| "fold": fold, | ||
| "rowid": rowid, | ||
| "type": split_type, | ||
| }, | ||
| ) | ||
|
|
||
| return entries | ||
|
|
||
|
|
||
| _ARFF_DATA_SECTION = re.compile(r"@[Dd][Aa][Tt][Aa]") | ||
|
|
||
|
|
||
| def parse_arff_splits(arff_content: str) -> list[SplitEntry]: | ||
| """Parse an OpenML splits ARFF file into the same list-of-dict format. | ||
|
|
||
| Expected ARFF columns (in order): type, rowid, repeat, fold | ||
| (This is the column order used by OpenML's split ARFF files.) | ||
| """ | ||
| in_data = False | ||
| entries: list[SplitEntry] = [] | ||
|
|
||
| for line in arff_content.splitlines(): | ||
| stripped = line.strip() | ||
| if not stripped or stripped.startswith("%"): | ||
| continue | ||
| if _ARFF_DATA_SECTION.match(stripped): | ||
| in_data = True | ||
| continue | ||
| if not in_data: | ||
| continue | ||
|
|
||
| parts = [p.strip() for p in stripped.split(",")] | ||
| if len(parts) < 4: # noqa: PLR2004 | ||
| continue | ||
| split_type, rowid_s, repeat_s, fold_s = parts[:4] | ||
| try: | ||
| entries.append( | ||
| { | ||
| "repeat": int(repeat_s), | ||
| "fold": int(fold_s), | ||
| "rowid": int(rowid_s), | ||
| "type": split_type.strip("'\""), | ||
| }, | ||
| ) | ||
| except ValueError: | ||
| continue | ||
|
|
||
| return entries | ||
|
|
||
|
|
||
| def build_fold_index( | ||
| splits: list[SplitEntry], | ||
| repeat: int = 0, | ||
| ) -> dict[int, tuple[list[int], list[int]]]: | ||
| """Build a dict of fold → (train_indices, test_indices) for a given repeat.""" | ||
| folds: dict[int, tuple[list[int], list[int]]] = {} | ||
| for entry in splits: | ||
| if entry["repeat"] != repeat: | ||
| continue | ||
| fold = int(entry["fold"]) | ||
| rowid = int(entry["rowid"]) | ||
| if fold not in folds: | ||
| folds[fold] = ([], []) | ||
| if entry["type"] == "TRAIN": | ||
| folds[fold][0].append(rowid) | ||
| else: | ||
| folds[fold][1].append(rowid) | ||
|
coderabbitai[bot] marked this conversation as resolved.
|
||
| return folds | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,65 @@ | ||
| from __future__ import annotations | ||
|
|
||
| import datetime | ||
| from collections.abc import Sequence | ||
| from typing import cast | ||
|
|
||
| from sqlalchemy import Connection, Row, text | ||
|
|
||
|
|
||
| def enqueue(run_id: int, expdb: Connection) -> None: | ||
| """Insert a new pending processing entry for the given run.""" | ||
| expdb.execute( | ||
| text( | ||
| """ | ||
| INSERT INTO processing_run(`run_id`, `status`, `date`) | ||
| VALUES (:run_id, 'pending', :date) | ||
| """, | ||
| ), | ||
| parameters={"run_id": run_id, "date": datetime.datetime.now()}, | ||
| ) | ||
|
|
||
|
|
||
| def get_pending(expdb: Connection) -> Sequence[Row]: | ||
| """Return all processing_run rows whose status is 'pending'.""" | ||
| return cast( | ||
| "Sequence[Row]", | ||
| expdb.execute( | ||
| text( | ||
| """ | ||
| SELECT `run_id`, `status`, `date` | ||
| FROM processing_run | ||
| WHERE `status` = 'pending' | ||
| ORDER BY `date` ASC | ||
| """, | ||
| ), | ||
| ).all(), | ||
| ) | ||
|
|
||
|
|
||
| def mark_done(run_id: int, expdb: Connection) -> None: | ||
| """Mark a processing_run entry as successfully completed.""" | ||
| expdb.execute( | ||
| text( | ||
| """ | ||
| UPDATE processing_run | ||
| SET `status` = 'done' | ||
| WHERE `run_id` = :run_id | ||
| """, | ||
| ), | ||
| parameters={"run_id": run_id}, | ||
| ) | ||
|
|
||
|
|
||
| def mark_error(run_id: int, error_message: str, expdb: Connection) -> None: | ||
| """Mark a processing_run entry as failed and store the error message.""" | ||
| expdb.execute( | ||
| text( | ||
| """ | ||
| UPDATE processing_run | ||
| SET `status` = 'error', `error` = :error_message | ||
| WHERE `run_id` = :run_id | ||
| """, | ||
| ), | ||
| parameters={"run_id": run_id, "error_message": error_message}, | ||
| ) | ||
|
Comment on lines
+23
to
+80
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Pending work items are not atomically claimed before execution.
🤖 Prompt for AI Agents |
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,111 @@ | ||
| from __future__ import annotations | ||
|
|
||
| import datetime | ||
| from collections.abc import Sequence | ||
| from typing import cast | ||
|
|
||
| from sqlalchemy import Connection, Row, text | ||
|
|
||
|
|
||
| def get(run_id: int, expdb: Connection) -> Row | None: | ||
| """Fetch a single run row by its primary key.""" | ||
| return expdb.execute( | ||
| text( | ||
| """ | ||
| SELECT `rid`, `task_id`, `implementation_id` AS `flow_id`, | ||
| `uploader`, `upload_time`, `setup_string` | ||
| FROM run | ||
| WHERE `rid` = :run_id | ||
| """, | ||
| ), | ||
| parameters={"run_id": run_id}, | ||
| ).one_or_none() | ||
|
|
||
|
|
||
| def create( | ||
| *, | ||
| task_id: int, | ||
| flow_id: int, | ||
| uploader_id: int, | ||
| setup_string: str | None, | ||
| expdb: Connection, | ||
| ) -> int: | ||
| """Insert a new run row and return the generated run_id.""" | ||
| expdb.execute( | ||
| text( | ||
| """ | ||
| INSERT INTO run( | ||
| `task_id`, `implementation_id`, `uploader`, | ||
| `upload_time`, `setup_string` | ||
| ) | ||
| VALUES (:task_id, :flow_id, :uploader_id, :upload_time, :setup_string) | ||
| """, | ||
| ), | ||
| parameters={ | ||
| "task_id": task_id, | ||
| "flow_id": flow_id, | ||
| "uploader_id": uploader_id, | ||
| "upload_time": datetime.datetime.now(), | ||
| "setup_string": setup_string, | ||
| }, | ||
| ) | ||
| row = expdb.execute(text("SELECT LAST_INSERT_ID()")).one() | ||
| return int(row[0]) | ||
|
|
||
|
|
||
| def get_tags(run_id: int, expdb: Connection) -> list[str]: | ||
| """Return all tags for a given run.""" | ||
| rows = expdb.execute( | ||
| text( | ||
| """ | ||
| SELECT `tag` | ||
| FROM run_tag | ||
| WHERE `id` = :run_id | ||
| """, | ||
| ), | ||
| parameters={"run_id": run_id}, | ||
| ) | ||
| return [row.tag for row in rows] | ||
|
|
||
|
|
||
| def get_evaluations(run_id: int, expdb: Connection) -> Sequence[Row]: | ||
| """Return all evaluation measure rows for a given run.""" | ||
| return cast( | ||
| "Sequence[Row]", | ||
| expdb.execute( | ||
| text( | ||
| """ | ||
| SELECT `function`, `value`, `array_data` | ||
| FROM run_measure | ||
| WHERE `run_id` = :run_id | ||
| """, | ||
| ), | ||
| parameters={"run_id": run_id}, | ||
| ).all(), | ||
| ) | ||
|
|
||
|
|
||
| def store_evaluation( | ||
| *, | ||
| run_id: int, | ||
| function: str, | ||
| value: float | None, | ||
| array_data: str | None = None, | ||
| expdb: Connection, | ||
| ) -> None: | ||
| """Insert or update a single evaluation measure for a run.""" | ||
| expdb.execute( | ||
| text( | ||
| """ | ||
| INSERT INTO run_measure(`run_id`, `function`, `value`, `array_data`) | ||
| VALUES (:run_id, :function, :value, :array_data) | ||
| ON DUPLICATE KEY UPDATE `value` = :value, `array_data` = :array_data | ||
| """, | ||
| ), | ||
| parameters={ | ||
| "run_id": run_id, | ||
| "function": function, | ||
| "value": value, | ||
| "array_data": array_data, | ||
| }, | ||
| ) |
Uh oh!
There was an error while loading. Please reload this page.