Skip to content

Commit 2937d51

Browse files
committed
Refactor qa/sim runtime to config-first task flow
1 parent a78b20b commit 2937d51

15 files changed

Lines changed: 666 additions & 174 deletions

File tree

psyflow/StimUnit.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import Callable, Optional, List, Dict, Any, Sequence, TypeAlias, Union
44
import importlib
55
import random
6-
from .qa.context import get_context
6+
from .sim.context import get_context
77
from .io.events import TriggerEvent
88
from .sim.adapter import ResponderAdapter, ResponderActionError
99
from .sim.contracts import Feedback, Observation
@@ -83,7 +83,7 @@ def _qa_scale_duration(self, nominal_s: float) -> tuple[float, int, bool]:
8383
"""Return (used_seconds, n_frames, scaled_flag) for QA mode.
8484
8585
This keeps default behavior unchanged. In QA mode, scaling is opt-in
86-
via QAContext.config.enable_scaling.
86+
via runtime context config.enable_scaling.
8787
"""
8888
used = float(nominal_s)
8989
n_frames = max(1, int(round(used / self.frame_time)))

psyflow/__init__.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,17 @@
3838
"initialize_triggers": ("psyflow.io", "initialize_triggers"),
3939
"initialize_exp": ("psyflow.utils", "initialize_exp"),
4040
"list_supported_voices": ("psyflow.utils", "list_supported_voices"),
41+
# Task runtime option parsing helpers
42+
"TaskRunOptions": ("psyflow.task_options", "TaskRunOptions"),
43+
"build_task_arg_parser": ("psyflow.task_options", "build_task_arg_parser"),
44+
"parse_task_run_options": ("psyflow.task_options", "parse_task_run_options"),
45+
"resolve_config_path": ("psyflow.task_options", "resolve_config_path"),
46+
"resolve_mode": ("psyflow.task_options", "resolve_mode"),
47+
# Sim/QA runtime context helpers
48+
"RuntimeContext": ("psyflow.sim", "RuntimeContext"),
49+
"context_from_config": ("psyflow.sim", "context_from_config"),
50+
"runtime_context": ("psyflow.sim", "runtime_context"),
51+
"set_trial_context": ("psyflow.sim", "set_trial_context"),
4152
}
4253

