Skip to content

Commit 5b63e06

Browse files
committed
[ENH] Implement tagging for OpenMLSetup (openml#1686)
- Make OpenMLSetup inherit from OpenMLBase to enable tagging - Add OpenMLSetup to _get_rest_api_type_alias mapping - Add test_tagging test for push_tag and remove_tag
1 parent 7feb2a3 commit 5b63e06

3 files changed

Lines changed: 52 additions & 29 deletions

File tree

openml/setups/setup.py

Lines changed: 34 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
# License: BSD 3-Clause
22
from __future__ import annotations
33

4+
from collections.abc import Sequence
45
from dataclasses import asdict, dataclass
56
from typing import Any
67

7-
import openml.config
88
import openml.flows
9+
from openml.base import OpenMLBase
910

1011

11-
@dataclass
12-
class OpenMLSetup:
12+
class OpenMLSetup(OpenMLBase):
1313
"""Setup object (a.k.a. Configuration).
1414
1515
Parameters
@@ -22,33 +22,32 @@ class OpenMLSetup:
2222
The setting of the parameters
2323
"""
2424

25-
setup_id: int
26-
flow_id: int
27-
parameters: dict[int, Any] | None
28-
29-
def __post_init__(self) -> None:
30-
if not isinstance(self.setup_id, int):
25+
def __init__(
26+
self,
27+
setup_id: int,
28+
flow_id: int,
29+
parameters: dict[int, Any] | None,
30+
) -> None:
31+
if not isinstance(setup_id, int):
3132
raise ValueError("setup id should be int")
3233

33-
if not isinstance(self.flow_id, int):
34+
if not isinstance(flow_id, int):
3435
raise ValueError("flow id should be int")
3536

36-
if self.parameters is not None and not isinstance(self.parameters, dict):
37+
if parameters is not None and not isinstance(parameters, dict):
3738
raise ValueError("parameters should be dict")
3839

39-
def _to_dict(self) -> dict[str, Any]:
40-
return {
41-
"setup_id": self.setup_id,
42-
"flow_id": self.flow_id,
43-
"parameters": {p.id: p._to_dict() for p in self.parameters.values()}
44-
if self.parameters is not None
45-
else None,
46-
}
40+
self.setup_id = setup_id
41+
self.flow_id = flow_id
42+
self.parameters = parameters
4743

48-
def __repr__(self) -> str:
49-
header = "OpenML Setup"
50-
header = f"{header}\n{'=' * len(header)}\n"
44+
@property
45+
def id(self) -> int | None:
46+
"""The ID of the setup."""
47+
return self.setup_id
5148

49+
def _get_repr_body_fields(self) -> Sequence[tuple[str, str | int | list[str] | None]]:
50+
"""Collect all information to display in the __repr__ body."""
5251
fields = {
5352
"Setup ID": self.setup_id,
5453
"Flow ID": self.flow_id,
@@ -57,15 +56,21 @@ def __repr__(self) -> str:
5756
len(self.parameters) if self.parameters is not None else float("nan")
5857
),
5958
}
60-
61-
# determines the order in which the information will be printed
6259
order = ["Setup ID", "Flow ID", "Flow URL", "# of Parameters"]
63-
_fields = [(key, fields[key]) for key in order if key in fields]
60+
return [(key, fields[key]) for key in order if key in fields]
6461

65-
longest_field_name_length = max(len(name) for name, _ in _fields)
66-
field_line_format = f"{{:.<{longest_field_name_length}}}: {{}}"
67-
body = "\n".join(field_line_format.format(name, value) for name, value in _fields)
68-
return header + body
62+
def _to_dict(self) -> dict[str, Any]:
63+
return {
64+
"setup_id": self.setup_id,
65+
"flow_id": self.flow_id,
66+
"parameters": {p.id: p._to_dict() for p in self.parameters.values()}
67+
if self.parameters is not None
68+
else None,
69+
}
70+
71+
def _parse_publish_response(self, xml_response: dict[str, str]) -> None:
72+
msg = "Setups cannot be published directly."
73+
raise NotImplementedError(msg)
6974

7075

7176
@dataclass

openml/utils/_openml.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ def _get_rest_api_type_alias(oml_object: OpenMLBase) -> str:
9999
rest_api_mapping: list[tuple[type | tuple, str]] = [
100100
(openml.datasets.OpenMLDataset, "data"),
101101
(openml.flows.OpenMLFlow, "flow"),
102+
(openml.setups.OpenMLSetup, "setup"),
102103
(openml.tasks.OpenMLTask, "task"),
103104
(openml.runs.OpenMLRun, "run"),
104105
((openml.study.OpenMLStudy, openml.study.OpenMLBenchmarkSuite), "study"),

tests/test_setups/test_setup_functions.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,3 +189,20 @@ def test_get_uncached_setup(self):
189189
openml.config.set_root_cache_directory(self.static_cache_dir)
190190
with pytest.raises(openml.exceptions.OpenMLCacheException):
191191
openml.setups.functions._get_cached_setup(10)
192+
193+
@pytest.mark.test_server()
194+
def test_tagging(self):
195+
setups = openml.setups.list_setups(size=1)
196+
setup_id = next(iter(setups.keys()))
197+
setup = openml.setups.get_setup(setup_id)
198+
unique_indicator = str(time.time()).replace(".", "")
199+
tag = f"test_tag_TestSetup_{unique_indicator}"
200+
tagged_setups = openml.setups.list_setups(tag=tag)
201+
assert len(tagged_setups) == 0
202+
setup.push_tag(tag)
203+
tagged_setups = openml.setups.list_setups(tag=tag)
204+
assert len(tagged_setups) == 1
205+
assert setup_id in tagged_setups
206+
setup.remove_tag(tag)
207+
tagged_setups = openml.setups.list_setups(tag=tag)
208+
assert len(tagged_setups) == 0

0 commit comments

Comments
 (0)