-
Notifications
You must be signed in to change notification settings - Fork 6
Expand file tree
/
Copy pathtest_edge_config.py
More file actions
309 lines (239 loc) · 11.6 KB
/
test_edge_config.py
File metadata and controls
309 lines (239 loc) · 11.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
from datetime import datetime, timezone
import pytest
from groundlight.edge import (
DEFAULT,
DISABLED,
EDGE_ANSWERS_WITH_ESCALATION,
NO_CLOUD,
DetectorsConfig,
EdgeEndpointConfig,
GlobalConfig,
InferenceConfig,
)
from model import Detector, DetectorTypeEnum
CUSTOM_REFRESH_RATE = 10.0
CUSTOM_AUDIT_RATE = 0.0
REFRESH_RATE_SECONDS = 15.0
def _make_detector(detector_id: str) -> Detector:
return Detector(
id=detector_id,
type=DetectorTypeEnum.detector,
created_at=datetime.now(timezone.utc),
name="test detector",
query="Is there a dog?",
group_name="default",
metadata=None,
mode="BINARY",
mode_configuration=None,
)
def test_add_detector_allows_equivalent_named_inference_config():
"""Allows reusing the same named inference config with equivalent values."""
detectors_config = DetectorsConfig()
detectors_config.add_detector(
"det_1",
InferenceConfig(
name="custom_config",
always_return_edge_prediction=True,
min_time_between_escalations=0.5,
),
)
detectors_config.add_detector(
"det_2",
InferenceConfig(
name="custom_config",
always_return_edge_prediction=True,
min_time_between_escalations=0.5,
),
)
assert len(detectors_config.detectors) == 2 # noqa: PLR2004
assert list(detectors_config.edge_inference_configs.keys()) == ["custom_config"]
def test_add_detector_rejects_different_named_inference_config():
"""Rejects conflicting inference config values under the same name."""
detectors_config = DetectorsConfig()
detectors_config.add_detector("det_1", InferenceConfig(name="custom_config"))
with pytest.raises(ValueError, match="different inference config named 'custom_config'"):
detectors_config.add_detector(
"det_2",
InferenceConfig(name="custom_config", always_return_edge_prediction=True),
)
def test_add_detector_rejects_duplicate_detector_id():
"""Rejects adding the same detector ID more than once."""
detectors_config = DetectorsConfig()
detectors_config.add_detector("det_1", DEFAULT)
with pytest.raises(ValueError, match="already exists"):
detectors_config.add_detector("det_1", DEFAULT)
def test_constructor_rejects_duplicate_detector_ids():
"""Rejects duplicated detector IDs in constructor input."""
with pytest.raises(ValueError, match="Duplicate detector IDs"):
DetectorsConfig(
edge_inference_configs={"default": DEFAULT},
detectors=[
{"detector_id": "det_1", "edge_inference_config": "default"},
{"detector_id": "det_1", "edge_inference_config": "default"},
],
)
def test_constructor_rejects_mismatched_inference_config_key_and_name():
"""Rejects inference config dict keys that do not match config names."""
with pytest.raises(ValueError, match="must match InferenceConfig.name"):
DetectorsConfig(
edge_inference_configs={"default": InferenceConfig(name="not_default")},
detectors=[],
)
def test_constructor_accepts_matching_inference_config_key_and_name():
"""Accepts constructor input when key/name pairs are consistent."""
config = DetectorsConfig(
edge_inference_configs={"default": InferenceConfig(name="default")},
detectors=[{"detector_id": "det_1", "edge_inference_config": "default"}],
)
assert list(config.edge_inference_configs.keys()) == ["default"]
assert [detector.detector_id for detector in config.detectors] == ["det_1"]
def test_constructor_hydrates_inference_config_name_from_dict_key():
"""Hydrates inference config names from payload dict keys."""
config = DetectorsConfig(
edge_inference_configs={"default": {"enabled": True}},
detectors=[{"detector_id": "det_1", "edge_inference_config": "default"}],
)
assert config.edge_inference_configs["default"].name == "default"
def test_constructor_rejects_detector_map_input():
"""Rejects detector maps and requires detector list payloads."""
with pytest.raises(ValueError):
DetectorsConfig(
edge_inference_configs={"default": {"enabled": True}},
detectors={"det_1": {"detector_id": "det_1", "edge_inference_config": "default"}},
)
def test_constructor_rejects_undefined_inference_config_reference():
"""Rejects detector entries that reference missing inference configs."""
with pytest.raises(ValueError, match="not defined"):
DetectorsConfig(
edge_inference_configs={},
detectors=[{"detector_id": "det_1", "edge_inference_config": "does_not_exist"}],
)
def test_edge_endpoint_config_add_detector_uses_shared_config_logic():
"""Adds detectors via EdgeEndpointConfig and preserves inferred config mapping."""
config = EdgeEndpointConfig()
config.add_detector("det_1", NO_CLOUD)
config.add_detector("det_2", EDGE_ANSWERS_WITH_ESCALATION)
config.add_detector("det_3", DEFAULT)
assert [detector.detector_id for detector in config.detectors] == ["det_1", "det_2", "det_3"]
assert set(config.edge_inference_configs.keys()) == {"no_cloud", "edge_answers_with_escalation", "default"}
def test_add_detector_accepts_detector_object():
"""Accepts Detector objects in add_detector."""
config = EdgeEndpointConfig()
config.add_detector(_make_detector("det_1"), DEFAULT)
assert [detector.detector_id for detector in config.detectors] == ["det_1"]
def test_disabled_preset_can_be_used():
"""Allows assigning the DISABLED inference preset to a detector."""
config = EdgeEndpointConfig()
config.add_detector("det_1", DISABLED)
assert [detector.edge_inference_config for detector in config.detectors] == ["disabled"]
assert config.edge_inference_configs["disabled"] == DISABLED
def test_detectors_config_to_payload_shape():
"""Serializes detector-scoped payload with expected top-level keys."""
detectors_config = DetectorsConfig()
detectors_config.add_detector("det_1", DEFAULT)
detectors_config.add_detector("det_2", NO_CLOUD)
payload = detectors_config.to_payload()
assert len(payload["detectors"]) == 2 # noqa: PLR2004
assert set(payload["edge_inference_configs"].keys()) == {"default", "no_cloud"}
def test_edge_endpoint_config_accepts_top_level_payload_shape():
"""Accepts the top-level edge endpoint payload shape used by APIs."""
config = EdgeEndpointConfig.model_validate({
"global_config": {"refresh_rate": CUSTOM_REFRESH_RATE},
"edge_inference_configs": {"default": {"enabled": True}},
"detectors": [{"detector_id": "det_1", "edge_inference_config": "default"}],
})
assert config.global_config.refresh_rate == CUSTOM_REFRESH_RATE
assert [detector.detector_id for detector in config.detectors] == ["det_1"]
def test_edge_endpoint_config_from_yaml_accepts_yaml_text():
"""Parses edge-endpoint YAML text using EdgeEndpointConfig.from_yaml."""
config = EdgeEndpointConfig.from_yaml(yaml_str=f"""
global_config:
refresh_rate: {REFRESH_RATE_SECONDS}
edge_inference_configs:
default:
enabled: true
detectors:
- detector_id: det_1
edge_inference_config: default
""")
assert config.global_config.refresh_rate == REFRESH_RATE_SECONDS
assert [detector.detector_id for detector in config.detectors] == ["det_1"]
def test_edge_endpoint_config_from_yaml_accepts_filename(tmp_path):
"""Parses edge-endpoint YAML from a file path."""
config_file = tmp_path / "edge-config.yaml"
config_file.write_text(
"global_config: {}\n"
"edge_inference_configs:\n"
" default:\n"
" enabled: true\n"
"detectors:\n"
" - detector_id: det_1\n"
" edge_inference_config: default\n"
)
config = EdgeEndpointConfig.from_yaml(filename=str(config_file))
assert [detector.detector_id for detector in config.detectors] == ["det_1"]
def test_edge_endpoint_config_from_yaml_requires_exactly_one_input():
"""Rejects missing input and mixed filename/yaml_str input."""
with pytest.raises(ValueError, match="Either filename or yaml_str must be provided"):
EdgeEndpointConfig.from_yaml()
with pytest.raises(ValueError, match="Only one of filename or yaml_str can be provided"):
EdgeEndpointConfig.from_yaml(filename="a.yaml", yaml_str="global_config: {}")
with pytest.raises(ValueError, match="filename must be a non-empty path"):
EdgeEndpointConfig.from_yaml(filename=" ")
def test_edge_endpoint_config_ignores_extra_fields_at_all_levels():
"""Unknown fields are silently ignored at every nesting level for forward compatibility."""
config = EdgeEndpointConfig.model_validate({
"global_config": {"refresh_rate": REFRESH_RATE_SECONDS, "unknown_global_field": "ignored"},
"edge_inference_configs": {
"default": {"enabled": True, "unknown_inference_field": 42},
},
"detectors": [
{"detector_id": "det_1", "edge_inference_config": "default", "unknown_detector_field": [1, 2]},
],
"unknown_top_level_field": True,
})
assert config.global_config.refresh_rate == REFRESH_RATE_SECONDS
assert config.edge_inference_configs["default"].enabled is True
assert config.detectors[0].detector_id == "det_1"
def test_model_dump_shape_for_edge_endpoint_config():
"""Serializes full edge endpoint config in wire payload shape."""
config = EdgeEndpointConfig(
global_config=GlobalConfig(refresh_rate=CUSTOM_REFRESH_RATE, confident_audit_rate=CUSTOM_AUDIT_RATE)
)
config.add_detector("det_1", DEFAULT)
config.add_detector("det_2", EDGE_ANSWERS_WITH_ESCALATION)
config.add_detector("det_3", NO_CLOUD)
payload = config.to_payload()
assert payload["global_config"]["refresh_rate"] == CUSTOM_REFRESH_RATE
assert payload["global_config"]["confident_audit_rate"] == CUSTOM_AUDIT_RATE
assert len(payload["detectors"]) == 3 # noqa: PLR2004
assert set(payload["edge_inference_configs"].keys()) == {"default", "edge_answers_with_escalation", "no_cloud"}
def test_edge_endpoint_config_from_payload_round_trip():
"""Round-trips edge endpoint config through payload helpers."""
config = EdgeEndpointConfig()
config.add_detector("det_1", DEFAULT)
config.add_detector("det_2", NO_CLOUD)
payload = config.to_payload()
reconstructed = EdgeEndpointConfig.from_payload(payload)
assert reconstructed == config
def test_edge_endpoint_config_from_payload_accepts_literal_payload():
"""Constructs EdgeEndpointConfig from a literal payload dictionary."""
payload = {
"global_config": {"refresh_rate": REFRESH_RATE_SECONDS},
"edge_inference_configs": {"default": {"enabled": True}},
"detectors": [{"detector_id": "det_1", "edge_inference_config": "default"}],
}
config = EdgeEndpointConfig.from_payload(payload)
assert config.global_config.refresh_rate == REFRESH_RATE_SECONDS
assert config.edge_inference_configs["default"].name == "default"
assert [detector.detector_id for detector in config.detectors] == ["det_1"]
def test_inference_config_validation_errors():
"""Raises on invalid inference config flag combinations and values."""
with pytest.raises(ValueError, match="disable_cloud_escalation"):
InferenceConfig(name="bad", disable_cloud_escalation=True)
with pytest.raises(ValueError, match="cannot be less than 0.0"):
InferenceConfig(
name="bad_escalation_interval",
always_return_edge_prediction=True,
min_time_between_escalations=-1.0,
)