4354
__all__ = ["__version__", *_LAZY_ATTRS.keys()]
@@ -51,6 +62,19 @@
5162
from .TaskSettings import TaskSettings as TaskSettings
5263
from .cli import main as cli_main
5364
from .io import initialize_triggers as initialize_triggers
65+
from .task_options import (
66+
TaskRunOptions as TaskRunOptions,
67+
build_task_arg_parser as build_task_arg_parser,
68+
parse_task_run_options as parse_task_run_options,
69+
resolve_config_path as resolve_config_path,
70+
resolve_mode as resolve_mode,
71+
)
72+
from .sim import (
73+
RuntimeContext as RuntimeContext,
74+
context_from_config as context_from_config,
75+
runtime_context as runtime_context,
76+
set_trial_context as set_trial_context,
77+
)
5478
from .utils import (
5579
count_down as count_down,
5680
initialize_exp as initialize_exp,

psyflow/io/runtime.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def _next_id(self) -> int:
6565
def _log(self, rec: dict[str, Any]) -> None:
6666
# 1) QA event sink (if active)
6767
try:
68-
from psyflow.qa.context import log_event
68+
from psyflow.sim.context import log_event
6969

7070
log_event(rec)
7171
except Exception:
@@ -90,7 +90,7 @@ def _effective_strict(self) -> bool:
9090
if self.strict:
9191
return True
9292
try:
93-
from psyflow.qa.context import get_context
93+
from psyflow.sim.context import get_context
9494

9595
ctx = get_context()
9696
return bool(ctx is not None and getattr(getattr(ctx, "config", None), "strict", False))

psyflow/qa/__init__.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,28 +4,17 @@
44
used in lightweight static checks and CI.
55
"""
66

7-
from .context import QAConfig, QAContext, context_from_env, get_context, log_event, log_sim_event, qa_context
8-
from .responder import NullResponder, ScriptedResponder
9-
from .static import contract_lint, load_yaml, static_qa
10-
from .trace import validate_events, validate_trace_csv
117
from .artifacts import qa_artifact_paths
128
from .report import FailureCode, QAReport
9+
from .static import contract_lint, load_yaml, static_qa
10+
from .trace import validate_events, validate_trace_csv
1311

1412
__all__ = [
1513
"FailureCode",
16-
"QAConfig",
17-
"QAContext",
1814
"QAReport",
19-
"NullResponder",
20-
"ScriptedResponder",
21-
"context_from_env",
2215
"contract_lint",
23-
"get_context",
2416
"load_yaml",
25-
"log_event",
26-
"log_sim_event",
2717
"qa_artifact_paths",
28-
"qa_context",
2918
"static_qa",
3019
"validate_events",
3120
"validate_trace_csv",

psyflow/qa/responder.py

Lines changed: 0 additions & 15 deletions
This file was deleted.

psyflow/sim/__init__.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,15 @@
1717
ScriptedResponder,
1818
SessionInfo,
1919
)
20+
from .context import (
21+
RuntimeConfig,
22+
RuntimeContext,
23+
context_from_config,
24+
get_context,
25+
log_event,
26+
log_sim_event,
27+
runtime_context,
28+
)
2029
from .context_helpers import set_trial_context
2130
from .loader import load_responder
2231
from .logging import iter_sim_events, make_sim_jsonl_logger
@@ -31,12 +40,19 @@
3140
"ResponderActionError",
3241
"ResponderAdapter",
3342
"ResponderProtocol",
43+
"RuntimeConfig",
44+
"RuntimeContext",
3445
"ScriptedResponder",
3546
"SessionInfo",
47+
"context_from_config",
48+
"get_context",
49+
"log_event",
50+
"log_sim_event",
3651
"load_responder",
3752
"iter_sim_events",
3853
"make_rng",
3954
"make_sim_jsonl_logger",
4055
"make_trial_seed",
56+
"runtime_context",
4157
"set_trial_context",
4258
]
Lines changed: 42 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,15 @@
33
import contextlib
44
import contextvars
55
import json
6-
import os
76
from dataclasses import dataclass, field
87
from datetime import datetime, timezone
98
from pathlib import Path
109
from typing import Any, Callable, Optional
1110

12-
from psyflow.sim.contracts import SessionInfo
13-
from psyflow.sim.loader import load_responder
14-
from psyflow.sim.logging import make_sim_jsonl_logger
15-
from psyflow.sim.rng import make_rng
11+
from .contracts import SessionInfo
12+
from .loader import load_responder
13+
from .logging import make_sim_jsonl_logger
14+
from .rng import make_rng
1615

1716

1817
def _cfg_get(mapping: dict[str, Any] | None, path: tuple[str, ...], default: Any = None) -> Any:
@@ -40,10 +39,13 @@ def _as_bool(value: Any, default: bool = False) -> bool:
4039
return bool(default)
4140

4241

43-
@dataclass(frozen=True)
44-
class QAConfig:
45-
"""QA/sim knobs used by the responder injection layer."""
42+
def _normalize_mode(value: Any, default: str = "human") -> str:
43+
mode = str(value or "").strip().lower() or str(default or "human").strip().lower() or "human"
44+
return mode if mode in ("human", "qa", "sim") else "human"
45+
4646

47+
@dataclass(frozen=True)
48+
class RuntimeConfig:
4749
enable_scaling: bool = False
4850
timing_scale: float = 1.0
4951
min_frames: int = 2
@@ -55,11 +57,11 @@ class QAConfig:
5557

5658

5759
@dataclass
58-
class QAContext:
60+
class RuntimeContext:
5961
mode: str = "human" # human | qa | sim
6062
responder: Any = None
6163
responder_meta: dict[str, Any] = field(default_factory=dict)
62-
config: QAConfig = field(default_factory=QAConfig)
64+
config: RuntimeConfig = field(default_factory=RuntimeConfig)
6365
event_logger: Optional[Callable[[dict[str, Any]], None]] = None
6466
sim_logger: Optional[Callable[[dict[str, Any]], None]] = None
6567
task_dir: Optional[Path] = None
@@ -68,15 +70,16 @@ class QAContext:
6870
rng: Any = None
6971

7072

71-
_CTX: contextvars.ContextVar[Optional[QAContext]] = contextvars.ContextVar("psyflow_qa_ctx", default=None)
73+
_CTX: contextvars.ContextVar[Optional[RuntimeContext]] = contextvars.ContextVar(
74+
"psyflow_runtime_ctx", default=None
75+
)
7276

7377

74-
def get_context() -> Optional[QAContext]:
78+
def get_context() -> Optional[RuntimeContext]:
7579
return _CTX.get()
7680

7781

7882
def log_event(event: dict[str, Any]) -> None:
79-
"""Best-effort QA event logging (never raise)."""
8083
ctx = get_context()
8184
if ctx is None or ctx.event_logger is None:
8285
return
@@ -87,7 +90,6 @@ def log_event(event: dict[str, Any]) -> None:
8790

8891

8992
def log_sim_event(event: dict[str, Any]) -> None:
90-
"""Best-effort simulation event logging (never raise)."""
9193
ctx = get_context()
9294
if ctx is None or ctx.sim_logger is None:
9395
return
@@ -115,8 +117,7 @@ def _log(ev: dict[str, Any]) -> None:
115117

116118

117119
@contextlib.contextmanager
118-
def qa_context(ctx: QAContext):
119-
"""Activate QA context for the current execution context."""
120+
def runtime_context(ctx: RuntimeContext):
120121
token = _CTX.set(ctx)
121122
try:
122123
yield ctx
@@ -130,54 +131,38 @@ def qa_context(ctx: QAContext):
130131
_CTX.reset(token)
131132

132133

133-
def context_from_env(
134+
def context_from_config(
134135
*,
135-
task_dir: str | os.PathLike | None = None,
136+
task_dir: str | Path | None = None,
136137
config: dict[str, Any] | None = None,
137-
) -> QAContext:
138-
"""Build QA/sim context from environment variables (+ optional config mapping)."""
138+
mode: str = "human",
139+
) -> RuntimeContext:
140+
"""Build runtime context from config with explicit mode selection."""
139141
raw_cfg = config
140142
if isinstance(config, dict) and isinstance(config.get("raw"), dict):
141143
raw_cfg = config.get("raw")
142144

143-
mode_cfg = str(_cfg_get(raw_cfg, ("sim", "mode"), "human") or "human").strip().lower()
144-
mode = os.getenv("PSYFLOW_MODE", mode_cfg).strip().lower() or "human"
145+
mode_cfg = _normalize_mode(_cfg_get(raw_cfg, ("sim", "mode"), mode))
146+
mode = _normalize_mode(mode, default=mode_cfg)
145147

146148
default_output_dir = "outputs/sim" if mode == "sim" else "outputs/qa"
147149
output_dir_cfg = _cfg_get(raw_cfg, ("sim", "output_dir"), None) or _cfg_get(raw_cfg, ("qa", "output_dir"), None)
148-
output_dir = os.getenv("PSYFLOW_QA_OUTPUT_DIR", str(output_dir_cfg or default_output_dir))
149-
150-
enable_scaling = _as_bool(
151-
os.getenv(
152-
"PSYFLOW_QA_ENABLE_SCALING",
153-
str(int(_as_bool(_cfg_get(raw_cfg, ("qa", "enable_scaling"), False)))),
154-
),
155-
False,
156-
)
157-
timing_scale = float(os.getenv("PSYFLOW_QA_TIMING_SCALE", str(_cfg_get(raw_cfg, ("qa", "timing_scale"), 1.0))))
158-
min_frames = int(os.getenv("PSYFLOW_QA_MIN_FRAMES", str(_cfg_get(raw_cfg, ("qa", "min_frames"), 2))))
159-
strict = _as_bool(
160-
os.getenv("PSYFLOW_QA_STRICT", str(int(_as_bool(_cfg_get(raw_cfg, ("qa", "strict"), False))))),
161-
False,
162-
)
163-
max_wait_s = float(os.getenv("PSYFLOW_QA_MAX_WAIT_S", str(_cfg_get(raw_cfg, ("qa", "max_wait_s"), 10.0))))
150+
output_dir = str(output_dir_cfg or default_output_dir)
164151

165-
sim_policy = str(
166-
os.getenv(
167-
"PSYFLOW_SIM_POLICY",
168-
str(_cfg_get(raw_cfg, ("sim", "policy"), "strict" if strict else "warn")),
169-
)
170-
).strip().lower()
152+
enable_scaling = _as_bool(_cfg_get(raw_cfg, ("qa", "enable_scaling"), False), False)
153+
timing_scale = float(_cfg_get(raw_cfg, ("qa", "timing_scale"), 1.0))
154+
min_frames = int(_cfg_get(raw_cfg, ("qa", "min_frames"), 2))
155+
strict = _as_bool(_cfg_get(raw_cfg, ("qa", "strict"), False), False)
156+
max_wait_s = float(_cfg_get(raw_cfg, ("qa", "max_wait_s"), 10.0))
157+
158+
sim_policy = str(_cfg_get(raw_cfg, ("sim", "policy"), "strict" if strict else "warn")).strip().lower()
171159
if sim_policy not in ("strict", "warn", "coerce"):
172160
sim_policy = "strict" if strict else "warn"
173161

174-
default_rt_s = float(os.getenv("PSYFLOW_SIM_DEFAULT_RT_S", str(_cfg_get(raw_cfg, ("sim", "default_rt_s"), 0.2))))
175-
clamp_rt = _as_bool(
176-
os.getenv("PSYFLOW_SIM_CLAMP_RT", str(int(_as_bool(_cfg_get(raw_cfg, ("sim", "clamp_rt"), False))))),
177-
False,
178-
)
162+
default_rt_s = float(_cfg_get(raw_cfg, ("sim", "default_rt_s"), 0.2))
163+
clamp_rt = _as_bool(_cfg_get(raw_cfg, ("sim", "clamp_rt"), False), False)
179164

180-
cfg = QAConfig(
165+
cfg = RuntimeConfig(
181166
enable_scaling=enable_scaling,
182167
timing_scale=timing_scale,
183168
min_frames=min_frames,
@@ -193,21 +178,15 @@ def context_from_env(
193178
or "unknown_task"
194179
)
195180
task_version = _cfg_get(raw_cfg, ("task", "task_version"), _cfg_get(config, ("task_config", "task_version"), None))
196-
seed = int(os.getenv("PSYFLOW_SIM_SEED", str(_cfg_get(raw_cfg, ("sim", "seed"), 0))))
197-
participant_id = str(
198-
os.getenv(
199-
"PSYFLOW_PARTICIPANT_ID",
200-
_cfg_get(raw_cfg, ("sim", "participant_id"), _cfg_get(raw_cfg, ("task", "participant_id"), "p000")),
201-
)
202-
or "p000"
203-
)
181+
seed = int(_cfg_get(raw_cfg, ("sim", "seed"), 0))
182+
participant_id = str(_cfg_get(raw_cfg, ("sim", "participant_id"), _cfg_get(raw_cfg, ("task", "participant_id"), "p000")) or "p000")
204183
default_session_id = f"{mode}-{participant_id}-seed{seed}"
205-
session_id = str(os.getenv("PSYFLOW_SESSION_ID", str(_cfg_get(raw_cfg, ("sim", "session_id"), default_session_id))))
184+
session_id = str(_cfg_get(raw_cfg, ("sim", "session_id"), default_session_id))
206185
session = SessionInfo(
207186
participant_id=participant_id,
208187
session_id=session_id,
209188
seed=seed,
210-
mode=mode if mode in ("human", "qa", "sim") else "human",
189+
mode=mode,
211190
task_name=task_name,
212191
task_version=task_version,
213192
)
@@ -219,9 +198,9 @@ def context_from_env(
219198
events_path = out / "qa_events.jsonl"
220199
event_logger = make_jsonl_logger(events_path) if mode in ("qa", "sim") else None
221200

222-
sim_log_default = "outputs/sim/sim_events.jsonl" if mode == "sim" else str(Path(output_dir) / "sim_events.jsonl")
201+
sim_log_default = str(Path(output_dir) / "sim_events.jsonl")
223202
sim_log_cfg = _cfg_get(raw_cfg, ("sim", "log_path"), sim_log_default)
224-
sim_log_path = os.getenv("PSYFLOW_SIM_LOG_PATH", str(sim_log_cfg))
203+
sim_log_path = str(sim_log_cfg)
225204
sim_log = (tdir / sim_log_path) if (tdir is not None and not Path(sim_log_path).is_absolute()) else Path(sim_log_path)
226205
sim_logger = make_sim_jsonl_logger(sim_log) if mode in ("qa", "sim") else None
227206

@@ -242,7 +221,7 @@ def context_from_env(
242221
allow_fallback=not strict,
243222
)
244223

245-
return QAContext(
224+
return RuntimeContext(
246225
mode=mode,
247226
responder=responder,
248227
responder_meta=responder_meta,

0 commit comments

Comments
 (0)