diff --git a/mypy.ini b/mypy.ini index 6f189e2e9..7c9443a59 100644 --- a/mypy.ini +++ b/mypy.ini @@ -18,6 +18,26 @@ warn_unreachable = False [mypy-scripts.local.launch_simp] warn_unreachable = False +# Domino debug/analysis scripts (init-state rendering, sketch replay, failure +# reproduction): exploratory tooling that is heavy on untyped third-party calls +# (PIL drawing etc.), so the strict def/call typing required of library code is +# relaxed here, mirroring the per-script carve-outs above. +[mypy-scripts.render_unsolved_domino_states] +disallow_untyped_defs = False +disallow_untyped_calls = False + +[mypy-scripts.render_domino_initial_states] +disallow_untyped_defs = False +disallow_untyped_calls = False + +[mypy-scripts.replay_domino_sketches] +disallow_untyped_defs = False +disallow_untyped_calls = False + +[mypy-scripts.reproduce_domino_failures] +disallow_untyped_defs = False +disallow_untyped_calls = False + [mypy-predicators.tests.*] ignore_missing_imports = True diff --git a/predicators/agent_sdk/agent_session_mixin.py b/predicators/agent_sdk/agent_session_mixin.py index c4e8a8396..be0d65134 100644 --- a/predicators/agent_sdk/agent_session_mixin.py +++ b/predicators/agent_sdk/agent_session_mixin.py @@ -2,7 +2,7 @@ Extracts common code for ToolContext initialization, lazy AgentSessionManager creation, async-to-sync bridging, and agent explorer -creation from AgentPlannerApproach and AgentAbstractionLearningApproach. +creation shared by AgentPlannerApproach and its subclasses. """ import asyncio import logging diff --git a/predicators/agent_sdk/bilevel_sketch.py b/predicators/agent_sdk/bilevel_sketch.py index 238072d1c..dcf435304 100644 --- a/predicators/agent_sdk/bilevel_sketch.py +++ b/predicators/agent_sdk/bilevel_sketch.py @@ -12,7 +12,7 @@ import dataclasses import logging import re -from typing import Callable, Collection, List, Optional, Sequence, Set, \ +from typing import Callable, Collection, Dict, List, Optional, Sequence, Set, \ Tuple, cast import numpy as np @@ -20,8 +20,8 @@ from predicators import utils from predicators.option_model import _OptionModelBase from predicators.planning import run_backtracking_refinement -from predicators.structs import GroundAtom, Object, ParameterizedOption, \ - Predicate, State, Task, Type, _Option +from predicators.structs import GroundAtom, Object, OptionSampler, \ + ParameterizedOption, Predicate, State, Task, Type, _Option # Signature of an info-gain scorer: given a candidate post-state and the # atoms whose truth the step is meant to establish, return a scalar where @@ -101,11 +101,18 @@ def build_solve_prompt( trajectory_summary: str = "", tool_names: Optional[Sequence[str]] = None, experiment_guidance: str = "", + prior_failures: str = "", ) -> str: """Build the bilevel solve/explore prompt asking for a plan sketch. Mirrors ``AgentBilevelApproach._build_solve_prompt`` but takes dependencies explicitly so explorers can reuse it. + + ``prior_failures`` is a pre-formatted block summarizing earlier + sketch attempts that the backtracking search could not refine (with a + pointer to the full per-step log in the sandbox). Injected so a + re-query produces a *different* skeleton instead of re-emitting the + dead one. """ init_state = task.init objects = list(init_state) @@ -157,6 +164,18 @@ def build_solve_prompt( experiment_section = (f"\n## Experiment Guidance\n" f"{experiment_guidance}\n") + prior_failures_section = "" + if prior_failures: + prior_failures_section = ( + "\n## Previous Sketch Attempts (FAILED — do NOT repeat them)\n" + "Each block below is a sketch you already tried and the " + "backtracking search could NOT refine, with where it got stuck " + "and a pointer to the full per-step refinement log (read it with " + "`Read` for details). Produce a DIFFERENT skeleton that avoids " + "the failure — change the step that got stuck (object choice, " + "ordering, an intermediate step, or its subgoal annotation).\n" + f"{prior_failures}\n") + goal_nl_section = "" if task.goal_nl: goal_nl_section = f"\n## Goal Description\n{task.goal_nl}\n" @@ -168,7 +187,11 @@ def build_solve_prompt( pred_strs = [] for pred in sorted(all_predicates, key=lambda p: p.name): type_sig = ", ".join(t.name for t in pred.types) - pred_strs.append(f" {pred.name}({type_sig})") + line = f" {pred.name}({type_sig})" + if pred.natural_language_assertion is not None: + names = [t.name for t in pred.types] + line += f" — {pred.natural_language_assertion(names)}" + pred_strs.append(line) prompt = f"""You are solving a task. \ Generate a plan sketch to achieve the goal. @@ -187,7 +210,7 @@ def build_solve_prompt( ## Available Predicates (for subgoal annotations) {chr(10).join(pred_strs)} -{trajectory_summary}{tools_str} +{trajectory_summary}{tools_str}{prior_failures_section} ## Instructions Use your available tools to inspect the environment before producing the plan. @@ -246,7 +269,11 @@ def parse_subgoal_annotations( results: List[Optional[Tuple[Set[GroundAtom], Set[GroundAtom]]]] = [] for line in text.split('\n'): - stripped = line.strip() + # Mirror the enumeration-prefix tolerance in the option-plan + # parser so the per-line subgoal results stay index-parallel with + # the parsed options (a numbered "0: Pick(...)" line must be seen + # as an option line here too, else annotations misalign). + stripped = utils.strip_enumeration_prefix(line.strip()) if not stripped: continue first_token = stripped.split('(')[0] @@ -368,6 +395,7 @@ def refine_sketch( elapsed_holder: Optional[List[float]] = None, info_scorer: Optional[InfoScorer] = None, info_n_feasible_target: int = 1, + option_samplers: Optional[Dict[str, OptionSampler]] = None, ) -> Tuple[List[_Option], bool, int]: """Backtracking search over continuous parameters for a plan sketch. @@ -415,6 +443,14 @@ def refine_sketch( from the sketch's subgoal annotations into ``grounded.memory`` so that ``WaitOption`` terminates on the intended atom change rather than the first incidental one. + + ``option_samplers`` maps an option name to a per-skill sampler + ``(state, subgoal_atoms, rng, objects) -> params`` (the NSRTSampler + signature, with the step subgoal in the atoms slot), used on both + plain and info-seeking draws to aim that option's parameters at the + subgoal instead of drawing uniformly. The return is clipped to the + option's box; a missing or misbehaving sampler falls back to uniform + sampling. """ if not sketch: return [], False, 0 @@ -431,6 +467,42 @@ def refine_sketch( deepest_fail_idx: List[int] = [-1] deepest_fail_prefix: List[List[Optional[_Option]]] = [[]] + # Options whose synthesized sampler already misbehaved once — so the + # per-draw fallback warning fires at most once per option, not on every + # one of the (potentially thousands of) draws during backtracking. + _sampler_warned: Set[str] = set() + + def _draw_params(step: SketchStep, state: State, + rng_: np.random.Generator) -> np.ndarray: + """Draw continuous params for a step's option. + + Uses a registered per-skill sampler (keyed by option name) when + present, else falls back to uniform ``sample_params`` — also on + a sampler error or wrong-shaped return. + """ + sampler = (option_samplers.get(step.option.name) + if option_samplers else None) + if sampler is not None: + box = step.option.params_space + expected = box.shape[0] + try: + raw = sampler(state, step.subgoal_atoms or set(), rng_, + list(step.objects)) + params = np.asarray(raw, dtype=np.float32).reshape(-1) + if params.shape == (expected, ): + return np.clip(params, box.low, box.high) + reason = (f"returned shape {params.shape}, " + f"expected ({expected},)") + except Exception as e: # pylint: disable=broad-except + reason = f"raised {type(e).__name__}: {e}" + if step.option.name not in _sampler_warned: + _sampler_warned.add(step.option.name) + logging.warning( + "[%s] synthesized sampler for %s %s; falling back to " + "uniform sampling for this option.", run_id, + step.option.name, reason) + return sample_params(step.option, rng_) + def _ground(step: SketchStep, params: np.ndarray) -> _Option: grounded = step.option.ground(list(step.objects), params) if grounded.name == "Wait": @@ -458,10 +530,21 @@ def _info_seeking_applies(step: SketchStep) -> bool: # step exhausts precisely when every pooled candidate has been tried # (with 1-draw fillers for attempts left over when the pool came up # short of the target). + def _is_deterministic(step: SketchStep) -> bool: + # A sampler may flag itself as returning constant params (ignoring + # state/rng); re-drawing it yields the identical option, so its step + # gets a single attempt -- backtracking then skips straight past it + # instead of wasting the full budget re-descending through it. + sampler = (option_samplers.get(step.option.name) + if option_samplers else None) + return bool(getattr(sampler, "deterministic", False)) + max_tries = [] for _step in sketch: if _step.option.params_space.shape[0] == 0: max_tries.append(1) + elif _is_deterministic(_step): + max_tries.append(1) elif _info_seeking_applies(_step): max_tries.append(info_n_feasible_target) else: @@ -538,7 +621,7 @@ def _sample_info_seeking(step: SketchStep, state: State, first_candidate: Optional[_Option] = None n_draws = 0 while len(scored) < info_n_feasible_target and n_draws < draw_cap: - grounded = _ground(step, sample_params(step.option, rng_)) + grounded = _ground(step, _draw_params(step, state, rng_)) n_draws += 1 if first_candidate is None: first_candidate = grounded @@ -610,7 +693,7 @@ def sample_fn(idx: int, state: State, f"{state.pretty_str()}") if _info_seeking_applies(step): return _sample_info_seeking(step, state, rng_, idx) - return _ground(step, sample_params(step.option, rng_)) + return _ground(step, _draw_params(step, state, rng_)) def validate_fn(idx: int, _pre_state: State, _option: _Option, post_state: State, _num_actions: int) -> Tuple[bool, str]: @@ -861,3 +944,146 @@ def validate_fn(i: int, _pre: State, _opt: _Option, post: State, completed, opt_str, last_err or "unknown reason") return False, diagnosis_holder[0] or "validation failed" + + +def resolve_refine_timeout( + timeout: Optional[float], + n_steps: int, + *, + per_step: float, + minimum: float, +) -> Tuple[float, str]: + """Resolve a refinement timeout, auto-scaling by sketch length. + + When ``timeout`` is None it auto-scales as + ``max(minimum, per_step * n_steps)`` so longer sketches get more + budget. Returns ``(timeout_seconds, source)`` where ``source`` is + ``"auto"`` or ``"explicit"``. Config defaults are passed in (not read + from ``CFG``) to keep this module settings-free. + """ + if timeout is None: + return float(max(minimum, per_step * n_steps)), "auto" + return float(timeout), "explicit" + + +def refine_and_validate_report( + task: Task, + sketch: List[SketchStep], + option_model: _OptionModelBase, + *, + predicates: Set[Predicate], + timeout: float, + rng: np.random.Generator, + max_samples_per_step: int, + check_subgoals: bool, + log_state: bool = False, + option_samplers: Optional[Dict[str, OptionSampler]] = None, + run_id: str = "refine", + timeout_source: str = "explicit", + extra_summary_lines: Optional[List[str]] = None, +) -> Tuple[bool, str]: + """Refine a sketch, forward-validate on success, return a report. + + Runs ``refine_sketch`` (backtracking search over continuous params) + and, when refinement succeeds, ``validate_plan_forward`` (continuous + re-execution). Returns ``(overall_success, human_readable_report)`` + where ``overall_success`` is True only if both refinement and forward + validation pass. The report names the verdict (SUCCESS / TIMEOUT / + SAMPLE_EXHAUSTED / FORWARD_VALIDATION_FAILED), per-step sample counts, + the stuck step on failure, and the forward-validation outcome. + + ``extra_summary_lines`` are appended verbatim after the time line + (e.g. a caller-specific ``Post-fit SSE`` line). Config-derived knobs + (``timeout``, ``max_samples_per_step``, ``check_subgoals``, + ``log_state``) are passed explicitly so this module stays free of + ``CFG``; callers read them from settings. + """ + step_samples_cumulative: List[int] = [0] * len(sketch) + termination_reason: List[str] = [] + elapsed_holder: List[float] = [] + plan, success, n_samples = refine_sketch( + task, + sketch, + option_model, + predicates=predicates, + timeout=timeout, + rng=rng, + max_samples_per_step=max_samples_per_step, + check_subgoals=check_subgoals, + log_state=log_state, + run_id=run_id, + step_samples_cumulative=step_samples_cumulative, + termination_reason=termination_reason, + elapsed_holder=elapsed_holder, + option_samplers=option_samplers, + ) + + reason = termination_reason[0] if termination_reason else ( + "success" if success else "exhausted") + elapsed = elapsed_holder[0] if elapsed_holder else 0.0 + if success: + verdict = "SUCCESS" + elif reason == "timeout": + verdict = "FAILURE: TIMEOUT" + elif reason == "exhausted": + verdict = "FAILURE: SAMPLE_EXHAUSTED" + else: + verdict = "FAILURE" + + lines = [ + verdict, + f" Sketch: {len(sketch)} steps Refined: {len(plan)} steps " + f"Samples: {n_samples} total", + f" Per-step samples: {step_samples_cumulative} " + f"(cap {max_samples_per_step}/step)", + f" Time: {elapsed:.1f}s used / {timeout:.1f}s allotted " + f"(timeout source: {timeout_source})", + ] + if extra_summary_lines: + lines.extend(extra_summary_lines) + if not success and len(plan) < len(sketch): + stuck_idx = len(plan) + stuck = sketch[stuck_idx] + objs = ", ".join(f"{o.name}:{o.type.name}" for o in stuck.objects) + lines.append(f" Stuck at step {stuck_idx}: " + f"{stuck.option.name}({objs})") + if stuck.subgoal_atoms: + atoms = ", ".join(str(a) for a in stuck.subgoal_atoms) + lines.append(f" subgoals: {atoms}") + + # Forward validation: re-execute the refined plan continuously (state + # carries forward across all options). Refinement's per-step resets + # and resampling can mask drift the real env will hit at test time. + if success: + try: + fv_ok, fv_reason = validate_plan_forward( + task, + plan, + option_model, + predicates=predicates, + sketch=sketch, + run_id=run_id, + ) + except Exception as e: # pylint: disable=broad-except + fv_ok = False + fv_reason = f"forward validation raised: {e}" + if fv_ok: + lines.append(" Forward validation: SUCCESS") + else: + # Demote the headline verdict: refinement passed but the plan + # does not survive continuous execution, which is what the + # real env will see at test time. + success = False + lines[0] = "FAILURE: FORWARD_VALIDATION_FAILED" + lines.append(f" Forward validation: FAIL — {fv_reason}") + lines.append( + " (Refinement resets state between options and " + "resamples up to the per-step cap; forward validation " + "runs the same plan once continuously. A divergence here " + "means the refined plan does not survive continuous " + "execution — accumulated drift, or (when the model is " + "learned) a rule/threshold more permissive than the env's " + "effective behavior. See the INFO log for the step-by-step " + "divergence.)") + + return success, "\n".join(lines) diff --git a/predicators/agent_sdk/docker_sandbox.py b/predicators/agent_sdk/docker_sandbox.py index 2491e1cf9..888de6c00 100644 --- a/predicators/agent_sdk/docker_sandbox.py +++ b/predicators/agent_sdk/docker_sandbox.py @@ -47,7 +47,7 @@ from predicators.agent_sdk.sandbox_prompts import build_claude_md, \ build_sandbox_system_prompt, find_repo_root, setup_sandbox_directory -from predicators.agent_sdk.tools import ToolContext +from predicators.agent_sdk.tools import ToolContext, session_log_filename from predicators.settings import CFG logger = logging.getLogger(__name__) @@ -234,7 +234,9 @@ async def query(self, # Counter-first layout: alphabetical sort matches chronological # order across mixed ``learn``/``test``/``explore`` phases. timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") - log_filename = f"{self._query_count:03d}_{kind}_{timestamp}.md" + log_filename = session_log_filename( + self._query_count, kind, timestamp, + getattr(self._tool_context, "test_task_idx", None)) if self._log_dir: os.makedirs(self._log_dir, exist_ok=True) incremental_log_path = os.path.join(self._log_dir, log_filename) @@ -540,7 +542,9 @@ def _save_query_response_log(self, query: str, timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") kind = getattr(self, "_last_kind", "query") - filename = f"{self._query_count:03d}_{kind}_{timestamp}.md" + filename = session_log_filename( + self._query_count, kind, timestamp, + getattr(self._tool_context, "test_task_idx", None)) filepath = os.path.join(self._log_dir, filename) lines = [ diff --git a/predicators/agent_sdk/local_sandbox.py b/predicators/agent_sdk/local_sandbox.py index eb6fc8863..84e5450fd 100644 --- a/predicators/agent_sdk/local_sandbox.py +++ b/predicators/agent_sdk/local_sandbox.py @@ -33,7 +33,8 @@ from predicators.agent_sdk.sandbox_prompts import build_claude_md, \ build_sandbox_system_prompt, find_repo_root, setup_sandbox_directory, \ truncate -from predicators.agent_sdk.tools import BUILTIN_TOOLS, ToolContext +from predicators.agent_sdk.tools import BUILTIN_TOOLS, ToolContext, \ + session_log_filename from predicators.settings import CFG logger = logging.getLogger(__name__) @@ -75,6 +76,10 @@ def __init__( self._phase = phase self._total_cost_usd: float = 0.0 + # The SDK reports total_cost_usd as the cumulative cost of the + # reused session, so we track the last value seen to derive the + # per-solve (marginal) cost of each query. + self._last_cost_usd: float = 0.0 self._total_turns: int = 0 self._query_count: int = 0 self._session_id: Optional[str] = None @@ -244,15 +249,30 @@ async def query(self, logging.debug("Agent tool call: %s(%s)", block["name"], param_summary) elif entry["type"] == "result": - cost = entry.get("total_cost_usd") + cost: Optional[float] = entry.get("total_cost_usd") turns = entry.get("num_turns") + solve_cost: Optional[float] = None if cost is not None: - self._total_cost_usd += cost + # cost is the session's cumulative total; the + # per-solve cost is the delta since the last result. + # A drop below the last value means the session was + # reset (e.g. recovery), so the new cumulative is + # itself the delta. + solve_cost = (cost - self._last_cost_usd + if cost >= self._last_cost_usd else cost) + self._last_cost_usd = cost + self._total_cost_usd += solve_cost + self._current_log_meta["solve_cost_usd"] = solve_cost + self._current_log_meta["total_cost_usd"] = \ + self._total_cost_usd if turns is not None: self._total_turns += turns logging.info( - "Local sandbox iteration complete. " - "Turns: %s, Cost: $%s", turns or '?', cost or '?') + "Local sandbox iteration complete. Turns: %s, " + "Cost this solve: $%s, Total cost so far: $%s", turns + or '?', + f"{solve_cost:.4f}" if solve_cost is not None else '?', + f"{self._total_cost_usd:.4f}") # Flush log after each message if log_path: @@ -333,11 +353,13 @@ def save_session_info(self) -> None: # -- Logging helpers -- - # Matches both the new ``NNN_kind_ts.md`` layout and the legacy + # Matches the new ``NNN_kind[_taskN]_ts.md`` layout and the legacy # ``kind_NNN_ts.md`` layout so resuming across the migration is - # lossless. The counter is always captured in group 1 or 2. + # lossless. The counter is always captured in group 1 or 2; the + # optional ``_task`` segment tags test queries with their task. _LOG_FILENAME_RE = re.compile( - r"^(?:(\d{3})_[a-z][a-z_]*|[a-z][a-z_]*_(\d{3}))_\d{8}_\d{6}\.md$") + r"^(?:(\d{3})_[a-z][a-z_]*(?:_task\d+)?|[a-z][a-z_]*_(\d{3}))" + r"_\d{8}_\d{6}\.md$") def _seed_query_count_from_log_dir(self) -> None: """Make the per-session counter continuous across the run. @@ -375,8 +397,11 @@ def _init_incremental_log(self, timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") # Counter-first layout: alphabetical sort matches chronological - # order across mixed ``learn``/``test``/``explore`` phases. - filename = f"{self._query_count:03d}_{kind}_{timestamp}.md" + # order across mixed ``learn``/``test``/``explore`` phases. Test + # queries also carry a ``_task`` segment for attribution. + filename = session_log_filename( + self._query_count, kind, timestamp, + getattr(self._tool_context, "test_task_idx", None)) # Primary: main log dir (host-visible) filepath = os.path.join(self._log_dir, filename) os.makedirs(self._log_dir, exist_ok=True) diff --git a/predicators/agent_sdk/log_formatter.py b/predicators/agent_sdk/log_formatter.py index c1eac0451..65f904d16 100644 --- a/predicators/agent_sdk/log_formatter.py +++ b/predicators/agent_sdk/log_formatter.py @@ -59,8 +59,17 @@ def format_conversation_markdown( elif etype == "result": turns = entry.get("num_turns", "?") - cost = entry.get("total_cost_usd") - cost_str = f"${cost:.2f}" if cost is not None else "?" + # Prefer the per-solve/total split the sandbox derives (the + # raw total_cost_usd is the cumulative session cost); fall + # back to the raw cumulative value when it isn't supplied. + solve_cost = meta.get("solve_cost_usd") if meta else None + total_cost = meta.get("total_cost_usd") if meta else None + if solve_cost is not None and total_cost is not None: + cost_str = (f"${solve_cost:.2f} this solve, " + f"${total_cost:.2f} total") + else: + cost = entry.get("total_cost_usd") + cost_str = f"${cost:.2f}" if cost is not None else "?" lines.append(f"---\n\n**Result:** {turns} turns, {cost_str}\n") elif etype == "error": diff --git a/predicators/agent_sdk/proposal_parser.py b/predicators/agent_sdk/proposal_parser.py index b65dd8005..d53082092 100644 --- a/predicators/agent_sdk/proposal_parser.py +++ b/predicators/agent_sdk/proposal_parser.py @@ -21,7 +21,7 @@ class ProposalBundle: # Retractions: names of previously-proposed abstractions to remove retract_type_names: Set[str] = field(default_factory=set) retract_predicate_names: Set[str] = field(default_factory=set) - retract_object_augmentor: bool = False + retract_task_augmentor: bool = False retract_process_names: Set[str] = field(default_factory=set) retract_option_names: Set[str] = field(default_factory=set) diff --git a/predicators/agent_sdk/sandbox_prompts.py b/predicators/agent_sdk/sandbox_prompts.py index 543b25449..a6bdd6e51 100644 --- a/predicators/agent_sdk/sandbox_prompts.py +++ b/predicators/agent_sdk/sandbox_prompts.py @@ -174,7 +174,7 @@ def deny(reason): Read ./session_logs/001_learn_*.md ## Scene Images -`test_option_plan` automatically saves scene images to ./test_images/ +`evaluate_option_plan` automatically saves scene images to ./test_images/ after each step. You can Read them to inspect the spatial state of the environment. @@ -207,7 +207,7 @@ def deny(reason): - **Use visualize_state liberally** — it's free (no physics, no failure modes). When stuck on a step, STOP testing and visualize the object at several candidate positions and orientations to find the right region - before spending more test_option_plan calls. + before spending more evaluate_option_plan calls. - **Vary all parameters** — orientation and other non-position params affect both the outcome and whether the action succeeds. - **Search coarse-to-fine** — spread initial attempts across the full @@ -332,7 +332,7 @@ def build_sandbox_system_prompt( ``` ### Scene Images -`test_option_plan` automatically saves scene images to ./test_images/ +`evaluate_option_plan` automatically saves scene images to ./test_images/ after each plan step for later review. ### Proposed Code diff --git a/predicators/agent_sdk/session_manager.py b/predicators/agent_sdk/session_manager.py index bff8331b1..073b6dd18 100644 --- a/predicators/agent_sdk/session_manager.py +++ b/predicators/agent_sdk/session_manager.py @@ -4,6 +4,7 @@ import json import logging import os +import time from typing import Any, Dict, List, Optional from predicators.agent_sdk.response_parser import parse_message @@ -33,6 +34,9 @@ def __init__(self, self._client: Any = None self._session_id: Optional[str] = None self._total_cost_usd: float = 0.0 + # total_cost_usd from the SDK is the cumulative session cost; track + # the last value to derive each query's per-solve (marginal) cost. + self._last_cost_usd: float = 0.0 self._total_turns: int = 0 self._started = False self._query_count: int = 0 @@ -132,6 +136,11 @@ async def query(self, collected: List[Dict[str, Any]] = [] log_path = self._init_incremental_log(message, kind=kind) + start = time.perf_counter() + # Wall-clock of the previous response message, so each logged step + # can report how long it took (model thinking before a tool call, + # tool execution before the next message, etc.). + prev_t = start try: await self._client.query(message) @@ -140,29 +149,49 @@ async def query(self, if entry is None: continue collected.append(entry) + now = time.perf_counter() + dt = now - prev_t + prev_t = now # Log side-effects if entry["type"] == "assistant": for block in entry.get("content", []): if block.get("type") == "text": - logging.debug("Agent: %s...", block["text"][:200]) + logging.debug("[+%.2fs] Agent: %s...", dt, + block["text"][:200]) + elif block.get("thinking") is not None: + logging.debug("[+%.2fs] Agent [thinking]: %s...", + dt, block["thinking"][:200]) elif block.get("type") == "tool_use": params = block.get("input") or {} param_summary = ", ".join( f"{k}={truncate(v)}" for k, v in params.items()) - logging.debug("Agent tool call: %s(%s)", - block["name"], param_summary) + logging.debug("[+%.2fs] Agent tool call: %s(%s)", + dt, block["name"], param_summary) elif entry["type"] == "result": - cost = entry.get("total_cost_usd") + cost: Optional[float] = entry.get("total_cost_usd") turns = entry.get("num_turns") + solve_cost: Optional[float] = None if cost is not None: - self._total_cost_usd += cost + # cost is cumulative; the per-solve cost is the + # delta since the last result (a drop means the + # session reset, so the new total is the delta). + solve_cost = (cost - self._last_cost_usd + if cost >= self._last_cost_usd else cost) + self._last_cost_usd = cost + self._total_cost_usd += solve_cost + self._current_log_meta["solve_cost_usd"] = solve_cost + self._current_log_meta["total_cost_usd"] = \ + self._total_cost_usd if turns is not None: self._total_turns += turns logging.info( - "Agent iteration complete. Turns: %s, Cost: $%s", turns - or '?', cost or '?') + "Agent iteration complete. Turns: %s, " + "Cost this solve: $%s, Total cost so far: $%s", turns + or '?', + f"{solve_cost:.4f}" if solve_cost is not None else '?', + f"{self._total_cost_usd:.4f}") # Flush log after each message if log_path: @@ -173,6 +202,10 @@ async def query(self, collected.append({"type": "error", "error": str(e)}) await self._recover_session(message) + elapsed = time.perf_counter() - start + logging.info("[agent-interaction] kind=%s took %.2fs (%d messages)", + kind, elapsed, len(collected)) + # Final flush to ensure everything is saved if log_path: self._flush_log(log_path, collected) diff --git a/predicators/agent_sdk/tools.py b/predicators/agent_sdk/tools.py index 19b375011..1747c0e2e 100644 --- a/predicators/agent_sdk/tools.py +++ b/predicators/agent_sdk/tools.py @@ -14,8 +14,8 @@ build_exec_context, exec_code_safely, validate_predicate from predicators.option_model import _OptionModelBase from predicators.settings import CFG -from predicators.structs import CausalProcess, LowLevelTrajectory, \ - ParameterizedOption, Predicate, State, Task, Type +from predicators.structs import CausalProcess, LowLevelTrajectory, Object, \ + OptionSampler, ParameterizedOption, Predicate, State, Task, Type MCP_SERVER_NAME = "predicator_tools" @@ -49,7 +49,7 @@ PROPOSAL_TOOL_NAMES = [ "propose_types", "propose_predicates", - "propose_object_augmentor", + "propose_task_augmentor", "propose_processes", "propose_options", ] @@ -57,13 +57,13 @@ "retract_abstractions", ] TESTING_TOOL_NAMES = [ - "test_predicate_on_states", - "test_planning", - "test_option_plan", + "evaluate_predicate_on_trajectory", + "evaluate_option_plan", ] PLANNING_TOOL_NAMES = [ "generate_bilevel_plan", "generate_abstract_plan", + "refine_plan_sketch", ] SCENE_TOOL_NAMES = [ "annotate_scene", @@ -87,6 +87,7 @@ "evaluate_plan_refinement", ) PREDICATE_SYNTHESIS_TOOL_NAMES = ("evaluate_predicate_quality", ) +SAMPLER_SYNTHESIS_TOOL_NAMES = ("evaluate_sampler", ) def get_allowed_tool_list(tool_names: Optional[List[str]] = None) -> List[str]: @@ -163,6 +164,12 @@ class ToolContext: # candidates that straddle the learned model's decision boundaries. # None ⇒ plain feasibility search (default). atom_disagreement_fn: Optional[Callable[[State, Any], float]] = None + # Synthesized per-skill samplers (option name -> sampler), synced from + # the learning approach when agent_sim_learn_synthesize_samplers is on. + # The agent_bilevel explorer and synthesis tools pass these into + # refinement so continuous-parameter search aims at each step's subgoal + # instead of drawing uniformly. Empty ⇒ uniform sampling (default). + option_samplers: Dict[str, OptionSampler] = field(default_factory=dict) current_task: Optional[Task] = None iteration_proposals: ProposalBundle = field(default_factory=ProposalBundle) planning_results: Dict[str, Any] = field(default_factory=dict) @@ -177,7 +184,11 @@ class ToolContext: show_option_source: bool = True # set False when using GT options iteration_id: int = 0 # current learning iteration (outer loop) turn_id: int = 0 # current query/turn within the session - test_call_id: int = 0 # incremented per test_option_plan call + # Index of the test task currently being solved (0-based), mirroring + # main.py's ``test_task_idx``. None outside the test phase. Threaded into + # the saved session-log filename so test queries are attributable to a task. + test_task_idx: Optional[int] = None + test_call_id: int = 0 # incremented per evaluate_option_plan call visualized_state: Optional[State] = None # last state from visualize_state # Managed by AgentSessionMixin: populated from # `_build_synthesis_mcp_tools` at session-open, reset to [] for @@ -200,6 +211,23 @@ class ToolContext: last_mental_model_solved: Optional[bool] = None +def session_log_filename(query_count: int, + kind: str, + timestamp: str, + test_task_idx: Optional[int] = None, + ext: str = "md") -> str: + """Build the session-log filename shared by the sandbox backends. + + Layout: ``NNN_[_task]_.``. The counter comes + first so alphabetical sort matches chronological order; for test queries + the ``_task`` segment ties the file to ``main.py``'s test task index. + """ + suffix = "" + if kind == "test" and test_task_idx is not None: + suffix = f"_task{test_task_idx}" + return f"{query_count:03d}_{kind}{suffix}_{timestamp}.{ext}" + + def _text_result(text: str) -> Dict[str, Any]: """Helper to format a successful text result.""" return {"content": [{"type": "text", "text": text}]} @@ -549,46 +577,9 @@ def _save_option_to_sandbox(ctx: ToolContext, option_name: str, return f"./proposed_code/{filename}" -def create_mcp_tools(ctx: ToolContext, - tool_names: Optional[List[str]] = None) -> list: - """Create MCP tools with the given ToolContext via closures. - - Args: - ctx: Shared mutable state between the approach and MCP tools. - tool_names: If provided, only return tools with these names. - If None, return all tools. - - Returns a list of SdkMcpTool objects to pass to create_sdk_mcp_server. - """ - from claude_agent_sdk import \ - tool # pylint: disable=import-outside-toplevel - - # Spill oversize tool output into the sandbox (``./tool_outputs/``) - # instead of returning it inline, where the agent SDK would truncate it - # and dump the full text to ``~/.claude/projects/.../tool-results/`` — - # outside the sandbox. Shadowing the module-level ``_text_result`` here - # routes every nested tool's ``_text_result(...)`` call (e.g. - # ``inspect_trajectories``) through the spiller, with no call-site edits. - _text_result = _make_spilling_text_result(ctx.sandbox_dir) - - _propose_count = [0] # mutable counter in closure - - def _save_proposal_code(tool_name: str, code: str, names: List[str], - description: str) -> None: - if not ctx.sandbox_dir: - return - _propose_count[0] += 1 - subdir = os.path.join(ctx.sandbox_dir, "proposed_code") - os.makedirs(subdir, exist_ok=True) - names_slug = "_".join(names)[:80] - filename = f"{_propose_count[0]:03d}_{tool_name}_{names_slug}.py" - filepath = os.path.join(subdir, filename) - header = f'"""{tool_name}: {description}"""\n\n' - with open(filepath, "w", encoding="utf-8") as f: - f.write(header + code) - logging.info(f"Saved proposal code to {filepath}") - - # ===== INSPECTION TOOLS ===== +def _build_inspection_tools(ctx: ToolContext, _text_result: Callable, + tool: Callable) -> Dict[str, Any]: + """Read-only inspection tools (views over ToolContext state).""" @tool("inspect_types", "List all object types and their features", {}) async def inspect_types(_args: Dict[str, Any]) -> Dict[str, Any]: @@ -916,7 +907,37 @@ async def inspect_past_proposals(_args: Dict[str, Any]) -> Dict[str, Any]: lines.append(json.dumps(entry, indent=2, default=str)) return _text_result("\n---\n".join(lines)) - # ===== PROPOSAL TOOLS ===== + return { + "inspect_types": inspect_types, + "inspect_predicates": inspect_predicates, + "inspect_processes": inspect_processes, + "inspect_options": inspect_options, + "inspect_trajectories": inspect_trajectories, + "inspect_train_tasks": inspect_train_tasks, + "inspect_planning_results": inspect_planning_results, + "inspect_past_proposals": inspect_past_proposals, + } + + +def _build_proposal_tools(ctx: ToolContext, _text_result: Callable, + tool: Callable) -> Dict[str, Any]: + """Proposal tools (agent authors new types/predicates/options/etc.).""" + _propose_count = [0] # mutable counter in closure + + def _save_proposal_code(tool_name: str, code: str, names: List[str], + description: str) -> None: + if not ctx.sandbox_dir: + return + _propose_count[0] += 1 + subdir = os.path.join(ctx.sandbox_dir, "proposed_code") + os.makedirs(subdir, exist_ok=True) + names_slug = "_".join(names)[:80] + filename = f"{_propose_count[0]:03d}_{tool_name}_{names_slug}.py" + filepath = os.path.join(subdir, filename) + header = f'"""{tool_name}: {description}"""\n\n' + with open(filepath, "w", encoding="utf-8") as f: + f.write(header + code) + logging.info(f"Saved proposal code to {filepath}") @tool( "propose_types", @@ -1019,7 +1040,7 @@ async def propose_predicates(args: Dict[str, Any]) -> Dict[str, Any]: return _text_result(msg) @tool( - "propose_object_augmentor", + "propose_task_augmentor", "Propose a task augmentation function. Code must define " "`augment_task(task) -> Task`.", { @@ -1039,7 +1060,7 @@ async def propose_predicates(args: Dict[str, Any]) -> Dict[str, Any]: "required": ["code", "description"], }, ) - async def propose_object_augmentor(args: Dict[str, Any]) -> Dict[str, Any]: + async def propose_task_augmentor(args: Dict[str, Any]) -> Dict[str, Any]: if not CFG.agent_sdk_propose_objects: return _error_result("Object augmentor proposals are disabled.") code = args["code"] @@ -1068,7 +1089,7 @@ async def propose_object_augmentor(args: Dict[str, Any]) -> Dict[str, Any]: ctx.iteration_proposals.augment_task_fn = result ctx.iteration_proposals.augment_task_code = code logging.info(f"Agent proposed augmentor adding objects: {obj_names}") - _save_proposal_code("propose_object_augmentor", code, obj_names, + _save_proposal_code("propose_task_augmentor", code, obj_names, args.get("description", "")) return _text_result( f"Successfully proposed augmentor. Test added objects: {obj_names}" @@ -1172,13 +1193,24 @@ async def propose_options(args: Dict[str, Any]) -> Dict[str, Any]: return _text_result( f"Successfully proposed {len(proposed)} options: {names}") - # ===== RETRACTION TOOLS ===== + return { + "propose_types": propose_types, + "propose_predicates": propose_predicates, + "propose_task_augmentor": propose_task_augmentor, + "propose_processes": propose_processes, + "propose_options": propose_options, + } + + +def _build_retraction_tools(ctx: ToolContext, _text_result: Callable, + tool: Callable) -> Dict[str, Any]: + """Retraction tools (remove agent-proposed abstractions).""" @tool( "retract_abstractions", "Remove previously proposed abstractions that are no longer needed. " "Specify names of predicates, processes, options, or helper types to " - "remove, and/or set clear_object_augmentor to remove the augmentor.", + "remove, and/or set clear_task_augmentor to remove the augmentor.", { "type": "object", "properties": { @@ -1210,7 +1242,7 @@ async def propose_options(args: Dict[str, Any]) -> Dict[str, Any]: }, "description": "Names of helper types to remove", }, - "clear_object_augmentor": { + "clear_task_augmentor": { "type": "boolean", "description": "Set to true to remove the object augmentor", @@ -1232,7 +1264,7 @@ async def retract_abstractions(args: Dict[str, Any]) -> Dict[str, Any]: proc_names = set(args.get("process_names") or []) opt_names = set(args.get("option_names") or []) type_names = set(args.get("type_names") or []) - clear_augmentor = bool(args.get("clear_object_augmentor", False)) + clear_augmentor = bool(args.get("clear_task_augmentor", False)) if not any( [pred_names, proc_names, opt_names, type_names, clear_augmentor]): @@ -1278,17 +1310,24 @@ async def retract_abstractions(args: Dict[str, Any]) -> Dict[str, Any]: lines.append(f" (unknown, ignored: {sorted(unknown)})") if clear_augmentor: - ctx.iteration_proposals.retract_object_augmentor = True + ctx.iteration_proposals.retract_task_augmentor = True lines.append("Object augmentor will be cleared.") logging.info(f"Agent retraction request: {args}") return _text_result("\n".join(lines)) - # ===== TESTING TOOLS ===== + return { + "retract_abstractions": retract_abstractions, + } + + +def _build_testing_tools(ctx: ToolContext, _text_result: Callable, + tool: Callable) -> Dict[str, Any]: + """Evaluation tools (run predicates / option plans against tasks).""" @tool( - "test_predicate_on_states", - "Test a predicate's truth value across timesteps in a trajectory", + "evaluate_predicate_on_trajectory", + "Evaluate a predicate's truth value across timesteps in a trajectory", { "type": "object", "properties": { @@ -1311,7 +1350,8 @@ async def retract_abstractions(args: Dict[str, Any]) -> Dict[str, Any]: "required": ["predicate_name", "traj_idx", "object_names"], }, ) - async def test_predicate_on_states(args: Dict[str, Any]) -> Dict[str, Any]: + async def evaluate_predicate_on_trajectory( + args: Dict[str, Any]) -> Dict[str, Any]: pred_name = args["predicate_name"] traj_idx = args["traj_idx"] object_names = args["object_names"] @@ -1363,62 +1403,7 @@ async def test_predicate_on_states(args: Dict[str, Any]) -> Dict[str, Any]: f"over trajectory {traj_idx}:\n" + "\n".join(results)) @tool( - "test_planning", - "Run the task planner on a specific task and report results", - { - "type": "object", - "properties": { - "task_idx": { - "type": "integer", - "description": "Task index to plan for" - }, - "timeout": { - "type": "integer", - "description": "Planning timeout in seconds", - "default": 30 - }, - }, - "required": ["task_idx"], - }, - ) - async def test_planning(args: Dict[str, Any]) -> Dict[str, Any]: - # pylint: disable=import-outside-toplevel - from predicators.approaches import ApproachFailure, ApproachTimeout - from predicators.planning_with_processes import \ - run_task_plan_with_processes_once - - task_idx = args["task_idx"] - timeout = args.get("timeout", 30) - - if task_idx < 0 or task_idx >= len(ctx.train_tasks): - return _error_result(f"Invalid task_idx {task_idx}. " - f"Available: 0-{len(ctx.train_tasks)-1}") - - task = ctx.train_tasks[task_idx] - all_preds = ctx.predicates | ctx.iteration_proposals.proposed_predicates - - try: - plan, _atoms_seq, metrics = run_task_plan_with_processes_once( - task, - ctx.processes | ctx.iteration_proposals.proposed_processes, - all_preds, - ctx.types | ctx.iteration_proposals.proposed_types, - timeout, - seed=CFG.seed, - _task_planning_heuristic=CFG.process_task_planning_heuristic, - max_horizon=float(CFG.horizon)) - plan_desc = " -> ".join(p.name for p in plan) - return _text_result( - f"Planning succeeded for task {task_idx}!\n" - f"Plan length: {len(plan)}\n" - f"Nodes expanded: {metrics.get('num_nodes_expanded', '?')}\n" - f"Plan: {plan_desc}") - except (ApproachFailure, ApproachTimeout, Exception) as e: # pylint: disable=broad-except - return _text_result(f"Planning failed for task {task_idx}.\n" - f"Reason: {type(e).__name__}: {e}") - - @tool( - "test_option_plan", + "evaluate_option_plan", "Execute a sequence of grounded options on a task via the option model " "and report the result at each step. Use include_states and/or " "include_atoms to control what is shown at each step.", @@ -1494,7 +1479,7 @@ async def test_planning(args: Dict[str, Any]) -> Dict[str, Any]: "required": ["option_plan"], }, ) - async def test_option_plan(args: Dict[str, Any]) -> Dict[str, Any]: + async def evaluate_option_plan(args: Dict[str, Any]) -> Dict[str, Any]: import numpy as np # pylint: disable=reimported,redefined-outer-name,import-outside-toplevel from predicators import \ @@ -1671,7 +1656,15 @@ async def test_option_plan(args: Dict[str, Any]) -> Dict[str, Any]: # Build result with text only (images are saved to disk) return _text_result("\n".join(lines)) - # ===== PLANNING TOOLS ===== + return { + "evaluate_predicate_on_trajectory": evaluate_predicate_on_trajectory, + "evaluate_option_plan": evaluate_option_plan, + } + + +def _build_planning_tools(ctx: ToolContext, _text_result: Callable, + tool: Callable) -> Dict[str, Any]: + """Planning tools (generate bilevel / abstract plans).""" @tool( "generate_bilevel_plan", @@ -1913,16 +1906,163 @@ async def generate_abstract_plan(args: Dict[str, Any]) -> Dict[str, Any]: return _text_result("\n".join(lines)) + @tool( + "refine_plan_sketch", + "Test whether a plan SKETCH is refinable: run backtracking search " + "for continuous parameters over the option model, then — on success " + "— forward-validate the refined plan by re-executing it continuously. " + "Unlike evaluate_option_plan (which runs a fully-specified plan whose " + "params you supply), this takes a sketch WITHOUT continuous params " + "and lets the search find them, exactly as the bilevel planner does " + "at solve time. `plan` is one option call per line with typed object " + "references (`obj:type`) and every argument supplied; add optional " + "`-> {Atom(obj:type, ...)}` subgoal annotations (effectively required " + "after open-ended skills like Place, and for Wait to say when it " + "should end — prefix an atom with NOT to require it become false). " + "Reports the verdict (SUCCESS / TIMEOUT / SAMPLE_EXHAUSTED with the " + "stuck step / FORWARD_VALIDATION_FAILED), per-step sample counts, and " + "time used. Requires a simulator (option model). Slow — use to vet a " + "skeleton before committing.", + { + "type": "object", + "properties": { + "plan": { + "type": + "string", + "description": + "Option-skeleton plan text, one option call per " + "line, typed `obj:type` references, every argument " + "supplied; optional `-> {Atom(...)}` subgoal per step.", + }, + "task_idx": { + "type": + "integer", + "description": + "Train task index. Omit to use the current " + "solve-time task (if available).", + }, + "timeout": { + "type": + "number", + "description": + "Refinement timeout in seconds. Omit for an auto " + "value that scales with sketch length; the value " + "used is reported back.", + }, + }, + "required": ["plan"], + }, + ) + async def refine_plan_sketch(args: Dict[str, Any]) -> Dict[str, Any]: + # pylint: disable=import-outside-toplevel,reimported,redefined-outer-name + import numpy as np + + from predicators.agent_sdk import bilevel_sketch + + if ctx.option_model is None: + return _error_result( + "refine_plan_sketch requires a simulator (no option model " + "in ToolContext).") + + # Resolve the task (mirrors evaluate_option_plan). + task_idx = args.get("task_idx") + if task_idx is not None: + if task_idx < 0 or task_idx >= len(ctx.train_tasks): + return _error_result(f"Invalid task_idx {task_idx}. " + f"Available: 0-{len(ctx.train_tasks)-1}") + task = ctx.train_tasks[task_idx] + elif ctx.current_task is not None: + task = ctx.current_task + task_idx = "current" + else: + return _error_result( + "No task_idx provided and no current_task set.") + + all_options = ctx.options | ctx.iteration_proposals.proposed_options + all_predicates = (ctx.predicates + | ctx.iteration_proposals.proposed_predicates) + # Keep the option model's name map in sync with proposed options so + # refinement can ground them (matches evaluate_option_plan). + model = ctx.option_model + model._name_to_parameterized_option = ( # type: ignore[attr-defined] # pylint: disable=protected-access + {o.name: o + for o in all_options}) + # Union declared types with those reachable from options/predicates/ + # objects so typed `obj:type` references in the sketch resolve. + types = set(ctx.types) + for opt in all_options: + types.update(opt.types) + for pred in all_predicates: + types.update(pred.types) + types.update(o.type for o in task.init) + + plan_text = (args.get("plan") or "").strip() + if not plan_text: + return _error_result("`plan` is required (option-skeleton text).") + try: + sketch = bilevel_sketch.parse_sketch_from_text( + plan_text, + task, + predicates=all_predicates, + options=all_options, + types=types, + ) + except Exception as e: # pylint: disable=broad-except + return _error_result(f"Could not parse plan sketch: {e}") + if not sketch: + return _error_result( + "Parsed empty plan sketch. Check that every line names a " + "known option with typed `obj:type` arguments matching what " + "the inspect tools report.") + + timeout, timeout_source = bilevel_sketch.resolve_refine_timeout( + args.get("timeout"), + len(sketch), + per_step=CFG.agent_bilevel_refinement_timeout_per_step, + minimum=CFG.agent_bilevel_refinement_timeout_min) + + try: + _, report = bilevel_sketch.refine_and_validate_report( + task, + sketch, + ctx.option_model, + predicates=all_predicates, + timeout=timeout, + rng=np.random.default_rng(CFG.seed), + max_samples_per_step=CFG.agent_bilevel_max_samples_per_step, + check_subgoals=CFG.agent_bilevel_check_subgoals, + log_state=CFG.agent_bilevel_log_state, + option_samplers=ctx.option_samplers or None, + run_id="planner_refine", + timeout_source=timeout_source, + ) + except Exception: # pylint: disable=broad-except + tb = traceback.format_exc() + return _error_result(f"Refinement raised:\n{tb}") + + return _text_result(f"Task {task_idx}:\n{report}") + # ------------------------------------------------------------------ # # Scene annotation # ------------------------------------------------------------------ # + return { + "generate_bilevel_plan": generate_bilevel_plan, + "generate_abstract_plan": generate_abstract_plan, + "refine_plan_sketch": refine_plan_sketch, + } + + +def _build_scene_tools(ctx: ToolContext, _text_result: Callable, + tool: Callable) -> Dict[str, Any]: + """Scene tools (render / annotate / mutate env states).""" + @tool( "annotate_scene", "Draw annotations (markers, lines, rectangles) at world " "coordinates in the 3D scene, render an image, and save it. " "Use this to visualize candidate placement positions or spatial " - "relationships before committing to test_option_plan. Annotations " + "relationships before committing to evaluate_option_plan. Annotations " "are temporary and cleaned up after rendering.", { "type": "object", @@ -2177,29 +2317,40 @@ async def visualize_state(args: Dict[str, Any]) -> Dict[str, Any]: "against this modified state.") return _text_result(text) - _all = { - "inspect_types": inspect_types, - "inspect_predicates": inspect_predicates, - "inspect_processes": inspect_processes, - "inspect_options": inspect_options, - "inspect_trajectories": inspect_trajectories, - "inspect_train_tasks": inspect_train_tasks, - "inspect_planning_results": inspect_planning_results, - "inspect_past_proposals": inspect_past_proposals, - "propose_types": propose_types, - "propose_predicates": propose_predicates, - "propose_object_augmentor": propose_object_augmentor, - "propose_processes": propose_processes, - "propose_options": propose_options, - "retract_abstractions": retract_abstractions, - "test_predicate_on_states": test_predicate_on_states, - "test_planning": test_planning, - "test_option_plan": test_option_plan, - "generate_bilevel_plan": generate_bilevel_plan, - "generate_abstract_plan": generate_abstract_plan, + return { "annotate_scene": annotate_scene, "visualize_state": visualize_state, } + + +def create_mcp_tools(ctx: ToolContext, + tool_names: Optional[List[str]] = None) -> list: + """Create MCP tools with the given ToolContext via closures. + + Args: + ctx: Shared mutable state between the approach and MCP tools. + tool_names: If provided, only return tools with these names. + If None, return all tools. + + Returns a list of SdkMcpTool objects to pass to create_sdk_mcp_server. + """ + from claude_agent_sdk import \ + tool # pylint: disable=import-outside-toplevel + + # Spill oversize tool output into the sandbox (``./tool_outputs/``) + # instead of returning it inline. Each builder names its parameter + # ``_text_result`` so every nested tool's ``_text_result(...)`` call + # routes through the spiller with no call-site edits. + _text_result = _make_spilling_text_result(ctx.sandbox_dir) + + _all = { + **_build_inspection_tools(ctx, _text_result, tool), + **_build_proposal_tools(ctx, _text_result, tool), + **_build_retraction_tools(ctx, _text_result, tool), + **_build_testing_tools(ctx, _text_result, tool), + **_build_planning_tools(ctx, _text_result, tool), + **_build_scene_tools(ctx, _text_result, tool), + } if tool_names is None: tools = list(_all.values()) else: @@ -3485,3 +3636,202 @@ async def evaluate_predicate_quality( return _text("\n".join(lines)) return [evaluate_predicate_quality] + + +def create_sampler_synthesis_tools( + samplers_file: str, + samplers_versions_dir: str, + approach: Any, + cycle_index_provider: Optional[Callable[[], int]] = None, +) -> list: + """Create the per-skill sampler-synthesis tool. + + Returns ``[evaluate_sampler]``. On each call the tool loads + ``samplers.py`` fresh (snapshotting into ``samplers_versions_dir``), + validates the ``LEARNED_SAMPLERS`` dict (option name -> callable), + installs it into ``approach._synthesized_samplers`` so refinement + uses it, and reports a per-option shape/in-box sanity check. + + Args: + samplers_file: Host path to the agent-edited ``samplers.py``. + samplers_versions_dir: Directory for per-call snapshots. + approach: The ``AgentSimLearningApproach`` instance. + cycle_index_provider: Returns the current 1-indexed cycle. + """ + # pylint: disable=import-outside-toplevel + import traceback # pylint: disable=redefined-outer-name,reimported + + from claude_agent_sdk import tool + + from predicators.code_sim_learning.training import ParamSpec + + # pylint: enable=import-outside-toplevel + _text = _make_spilling_text_result(os.path.dirname(samplers_file)) + _snapshotter = _ArtifactSnapshotter( + live_file=samplers_file, + versions_dir=samplers_versions_dir, + artifact_name="samplers", + cycle_index_provider=cycle_index_provider, + missing_file_hint=("Use Write to create it with " + "LEARNED_SAMPLERS = {\"OptionName\": fn, ...}."), + ) + params_view = _ParamsView(approach._fitted_params) # pylint: disable=protected-access + + def _snapshot_and_load_samplers( + path: str, + ) -> Tuple[Dict[str, Any], Optional[str], Optional[str], List[str]]: + """Snapshot ``path`` then exec it into a fresh namespace. + + Returns ``(samplers, version_tag, error_msg, warnings)``. + Entries keyed by an unknown option name, or whose value is not + callable, are skipped and described in ``warnings``. On success, + mutates ``approach._synthesized_samplers`` to the validated + dict. + """ + raw, version_tag, err = _snapshotter.snapshot(path) + if err is not None: + return {}, None, err, [] + assert raw is not None and version_tag is not None + + ctx = build_exec_context( + types=approach._types, # pylint: disable=protected-access + predicates=approach._get_all_predicates(), # pylint: disable=protected-access + options=approach._get_all_options(), # pylint: disable=protected-access + extra_context={ + "params": params_view, + "ParamSpec": ParamSpec, + }) + result, err = exec_code_safely(raw.decode("utf-8"), ctx, + "LEARNED_SAMPLERS") + if err is not None: + return {}, version_tag, (f"[{version_tag}] Error executing " + f"{path}:\n{err}"), [] + if not isinstance(result, dict): + return {}, version_tag, ( + f"[{version_tag}] LEARNED_SAMPLERS must be a dict " + f"{{option_name: sampler_fn}}, got " + f"{type(result).__name__}."), [] + + option_names = {o.name for o in approach._get_all_options()} # pylint: disable=protected-access + valid: Dict[str, Any] = {} + warnings: List[str] = [] + for name, fn in result.items(): + if name not in option_names: + warnings.append( + f"Skipped '{name}' (not a known option name; known: " + f"{', '.join(sorted(option_names))}).") + continue + if not callable(fn): + warnings.append( + f"Skipped '{name}' (value is not callable, got " + f"{type(fn).__name__}).") + continue + valid[name] = fn + + # Mutate approach state so evaluate_plan_refinement / test-time + # refinement draw from the agent's draft samplers. + approach._synthesized_samplers = valid # pylint: disable=protected-access + return valid, version_tag, None, warnings + + def _sanity_check(name: str, fn: Any) -> str: + """Draw a few params from a representative state; report shape/box.""" + # pylint: disable=protected-access,import-outside-toplevel + import numpy as np # pylint: disable=redefined-outer-name,reimported + + from predicators.settings import \ + CFG # pylint: disable=redefined-outer-name,reimported + options_by_name = {o.name: o for o in approach._get_all_options()} + opt = options_by_name[name] + train_tasks = approach._train_tasks + if not train_tasks: + return f" {name}: no train task to sanity-check against." + state = train_tasks[0].init + # Pick the first object of each option-arg type present in the state. + objs: List[Object] = [] + for t in opt.types: + match = next((o for o in state if o.type.name == t.name), None) + if match is None: + return (f" {name}: no object of type '{t.name}' in the " + "train-task state to sanity-check against.") + objs.append(match) + box = opt.params_space + expected = box.shape[0] + rng = np.random.default_rng(CFG.seed) + in_box = 0 + n_draws = 3 + for _ in range(n_draws): + try: + raw = fn(state, set(), rng, objs) + arr = np.asarray(raw, dtype=np.float32).reshape(-1) + except Exception: # pylint: disable=broad-except + last = traceback.format_exc().strip().splitlines()[-1] + return f" {name}: ERROR — sampler raised: {last}" + if arr.shape != (expected, ): + return (f" {name}: ERROR — returned shape {arr.shape}, " + f"expected ({expected},).") + if bool(np.all(arr >= box.low - 1e-6)) and \ + bool(np.all(arr <= box.high + 1e-6)): + in_box += 1 + return (f" {name}: OK — {n_draws} draws, {in_box}/{n_draws} " + f"within the params box.") + + @tool( + "evaluate_sampler", + "Load LEARNED_SAMPLERS (fresh from `samplers.py`) and install " + "them as the per-skill samplers used by refinement. Each entry " + "maps an option name to a function " + "(state, subgoal_atoms, rng, objects) -> params array (the same " + "signature as the env's NSRT samplers); refinement calls it " + "instead of drawing uniformly so the sampler can aim continuous " + "params at the step's subgoal, then clips the result to the box. " + "Reports a per-option sanity check (return shape + within-box) " + "over a representative train-task state. After loading, the " + "samplers used by evaluate_plan_refinement are updated — so call " + "this any time you edit samplers.py before re-running " + "refinement. Snapshots samplers.py into samplers_versions/; " + "output tagged [cycle_XXX_vers_YYY].", + { + "type": "object", + "properties": {}, + }, + ) + async def evaluate_sampler(args: Dict[str, Any]) -> Dict[str, Any]: + del args + try: + samplers, version_tag, err, warnings = ( + _snapshot_and_load_samplers(samplers_file)) + except Exception: # pylint: disable=broad-except + return _text( + f"Error loading samplers.py:\n{traceback.format_exc()}") + + if err is not None: + return _text(err) + + prefix = f"[{version_tag}]" + lines = [ + f"{prefix} Sampler report — {len(samplers)} per-skill " + f"sampler(s) installed.", + ] + if warnings: + lines.append("") + lines.append("Warnings (entries skipped during load):") + for w in warnings: + lines.append(f" - {w}") + + if not samplers: + lines.append("") + lines.append("LEARNED_SAMPLERS is empty — add " + "{\"OptionName\": fn} entries to samplers.py.") + return _text("\n".join(lines)) + + lines.append("") + lines.append("Sanity check (representative train-task state):") + for name in sorted(samplers): + lines.append(_sanity_check(name, samplers[name])) + lines.append("") + lines.append("Now call evaluate_plan_refinement with a sketch that " + "uses these options to measure the samples-to-refine " + "improvement.") + return _text("\n".join(lines)) + + return [evaluate_sampler] diff --git a/predicators/approaches/agent_abstraction_learning_approach.py b/predicators/approaches/agent_abstraction_learning_approach.py deleted file mode 100644 index 363a038f0..000000000 --- a/predicators/approaches/agent_abstraction_learning_approach.py +++ /dev/null @@ -1,853 +0,0 @@ -"""Agent abstraction learning approach: online process and predicate invention. - -Uses a persistent Claude Agent SDK session to iteratively propose -abstractions (types, predicates, helper objects, processes, options) -based on observed trajectory data and planning feedback. -""" -import json -import logging -import os -from typing import Any, Callable, Dict, List, Optional, Sequence, Set - -import dill as pkl -from gym.spaces import Box - -from predicators import utils -from predicators.agent_sdk.agent_session_mixin import AgentSessionMixin -from predicators.agent_sdk.proposal_parser import ProposalBundle, \ - build_exec_context, exec_code_safely -from predicators.approaches.agent_planner_approach import AgentPlannerApproach -from predicators.approaches.pp_online_process_learning_approach import \ - OnlineProcessLearningAndPlanningApproach -from predicators.approaches.pp_predicate_invention_approach import \ - PredicateInventionProcessPlanningApproach -from predicators.explorers.base_explorer import BaseExplorer -from predicators.option_model import _OptionModelBase, create_option_model -from predicators.settings import CFG -from predicators.structs import Action, CausalProcess, Dataset, \ - EndogenousProcess, InteractionResult, LowLevelTrajectory, \ - ParameterizedOption, Predicate, State, Task, Type - - -class AgentAbstractionLearningApproach( # type: ignore[misc] - AgentPlannerApproach, PredicateInventionProcessPlanningApproach, - OnlineProcessLearningAndPlanningApproach): - """Abstraction-learning planning approach using Claude Agent SDK. - - Maintains a persistent Claude agent session that iteratively refines - abstraction proposals based on observed trajectory data and planning - feedback. The agent cannot see environment source code -- it - observes the world only through custom MCP tools. - """ - - def __init__(self, - initial_predicates: Set[Predicate], - initial_options: Set[ParameterizedOption], - types: Set[Type], - action_space: Box, - train_tasks: List[Task], - task_planning_heuristic: str = "default", - max_skeletons_optimized: int = -1, - bilevel_plan_without_sim: Optional[bool] = None, - option_model: Optional[_OptionModelBase] = None) -> None: - # Agent-specific attributes (before super().__init__) - self._helper_types: Set[Type] = set() - self._augment_task_fn: Optional[Callable] = None - self._augment_task_code: str = "" - self._agent_proposed_options: Set[ParameterizedOption] = set() - self._agent_proposed_processes: Set[CausalProcess] = set() - self._iteration_history: List[Dict[str, Any]] = [] - self._planning_results: Dict[str, Any] = {} - self._last_context_message: str = "" - self._last_agent_responses: List[Any] = [] - self._agent_session_id: Optional[str] = None - self._option_model = create_option_model(CFG.option_model_name) - - self._init_agent_session_state(types, initial_predicates, - initial_options, train_tasks) - - super().__init__(initial_predicates, - initial_options, - types, - action_space, - train_tasks, - task_planning_heuristic, - max_skeletons_optimized, - bilevel_plan_without_sim, - option_model=option_model) - - @classmethod - def get_name(cls) -> str: - return "agent_abstraction_learning" - - # ------------------------------------------------------------------ # - # AgentSessionMixin hooks - # ------------------------------------------------------------------ # - - def _get_log_dir(self) -> str: - """Use the mixin's simple log dir (no run_id subdirectory).""" - # pylint: disable-next=protected-access - return AgentSessionMixin._get_log_dir(self) - - def _get_agent_system_prompt(self) -> str: - return _SYSTEM_PROMPT - - # ------------------------------------------------------------------ # - # Overridable helpers (from AgentPlannerApproach) - # ------------------------------------------------------------------ # - - def _get_all_options(self) -> Set[ParameterizedOption]: - return self._initial_options | self._agent_proposed_options - - def _get_all_predicates(self) -> Set[Predicate]: - return self._get_current_predicates() - - def _get_all_trajectories(self) -> list: - return (self._offline_dataset.trajectories + - self._online_dataset.trajectories) - - # ------------------------------------------------------------------ # - # Learning - # ------------------------------------------------------------------ # - - def learn_from_offline_dataset(self, dataset: Dataset) -> None: - """Store the offline dataset. - - Do NOT start agent session yet. - """ - self._offline_dataset = dataset - self._tool_context.offline_trajectories = dataset.trajectories - # Set example state from first trajectory - if dataset.trajectories: - self._tool_context.example_state = \ - dataset.trajectories[0].states[0] - self.save() - - def learn_from_interaction_results( - self, results: Sequence[InteractionResult]) -> None: - """Learn from interaction results via the Claude agent.""" - # 1. Convert results to trajectories, append to online dataset - assert self._requests_train_task_idxs is not None - for i, result in enumerate(results): - task_idx = self._requests_train_task_idxs[i] - traj = LowLevelTrajectory(result.states, - result.actions, - _train_task_idx=task_idx) - self._online_dataset.append(traj) - - all_trajs = self._offline_dataset.trajectories + \ - self._online_dataset.trajectories - - # 2. Update tool context with current state - self._sync_tool_context(all_trajs) - - # 3. Run agent iteration - self._run_agent_iteration(all_trajs) - - # 4. Integrate proposals from tool context - proposals = self._tool_context.iteration_proposals - self._integrate_proposals(proposals) - - # 5. Use agent-proposed processes (not data-driven learning) - # The processes are already integrated in _integrate_proposals - # Optionally learn parameters for the agent-proposed processes - if CFG.learn_process_parameters and self._get_current_processes(): - self._learn_process_parameters(all_trajs) - - # 7. Log iteration summary - summary = self._build_iteration_summary(proposals) - self._iteration_history.append(summary) - self._tool_context.iteration_history = self._iteration_history - logging.info(f"Iteration {self._online_learning_cycle} summary: " - f"{json.dumps(summary, default=str)}") - - # 8. Save and log agent responses - self._save_iteration_logs(self._online_learning_cycle) - self.save(self._online_learning_cycle) - - # 9. Increment cycle - self._online_learning_cycle += 1 - - # pylint: disable-next=arguments-differ - def _sync_tool_context( # type: ignore[override] - self, all_trajs: List[LowLevelTrajectory]) -> None: - """Synchronize ToolContext with current approach state.""" - self._tool_context.types = self._types - self._tool_context.predicates = self._get_current_predicates() - self._tool_context.processes = self._get_current_processes() - self._tool_context.options = self._initial_options | \ - self._agent_proposed_options - self._tool_context.train_tasks = self._train_tasks - self._tool_context.offline_trajectories = \ - self._offline_dataset.trajectories - self._tool_context.online_trajectories = \ - self._online_dataset.trajectories - self._tool_context.planning_results = self._planning_results - self._tool_context.iteration_history = self._iteration_history - self._tool_context.option_model = self._option_model - self._tool_context.iteration_id = self._online_learning_cycle - - if all_trajs: - self._tool_context.example_state = all_trajs[0].states[0] - - # Reset proposals for this iteration - self._tool_context.iteration_proposals = ProposalBundle() - - def _run_agent_iteration(self, - all_trajs: List[LowLevelTrajectory]) -> None: - """Build iteration message and query the Claude agent.""" - self._ensure_agent_session() - - # Build the iteration message - num_new = len(self._online_dataset.trajectories) - num_total = len(all_trajs) - task_success = self._compute_task_success_rate(all_trajs) - - type_str = ", ".join( - f"{t.name}[{','.join(t.feature_names)}]" - for t in sorted(self._types, key=lambda t: t.name)) - preds = self._get_current_predicates() - pred_str = ", ".join(f"{p.name}({','.join(t.name for t in p.types)})" - for p in sorted(preds, key=lambda p: p.name)) - procs = self._get_current_processes() - proc_str = ", ".join(p.name - for p in sorted(procs, key=lambda p: p.name)) - opt_str = ", ".join( - o.name - for o in sorted(self._initial_options, key=lambda o: o.name)) - - plan_success = self._planning_results.get("success_str", - "Not yet evaluated") - avg_nodes = str(self._planning_results.get("avg_nodes_expanded", - "N/A")) - failures = self._planning_results.get("failure_summaries", - "None recorded") - - prev_outcomes = "No previous iterations." if not \ - self._iteration_history else json.dumps( - self._iteration_history[-1], default=str, indent=2) - - message = build_iteration_message( - cycle=self._online_learning_cycle, - num_new_trajs=num_new, - num_total_trajs=num_total, - task_success_rate=task_success, - type_names_with_features=type_str, - predicate_signatures=pred_str, - num_predicates=len(preds), - process_summaries=proc_str, - num_processes=len(procs), - option_names=opt_str, - num_options=len(self._initial_options), - planning_success=plan_success, - avg_nodes=avg_nodes, - failure_summaries=failures, - previous_iteration_outcomes=prev_outcomes, - available_tools=self._agent_session.tool_names - if self._agent_session else None, - ) - - # Save the context message - self._last_context_message = message - - # Run async query via mixin helper - self._last_agent_responses = self._query_agent_sync(message, - kind="learn") - - def _integrate_proposals(self, proposals: ProposalBundle) -> None: - """Integrate validated proposals into approach state.""" - # Types - if proposals.proposed_types: - self._types = self._types | proposals.proposed_types - self._helper_types |= proposals.proposed_types - logging.info(f"Integrated {len(proposals.proposed_types)} " - f"new types") - - # Predicates - if proposals.proposed_predicates: - self._learned_predicates |= proposals.proposed_predicates - logging.info(f"Integrated {len(proposals.proposed_predicates)} " - f"new predicates") - - # Task augmentor - if proposals.augment_task_fn is not None: - self._augment_task_fn = proposals.augment_task_fn - self._augment_task_code = proposals.augment_task_code or "" - logging.info("Integrated new task augmentor") - - # Processes (agent-proposed, not data-driven) - if proposals.proposed_processes: - self._agent_proposed_processes |= proposals.proposed_processes - logging.info(f"Integrated {len(proposals.proposed_processes)} " - f"new processes (total: " - f"{len(self._get_current_processes())})") - - # Options - if proposals.proposed_options: - self._agent_proposed_options |= proposals.proposed_options - logging.info(f"Integrated {len(proposals.proposed_options)} " - f"new options") - - # Retractions - if proposals.retract_type_names: - removed = { - t - for t in self._helper_types - if t.name in proposals.retract_type_names - } - self._helper_types -= removed - self._types -= removed - logging.info(f"Retracted {len(removed)} helper types: " - f"{[t.name for t in removed]}") - - if proposals.retract_predicate_names: - before = len(self._learned_predicates) - self._learned_predicates = { - p - for p in self._learned_predicates - if p.name not in proposals.retract_predicate_names - } - logging.info( - f"Retracted " - f"{before - len(self._learned_predicates)} predicates") - - if proposals.retract_object_augmentor: - self._augment_task_fn = None - self._augment_task_code = "" - logging.info("Retracted object augmentor") - - if proposals.retract_process_names: - before = len(self._agent_proposed_processes) - self._agent_proposed_processes = { - p - for p in self._agent_proposed_processes - if p.name not in proposals.retract_process_names - } - logging.info(f"Retracted " - f"{before - len(self._agent_proposed_processes)} " - f"processes") - - if proposals.retract_option_names: - before = len(self._agent_proposed_options) - self._agent_proposed_options = { - o - for o in self._agent_proposed_options - if o.name not in proposals.retract_option_names - } - logging.info(f"Retracted " - f"{before - len(self._agent_proposed_options)} " - f"options") - - def _get_current_processes(self) -> Set[CausalProcess]: - """Get current processes including agent-proposed ones.""" - return self._processes | self._agent_proposed_processes - - def _compute_task_success_rate(self, - trajs: List[LowLevelTrajectory]) -> float: - """Compute fraction of trajectories that achieved their task goal.""" - if not trajs: - return 0.0 - successes = 0 - counted = 0 - for traj in trajs: - idx = traj._train_task_idx # pylint: disable=protected-access - if idx is not None and \ - idx < len(self._train_tasks): - task = self._train_tasks[idx] - goal_preds = {a.predicate for a in task.goal} - final_atoms = utils.abstract(traj.states[-1], goal_preds) - if task.goal.issubset(final_atoms): - successes += 1 - counted += 1 - return successes / max(counted, 1) - - def _solve(self, task: Task, timeout: int) -> Callable[[State], Action]: - """Solve via agent-driven option plan generation.""" - if self._augment_task_fn is not None: - try: - task = self._augment_task_fn(task) - except Exception as e: # pylint: disable=broad-except - logging.warning(f"Task augmentation failed: {e}. " - f"Using original task.") - - all_trajs = self._get_all_trajectories() - self._tool_context.current_task = task - self._sync_tool_context(all_trajs) - try: - return super()._solve(task, timeout) - finally: - self._tool_context.current_task = None - - def _build_solve_prompt(self, task: Task) -> str: - """Build the prompt for generating an option plan.""" - init_state = task.init - objects = list(init_state) - - # Objects - obj_strs = [] - for obj in sorted(objects, key=lambda o: o.name): - obj_strs.append(f" {obj.name}: {obj.type.name}") - - # Goal - goal_strs = [str(a) for a in sorted(task.goal, key=str)] - - # Options (include agent-proposed) - option_strs = [] - for opt in sorted(self._get_all_options(), key=lambda o: o.name): - type_sig = ", ".join(t.name for t in opt.types) - params_dim = opt.params_space.shape[0] - if params_dim > 0: - low = opt.params_space.low.tolist() - high = opt.params_space.high.tolist() - param_info = (f", params_dim={params_dim}, " - f"low={low}, high={high}") - else: - param_info = "" - option_strs.append(f" {opt.name}({type_sig}{param_info})") - - # Current atoms (include learned predicates) - atoms = utils.abstract(init_state, self._get_all_predicates()) - atom_strs = [str(a) for a in sorted(atoms, key=str)] - - # Trajectory summary - traj_summary = self._build_trajectory_summary() - - # State features - state_str = init_state.dict_str(indent=2) - - # Processes summary - procs = self._get_current_processes() - proc_strs = [] - for proc in sorted(procs, key=lambda p: p.name): - conds = ", ".join(str(a) for a in sorted(proc.condition_at_start)) - adds = ", ".join(str(a) for a in sorted(proc.add_effects)) - dels = ", ".join(str(a) for a in sorted(proc.delete_effects)) - proc_strs.append(f" {proc.name}: conds={{{conds}}}, " - f"add={{{adds}}}, del={{{dels}}}") - - proc_section = "" - if proc_strs: - proc_section = (f"\n## Processes ({len(procs)})\n" + - "\n".join(proc_strs) + "\n") - - prompt = f"""You are solving a task. Generate an option plan \ -to achieve the goal. - -## Goal -{chr(10).join(goal_strs)} - -## Initial State Atoms -{chr(10).join(atom_strs)} - -## Initial State Features -{state_str} - -## Objects -{chr(10).join(obj_strs)} - -## Available Options -{chr(10).join(option_strs)} -{proc_section}{traj_summary} -## Available Tools -You have access to planning tools: - - generate_bilevel_plan: Get a complete plan with sampled params from the bilevel planner - - generate_abstract_plan: Get a plan skeleton with parameter space info - - test_option_plan: Test an option plan on the current task - - inspect_trajectories, inspect_options, inspect_predicates, etc. - -## Instructions -Use your available tools to generate and test plans before committing. - -Recommended workflow: -1. Call generate_bilevel_plan (no task_idx needed - uses current task) to get a baseline plan -2. Optionally call test_option_plan to verify the plan works -3. Adjust parameters if needed and test again - -Output the final plan with one option per line in this exact format: - OptionName(obj1:type1, obj2:type2)[param1, param2] - -If an option has no continuous parameters, use empty brackets: OptionName(obj1:type1)[] - -Output ONLY the option plan lines at the end, after any analysis.""" - - return prompt - - # ------------------------------------------------------------------ # - # Explorer - # ------------------------------------------------------------------ # - - def _create_explorer(self) -> BaseExplorer: - """Create explorer, passing agent context if using agent explorer.""" - if CFG.explorer == "agent_plan": - all_trajs = (self._offline_dataset.trajectories + - self._online_dataset.trajectories) - self._sync_tool_context(all_trajs) - preds = self._get_current_predicates() - return self._create_agent_explorer( - preds, self._initial_options | self._agent_proposed_options) - return super()._create_explorer() - - # ------------------------------------------------------------------ # - # Iteration summary / logs - # ------------------------------------------------------------------ # - - def _build_iteration_summary(self, - proposals: ProposalBundle) -> Dict[str, Any]: - """Build a summary dict of what happened this iteration.""" - return { - "cycle": self._online_learning_cycle, - "proposed_types": [t.name for t in proposals.proposed_types], - "proposed_predicates": - [p.name for p in proposals.proposed_predicates], - "proposed_augmentor": proposals.augment_task_code is not None, - "proposed_processes": - [p.name for p in proposals.proposed_processes], - "proposed_options": [o.name for o in proposals.proposed_options], - "retracted_types": sorted(proposals.retract_type_names), - "retracted_predicates": sorted(proposals.retract_predicate_names), - "retracted_augmentor": proposals.retract_object_augmentor, - "retracted_processes": sorted(proposals.retract_process_names), - "retracted_options": sorted(proposals.retract_option_names), - "errors": proposals.errors, - "total_predicates": len(self._get_current_predicates()), - "total_processes": len(self._get_current_processes()), - } - - def _save_iteration_logs(self, cycle: int) -> None: - """Save iteration-specific logs to disk.""" - log_dir = os.path.join(self._get_log_dir(), f"iteration_{cycle}") - os.makedirs(log_dir, exist_ok=True) - - # Context message - if hasattr(self, '_last_context_message'): - with open(os.path.join(log_dir, "context_message.txt"), - "w", - encoding="utf-8") as f: - f.write(self._last_context_message) - - # Agent responses - if CFG.agent_sdk_log_agent_responses and \ - hasattr(self, '_last_agent_responses'): - resp_path = os.path.join(log_dir, "agent_responses.jsonl") - with open(resp_path, "w", encoding="utf-8") as f: - for resp in self._last_agent_responses: - f.write(json.dumps(resp, default=str) + "\n") - - # Proposals directory - proposals_dir = os.path.join(log_dir, "proposals") - os.makedirs(proposals_dir, exist_ok=True) - - proposals = self._tool_context.iteration_proposals - if proposals.proposed_types: - with open(os.path.join(proposals_dir, "types.json"), - "w", - encoding="utf-8") as f: - json.dump([t.name for t in proposals.proposed_types], - f, - indent=2) - if proposals.proposed_predicates: - with open(os.path.join(proposals_dir, "predicates_validated.json"), - "w", - encoding="utf-8") as f: - json.dump([p.name for p in proposals.proposed_predicates], - f, - indent=2) - if proposals.augment_task_code: - with open(os.path.join(proposals_dir, "augmentor_code.py"), - "w", - encoding="utf-8") as f: - f.write(proposals.augment_task_code) - if proposals.proposed_processes: - with open(os.path.join(proposals_dir, "processes_code.json"), - "w", - encoding="utf-8") as f: - json.dump([p.name for p in proposals.proposed_processes], - f, - indent=2) - - any_retractions = any([ - proposals.retract_type_names, - proposals.retract_predicate_names, - proposals.retract_object_augmentor, - proposals.retract_process_names, - proposals.retract_option_names, - ]) - if any_retractions: - with open(os.path.join(proposals_dir, "retractions.json"), - "w", - encoding="utf-8") as f: - json.dump( - { - "types": sorted(proposals.retract_type_names), - "predicates": sorted( - proposals.retract_predicate_names), - "augmentor": proposals.retract_object_augmentor, - "processes": sorted(proposals.retract_process_names), - "options": sorted(proposals.retract_option_names), - }, - f, - indent=2) - - # Session info - if self._agent_session is not None: - self._agent_session.save_session_info() - - # ------------------------------------------------------------------ # - # Save / Load - # ------------------------------------------------------------------ # - - def save(self, online_learning_cycle: Optional[int] = None) -> None: - """Save approach state.""" - save_path = utils.get_approach_save_path_str() - with open( - f"{save_path}_{online_learning_cycle}.AgentAbstractionLearning", - "wb") as f: - save_dict = { - "processes": - self._processes, - "learned_predicates": - self._learned_predicates, - "offline_dataset": - self._offline_dataset, - "online_dataset": - self._online_dataset, - "online_learning_cycle": - self._online_learning_cycle, - "helper_types": - self._helper_types, - "augment_task_code": - self._augment_task_code, - "agent_proposed_options": - self._agent_proposed_options, - "agent_proposed_processes": - self._agent_proposed_processes, - "iteration_history": - self._iteration_history, - "agent_session_id": (self._agent_session.session_id - if self._agent_session else None), - } - pkl.dump(save_dict, f) - logging.info(f"Saved approach to {save_path}_" - f"{online_learning_cycle}.AgentAbstractionLearning") - - def load(self, online_learning_cycle: Optional[int] = None) -> None: - """Load previously saved approach state.""" - save_path = utils.get_approach_load_path_str() - with open( - f"{save_path}_{online_learning_cycle}.AgentAbstractionLearning", - "rb") as f: - save_dict = pkl.load(f) - - self._processes = save_dict["processes"] - self._learned_predicates = save_dict["learned_predicates"] - self._offline_dataset = save_dict["offline_dataset"] - self._online_dataset = save_dict["online_dataset"] - self._online_learning_cycle = save_dict["online_learning_cycle"] + 1 - self._helper_types = save_dict.get("helper_types", set()) - self._augment_task_code = save_dict.get("augment_task_code", "") - self._agent_proposed_options = save_dict.get("agent_proposed_options", - set()) - self._agent_proposed_processes = save_dict.get( - "agent_proposed_processes", set()) - self._iteration_history = save_dict.get("iteration_history", []) - self._agent_session_id = save_dict.get("agent_session_id") - - # Re-exec augment_task_code to restore the function - if self._augment_task_code: - exec_ctx = build_exec_context(self._types, - self._get_current_predicates(), - self._initial_options) - result, error = exec_code_safely(self._augment_task_code, exec_ctx, - "augment_task") - if error: - logging.warning( - f"Failed to restore augment_task function: {error}") - self._augment_task_fn = None - else: - self._augment_task_fn = result - - # Restore types - self._types = self._types | self._helper_types - - # Reseed options - for proc in self._processes: - if isinstance(proc, EndogenousProcess): - proc.option.params_space.seed(CFG.seed) - - logging.info( - f"Loaded {len(self._processes)} processes, " - f"{len(self._learned_predicates)} learned predicates, " - f"{len(self._offline_dataset.trajectories)} offline trajectories, " - f"{len(self._online_dataset.trajectories)} online trajectories") - - -# ------------------------------------------------------------------ # -# Prompt helpers (abstraction-learning specific) -# ------------------------------------------------------------------ # - -_SYSTEM_PROMPT = """\ -You are an abstraction inventor for a bilevel process planning system. Your \ -role is to propose types, predicates, helper objects, processes, and options \ -that help a task planner solve planning problems. - -## What You Observe - -You observe the world ONLY through: -- **Trajectory data**: sequences of states (feature vectors per object) and \ -actions -- **Task goals**: symbolic goal descriptions -- **Planning metrics**: success rate, nodes expanded, failure reasons -- **Current abstractions**: the types, predicates, processes, and options \ -currently in use - -You do NOT have access to environment source code, simulator internals, or \ -ground-truth models. You must infer useful abstractions from observed data. - -## What You Can Propose - -1. **Types**: New object types with named features -2. **Predicates**: Boolean classifiers over states and objects -3. **Helper Objects / Task Augmentation**: Functions that add helper objects \ -to tasks (e.g., grid locations, reference frames) -4. **Processes**: Causal processes (exogenous events triggered by conditions) -5. **Options**: Parameterized actions - -## Code Conventions - -When writing proposal code, the following variables are available in the exec \ -context: - -### Imports (already available — no need to import) -- `np`, `numpy`, `torch` -- `Box` (from gym.spaces) -- `Type`, `Predicate`, `DerivedPredicate`, `NSPredicate` -- `Object`, `Variable`, `LiftedAtom`, `GroundAtom` -- `ExogenousProcess`, `EndogenousProcess`, `CausalProcess` -- `ParameterizedOption`, `State`, `Task` -- `ConstantDelay`, `DiscreteGaussianDelay` -- `List`, `Set`, `Sequence` (from typing) - -### Current abstractions -- Each type `T` is available as `T_type` (e.g., `domino_type`, `robot_type`) -- Each predicate `P` is available by name (e.g., `Fallen`, `Standing`) -- Each predicate classifier is available as `_P_holds` \ -(e.g., `_Fallen_holds`) -- Each option `O` is available by name (e.g., `Push`) - -### Expected output variables per proposal tool -- `propose_types`: must define `proposed_types` (a list of Type objects) -- `propose_predicates`: must define `proposed_predicates` \ -(a list of Predicate objects) -- `propose_object_augmentor`: must define `augment_task(task) -> Task` -- `propose_processes`: must define `proposed_processes` \ -(a list of CausalProcess objects) -- `propose_options`: must define `proposed_options` \ -(a list of ParameterizedOption objects) - -## Key API Reference - -### State -```python -state.get(obj, "feature_name") # get a feature value -state.set(obj, "feature_name", value) # set a feature value -state.get_objects(some_type) # get all objects of a type -list(state) # iterate over all objects -state.copy() # copy the state -``` - -### Predicate -```python -pred = Predicate("MyPred", [type1_type, type2_type], - lambda state, objects: state.get(objects[0], "feat") > 0.5) -pred.holds(state, [obj1, obj2]) # evaluate -``` - -### Process (ExogenousProcess) -```python -v1 = Variable("?x", some_type) -v2 = Variable("?y", other_type) -proc = ExogenousProcess( - name="MyProcess", - parameters=[v1, v2], - condition_at_start={LiftedAtom(SomePred, [v1, v2])}, - condition_overall={LiftedAtom(SomePred, [v1, v2])}, - condition_at_end=set(), - add_effects={LiftedAtom(ResultPred, [v1])}, - delete_effects=set(), - delay_distribution=ConstantDelay(1), - strength=torch.tensor([1.0]), -) -``` - -### Type -```python -my_type = Type("my_type", ["feature1", "feature2"]) -``` - -## Iteration Protocol - -At each learning iteration: -1. **Inspect** the trajectory data and planning results using inspection tools -2. **Form hypotheses** about what abstractions are missing or insufficient -3. **Propose** new abstractions using proposal tools -4. **Test** your proposals using testing tools -5. **Refine** based on test results - fix errors and retry - -Focus on proposing abstractions that will help the planner solve more tasks. \ -Pay attention to: -- States where planning fails - what conditions are missing? -- Patterns in trajectory data that aren't captured by current predicates -- Whether helper objects (like grid positions) could simplify the problem -""" - - -def build_iteration_message( - cycle: int, - num_new_trajs: int, - num_total_trajs: int, - task_success_rate: float, - type_names_with_features: str, - predicate_signatures: str, - num_predicates: int, - process_summaries: str, - num_processes: int, - option_names: str, - num_options: int, - planning_success: str, - avg_nodes: str, - failure_summaries: str, - previous_iteration_outcomes: str, - available_tools: Optional[List[Any]] = None) -> str: - """Build the message sent to the agent at each iteration.""" - tools_section = "" - if available_tools: - tool_list = "\n".join(f" - {t}" for t in available_tools) - tools_section = f"\nAVAILABLE TOOLS:\n{tool_list}\n" - - return f"""\ -== Online Learning Iteration {cycle} == - -TRAJECTORY SUMMARY: -- {num_new_trajs} new trajectories collected this cycle -- {num_total_trajs} total trajectories (offline + online) -- Task success rate: {task_success_rate:.1%} - -CURRENT ABSTRACTIONS: -- Types: {type_names_with_features} -- Predicates ({num_predicates}): {predicate_signatures} -- Processes ({num_processes}): {process_summaries} -- Options ({num_options}): {option_names} - -PLANNING PERFORMANCE: -{planning_success} -- Avg nodes expanded: {avg_nodes} -- Failures: {failure_summaries} - -PREVIOUS ITERATION OUTCOMES: -{previous_iteration_outcomes} -{tools_section} -YOUR TASK: -Inspect the trajectory data and planning results. Propose new or improved \ -abstractions that will help the planner solve more tasks. Use the proposal \ -tools to register your proposals and the testing tools to validate them. -""" diff --git a/predicators/approaches/agent_bilevel_approach.py b/predicators/approaches/agent_bilevel_approach.py index ddd06df79..4226a780d 100644 --- a/predicators/approaches/agent_bilevel_approach.py +++ b/predicators/approaches/agent_bilevel_approach.py @@ -13,7 +13,9 @@ --num_online_learning_cycles 1 --explorer agent_plan """ import logging +import os import time +from collections import Counter from typing import Any, Callable, List, Optional, Sequence, Set, Tuple import numpy as np @@ -68,6 +70,14 @@ def reset_for_new_episode(self) -> None: super().reset_for_new_episode() self._exec_status = None self._exec_replans_left = CFG.agent_bilevel_max_execution_replans + # Optionally give each test solve a fresh agent conversation: close + # the session here (once per test task, before its first solve; not + # on mid-episode replans, which go through step() not reset()). The + # next query lazily rebuilds the session — same sandbox + learned + # artifacts, empty chat context. Gated to the test phase so + # exploration episodes keep their shared session. + if CFG.agent_fresh_session_per_test_task and self._in_test_phase: + self._close_agent_session() def get_execution_monitoring_info(self) -> List[Any]: if self._exec_status is None: @@ -126,14 +136,18 @@ def _get_agent_system_prompt(self) -> str: # Solve prompt (no continuous params, subgoal format) # ------------------------------------------------------------------ # - def _build_solve_prompt(self, task: Task) -> str: + def _build_solve_prompt(self, + task: Task, + prior_failures: Optional[List[str]] = None) -> str: """Build prompt asking for a plan sketch without continuous params.""" + failures_text = "\n\n".join(prior_failures) if prior_failures else "" return bilevel_sketch.build_solve_prompt( task, all_predicates=self._get_all_predicates(), all_options=self._get_all_options(), trajectory_summary=self._build_trajectory_summary(), tool_names=self._get_solve_tool_names(), + prior_failures=failures_text, ) # ------------------------------------------------------------------ # @@ -149,16 +163,34 @@ def _solve(self, task: Task, timeout: int) -> Callable[[State], Action]: self._sync_tool_context() self._tool_context.current_task = task start = time.perf_counter() - + # Exclude the (minutes-long) LLM sketch query from the refinement + # budget, else a slow query overruns `timeout` and starves the + # refine loop -- failing the solve without ever refining. + llm_query_time = 0.0 + + def _refine_remaining() -> float: + elapsed = time.perf_counter() - start - llm_query_time + return timeout - elapsed + + sketches_tried = 0 + # Pre-formatted summaries of earlier sketches the search could not + # refine; threaded into the next sketch query so the agent revises + # the dead skeleton instead of re-emitting it. + prior_failures: List[str] = [] for sketch_attempt in range(max_sketch_retries): - if timeout - (time.perf_counter() - start) <= 0: + if _refine_remaining() <= 0: break + query_start = time.perf_counter() try: - sketch = self._query_agent_for_plan_sketch(task) + sketch = self._query_agent_for_plan_sketch( + task, prior_failures=prior_failures) except Exception as e: # pylint: disable=broad-except + llm_query_time += time.perf_counter() - query_start logging.warning("Sketch query failed (attempt %d): %s", sketch_attempt, e) continue + llm_query_time += time.perf_counter() - query_start + sketches_tried += 1 sketch_lines = [] for i, s in enumerate(sketch): @@ -171,13 +203,19 @@ def _solve(self, task: Task, timeout: int) -> Callable[[State], Action]: logging.info("[%s] Sketch (attempt %d):\n%s", self._run_id, sketch_attempt, "\n".join(sketch_lines)) + # Aggregate per-step failures across this sketch's refine + # retries (same skeleton, so the obstruction is the same): + # deepest step the search reached, and a tally of the distinct + # failure reasons it hit there and earlier. + record_fail, fail_state = self._make_step_fail_recorder() + # Resample continuous params with a fresh seed before paying # for another agent query: a sketch that refines but fails # forward validation is a continuous-params problem, not a # wrong skeleton, and re-querying rarely changes the skeleton # while always costing an LLM call. for refine_attempt in range(max_refine_retries): - remaining = timeout - (time.perf_counter() - start) + remaining = _refine_remaining() if remaining <= 0: break # Flatten the two loop indices so every (sketch, refine) @@ -187,12 +225,19 @@ def _solve(self, task: Task, timeout: int) -> Callable[[State], Action]: plan, success = self._refine_sketch(task, sketch, remaining, - attempt=seed_offset) + attempt=seed_offset, + on_step_fail=record_fail) if not success: + reason_msg = "" + if fail_state["deepest_idx"] >= 0: + reason_msg = ( + f" (stuck at step {fail_state['deepest_idx']}: " + f"{fail_state['deepest_reason']})") + logging.info( f"Refinement failed (sketch " f"{sketch_attempt}, refine {refine_attempt}), " - f"{len(sketch)} steps.") + f"{len(sketch)} steps{reason_msg}.") continue plan_strs = [] @@ -230,22 +275,52 @@ def _solve(self, task: Task, timeout: int) -> Callable[[State], Action]: f"{refine_attempt}): {reason}") # Fall through to the next seed on the same sketch. + # Every refine retry for this skeleton failed: save a full + # per-step refinement log to the sandbox and add a preview + + # pointer so the next sketch query revises this dead skeleton. + preview = self._record_refinement_failure( + sketch_attempt, sketch_lines, sketch, + fail_state["deepest_idx"], fail_state["deepest_reason"], + fail_state["counts"]) + if preview: + prior_failures.append(preview) + raise ApproachFailure( - f"Bilevel solve failed after {max_sketch_retries} sketches.") + f"Bilevel solve failed after {sketches_tried} sketch(es) " + f"(LLM query time {llm_query_time:.1f}s excluded from the " + f"{timeout}s refinement budget).") # ------------------------------------------------------------------ # # Plan sketch extraction # ------------------------------------------------------------------ # - def _query_agent_for_plan_sketch(self, task: Task) -> List[_SketchStep]: - """Query agent for a plan sketch and parse it.""" + def _query_agent_for_plan_sketch( + self, + task: Task, + prior_failures: Optional[List[str]] = None) -> List[_SketchStep]: + """Query agent for a plan sketch and parse it. + + ``prior_failures`` carries preview+pointer blocks for earlier + sketches the search could not refine; they are injected into the + prompt so the re-query revises the dead skeleton. + """ sketch_file = CFG.agent_bilevel_plan_sketch_file if sketch_file: - with open(sketch_file, "r", encoding="utf-8") as f: + # An absolute path is used as-is; a bare name resolves under + # scripts//. + if os.path.isabs(sketch_file): + filepath = sketch_file + else: + filepath = os.path.join(utils.get_path_to_predicators_root(), + "scripts", + CFG.agent_bilevel_plan_sketch_dir, + sketch_file) + with open(filepath, "r", encoding="utf-8") as f: plan_text = f.read().strip() logging.info("Loaded plan sketch from file: %s", sketch_file) else: - prompt = self._build_solve_prompt(task) + prompt = self._build_solve_prompt(task, + prior_failures=prior_failures) responses = self._query_agent_sync(prompt, kind="test") plan_text = self._extract_option_plan_text(responses) @@ -272,6 +347,109 @@ def _query_agent_for_plan_sketch(self, task: Task) -> List[_SketchStep]: f"with subgoals.") return sketch + @staticmethod + def _make_step_fail_recorder( + ) -> Tuple[Callable[[int, List[Optional[_Option]], str], None], "dict"]: + """Build an ``on_step_fail`` callback and its accumulator state. + + Returns ``(callback, state)`` where ``state`` is a dict with + keys ``deepest_idx`` (the deepest step index the search reached + before failing), ``deepest_reason`` (the failure reason there), + and ``counts`` (a ``Counter`` over ``(step_idx, reason)``). + Built as a factory so the closure captures fresh per-sketch + state instead of loop variables. + """ + state: dict = { + "deepest_idx": -1, + "deepest_reason": "", + "counts": Counter(), + } + + def _record(idx: int, _plan: List[Optional[_Option]], + reason: str) -> None: + state["counts"][(idx, reason)] += 1 + if idx > state["deepest_idx"]: + state["deepest_idx"] = idx + state["deepest_reason"] = reason + + return _record, state + + def _record_refinement_failure( + self, + attempt_idx: int, + sketch_lines: List[str], + sketch: List[_SketchStep], + deepest_idx: int, + deepest_reason: str, + reason_counts: "Counter[Tuple[int, str]]", + ) -> str: + """Persist a full refinement-failure log to the sandbox and return a + preview+pointer block for the next sketch prompt. + + Writes ``/refinement_logs/sketch__refine.md`` with the + tried skeleton, where backtracking got stuck (deepest step), and a + per-step tally of the distinct failure reasons. The returned block + embeds a short preview and a relative pointer to that file so the + agent can ``Read`` the detail. Returns ``""`` if there is nothing + to report (no recorded failures). + """ + if not reason_counts: + return "" + + def _step_desc(idx: int) -> str: + if 0 <= idx < len(sketch): + objs = ", ".join(o.name for o in sketch[idx].objects) + return f"step {idx}: {sketch[idx].option.name}({objs})" + return f"step {idx}" + + total_fail = sum(reason_counts.values()) + deepest_desc = _step_desc(deepest_idx) + + full_lines = [ + f"# Refinement failure — sketch attempt {attempt_idx}", + "", + "## Sketch (could not be refined)", + *sketch_lines, + "", + "## Outcome", + f"FAILED. Deepest step the search reached: {deepest_desc}.", + f"Dominant failure there: {deepest_reason}", + f"Total failed samples: {total_fail}.", + "", + "## Per-step failure reasons (count)", + ] + for (idx, reason), cnt in sorted(reason_counts.items(), + key=lambda kv: (kv[0][0], -kv[1])): + full_lines.append(f"- {_step_desc(idx)}: {cnt}x {reason}") + full_text = "\n".join(full_lines) + "\n" + + # Prefer the agent-visible sandbox cwd so the pointer is a valid + # relative path for the agent; fall back to the run log dir. + sandbox = getattr(self._tool_context, "sandbox_dir", None) \ + or self._get_log_dir() + rel_dir = "refinement_logs" + out_dir = os.path.join(sandbox, rel_dir) + os.makedirs(out_dir, exist_ok=True) + fname = f"sketch_{attempt_idx:02d}_refine.md" + try: + with open(os.path.join(out_dir, fname), "w", + encoding="utf-8") as f: + f.write(full_text) + pointer = f"./{rel_dir}/{fname}" + except OSError as e: # pragma: no cover - best-effort logging + logging.warning("Could not write refinement log: %s", e) + pointer = "(refinement log unavailable)" + + preview = "\n".join([ + f"### Attempt {attempt_idx} (FAILED)", + *sketch_lines, + f" -> Refinement FAILED. Deepest step reached: {deepest_desc}. " + f"Dominant failure: {deepest_reason} " + f"({total_fail} failed samples).", + f" Full per-step refinement log: {pointer}", + ]) + return preview + # ------------------------------------------------------------------ # # Backtracking refinement # ------------------------------------------------------------------ # @@ -282,6 +460,8 @@ def _refine_sketch( sketch: List[_SketchStep], timeout: float, attempt: int = 0, + on_step_fail: Optional[Callable[[int, List[Optional[_Option]], str], + None]] = None, ) -> Tuple[List[_Option], bool]: """Backtracking search over continuous parameters for a plan sketch. @@ -315,6 +495,8 @@ def _refine_sketch( check_subgoals=CFG.agent_bilevel_check_subgoals, log_state=CFG.agent_bilevel_log_state, run_id=self._run_id, + option_samplers=self._get_all_samplers(), + on_step_fail=on_step_fail, ) return plan, success @@ -446,14 +628,7 @@ def _option_policy(state: State) -> _Option: inner = utils.option_policy_to_policy(_option_policy, abstract_function=_abstract) - - def _policy(s: State) -> Action: - try: - return inner(s) - except utils.OptionExecutionFailure as e: - raise ApproachFailure(e.args[0], e.info) - - return _policy + return self._wrap_option_failures(inner) def _replan_suffix( self, diff --git a/predicators/approaches/agent_closed_loop_approach.py b/predicators/approaches/agent_closed_loop_approach.py deleted file mode 100644 index 3ef1112d7..000000000 --- a/predicators/approaches/agent_closed_loop_approach.py +++ /dev/null @@ -1,183 +0,0 @@ -"""Agent closed-loop planning approach. - -Like AgentPlannerApproach, but instead of generating the full option plan -upfront, the agent is queried at each option boundary to decide the next -single option based on the current state. This makes the approach reactive -to actual execution outcomes. - -Example command: - python predicators/main.py --env pybullet_domino \ - --approach agent_closed_loop --seed 0 \ - --num_train_tasks 1 --num_test_tasks 1 \ - --num_online_learning_cycles 1 --explorer agent_plan -""" -import logging -from typing import Callable, List - -import numpy as np - -from predicators import utils -from predicators.agent_sdk.tools import create_mcp_tools -from predicators.approaches import ApproachFailure -from predicators.approaches.agent_planner_approach import AgentPlannerApproach -from predicators.structs import Action, State, Task, _Option - - -class AgentClosedLoopApproach(AgentPlannerApproach): - """Closed-loop planning via Claude Agent SDK. - - At each option boundary, queries the agent for the next single - option based on the current state, goal, and execution history. - """ - - @classmethod - def get_name(cls) -> str: - return "agent_closed_loop" - - def _create_agent_mcp_tools(self) -> list: - return create_mcp_tools( - self._tool_context, - tool_names=[ - "inspect_options", "inspect_trajectories", - "inspect_train_tasks" - ], - ) - - def _solve(self, task: Task, timeout: int) -> Callable[[State], Action]: - step_history: List[str] = [] - - def _option_policy(state: State) -> _Option: - try: - prompt = self._build_step_prompt(state, task, step_history) - responses = self._query_agent_sync(prompt, kind="test") - text = self._extract_option_plan_text(responses) - option = self._parse_single_option(text, task) - step_history.append(option.simple_str()) - return option - except ApproachFailure: - raise - except Exception as e: - raise ApproachFailure( - f"Agent failed to produce next option: {e}") - - policy = utils.option_policy_to_policy(_option_policy) - - def _policy(s: State) -> Action: - try: - return policy(s) - except utils.OptionExecutionFailure as e: - raise ApproachFailure(e.args[0], e.info) - - return _policy - - def _build_step_prompt(self, state: State, task: Task, - step_history: List[str]) -> str: - """Build prompt asking for the next single option.""" - objects = list(state) - - # Objects - obj_strs = [] - for obj in sorted(objects, key=lambda o: o.name): - obj_strs.append(f" {obj.name}: {obj.type.name}") - - # Goal - goal_strs = [str(a) for a in sorted(task.goal, key=str)] - - # Options - option_strs = [] - for opt in sorted(self._initial_options, key=lambda o: o.name): - type_sig = ", ".join(t.name for t in opt.types) - params_dim = opt.params_space.shape[0] - if params_dim > 0: - low = opt.params_space.low.tolist() - high = opt.params_space.high.tolist() - if opt.params_description: - desc = ", ".join(opt.params_description) - param_info = (f", params=[{desc}], " - f"low={low}, high={high}") - else: - param_info = (f", params_dim={params_dim}, " - f"low={low}, high={high}") - else: - param_info = "" - option_strs.append(f" {opt.name}({type_sig}{param_info})") - - # Current atoms - atoms = utils.abstract(state, self._initial_predicates) - atom_strs = [str(a) for a in sorted(atoms, key=str)] - - # State features - state_str = state.dict_str(indent=2) - - # Trajectory summary - traj_summary = self._build_trajectory_summary() - - # Step history - if step_history: - history_str = "\n## Options Executed So Far\n" - for i, s in enumerate(step_history): - history_str += f" Step {i + 1}: {s}\n" - else: - history_str = ("\n## Options Executed So Far\n" - "None yet (first step).\n") - - prompt = f"""You are solving a task step by step. \ -Decide the NEXT SINGLE option to execute. - -## Goal -{chr(10).join(goal_strs)} - -## Current State Atoms -{chr(10).join(atom_strs)} - -## Current State Features -{state_str} - -## Objects -{chr(10).join(obj_strs)} - -## Available Options -{chr(10).join(option_strs)} -{history_str}{traj_summary} -## Instructions -You can use the inspect tools to examine types, predicates, options, and past trajectories in more detail. - -Based on the current state and execution history, output the NEXT SINGLE option to execute. -Output exactly ONE option line in this format: - OptionName(obj1:type1, obj2:type2)[param1, param2] - -If an option has no continuous parameters, use empty brackets: OptionName(obj1:type1)[] - -Output ONLY the single option line at the end, after any analysis.""" - - return prompt - - def _parse_single_option(self, text: str, task: Task) -> _Option: - """Parse a single option from agent response and ground it.""" - if not text.strip(): - raise ApproachFailure("Agent returned empty response.") - - objects = list(task.init) - parsed = utils.parse_model_output_into_option_plan( - text, - objects, - self._types, - self._initial_options, - parse_continuous_params=True) - - if not parsed: - raise ApproachFailure( - "Could not parse any option from agent response.") - - # Take the last parsed option (agent may include analysis before it) - option, objs, params = parsed[-1] - try: - params_arr = np.array(params, dtype=np.float32) - ground_opt = option.ground(objs, params_arr) - except Exception as e: - raise ApproachFailure( - f"Failed to ground option {option.name}: {e}") - - logging.info(f"Agent selected next option: " - f"{ground_opt.simple_str()}") - return ground_opt diff --git a/predicators/approaches/agent_option_learning_approach.py b/predicators/approaches/agent_option_learning_approach.py index 201514a2b..c63d59623 100644 --- a/predicators/approaches/agent_option_learning_approach.py +++ b/predicators/approaches/agent_option_learning_approach.py @@ -17,7 +17,6 @@ from functools import lru_cache from typing import Any, Callable, Dict, List, Optional, Set -import dill as pkl from gym.spaces import Box from predicators import utils @@ -38,13 +37,15 @@ class AgentOptionLearningApproach(AgentPlannerApproach): then plans with them in the same query. """ + _save_suffix = "AgentOptionLearning" + def __init__(self, initial_predicates: Set[Predicate], initial_options: Set[ParameterizedOption], types: Set[Type], action_space: Box, train_tasks: List[Task], *args: Any, **kwargs: Any) -> None: - # Agent-specific state (before super().__init__) + # Agent-specific state (before super().__init__). + # (_agent_session_id is initialized by the session mixin.) self._agent_proposed_options: Set[ParameterizedOption] = set() - self._agent_session_id: Optional[str] = None super().__init__(initial_predicates, initial_options, types, action_space, train_tasks, *args, **kwargs) @@ -70,7 +71,7 @@ def _get_agent_system_prompt(self) -> str: 2. **Invent** new options if needed — either by writing and executing Python code directly, or by using the `propose_options` tool 3. **Test** — either write and run Python experiments to verify your - options, or use `test_option_plan` to check that a plan achieves + options, or use `evaluate_option_plan` to check that a plan achieves the goal. Use `retract_abstractions` to remove options that don't work. 4. **Plan** — output the final option plan @@ -124,20 +125,20 @@ def _get_agent_system_prompt(self) -> str: - Only propose new options if existing ones cannot achieve the goal - You can invent and test options in two ways: (a) write and execute Python code directly in the sandbox, or (b) use the `propose_options`, - `retract_abstractions`, and `test_option_plan` tools + `retract_abstractions`, and `evaluate_option_plan` tools - Always test your plan before committing - Output the final plan in the standard format at the end ## Debugging Tips - Use `inspect_options` with `option_name` to save an option's source code to ./proposed_code/.py, then Read it to study the implementation -- `test_option_plan` automatically saves scene images to ./test_images/ +- `evaluate_option_plan` automatically saves scene images to ./test_images/ after each step — check them to debug spatial issues - Your session logs are in ./session_logs/ — Glob and Read them to review past attempts when iterating - All proposal and option source code is in ./proposed_code/ — Read files there to understand how existing options work -- When `test_option_plan` fails, check the "Object poses at failure" +- When `evaluate_option_plan` fails, check the "Object poses at failure" and "Missing goal atoms" in the output""" def _get_solve_tool_names(self) -> Optional[List[str]]: @@ -149,7 +150,7 @@ def _get_solve_tool_names(self) -> Optional[List[str]]: "inspect_past_proposals", "propose_options", "retract_abstractions", - "test_option_plan", + "evaluate_option_plan", ] def _get_sandbox_reference_files( # pylint: disable=useless-super-delegation @@ -291,7 +292,7 @@ def _build_solve_prompt(self, task: Task) -> str: 3. **Test** — Verify your options and plan work correctly: - **Python code**: Write and run Python experiments to unit-test \ individual options or full plans. - - **MCP tools**: Use `test_option_plan` to check that a plan \ + - **MCP tools**: Use `evaluate_option_plan` to check that a plan \ (including any new options) achieves the goal. Iterate until the test passes. 4. **Commit** — Once the test passes, output the final plan. Your \ @@ -329,55 +330,14 @@ def _solve(self, task: Task, timeout: int) -> Callable[[State], Action]: # Save / Load # ------------------------------------------------------------------ # - def save(self, online_learning_cycle: Optional[int] = None) -> None: - save_path = utils.get_approach_save_path_str() - with open(f"{save_path}_{online_learning_cycle}.AgentOptionLearning", - "wb") as f: - save_dict = { - "offline_dataset": - self._offline_dataset, - "online_trajectories": - self._online_trajectories, - "online_learning_cycle": - self._online_learning_cycle, - "run_id": - self._run_id, - "agent_proposed_options": - self._agent_proposed_options, - "agent_session_id": (self._agent_session.session_id - if self._agent_session else None), - } - pkl.dump(save_dict, f) - logging.info(f"[Run {self._run_id}] Saved approach to {save_path}_" - f"{online_learning_cycle}.AgentOptionLearning") - - def load(self, online_learning_cycle: Optional[int] = None) -> None: - save_path = utils.get_approach_load_path_str() - with open(f"{save_path}_{online_learning_cycle}.AgentOptionLearning", - "rb") as f: - save_dict = pkl.load(f) - - self._offline_dataset = save_dict["offline_dataset"] - self._online_trajectories = save_dict["online_trajectories"] - self._online_learning_cycle = \ - save_dict["online_learning_cycle"] + 1 - self._agent_session_id = save_dict.get("agent_session_id") + def _extra_save_state(self) -> Dict[str, Any]: + return {"agent_proposed_options": self._agent_proposed_options} + + def _load_extra_save_state(self, save_dict: Dict[str, Any]) -> None: self._agent_proposed_options = save_dict.get("agent_proposed_options", set()) - - import datetime # pylint: disable=import-outside-toplevel - original_run_id = save_dict.get("run_id", "unknown") - self._run_id = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") - - # Re-sync tool context - self._sync_tool_context() - - logging.info( - f"[Run {self._run_id}] Loaded from previous run " - f"{original_run_id}: " - f"{len(self._offline_dataset.trajectories)} offline, " - f"{len(self._online_trajectories)} online trajectories, " - f"{len(self._agent_proposed_options)} agent-proposed options") + logging.info("[Run %s] Restored %d agent-proposed options.", + self._run_id, len(self._agent_proposed_options)) # --------------------------------------------------------------------------- # diff --git a/predicators/approaches/agent_planner_approach.py b/predicators/approaches/agent_planner_approach.py index 4fe2ca802..31475d3db 100644 --- a/predicators/approaches/agent_planner_approach.py +++ b/predicators/approaches/agent_planner_approach.py @@ -31,7 +31,7 @@ from predicators.settings import CFG from predicators.structs import Action, Dataset, GroundAtom, \ InteractionRequest, InteractionResult, LowLevelTrajectory, Object, \ - ParameterizedOption, Predicate, State, Task, Type + OptionSampler, ParameterizedOption, Predicate, State, Task, Type class AgentPlannerApproach(AgentSessionMixin, BaseApproach): @@ -68,11 +68,24 @@ def __init__(self, Any, self._option_model)._abstract_function = ( lambda s: utils.abstract(s, self._get_all_predicates())) self._online_learning_cycle = 0 + # Synthesized per-skill samplers (option name -> sampler). Empty for + # the base planner; learning subclasses that synthesize samplers + # populate it. Threaded into bilevel refinement via + # _get_all_samplers() so continuous-parameter search can aim at each + # step's subgoal instead of drawing uniformly. + self._synthesized_samplers: Dict[str, OptionSampler] = {} self._requests_train_task_idxs: Optional[List[int]] = None self._run_id = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") self._pre_test_conversation_log: Optional[List[Dict[str, Any]]] = None - self._agent_session_id: Optional[str] = None - + # True only between begin_test_phase / end_test_phase, so per-episode + # hooks can act on test solves without touching exploration episodes. + self._in_test_phase = False + # 0-based index of the test task being solved, mirroring main.py's + # ``test_task_idx``. Incremented per test solve; threaded into the + # session-log filename via the ToolContext. + self._test_task_idx = -1 + + # Initializes _tool_context and _agent_session_id (see mixin). self._init_agent_session_state(types, initial_predicates, initial_options, train_tasks) @@ -115,6 +128,15 @@ def _get_all_predicates(self) -> Set[Predicate]: """Return the full set of predicates for abstraction.""" return self._initial_predicates + def _get_all_samplers(self) -> Dict[str, OptionSampler]: + """Return synthesized per-skill samplers (option name -> sampler). + + Empty by default; learning subclasses populate the backing + field. Threaded into bilevel refinement to aim continuous- + parameter search at each step's subgoal. + """ + return self._synthesized_samplers + def _get_all_trajectories(self) -> List[LowLevelTrajectory]: """Return all trajectories (offline + online).""" return self._offline_dataset.trajectories + self._online_trajectories @@ -125,7 +147,7 @@ def _create_planner_option_model(self) -> Optional[_OptionModelBase]: Honors two CFG knobs: * ``agent_planner_use_simulator`` -- when False, returns ``None`` - so the agent gets no ``test_option_plan`` rollouts and must + so the agent gets no ``evaluate_option_plan`` rollouts and must plan open-loop from data + LLM reasoning (the model-free baseline). * ``agent_planner_use_base_simulator`` -- when True (and a @@ -169,9 +191,9 @@ def _create_planner_option_model(self) -> Optional[_OptionModelBase]: ## Scratchpad — CRITICAL You MUST maintain `./notes.md` as your working memory. \ **Read it at the very start of the session** and **read it \ -again before every test_option_plan call** to remind yourself \ +again before every evaluate_option_plan call** to remind yourself \ what you already tried. **Update it immediately after every \ -test_option_plan call** — no exceptions. +evaluate_option_plan call** — no exceptions. Use this exact format for each option you are tuning: @@ -202,7 +224,7 @@ def _create_planner_option_model(self) -> Optional[_OptionModelBase]: rotation, water_volume, is_on, etc.) and renders the scene \ WITHOUT running the full simulation. It is FREE (no physics, \ no failure modes) — use it liberally to build spatial \ -understanding before spending expensive test_option_plan calls. +understanding before spending expensive evaluate_option_plan calls. **When to use visualize_state:** - **At the start**: visualize key objects to understand the \ @@ -279,7 +301,7 @@ def _get_agent_system_prompt(self) -> str: if use_scratchpad: steps.append( "**Read `./notes.md` before every test**, then **update it " - "immediately after every test_option_plan call**. Record " + "immediately after every evaluate_option_plan call**. Record " "what you tried, what happened, and what you learned. " "This is your memory — without it you will repeat failures.") steps += [ @@ -289,8 +311,9 @@ def _get_agent_system_prompt(self) -> str: "**Inspect rendered images** from `./test_images/` when " "something goes wrong to understand the actual outcome. " "For finer-grained debugging, pass `save_low_level_action_images: " - "true` to test_option_plan — this saves per-simulator-step images " - "to `./test_images_low_level/`.", + "true` to evaluate_option_plan — this saves " + "per-simulator-step images to " + "`./test_images_low_level/`.", "**Expect geometric offsets.** The target position for " "options is often offset from the reference object's reported " "position due to object geometry. Explore a wide range around " @@ -345,11 +368,14 @@ def _get_solve_tool_names(self) -> Optional[List[str]]: "inspect_options", "inspect_trajectories", "inspect_train_tasks" ] # The remaining tools all require a simulator / live env: - # test_option_plan rolls plans out through the option model, and - # visualize_state / annotate_scene render env states. None are - # offered when the planner has no simulator. + # evaluate_option_plan rolls fully-specified plans out through the + # option model, refine_plan_sketch runs backtracking refinement + + # forward validation on a param-free sketch, and visualize_state / + # annotate_scene render env states. None are offered when the + # planner has no simulator. if CFG.agent_planner_use_simulator: - tools.append("test_option_plan") + tools.append("evaluate_option_plan") + tools.append("refine_plan_sketch") if CFG.agent_planner_use_annotate_scene: tools.append("annotate_scene") if CFG.agent_planner_use_visualize_state: @@ -403,6 +429,9 @@ def learn_from_interaction_results( preds_version: Optional[str] = getattr(self, "_current_predicates_version", None) + samplers_version: Optional[str] = getattr(self, + "_current_samplers_version", + None) for i, result in enumerate(results): task_idx = self._requests_train_task_idxs[i] traj = LowLevelTrajectory( @@ -411,6 +440,7 @@ def learn_from_interaction_results( _train_task_idx=task_idx, _source_simulator_version=sim_version, _source_predicates_version=preds_version, + _source_samplers_version=samplers_version, ) self._online_trajectories.append(traj) @@ -429,9 +459,30 @@ def learn_from_interaction_results( # Solving # ------------------------------------------------------------------ # + @staticmethod + def _wrap_option_failures( + policy: Callable[[State], Action]) -> Callable[[State], Action]: + """Wrap a policy so OptionExecutionFailure surfaces as ApproachFailure. + + Bilevel planning and the base open-loop planner both build a + low-level policy from a grounded option plan; this adapter gives + them a single place to translate the option-execution exception + the harness raises into the ApproachFailure CogMan expects. + """ + + def _policy(s: State) -> Action: + try: + return policy(s) + except utils.OptionExecutionFailure as e: + raise ApproachFailure(e.args[0], e.info) + + return _policy + def _solve(self, task: Task, timeout: int) -> Callable[[State], Action]: self._sync_tool_context() self._tool_context.current_task = task + # Render the initial state so the agent can see the scene layout. + self._render_initial_state_image(task) try: option_plan = self._query_agent_for_option_plan(task) except Exception as e: @@ -441,13 +492,61 @@ def _solve(self, task: Task, timeout: int) -> Callable[[State], Action]: policy = utils.option_plan_to_policy( option_plan, abstract_function=lambda s: utils.abstract(s, preds)) - def _policy(s: State) -> Action: - try: - return policy(s) - except utils.OptionExecutionFailure as e: - raise ApproachFailure(e.args[0], e.info) + return self._wrap_option_failures(policy) - return _policy + def _render_initial_state_image(self, task: Task) -> Optional[str]: + """Render the initial state of the task and save to the sandbox. + + Returns the sandbox-relative path to the saved image, or None if + rendering is not available. + """ + env = self._tool_context.env + if env is None: + return None + save_dir = self._tool_context.image_save_dir + if save_dir is None: + return None + try: + # pylint: disable=import-outside-toplevel + from PIL import Image as PILImage + + # For PyBullet envs, set state then use render() (render_state + # raises NotImplementedError for arbitrary states). + # For other envs, use render_state directly. + try: + from predicators.envs.pybullet_env import PyBulletEnv + is_pybullet = isinstance(env, PyBulletEnv) + except ImportError: + is_pybullet = False + + if is_pybullet: + env._set_state(task.init) # pylint: disable=protected-access + video = env.render() + else: + # Build a minimal EnvironmentTask for the render_state API. + from predicators.structs import EnvironmentTask + env_task = EnvironmentTask(task.init, task.goal) + video = env.render_state(task.init, env_task) + + if not video: + return None + + rgb_array = np.asarray(video[0], dtype=np.uint8) + img = PILImage.fromarray( + rgb_array) # type: ignore[no-untyped-call] + os.makedirs(save_dir, exist_ok=True) + task_id = self._tool_context.test_task_idx + if task_id is not None: + filename = f"task{task_id:03d}_initial_state.png" + else: + filename = "initial_state.png" + saved_path = os.path.join(save_dir, filename) + img.save(saved_path) + logging.info("Saved initial state image to %s", saved_path) + return saved_path + except Exception as e: # pylint: disable=broad-except + logging.warning("Failed to render initial state image: %s", e) + return None # ------------------------------------------------------------------ # # Test phase lifecycle @@ -455,6 +554,8 @@ def _policy(s: State) -> Action: def begin_test_phase(self) -> None: """Snapshot the learning conversation log before testing.""" + self._in_test_phase = True + self._test_task_idx = -1 if self._agent_session is not None: import copy # pylint: disable=import-outside-toplevel self._pre_test_conversation_log = copy.deepcopy( @@ -464,12 +565,29 @@ def begin_test_phase(self) -> None: def end_test_phase(self) -> None: """Restore the conversation log to its pre-test state.""" + self._in_test_phase = False + self._tool_context.test_task_idx = None if self._agent_session is not None \ and self._pre_test_conversation_log is not None: self._agent_session._conversation_log = \ self._pre_test_conversation_log # pylint: disable=protected-access self._pre_test_conversation_log = None + def reset_for_new_episode(self) -> None: + """Advance the test-task counter at each test episode start. + + CogMan calls this exactly once per test task (via + ``cogman.reset`` in main.py's ``_solve_task``) and never on mid- + episode replans, so the counter stays in lockstep with main.py's + ``test_task_idx``. The index is exposed to the sandbox via the + ToolContext and lands in the session-log filename. No-op outside + the test phase. + """ + super().reset_for_new_episode() + if self._in_test_phase: + self._test_task_idx += 1 + self._tool_context.test_task_idx = self._test_task_idx + def _query_agent_for_option_plan(self, task: Task) -> list: """Query the agent for an option plan and parse it.""" prompt = self._build_solve_prompt(task) @@ -490,7 +608,8 @@ def _solve_prompt_scratchpad_line(self) -> str: """Return the notes.md bullet for the solve prompt, or empty.""" if CFG.agent_planner_use_scratchpad: return ( - "- **Read `./notes.md` before every test_option_plan call** " + "- **Read `./notes.md` before every " + "evaluate_option_plan call** " "and **update it immediately after each call** — append a " "row to the parameter table and update the explored-ranges " "summary. If you realize you forgot to update, STOP and " @@ -554,6 +673,25 @@ def _build_solve_prompt(self, task: Task) -> str: {task.goal_nl} """ + # Initial state image reference + initial_image_section = "" + if self._tool_context.image_save_dir: + task_id = self._tool_context.test_task_idx + if task_id is not None: + img_name = f"task{task_id:03d}_initial_state.png" + else: + img_name = "initial_state.png" + initial_img_path = os.path.join(self._tool_context.image_save_dir, + img_name) + if os.path.exists(initial_img_path): + # Use sandbox-relative path for the agent + initial_image_section = ( + "\n## Initial State Image\n" + "A rendering of the initial scene has been saved to " + f"`./test_images/{img_name}`. **Read this image " + "first** to understand the spatial layout before " + "planning.\n") + if CFG.agent_planner_use_simulator: instructions_intro = ( "Use your available tools to inspect the environment and " @@ -575,7 +713,7 @@ def _build_solve_prompt(self, task: Task) -> str: ## Initial State Features {state_str} - +{initial_image_section} ## Objects {chr(10).join(obj_strs)} @@ -797,7 +935,7 @@ def _create_explorer(self) -> BaseExplorer: def _sync_tool_context(self) -> None: """Push current approach state into the shared ToolContext. - The MCP tools (inspect_options, test_option_plan, etc.) read + The MCP tools (inspect_options, evaluate_option_plan, etc.) read from the ToolContext dataclass, not from the approach directly. This method keeps them in sync after mutations (e.g. new trajectories collected, options added). Called before each @@ -818,6 +956,9 @@ def _sync_tool_context(self) -> None: self._tool_context.log_dir = self._get_log_dir() self._tool_context.option_model = self._option_model + # Synthesized samplers, so the explorer and synthesis tools thread + # the same per-skill samplers into refinement that the approach uses. + self._tool_context.option_samplers = self._get_all_samplers() # Wire the active-experiment info-gain scorer when a learning # subclass exposes one and info-seeking exploration is on. Syncing # the bound method (not a snapshot) keeps it pointed at the latest @@ -849,45 +990,68 @@ def _sync_tool_context(self) -> None: # Save / Load # ------------------------------------------------------------------ # + # Filename suffix for the pickled approach state. Subclasses that + # persist extra fields override this so their saves don't collide + # with the base planner's. + _save_suffix: str = "AgentPlanner" + + def _extra_save_state(self) -> Dict[str, Any]: + """Subclass hook: extra (key -> value) pairs to persist. + + Merged into the base save dict; restored by the matching + :meth:`_load_extra_save_state`. + """ + return {} + + def _load_extra_save_state(self, save_dict: Dict[str, Any]) -> None: + """Subclass hook: restore fields written by _extra_save_state. + + Called after the base fields are restored and ``_run_id`` has + been refreshed, but before the tool context is re-synced. + """ + def save(self, online_learning_cycle: Optional[int] = None) -> None: """Save approach state to disk.""" save_path = utils.get_approach_save_path_str() - with open(f"{save_path}_{online_learning_cycle}.AgentPlanner", - "wb") as f: - save_dict = { - "offline_dataset": - self._offline_dataset, - "online_trajectories": - self._online_trajectories, - "online_learning_cycle": - self._online_learning_cycle, - "run_id": - self._run_id, - "agent_session_id": (self._agent_session.session_id - if self._agent_session else None), - } + path = f"{save_path}_{online_learning_cycle}.{self._save_suffix}" + save_dict = { + "offline_dataset": + self._offline_dataset, + "online_trajectories": + self._online_trajectories, + "online_learning_cycle": + self._online_learning_cycle, + "run_id": + self._run_id, + "agent_session_id": + (self._agent_session.session_id if self._agent_session else None), + **self._extra_save_state(), + } + with open(path, "wb") as f: pkl.dump(save_dict, f) - logging.info(f"[Run {self._run_id}] Saved approach to {save_path}_" - f"{online_learning_cycle}.AgentPlanner") + logging.info(f"[Run {self._run_id}] Saved approach to {path}") def load(self, online_learning_cycle: Optional[int] = None) -> None: save_path = utils.get_approach_load_path_str() - with open(f"{save_path}_{online_learning_cycle}.AgentPlanner", - "rb") as f: + path = f"{save_path}_{online_learning_cycle}.{self._save_suffix}" + with open(path, "rb") as f: save_dict = pkl.load(f) self._offline_dataset = save_dict["offline_dataset"] self._online_trajectories = save_dict["online_trajectories"] - self._online_learning_cycle = \ - save_dict["online_learning_cycle"] + 1 + self._online_learning_cycle = save_dict["online_learning_cycle"] + 1 + # pylint: disable=attribute-defined-outside-init self._agent_session_id = save_dict.get("agent_session_id") + # pylint: enable=attribute-defined-outside-init # Create new run_id for continued execution (each run gets own dir) - # but log the original run_id for reference + # but log the original run_id for reference. original_run_id = save_dict.get("run_id", "unknown") self._run_id = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") - # Re-sync tool context + self._load_extra_save_state(save_dict) + + # Re-sync tool context (subclass fields are restored first). self._sync_tool_context() logging.info( diff --git a/predicators/approaches/agent_sim_learning_approach.py b/predicators/approaches/agent_sim_learning_approach.py index d0a8d4794..4a4719642 100644 --- a/predicators/approaches/agent_sim_learning_approach.py +++ b/predicators/approaches/agent_sim_learning_approach.py @@ -29,8 +29,9 @@ from gym.spaces import Box from predicators import utils -from predicators.agent_sdk.tools import SYNTHESIS_TOOL_NAMES, \ - _SnapshotTarget, create_synthesis_tools, finalize_versioned_snapshot, \ +from predicators.agent_sdk.tools import SAMPLER_SYNTHESIS_TOOL_NAMES, \ + SYNTHESIS_TOOL_NAMES, _SnapshotTarget, create_sampler_synthesis_tools, \ + create_synthesis_tools, finalize_versioned_snapshot, \ make_write_snapshot_hook from predicators.approaches.agent_bilevel_approach import AgentBilevelApproach from predicators.code_sim_learning.active_experiment import laplace_ensemble, \ @@ -44,12 +45,12 @@ iter_feature_residuals, merge_updates, read_latent_init, \ read_simulator_components from predicators.envs import create_new_env -from predicators.ground_truth_models import get_gt_simulator +from predicators.ground_truth_models import get_gt_samplers, get_gt_simulator from predicators.option_model import _OptionModelBase, _OracleOptionModel from predicators.settings import CFG from predicators.structs import Action, Dataset, GroundAtom, \ - InteractionResult, LowLevelTrajectory, ParameterizedOption, Predicate, \ - State, Task, Type + InteractionResult, LowLevelTrajectory, OptionSampler, \ + ParameterizedOption, Predicate, State, Task, Type logger = logging.getLogger(__name__) @@ -151,6 +152,13 @@ def __init__(self, # provenance (consumed in the next learn-phase prompt). self._current_simulator_version: Optional[str] = None self._current_predicates_version: Optional[str] = None + self._current_samplers_version: Optional[str] = None + # Whether this run learns samplers (vs. using ground-truth ones). + # Refined per cycle in _learn_simulator once GT availability is known; + # this default is what the synthesis-session tool surface reads. + self._do_synthesize_samplers: bool = ( + CFG.agent_sim_learn_synthesize_samplers + and not CFG.agent_sim_learn_oracle_samplers) # Partial-observability latent block: loaded from a simulator's # LATENT_INIT export (None ⇒ no latent state). When the loaded # rules use the recurrent 5-arg signature, fitting, the combined @@ -186,8 +194,13 @@ def _get_synthesis_tool_names(self) -> Optional[List[str]]: ``ctx.extra_mcp_tools`` inside :meth:`_synthesize_with_agent`. The mixin asserts the attached instances and this list agree. """ - return ["inspect_types", "inspect_options", "inspect_trajectories"] +\ + names = ["inspect_types", "inspect_options", "inspect_trajectories"] +\ list(SYNTHESIS_TOOL_NAMES) + # When the agent is learning samplers in this session (not using + # ground-truth ones), expose the evaluate_sampler tool. + if self._do_synthesize_samplers: + names += list(SAMPLER_SYNTHESIS_TOOL_NAMES) + return names # ── Subclass hooks ────────────────────────────────────────── # Default implementations are no-ops so subclasses can add @@ -292,6 +305,296 @@ def _build_synthesis_session_hooks( ], } + # ── Per-skill sampler synthesis ───────────────────────────── + # Samplers are a first-class artifact of the base sim-learning + # approach (gated by a flag), not a subclass extension like + # predicates — so they are woven into _synthesize_with_agent and + # _learn_simulator directly rather than via the _extra_synthesis_* + # hooks, which keeps them independent of the predicate subclass's + # (non-super-calling) hook overrides. When a sim-synthesis session + # runs (oracle_sim_program=False) the sampler tool/snapshot/message + # ride along in it; when none runs (oracle_sim_program=True) they get + # a dedicated session via _synthesize_samplers_standalone. + + @staticmethod + def _samplers_enabled() -> bool: + """Whether per-skill samplers are used at all this run.""" + return CFG.agent_sim_learn_synthesize_samplers + + def _maybe_install_oracle_samplers(self) -> None: + """Resolve sampler mode for this cycle and install GT ones if used. + + Sets ``self._do_synthesize_samplers`` (learn vs. use ground + truth). When ``agent_sim_learn_oracle_samplers`` is on and the + env provides ground-truth samplers, installs them and skips + synthesis; if none exist, warns and falls back to synthesis. + """ + gt_samplers = None + if self._samplers_enabled() and CFG.agent_sim_learn_oracle_samplers: + gt_samplers = get_gt_samplers(CFG.env) + if gt_samplers: + self._synthesized_samplers = dict(gt_samplers) + self._current_samplers_version = "oracle" + logger.info("Using %d ground-truth sampler(s): %s", + len(gt_samplers), ", ".join(sorted(gt_samplers))) + else: + logger.warning( + "agent_sim_learn_oracle_samplers=True but no ground-truth " + "samplers for env %s; falling back to synthesis.", CFG.env) + self._do_synthesize_samplers = (self._samplers_enabled() + and not gt_samplers) + + def _sampler_paths(self, base: str) -> Dict[str, str]: + """Sandbox path bindings for samplers.py (host + agent-visible).""" + samplers_file = os.path.join(base, "samplers.py") + samplers_versions_dir = os.path.join(base, "samplers_versions") + if CFG.agent_sdk_use_local_sandbox: + samplers_file_for_agent = "./samplers.py" + elif self._tool_context.sandbox_dir: + samplers_file_for_agent = "/sandbox/samplers.py" + else: + samplers_file_for_agent = samplers_file + return { + "samplers_file": samplers_file, + "samplers_versions_dir": samplers_versions_dir, + "samplers_file_for_agent": samplers_file_for_agent, + } + + def _make_sampler_tools(self, paths: Dict[str, str]) -> List[Any]: + """Build the evaluate_sampler MCP tool for a synthesis session.""" + return create_sampler_synthesis_tools( + samplers_file=paths["samplers_file"], + samplers_versions_dir=paths["samplers_versions_dir"], + approach=self, + cycle_index_provider=self._learning_cycle_index, + ) + + def _sampler_snapshot_target(self, paths: Dict[str, + str]) -> _SnapshotTarget: + """Snapshot target that versions samplers.py on every Write/Edit.""" + return _SnapshotTarget( + live_file=paths["samplers_file"], + versions_dir=paths["samplers_versions_dir"], + artifact_name="samplers", + cycle_index_provider=self._learning_cycle_index, + ) + + def _sampler_synthesis_message(self, paths: Dict[str, str]) -> str: + """Instructions appended to the agent's first synthesis message.""" + path = paths["samplers_file_for_agent"] + return f"""\ +## Per-Skill Sampler Synthesis + +Backtracking refinement draws each option's continuous parameters \ +*uniformly* from its params box by default. When a sketch step's subgoal \ +pins the parameters into a tiny region (e.g. a placement that must land \ +within a few cm of an exact point and at a specific orientation), uniform \ +sampling almost never hits it and refinement exhausts its budget. Fix this \ +by writing per-skill samplers to `{path}` as a dict \ +`LEARNED_SAMPLERS = {{"OptionName": sampler_fn, ...}}` keyed by option name. + +Each sampler has signature \ +`fn(state, subgoal_atoms, rng, objects) -> params` (the same signature as \ +the env's NSRT samplers) where: +- `state` is the current `State` (read object features with `state.get(obj, "feat")`), +- `subgoal_atoms` is the set of `GroundAtom`s the step must establish — \ +read the target relation here (e.g. an `InFront`/at-target atom names the \ +two objects whose geometry the placement must satisfy) and compute the \ +parameters that achieve it, +- `rng` is a `numpy` `Generator` (use it for small jitter so retries differ), +- `objects` is the list of typed objects bound to this option call. +Return a `float32` array whose length matches the option's params box \ +(see `inspect_options` for the dimension and ranges); refinement clips it \ +to that box, so stay within the ranges. + +Aim the parameters at the subgoal geometrically (then add a little `rng` \ +jitter); do NOT just return uniform draws. Read the option signatures with \ +`inspect_options` and the predicate classifiers (for the subgoal geometry) \ +with the predicate listing above. + +Workflow: write `{path}`, call `evaluate_sampler` (snapshots + installs \ +them and sanity-checks shape/box), then call `evaluate_plan_refinement` \ +with a sketch using those options — the samples-to-refine count should \ +drop sharply versus uniform. Iterate with `Edit` and re-run. Every \ +successful Write/Edit of `{path}` is snapshotted to `samplers_versions/` \ +as `cycle_XXX_vers_YYY_samplers.py`.""" + + def _finalize_and_load_samplers(self, paths: Dict[str, str]) -> None: + """Snapshot the final samplers.py and load it into approach state.""" + tag = finalize_versioned_snapshot( + paths["samplers_file"], + paths["samplers_versions_dir"], + cycle_idx=self._learning_cycle_index(), + artifact_name="samplers", + ) + if tag is not None: + self._current_samplers_version = tag + logger.info("Final samplers snapshot: %s", tag) + loaded = self._load_samplers_from_module_file(paths["samplers_file"]) + self._synthesized_samplers = loaded + logger.info("Loaded %d per-skill sampler(s) from %s.", len(loaded), + paths["samplers_file"]) + for name in sorted(loaded): + logger.info(" sampler: %s", name) + + def _load_samplers_from_module_file(self, + path: str) -> Dict[str, OptionSampler]: + """Load LEARNED_SAMPLERS from ``path``; validate each entry. + + Mirrors ``_load_predicates_from_module_file``. Returns an empty + dict on missing file or exec failure (samplers are optional). + Skips entries keyed by an unknown option name or whose value is + not callable. + """ + # pylint: disable=import-outside-toplevel + from predicators.agent_sdk.proposal_parser import build_exec_context, \ + exec_code_safely + from predicators.agent_sdk.tools import _ParamsView + + # pylint: enable=import-outside-toplevel + # ParamSpec is imported at module scope (used by exec'd samplers + # that close over learned params, mirroring the predicate loader). + + if not os.path.isfile(path): + logger.info("No samplers file at %s; sampler set is empty.", path) + return {} + + with open(path, "r", encoding="utf-8") as f: + code = f.read() + + ctx = build_exec_context(types=self._types, + predicates=self._get_all_predicates(), + options=self._get_all_options(), + extra_context={ + "params": + _ParamsView(self._fitted_params), + "ParamSpec": ParamSpec, + }) + + result, err = exec_code_safely(code, ctx, "LEARNED_SAMPLERS") + if err is not None: + logger.warning("Failed to load %s:\n%s", path, err) + return {} + if not isinstance(result, dict): + logger.warning("%s: LEARNED_SAMPLERS must be a dict, got %s.", + path, + type(result).__name__) + return {} + + option_names = {o.name for o in self._get_all_options()} + valid: Dict[str, OptionSampler] = {} + for name, fn in result.items(): + if name not in option_names: + logger.warning( + "Skipped sampler '%s' (not a known option name).", name) + continue + if not callable(fn): + logger.warning("Skipped sampler '%s' (value is not callable).", + name) + continue + valid[name] = fn + return valid + + def _synthesize_samplers_standalone( + self, trajectories: List[LowLevelTrajectory], + base_pred_triples: List[Tuple[State, Action, State]], + inferred_hint: Dict[str, List[str]]) -> None: + """Run a dedicated sampler-synthesis session. + + Used when oracle_sim_program short-circuits the sim-synthesis + session, so samplers still get learned. Reuses that session's + sandbox/snapshot/tool machinery. Called from _learn_simulator + after the option model is built, so evaluate_plan_refinement has + a working simulator. + """ + if CFG.agent_sdk_use_local_sandbox: + sandbox_dir: Optional[str] = os.path.abspath( + os.path.join(self._get_log_dir(), "sandbox")) + else: + sandbox_dir = self._tool_context.sandbox_dir + base = sandbox_dir or self._get_log_dir() + + if CFG.agent_sdk_use_local_sandbox: + sandbox_dir_for_agent: Optional[str] = "." + elif sandbox_dir: + sandbox_dir_for_agent = "/sandbox" + else: + sandbox_dir_for_agent = None + + paths = self._sampler_paths(base) + simulator_file = os.path.join(base, "simulator.py") + versions_dir = os.path.join(base, "simulator_versions") + + exec_ns: Dict[str, Any] = { + "trajectories": + trajectories, + "train_tasks": + self._train_tasks, + "is_goal_state": + lambda state, task_idx: self._train_tasks[task_idx].goal_holds( + state), + "np": + np, + "ParamSpec": + ParamSpec, + } + # evaluate_plan_refinement (from the standard synthesis tools) gives + # the agent the samples-to-refine feedback signal; the sampler tool + # installs + sanity-checks the samplers. + tools = create_synthesis_tools( + exec_ns, + base_pred_triples, + inferred_hint, + simulator_file=simulator_file, + versions_dir=versions_dir, + approach=self, + sandbox_dir=base, + sandbox_dir_for_agent=sandbox_dir_for_agent, + cycle_index_provider=self._learning_cycle_index, + ) + tools.extend(self._make_sampler_tools(paths)) + # Use the same declared surface as the mixin will assert against + # (_get_synthesis_tool_names already includes the sampler tool since + # _do_synthesize_samplers is True here). The rule-fitting tools are + # exposed but irrelevant — the message steers the agent to samplers. + declared = set(self._get_synthesis_tool_names() or ()) + self._tool_context.extra_mcp_tools = [ + t for t in tools if getattr(t, "name", "") in declared + ] + self._learning_mode = True + self._tool_context.extra_session_hooks = ( + self._build_synthesis_session_hooks( + [self._sampler_snapshot_target(paths)], base)) + + self._close_agent_session() + self._ensure_agent_session() + + predicate_listing = self._format_predicate_signatures( + self._get_all_predicates()) + message = f"""\ +Synthesize per-skill samplers for this environment's options. The \ +simulator dynamics are already fixed (oracle/learned); your only job is \ +to make backtracking refinement land each option's continuous parameters \ +on its sketch-step subgoal instead of drawing them uniformly. + +## Available Predicates (subgoal geometry) +{predicate_listing} + +Read the option signatures with `inspect_options` and explore the \ +trajectory data with `run_python` (variables: `trajectories`, \ +`train_tasks`, `is_goal_state`, `np`, `ParamSpec`).""" + message = message + "\n\n" + self._sampler_synthesis_message(paths) + + try: + self._query_agent_sync(message, kind="learn") + finally: + self._tool_context.extra_session_hooks = {} + self._tool_context.extra_mcp_tools = [] + self._learning_mode = False + self._close_agent_session() + + self._finalize_and_load_samplers(paths) + # ── Learning ──────────────────────────────────────────────── def learn_from_offline_dataset(self, dataset: Dataset) -> None: @@ -310,6 +613,12 @@ def _learn_simulator(self, trajectories: List[LowLevelTrajectory]) -> None: # (latent threads within a trajectory, not across). Harmless for # fully-observable (legacy) simulators, which never regroup. self._fit_trajectories = list(trajectories) + # Decide how samplers are obtained this cycle: ground-truth (if + # requested and available for the env) else agent synthesis. GT + # samplers are static, so install them up front — independent of + # whether simulator learning runs below (it is skipped when there + # are no step transitions, e.g. when every demo failed). + self._maybe_install_oracle_samplers() # Two parallel triple lists drive the rest of this method: # * obs_triples — raw (s_t, a, s_{t+1}) from the data. # * base_pred_triples — same triples but s_t replaced by the @@ -350,6 +659,18 @@ def _learn_simulator(self, trajectories: List[LowLevelTrajectory]) -> None: self._option_model = self._build_option_model(combined_sim) logger.info("Built learned option model (SSE: %.6f).", self._fit_sse) + # When the simulator came from the oracle short-circuit no agent + # session ran above, so per-skill samplers (if enabled) get their + # own session here — after the option model is built, so the + # session's evaluate_plan_refinement has a working simulator. When + # the agent *did* synthesize the simulator, samplers already rode + # along in that session and this is skipped. + if self._do_synthesize_samplers and \ + CFG.agent_sim_learn_oracle_sim_program: + self._synthesize_samplers_standalone(trajectories, + base_pred_triples, + inferred_hint) + def _build_option_model( self, simulator_fn: Callable[[State, Action], State], @@ -570,6 +891,9 @@ def _synthesize_with_agent( simulator_file = os.path.join(base, "simulator.py") versions_dir = os.path.join(base, "simulator_versions") extra_paths = self._compute_extra_synthesis_paths(base) + # Per-skill samplers ride along in this session when enabled. + sampler_paths = (self._sampler_paths(base) + if self._do_synthesize_samplers else {}) # Path the agent sees: cwd-relative for local-sandbox (the # validation hook resolves against cwd and rejects literal @@ -620,6 +944,8 @@ def _synthesize_with_agent( tools.extend( self._extra_synthesis_tools(exec_ns, base_pred_triples, inferred_hint, extra_paths)) + if self._do_synthesize_samplers: + tools.extend(self._make_sampler_tools(sampler_paths)) declared = set(self._get_synthesis_tool_names() or ()) self._tool_context.extra_mcp_tools = [ t for t in tools if getattr(t, "name", "") in declared @@ -633,6 +959,9 @@ def _synthesize_with_agent( # call). Only active for this synthesis session. snapshot_targets = self._build_write_snapshot_targets( simulator_file, versions_dir, extra_paths) + if self._do_synthesize_samplers: + snapshot_targets.append( + self._sampler_snapshot_target(sampler_paths)) self._tool_context.extra_session_hooks = ( self._build_synthesis_session_hooks(snapshot_targets, base)) @@ -697,6 +1026,9 @@ def _synthesize_with_agent( extra_message = self._extra_synthesis_message(extra_paths) if extra_message: message = message + "\n\n" + extra_message + if self._do_synthesize_samplers: + message = message + "\n\n" + \ + self._sampler_synthesis_message(sampler_paths) try: self._query_agent_sync(message, kind="learn") @@ -735,6 +1067,8 @@ def _synthesize_with_agent( logger.info("Agent synthesized %d rules, %d params.", len(rules), len(specs)) self._post_synthesis_loading(extra_paths, specs) + if self._do_synthesize_samplers: + self._finalize_and_load_samplers(sampler_paths) self._process_rules = rules self._process_features = process_features @@ -1231,7 +1565,11 @@ def _format_predicate_signatures(predicates: Set[Predicate]) -> str: lines = [] for pred in sorted(predicates, key=lambda p: p.name): type_sig = ", ".join(t.name for t in pred.types) - lines.append(f" {pred.name}({type_sig})") + line = f" {pred.name}({type_sig})" + if pred.natural_language_assertion is not None: + names = [t.name for t in pred.types] + line += f" — {pred.natural_language_assertion(names)}" + lines.append(line) return "\n".join(lines) @staticmethod diff --git a/predicators/approaches/human_option_control_approach.py b/predicators/approaches/human_option_control_approach.py index 0f069f635..b8307e6fb 100644 --- a/predicators/approaches/human_option_control_approach.py +++ b/predicators/approaches/human_option_control_approach.py @@ -80,9 +80,12 @@ def _get_current_processes(self) -> Set[CausalProcess]: """ return self._processes - def _solve(self, task: Task, timeout: int) -> Callable[[State], Action]: + def _solve(self, + task: Task, + timeout: int, + _allow_replan: bool = True) -> Callable[[State], Action]: """Create a policy that prompts the user for process selection.""" - del timeout # Unused parameter + del timeout, _allow_replan # Unused parameters # If scripted option is enabled, use the scripted plan if CFG.human_option_control_approach_use_scripted_option: diff --git a/predicators/approaches/maple_q_process_approach.py b/predicators/approaches/maple_q_process_approach.py index bfea00ae4..58106438e 100644 --- a/predicators/approaches/maple_q_process_approach.py +++ b/predicators/approaches/maple_q_process_approach.py @@ -68,10 +68,11 @@ def get_name(cls) -> str: return "maple_q_with_process" # pylint: disable=arguments-differ - def _solve(self, - task: Task, - timeout: int, - train_or_test: str = "") -> Callable[[State], Action]: + def _solve( # type: ignore[override] + self, + task: Task, + timeout: int, + train_or_test: str = "") -> Callable[[State], Action]: def _option_policy(state: State) -> _Option: option = self._q_function.get_option( diff --git a/predicators/approaches/pp_oracle_approach.py b/predicators/approaches/pp_oracle_approach.py index 110544751..b27d03c37 100644 --- a/predicators/approaches/pp_oracle_approach.py +++ b/predicators/approaches/pp_oracle_approach.py @@ -1,16 +1,15 @@ """Oracle bilevel process planning approach.""" -from typing import Callable, List, Optional, Set +from typing import List, Optional, Set from gym.spaces import Box from predicators.approaches.process_planning_approach import \ BilevelProcessPlanningApproach -from predicators.ground_truth_models import augment_task_with_helper_objects, \ - get_gt_helper_predicates, get_gt_helper_types, get_gt_processes +from predicators.ground_truth_models import get_gt_processes from predicators.option_model import _OptionModelBase from predicators.settings import CFG -from predicators.structs import NSRT, Action, CausalProcess, \ - ParameterizedOption, Predicate, State, Task, Type +from predicators.structs import NSRT, CausalProcess, ParameterizedOption, \ + Predicate, Task, Type class OracleBilevelProcessPlanningApproach(BilevelProcessPlanningApproach): @@ -36,12 +35,8 @@ def __init__(self, max_skeletons_optimized, bilevel_plan_without_sim, option_model=option_model) - # Add optional helpful types and predicates (such as in dominoes the - # ones about positions and directions) - helper_types = get_gt_helper_types(CFG.env) - helper_predicates = get_gt_helper_predicates(CFG.env) - self._types = types | helper_types - self._initial_predicates = initial_predicates | helper_predicates + # The optional helper types/predicates (e.g. the domino grid) are + # added by the base class because _use_gt_helpers() returns True here. if processes is None: # use only_endogenous for the no_invent baseline @@ -75,14 +70,14 @@ def get_name(cls) -> str: def is_learning_based(self) -> bool: return False + def _use_gt_helpers(self) -> bool: + # The oracle always uses the ground-truth helper types/predicates/ + # objects (e.g. the domino grid), independent of the CFG flag. + return True + def _get_current_processes(self) -> Set[CausalProcess]: return self._processes def _get_current_nsrts(self) -> Set[NSRT]: """Get the current set of NSRTs.""" return set() - - def _solve(self, task: Task, timeout: int) -> Callable[[State], Action]: - # Augment task with helper objects if needed - task = augment_task_with_helper_objects(task, CFG.env) - return super()._solve(task, timeout) diff --git a/predicators/approaches/process_planning_approach.py b/predicators/approaches/process_planning_approach.py index 65770ce06..3408f259b 100644 --- a/predicators/approaches/process_planning_approach.py +++ b/predicators/approaches/process_planning_approach.py @@ -9,6 +9,8 @@ from predicators.approaches import ApproachFailure, ApproachTimeout from predicators.approaches.bilevel_planning_approach import \ BilevelPlanningApproach +from predicators.ground_truth_models import augment_task_with_helper_objects, \ + get_gt_helper_predicates, get_gt_helper_types from predicators.option_model import _OptionModelBase from predicators.planning import PlanningFailure, PlanningTimeout from predicators.planning_with_processes import ProcessWorldModel, \ @@ -45,6 +47,30 @@ def __init__(self, option_model=option_model) self._last_option_plan: List[_Option] = [] # used if plan WITH sim + # Optionally augment with ground-truth helper types and predicates + # (e.g. the domino grid loc/angle/direction types and predicates). + # The oracle always uses them (overrides _use_gt_helpers); other + # process-planning approaches opt in via CFG. No-op for envs without + # a helper factory. + if self._use_gt_helpers(): + self._types = self._types | get_gt_helper_types(CFG.env) + # Helper predicates take precedence on name collisions (e.g. the + # grid's derived InFront replaces the position-based InFront). + # A plain set union does NOT enforce this: the two same-named + # predicates are ``==``-equal but hash differently (DerivedPredicate + # vs Predicate), so both survive the union and ``abstract`` then + # evaluates BOTH -- the looser position-based InFront injects + # spurious adjacencies (e.g. the start block "in front of" a staged + # movable 0.13 m away), which lets the task planner build a + # physically-impossible single-block bridge. Drop any base predicate + # whose name a helper predicate already provides, then union. + helper_preds = get_gt_helper_predicates(CFG.env) + helper_names = {p.name for p in helper_preds} + self._initial_predicates = helper_preds | { + p + for p in self._initial_predicates if p.name not in helper_names + } + # Conditionally load VLM components if an abstract policy is used. self._vlm = None self.base_prompt = "" @@ -62,15 +88,33 @@ def __init__(self, with open(filepath_to_vlm_prompt, "r", encoding="utf-8") as f: self.base_prompt = f.read() + def _use_gt_helpers(self) -> bool: + """Whether to augment with ground-truth helper + types/predicates/objects. + + The oracle always uses them (overrides this to return True); + other process-planning approaches opt in via + ``CFG.process_planning_use_gt_helpers`` (e.g. for + ExoPredicator). + """ + return CFG.process_planning_use_gt_helpers + @abc.abstractmethod def _get_current_processes(self) -> Set[CausalProcess]: """Get the current set of Processes.""" raise NotImplementedError("Override me!") - def _solve(self, task: Task, timeout: int) -> Callable[[State], Action]: + def _solve(self, + task: Task, + timeout: int, + _allow_replan: bool = True) -> Callable[[State], Action]: self._num_calls += 1 # ensure random over successive seed = self._seed + self._num_calls + # Augment with ground-truth helper objects (e.g. the domino grid + # locations) when enabled; see _use_gt_helpers. No-op otherwise. + if self._use_gt_helpers(): + task = augment_task_with_helper_objects(task, CFG.env) processes = self._get_current_processes() preds = self._get_current_predicates() @@ -125,11 +169,39 @@ def _solve(self, task: Task, timeout: int) -> Callable[[State], Action]: self._save_metrics(metrics, processes, preds) + # A raw (replanned) policy is returned unwrapped so the wrapper below + # owns all replanning, avoiding nested replanning loops. + if not _allow_replan: + return policy + + max_replans = CFG.process_planning_max_execution_replans + def _policy(s: State) -> Action: - try: - return policy(s) - except utils.OptionExecutionFailure as e: - raise ApproachFailure(e.args[0], e.info) + nonlocal policy + replans = 0 + while True: + try: + return policy(s) + except utils.OptionExecutionFailure as e: + if replans >= max_replans: + raise ApproachFailure(e.args[0], e.info) + replans += 1 + # An option failed mid-execution (typically a fresh BiRRT + # collision from drift between the refinement simulator and + # the real environment). Re-refine from the current state + # so the remaining options use parameters valid for the + # actual world, then retry. Bounded by the setting above. + logging.info( + "[ProcessPlanning] Execution failure (%s); replanning " + "from the current state (attempt %d/%d).", e.args[0], + replans, max_replans) + try: + policy = self._solve(Task(s, task.goal), + timeout, + _allow_replan=False) + except (ApproachFailure, ApproachTimeout, PlanningFailure, + PlanningTimeout) as solve_err: + raise ApproachFailure(e.args[0], e.info) from solve_err return _policy diff --git a/predicators/code_sim_learning/synthesis_validation.py b/predicators/code_sim_learning/synthesis_validation.py index 344f11508..810242744 100644 --- a/predicators/code_sim_learning/synthesis_validation.py +++ b/predicators/code_sim_learning/synthesis_validation.py @@ -150,15 +150,11 @@ def run_refinement_for_synthesis( "every line names a known option with typed `obj:type` " "arguments matching what the inspect tools report.") - if timeout is None: - timeout = float( - max(CFG.agent_bilevel_refinement_timeout_min, - CFG.agent_bilevel_refinement_timeout_per_step * len(sketch))) - timeout_source = "auto" - else: - timeout = float(timeout) - timeout_source = "explicit" - assert timeout is not None + timeout, timeout_source = bilevel_sketch.resolve_refine_timeout( + timeout, + len(sketch), + per_step=CFG.agent_bilevel_refinement_timeout_per_step, + minimum=CFG.agent_bilevel_refinement_timeout_min) logger.info("Refining plan sketch (task %d, %d steps, timeout=%.0fs/%s):", task_idx, len(sketch), timeout, timeout_source) @@ -170,10 +166,10 @@ def run_refinement_for_synthesis( line += f" [subgoals: {atoms}]" logger.info(line) - step_samples_cumulative: List[int] = [0] * len(sketch) - termination_reason: List[str] = [] - elapsed_holder: List[float] = [] - plan, success, n_samples = bilevel_sketch.refine_sketch( + # Shared refinement + forward-validation + report builder (also used + # by the planner's refine_plan_sketch tool). Synthesis-specific extra: + # the post-fit SSE line, and the "Task N:" prefix on the verdict. + _, report = bilevel_sketch.refine_and_validate_report( task, sketch, candidate_om, @@ -183,81 +179,12 @@ def run_refinement_for_synthesis( max_samples_per_step=CFG.agent_bilevel_max_samples_per_step, check_subgoals=CFG.agent_bilevel_check_subgoals, log_state=CFG.agent_bilevel_log_state, + option_samplers=approach._get_all_samplers(), run_id=f"{getattr(approach, '_run_id', 'sim_learn')}_validate", - step_samples_cumulative=step_samples_cumulative, - termination_reason=termination_reason, - elapsed_holder=elapsed_holder, + timeout_source=timeout_source, + extra_summary_lines=[f" Post-fit SSE: {fit_sse:.6f}"], ) - - reason = termination_reason[0] if termination_reason else ( - "success" if success else "exhausted") - elapsed = elapsed_holder[0] if elapsed_holder else 0.0 - cap = CFG.agent_bilevel_max_samples_per_step - if success: - verdict = "SUCCESS" - elif reason == "timeout": - verdict = "FAILURE: TIMEOUT" - elif reason == "exhausted": - verdict = "FAILURE: SAMPLE_EXHAUSTED" - else: - verdict = "FAILURE" - - lines = [ - f"Task {task_idx}: {verdict}", - f" Sketch: {len(sketch)} steps Refined: {len(plan)} steps " - f"Samples: {n_samples} total", - f" Per-step samples: {step_samples_cumulative} (cap " - f"{cap}/step)", - f" Time: {elapsed:.1f}s used / {timeout:.1f}s allotted " - f"(timeout source: {timeout_source})", - f" Post-fit SSE: {fit_sse:.6f}", - ] - if not success and len(plan) < len(sketch): - stuck_idx = len(plan) - stuck = sketch[stuck_idx] - objs = ", ".join(f"{o.name}:{o.type.name}" for o in stuck.objects) - lines.append(f" Stuck at step {stuck_idx}: " - f"{stuck.option.name}({objs})") - if stuck.subgoal_atoms: - atoms = ", ".join(str(a) for a in stuck.subgoal_atoms) - lines.append(f" subgoals: {atoms}") - - # Forward validation: re-execute the refined plan continuously - # (state carries forward across all options, single shot per step). - # Refinement's per-step resets and resampling can mask test-time - # failures — running the same plan through validate_plan_forward - # under the same option model surfaces them here, *before* the - # agent declares synthesis done. - if success: - try: - fv_ok, fv_reason = bilevel_sketch.validate_plan_forward( - task, - plan, - candidate_om, - predicates=approach._get_all_predicates(), - sketch=sketch, - run_id=f"{getattr(approach, '_run_id', 'sim_learn')}_validate", - ) - except Exception as e: # pylint: disable=broad-except - fv_ok = False - fv_reason = f"forward validation raised: {e}" - if fv_ok: - lines.append(" Forward validation: SUCCESS") - else: - # Demote the headline verdict: refinement passed but the - # plan doesn't survive continuous execution, which is what - # test time will see. - lines[0] = (f"Task {task_idx}: FAILURE: " - f"FORWARD_VALIDATION_FAILED") - lines.append(f" Forward validation: FAIL — {fv_reason}") - lines.append( - " (Refinement passed because it resets state between " - "options and resamples; forward validation runs the same " - "plan continuously. A divergence here usually means a " - "learned threshold or rule is more permissive than the " - "env's effective behavior — see the INFO log for the " - "step-by-step divergence.)") - return "\n".join(lines) + return f"Task {task_idx}: {report}" def get_or_build_sketch( diff --git a/predicators/envs/gymnasium_wrapper.py b/predicators/envs/gymnasium_wrapper.py index 3829e33cd..98a65cb79 100644 --- a/predicators/envs/gymnasium_wrapper.py +++ b/predicators/envs/gymnasium_wrapper.py @@ -211,7 +211,7 @@ def close(self) -> None: "predicators.envs.pybullet_coffee:PyBulletCoffeeEnv"), ("robodisco/Cover-v0", "predicators.envs.pybullet_cover:PyBulletCoverEnv"), ("robodisco/Domino-v0", - "predicators.envs.pybullet_domino.composed_env:PyBulletDominoEnvNew"), + "predicators.envs.pybullet_domino.env:PyBulletDominoEnv"), ("robodisco/Fan-v0", "predicators.envs.pybullet_fan:PyBulletFanEnv"), ("robodisco/Float-v0", "predicators.envs.pybullet_float:PyBulletFloatEnv"), ("robodisco/Grow-v0", "predicators.envs.pybullet_grow:PyBulletGrowEnv"), diff --git a/predicators/envs/pybullet_domino/__init__.py b/predicators/envs/pybullet_domino/__init__.py index 90c82811d..59f111fce 100644 --- a/predicators/envs/pybullet_domino/__init__.py +++ b/predicators/envs/pybullet_domino/__init__.py @@ -13,12 +13,8 @@ env = PyBulletDominoFanEnv(use_gui=True) """ -from predicators.envs.pybullet_domino.composed_env import \ - PyBulletDominoEnvNew, PyBulletDominoFanEnvNew - -# Backward-compatible aliases -PyBulletDominoEnv = PyBulletDominoEnvNew -PyBulletDominoFanEnv = PyBulletDominoFanEnvNew +from predicators.envs.pybullet_domino.env import PyBulletDominoEnv, \ + PyBulletDominoFanEnv __all__ = [ "PyBulletDominoEnv", diff --git a/predicators/envs/pybullet_domino/chain_reward.py b/predicators/envs/pybullet_domino/chain_reward.py deleted file mode 100644 index 222c7043f..000000000 --- a/predicators/envs/pybullet_domino/chain_reward.py +++ /dev/null @@ -1,209 +0,0 @@ -"""Reward function for domino chain-reaction tasks. - -Evaluates whether targets were toppled via a genuine chain reaction -starting from the start domino, rather than direct robot manipulation. - -The reward decomposes into five components: - 1. target_score: fraction of targets actually toppled - 2. order_score: start domino toppled before any target - 3. robot_dist_score: robot far from targets at the moment they topple - 4. propagation_score: topple times correlate with distance from start - 5. spread_score: topples spread over time (not simultaneous) -""" - -from typing import Dict, List, Optional, Sequence, Set, Tuple - -import numpy as np - -from predicators.structs import LowLevelTrajectory, Object, State, Type - -# From domino_component.py -FALLEN_THRESHOLD = np.pi * 2 / 5 # ~72 deg — domino considered toppled - -# Color constants (r, g, b) for domino classification -_START_COLOR = (0.56, 0.93, 0.56) -_TARGET_COLOR = (0.85, 0.7, 0.85) -_MOVEABLE_COLOR = (0.6, 0.8, 1.0) - -# Reward tuning -_ROBOT_SAFE_DIST = 0.20 # metres; ~3 domino widths -_COLOR_TOL = 0.1 # tolerance for RGB matching -_MIN_SPREAD_PER_DOMINO = 3 # expected timesteps between consecutive topples - - -def _color_matches(state: State, - obj: Object, - target_rgb: Tuple[float, float, float], - tol: float = _COLOR_TOL) -> bool: - r, g, b = state.get(obj, "r"), state.get(obj, "g"), state.get(obj, "b") - return (abs(r - target_rgb[0]) < tol and abs(g - target_rgb[1]) < tol - and abs(b - target_rgb[2]) < tol) - - -def _classify_dominoes( - state: State, - dominoes: Sequence[Object], -) -> Tuple[List[Object], List[Object], List[Object]]: - """Classify dominoes into (start, moveable, target) by colour.""" - start, moveable, targets = [], [], [] - for d in dominoes: - if _color_matches(state, d, _START_COLOR): - start.append(d) - elif _color_matches(state, d, _TARGET_COLOR): - targets.append(d) - else: - moveable.append(d) - return start, moveable, targets - - -def _find_topple_times( - states: Sequence[State], - dominoes: Sequence[Object], -) -> Dict[Object, int]: - """Return {domino: first_timestep_where_toppled}.""" - topple_times: Dict[Object, int] = {} - for d in dominoes: - for t, state in enumerate(states): - if abs(state.get(d, "roll")) >= FALLEN_THRESHOLD: - topple_times[d] = t - break - return topple_times - - -def _spearman_corr(x: Sequence[float], y: Sequence[float]) -> float: - """Spearman rank correlation (no scipy dependency).""" - n = len(x) - if n < 3: - return 0.0 - xa, ya = np.asarray(x, dtype=float), np.asarray(y, dtype=float) - - def _ranks(arr: np.ndarray) -> np.ndarray: - order = np.argsort(arr) - r = np.empty_like(order, dtype=float) - r[order] = np.arange(n, dtype=float) - return r - - rx, ry = _ranks(xa), _ranks(ya) - mx, my = rx.mean(), ry.mean() - dx, dy = rx - mx, ry - my - denom = np.sqrt(float((dx**2).sum() * (dy**2).sum())) - if denom < 1e-12: - return 0.0 - return float((dx * dy).sum() / denom) - - -# ------------------------------------------------------------------ # -# Main reward function -# ------------------------------------------------------------------ # - - -def domino_chain_reward( - trajectory: LowLevelTrajectory, - types: Set[Type], - weights: Optional[Dict[str, float]] = None, -) -> float: - """Score a trajectory on how well it achieves a domino chain reaction. - - Args: - trajectory: recorded states (and actions) from an episode. - types: the environment's type set (must contain "domino", "robot"). - weights: optional dict overriding default component weights. - Keys: target, order, robot_dist, propagation, spread. - - Returns: - float in [0, 1]. - """ - w = { - "target": 0.30, - "order": 0.20, - "robot_dist": 0.20, - "propagation": 0.15, - "spread": 0.15, - } - if weights: - w.update(weights) - - states = trajectory.states - if len(states) < 2: - return 0.0 - - # --- resolve types --- - domino_type = next((t for t in types if t.name == "domino"), None) - robot_type = next((t for t in types if t.name == "robot"), None) - if domino_type is None: - return 0.0 - - all_dominoes = states[0].get_objects(domino_type) - robot = (states[0].get_objects(robot_type)[0] - if robot_type and states[0].get_objects(robot_type) else None) - - start, _moveable, targets = _classify_dominoes(states[0], all_dominoes) - if not start or not targets: - return 0.0 - - topple_times = _find_topple_times(states, all_dominoes) - - # ---- 1. target_score: fraction of targets toppled ---- - n_toppled = sum(1 for t in targets if t in topple_times) - target_score = n_toppled / len(targets) - if target_score == 0.0: - return 0.0 # nothing else to evaluate - - # ---- 2. order_score: start topples before every target ---- - start_time = min(topple_times.get(s, len(states)) for s in start) - earliest_target = min(topple_times[t] for t in targets - if t in topple_times) - order_score = 1.0 if start_time < earliest_target else 0.0 - - # ---- 3. robot_dist_score: robot far from ALL dominoes when they topple -- - # Exception: the start domino (robot must push it to initiate the chain). - if robot is not None: - dists: List[float] = [] - non_start = [d for d in all_dominoes if d not in start] - for d in non_start: - if d not in topple_times: - continue - s = states[topple_times[d]] - rx, ry = s.get(robot, "x"), s.get(robot, "y") - dx, dy = s.get(d, "x"), s.get(d, "y") - dist = np.hypot(rx - dx, ry - dy) - dists.append(min(dist / _ROBOT_SAFE_DIST, 1.0)) - robot_dist_score = float(np.mean(dists)) if dists else 1.0 - else: - robot_dist_score = 1.0 - - # ---- 4. propagation_score: topple order matches distance from start ---- - start_xy = np.array([ - states[0].get(start[0], "x"), - states[0].get(start[0], "y"), - ]) - toppled_items = [(d, topple_times[d]) for d in all_dominoes - if d in topple_times] - if len(toppled_items) >= 3: - dists_from_start = [ - np.hypot(states[0].get(d, "x") - start_xy[0], - states[0].get(d, "y") - start_xy[1]) - for d, _ in toppled_items - ] - times = [float(tt) for _, tt in toppled_items] - corr = _spearman_corr(dists_from_start, times) - propagation_score = max(0.0, corr) - else: - propagation_score = 0.5 # insufficient data, neutral - - # ---- 5. spread_score: topples spread over time, not simultaneous ---- - if len(toppled_items) >= 2: - sorted_times = sorted(tt for _, tt in toppled_items) - spread = sorted_times[-1] - sorted_times[0] - expected = len(toppled_items) * _MIN_SPREAD_PER_DOMINO - spread_score = min(spread / max(expected, 1), 1.0) - else: - spread_score = 0.5 - - # ---- weighted combination ---- - reward = (w["target"] * target_score + w["order"] * order_score + - w["robot_dist"] * robot_dist_score + - w["propagation"] * propagation_score + - w["spread"] * spread_score) - - return float(np.clip(reward, 0.0, 1.0)) diff --git a/predicators/envs/pybullet_domino/components/domino_component.py b/predicators/envs/pybullet_domino/components/domino_component.py index 8375ffba3..79155fcac 100644 --- a/predicators/envs/pybullet_domino/components/domino_component.py +++ b/predicators/envs/pybullet_domino/components/domino_component.py @@ -25,8 +25,7 @@ from predicators.structs import Object, Predicate, State, Type if TYPE_CHECKING: - from predicators.envs.pybullet_domino.composed_env import \ - PyBulletDominoComposedEnv + from predicators.envs.pybullet_domino.env import PyBulletDominoComposedEnv @dataclass @@ -41,6 +40,13 @@ class PlacementResult: target_count: int = 0 just_turned_90: bool = False just_placed_target: bool = False + # Yaw to place the *next* block at. Tracks the smooth 45-deg-per-turn + # increment, which after a turn differs from ``rotation`` (the travel + # direction used to lay out positions) by 180 deg — same physical box, + # but the increment representation keeps a straight run reading as one + # constant yaw instead of flipping. ``None`` means "same as rotation" + # (no turn has happened yet). + block_yaw: Optional[float] = None class DominoComponent(DominoEnvComponent): @@ -87,7 +93,7 @@ class DominoComponent(DominoEnvComponent): @staticmethod def _get_env_class() -> TypingType["PyBulletDominoComposedEnv"]: """Get PyBulletDominoComposedEnv class to access shared config.""" - from predicators.envs.pybullet_domino.composed_env import \ + from predicators.envs.pybullet_domino.env import \ PyBulletDominoComposedEnv # pylint: disable=import-outside-toplevel return PyBulletDominoComposedEnv @@ -163,10 +169,14 @@ def __init__(self, self.z_lb = workspace_bounds["z_lb"] self.z_ub = workspace_bounds["z_ub"] - # Domino-specific placement bounds (narrower than workspace) - # to avoid placing dominoes too close to edges - # 1.1 + 0.07 = 1.17 - self.domino_y_lb = self.y_lb + self.domino_width + # Domino-specific placement bounds (narrower than workspace) to avoid + # placing dominoes too close to edges. The lower (robot-side) margin is + # 1.5x the width: keeping the start block farther from the near edge + # makes it reliably reachable for the push, which lifts the oracle + # push-only solve rate from ~92% to ~99% (the misses were robot + # reach/push failures, not cascade stalls) while keeping task diversity. + # 1.1 + 1.5 * 0.07 = 1.205 + self.domino_y_lb = self.y_lb + 1.5 * self.domino_width # 1.6 - 0.21 = 1.39 self.domino_y_ub = self.y_ub - 3 * self.domino_width self.domino_x_lb = self.x_lb @@ -232,6 +242,18 @@ def _create_predicates(self) -> None: self._MovableBlock_holds) self._DominoNotGlued = Predicate("DominoNotGlued", [self._domino_type], self._DominoNotGlued_holds) + # Position-based InFront over continuous domino poses. When the grid is + # in use, GridComponent's derived InFront replaces this one (helper + # predicates take precedence on name collisions). + self._InFront = Predicate( + "InFront", [self._domino_type, self._domino_type], + self._InFront_holds, + natural_language_assertion=lambda os: + ("the two dominoes are chain-adjacent: one sits one spacing-gap " + "ahead of the other along that other's facing (toppling) " + "direction -- straight or bent 45 degrees left/right for a turn, " + "in both placement direction and yaw -- so that toppling the " + "back domino knocks the front one over")) # ------------------------------------------------------------------------- # DominoEnvComponent interface implementation @@ -252,6 +274,7 @@ def get_predicates(self) -> Set[Predicate]: self._Tilting, self._InitialBlock, self._MovableBlock, + self._InFront, } if CFG.domino_has_glued_dominos: preds.add(self._DominoNotGlued) @@ -476,6 +499,87 @@ def _DominoNotGlued_holds(cls, state: State, """Check if domino is NOT glued.""" return not cls._DominoGlued_holds(state, objects) + def _InFront_holds(self, state: State, objects: Sequence[Object]) -> bool: + """Position-based ``InFront`` classifier over continuous poses. + + ``InFront(d1, d2)`` holds when one domino sits roughly one + ``pos_gap`` ahead of the other along that other's facing + (toppling) direction, with a discrete turn offset between their + yaws (straight / 45-left / 45-right). It reads the continuous + domino poses directly, so it is available to grid-free agent + approaches. + """ + domino1, domino2 = objects + if state.get(domino1, "is_held") or state.get(domino2, "is_held"): + return False + + pos_gap = self.pos_gap + pos_tol = pos_gap * 0.3 + ang_tol = np.radians(15) + # Cardinal-facing slack for the reference (back) domino. A domino + # the robot re-places settles ~1 deg off cardinal, so a 1e-3 rad + # (~0.06 deg) gate makes InFront(front, placed_back) unsatisfiable + # for chained placements; allow a few degrees of slack instead. + card_thresh = float(np.sin(np.radians(10))) + # Straight, 45-degree right turn, and 45-degree left turn. + turn_offsets = (-np.pi / 4, 0.0, np.pi / 4) + + def _ahead(back: Object, front: Object) -> bool: + x_b = state.get(back, "x") + y_b = state.get(back, "y") + rot_b = state.get(back, "yaw") + # The relationship only holds for (roughly) cardinal back-facings. + if not (abs(np.sin(rot_b)) < card_thresh + or abs(np.cos(rot_b)) < card_thresh): + return False + # The front domino's yaw differs from the back's by a discrete + # turn offset (straight / +-45 deg). + diff = utils.wrap_angle(state.get(front, "yaw") - rot_b) + if not any(abs(diff - off) < ang_tol for off in turn_offsets): + return False + # The front domino sits one pos_gap from the back, along the + # back's facing -- which may itself be rotated by a turn offset, + # so the chain can bend through a turn (the next block then lies + # diagonally off the back rather than straight ahead). + fx = state.get(front, "x") + fy = state.get(front, "y") + # A domino is 180-degree symmetric, so its facing names a + # bidirectional topple axis: the front may sit one gap along + # either end of that (possibly turn-rotated) axis. + # + # A turn-completing block always carries a half-width lateral + # ("side") offset, applied orthogonal to the reference's facing + # by the task generator (see DominoTaskGenerator. + # _place_turn90_domino) so the toppling chain stays overlapping + # through the corner. A turn placement (dir_off != 0) therefore + # sits at +-side_offset along the perpendicular -- NOT on the bare + # axis. Excluding lateral 0 here is what lets the Place sampler + # distinguish the cascade-enabling offset pose from the + # symbolically-equivalent-but-physically-dead on-axis pose (an + # on-axis turn block fails this edge, so scoring prefers the + # offset). Straight placements (dir_off == 0) stay exactly on the + # axis, so no spurious edges appear. + side_offset = self.domino_width / 2 + perp_x = np.cos(rot_b) + perp_y = -np.sin(rot_b) + for dir_off in turn_offsets: + ang = rot_b + dir_off + laterals = ((0.0, ) if abs(dir_off) < 1e-9 else + (side_offset, -side_offset)) + for sgn in (1.0, -1.0): + base_x = x_b + sgn * pos_gap * np.sin(ang) + base_y = y_b + sgn * pos_gap * np.cos(ang) + for lat in laterals: + expected_x = base_x + lat * perp_x + expected_y = base_y + lat * perp_y + if (abs(fx - expected_x) < pos_tol + and abs(fy - expected_y) < pos_tol): + return True + return False + + # InFront(d1, d2) := d1 is ahead of d2, or d2 is ahead of d1. + return _ahead(domino2, domino1) or _ahead(domino1, domino2) + @classmethod def _DominoGlued_holds(cls, state: State, objects: Sequence[Object]) -> bool: diff --git a/predicators/envs/pybullet_domino/components/grid_component.py b/predicators/envs/pybullet_domino/components/grid_component.py index 29f98576e..a980e0816 100644 --- a/predicators/envs/pybullet_domino/components/grid_component.py +++ b/predicators/envs/pybullet_domino/components/grid_component.py @@ -62,6 +62,7 @@ def __init__(self, self._position_type = Type("loc", ["xx", "yy"], sim_features=["id", "xx", "yy"]) self._angle_type = Type("angle", ["angle"]) + self._direction_type = Type("direction", ["dir"]) # Create rotation objects for 8 discrete angles self.rotations: List[Object] = [] @@ -97,8 +98,7 @@ def _create_predicates(self) -> None: self._PosClear_holds) self._InFrontDirection = DerivedPredicate( "InFrontDirection", - [self._domino_type, self._domino_type, - Type("direction", ["dir"])], + [self._domino_type, self._domino_type, self._direction_type], self._InFrontDirection_holds, auxiliary_predicates={self._DominoAtPos, self._DominoAtRot}) self._InFront = DerivedPredicate( @@ -115,7 +115,7 @@ def _create_predicates(self) -> None: # ------------------------------------------------------------------------- def get_types(self) -> Set[Type]: - return {self._position_type, self._angle_type} + return {self._position_type, self._angle_type, self._direction_type} def get_predicates(self) -> Set[Predicate]: if self._domino_type is None: @@ -168,15 +168,35 @@ def reset_state(self, state: State) -> None: self._debug_line_ids.append(line_id) def extract_feature(self, obj: Object, feature: str) -> Optional[float]: - if obj.type == self._position_type: - if feature == "xx": - return obj.xx - if feature == "yy": - return obj.yy - elif obj.type == self._angle_type: - if feature == "angle": - angle_str = obj.name.split("_")[1] - return float(angle_str) + # Grid helper-object features (loc/angle/direction) are encoded in + # their names; reuse the canonical name-based reconstruction. + return self.reconstruct_feature_from_name(obj, feature) + + @staticmethod + def reconstruct_feature_from_name(obj: Object, + feature: str) -> Optional[float]: + """Reconstruct a grid helper-object feature from its name. + + The grid helper objects (loc/angle/direction) are injected into + tasks by the ground-truth models and carry no PyBullet body, so + their feature values are encoded in their names (e.g. + "loc_0.47_1.28", "ang_-90", "straight"). The composed env calls + this during its _get_state round-trip, where there is no live + GridComponent to query (these objects appear only inside oracle / + process-planning, which requires the grid). + + Returns None for non-grid objects/features so the caller can fall + through to its own error handling. + """ + if obj.type.name == "loc" and feature in ("xx", "yy"): + # Name format: "loc__", e.g. "loc_0.47_1.28". + _, x_str, y_str = obj.name.split("_") + return float(x_str) if feature == "xx" else float(y_str) + if obj.type.name == "angle" and feature == "angle": + # Name format: "ang_", e.g. "ang_-90". + return float(obj.name.split("_")[1]) + if obj.type.name == "direction" and feature == "dir": + return {"straight": 0.0, "left": 1.0, "right": 2.0}[obj.name] return None def get_init_dict_entries( @@ -291,21 +311,43 @@ def _Connected_holds(self, state: State, y_adjacent = abs(dy - self.pos_gap) < tolerance and dx < tolerance return x_adjacent or y_adjacent - def _PosClear_holds(self, state: State, objects: Sequence[Object]) -> bool: - """Check if a grid position is unoccupied by any domino.""" + @staticmethod + def _PosClear_holds(state: State, objects: Sequence[Object]) -> bool: + """Check if a position is clear (not occupied by any domino). + + A position is considered clear if no domino is currently at that + position. The occupancy tolerance is derived from the grid + spacing (half the smallest gap between location objects). + """ position, = objects + target_x = state.get(position, "xx") target_y = state.get(position, "yy") - position_tolerance = self.pos_gap * 0.5 - - assert self._domino_type is not None - for domino in state.get_objects(self._domino_type): - domino_x = state.get(domino, "x") - domino_y = state.get(domino, "y") - if (abs(domino_x - target_x) <= position_tolerance - and abs(domino_y - target_y) <= position_tolerance - and not state.get(domino, "is_held")): - return False + + # Calculate grid spacing (minimum distance between positions). + position_type = position.type + positions = list(state.get_objects(position_type)) + min_distance = float('inf') + for i, pos1 in enumerate(positions): + for pos2 in positions[i + 1:]: + x1 = state.get(pos1, "xx") + y1 = state.get(pos1, "yy") + x2 = state.get(pos2, "xx") + y2 = state.get(pos2, "yy") + distance = np.sqrt((x1 - x2)**2 + (y1 - y2)**2) + if distance > 1e-6: # Skip identical positions + min_distance = min(min_distance, distance) + position_tolerance = (min_distance * + 0.5 if min_distance != float('inf') else 0.1) + + for obj in state: + if obj.type.name == "domino": + domino_x = state.get(obj, "x") + domino_y = state.get(obj, "y") + if (abs(domino_x - target_x) <= position_tolerance + and abs(domino_y - target_y) <= position_tolerance + and not state.get(obj, "is_held")): + return False return True @staticmethod @@ -317,16 +359,15 @@ def _InFrontDirection_holds(atoms: Set[GroundAtom], """ domino1, domino2, direction_obj = objects - _pos_coord_cache: Dict[Object, Tuple[int, int]] = {} + _pos_coord_cache: Dict[Object, Tuple[float, float]] = {} _rot_rad_cache: Dict[Object, float] = {} - def extract_grid_coords(pos_obj: Object) -> Tuple[int, int]: + def extract_coords(pos_obj: Object) -> Tuple[float, float]: + # Location names encode continuous coords, e.g. "loc_0.49_1.23". if pos_obj in _pos_coord_cache: return _pos_coord_cache[pos_obj] name_parts = pos_obj.name.split("_") - y_idx = int(name_parts[1][1:]) - x_idx = int(name_parts[2][1:]) - result = (x_idx, y_idx) + result = (float(name_parts[1]), float(name_parts[2])) _pos_coord_cache[pos_obj] = result return result @@ -339,7 +380,7 @@ def extract_rotation_angle_rad(rot_obj: Object) -> float: return result d1_positions = { - extract_grid_coords(a.objects[1]) + extract_coords(a.objects[1]) for a in atoms if a.predicate.name == "DominoAtPos" and a.objects[0] == domino1 } @@ -349,7 +390,7 @@ def extract_rotation_angle_rad(rot_obj: Object) -> float: if a.predicate.name == "DominoAtRot" and a.objects[0] == domino1 } d2_positions = { - extract_grid_coords(a.objects[1]) + extract_coords(a.objects[1]) for a in atoms if a.predicate.name == "DominoAtPos" and a.objects[0] == domino2 } @@ -359,25 +400,35 @@ def extract_rotation_angle_rad(rot_obj: Object) -> float: if a.predicate.name == "DominoAtRot" and a.objects[0] == domino2 } - def _check_case(front_pos: Set[Tuple[int, int]], + def _check_case(front_pos: Set[Tuple[float, float]], front_rot: Set[float], - back_pos: Set[Tuple[int, int]], + back_pos: Set[Tuple[float, float]], back_rot: Set[float], direction_name: str, tolerance: float = 1e-6) -> bool: if not all([front_pos, front_rot, back_pos, back_rot]): return False + # pos_gap is the physical spacing between adjacent grid cells. + from predicators.envs.pybullet_domino.env import \ + PyBulletDominoComposedEnv # pylint: disable=import-outside-toplevel + pos_gap = PyBulletDominoComposedEnv.pos_gap + position_possible = False for (x_b, y_b) in back_pos: for rot_b in back_rot: + # Relationship only holds for cardinal rotations. if not (abs(np.sin(rot_b)) < tolerance or abs(np.cos(rot_b)) < tolerance): continue - dx_idx = round(np.sin(rot_b)) - dy_idx = round(np.cos(rot_b)) - if (x_b + dx_idx, y_b + dy_idx) in front_pos: - position_possible = True + expected_x = x_b + pos_gap * np.sin(rot_b) + expected_y = y_b + pos_gap * np.cos(rot_b) + for (x_f, y_f) in front_pos: + if (abs(x_f - expected_x) < pos_gap * 0.3 + and abs(y_f - expected_y) < pos_gap * 0.3): + position_possible = True + break + if position_possible: break if position_possible: break @@ -427,27 +478,42 @@ def _InFront_holds(atoms: Set[GroundAtom], @staticmethod def _AdjacentTo_holds(atoms: Set[GroundAtom], objects: Sequence[Object]) -> bool: - """Check if a position is adjacent to a domino in cardinal - directions.""" + """Check if a position is adjacent to a domino in cardinal directions. + + Adjacent means about one ``pos_gap`` away in a cardinal + direction (up/down/left/right) but not diagonal, over the + continuous-coordinate location names (e.g. ``loc_0.49_1.23``). + """ position, domino = objects - def extract_grid_coords(pos_obj: Object) -> Tuple[int, int]: + _pos_coord_cache: Dict[Object, Tuple[float, float]] = {} + + def extract_coords(pos_obj: Object) -> Tuple[float, float]: + if pos_obj in _pos_coord_cache: + return _pos_coord_cache[pos_obj] name_parts = pos_obj.name.split("_") - y_idx = int(name_parts[1][1:]) - x_idx = int(name_parts[2][1:]) - return (x_idx, y_idx) + result = (float(name_parts[1]), float(name_parts[2])) + _pos_coord_cache[pos_obj] = result + return result + + # pos_gap is the physical spacing between adjacent grid cells. + from predicators.envs.pybullet_domino.env import \ + PyBulletDominoComposedEnv # pylint: disable=import-outside-toplevel + pos_gap = PyBulletDominoComposedEnv.pos_gap - target_x, target_y = extract_grid_coords(position) + target_x, target_y = extract_coords(position) domino_positions = { - extract_grid_coords(a.objects[1]) + extract_coords(a.objects[1]) for a in atoms if a.predicate.name == "DominoAtPos" and a.objects[0] == domino } - for dx, dy in domino_positions: - if (abs(target_x - dx) == 1 and target_y == dy) or \ - (target_x == dx and abs(target_y - dy) == 1): + for domino_x, domino_y in domino_positions: + dx = abs(target_x - domino_x) + dy = abs(target_y - domino_y) + if ((abs(dx - pos_gap) < pos_gap * 0.3 and dy < pos_gap * 0.3) or + (abs(dy - pos_gap) < pos_gap * 0.3 and dx < pos_gap * 0.3)): return True return False diff --git a/predicators/envs/pybullet_domino/composed_env.py b/predicators/envs/pybullet_domino/env.py similarity index 78% rename from predicators/envs/pybullet_domino/composed_env.py rename to predicators/envs/pybullet_domino/env.py index f5ad85721..5e8973be6 100644 --- a/predicators/envs/pybullet_domino/composed_env.py +++ b/predicators/envs/pybullet_domino/env.py @@ -17,6 +17,8 @@ DominoComponent from predicators.envs.pybullet_domino.components.fan_component import \ FanComponent +from predicators.envs.pybullet_domino.components.grid_component import \ + GridComponent from predicators.envs.pybullet_domino.components.ramp_component import \ RampComponent from predicators.envs.pybullet_domino.components.stairs_component import \ @@ -174,6 +176,8 @@ def predicates(self) -> Set[Predicate]: all_preds.add(self._Holding) for comp in self._components: all_preds |= comp.get_predicates() + if self._ball_component is not None: + all_preds.add(self._ball_component.BallAtTarget) return all_preds @property @@ -182,6 +186,8 @@ def goal_predicates(self) -> Set[Predicate]: goal_preds: Set[Predicate] = set() for comp in self._components: goal_preds |= comp.get_goal_predicates() + if self._ball_component is not None: + goal_preds.add(self._ball_component.BallAtTarget) return goal_preds # ========================================================================= @@ -254,6 +260,16 @@ def _get_domain_specific_feature(self, obj: Object, feature: str) -> float: if result is not None: return result + # Grid helper objects (loc/angle/direction) are injected by the + # ground-truth models during oracle / process planning and own no + # live component here. GridComponent is the canonical home for the + # grid logic, so reconstruct their features from their names. This + # lets the _get_state round-trip in _set_state succeed even when the + # env itself is built grid-free. + result = GridComponent.reconstruct_feature_from_name(obj, feature) + if result is not None: + return result + raise ValueError(f"Unknown feature {feature} for object {obj}") def _set_domain_specific_state(self, state: State) -> None: @@ -294,6 +310,37 @@ def _Holding_holds(self, state: State, objects: Sequence[Object]) -> bool: _, domino = objects return state.get(domino, "is_held") > 0.5 + # ========================================================================= + # COMPONENT CONSTRUCTION HELPERS + # ========================================================================= + + @classmethod + def _default_workspace_bounds(cls) -> Dict[str, float]: + """Workspace bounds shared by all concrete domino environments.""" + return { + "x_lb": cls.x_lb, + "x_ub": cls.x_ub, + "y_lb": cls.y_lb, + "y_ub": cls.y_ub, + "z_lb": cls.z_lb, + "z_ub": cls.z_ub, + } + + @classmethod + def _make_domino_component( + cls, workspace_bounds: Dict[str, float]) -> DominoComponent: + """Build a domino component sized to the configured task ranges.""" + max_dominos = max(max(CFG.domino_train_num_dominos), + max(CFG.domino_test_num_dominos)) + max_targets = max(max(CFG.domino_train_num_targets), + max(CFG.domino_test_num_targets)) + max_pivots = max(max(CFG.domino_train_num_pivots), + max(CFG.domino_test_num_pivots)) + return DominoComponent(num_dominos_max=max_dominos, + num_targets_max=max_targets, + num_pivots_max=max_pivots, + workspace_bounds=workspace_bounds) + # ========================================================================= # TASK GENERATION # ========================================================================= @@ -371,31 +418,12 @@ def _make_tasks(self, # ============================================================================= -class PyBulletDominoEnvNew(PyBulletDominoComposedEnv): +class PyBulletDominoEnv(PyBulletDominoComposedEnv): """Backward-compatible domino environment class.""" def __init__(self, use_gui: bool = False, **kwargs: Any) -> None: - workspace_bounds = { - "x_lb": self.x_lb, - "x_ub": self.x_ub, - "y_lb": self.y_lb, - "y_ub": self.y_ub, - "z_lb": self.z_lb, - "z_ub": self.z_ub, - } - - max_dominos = max(max(CFG.domino_train_num_dominos), - max(CFG.domino_test_num_dominos)) - max_targets = max(max(CFG.domino_train_num_targets), - max(CFG.domino_test_num_targets)) - max_pivots = max(max(CFG.domino_train_num_pivots), - max(CFG.domino_test_num_pivots)) - - domino_comp = DominoComponent(num_dominos_max=max_dominos, - num_targets_max=max_targets, - num_pivots_max=max_pivots, - workspace_bounds=workspace_bounds) - + bounds = self._default_workspace_bounds() + domino_comp = self._make_domino_component(bounds) super().__init__(components=[domino_comp], use_gui=use_gui, **kwargs) @classmethod @@ -403,38 +431,17 @@ def get_name(cls) -> str: return "pybullet_domino" -class PyBulletDominoFanEnvNew(PyBulletDominoComposedEnv): +class PyBulletDominoFanEnv(PyBulletDominoComposedEnv): """Backward-compatible domino + fan + ball environment class.""" def __init__(self, use_gui: bool = False, **kwargs: Any) -> None: - workspace_bounds = { - "x_lb": self.x_lb, - "x_ub": self.x_ub, - "y_lb": self.y_lb, - "y_ub": self.y_ub, - "z_lb": self.z_lb, - "z_ub": self.z_ub, - } - - max_dominos = max(max(CFG.domino_train_num_dominos), - max(CFG.domino_test_num_dominos)) - max_targets = max(max(CFG.domino_train_num_targets), - max(CFG.domino_test_num_targets)) - max_pivots = max(max(CFG.domino_train_num_pivots), - max(CFG.domino_test_num_pivots)) - - domino_comp = DominoComponent(num_dominos_max=max_dominos, - num_targets_max=max_targets, - num_pivots_max=max_pivots, - workspace_bounds=workspace_bounds) - - fan_comp = FanComponent(workspace_bounds=workspace_bounds, + bounds = self._default_workspace_bounds() + domino_comp = self._make_domino_component(bounds) + fan_comp = FanComponent(workspace_bounds=bounds, table_height=self.table_height, table_width=self.table_width) - - ball_comp = BallComponent(workspace_bounds=workspace_bounds, + ball_comp = BallComponent(workspace_bounds=bounds, table_height=self.table_height) - super().__init__(components=[domino_comp, fan_comp, ball_comp], use_gui=use_gui, **kwargs) @@ -443,59 +450,21 @@ def __init__(self, use_gui: bool = False, **kwargs: Any) -> None: def get_name(cls) -> str: return "pybullet_domino_fan" - @property - def predicates(self) -> Set[Predicate]: - """Include BallAtTarget in predicates.""" - preds = super().predicates - if self._ball_component is not None: - preds.add(self._ball_component.BallAtTarget) - return preds - - @property - def goal_predicates(self) -> Set[Predicate]: - """Goals can be ball at target OR dominoes toppled.""" - preds = super().goal_predicates - if self._ball_component is not None: - preds.add(self._ball_component.BallAtTarget) - return preds - class PyBulletDominoFanRampEnv(PyBulletDominoComposedEnv): """Domino + fan + ball + ramp environment class.""" def __init__(self, use_gui: bool = False, **kwargs: Any) -> None: - workspace_bounds = { - "x_lb": self.x_lb, - "x_ub": self.x_ub, - "y_lb": self.y_lb, - "y_ub": self.y_ub, - "z_lb": self.z_lb, - "z_ub": self.z_ub, - } - - max_dominos = max(max(CFG.domino_train_num_dominos), - max(CFG.domino_test_num_dominos)) - max_targets = max(max(CFG.domino_train_num_targets), - max(CFG.domino_test_num_targets)) - max_pivots = max(max(CFG.domino_train_num_pivots), - max(CFG.domino_test_num_pivots)) - - domino_comp = DominoComponent(num_dominos_max=max_dominos, - num_targets_max=max_targets, - num_pivots_max=max_pivots, - workspace_bounds=workspace_bounds) - - fan_comp = FanComponent(workspace_bounds=workspace_bounds, + bounds = self._default_workspace_bounds() + domino_comp = self._make_domino_component(bounds) + fan_comp = FanComponent(workspace_bounds=bounds, table_height=self.table_height, table_width=self.table_width) - - ball_comp = BallComponent(workspace_bounds=workspace_bounds, + ball_comp = BallComponent(workspace_bounds=bounds, table_height=self.table_height) - - ramp_comp = RampComponent(workspace_bounds=workspace_bounds, + ramp_comp = RampComponent(workspace_bounds=bounds, table_height=self.table_height, max_ramps=5) - super().__init__( components=[domino_comp, fan_comp, ball_comp, ramp_comp], use_gui=use_gui, @@ -505,65 +474,26 @@ def __init__(self, use_gui: bool = False, **kwargs: Any) -> None: def get_name(cls) -> str: return "pybullet_domino_fan_ramp" - @property - def predicates(self) -> Set[Predicate]: - """Include BallAtTarget in predicates.""" - preds = super().predicates - if self._ball_component is not None: - preds.add(self._ball_component.BallAtTarget) - return preds - - @property - def goal_predicates(self) -> Set[Predicate]: - """Goals can be ball at target OR dominoes toppled.""" - preds = super().goal_predicates - if self._ball_component is not None: - preds.add(self._ball_component.BallAtTarget) - return preds - class PyBulletDominoFanRampStairsEnv(PyBulletDominoComposedEnv): """Domino + fan + ball + ramp + stairs environment class.""" def __init__(self, use_gui: bool = False, **kwargs: Any) -> None: - workspace_bounds = { - "x_lb": self.x_lb, - "x_ub": self.x_ub, - "y_lb": self.y_lb, - "y_ub": self.y_ub, - "z_lb": self.z_lb, - "z_ub": self.z_ub, - } - - max_dominos = max(max(CFG.domino_train_num_dominos), - max(CFG.domino_test_num_dominos)) - max_targets = max(max(CFG.domino_train_num_targets), - max(CFG.domino_test_num_targets)) - max_pivots = max(max(CFG.domino_train_num_pivots), - max(CFG.domino_test_num_pivots)) - - domino_comp = DominoComponent(num_dominos_max=max_dominos, - num_targets_max=max_targets, - num_pivots_max=max_pivots, - workspace_bounds=workspace_bounds) - - fan_comp = FanComponent(workspace_bounds=workspace_bounds, + bounds = self._default_workspace_bounds() + domino_comp = self._make_domino_component(bounds) + fan_comp = FanComponent(workspace_bounds=bounds, table_height=self.table_height, table_width=self.table_width) - - ball_comp = BallComponent(workspace_bounds=workspace_bounds, + ball_comp = BallComponent(workspace_bounds=bounds, table_height=self.table_height) - - ramp_comp = RampComponent(workspace_bounds=workspace_bounds, + ramp_comp = RampComponent(workspace_bounds=bounds, table_height=self.table_height, max_ramps=5) - # Stairs component needs reference to domino type for positioning - stairs_comp = StairsComponent(workspace_bounds=workspace_bounds, + stairs_comp = StairsComponent(workspace_bounds=bounds, table_height=self.table_height, domino_type=domino_comp.domino_type, enabled=True) - super().__init__(components=[ domino_comp, fan_comp, ball_comp, ramp_comp, stairs_comp ], @@ -577,46 +507,33 @@ def __init__(self, use_gui: bool = False, **kwargs: Any) -> None: def get_name(cls) -> str: return "pybullet_domino_fan_ramp_stairs" - @property - def predicates(self) -> Set[Predicate]: - """Include BallAtTarget in predicates.""" - preds = super().predicates - if self._ball_component is not None: - preds.add(self._ball_component.BallAtTarget) - return preds - - @property - def goal_predicates(self) -> Set[Predicate]: - """Goals can be ball at target OR dominoes toppled.""" - preds = super().goal_predicates - if self._ball_component is not None: - preds.add(self._ball_component.BallAtTarget) - return preds - if __name__ == "__main__": import sys import time + from predicators import utils + # Choose which environment to test # Options: "domino", "domino_fan", "domino_fan_ramp", # "domino_fan_ramp_stairs" # Change this to test different environments test_env = "domino_fan_ramp_stairs" + test_env = "domino" if len(sys.argv) > 1: test_env = sys.argv[1] # Configure environment - CFG.seed = 0 + CFG.seed = 1 CFG.num_train_tasks = 0 - CFG.num_test_tasks = 3 + CFG.num_test_tasks = 1 # Domino configuration - CFG.domino_initialize_at_finished_state = True + CFG.domino_initialize_at_finished_state = False CFG.domino_use_domino_blocks_as_target = True CFG.domino_has_glued_dominos = False - CFG.domino_test_num_dominos = [3, 4] - CFG.domino_test_num_targets = [1] + CFG.domino_test_num_dominos = [3] + CFG.domino_test_num_targets = [1, 2] CFG.domino_test_num_pivots = [0] # Fan/ball configuration @@ -627,13 +544,13 @@ def goal_predicates(self) -> Set[Predicate]: # Create environment based on selection env: PyBulletDominoComposedEnv if test_env == "domino": - print("Creating PyBulletDominoEnvNew...") + print("Creating PyBulletDominoEnv...") CFG.env = "pybullet_domino" - env = PyBulletDominoEnvNew(use_gui=True) + env = PyBulletDominoEnv(use_gui=True) elif test_env == "domino_fan": - print("Creating PyBulletDominoFanEnvNew...") + print("Creating PyBulletDominoFanEnv...") CFG.env = "pybullet_domino_fan" - env = PyBulletDominoFanEnvNew(use_gui=True) + env = PyBulletDominoFanEnv(use_gui=True) elif test_env == "domino_fan_ramp": print("Creating PyBulletDominoFanRampEnv...") CFG.env = "pybullet_domino_fan_ramp" @@ -666,8 +583,18 @@ def goal_predicates(self) -> Set[Predicate]: for atom in task.goal: print(f" {atom}") + # Print the initial abstract atoms (what the agent sees). + init_atoms = utils.abstract(task.init, env.predicates) + print("\nInitial atoms (abstract state seen by the agent):") + for atom in sorted(init_atoms, key=str): + print(f" {atom}") + + # Print task pretty_str + print("\n Initial state:") + print(task.init.pretty_str()) + try: - for step in range(100000): + for step in range(100): # pylint: disable=protected-access cur_action = Action( np.array(env._pybullet_robot.initial_joint_positions)) diff --git a/predicators/envs/pybullet_domino/task_generators/domino_task_generator.py b/predicators/envs/pybullet_domino/task_generators/domino_task_generator.py index e8b3655b9..ebbdb1513 100644 --- a/predicators/envs/pybullet_domino/task_generators/domino_task_generator.py +++ b/predicators/envs/pybullet_domino/task_generators/domino_task_generator.py @@ -1,6 +1,6 @@ """Task generator for domino-based tasks.""" -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, Tuple import numpy as np @@ -98,23 +98,41 @@ def _generate_single_task( for attempt_num in range(max_attempts): if log_debug: print(f"\nAttempt {attempt_num} for task {task_idx}") - obj_dict = self._generate_domino_sequence(rng, n_dominos, - n_targets, n_pivots, - log_debug, task_idx, - domino_in_upper_half) - if obj_dict is not None: - if log_debug: - print("Found satisfying domino sequence") - break + candidate_obj_dict = self._generate_domino_sequence( + rng, n_dominos, n_targets, n_pivots, log_debug, task_idx, + domino_in_upper_half) + if candidate_obj_dict is None: + continue + + # Make the chain's terminal block(s) the target(s). The placement + # loop can otherwise mark a mid-chain block as the target, leaving + # movable blocks after the goal -- which makes the bridge length + # ambiguous (an agent over-builds past the target, e.g. a 2-gap + # task that admits one intermediate but is planned with two). + # Blocks are placed start-first along the chain, so the + # highest-index ones are the chain end; re-designating those keeps + # the target last. + if CFG.domino_use_domino_blocks_as_target: + self._retarget_terminal_dominoes(candidate_obj_dict, n_targets) + + # Move intermediate objects if needed. This can fail if the + # unfinished staging area is too full after collision checking, so + # keep it inside the attempt loop and resample the solved chain. + if not CFG.domino_initialize_at_finished_state: + candidate_obj_dict = \ + self._move_intermediate_objects_to_unfinished_state( + candidate_obj_dict) + if candidate_obj_dict is None: + continue + + obj_dict = candidate_obj_dict + if log_debug: + print("Found satisfying domino sequence") + break if obj_dict is None: return None - # Move intermediate objects if needed - if not CFG.domino_initialize_at_finished_state: - obj_dict = self._move_intermediate_objects_to_unfinished_state( - obj_dict) - init_dict.update(obj_dict) # Add entries from additional components @@ -137,11 +155,14 @@ def _generate_single_task( for target_obj in init_state.get_objects(self.domino.target_type): goal_atoms.add(GroundAtom(self.domino.Toppled, [target_obj])) + if len(goal_atoms) == 1: + target_word, target_verb = "the purple domino", "is" + else: + target_word, target_verb = "the purple dominoes", "are" goal_nl = ( - "Arrange the moveable domino blocks into a chain so that when " - "the start domino is pushed, the chain reaction topples the " - "target(s). Do NOT directly push " - "or topple the target dominoes yourself.") + f"Move the blue dominoes such that when the green domino is " + f"pushed, {target_word} {target_verb} toppled. Do NOT directly " + f"push or topple {target_word} yourself.") return EnvironmentTask(init_state, goal_atoms, goal_nl=goal_nl) @@ -188,13 +209,47 @@ def _in_bounds(nx: float, ny: float) -> bool: task_idx=task_idx) domino_count += 1 - # Main placement loop + expected_count = self._get_expected_domino_count(n_dominos, n_targets) + + # When targets are domino blocks, they are re-designated as the + # chain's terminal block(s) after generation (see + # _retarget_terminal_dominoes), so here we just fill the chain to + # length with regular blocks. This also avoids overrunning the + # fixed-size dominos[] list: the interleaved loop below could let a + # turn (which places two blocks at once) push the count past the last + # slot when a max-size task leaves no slack for the targets — the + # index-out-of-range crash. The turn90 guard + # (domino_count + 1 >= expected_count -> straight) keeps this loop in + # bounds. + if CFG.domino_use_domino_blocks_as_target: + # ``block_yaw`` tracks the smooth 45-deg-per-turn yaw increment so + # straight runs after a turn keep one constant yaw; positions still + # follow ``rotation`` (the travel direction). + block_yaw = rotation + while domino_count < expected_count: + result = self._place_next_domino( + rng, obj_dict, x, y, rotation, gap, domino_count, + pivot_count, target_count, n_pivots, n_dominos, n_targets, + just_placed_target, just_turned_90, _in_bounds, task_idx, + block_yaw) + if not result.success: + return None + x, y, rotation = result.x, result.y, result.rotation + domino_count = result.domino_count + pivot_count = result.pivot_count + just_turned_90 = result.just_turned_90 + block_yaw = (result.block_yaw + if result.block_yaw is not None else rotation) + if domino_count == expected_count and pivot_count == n_pivots: + return obj_dict + return None + + # Separate target objects (use_domino_blocks_as_target=False): + # interleave regular dominoes and target-typed objects. while self._should_continue_placement(domino_count, target_count, n_dominos, n_targets): can_place_target = (domino_count >= 2 and target_count < n_targets and not just_placed_target) - expected_count = self._get_expected_domino_count( - n_dominos, n_targets) can_place_domino = domino_count < expected_count should_place_domino = (not can_place_target @@ -232,6 +287,31 @@ def _in_bounds(nx: float, ny: float) -> bool: return obj_dict return None + def _retarget_terminal_dominoes(self, obj_dict: Dict[Object, Any], + n_targets: int) -> None: + """Recolor so the last ``n_targets`` placed blocks are the target(s). + + Mutates ``obj_dict`` in place. Dominoes are placed start-first + along the chain, so ``self.domino.dominos`` index order is chain + order: the terminal ``n_targets`` blocks become targets (purple) + and every other non-start block becomes movable (blue). No-op + for ``n_targets <= 0``. (Glue state is not preserved; it only + applies when ``domino_has_glued_dominos`` is set, which is off + by default.) + """ + if n_targets <= 0: + return + placed = [d for d in self.domino.dominos if d in obj_dict] + terminal = set(placed[-n_targets:]) + target_color = self.domino.target_domino_color + movable_color = self.domino.domino_color + for idx, domino_obj in enumerate(placed): + if idx == 0: + continue # start block keeps its color + color = target_color if domino_obj in terminal else movable_color + entry = obj_dict[domino_obj] + entry["r"], entry["g"], entry["b"] = color[0], color[1], color[2] + def _get_expected_domino_count(self, n_dominos: int, n_targets: int) -> int: if CFG.domino_use_domino_blocks_as_target: @@ -255,23 +335,25 @@ def _check_placement_complete(self, domino_count: int, target_count: int, return (domino_count == n_dominos and target_count == n_targets and pivot_count == n_pivots) - def _place_next_domino(self, - rng: np.random.Generator, - obj_dict: Dict, - x: float, - y: float, - rotation: float, - gap: float, - domino_count: int, - pivot_count: int, - target_count: int, - n_pivots: int, - n_dominos: int, - n_targets: int, - just_placed_target: bool, - just_turned_90: bool, - _in_bounds: Callable[[float, float], bool], - task_idx: Optional[int] = None) -> PlacementResult: + def _place_next_domino( + self, + rng: np.random.Generator, + obj_dict: Dict, + x: float, + y: float, + rotation: float, + gap: float, + domino_count: int, + pivot_count: int, + target_count: int, + n_pivots: int, + n_dominos: int, + n_targets: int, + just_placed_target: bool, + just_turned_90: bool, + _in_bounds: Callable[[float, float], bool], + task_idx: Optional[int] = None, + block_yaw: Optional[float] = None) -> PlacementResult: """Place the next domino using various strategies.""" turn_choices = self.domino.turn_choices.copy() if pivot_count >= n_pivots and "pivot180" in turn_choices: @@ -293,25 +375,39 @@ def _place_next_domino(self, if choice == "straight": return self._place_straight_domino(rng, obj_dict, x, y, rotation, gap, domino_count, _in_bounds, - task_idx) + task_idx, block_yaw) if choice == "turn90": return self._place_turn90_domino(rng, obj_dict, x, y, rotation, gap, domino_count, n_dominos, n_targets, _in_bounds, task_idx, - should_place_target_at_end) + should_place_target_at_end, + block_yaw) if choice == "pivot180": return self._place_pivot180_domino(rng, obj_dict, x, y, rotation, gap, domino_count, pivot_count, _in_bounds, task_idx, should_place_target_at_end) return self._place_straight_domino(rng, obj_dict, x, y, rotation, gap, - domino_count, _in_bounds, task_idx) + domino_count, _in_bounds, task_idx, + block_yaw) - def _place_straight_domino(self, rng: np.random.Generator, - obj_dict: Dict[Object, Any], x: float, y: float, - rotation: float, gap: float, domino_count: int, - _in_bounds: Callable[[float, float], bool], - task_idx: Optional[int]) -> PlacementResult: + def _place_straight_domino( + self, + rng: np.random.Generator, + obj_dict: Dict[Object, Any], + x: float, + y: float, + rotation: float, + gap: float, + domino_count: int, + _in_bounds: Callable[[float, float], bool], + task_idx: Optional[int], + block_yaw: Optional[float] = None) -> PlacementResult: + # Travel direction (positions) follows ``rotation``; the block is laid + # at ``block_yaw`` (the smooth turn increment) when one has been + # established, else at ``rotation``. They are the same box, so a run + # after a turn reads as one constant yaw instead of flipping 180 deg. + yaw = rotation if block_yaw is None else block_yaw dx = gap * np.sin(rotation) dy = gap * np.cos(rotation) new_x, new_y = x + dx, y + dy @@ -321,13 +417,14 @@ def _place_straight_domino(self, rng: np.random.Generator, x=x, y=y, rotation=rotation, - domino_count=domino_count) + domino_count=domino_count, + block_yaw=block_yaw) obj_dict[self.domino.dominos[domino_count]] = self.domino.place_domino( domino_count, new_x, new_y, - rotation, + yaw, is_start_block=False, rng=rng, task_idx=task_idx) @@ -336,34 +433,54 @@ def _place_straight_domino(self, rng: np.random.Generator, x=new_x, y=new_y, rotation=rotation, - domino_count=domino_count + 1) + domino_count=domino_count + 1, + block_yaw=block_yaw) def _place_turn90_domino( - self, rng: np.random.Generator, obj_dict: Dict[Object, Any], - x: float, y: float, rotation: float, gap: float, domino_count: int, - n_dominos: int, n_targets: int, - _in_bounds: Callable[[float, float], - bool], task_idx: Optional[int], - should_place_target_at_end: bool) -> PlacementResult: + self, + rng: np.random.Generator, + obj_dict: Dict[Object, Any], + x: float, + y: float, + rotation: float, + gap: float, + domino_count: int, + n_dominos: int, + n_targets: int, + _in_bounds: Callable[[float, float], bool], + task_idx: Optional[int], + should_place_target_at_end: bool, + block_yaw: Optional[float] = None) -> PlacementResult: expected_count = self._get_expected_domino_count(n_dominos, n_targets) if domino_count + 1 >= expected_count: return self._place_straight_domino(rng, obj_dict, x, y, rotation, gap, domino_count, _in_bounds, - task_idx) - + task_idx, block_yaw) + + # The two turn blocks' yaws step 45 deg per block off the running block + # yaw (``block_yaw``, = ``rotation`` before any turn), so successive + # turns keep incrementing rather than resetting and a 90 deg turn reads + # as a smooth increment (yaw, yaw +/- 45, yaw +/- 90). Positions are + # independent of this representation and follow ``rotation`` (the + # travel direction): ``d1_dir`` is the chain's toppling direction one + # 45 deg step into the turn; d1 sits one gap ahead of the current block + # along the entry direction (no lateral shift, so it stays on the + # previous block's fall line) and d2 one gap ahead of d1 along d1_dir. + base_yaw = rotation if block_yaw is None else block_yaw turn_direction = rng.choice([-1, 1]) - dx = gap * np.sin(rotation) - dy = gap * np.cos(rotation) - d1_base_x, d1_base_y = x + dx, y + dy - d1_rot = rotation - turn_direction * np.pi / 4 - - shift_magnitude = self.domino.domino_width * self.domino.turn_shift_frac - shift_dx = shift_magnitude * (turn_direction * np.cos(rotation) - - np.sin(rotation)) - shift_dy = shift_magnitude * (-turn_direction * np.sin(rotation) - - np.cos(rotation)) - d1_x = d1_base_x + shift_dx - d1_y = d1_base_y + shift_dy + d1_dir = rotation - turn_direction * np.pi / 4 + d1_yaw = base_yaw + turn_direction * np.pi / 4 + d1_x = x + gap * np.sin(rotation) + d1_y = y + gap * np.cos(rotation) + # Lateral "side" offset for the first turn block, kept at 0 (matching + # the legacy generator, which only nudged the turn-completing block). + # Exposed here as an explicit tunable knob -- raise it to also shift + # the first block orthogonal to its post-turn travel direction + # ``d1_dir`` if future tuning needs more overlap entering the bend. + d1_side_offset = -self.domino.domino_width / 2 + # d1_side_offset = 0 + d1_x += turn_direction * d1_side_offset * np.cos(d1_dir) + d1_y -= turn_direction * d1_side_offset * np.sin(d1_dir) if not _in_bounds(d1_x, d1_y): return PlacementResult(success=False, @@ -376,21 +493,34 @@ def _place_turn90_domino( domino_count, d1_x, d1_y, - d1_rot, + d1_yaw, is_start_block=False, rng=rng, task_idx=task_idx) domino_count += 1 - d2_rot = d1_rot - turn_direction * np.pi / 4 - sin_d1 = np.sin(d1_rot) - cos_d1 = np.cos(d1_rot) - disp_x = (gap * turn_direction * cos_d1 + - (2 * shift_magnitude - gap) * sin_d1) / np.sqrt(2) - disp_y = (-gap * turn_direction * sin_d1 + - (2 * shift_magnitude - gap) * cos_d1) / np.sqrt(2) - d2_x = d1_x + disp_x - d2_y = d1_y + disp_y + # Second turn block: one gap ahead of d1 along the chain direction, + # completing the 90 deg turn. Its yaw continues the +/-45 increment; + # ``d2_rot`` (the same cardinal orientation, 180 deg off) is returned + # as the travel direction so subsequent straight blocks lay out + # correctly, while ``d2_yaw`` is threaded as the running block yaw so + # those blocks keep this orientation instead of flipping. + d2_yaw = base_yaw + turn_direction * np.pi / 2 + d2_rot = rotation - turn_direction * np.pi / 2 + d2_x = d1_x + gap * np.sin(d1_dir) + d2_y = d1_y + gap * np.cos(d1_dir) + # Lateral "side" offset (ported from the legacy turn generator): in + # addition to stepping the turn-completing block one gap *along* the + # chain, nudge it a half-width *orthogonal* to its own travel + # direction. Without this sideways shift the falling chain only moves + # along one axis and clips past the corner block, so the cascade + # stalls; the inward nudge keeps the toppling dominoes overlapping + # through the bend. ``(cos d2_rot, -sin d2_rot)`` is the unit vector + # perpendicular to the block's facing, signed by the turn direction. + side_offset = -self.domino.domino_width / 2 + # side_offset = 0 + d2_x += turn_direction * side_offset * np.cos(d2_rot) + d2_y -= turn_direction * side_offset * np.sin(d2_rot) if not _in_bounds(d2_x, d2_y): return PlacementResult(success=False, @@ -403,7 +533,7 @@ def _place_turn90_domino( domino_count, d2_x, d2_y, - d2_rot, + d2_yaw, is_start_block=False, is_target_block=should_place_target_at_end, rng=rng, @@ -417,7 +547,8 @@ def _place_turn90_domino( domino_count=domino_count + 1, target_count=target_inc, just_turned_90=True, - just_placed_target=should_place_target_at_end) + just_placed_target=should_place_target_at_end, + block_yaw=d2_yaw) def _place_pivot180_domino( self, rng: np.random.Generator, obj_dict: Dict[Object, Any], @@ -523,8 +654,8 @@ def _place_next_target(self, rng: np.random.Generator, domino_count=domino_count, target_count=target_count + 1) - def _move_intermediate_objects_to_unfinished_state(self, - obj_dict: Dict) -> Dict: + def _move_intermediate_objects_to_unfinished_state( + self, obj_dict: Dict) -> Optional[Dict]: """Move intermediate dominoes and pivots to unfinished positions.""" intermediate_objects = [] eps = 1e-3 @@ -570,30 +701,164 @@ def _move_intermediate_objects_to_unfinished_state(self, if not intermediate_objects: return obj_dict - start_x = self.domino.domino_x_lb + self.domino.domino_width + occupied = { + obj: data + for obj, data in obj_dict.items() + if all(obj != intermediate[0] + for intermediate in intermediate_objects) + } + + x_margin = self.domino.domino_width + y_margin = self.domino.domino_width spacing = self.domino.domino_width * 1.5 - y_position = (self.domino.domino_y_lb + self.domino.domino_y_ub) / 2 - - for i, (obj, obj_type) in enumerate(intermediate_objects): - new_x = start_x + i * spacing - if obj_type == "domino": - obj_dict[obj] = { - "x": new_x, - "y": y_position, - "z": self.domino.z_lb + self.domino.domino_height / 2, - "yaw": 0.0, - "roll": 0.0, - "r": self.domino.domino_color[0], - "g": self.domino.domino_color[1], - "b": self.domino.domino_color[2], - "is_held": 0.0, - } - elif obj_type == "pivot": - obj_dict[obj] = { - "x": new_x, - "y": y_position, - "z": self.domino.z_lb, - "yaw": 0.0, - } + + # Gripper swept-footprint half-extents for a top-down grasp of a + # staged (yaw=0) domino. The open fingers span the domino's depth axis + # (local y) and reach ~1.45x the domino width from the grasp center; + # the hand spans ~0.85x along the long axis (local x). Measured from + # the Fetch gripper at the descend pose. A staged domino must keep this + # footprint clear of every other object, otherwise it lands placed but + # *un-pickable* -- BiRRT finds no collision-free descent because a + # neighbor (especially a perpendicular one a few cm away in y) sits + # inside the finger sweep even though the footprints don't overlap. + grasp_clear_hand = self.domino.domino_width * 0.85 + grasp_clear_finger = self.domino.domino_width * 1.45 + x_values = np.arange(self.domino.domino_x_lb + x_margin, + self.domino.domino_x_ub - x_margin + eps, spacing) + y_values = np.arange(self.domino.domino_y_lb + y_margin, + self.domino.domino_y_ub - y_margin + eps, spacing) + candidate_xy = [(float(x), float(y)) for y in y_values + for x in x_values] + + for obj, obj_type in intermediate_objects: + placed = False + for new_x, new_y in candidate_xy: + candidate: Dict[str, float] + if obj_type == "domino": + candidate = { + "x": new_x, + "y": new_y, + "z": self.domino.z_lb + self.domino.domino_height / 2, + "yaw": 0.0, + "roll": 0.0, + "r": self.domino.domino_color[0], + "g": self.domino.domino_color[1], + "b": self.domino.domino_color[2], + "is_held": 0.0, + } + else: + candidate = { + "x": new_x, + "y": new_y, + "z": self.domino.z_lb, + "yaw": 0.0, + } + if self._placement_collides(obj, candidate, occupied): + continue + if obj_type == "domino" and self._grasp_clearance_blocked( + candidate, occupied, grasp_clear_hand, + grasp_clear_finger): + continue + obj_dict[obj] = candidate + occupied[obj] = candidate + placed = True + break + if not placed: + return None return obj_dict + + def _placement_collides(self, obj: Object, candidate: Dict[str, float], + occupied: Dict[Object, Dict[str, float]]) -> bool: + """Check whether ``candidate`` overlaps any occupied object.""" + candidate_rect = self._placement_rect(obj, candidate) + for other_obj, other_data in occupied.items(): + other_rect = self._placement_rect(other_obj, other_data) + if self._rectangles_overlap(candidate_rect, other_rect): + return True + return False + + def _grasp_clearance_blocked(self, candidate: Dict[str, float], + occupied: Dict[Object, Dict[str, float]], + half_hand: float, half_finger: float) -> bool: + """Whether the gripper's swept grasp footprint at ``candidate`` would + overlap another object, leaving the staged domino un-pickable. + + ``half_hand``/``half_finger`` are the gripper footprint half- + extents along the domino's long axis (local x) and depth/finger- + span axis (local y). The check is the same oriented-rectangle + overlap test used for placement, but against the larger gripper + footprint. + """ + clear_rect = self._oriented_rect_corners(candidate["x"], + candidate["y"], + candidate.get("yaw", 0.0), + half_hand, half_finger) + for other_obj, other_data in occupied.items(): + if self._rectangles_overlap( + clear_rect, self._placement_rect(other_obj, other_data)): + return True + return False + + @staticmethod + def _oriented_rect_corners(x: float, y: float, yaw: float, half_w: float, + half_d: float) -> Tuple[np.ndarray, np.ndarray]: + """Return (center, corners) of an oriented rectangle with the given + half-extents along its local x (``half_w``) and y (``half_d``) axes.""" + center = np.array([x, y], dtype=np.float64) + local = np.array( + [[-half_w, -half_d], [-half_w, half_d], [half_w, half_d], + [half_w, -half_d]], + dtype=np.float64, + ) + rot = np.array([[np.cos(yaw), -np.sin(yaw)], + [np.sin(yaw), np.cos(yaw)]], + dtype=np.float64) + return center, center + local @ rot.T + + def _placement_rect( + self, obj: Object, + data: Dict[str, float]) -> Tuple[np.ndarray, np.ndarray]: + """Return center and corners for an object's conservative footprint.""" + if obj.type == self.domino.domino_type: + width = self.domino.domino_width + depth = self.domino.domino_depth + elif obj.type == self.domino.pivot_type: + width = self.domino.pivot_width + depth = self.domino.pivot_width + else: + width = self.domino.domino_width + depth = self.domino.domino_width + + padding = 0.003 + half_w = width / 2 + padding + half_d = depth / 2 + padding + yaw = data["yaw"] + center = np.array([data["x"], data["y"]], dtype=np.float64) + local = np.array( + [[-half_w, -half_d], [-half_w, half_d], [half_w, half_d], + [half_w, -half_d]], + dtype=np.float64, + ) + rot = np.array([[np.cos(yaw), -np.sin(yaw)], + [np.sin(yaw), np.cos(yaw)]], + dtype=np.float64) + return center, center + local @ rot.T + + @staticmethod + def _rectangles_overlap(rect1: Tuple[np.ndarray, np.ndarray], + rect2: Tuple[np.ndarray, np.ndarray]) -> bool: + """Separating-axis overlap test for two oriented rectangles.""" + + def _axes(corners: np.ndarray) -> List[np.ndarray]: + edges = [corners[1] - corners[0], corners[2] - corners[1]] + return [edge / np.linalg.norm(edge) for edge in edges] + + _, corners1 = rect1 + _, corners2 = rect2 + for axis in _axes(corners1) + _axes(corners2): + proj1 = corners1 @ axis + proj2 = corners2 @ axis + if max(proj1) <= min(proj2) or max(proj2) <= min(proj1): + return False + return True diff --git a/predicators/envs/pybullet_env.py b/predicators/envs/pybullet_env.py index 7cabcad43..b8aa6e8b0 100644 --- a/predicators/envs/pybullet_env.py +++ b/predicators/envs/pybullet_env.py @@ -449,7 +449,14 @@ def _step_base(self, action: Action) -> None: """Run robot control, physics stepping, and grasp management.""" # Send the action to the robot. target_joint_positions, base_delta = self._split_action(action) - if base_delta.size: + # Only relocate the (kinematic) base when there is an actual move. + # Calling set_base_pose (resetBasePositionAndOrientation) every step, + # even for a zero delta, perturbs the arm's contact dynamics — it makes + # the mobile_fetch switch-push wander off target — whereas fixed-base + # robots never touch the base. A zero delta is a no-op, so skip it. + base_moved = bool( + base_delta.size) and not bool(np.allclose(base_delta, 0.0)) + if base_moved: self._apply_base_delta(base_delta) self._pybullet_robot.set_motors(target_joint_positions.tolist()) @@ -457,8 +464,17 @@ def _step_base(self, action: Action) -> None: # object, we need to reset the pose of the held object directly. This # is because the PyBullet constraints don't seem to play nicely with # resetJointState (the robot will sometimes drop the object). - if CFG.pybullet_control_mode == "reset" and \ - self._held_obj_id is not None: + # + # The same hand-off is needed whenever the kinematic base just + # teleported with an object in hand (mobile robots): set_base_pose jumps + # the gripper, and over the single physics step the grasp constraint + # would yank the object across the jump -- the jug lags, tips, or slides + # in the gripper and then collides at the subsequent place/retreat. Pre- + # placing it at the gripper (it tracks the constant grasp offset, so + # this is exact for a rigid grasp) makes the carry follow the base + # smoothly. + if self._held_obj_id is not None and (CFG.pybullet_control_mode + == "reset" or base_moved): world_to_base_link = get_link_state( self._pybullet_robot.robot_id, self._pybullet_robot.end_effector_id, @@ -538,6 +554,12 @@ def _set_state(self, state: State) -> None: # any features this reset could not round-trip. self._last_unreconstructible_features = [] + # Mobile base: restore the base pose first, since every arm/object + # world pose is expressed relative to it. _robot_matches_state also + # checks the base, so a base move forces the joints + grasp constraint + # to be rebuilt in the restored base frame below. + self._restore_base_pose_from_state(state) + wrote_anything = False # 1) Robot pose diff. Skipping this branch when the live joints @@ -603,6 +625,11 @@ def _set_state(self, state: State) -> None: self._pybullet_robot.reset_state(self._extract_robot_state(state), joint_positions=joint_positions, trust_joints=trust_joints) + # reset_state snaps the base back to the robot's fixed home pose; + # for a mobile base, re-apply the requested base pose so the joints + # (recorded for that base) place the arm in the right world frame + # and the grasp constraint below is recorded in the correct frame. + self._restore_base_pose_from_state(state) wrote_anything = True for obj in objects_to_reset: @@ -817,7 +844,42 @@ def _robot_matches_state(self, state: State, atol: float = 1e-3) -> bool: cur_jp = self._pybullet_robot.get_joints() except (KeyError, ValueError): return False - return bool(np.allclose(jp, cur_jp, atol=atol)) + if not bool(np.allclose(jp, cur_jp, atol=atol)): + return False + # Mobile base: a base move (with identical joints) still relocates the + # whole arm, so it must count as a robot change. + want_base = self._base_pose_from_state(state) + if want_base is not None: + cur_base = self._robot_base_pose_tuple() + if cur_base is not None and not ( + np.allclose(want_base[0], cur_base[0], atol=atol) + and np.allclose(want_base[1], cur_base[1], atol=atol)): + return False + return True + + @staticmethod + def _base_pose_from_state( + state: State + ) -> Optional[Tuple[Tuple[float, float, float], Tuple[float, float, float, + float]]]: + """Pull a mobile base pose out of a State's simulator_state, if any.""" + sim_state = getattr(state, "simulator_state", None) + if isinstance(sim_state, dict): + return sim_state.get("base_pose", None) + return None + + def _restore_base_pose_from_state(self, state: State) -> None: + """Set the mobile base pose from the State's simulator_state, if it + carries one (no-op for fixed-base robots / states without it).""" + base_pose = self._base_pose_from_state(state) + if base_pose is None: + return + robot = self._pybullet_robot + if not hasattr(robot, "set_base_pose"): + return + pos, orn = base_pose + robot.set_base_pose( # type: ignore[attr-defined] + Pose((pos[0], pos[1], pos[2]), (orn[0], orn[1], orn[2], orn[3]))) def _object_pose_matches_state(self, obj: Object, @@ -911,9 +973,17 @@ def _reset_single_object(self, obj: Object, state: State) -> None: # Convert from 2D angle to a 3D quaternion (assuming rotation around # z) orn = p.getQuaternionFromEuler([0.0, 0.0, angle]) - elif "yaw" in features: - angle = state.get(obj, "yaw") - orn = p.getQuaternionFromEuler([0.0, 0.0, angle]) + elif {"yaw", "roll", "pitch"} & set(features): + # Rebuild the full orientation from whichever Euler angles the type + # carries (PyBullet's convention is [roll, pitch, yaw]). Dropping + # roll/pitch here would make toppled objects — e.g. a fallen domino + # with roll≈π — unreconstructible: _get_state reads the angle back, + # the mismatch exceeds _reconstruction_raise_atol, and _set_state + # raises instead of round-tripping. Missing angles default to 0. + roll = state.get(obj, "roll") if "roll" in features else 0.0 + pitch = state.get(obj, "pitch") if "pitch" in features else 0.0 + yaw = state.get(obj, "yaw") if "yaw" in features else 0.0 + orn = p.getQuaternionFromEuler([roll, pitch, yaw]) else: orn = self._default_orn # e.g. (0,0,0,1) @@ -1063,16 +1133,34 @@ def _get_state(self, _render_obs: bool = False) -> State: state = utils.create_state_from_dict(state_dict) joint_positions = self._pybullet_robot.get_joints() - pyb_state = PyBulletState(state.data, - simulator_state={ - "joint_positions": joint_positions, - "physics_client_id": - self._physics_client_id, - "robot_id": - self._pybullet_robot.robot_id, - }) + sim_state_dict: Dict[str, Any] = { + "joint_positions": joint_positions, + "physics_client_id": self._physics_client_id, + "robot_id": self._pybullet_robot.robot_id, + } + # Mobile robots: carry the base pose so it round-trips through + # _set_state (the base is not a State feature, so without this a + # reconstruction would silently keep the live base pose, breaking + # option-model / refinement rollouts that move the base). + base_pose = self._robot_base_pose_tuple() + if base_pose is not None: + sim_state_dict["base_pose"] = base_pose + pyb_state = PyBulletState(state.data, simulator_state=sim_state_dict) return pyb_state + def _robot_base_pose_tuple( + self + ) -> Optional[Tuple[Tuple[float, float, float], Tuple[float, float, float, + float]]]: + """Return the mobile base pose as (position, orientation) tuples, or + None for fixed-base robots.""" + robot = self._pybullet_robot + if int(getattr(robot, "base_action_dim", 0)) <= 0 or \ + not hasattr(robot, "get_base_pose"): + return None + base_pose = robot.get_base_pose() # type: ignore[attr-defined] + return (tuple(base_pose.position), tuple(base_pose.orientation)) + def _get_robot_state_dict(self) -> Dict[str, float]: """Build a feature dict for the robot from PyBullet state. diff --git a/predicators/explorers/agent_bilevel_explorer.py b/predicators/explorers/agent_bilevel_explorer.py index 5055581be..1003dfd66 100644 --- a/predicators/explorers/agent_bilevel_explorer.py +++ b/predicators/explorers/agent_bilevel_explorer.py @@ -176,6 +176,7 @@ def _get_exploration_strategy(self, train_task_idx: int, run_id="agent_bilevel_explorer", info_scorer=info_scorer, info_n_feasible_target=info_n_feasible_target, + option_samplers=self._tool_context.option_samplers, ) # Record the honest verdict so get_interaction_requests can # stamp it onto this request: early stopping should not treat a diff --git a/predicators/ground_truth_models/__init__.py b/predicators/ground_truth_models/__init__.py index 54b6155d9..56fbbb3ad 100644 --- a/predicators/ground_truth_models/__init__.py +++ b/predicators/ground_truth_models/__init__.py @@ -2,7 +2,7 @@ import abc import sys from pathlib import Path -from typing import Dict, List, Sequence, Set +from typing import Dict, List, Optional, Sequence, Set from gym.spaces import Box @@ -10,7 +10,8 @@ from predicators.envs import BaseEnv, get_or_create_env from predicators.settings import CFG from predicators.structs import NSRT, CausalProcess, EndogenousProcess, \ - LiftedDecisionList, ParameterizedOption, Predicate, Task, Type + LiftedDecisionList, OptionSampler, ParameterizedOption, Predicate, Task, \ + Type class GroundTruthOptionFactory(abc.ABC): @@ -87,6 +88,28 @@ def get_env_names(cls) -> Set[str]: raise NotImplementedError("Override me!") +class GroundTruthSamplerFactory(abc.ABC): + """Parent class for ground-truth per-skill samplers. + + Provides a mapping ``option name -> OptionSampler`` consulted by + bilevel-sketch refinement (the grid-free counterpart of the NSRT + samplers in ``processes.py``). Lets an env supply hand-written + samplers instead of having the agent synthesize them. + """ + + @classmethod + @abc.abstractmethod + def get_env_names(cls) -> Set[str]: + """Get the env names that this factory builds samplers for.""" + raise NotImplementedError("Override me!") + + @classmethod + @abc.abstractmethod + def get_samplers(cls, env_name: str) -> Dict[str, OptionSampler]: + """Return ``option name -> OptionSampler`` for the given env.""" + raise NotImplementedError("Override me!") + + class GroundTruthLDLBridgePolicyFactory(abc.ABC): """Ground-truth policies implemented with LDLs saved in text files.""" @@ -221,7 +244,9 @@ def get_gt_processes(env_name: str, env = get_or_create_env(env_name) env_options = get_gt_options(env_name) helper_predicates = get_gt_helper_predicates(env_name) - all_predicates = env.predicates | helper_predicates + # Helper predicates take precedence over env predicates on name collisions + # (e.g. the grid's derived InFront replaces the position-based InFront). + all_predicates = helper_predicates | env.predicates helper_types = get_gt_helper_types(env_name) all_types = env.types | helper_types assert predicates_to_keep.issubset(all_predicates) @@ -292,6 +317,20 @@ def get_gt_simulator(env_name: str) -> tuple: f"env: {env_name}") +def get_gt_samplers(env_name: str) -> Optional[Dict[str, OptionSampler]]: + """Return ``option name -> ground-truth OptionSampler`` for an env. + + Merges the samplers from every ``GroundTruthSamplerFactory`` bound + to ``env_name``. Returns ``None`` when no factory provides samplers + for the env, so callers can fall back to learning/uniform sampling. + """ + out: Dict[str, OptionSampler] = {} + for cls in utils.get_all_subclasses(GroundTruthSamplerFactory): + if not cls.__abstractmethods__ and env_name in cls.get_env_names(): + out.update(cls.get_samplers(env_name)) + return out or None + + def get_gt_ldl_bridge_policy(env_name: str, types: Set[Type], predicates: Set[Predicate], options: Set[ParameterizedOption], diff --git a/predicators/ground_truth_models/boil/options.py b/predicators/ground_truth_models/boil/options.py index 59b2ccd48..0ef9bec31 100644 --- a/predicators/ground_truth_models/boil/options.py +++ b/predicators/ground_truth_models/boil/options.py @@ -83,6 +83,17 @@ def _get_options_skill_factories( robot_home_pos=(env_cls.robot_init_x, env_cls.robot_init_y, env_cls.robot_init_z), transport_z=cls._transport_z, + # Mobile-base (mobile_fetch) positioning: park the base 0.6 m in + # front of each reach target with its x aligned to the target x, so + # the arm reaches straight forward at a comfortable distance instead + # of sideways over the burner or fully extended. base_y is clamped + # to keep the base clear of the table front (y_lb). + base_standoff=(CFG.boil_mobile_base_standoff + if CFG.boil_mobile_base_park else None), + base_y_max=env_cls.y_lb - 0.28, + base_align_x=CFG.boil_mobile_base_align_x, + base_home_xy=(env_cls.robot_base_pos[0], + env_cls.robot_base_pos[1]), simulator=simulator, ) diff --git a/predicators/ground_truth_models/boil/options_legacy.py b/predicators/ground_truth_models/boil/options_legacy.py index a0a5d3d45..eb26098fb 100644 --- a/predicators/ground_truth_models/boil/options_legacy.py +++ b/predicators/ground_truth_models/boil/options_legacy.py @@ -578,7 +578,8 @@ def _get_current_and_target_pose_and_finger_status( cls._move_to_pose_tol, CFG.pybullet_max_vel_norm, cls._finger_action_nudge_magnitude, - validate=CFG.pybullet_ik_validate) + validate=CFG.pybullet_ik_validate, + stall_limit=8) @classmethod def _create_boil_move_to_above_jug_option( @@ -631,7 +632,8 @@ def _get_current_and_target_pose_and_finger_status( cls._move_to_pose_tol, CFG.pybullet_max_vel_norm, cls._finger_action_nudge_magnitude, - validate=CFG.pybullet_ik_validate) + validate=CFG.pybullet_ik_validate, + stall_limit=8) @classmethod def _create_boil_move_to_push_switch_option( diff --git a/predicators/ground_truth_models/boil/processes.py b/predicators/ground_truth_models/boil/processes.py index bd9d004d3..389657b7b 100644 --- a/predicators/ground_truth_models/boil/processes.py +++ b/predicators/ground_truth_models/boil/processes.py @@ -63,9 +63,17 @@ def _place_outside_sampler(state: State, goal: Set[GroundAtom], if not CFG.boil_use_skill_factories: return np.array([], dtype=np.float32) del state, goal, rng, objs + # Drop the idle jug at a single tuned spot in the open table region between + # the burner (south) and the faucet (east), clear of the table edges. This + # deterministic point is reachable and collision-free for both the fixed + # and mobile bases. (A per-sample randomized spread was tried for + # mobile_fetch but measured strictly worse -- 8/10 vs 10/10 on the 2-jug + # tasks -- because the spread occasionally lands near the burner or past the + # arm's reach, so it was removed.) x = PyBulletBoilEnv.x_mid - 0.15 y = PyBulletBoilEnv.y_mid + 0.10 - return np.array([x, y, _BOIL_DROP_Z, 0.0], dtype=np.float32) + z = _BOIL_DROP_Z + return np.array([x, y, z, 0.0], dtype=np.float32) class PyBulletBoilGroundTruthProcessFactory(GroundTruthProcessFactory): @@ -118,7 +126,15 @@ def get_processes( # Options PickJug = options["PickJug"] - Place = options["Place"] + if CFG.boil_use_skill_factories: + Place = options["Place"] + else: + # Legacy options expose object-keyed place options instead of a + # generic Place; the samplers already return empty params for the + # legacy path, so each place process just selects the right one. + PlaceUnderFaucetOpt = options["PlaceUnderFaucet"] + PlaceOnBurnerOpt = options["PlaceOnBurner"] + PlaceOutsideOpt = options["PlaceOutsideBurnerAndFaucet"] # Having swtich for each because of the type SwitchFaucetOn = options["SwitchFaucetOn"] SwitchFaucetOff = options["SwitchFaucetOff"] @@ -216,8 +232,12 @@ def get_processes( jug = Variable("?jug", jug_type) burner = Variable("?burner", burner_type) parameters = [robot, jug, burner] - option_vars = [robot] - option = Place + if CFG.boil_use_skill_factories: + option_vars = [robot] + option = Place + else: + option_vars = [robot, burner] + option = PlaceOnBurnerOpt condition_at_start = { LiftedAtom(Holding, [robot, jug]), LiftedAtom(NoJugAtBurner, [burner]), @@ -243,8 +263,12 @@ def get_processes( jug = Variable("?jug", jug_type) faucet = Variable("?faucet", faucet_type) parameters = [robot, jug, faucet] - option_vars = [robot] - option = Place + if CFG.boil_use_skill_factories: + option_vars = [robot] + option = Place + else: + option_vars = [robot, faucet] + option = PlaceUnderFaucetOpt condition_at_start = { LiftedAtom(Holding, [robot, jug]), LiftedAtom(NoJugAtFaucet, [faucet]), @@ -270,8 +294,12 @@ def get_processes( robot = Variable("?robot", robot_type) jug = Variable("?jug", jug_type) parameters = [robot, jug] - option_vars = [robot] - option = Place + if CFG.boil_use_skill_factories: + option_vars = [robot] + option = Place + else: + option_vars = [robot] + option = PlaceOutsideOpt condition_at_start = { LiftedAtom(Holding, [robot, jug]), } @@ -442,7 +470,18 @@ def get_processes( # delete_effects = { # LiftedAtom(JugNotFilled, [jug]), # } - delay_distribution = DiscreteGaussianDelay(mu=torch.tensor(5.0), + # Legacy options take fewer low-level steps per option than the + # skill-factory options, so the jug does not physically reach the + # fill threshold within the SwitchFaucetOn(1)+SwitchBurnerOn(3)+ + # SwitchFaucetOff(1) window the skill-factory timing was calibrated + # for. Use a longer symbolic fill delay for the legacy options so + # the planner emits an explicit Wait (which terminates exactly on + # JugFilled), filling robustly regardless of option duration. Keep + # the original delay for the skill-factory options, whose longer + # rollouts already fill within the window and would overfill / spill + # if the faucet kept running through an added Wait. + _fill_mu = 5.0 if CFG.boil_use_skill_factories else 8.0 + delay_distribution = DiscreteGaussianDelay(mu=torch.tensor(_fill_mu), sigma=torch.tensor(0.1)) fill_jug_process = ExogenousProcess("FillJug", parameters, condition_at_start, diff --git a/predicators/ground_truth_models/domino/__init__.py b/predicators/ground_truth_models/domino/__init__.py index f72ccd981..0f75ff78f 100644 --- a/predicators/ground_truth_models/domino/__init__.py +++ b/predicators/ground_truth_models/domino/__init__.py @@ -1,5 +1,6 @@ """Ground-truth models for coffee environment and variants.""" +from .gt_simulator import PyBulletDominoGroundTruthSimulatorFactory from .nsrts import PyBulletDominoGroundTruthNSRTFactory from .options import PyBulletDominoGroundTruthOptionFactory from .predicates import PyBulletDominoGroundTruthPredicateFactory @@ -11,6 +12,6 @@ "PyBulletDominoGroundTruthOptionFactory", "PyBulletDominoGroundTruthPredicateFactory", "PyBulletDominoGroundTruthProcessFactory", - "PyBulletDominoGroundTruthProcessFactory", + "PyBulletDominoGroundTruthSimulatorFactory", "PyBulletDominoGroundTruthTypeFactory", ] diff --git a/predicators/ground_truth_models/domino/gt_simulator.py b/predicators/ground_truth_models/domino/gt_simulator.py new file mode 100644 index 000000000..e253a7e0a --- /dev/null +++ b/predicators/ground_truth_models/domino/gt_simulator.py @@ -0,0 +1,57 @@ +"""Ground-truth simulator program for pybullet_domino process dynamics. + +This is an intentionally *empty* (no-op) simulator: it carries no +process dynamics and predicts no state features. It exists so that +``get_gt_simulator("pybullet_domino")`` resolves to a valid module +instead of raising ``NotImplementedError``. + +The contract enforced by ``read_simulator_components`` requires a +non-empty ``PROCESS_RULES`` list and a non-empty ``PARAM_SPECS`` list, +so we provide a single identity rule (returns updates unchanged) and a +single placeholder parameter. ``PROCESS_FEATURES`` is empty, signalling +that no features are predicted by the GT process model. +""" + +from __future__ import annotations + +from typing import Dict, List + +from predicators.code_sim_learning.training import ParamSpec +from predicators.code_sim_learning.utils import Params, ProcessUpdate +from predicators.ground_truth_models import GroundTruthSimulatorFactory +from predicators.structs import State + +# ── Process rules ──────────────────────────────────────────────── + + +def _identity(state: State, updates: ProcessUpdate, + params: Params) -> ProcessUpdate: + """No-op rule: domino dynamics are not modelled, so pass through.""" + del state, params # unused + return updates + + +# ── Public API: consumed by read_simulator_components ──────────── + +PROCESS_RULES = [_identity] + +# A single placeholder spec keeps PARAM_SPECS non-empty (the loader +# rejects an empty list) while leaving the dynamics a true no-op. +PARAM_SPECS: List[ParamSpec] = [ParamSpec("placeholder", 0.0, lo=0.0)] + +PROCESS_FEATURES: Dict[str, List[str]] = {} + +# ── Factory binding ────────────────────────────────────────────── + + +class PyBulletDominoGroundTruthSimulatorFactory(GroundTruthSimulatorFactory): + """Empty GT process-dynamics simulator for pybullet_domino. + + Only pins the env-name binding so ``get_gt_simulator`` can locate + this module via the factory registry; the simulator components live + as module globals above. + """ + + @classmethod + def get_env_names(cls) -> set: + return {"pybullet_domino"} diff --git a/predicators/ground_truth_models/domino/predicates.py b/predicators/ground_truth_models/domino/predicates.py index 430a63407..a1acaf959 100644 --- a/predicators/ground_truth_models/domino/predicates.py +++ b/predicators/ground_truth_models/domino/predicates.py @@ -1,13 +1,15 @@ -"""Helper predicates for the domino environment.""" +"""Helper predicates for the domino environment. -from typing import Dict, Sequence, Set +The grid predicates (DominoAtPos, DominoAtRot, PosClear, +InFrontDirection, InFront, AdjacentTo) are defined canonically by +``GridComponent``; this factory simply delegates to it so there is a +single source of truth. +""" -import numpy as np +from typing import Dict, Set -from predicators import utils from predicators.ground_truth_models import GroundTruthPredicateFactory -from predicators.structs import DerivedPredicate, GroundAtom, Object, \ - Predicate, State, Type +from predicators.structs import Predicate, Type class PyBulletDominoGroundTruthPredicateFactory(GroundTruthPredicateFactory): @@ -22,370 +24,12 @@ def get_helper_predicates(cls, env_name: str, types: Dict[str, Type]) -> Set[Predicate]: """Get helper predicates for the domino environment. - Returns DominoAtPos, DominoAtRot, and InFront predicates. + Delegates to ``GridComponent``, the canonical definition of the + grid predicates. Only oracle / process-planning approaches + consume these helpers; agent approaches run grid-free. """ del env_name # unused - # Get the required types from the passed-in types dict - domino_type = types["domino"] - position_type = types["loc"] - angle_type = types["angle"] - direction_type = types["direction"] - - # DominoAtPos predicate - DominoAtPos = Predicate("DominoAtPos", [domino_type, position_type], - cls._DominoAtPos_holds) - - # DominoAtRot predicate - DominoAtRot = Predicate("DominoAtRot", [domino_type, angle_type], - cls._DominoAtRot_holds) - - # PosClear predicate - PosClear = Predicate("PosClear", [position_type], cls._PosClear_holds) - - # InFrontDirection derived predicate - InFrontDirection = DerivedPredicate( - "InFrontDirection", [domino_type, domino_type, direction_type], - cls._InFrontDirection_holds, - auxiliary_predicates={DominoAtPos, DominoAtRot}) - - # InFront derived predicate - InFront = DerivedPredicate("InFront", [domino_type, domino_type], - cls._InFront_holds, - auxiliary_predicates={InFrontDirection}) - - # AdjacentTo derived predicate - AdjacentTo = DerivedPredicate("AdjacentTo", - [position_type, domino_type], - cls._AdjacentTo_holds, - auxiliary_predicates={DominoAtPos}) - - return { - DominoAtPos, DominoAtRot, InFrontDirection, InFront, PosClear, - AdjacentTo - } - - @staticmethod - def _DominoAtPos_holds(state: State, objects: Sequence[Object]) -> bool: - """Check if domino is at a specific position.""" - domino, position = objects - if state.get(domino, "is_held"): - return False - - # Get domino's actual position - domino_x = state.get(domino, "x") - domino_y = state.get(domino, "y") - - # Get position type to find all positions - position_type = position.type - - # Find closest position to the domino - closest_position = None - closest_distance = float('inf') - for pos in state.get_objects(position_type): - pos_x = state.get(pos, "xx") - pos_y = state.get(pos, "yy") - distance = np.sqrt((domino_x - pos_x)**2 + (domino_y - pos_y)**2) - if distance < closest_distance: - closest_distance = distance - closest_position = pos - - return closest_position == position - - @staticmethod - def _DominoAtRot_holds(state: State, objects: Sequence[Object]) -> bool: - """Check if domino is at a specific rotation.""" - domino, rotation = objects - if state.get(domino, "is_held"): - return False - - # Get domino's actual rotation (in radians) - domino_rot = state.get(domino, "yaw") - - # Get the target rotation (convert from degrees to radians) - target_rot_degrees = state.get(rotation, "angle") - target_rot_radians = np.radians(target_rot_degrees) - - # Check if domino rotation is close enough to target rotation - rotation_tolerance = np.radians(15) # 15 degrees tolerance - angle_diff = abs(utils.wrap_angle(domino_rot - target_rot_radians)) - - return angle_diff <= rotation_tolerance - - @staticmethod - def _InFrontDirection_holds(atoms: Set[GroundAtom], - objects: Sequence[Object]) -> bool: - """Check if domino1 is in front of domino2 in the given direction. - - This is an optimized implementation for heuristic evaluation. - """ - domino1, domino2, direction_obj = objects - - # Note: No longer need to filter "loc_other_" positions since we use - # exact coordinates - - # Helper functions to parse object names and cache results - _pos_coord_cache: Dict[Object, tuple] = {} - _rot_rad_cache: Dict[Object, float] = {} - - def extract_coords(pos_obj: Object) -> tuple: - """Extract x, y coordinates from location name like - 'loc_0.49_1.23'.""" - if pos_obj in _pos_coord_cache: - return _pos_coord_cache[pos_obj] - name_parts = pos_obj.name.split("_") - x_coord = float(name_parts[1]) # Extract from "0.49" part - y_coord = float(name_parts[2]) # Extract from "1.23" part - result = (x_coord, y_coord) - _pos_coord_cache[pos_obj] = result - return result - - def extract_rotation_angle_rad(rot_obj: Object) -> float: - if rot_obj in _rot_rad_cache: - return _rot_rad_cache[rot_obj] - angle_str = rot_obj.name.split("_")[1] - result = np.radians(float(angle_str)) - _rot_rad_cache[rot_obj] = result - return result - - # Gather all possible states for each domino - d1_positions_coords = { - extract_coords(atom.objects[1]) - for atom in atoms if atom.predicate.name == "DominoAtPos" - and atom.objects[0] == domino1 - } - d1_rotations_rad = { - extract_rotation_angle_rad(atom.objects[1]) - for atom in atoms if atom.predicate.name == "DominoAtRot" - and atom.objects[0] == domino1 - } - d2_positions_coords = { - extract_coords(atom.objects[1]) - for atom in atoms if atom.predicate.name == "DominoAtPos" - and atom.objects[0] == domino2 - } - d2_rotations_rad = { - extract_rotation_angle_rad(atom.objects[1]) - for atom in atoms if atom.predicate.name == "DominoAtRot" - and atom.objects[0] == domino2 - } - - def _check_case(front_domino_positions: Set[tuple], - front_domino_rotations: Set[float], - back_domino_positions: Set[tuple], - back_domino_rotations: Set[float], - direction_name: str, - tolerance: float = 1e-6) -> bool: - """Perform decoupled checks for positional and rotational - possibility.""" - # Fail fast if any required sets are empty - if not all([ - front_domino_positions, front_domino_rotations, - back_domino_positions, back_domino_rotations - ]): - return False - - # Import pos_gap for spatial calculations - from predicators.envs.pybullet_domino.composed_env import \ - PyBulletDominoComposedEnv # pylint: disable=import-outside-toplevel - pos_gap = PyBulletDominoComposedEnv.pos_gap - - # Positional Check: Is there ANY valid geometric placement? - position_possible = False - for (x_back, y_back) in back_domino_positions: - for rot_back_rad in back_domino_rotations: - # Relationship only holds for cardinal rotations - if not (abs(np.sin(rot_back_rad)) < tolerance - or abs(np.cos(rot_back_rad)) < tolerance): - continue - # Calculate expected position using actual spatial offset - dx = pos_gap * np.sin(rot_back_rad) - dy = pos_gap * np.cos(rot_back_rad) - expected_x = x_back + dx - expected_y = y_back + dy - - # Check if any front position matches (within tolerance) - for (x_front, y_front) in front_domino_positions: - if (abs(x_front - expected_x) < pos_gap * 0.3 - and abs(y_front - expected_y) < pos_gap * 0.3): - position_possible = True - break - if position_possible: - break - if position_possible: - break - - if not position_possible: - return False - - # Rotational Check: Is there ANY pair with correct rotation diff? - if direction_name == "left": - expected_rot_diff = np.pi / 4 - elif direction_name == "straight": - expected_rot_diff = 0 - elif direction_name == "right": - expected_rot_diff = -np.pi / 4 - else: - return False - - for rot_back_rad in back_domino_rotations: - for rot_front_rad in front_domino_rotations: - diff = utils.wrap_angle(rot_front_rad - rot_back_rad) - if abs(diff - expected_rot_diff) < tolerance: - return True - - return False - - # Check both symmetric cases for the relationship - dir_name = direction_obj.name - if dir_name == "left": - opposite_dir_name = "right" - elif dir_name == "right": - opposite_dir_name = "left" - else: # "straight" - opposite_dir_name = "straight" - - # Case 1: Is domino1 in front of domino2 in dir_name? - if _check_case(front_domino_positions=d1_positions_coords, - front_domino_rotations=d1_rotations_rad, - back_domino_positions=d2_positions_coords, - back_domino_rotations=d2_rotations_rad, - direction_name=dir_name): - return True - - # Case 2: Is domino2 in front of domino1 in opposite_dir_name? - if _check_case(front_domino_positions=d2_positions_coords, - front_domino_rotations=d2_rotations_rad, - back_domino_positions=d1_positions_coords, - back_domino_rotations=d1_rotations_rad, - direction_name=opposite_dir_name): - return True - - return False - - @staticmethod - def _InFront_holds(atoms: Set[GroundAtom], - objects: Sequence[Object]) -> bool: - """Check if domino1 is in front of domino2 in any direction.""" - domino1, domino2 = objects - - # Check if there exists any InFrontDirection atom with these dominos - for atom in atoms: - if (atom.predicate.name == "InFrontDirection" - and len(atom.objects) == 3 and atom.objects[0] == domino1 - and atom.objects[1] == domino2): - return True - - return False - - @staticmethod - def _PosClear_holds(state: State, objects: Sequence[Object]) -> bool: - """Check if a position is clear (not occupied by any domino). - - A position is considered clear if no domino is currently at that - position. - """ - position, = objects - - # Get the position coordinates - target_x = state.get(position, "xx") - target_y = state.get(position, "yy") - - # Calculate grid spacing (minimum distance between positions) - position_type = position.type - positions = list(state.get_objects(position_type)) - - min_distance = float('inf') - for i, pos1 in enumerate(positions): - for pos2 in positions[i + 1:]: - x1 = state.get(pos1, "xx") - y1 = state.get(pos1, "yy") - x2 = state.get(pos2, "xx") - y2 = state.get(pos2, "yy") - distance = np.sqrt((x1 - x2)**2 + (y1 - y2)**2) - if distance > 1e-6: # Skip identical positions - min_distance = min(min_distance, distance) - - # Use half the grid spacing as tolerance - position_tolerance = min_distance * 0.5 if min_distance != float( - 'inf') else 0.1 - - # Check if any domino is at this position - for obj in state: - if obj.type.name == "domino": - domino_x = state.get(obj, "x") - domino_y = state.get(obj, "y") - - # If domino is close enough to this position, position is not - # clear - if (abs(domino_x - target_x) <= position_tolerance - and abs(domino_y - target_y) <= position_tolerance - and not state.get(obj, "is_held")): - return False - - return True - - @staticmethod - def _AdjacentTo_holds(atoms: Set[GroundAtom], - objects: Sequence[Object]) -> bool: - """Check if a position is adjacent to a domino in cardinal directions. - - This is similar to _InFrontDirection_holds but checks if - a position is adjacent to any position where the domino - could be placed, considering that the domino can be in - multiple positions during heuristic computation. - - Adjacent positions are those that are exactly one grid step away in - cardinal directions (up, down, left, right) but not diagonal. - """ - position, domino = objects - - # Note: No longer need to filter "loc_other_" positions since we use - # exact coordinates - - # Helper functions to parse object names and cache results - _pos_coord_cache: Dict[Object, tuple] = {} - - def extract_coords(pos_obj: Object) -> tuple: - """Extract x, y coordinates from location name like - 'loc_0.49_1.23'.""" - if pos_obj in _pos_coord_cache: - return _pos_coord_cache[pos_obj] - name_parts = pos_obj.name.split("_") - x_coord = float(name_parts[1]) # Extract from "0.49" part - y_coord = float(name_parts[2]) # Extract from "1.23" part - result = (x_coord, y_coord) - _pos_coord_cache[pos_obj] = result - return result - - # Import pos_gap for spatial calculations - from predicators.envs.pybullet_domino.composed_env import \ - PyBulletDominoComposedEnv # pylint: disable=import-outside-toplevel - pos_gap = PyBulletDominoComposedEnv.pos_gap - - # Get coordinates of the target position - target_coords = extract_coords(position) - target_x, target_y = target_coords - - # Get all possible positions where the domino could be - domino_positions_coords = { - extract_coords(atom.objects[1]) - for atom in atoms if atom.predicate.name == "DominoAtPos" - and atom.objects[0] == domino - } - - # Check if the target position is adjacent to any domino position - # Adjacent means approximately one pos_gap away in cardinal directions - for domino_x, domino_y in domino_positions_coords: - # Calculate the actual distance in each dimension - dx = abs(target_x - domino_x) - dy = abs(target_y - domino_y) - - # Adjacent in cardinal directions means: - # - ~pos_gap away in one dir AND close to 0 in other - # Use 30% tolerance for matching pos_gap - if ((abs(dx - pos_gap) < pos_gap * 0.3 and dy < pos_gap * 0.3) or - (abs(dy - pos_gap) < pos_gap * 0.3 and dx < pos_gap * 0.3)): - return True - - return False + from predicators.envs.pybullet_domino.components.grid_component import \ + GridComponent # pylint: disable=import-outside-toplevel + return GridComponent(domino_type=types["domino"]).get_predicates() diff --git a/predicators/ground_truth_models/domino/processes.py b/predicators/ground_truth_models/domino/processes.py index a5b934d15..6d61debba 100644 --- a/predicators/ground_truth_models/domino/processes.py +++ b/predicators/ground_truth_models/domino/processes.py @@ -1,21 +1,26 @@ """Ground-truth processes for the domino environment.""" -from typing import Dict, Sequence, Set +from typing import Dict, List, Sequence, Set, Tuple import numpy as np import torch -from predicators.ground_truth_models import GroundTruthProcessFactory +from predicators.ground_truth_models import GroundTruthProcessFactory, \ + GroundTruthSamplerFactory from predicators.settings import CFG from predicators.structs import Array, CausalProcess, EndogenousProcess, \ - ExogenousProcess, GroundAtom, LiftedAtom, Object, ParameterizedOption, \ - Predicate, State, Type, Variable + ExogenousProcess, GroundAtom, LiftedAtom, Object, OptionSampler, \ + ParameterizedOption, Predicate, State, Type, Variable from predicators.utils import ConstantDelay, DiscreteGaussianDelay, \ - null_sampler + null_sampler, wrap_angle # Fixed parameter values for domino environment. _DOMINO_GRASP_Z_OFFSET = 0.0825 # domino_height * 0.55 -_DOMINO_DROP_Z = 0.5695 # table_height + domino_height * 1.13 +# Slightly above the legacy drop height. With the skill-factory Pick grasp +# transform, 0.5695 leaves the held domino penetrating the table at the +# collision-aware Place goal; 0.58 clears the table and still settles to the +# intended upright pose. +_DOMINO_DROP_Z = 0.58 _DOMINO_OFFSET_X = 0.045 # domino_depth * 3 _DOMINO_OFFSET_Z = 0.0825 # domino_height * 0.55 @@ -38,18 +43,65 @@ def _push_sampler(state: State, goal: Set[GroundAtom], def _place_sampler(state: State, goal: Set[GroundAtom], rng: np.random.Generator, objs: Sequence[Object]) -> Array: - """Return placement params from process objects.""" + """Return a generator-faithful placement for the open-loop oracle. + + ``objs = [robot, domino1, domino2, target_pos, rotation]``. The process + planner picks a discrete grid cell (``target_pos``) and angle (``rotation``) + for the held ``domino1`` next to the reference ``domino2``. The grid is a + uniform lattice (see ``augment_task_with_helper_objects``), so a turn block + lands at the *same* cell a straight block would, differing only in angle -- + the generator's inward ``domino_width/2`` corner offset is absent from the + lattice. Placing the held domino at the bare cell stalls corner cascades. + + Instead pick from the placements the generator would lay next to ``domino2`` + (``_generator_placements``, which carry the corner offset), rank-summing + three signals that each, alone, mishandle one case -- future-target bridge + (greedy: pulls a straight run onto the target), grid-cell distance (a + uniform-grid turn cell sits on the straight position, missing corners), and + angle error (the planner stamps spurious turn angles on straight runs). The + cascade-correct candidate is top-ranked on >=2 of the three. Deterministic; + final tiebreak is the planner's cell; bare cell if no candidate at all. + """ if not CFG.domino_use_skill_factories: return np.array([], dtype=np.float32) - del state, goal, rng + del goal, rng # objs = [robot, domino1, domino2, target_pos, rotation] + held = objs[1] + ref = objs[2] target_pos = objs[3] rotation = objs[4] - x = float(target_pos.name.split("_")[1]) - y = float(target_pos.name.split("_")[2]) - angle_deg = float(rotation.name.split("_")[-1]) - yaw = np.radians(angle_deg) - return np.array([x, y, _DOMINO_DROP_Z, yaw], dtype=np.float32) + gx = float(target_pos.name.split("_")[1]) + gy = float(target_pos.name.split("_")[2]) + gyaw = np.radians(float(rotation.name.split("_")[-1])) + + rx = state.get(ref, "x") + ry = state.get(ref, "y") + ryaw = state.get(ref, "yaw") + candidates = _generator_placements(rx, ry, ryaw) + if not candidates: + # Fallback: bare lattice cell (no generator candidate available). + return np.array([gx, gy, _DOMINO_DROP_Z, gyaw], dtype=np.float32) + bridges = [ + _future_target_bridge_score(state, held, c[0], c[1], c[2]) + for c in candidates + ] + dgrids = [float(np.hypot(c[0] - gx, c[1] - gy)) for c in candidates] + angerrs = [abs(wrap_angle(c[2] - gyaw)) for c in candidates] + + def _rank(vals: List[float], i: int, higher_better: bool = False) -> int: + # Number of candidates strictly better than ``i`` (ties share a rank). + if higher_better: + return sum(1 for v in vals if v > vals[i] + 1e-9) + return sum(1 for v in vals if v < vals[i] - 1e-9) + + def _total(i: int) -> Tuple[int, float]: + rank_sum = (_rank(bridges, i, higher_better=True) + _rank(dgrids, i) + + _rank(angerrs, i)) + return (rank_sum, dgrids[i]) + + best_i = min(range(len(candidates)), key=_total) + cx, cy, cyaw = candidates[best_i] + return np.array([cx, cy, _DOMINO_DROP_Z, cyaw], dtype=np.float32) class PyBulletDominoGroundTruthProcessFactory(GroundTruthProcessFactory): @@ -67,6 +119,10 @@ def get_processes( options: Dict[str, ParameterizedOption]) -> Set[CausalProcess]: del env_name # unused + # These processes are defined over the grid (loc/angle/direction). + # Only oracle / process-planning approaches request them, and they do + # so unconditionally, so the grid is intrinsic to those approaches. + # Types robot_type = types["robot"] domino_type = types["domino"] @@ -107,7 +163,13 @@ def get_processes( robot = Variable("?robot", robot_type) domino = Variable("?domino", domino_type) parameters = [robot, domino] - option_vars = [robot, domino] + # With restricted push the "Push" option finds the start block from + # the state itself, so it takes only the robot. The unrestricted + # option also takes the domino to push. + if CFG.domino_restricted_push: + option_vars = [robot] + else: + option_vars = [robot, domino] option = Push condition_at_start = { LiftedAtom(HandEmpty, [robot]), @@ -274,3 +336,252 @@ def get_processes( processes.add(domino_tilting_delete_process) return processes + + +# --------------------------------------------------------------------------- +# Grid-free per-skill samplers (NSRTSampler / OptionSampler signature) for +# bilevel refinement. The NSRT samplers above read the placement off grid +# ``loc``/``angle`` objects in ``objs``; these instead compute it +# geometrically from the step's ``InFront`` subgoal (passed in the atoms +# slot), so they work in the grid-free agent_bilevel path. Both versions +# coexist intentionally. Refinement clips the returned params to the box. +# --------------------------------------------------------------------------- + +_DOMINO_POS_GAP = 0.098 # PyBulletDominoEnv.pos_gap (domino_width * 1.4) +_DOMINO_WIDTH = 0.07 # PyBulletDominoEnv.domino_width +_DOMINO_TARGET_COLOR = (0.85, 0.7, 0.85) +_DOMINO_COLOR_EPS = 1e-3 + + +def _deterministic(sampler: OptionSampler) -> OptionSampler: + """Flag a sampler as returning constant params (ignores state/rng). + + Backtracking refinement reads this flag to cap such a step's retries at + 1: re-drawing a constant sampler yields the identical option, so spending + the full per-step budget re-descending through it on every backtrack is + wasted work (it can never produce a different outcome). + """ + setattr(sampler, "deterministic", True) + return sampler + + +@_deterministic +def _pick_option_sampler(state: State, subgoal_atoms: Set[GroundAtom], + rng: np.random.Generator, + objects: Sequence[Object]) -> Array: + """Grid-free Pick sampler: fixed grasp height above the domino origin.""" + del state, subgoal_atoms, rng, objects + return np.array([_DOMINO_GRASP_Z_OFFSET], dtype=np.float32) + + +@_deterministic +def _push_option_sampler(state: State, subgoal_atoms: Set[GroundAtom], + rng: np.random.Generator, + objects: Sequence[Object]) -> Array: + """Grid-free Push sampler: fixed approach distance / contact height.""" + del state, subgoal_atoms, rng, objects + return np.array([_DOMINO_OFFSET_X, _DOMINO_OFFSET_Z], dtype=np.float32) + + +def _score_placement(state: State, subgoal_atoms: Set[GroundAtom], + held: Object, hx: float, hy: float, hyaw: float) -> int: + """Count subgoal atoms that hold if ``held`` is placed at (hx, hy, + hyaw).""" + s2 = state.copy() + s2.set(held, "x", hx) + s2.set(held, "y", hy) + s2.set(held, "yaw", hyaw) + s2.set(held, "roll", 0.0) + s2.set(held, "is_held", 0.0) + return sum(1 for atom in subgoal_atoms if atom.holds(s2)) + + +def _is_cardinal(angle: float) -> bool: + """True when ``angle`` is within ~10 deg of a cardinal (axis-aligned) yaw. + + Mirrors the cardinal-facing gate in + ``DominoComponent._InFront_holds``: a settled reference domino sits + a degree or two off cardinal, so a hard equality would make chained + placements onto it unsatisfiable. + """ + card_thresh = float(np.sin(np.radians(10))) + return bool( + abs(np.sin(angle)) < card_thresh or abs(np.cos(angle)) < card_thresh) + + +def _generator_placements(xr: float, yr: float, + ryaw: float) -> List[Tuple[float, float, float]]: + """Every placement the task generator would lay next to a reference. + + Reproduces ``DominoTaskGenerator._place_straight_domino`` / + ``_place_turn90_domino`` exactly -- one ``pos_gap`` along a cardinal + travel direction, with 45-deg turn blocks carrying the generator's + half-width inward side offset -- expressed relative to a reference domino + at ``(xr, yr, ryaw)``. Each returned ``(cx, cy, cyaw)`` is a valid + ``InFront`` placement off the reference. + + A cardinal reference yields, for each of the two chain (forward / backward) + directions, the straight successor and the two turn-start (``d1``) blocks + (left / right). A non-cardinal reference -- an already-placed 45-deg + turn-start block -- yields the turn-completing (``d2``) block that bends + the chain the rest of the way through the corner. + """ + gap = _DOMINO_POS_GAP + s_off = -_DOMINO_WIDTH / 2 # generator's d1_side_offset / side_offset + out: List[Tuple[float, float, float]] = [] + if _is_cardinal(ryaw): + for rotation in (ryaw, wrap_angle(ryaw + np.pi)): + # Straight successor: one gap along travel, same (box) yaw. + out.append( + (xr + gap * np.sin(rotation), yr + gap * np.cos(rotation), + wrap_angle(ryaw))) + # Turn-start (d1): one gap ahead, nudged a half width orthogonal + # to the post-turn travel direction, yaw stepped +-45. + for turn in (1.0, -1.0): + d1_dir = wrap_angle(rotation - turn * np.pi / 4) + cx = xr + gap * np.sin(rotation) + turn * s_off * np.cos( + d1_dir) + cy = yr + gap * np.cos(rotation) - turn * s_off * np.sin( + d1_dir) + out.append((cx, cy, wrap_angle(ryaw + turn * np.pi / 4))) + else: + # Turn-completing block (d2) off an already-placed turn-start block. + # Take whichever turn sign(s) leave the pre-turn travel cardinal. + for turn in (1.0, -1.0): + base = wrap_angle(ryaw - turn * np.pi / 4) + if not _is_cardinal(base): + continue + d1_dir = wrap_angle(base - turn * np.pi / 4) + d2_rot = wrap_angle(base - turn * np.pi / 2) + cx = xr + gap * np.sin(d1_dir) + turn * s_off * np.cos(d2_rot) + cy = yr + gap * np.cos(d1_dir) - turn * s_off * np.sin(d2_rot) + out.append((cx, cy, wrap_angle(base + turn * np.pi / 2))) + return out + + +def _is_target_domino(state: State, domino: Object) -> bool: + """Check whether ``domino`` has the target-block color.""" + return all( + abs(state.get(domino, feat) - val) < _DOMINO_COLOR_EPS + for feat, val in zip(("r", "g", "b"), _DOMINO_TARGET_COLOR)) + + +def _future_target_bridge_score(state: State, held: Object, hx: float, + hy: float, hyaw: float) -> float: + """Tie-break score for placements that can be completed to a target. + + The immediate ``InFront(held, ref)`` subgoal underdetermines which + side of the start domino to place the bridge on. Prefer placements + for which one additional domino can be placed at the intersection of + generator-faithful successors from the held domino and from a purple + target domino. This keeps the sampler from spending most refinement + attempts on locally valid but globally dead first placements. + """ + dominoes = [o for o in state if o.type.name == "domino" and o is not held] + targets = [d for d in dominoes if _is_target_domino(state, d)] + if not targets: + return 0.0 + held_next = _generator_placements(hx, hy, hyaw) + if not held_next: + return 0.0 + best_resid = float("inf") + yaw_scale = _DOMINO_POS_GAP / np.pi + for target in targets: + tx = state.get(target, "x") + ty = state.get(target, "y") + tyaw = state.get(target, "yaw") + for hx2, hy2, hyaw2 in held_next: + for tx2, ty2, tyaw2 in _generator_placements(tx, ty, tyaw): + yaw_resid = abs(wrap_angle(hyaw2 - tyaw2)) * yaw_scale + resid = float(np.hypot(hx2 - tx2, hy2 - ty2) + yaw_resid) + best_resid = min(best_resid, resid) + if best_resid == float("inf"): + return 0.0 + return -best_resid + + +def _place_option_sampler(state: State, subgoal_atoms: Set[GroundAtom], + rng: np.random.Generator, + objects: Sequence[Object]) -> Array: + """Grid-free Place sampler that draws a generator-faithful placement. + + Builds the discrete set of placements the task generator could lay next + to each reference domino named in an ``InFront`` subgoal -- straight, or a + 45-deg left / right turn block, in either chain direction (see + ``_generator_placements``) -- scores each by how many of the step's + subgoal atoms it satisfies, and draws one uniformly at random from those + tied for the best score. Randomizing (rather than always returning the + first / straight placement) is what lets backtracking that re-draws this + step reach a turn when the lone subgoal (e.g. ``InFront(d1, d0)``) is + satisfied equally by straight and by a turn and a later step needs the + bend. No jitter is added -- the generator placements are already the + exact, cascade-tuned poses. Raises (so refinement falls back to uniform) + when the held domino or a usable reference can't be found. + """ + del objects + dominoes = [o for o in state if o.type.name == "domino"] + held = [d for d in dominoes if state.get(d, "is_held") > 0.5] + if len(held) != 1: + raise ValueError(f"expected one held domino, found {len(held)}") + held_d = held[0] + + refs = [] + for atom in subgoal_atoms: + if atom.predicate.name != "InFront": + continue + d1, d2 = atom.objects + if held_d is d1 and held_d is not d2: + refs.append(d2) + elif held_d is d2 and held_d is not d1: + refs.append(d1) + if not refs: + raise ValueError("no InFront subgoal references the held domino") + + # Collect every generator-faithful candidate, scored by how many of the + # step's subgoal atoms it satisfies. The candidates come straight from the + # task generator's geometry, so each is a valid InFront placement off its + # reference and the set is exactly what the generator could have laid. + candidates: List[Tuple[int, float, float, float, float]] = [] + for ref in refs: + xr = state.get(ref, "x") + yr = state.get(ref, "y") + rot = state.get(ref, "yaw") + for cx, cy, cyaw in _generator_placements(xr, yr, rot): + score = _score_placement(state, subgoal_atoms, held_d, cx, cy, + cyaw) + future_score = _future_target_bridge_score(state, held_d, cx, cy, + cyaw) + candidates.append((score, future_score, cx, cy, cyaw)) + if not candidates: + raise ValueError("no usable reference domino for placement") + + # Randomize among the placements tied for the best score, so backtracking + # that re-draws this step explores a turn instead of always returning the + # straight pose. Score alone disambiguates: a multi-edge step (a second + # InFront naming the next block) is satisfied only by the turn block that + # bends toward it, which no straight placement matches. + best_score = max(c[0] for c in candidates) + best_future_score = max(c[1] for c in candidates if c[0] == best_score) + tied = [ + c for c in candidates + if c[0] == best_score and abs(c[1] - best_future_score) < 1e-9 + ] + _, _, cx, cy, cyaw = tied[int(rng.integers(len(tied)))] + return np.array([cx, cy, _DOMINO_DROP_Z, cyaw], dtype=np.float32) + + +class PyBulletDominoGroundTruthSamplerFactory(GroundTruthSamplerFactory): + """Ground-truth grid-free per-skill samplers for the domino env.""" + + @classmethod + def get_env_names(cls) -> Set[str]: + return {"pybullet_domino_grid", "pybullet_domino"} + + @classmethod + def get_samplers(cls, env_name: str) -> Dict[str, OptionSampler]: + del env_name + return { + "Pick": _pick_option_sampler, + "Push": _push_option_sampler, + "Place": _place_option_sampler, + } diff --git a/predicators/ground_truth_models/domino/types.py b/predicators/ground_truth_models/domino/types.py index 767705c66..7fe05d0a4 100644 --- a/predicators/ground_truth_models/domino/types.py +++ b/predicators/ground_truth_models/domino/types.py @@ -6,8 +6,7 @@ from predicators.envs.pybullet_domino.components.domino_component import \ DominoComponent -from predicators.envs.pybullet_domino.composed_env import \ - PyBulletDominoComposedEnv +from predicators.envs.pybullet_domino.env import PyBulletDominoComposedEnv from predicators.ground_truth_models import GroundTruthTypeFactory from predicators.structs import Object, Task, Type from predicators.utils import PyBulletState @@ -29,16 +28,14 @@ def get_helper_types(cls, env_name: str) -> Set[Type]: """ del env_name # unused - # Position type with xx, yy coordinates - position_type = Type("loc", ["xx", "yy"]) - - # Angle type for discrete rotations - angle_type = Type("angle", ["angle"]) - - # Direction type for sequence generation - direction_type = Type("direction", ["dir"]) - - return {position_type, angle_type, direction_type} + # The grid types (loc/angle/direction) are defined canonically by + # GridComponent; delegate so there is a single source of truth. Only + # oracle / process-planning approaches request these helpers (and the + # grid predicates and processes built on them); the oracle does so + # unconditionally, so the grid is intrinsic to it and needs no flag. + from predicators.envs.pybullet_domino.components.grid_component import \ + GridComponent # pylint: disable=import-outside-toplevel + return GridComponent().get_types() @classmethod def augment_task_with_helper_objects(cls, task: Task) -> Task: diff --git a/predicators/ground_truth_models/skill_factories/base.py b/predicators/ground_truth_models/skill_factories/base.py index d4f17d86b..159693398 100644 --- a/predicators/ground_truth_models/skill_factories/base.py +++ b/predicators/ground_truth_models/skill_factories/base.py @@ -6,8 +6,8 @@ import logging from dataclasses import dataclass, field from enum import Enum, auto -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, \ - Sequence, Tuple, cast +from typing import TYPE_CHECKING, Any, Callable, ClassVar, Dict, List, \ + Optional, Sequence, Tuple, cast if TYPE_CHECKING: from predicators.envs.pybullet_env import PyBulletEnv @@ -18,6 +18,7 @@ from predicators import utils from predicators.pybullet_helpers.controllers import \ + _build_action_from_joints, _robot_supports_base_action, \ get_change_fingers_action, get_move_end_effector_to_pose_action from predicators.pybullet_helpers.geometry import Pose from predicators.pybullet_helpers.inverse_kinematics import \ @@ -82,6 +83,14 @@ class SkillConfig: after push skills. Required by ``create_push_skill``. transport_z: Safe Z height for transit above obstacles during pick, place, push, and pour skills. Default ``0.7``. + base_standoff: For mobile-base robots, the forward (y) distance at + which the base parks in front of a reach target (with its x aligned + to the target x), so the arm reaches it straight forward at a + comfortable distance instead of sideways over the burner/a jug or + fully extended. ``None`` (default) disables base positioning; only + mobile robots use it. + base_y_max: Upper bound on the base y while positioning, to keep the + base clear of the table front. Default ``inf`` (no clamp). extra: Arbitrary dict for environment-specific constants that callbacks may need. Access via ``config.extra["key"]``. """ @@ -99,6 +108,10 @@ class SkillConfig: robot_init_wrist: float = 0.0 robot_home_pos: Optional[Tuple[float, float, float]] = None transport_z: float = 0.7 + base_standoff: Optional[float] = None + base_y_max: float = float("inf") + base_align_x: bool = True + base_home_xy: Optional[Tuple[float, float]] = None simulator: Optional[PyBulletEnv] = None collision_skip_types: Tuple[str, ...] = () sim_extra_collision_bodies: Tuple[int, ...] = () @@ -184,6 +197,15 @@ class Phase: use_motion_planning: bool = field( default_factory=lambda: CFG.skill_phase_use_motion_planning) expect_contact: bool = False + allow_shallow_held_object_contacts: bool = False + # Force validated (iterative) IK for this phase's BiRRT goal pose, even + # when CFG.pybullet_ik_validate is False. Unvalidated IK can return a goal + # config whose EE pose is numerically close but whose gripper slightly + # penetrates the very object being approached (the grasp target), making + # BiRRT reject an otherwise-reachable grasp. Validating only this phase's + # goal fixes that without the cost/regressions of globally validating + # every transport/retreat IK. + validate_ik: bool = False class PhaseSkill: @@ -209,7 +231,8 @@ def __init__(self, params_space: Box, config: SkillConfig, phases: List[Phase], - params_description: Optional[Tuple[str, ...]] = None) -> None: + params_description: Optional[Tuple[str, ...]] = None, + base_mode: Optional[str] = None) -> None: assert len(phases) > 0 self._name = name self._types = types @@ -217,6 +240,14 @@ def __init__(self, self._config = config self._phases = phases self._params_description = params_description + # Mobile-base positioning mode for this skill (None disables it): + # "home" park at the robot's home base (good offset to press a + # switch; diagonal fixed-base reach for far targets). + # "align_left" slide base x toward the target but not right of home + # (frees the over-the-burner reach), forward in y. + # "diag" keep base x at home, move forward in y (diagonal carry + # that clears an adjacent jug / the faucet body). + self._base_mode = base_mode def build(self) -> ParameterizedOption: """Build and return the ParameterizedOption.""" @@ -344,12 +375,129 @@ def _ik_phase_is_terminal(self, phase: Phase, state: State, def _execute_move(self, phase: Phase, state: State, memory: Dict, objects: Sequence[Object], params: Array) -> Action: - """Dispatch to BiRRT or incremental IK based on phase flag.""" + """Dispatch to BiRRT or incremental IK based on phase flag. + + For mobile-base robots, first drive the base to a pose that puts + the reach target in comfortable arm range (the arm BiRRT/IK then + plans from the repositioned base). + """ + base_action = self._maybe_drive_base(phase, state, memory, objects, + params) + if base_action is not None: + return base_action if phase.use_motion_planning: return self._execute_move_birrt(phase, state, memory, objects, params) return self._execute_move_ik(phase, state, objects, params) + # Mobile-base positioning. Before the first reach of an option, drive the + # (kinematic) base to park `base_standoff` in front of the reach target with + # its x aligned to the target x (base y clamped to base_y_max to stay clear + # of the table), so the arm reaches *straight forward at a comfortable + # distance* rather than sideways/over the burner or fully extended. The base + # pose is a deterministic function of the option params, so it is + # reproducible across refinement samples (unlike a per-sample search) and + # adds just one base-drive step per option. Enabled per-env by setting + # base_standoff; only active for mobile robots (e.g. mobile_fetch), a no-op + # for fixed bases. + _base_pos_tol: ClassVar[float] = 0.02 # xy tol to call the base positioned + _base_step: ClassVar[float] = 0.08 # max base xy move per step (smooth) + + def _maybe_drive_base(self, phase: Phase, state: State, memory: Dict, + objects: Sequence[Object], + params: Array) -> Optional[Action]: + """Return a one-step base-drive Action that stands the base in front of + this option's reach target; None once positioned (or for fixed-base + robots / when base positioning is disabled).""" + robot = self._config.robot + if self._config.base_standoff is None \ + or self._base_mode is None \ + or not _robot_supports_base_action(robot): + return None + pb_state = cast(utils.PyBulletState, state) + sim_state = pb_state.simulator_state + if not isinstance(sim_state, dict) or "base_pose" not in sim_state: + return None + if memory.get("_base_pos_done", False): + return None + (cur_x, cur_y, _), _ = sim_state["base_pose"] + home_xy = self._config.base_home_xy + if self._base_mode == "home" and home_xy is not None: + # Push: park at the robot's home base, which sits diagonally off the + # switch (offset opposite the push direction and in front) so the + # arm presses it naturally. Head-on (x-aligned) pins the arm near a + # singularity and makes the push wander off target. + target_bx, target_by = home_xy + else: + _, target_pose, _ = phase.target_fn(state, objects, params, + self._config) + home_x = home_xy[0] if home_xy is not None else ( + self._config.robot_home_pos[0] + if self._config.robot_home_pos is not None else float(cur_x)) + stay_home = False + if self._base_mode == "align_left": + # Pick: slide x toward the target but never to the right of + # home. The over-the-burner reach only happens for targets left + # of home; right targets (front jug, jug under the faucet) keep + # home's diagonal approach, which clears the faucet body. + target_bx = min(float(target_pose.position[0]), home_x) + elif self._base_mode == "approach": + # Pick a jug that may sit beside another jug (the 2-jug boil + # tasks). Reposition only when a second jug actually blocks the + # reach -- one sitting close to the target in both x and y, so + # reaching it from home would sweep the arm across it (the jug0- + # vs-jug1 grasp/lift collision a fixed base cannot avoid). Then + # stand to the target's far side from that jug, offset laterally + # (NOT x-aligned, which would pin this arm at a + # singularity -- see the "home" push note). With no blocker, + # keep home's diagonal approach: moving the base in would only + # risk that singularity (e.g. re-picking a jug under the + # faucet, which has no neighbor). + tx = float(target_pose.position[0]) + ty = float(target_pose.position[1]) + blocker_x: Optional[float] = None + for other in state: + if other.type.name != "jug" or other in objects: + continue + ox = float(state.get(other, "x")) + oy = float(state.get(other, "y")) + if abs(ox - tx) < 0.4 and abs(oy - ty) < 0.4: + blocker_x = ox + break + if blocker_x is None: + target_bx = home_x + stay_home = True + else: + side = 1.0 if tx >= blocker_x else -1.0 + target_bx = tx + side * 0.15 + else: + # Place ("diag"): keep base x at home and only move forward in + # y, so the carry stays diagonal (clearing an adjacent jug or + # the faucet body) yet close enough for a comfortable reach. + target_bx = home_x + if stay_home: + # No reposition needed: return to (or stay at) the home base so + # the reach keeps home's well-conditioned diagonal geometry. + target_by = home_xy[1] if home_xy is not None else float(cur_y) + else: + target_by = min( + float(target_pose.position[1]) - + self._config.base_standoff, self._config.base_y_max) + dx, dy = target_bx - cur_x, target_by - cur_y + dist = float(np.hypot(dx, dy)) + if dist < self._base_pos_tol: + memory["_base_pos_done"] = True + return None + # Move the base toward the target in small increments rather than one + # teleport, so a held jug follows the grasp constraint smoothly instead + # of being yanked across the jump (which destabilizes the carry). + if dist > self._base_step: + dx *= self._base_step / dist + dy *= self._base_step / dist + base_delta = np.array([dx, dy, 0.0], dtype=np.float32) + return _build_action_from_joints(robot, pb_state.joint_positions, + base_delta) + def _execute_move_birrt(self, phase: Phase, state: State, memory: Dict, objects: Sequence[Object], params: Array) -> Action: @@ -385,7 +533,8 @@ def _execute_move_birrt(self, phase: Phase, state: State, memory: Dict, if self._config.simulator is not None: traj = self._plan_with_simulator(pb_state, target_pose, phase.name, - phase.expect_contact) + phase.expect_contact, objects, + phase) else: traj = self._plan_without_simulator(pb_state, target_pose, phase.name) @@ -450,12 +599,11 @@ def _execute_move_birrt(self, phase: Phase, state: State, memory: Dict, joint_action[finger_idx_l] = f_action joint_action[finger_idx_r] = f_action - action_arr = np.clip( - np.array(joint_action, dtype=np.float32), - robot.action_space.low, - robot.action_space.high, - ) - return Action(action_arr) + # _build_action_from_joints pads zero base deltas for mobile robots + # (BiRRT replays a fixed-base arm trajectory) and is a no-op clip for + # fixed-base robots, keeping the action shape matched to the robot's + # action space. + return _build_action_from_joints(robot, joint_action) # ------------------------------------------------------------------ # BiRRT planning helpers @@ -530,6 +678,8 @@ def _plan_with_simulator( target_pose: Pose, phase_name: str, expect_contact: bool = False, + objects: Sequence[Object] = (), + phase: Optional[Phase] = None, ) -> Optional[Sequence[JointPositions]]: """Plan using the simulator env for collision-aware motion planning. @@ -537,6 +687,7 @@ def _plan_with_simulator( the simulator, collects collision body IDs, and runs IK + BiRRT on the simulator's physics client. """ + del objects # Currently unused; kept for caller-signature parity. sim = self._config.simulator assert sim is not None @@ -590,11 +741,22 @@ def _plan_with_simulator( # 5. IK + motion planning on simulator's robot planning_robot = sim._pybullet_robot # pylint: disable=protected-access planning_robot.set_joints(pb_state.joint_positions) + + # Compute base_link_to_held_obj if an object is held (needed both for + # motion planning and the collision-aware IK below). + base_link_to_held_obj = None + if held_object is not None and sim._held_obj_to_base_link is not None: # pylint: disable=protected-access + base_link_to_held_obj = p.invertTransform( + *sim._held_obj_to_base_link) # pylint: disable=protected-access + + # Validate the goal IK when globally enabled, or when this phase + # requests it (e.g. a grasp approach, where an imprecise goal config + # clips the target object and BiRRT then rejects a reachable grasp). + validate_goal_ik = self._config.ik_validate or (phase is not None + and phase.validate_ik) try: target_joints: JointPositions = planning_robot.inverse_kinematics( - target_pose, - validate=self._config.ik_validate, - set_joints=True) + target_pose, validate=validate_goal_ik, set_joints=True) except InverseKinematicsError: pos = target_pose.position logging.warning( @@ -603,12 +765,6 @@ def _plan_with_simulator( "falling back to incremental IK.") return None - # Compute base_link_to_held_obj if an object is held. - base_link_to_held_obj = None - if held_object is not None and sim._held_obj_to_base_link is not None: # pylint: disable=protected-access - base_link_to_held_obj = p.invertTransform( - *sim._held_obj_to_base_link) # pylint: disable=protected-access - traj = run_motion_planning( robot=planning_robot, initial_positions=pb_state.joint_positions, @@ -618,8 +774,43 @@ def _plan_with_simulator( physics_client_id=sim._physics_client_id, # pylint: disable=protected-access held_object=held_object, base_link_to_held_obj=base_link_to_held_obj, + allow_shallow_held_object_contacts=( + phase.allow_shallow_held_object_contacts + if phase is not None else False), ) + if traj is None and not validate_goal_ik: + # A single unvalidated PyBullet IK call can return a joint + # configuration whose EE pose is close enough numerically but whose + # carried object is in collision. Before declaring the option + # infeasible, retry with validated IK, which iterates to a better + # Cartesian target solution while preserving the fast path for the + # common case. + sim._set_state(remapped_state) # pylint: disable=protected-access + planning_robot.set_joints(pb_state.joint_positions) + try: + validated_target_joints = \ + planning_robot.inverse_kinematics( + target_pose, validate=True, set_joints=True) + except InverseKinematicsError: + validated_target_joints = None + if validated_target_joints is not None: + traj = run_motion_planning( + robot=planning_robot, + initial_positions=pb_state.joint_positions, + target_positions=validated_target_joints, + collision_bodies=collision_bodies, + seed=CFG.seed, + physics_client_id=sim._physics_client_id, # pylint: disable=protected-access + held_object=held_object, + base_link_to_held_obj=base_link_to_held_obj, + allow_shallow_held_object_contacts=( + phase.allow_shallow_held_object_contacts + if phase is not None else False), + ) + if traj is not None: + target_joints = validated_target_joints + if traj is None and not expect_contact: self._log_collision_diagnostics( planning_robot, @@ -677,15 +868,19 @@ def _check(joints: JointPositions, label: str) -> None: body, physicsClientId=physics_client_id) if any(c[8] < margin for c in contacts): + min_dist = min(c[8] for c in contacts) logging.error(f"[{self._name}/{phase_name}] {label} ROBOT " - f"collision with body {body} ({body_name})") + f"collision with body {body} ({body_name}); " + f"min contact distance {min_dist:.6f}") if held_object is not None: contacts = p.getContactPoints( held_object, body, physicsClientId=physics_client_id) if any(c[8] < margin for c in contacts): + min_dist = min(c[8] for c in contacts) logging.error( f"[{self._name}/{phase_name}] {label} HELD " - f"collision with body {body} ({body_name})") + f"collision with body {body} ({body_name}); " + f"min contact distance {min_dist:.6f}") _check(start_joints, "START") _check(goal_joints, "GOAL") @@ -709,6 +904,10 @@ def _execute_move_ik(self, phase: Phase, state: State, finger_action_nudge_magnitude=( self._config.finger_action_nudge_magnitude), validate=self._config.ik_validate, + # Base positioning is handled once per option by + # _maybe_drive_base; keep incremental IK arm-only so the base + # doesn't drift during contact phases (e.g. a switch push). + move_base=False, ) except utils.OptionExecutionFailure: cur = current_pose.position diff --git a/predicators/ground_truth_models/skill_factories/move_to.py b/predicators/ground_truth_models/skill_factories/move_to.py index 08d3e21bb..25af34339 100644 --- a/predicators/ground_truth_models/skill_factories/move_to.py +++ b/predicators/ground_truth_models/skill_factories/move_to.py @@ -101,6 +101,8 @@ def make_move_to_phase( get_target_pose_fn: TargetPoseFn, finger_status: Optional[str] = None, expect_contact: bool = False, + allow_shallow_held_object_contacts: bool = False, + validate_ik: bool = False, ) -> Phase: """Create a MOVE_TO_POSE phase for use in a ``PhaseSkill``. @@ -166,4 +168,6 @@ def _target_fn( action_type=PhaseAction.MOVE_TO_POSE, target_fn=_target_fn, expect_contact=expect_contact, + allow_shallow_held_object_contacts=allow_shallow_held_object_contacts, + validate_ik=validate_ik, ) diff --git a/predicators/ground_truth_models/skill_factories/pick.py b/predicators/ground_truth_models/skill_factories/pick.py index 7c53f765e..bc143531b 100644 --- a/predicators/ground_truth_models/skill_factories/pick.py +++ b/predicators/ground_truth_models/skill_factories/pick.py @@ -140,7 +140,13 @@ def _slight_lift_pose( phases = [] phases.extend([ make_move_to_phase("MoveAbove", _above_pose, "closed"), - make_move_to_phase("MoveToGrasp", _descend_pose, "open"), + # Validate the grasp goal IK: the gripper descends to envelop the + # target, and an imprecise (unvalidated) IK config can clip the target + # object, making BiRRT reject a reachable grasp. See Phase.validate_ik. + make_move_to_phase("MoveToGrasp", + _descend_pose, + "open", + validate_ik=True), Phase( name="Grasp", action_type=PhaseAction.CHANGE_FINGERS, @@ -148,7 +154,10 @@ def _slight_lift_pose( terminal_fn=None, finger_direction="close", ), - make_move_to_phase("LiftSlightly", _slight_lift_pose, "closed") + make_move_to_phase("LiftSlightly", + _slight_lift_pose, + "closed", + allow_shallow_held_object_contacts=True) ]) return PhaseSkill(name, @@ -156,4 +165,5 @@ def _slight_lift_pose( params_space, config, phases, - params_description=params_description).build() + params_description=params_description, + base_mode="home").build() diff --git a/predicators/ground_truth_models/skill_factories/place.py b/predicators/ground_truth_models/skill_factories/place.py index 502120636..2ffad692e 100644 --- a/predicators/ground_truth_models/skill_factories/place.py +++ b/predicators/ground_truth_models/skill_factories/place.py @@ -139,4 +139,5 @@ def _drop_pose( params_space, config, phases, - params_description=params_description).build() + params_description=params_description, + base_mode="home").build() diff --git a/predicators/ground_truth_models/skill_factories/push.py b/predicators/ground_truth_models/skill_factories/push.py index d03db748b..bf673eb25 100644 --- a/predicators/ground_truth_models/skill_factories/push.py +++ b/predicators/ground_truth_models/skill_factories/push.py @@ -207,4 +207,5 @@ def _get_target( params_space, config, phases, - params_description=params_description).build() + params_description=params_description, + base_mode="home").build() diff --git a/predicators/ground_truth_models/skill_factories/wait.py b/predicators/ground_truth_models/skill_factories/wait.py index b2c7c0042..678a43b90 100644 --- a/predicators/ground_truth_models/skill_factories/wait.py +++ b/predicators/ground_truth_models/skill_factories/wait.py @@ -70,12 +70,18 @@ def _policy(state: State, memory: Dict, objects: Sequence[Object], joint_positions[robot.left_finger_joint_idx] = f_action joint_positions[robot.right_finger_joint_idx] = f_action + # Pad base-action dims with zeros for mobile robots so the action + # matches the (arm + base) action space; a no-op for fixed bases. + action_arr = np.array(joint_positions, dtype=np.float32) + n_action = robot.action_space.shape[0] + if action_arr.shape[0] < n_action: + action_arr = np.concatenate([ + action_arr, + np.zeros(n_action - action_arr.shape[0], dtype=np.float32) + ]) return Action( - np.clip( - np.array(joint_positions, dtype=np.float32), - robot.action_space.low, - robot.action_space.high, - )) + np.clip(action_arr, robot.action_space.low, + robot.action_space.high)) return ParameterizedOption( name, diff --git a/predicators/pybullet_helpers/controllers.py b/predicators/pybullet_helpers/controllers.py index 6e3bc9292..50c2eb570 100644 --- a/predicators/pybullet_helpers/controllers.py +++ b/predicators/pybullet_helpers/controllers.py @@ -85,12 +85,19 @@ def get_move_end_effector_to_pose_action( max_vel_norm: float, finger_action_nudge_magnitude: float, validate: bool = True, + move_base: bool = True, ) -> Action: """Get an action for moving the end effector to a target pose. See create_move_end_effector_to_pose_option() for more info. + + For mobile-base robots the base is also driven toward the target by + default. Callers that position the base separately (e.g. the skill + factories' ``_maybe_drive_base``) should pass ``move_base=False`` to keep + this purely an arm motion -- otherwise the base would drift during delicate + incremental-IK phases such as a switch push. """ - if _robot_supports_base_action(robot): + if move_base and _robot_supports_base_action(robot): max_base_vel_norm = getattr(robot, "default_base_vel_norm", max_vel_norm) max_base_rot_vel = getattr(robot, "default_base_rot_vel", max_vel_norm) @@ -211,6 +218,7 @@ def create_move_end_effector_to_pose_option( initiable: ParameterizedInitiable = lambda _1, _2, _3, _4: True, terminal: Optional[ParameterizedTerminal] = None, validate: bool = True, + stall_limit: Optional[int] = None, ) -> ParameterizedOption: """A generic utility that creates a ParameterizedOption for moving the end effector to a target pose, given a function that takes in the current @@ -248,7 +256,6 @@ def _policy(state: State, memory: Dict, objects: Sequence[Object], def _terminal(state: State, memory: Dict, objects: Sequence[Object], params: Array) -> bool: - del memory # unused current_pose, target_pose, _ = \ get_current_and_target_pose_and_finger_status( state, objects, params) @@ -257,7 +264,24 @@ def _terminal(state: State, memory: Dict, objects: Sequence[Object], current = current_pose.position target = target_pose.position squared_dist = np.sum(np.square(np.subtract(current, target))) - return squared_dist < move_to_pose_tol + if squared_dist < move_to_pose_tol: + return True + # When opted in via ``stall_limit``, also terminate once the end + # effector has frozen: incremental IK can plateau a couple of cm + # short of a reach-edge target under the fixed wrist orientation, + # otherwise burning the whole option horizon. The near-target gate + # keeps this from masking genuine far-from-goal failures. + if stall_limit is not None: + last = memory.get("_stall_last_pos") + if last is not None and \ + np.sum(np.square(np.subtract(current, last))) < 1e-8: + memory["_stall_count"] = memory.get("_stall_count", 0) + 1 + else: + memory["_stall_count"] = 0 + memory["_stall_last_pos"] = current + if memory["_stall_count"] >= stall_limit and squared_dist < 0.01: + return True + return False return ParameterizedOption( name, diff --git a/predicators/pybullet_helpers/motion_planning.py b/predicators/pybullet_helpers/motion_planning.py index b2749b10e..bfb605795 100644 --- a/predicators/pybullet_helpers/motion_planning.py +++ b/predicators/pybullet_helpers/motion_planning.py @@ -5,6 +5,7 @@ import numpy as np import pybullet as p +from gym.spaces import Box from numpy.typing import NDArray from predicators import utils @@ -23,13 +24,23 @@ def run_motion_planning( physics_client_id: int, held_object: Optional[int] = None, base_link_to_held_obj: Optional[NDArray] = None, + allow_shallow_held_object_contacts: bool = False, ) -> Optional[Sequence[JointPositions]]: """Run BiRRT to find a collision-free sequence of joint positions. Note that this function changes the state of the robot. """ rng = np.random.default_rng(seed) + # BiRRT plans in the arm-joint space. For mobile robots, action_space also + # includes base-delta dims (appended last); strip them so sampled configs + # match the arm joints that set_joints / forward_kinematics expect. For + # fixed-base robots (base_action_dim == 0) this is a no-op. joint_space = robot.action_space + base_dim = int(getattr(robot, "base_action_dim", 0)) + if base_dim > 0: + joint_space = Box(low=np.asarray(joint_space.low[:-base_dim]), + high=np.asarray(joint_space.high[:-base_dim]), + dtype=np.float32) joint_space.seed(seed) num_interp = CFG.pybullet_birrt_extend_num_interp @@ -58,6 +69,20 @@ def _set_state(pt: JointPositions) -> None: world_to_held_obj[1], physicsClientId=physics_client_id) + allowed_shallow_held_collision_bodies = set() + if allow_shallow_held_object_contacts and held_object is not None: + _set_state(initial_positions) + p.performCollisionDetection(physicsClientId=physics_client_id) + shallow_margin = CFG.pybullet_birrt_shallow_held_contact_margin + hard_margin = CFG.pybullet_birrt_contact_margin + for body in collision_bodies: + contacts = p.getContactPoints(held_object, + body, + physicsClientId=physics_client_id) + penetrating = [c[8] for c in contacts if c[8] < hard_margin] + if penetrating and min(penetrating) >= shallow_margin: + allowed_shallow_held_collision_bodies.add(body) + def _extend_fn(pt1: JointPositions, pt2: JointPositions) -> Iterator[JointPositions]: pt1_arr = np.array(pt1) @@ -87,7 +112,14 @@ def _collision_fn(pt: JointPositions) -> bool: if held_object is not None: contacts = p.getContactPoints( held_object, body, physicsClientId=physics_client_id) - if any(c[8] < margin for c in contacts): + contact_distances = [c[8] for c in contacts] + if body in allowed_shallow_held_collision_bodies: + shallow_margin = \ + CFG.pybullet_birrt_shallow_held_contact_margin + if any(d < shallow_margin for d in contact_distances): + return True + continue + if any(d < margin for d in contact_distances): return True return False diff --git a/predicators/settings.py b/predicators/settings.py index b8038d6ef..ca4aa7914 100644 --- a/predicators/settings.py +++ b/predicators/settings.py @@ -197,6 +197,10 @@ class GlobalSettings: pybullet_birrt_extend_num_interp = 10 pybullet_birrt_path_subsample_ratio = 1 pybullet_birrt_contact_margin = -0.001 + # During a lift after grasping, the held object can start in shallow + # penetration from grasp settling. Allow escaping these initial contacts + # only up to this depth; deeper penetration remains a collision. + pybullet_birrt_shallow_held_contact_margin = -0.003 pybullet_control_mode = "position" pybullet_max_vel_norm = 0.05 # env -> robot -> quaternion @@ -434,7 +438,6 @@ class GlobalSettings: domino_some_dominoes_are_connected = False domino_initialize_at_finished_state = True domino_use_domino_blocks_as_target = False - domino_use_grid = False domino_include_connected_predicate = False domino_has_glued_dominos = True domino_prune_actions = False # Set to True to enable action pruning @@ -530,6 +533,21 @@ class GlobalSettings: boil_num_burner_train = [1] boil_num_burner_test = [1] boil_water_fill_speed = 0.002 + # For the mobile_fetch robot: park the base (x-aligned to each reach + # target, a stand-off in front in y) before reaching, so the arm reaches + # straight forward at a comfortable distance instead of sideways over the + # burner or fully extended. No-op for fixed bases. Set False to disable + # (e.g. to isolate base-positioning effects). + boil_mobile_base_park = True + # Forward (y) stand-off distance for the parked base. Smaller = closer to + # the target = more reach margin (incl. sideways switch push-through), + # bounded by the table-clear y cap. + boil_mobile_base_standoff = 0.45 + # Align the parked base x with the reach target x. True is best for picks + # (straight approach avoids sweeping over the burner); False keeps the base + # at the home x (diagonal approach) which leaves room for sideways switch + # push-throughs. + boil_mobile_base_align_x = True # parameters for random options approach random_options_max_tries = 100 @@ -766,6 +784,16 @@ class GlobalSettings: process_planning_use_abstract_policy = False process_planning_max_policy_guided_rollout = 10 process_planning_set_parameters_one = False + # On an execution-time option failure (e.g. a fresh BiRRT collision caused + # by drift between the refinement simulator and the real environment), + # re-refine from the current state and retry, up to this many times. 0 + # disables replanning (the option failure is terminal, as before). + process_planning_max_execution_replans = 0 + # Whether non-oracle process-planning approaches (process/param learning, + # predicate invention, etc.) augment with the ground-truth helper types, + # predicates, and objects (e.g. the domino grid). The oracle always does; + # the others opt in via this flag (e.g. for ExoPredicator). + process_planning_use_gt_helpers = False process_task_planning_heuristic = 'h_ff' wait_option_terminate_on_atom_change = True running_no_invent_baseline = False @@ -999,7 +1027,7 @@ class GlobalSettings: # agent SDK online abstraction learning parameters agent_sdk_model_name = "claude-sonnet-4-6" - agent_sdk_max_agent_turns_per_iteration = 20 + agent_sdk_max_agent_turns_per_iteration = 50 agent_sdk_agent_timeout = 300 # seconds per iteration agent_sdk_resume_session = True # resume previous session if available agent_sdk_propose_types = True @@ -1026,7 +1054,7 @@ class GlobalSettings: agent_planner_use_visualize_state = False # include visualize_state tool agent_planner_use_annotate_scene = False # include annotate_scene tool # Whether the planner is given a simulator to test candidate plans with - # (the test_option_plan tool / option-model rollouts). When False, the + # (the evaluate_option_plan tool / option-model rollouts). When False, the # agent must plan open-loop from trajectory data and LLM reasoning alone # -- the genuinely model-free baseline. agent_planner_use_simulator = True @@ -1044,6 +1072,12 @@ class GlobalSettings: # reseed refinement on the same skeleton before re-querying the agent agent_bilevel_max_refine_retries = 5 agent_bilevel_check_subgoals = True # check subgoal atoms after each step + # When True, close the agent SDK session at the start of each test task + # so every test solve begins with a FRESH conversation (no context from + # earlier test tasks). The sandbox filesystem and learned artifacts are + # untouched. Default False keeps the current behavior: all test tasks + # share one continuous agent conversation. + agent_fresh_session_per_test_task = False # Test-time closed-loop recovery. After each option in the refined plan # finishes, the subgoal_annotations execution monitor checks the # sketch's subgoal annotation for that step against the REAL state; on @@ -1058,6 +1092,8 @@ class GlobalSettings: agent_bilevel_max_execution_replans = 0 # log state pretty_str before/after each step agent_bilevel_log_state = False + # load sketch from file instead of LLM + agent_bilevel_plan_sketch_dir = "plan_sketches" agent_bilevel_plan_sketch_file = "" # load sketch from file instead of LLM # When evaluate_plan_refinement is called without an explicit timeout, # the synthesis tool computes @@ -1122,6 +1158,20 @@ class GlobalSettings: agent_sim_learn_oracle_sim_param_noise_scale = 0.2 # When True, use GT parameter values directly, skipping MCMC fitting. agent_sim_learn_oracle_sim_params = False + # When True, the agent synthesizes per-skill (per-option) samplers that + # aim continuous option parameters at each sketch step's subgoal, instead + # of bilevel refinement drawing them uniformly from the option's box. The + # agent authors a versioned ``samplers.py`` (LEARNED_SAMPLERS keyed by + # option name) and tunes it with the ``evaluate_sampler`` tool. Sampler + # learning rides along in the sim/predicate synthesis session when one + # runs (oracle_sim_program=False); when no synthesis session runs + # (oracle_sim_program=True) it gets a dedicated session of its own. + agent_sim_learn_synthesize_samplers = False + # When True (and synthesize_samplers is on), use ground-truth per-skill + # samplers from the env's GroundTruthSamplerFactory instead of having the + # agent learn them — if such samplers exist for the env; otherwise warn + # and fall back to synthesis. Mirrors agent_sim_learn_oracle_sim_program. + agent_sim_learn_oracle_samplers = False # Names of env predicates kept (not stripped) for the # agent_sim_predicate_invention approach. Empty list defers to the diff --git a/predicators/structs.py b/predicators/structs.py index 9499b1b04..67ea23d98 100644 --- a/predicators/structs.py +++ b/predicators/structs.py @@ -1793,6 +1793,7 @@ class LowLevelTrajectory: _train_task_idx: Optional[int] = field(default=None) _source_simulator_version: Optional[str] = field(default=None) _source_predicates_version: Optional[str] = field(default=None) + _source_samplers_version: Optional[str] = field(default=None) def __post_init__(self) -> None: assert len(self._states) == len(self._actions) + 1 @@ -1835,6 +1836,12 @@ def source_predicates_version(self) -> Optional[str]: collected this trajectory, or ``None`` if not tracked.""" return self._source_predicates_version + @property + def source_samplers_version(self) -> Optional[str]: + """Snapshot tag of the per-skill samplers used to generate the plan + that collected this trajectory, or ``None`` if not tracked.""" + return self._source_samplers_version + @dataclass(frozen=True, repr=False, eq=False) class AtomOptionTrajectory: @@ -3263,6 +3270,15 @@ def copy(self) -> _GroundExogenousProcess: NSRTSamplerWithEpsilonIndicator = Callable[ [State, Set[GroundAtom], np.random.Generator, Sequence[Object]], Tuple[Array, bool]] +# Per-skill sampler consulted during bilevel-sketch refinement. Shares +# NSRTSampler's call signature (state, atoms, rng, objects) so the two are +# interchangeable, but the GroundAtom set it receives is the step's +# *subgoal* (not the task goal), letting it aim continuous params at the +# subgoal instead of drawing uniformly. Returns a params array matching the +# option's params_space; refinement clips it to that box and falls back to +# uniform on a wrong-shaped return. +OptionSampler = Callable[ + [State, Set[GroundAtom], np.random.Generator, Sequence[Object]], Array] Metrics = DefaultDict[str, float] LiftedOrGroundAtom = TypeVar("LiftedOrGroundAtom", LiftedAtom, GroundAtom, _Atom) diff --git a/predicators/utils.py b/predicators/utils.py index c50660f88..bd2b55c2c 100644 --- a/predicators/utils.py +++ b/predicators/utils.py @@ -2889,6 +2889,23 @@ def create_vlm_by_name( f"{CFG.pretrained_model_service_provider}") +def strip_enumeration_prefix(line: str) -> str: + """Strip a leading list-enumeration prefix like ``0:``, ``1.``, ``2)``. + + Agents sometimes number plan/sketch lines, mirroring the numbered + format the system itself prints in logs and prior-failure previews + (e.g. ``0: Pick(robot:robot, block:block)``). The option-plan parser + keys on the option name being the first token of the line, so an + unstripped number prefix turns ``0: Pick(...)`` into the bogus token + ``"0: Pick"`` and the whole plan parses as empty. Stripping is + deliberately conservative: it matches only a leading run of digits + followed by one of ``:.)`` so prose bullets like ``- Step 1:`` are + left untouched (their option name is still not the first token, so + they are correctly ignored as preamble). + """ + return re.sub(r'^\s*\d+\s*[:.)]\s*', '', line) + + def parse_model_output_into_option_plan( model_prediction: str, objects: Collection[Object], types: Collection[Type], options: Collection[ParameterizedOption], @@ -2912,7 +2929,11 @@ def parse_model_output_into_option_plan( obj_name_to_obj = {o.name: o for o in objects} options_str_list = model_prediction.split('\n') for option_str in options_str_list: - option_str_stripped = option_str.strip() + # Tolerate a leading enumeration prefix ("0:", "1.", "2)") that + # agents emit when mirroring the numbered sketch format shown in + # logs; without this the bogus first token makes the plan parse + # as empty. + option_str_stripped = strip_enumeration_prefix(option_str.strip()) option_name = option_str_stripped.split('(')[0] # Skip empty option strs. if not option_str: diff --git a/scripts/configs/ExoPredicator/causal_predicator.yaml b/scripts/configs/ExoPredicator/causal_predicator.yaml index 2721749e9..d1349b5aa 100644 --- a/scripts/configs/ExoPredicator/causal_predicator.yaml +++ b/scripts/configs/ExoPredicator/causal_predicator.yaml @@ -378,7 +378,7 @@ ENVS: horizon: 200 domino_initialize_at_finished_state: False domino_use_domino_blocks_as_target: True - domino_use_grid: True + process_planning_use_gt_helpers: True domino_include_connected_predicate: False # necessary to generate valid plan domino_prune_actions: False process_planning_heuristic_weight: 2.0 diff --git a/scripts/configs/ExoPredicator/causal_predicator_baselines.yaml b/scripts/configs/ExoPredicator/causal_predicator_baselines.yaml index e7dffa44e..677e72f2e 100644 --- a/scripts/configs/ExoPredicator/causal_predicator_baselines.yaml +++ b/scripts/configs/ExoPredicator/causal_predicator_baselines.yaml @@ -176,7 +176,7 @@ ENVS: horizon: 200 domino_initialize_at_finished_state: False domino_use_domino_blocks_as_target: True - domino_use_grid: True + process_planning_use_gt_helpers: True domino_include_connected_predicate: False # necessary to generate valid plan domino_prune_actions: False process_planning_heuristic_weight: 2 # too large will generate suboptimal plans in some cases (e.g. w 2 moveble and 2 targets) diff --git a/scripts/configs/ExoPredicator/mara_bench.yaml b/scripts/configs/ExoPredicator/mara_bench.yaml index 786c34890..cc57331f8 100644 --- a/scripts/configs/ExoPredicator/mara_bench.yaml +++ b/scripts/configs/ExoPredicator/mara_bench.yaml @@ -269,7 +269,7 @@ ENVS: # option_model_terminate_on_repeat: False domino_initialize_at_finished_state: False domino_use_domino_blocks_as_target: True - domino_use_grid: True + process_planning_use_gt_helpers: True domino_include_connected_predicate: False # necessary to generate valid plan domino_prune_actions: False domino_num_dominos_max: 3 diff --git a/scripts/configs/predicatorv3/agents.yaml b/scripts/configs/predicatorv3/agents.yaml index 07e158c0a..9bd564514 100644 --- a/scripts/configs/predicatorv3/agents.yaml +++ b/scripts/configs/predicatorv3/agents.yaml @@ -4,8 +4,26 @@ includes: - common.yaml - envs/all.yaml +# Agent-only env overrides: deep-merged on top of envs/all.yaml. These +# excluded_predicates are dropped for "ours" runs but kept for oracle.yaml. +ENVS: + domino: + FLAGS: + excluded_predicates: "InitialBlock,MovableBlock,Tilting,Upright" APPROACHES: - # agent_planner: + # # Baseline: agent planning does NOT have a simulator / world model + # agent_model_free_planning: + # NAME: "agent_planner" + # FLAGS: + # explorer: "agent_plan" + # demonstrator: "oracle_process_planning" + # terminate_on_goal_reached_and_option_terminated: True + # agent_sdk_use_local_sandbox: True + # option_model_terminate_on_repeat: False + # agent_planner_use_scratchpad: False + # agent_planner_use_simulator: False + # # Baseline: ground truth monolithic sim + agent planner + # agent_model_based_planning: # NAME: "agent_planner" # FLAGS: # explorer: "agent_plan" @@ -13,11 +31,11 @@ APPROACHES: # terminate_on_goal_reached_and_option_terminated: True # agent_sdk_use_local_sandbox: True # option_model_terminate_on_repeat: False - # agent_sdk_max_agent_turns_per_iteration: 50 # agent_planner_use_scratchpad: False + # agent_planner_use_simulator: True # agent_planner_use_visualize_state: True # agent_planner_use_annotate_scene: True - # option_model_use_gui: True + # Oracle: ground truth monolithic sim + predicates + our planning pipeline # agent_bilevel: # NAME: "agent_bilevel" # FLAGS: @@ -26,13 +44,57 @@ APPROACHES: # terminate_on_goal_reached_and_option_terminated: True # agent_sdk_use_local_sandbox: True # option_model_terminate_on_repeat: False - # agent_sdk_max_agent_turns_per_iteration: 50 # agent_planner_use_scratchpad: False # agent_planner_use_visualize_state: True # agent_planner_use_annotate_scene: True # option_model_use_gui: True # agent_bilevel_log_state: False # agent_bilevel_plan_sketch_file: "tests/approaches/test_data/boil_plan_sketch.txt" + # Oracle: hybrid sim + agent_oracle_hybrid_sim_oracle_samplers_demo: + NAME: "agent_sim_learning" + FLAGS: + demonstrator: "oracle_process_planning" + bilevel_plan_without_sim: True # for the demonstrator + explorer: "agent_bilevel" + terminate_on_goal_reached_and_option_terminated: True + agent_sdk_use_local_sandbox: True + option_model_terminate_on_repeat: False + agent_planner_use_visualize_state: True + agent_planner_use_annotate_scene: True + option_model_use_gui: False + agent_bilevel_log_state: False + agent_sim_learn_oracle_sim_program: True + agent_sim_learn_oracle_sim_params: True + agent_sim_learn_synthesize_samplers: True + agent_sim_learn_oracle_samplers: True + num_online_learning_cycles: 0 + agent_explorer_info_seeking: True + execution_monitor: "subgoal_annotations" + agent_bilevel_max_execution_replans: 2 + domino_restricted_push: False + # agent_oracle_hybrid_sim_no_oracle_samplers: + # NAME: "agent_sim_learning" + # FLAGS: + # demonstrator: "oracle_process_planning" + # bilevel_plan_without_sim: True # for the demonstrator + # explorer: "agent_bilevel" + # terminate_on_goal_reached_and_option_terminated: True + # agent_sdk_use_local_sandbox: True + # option_model_terminate_on_repeat: False + # agent_planner_use_visualize_state: True + # agent_planner_use_annotate_scene: True + # option_model_use_gui: False + # agent_bilevel_log_state: False + # agent_sim_learn_oracle_sim_program: True + # agent_sim_learn_oracle_sim_params: True + # agent_sim_learn_synthesize_samplers: False + # agent_sim_learn_oracle_samplers: False + # num_online_learning_cycles: 0 + # agent_explorer_info_seeking: True + # execution_monitor: "subgoal_annotations" + # agent_bilevel_max_execution_replans: 2 + # Oracle: ground truth hybrid sim / predicates; learn params # agent_param_learning: # NAME: "agent_sim_learning" # FLAGS: @@ -41,17 +103,16 @@ APPROACHES: # terminate_on_goal_reached_and_option_terminated: True # agent_sdk_use_local_sandbox: True # option_model_terminate_on_repeat: False - # agent_sdk_max_agent_turns_per_iteration: 50 # agent_planner_use_visualize_state: True # agent_planner_use_annotate_scene: True # option_model_use_gui: True # agent_bilevel_log_state: False # agent_bilevel_plan_sketch_file: "tests/approaches/test_data/boil_plan_sketch.txt" - # skip_test_until_last_ite_or_early_stopping: False # agent_sim_learn_oracle_sim_program: True # agent_sim_learn_oracle_sim_params: False # agent_sim_learn_oracle_sim_param_noise_scale: 1.0 # 1.0 allows successful planning but insatisficing plan; 0.8 produces satisficing plan # code_sim_learning_num_mcmc_steps: 0 + # Oracle: ground truth predicates; learn hybrid sim and params # agent_rule_learning: # NAME: "agent_sim_learning" # FLAGS: @@ -60,17 +121,16 @@ APPROACHES: # terminate_on_goal_reached_and_option_terminated: True # agent_sdk_use_local_sandbox: True # option_model_terminate_on_repeat: False - # agent_sdk_max_agent_turns_per_iteration: 50 # agent_planner_use_visualize_state: True # agent_planner_use_annotate_scene: True # option_model_use_gui: True # agent_bilevel_log_state: False # agent_bilevel_plan_sketch_file: "tests/approaches/test_data/boil_plan_sketch.txt" - # skip_test_until_last_ite_or_early_stopping: False # agent_sim_learn_oracle_sim_program: False # agent_sim_learn_oracle_sim_params: False # code_sim_learning_num_mcmc_steps: 0 - # agent_predicate_invention: + # Oracle: see state; no ground truth; learn predicates, hybrid sim, and params + # agent_predicate_invention: # NAME: "agent_sim_predicate_invention" # FLAGS: # explorer: "agent_bilevel" @@ -87,49 +147,29 @@ APPROACHES: # agent_sim_learn_oracle_sim_params: False # code_sim_learning_num_mcmc_steps: 0 # agent_sim_predicate_invention_kept_predicate_names: ["Holding"] - # agent_po_gt_sim: - # NAME: "agent_sim_learning" + # # Ours: no ground truth; learn predicates, hybrid sim, and params in PO setting + # agent_po_predicate_invention_al: + # NAME: "agent_po_sim_predicate_invention" # FLAGS: # demonstrator: "oracle_process_planning" # explorer: "agent_bilevel" # terminate_on_goal_reached_and_option_terminated: True # agent_sdk_use_local_sandbox: True # option_model_terminate_on_repeat: False + # agent_planner_use_visualize_state: True + # agent_planner_use_annotate_scene: True # option_model_use_gui: False # agent_bilevel_log_state: False - # agent_bilevel_plan_sketch_file: "tests/approaches/test_data/boil_plan_sketch.txt" + # online_learning_early_stopping: True + # agent_sim_learn_oracle_sim_program: False + # agent_sim_learn_oracle_sim_params: False + # code_sim_learning_num_mcmc_steps: 0 + # code_sim_learning_warm_start_with_lm: True + # agent_sim_predicate_invention_kept_predicate_names: ["Holding"] # partially_observable: True - # agent_sim_learn_oracle_sim_program: True - # agent_sim_learn_oracle_sim_params: True - # num_online_learning_cycles: 0 - agent_po_predicate_invention_al: - NAME: "agent_po_sim_predicate_invention" - FLAGS: - demonstrator: "oracle_process_planning" - explorer: "agent_bilevel" - terminate_on_goal_reached_and_option_terminated: True - agent_sdk_use_local_sandbox: True - option_model_terminate_on_repeat: False - agent_sdk_max_agent_turns_per_iteration: 50 - agent_planner_use_visualize_state: True - agent_planner_use_annotate_scene: True - option_model_use_gui: False - agent_bilevel_log_state: False - skip_test_until_last_ite_or_early_stopping: False - online_learning_early_stopping: True - agent_sim_learn_oracle_sim_program: False - agent_sim_learn_oracle_sim_params: False - code_sim_learning_num_mcmc_steps: 0 - code_sim_learning_warm_start_with_lm: True - agent_sim_predicate_invention_kept_predicate_names: ["Holding"] - partially_observable: True - agent_explorer_info_seeking: True - # Closed-loop test execution: replan when a finished step's subgoal - # annotation fails in the real state (chaotic place landings were - # costing 2-4 test tasks per run; see boil-…_al seed0/seed1 logs). - # The monitor detects divergence; the budget caps recoveries. - execution_monitor: "subgoal_annotations" - agent_bilevel_max_execution_replans: 2 + # agent_explorer_info_seeking: True + # execution_monitor: "subgoal_annotations" + # agent_bilevel_max_execution_replans: 2 # agent_option_learning: # NAME: "agent_option_learning" # FLAGS: @@ -138,4 +178,3 @@ APPROACHES: # demonstrator: "oracle_process_planning" # terminate_on_goal_reached_and_option_terminated: True # agent_sdk_use_local_sandbox: True - # agent_sdk_max_agent_turns_per_iteration: 50 diff --git a/scripts/configs/predicatorv3/common.yaml b/scripts/configs/predicatorv3/common.yaml index 7e1640a1c..1442fa281 100644 --- a/scripts/configs/predicatorv3/common.yaml +++ b/scripts/configs/predicatorv3/common.yaml @@ -31,4 +31,4 @@ FLAGS: log: 'logs/' no_repeated_arguments_in_grounding: True START_SEED: 0 -NUM_SEEDS: 5 +NUM_SEEDS: 1 \ No newline at end of file diff --git a/scripts/configs/predicatorv3/envs/all.yaml b/scripts/configs/predicatorv3/envs/all.yaml index 07861a6b3..2ed80802d 100644 --- a/scripts/configs/predicatorv3/envs/all.yaml +++ b/scripts/configs/predicatorv3/envs/all.yaml @@ -7,27 +7,25 @@ ENVS: # grow_weak_pour_terminate_condition: True # grow_place_option_no_sampler: True # horizon: 400 - # domino: - # NAME: "pybullet_domino" - # FLAGS: - # excluded_objects_in_state_str: "loc,rot,angle,direction" - # horizon: 200 - # domino_initialize_at_finished_state: False - # domino_use_domino_blocks_as_target: True - # domino_use_grid: True - # domino_include_connected_predicate: False - # domino_use_continuous_place: True - # domino_restricted_push: True - # domino_prune_actions: False - # process_planning_heuristic_weight: 2.0 - # process_planning_use_abstract_policy: False - # domino_has_glued_dominos: False - # keep_failed_demos: True - # env_has_impossible_goals: True - # process_param_learning_use_empirical: True - # process_learning_use_empirical: True - # predicate_invent_invent_derived_predicates: True - # script_option_file_name: "domino2.txt" + domino: + NAME: "pybullet_domino" + FLAGS: + excluded_objects_in_state_str: "loc,rot,angle,direction" + # excluded_predicates is set per-approach: agents.yaml excludes + # these (test ours); oracle.yaml leaves them in (test oracle). + horizon: 400 + domino_initialize_at_finished_state: False + domino_use_domino_blocks_as_target: True + domino_use_continuous_place: True + domino_restricted_push: True + process_planning_heuristic_weight: 2.0 + domino_has_glued_dominos: False + keep_failed_demos: True + predicate_invent_invent_derived_predicates: True + # script_option_file_name: "domino2.txt" + # agent_bilevel_plan_sketch_file: "domino3.txt" + pybullet_birrt_extend_num_interp: 20 # increase this to avoid collisions when placing dominoes + pybullet_birrt_path_subsample_ratio: 2 # coffee: # NAME: "pybullet_coffee" # FLAGS: @@ -43,18 +41,18 @@ ENVS: # max_num_steps_option_rollout: 100 # horizon: 300 # script_option_file_name: "coffee.txt" - boil: - NAME: "pybullet_boil" - FLAGS: - excluded_objects_in_state_str: "switch" - max_num_steps_option_rollout: 100 - horizon: 500 - boil_goal: "simple" - boil_require_jug_full_to_heatup: True - script_option_file_name: "boil.txt" - boil_water_fill_speed: 0.0015 - pybullet_birrt_path_subsample_ratio: 2 - boil_num_jugs_test: [1] + # boil: + # NAME: "pybullet_boil" + # FLAGS: + # excluded_objects_in_state_str: "switch" + # max_num_steps_option_rollout: 100 + # horizon: 500 + # boil_goal: "simple" + # boil_require_jug_full_to_heatup: True + # script_option_file_name: "boil.txt" + # boil_water_fill_speed: 0.0015 + # pybullet_birrt_path_subsample_ratio: 2 + # boil_num_jugs_test: [1] # fan: # NAME: "pybullet_fan" # FLAGS: diff --git a/scripts/configs/predicatorv3/oracle.yaml b/scripts/configs/predicatorv3/oracle.yaml index 84ae737ab..55c78c4e7 100644 --- a/scripts/configs/predicatorv3/oracle.yaml +++ b/scripts/configs/predicatorv3/oracle.yaml @@ -10,6 +10,19 @@ APPROACHES: FLAGS: demonstrator: "oracle_process_planning" terminate_on_goal_reached_and_option_terminated: True + # Plan open-loop (task plan + greedy execution), NOT sim-in-the-loop. + # The deterministic ``_place_sampler`` rank-sums three signals + # (future-target bridge, planner cell, planner angle) over the + # generator-faithful placements, so the cascade-correct pose is chosen on + # the first try at corners, straights, and spurious-turn plans alike -- + # no per-step pybullet rollout / backtracking needed. (Sim-in-the-loop was + # both slower and lower-scoring: the cascade is too physics-sensitive and + # the option-model resets diverge from continuous execution.) + bilevel_plan_without_sim: True + # Greedy execution validates only the final Toppled(target) goal, not the + # full per-step grid state (whose exact Tilting/Upright cascade timing the + # physics can't match step for step). + sesame_check_expected_atoms: False # human_interaction: # NAME: "human_interaction" # FLAGS: diff --git a/scripts/configs/predicatorv3/predicator_v3.yaml b/scripts/configs/predicatorv3/predicator_v3.yaml index 479248aec..dd139f40e 100644 --- a/scripts/configs/predicatorv3/predicator_v3.yaml +++ b/scripts/configs/predicatorv3/predicator_v3.yaml @@ -80,7 +80,6 @@ ENVS: # horizon: 200 # domino_initialize_at_finished_state: False # domino_use_domino_blocks_as_target: True - # domino_use_grid: True # domino_include_connected_predicate: False # necessary to generate valid plan # domino_use_continuous_place: True # domino_restricted_push: True diff --git a/scripts/configs/predicatorv3/random_actions_pybullet.yaml b/scripts/configs/predicatorv3/random_actions_pybullet.yaml index 344982c3b..150d7fec4 100644 --- a/scripts/configs/predicatorv3/random_actions_pybullet.yaml +++ b/scripts/configs/predicatorv3/random_actions_pybullet.yaml @@ -81,7 +81,6 @@ ENVS: bilevel_plan_without_sim: True domino_initialize_at_finished_state: False domino_use_domino_blocks_as_target: True - domino_use_grid: True domino_include_connected_predicate: False domino_prune_actions: False float: diff --git a/scripts/dbg_domino_infront.py b/scripts/dbg_domino_infront.py new file mode 100644 index 000000000..ad6268ebb --- /dev/null +++ b/scripts/dbg_domino_infront.py @@ -0,0 +1,86 @@ +"""Crack task-2 spurious InFront: use the EXACT approach machinery.""" + +from predicators import utils +from predicators.envs.pybullet_domino.env import PyBulletDominoEnv +from predicators.ground_truth_models import augment_task_with_helper_objects, \ + get_gt_helper_predicates +from predicators.structs import Task + +utils.reset_config({ + "env": "pybullet_domino", + "seed": 0, + "num_train_tasks": 1, + "num_test_tasks": 5, + "domino_use_domino_blocks_as_target": True, + "domino_use_continuous_place": True, + "domino_restricted_push": True, + "domino_initialize_at_finished_state": False, + "domino_has_glued_dominos": False, +}) + +env = PyBulletDominoEnv() +tasks = env._generate_test_tasks() # pylint: disable=protected-access +env_task = tasks[1] # task 2 +task = augment_task_with_helper_objects(Task(env_task.init, env_task.goal), + "pybullet_domino") +s = task.init +helper_preds = get_gt_helper_predicates("pybullet_domino") +# How many distinct InFront predicate OBJECTS exist, and from where? +env_infronts = [p for p in env.predicates if p.name == "InFront"] +helper_infronts = [p for p in helper_preds if p.name == "InFront"] +print(f"env InFront objs: {len(env_infronts)} " + f"derived={[type(p).__name__ for p in env_infronts]}") +print(f"helper InFront objs: {len(helper_infronts)} " + f"derived={[type(p).__name__ for p in helper_infronts]}") +if env_infronts and helper_infronts: + print("same object?", env_infronts[0] is helper_infronts[0]) + print("equal (==)?", env_infronts[0] == helper_infronts[0]) + +# The approach does: helpers | initial_predicates (helpers win on collision). +full_preds = helper_preds | set(env.predicates) +preds = {p.name: p for p in full_preds} +infront_in_full = [p for p in full_preds if p.name == "InFront"] +print(f"InFront objs in (helpers|env): {len(infront_in_full)} " + f"-> {[type(p).__name__ for p in infront_in_full]}") + +# Apply the FIX: drop base predicates whose name a helper already provides. +helper_names = {p.name for p in helper_preds} +fixed_preds = helper_preds | { + p + for p in env.predicates if p.name not in helper_names +} +fixed_infront = [p for p in fixed_preds if p.name == "InFront"] +fixed_atoms = utils.abstract(s, fixed_preds) +print(f"FIXED: InFront objs={len(fixed_infront)} " + f"types={[type(p).__name__ for p in fixed_infront]}") +print("FIXED InFront atoms:", + sorted(str(a) for a in fixed_atoms if a.predicate.name == "InFront")) + +atoms = utils.abstract(s, full_preds) +atpos = { + a.objects[0].name: a.objects[1] + for a in atoms if a.predicate.name == "DominoAtPos" +} +print("=== DominoAtPos ===") +for d in sorted(atpos): + loc = atpos[d] + print(f" {d} -> {loc.name} " + f"(xx={s.get(loc,'xx'):.4f} yy={s.get(loc,'yy'):.4f})") +print("=== InFront atoms ===") +for a in sorted(str(x) for x in atoms if x.predicate.name == "InFront"): + print(" ", a) +print("=== InFrontDirection atoms ===") +for a in sorted( + str(x) for x in atoms if x.predicate.name == "InFrontDirection"): + print(" ", a) + +c0, c1 = atpos["domino_0"], atpos["domino_1"] +n0 = tuple(float(v) for v in c0.name.split("_")[1:]) +n1 = tuple(float(v) for v in c1.name.split("_")[1:]) +print("=== manual d0 vs d1 ===") +print(f" d0 cell name->coords: {n0} feats: " + f"({s.get(c0,'xx'):.4f},{s.get(c0,'yy'):.4f})") +print(f" d1 cell name->coords: {n1} feats: " + f"({s.get(c1,'xx'):.4f},{s.get(c1,'yy'):.4f})") +print(f" name dx={abs(n0[0]-n1[0]):.4f} dy={abs(n0[1]-n1[1]):.4f} " + f"(pos_gap=0.098, tol={0.098*0.3:.4f})") diff --git a/scripts/dbg_domino_tasks.py b/scripts/dbg_domino_tasks.py new file mode 100644 index 000000000..fc1a3a3bc --- /dev/null +++ b/scripts/dbg_domino_tasks.py @@ -0,0 +1,52 @@ +"""Dump domino test-task geometry (roles + poses), no physics. + +Usage: python scripts/dbg_domino_tasks.py [seed] +""" +import sys + +import numpy as np + +from predicators import utils +from predicators.envs.pybullet_domino.components.domino_component import \ + DominoComponent +from predicators.envs.pybullet_domino.env import PyBulletDominoEnv + +_SEED = int(sys.argv[1]) if len(sys.argv) > 1 else 0 +utils.reset_config({ + "env": "pybullet_domino", + "seed": _SEED, + "num_train_tasks": 1, + "num_test_tasks": 5, + "domino_use_domino_blocks_as_target": True, + "domino_use_continuous_place": True, + "domino_restricted_push": True, + "domino_initialize_at_finished_state": True, + "domino_has_glued_dominos": False, +}) + +env = PyBulletDominoEnv() +tasks = env._generate_test_tasks() # pylint: disable=protected-access + +for ti, task in enumerate(tasks): + s = task.init + dt = None + for o in s: + if o.type.name == "domino": + dt = o.type + break + dominoes = sorted((o for o in s if o.type == dt), key=lambda o: o.name) + print(f"\n===== TASK {ti+1} =====") + print("goal:", sorted(str(a) for a in task.goal)) + for d in dominoes: + x = s.get(d, "x") + y = s.get(d, "y") + yaw = np.degrees(s.get(d, "yaw")) + # pylint: disable=protected-access + is_start = DominoComponent._StartBlock_holds(s, [d]) + is_target = DominoComponent._TargetDomino_holds(s, [d]) + is_movable = DominoComponent._MovableBlock_holds(s, [d]) \ + if hasattr(DominoComponent, "_MovableBlock_holds") else ( + not is_start and not is_target) + role = ("START" if is_start else + "TARGET" if is_target else "MOVABLE" if is_movable else "?") + print(f" {d.name:10s} {role:8s} pos=({x:.3f},{y:.3f}) yaw={yaw:6.1f}") diff --git a/scripts/plan_sketches/domino3.txt b/scripts/plan_sketches/domino3.txt new file mode 100644 index 000000000..b14c27d5e --- /dev/null +++ b/scripts/plan_sketches/domino3.txt @@ -0,0 +1,7 @@ +Plan: +Pick(robot:robot, domino_1:domino) -> {Holding(robot:robot, domino_1:domino)} +Place(robot:robot) -> {HandEmpty(robot:robot), InFront(domino_1:domino, domino_0:domino)} +Pick(robot:robot, domino_2:domino) -> {Holding(robot:robot, domino_2:domino)} +Place(robot:robot) -> {HandEmpty(robot:robot), InFront(domino_3:domino, domino_2:domino), InFront(domino_2:domino, domino_1:domino)} +Push(robot:robot) -> {Toppled(domino_0:domino)} +Wait(robot:robot) -> {Toppled(domino_3:domino)} diff --git a/scripts/plan_sketches/domino4.txt b/scripts/plan_sketches/domino4.txt new file mode 100644 index 000000000..d4f2a3e5a --- /dev/null +++ b/scripts/plan_sketches/domino4.txt @@ -0,0 +1,7 @@ +Plan: +Pick(robot:robot, domino_1:domino) -> {Holding(robot:robot, domino_1:domino)} +Place(robot:robot) -> {HandEmpty(robot:robot), InFront(domino_1:domino, domino_0:domino)} +Pick(robot:robot, domino_2:domino) -> {Holding(robot:robot, domino_2:domino)} +Place(robot:robot) -> {HandEmpty(robot:robot), InFront(domino_3:domino, domino_2:domino), InFront(domino_2:domino, domino_1:domino)} +Push(robot:robot) -> {Toppled(domino_0:domino)} +Wait(robot:robot) -> {Toppled(domino_4:domino), Toppled(domino_3:domino)} \ No newline at end of file diff --git a/scripts/plan_sketches/domino_repro_s1t0.txt b/scripts/plan_sketches/domino_repro_s1t0.txt new file mode 100644 index 000000000..b14c27d5e --- /dev/null +++ b/scripts/plan_sketches/domino_repro_s1t0.txt @@ -0,0 +1,7 @@ +Plan: +Pick(robot:robot, domino_1:domino) -> {Holding(robot:robot, domino_1:domino)} +Place(robot:robot) -> {HandEmpty(robot:robot), InFront(domino_1:domino, domino_0:domino)} +Pick(robot:robot, domino_2:domino) -> {Holding(robot:robot, domino_2:domino)} +Place(robot:robot) -> {HandEmpty(robot:robot), InFront(domino_3:domino, domino_2:domino), InFront(domino_2:domino, domino_1:domino)} +Push(robot:robot) -> {Toppled(domino_0:domino)} +Wait(robot:robot) -> {Toppled(domino_3:domino)} diff --git a/scripts/render_domino_initial_states.py b/scripts/render_domino_initial_states.py new file mode 100644 index 000000000..ed8c591e6 --- /dev/null +++ b/scripts/render_domino_initial_states.py @@ -0,0 +1,82 @@ +"""Render initial states of the domino test tasks for debugging. + +Reproduces the exact test tasks from a run (same env, seed, +test_env_seed_offset and domino flags) and saves a PNG of each test +task's initial state so failed tasks can be visualized. + +Usage: + PYTHONPATH=. python scripts/render_domino_initial_states.py +""" +import os + +import numpy as np +from PIL import Image + +from predicators import utils +from predicators.envs import create_new_env + +# Domino flags copied verbatim from the run namespace (info.log) so the +# generated test tasks match the run exactly. +_DOMINO_FLAGS = { + "env": "pybullet_domino", + "num_train_tasks": 1, + "num_test_tasks": 5, + "test_env_seed_offset": 10000, + "pybullet_camera_width": 1340, + "pybullet_camera_height": 720, + "domino_test_num_dominos": [3], + "domino_test_num_targets": [1, 2], + "domino_test_num_pivots": [0], + "domino_test_num_pos_x": 4, + "domino_test_num_pos_y": 3, + "domino_train_num_dominos": [2], + "domino_train_num_targets": [1], + "domino_train_num_pivots": [0], + "domino_train_num_pos_x": 3, + "domino_train_num_pos_y": 2, + "domino_use_continuous_place": True, + "domino_use_domino_blocks_as_target": True, + "domino_restricted_push": True, + "domino_only_straight_sequence_in_training": True, + "domino_use_skill_factories": True, + "domino_prune_actions": False, + "domino_has_glued_dominos": False, + "domino_some_dominoes_are_connected": False, + "domino_include_connected_predicate": False, + "domino_initialize_at_finished_state": False, + "domino_debug_layout": False, + "domino_domino_on_stairs": False, +} + +# Which 1-indexed test tasks failed in each seed (for labeling). +_FAILED = {0: {2, 3}, 2: {1, 2, 3, 5}} + +_OUT_DIR = ("logs/agent_sim_learning/" + "domino-agent_oracle_hybrid_sim_oracle_samplers/initial_states") + + +def main() -> None: + """Render init-state PNGs for the domino test tasks of seeds 0 and 2.""" + os.makedirs(_OUT_DIR, exist_ok=True) + for seed in (0, 2): + utils.reset_config({**_DOMINO_FLAGS, "seed": seed}) + # do_cache=False: a cached env keeps its seed-0 test tasks, so each + # seed must build a fresh env to regenerate its own test tasks. + env = create_new_env("pybullet_domino", do_cache=False) + tasks = env.get_test_tasks() + for idx in range(len(tasks)): + env.reset("test", idx) + rgb = np.asarray(env.render()[0], dtype=np.uint8) + task_num = idx + 1 # 1-indexed to match the run logs + status = "FAILED" if task_num in _FAILED.get(seed, set()) \ + else "solved" + fname = f"seed{seed}_task{task_num}_{status}.png" + path = os.path.join(_OUT_DIR, fname) + Image.fromarray(rgb).save(path) + goal = sorted(str(a) for a in tasks[idx].goal) + print(f"seed{seed} task{task_num} [{status}] -> {path}") + print(f" goal: {goal}") + + +if __name__ == "__main__": + main() diff --git a/scripts/render_unsolved_domino_states.py b/scripts/render_unsolved_domino_states.py new file mode 100644 index 000000000..b4e377415 --- /dev/null +++ b/scripts/render_unsolved_domino_states.py @@ -0,0 +1,162 @@ +"""Render init-state PNGs for the unsolved domino tasks (oracle-samplers runs). + +Uses the geometry-affecting flags from the experiment command line so the +regenerated test scenes match the runs exactly (verified: seed1 = [4,4,5,4,4] +dominoes, and the seed1.t2 grasp-infeasibility matches the run). Run ONE seed +per process (task-gen RNG is shared across seeds in one interpreter). + +Usage: PYTHONPATH=. python scripts/render_unsolved_domino_states.py +""" +import os +import sys + +import numpy as np +from PIL import Image, ImageDraw, ImageFont + +from predicators import utils + + +def _project(xyz, view_matrix, proj_matrix, width, height): + """World (x,y,z) -> (u,v) pixel using pybullet's column-major matrices.""" + V = np.array(view_matrix).reshape((4, 4), order="F") + P = np.array(proj_matrix).reshape((4, 4), order="F") + clip = P @ (V @ np.array([xyz[0], xyz[1], xyz[2], 1.0])) + if clip[3] == 0: + return None + ndc = clip[:3] / clip[3] + return ((ndc[0] * 0.5 + 0.5) * width, + (1.0 - (ndc[1] * 0.5 + 0.5)) * height) + + +def _font(size): + for path in ("/System/Library/Fonts/Supplemental/Arial Bold.ttf", + "/System/Library/Fonts/Helvetica.ttc"): + try: + return ImageFont.truetype(path, size) + except Exception: # pylint: disable=broad-except + pass + try: + return ImageFont.load_default(size=size) + except TypeError: + return ImageFont.load_default() + + +def _caption(rgb, lines): + """Draw a header banner (top-left) with the given text lines.""" + img = Image.fromarray(rgb) + draw = ImageDraw.Draw(img, "RGBA") + font = _font(22) + pad, lh = 8, 26 + w = max(draw.textlength(t, font=font) for t in lines) + draw.rectangle([0, 0, w + 2 * pad, lh * len(lines) + pad], + fill=(0, 0, 0, 170)) + for i, t in enumerate(lines): + draw.text((pad, pad + i * lh), t, fill=(255, 255, 255), font=font) + return np.asarray(img) + + +def _annotate(rgb, init_state, cam): + """Label each domino with its index at its initial-state position.""" + img = Image.fromarray(rgb) + draw = ImageDraw.Draw(img) + font = _font(26) + for o in sorted([o for o in init_state if o.type.name == "domino"], + key=lambda o: o.name): + x, y, z = (init_state.get(o, "x"), init_state.get(o, "y"), + init_state.get(o, "z")) + uv = _project((x, y, z + 0.13), *cam) + if uv is None: + continue + u, v = uv + idx = o.name.split("_")[-1] + col = (int(init_state.get(o, "r") * 255), + int(init_state.get(o, "g") * 255), + int(init_state.get(o, "b") * 255)) + r = 15 + draw.ellipse([u - r, v - r, u + r, v + r], + fill=(0, 0, 0), + outline=col, + width=3) + tb = draw.textbbox((0, 0), idx, font=font) + draw.text((u - (tb[2] - tb[0]) / 2, v - (tb[3] - tb[1]) / 2 - tb[1]), + idx, + fill=(255, 255, 255), + font=font) + return np.asarray(img) + + +# 1-indexed tasks unsolved in EITHER arm, with (arms, failure-mode) labels. +UNSOLVED = { + 0: { + 1: ("both", "push-dropped"), + 2: ("both", "place-MP+InFront"), + 3: ("no_demo", "pick+place-MP") + }, + 1: { + 1: ("demo", "exec-retreat-collision"), + 3: ("both", "pick+place-MP") + }, + 2: { + 1: ("no_demo", "pick+place-MP"), + 2: ("demo", "toppled-cascade"), + 4: ("both", "pick+place-MP"), + 5: ("both", "place-MP+toppled") + }, + 3: { + 5: ("demo", "holding+InFront+place-MP") + }, +} +FLAGS = { + "env": "pybullet_domino", + "num_train_tasks": 1, + "num_test_tasks": 5, + "pybullet_ik_validate": False, + "pybullet_camera_width": 900, + "pybullet_camera_height": 900, + "domino_initialize_at_finished_state": False, + "domino_use_domino_blocks_as_target": True, + "domino_use_continuous_place": True, + "domino_restricted_push": True, + "domino_has_glued_dominos": False, + "pybullet_birrt_extend_num_interp": 20, + "pybullet_birrt_path_subsample_ratio": 2, +} +OUT = "logs/agent_sim_learning/unsolved_init_states" + + +def main(): + """Render annotated init-state PNGs for one seed's unsolved tasks.""" + seed = int(sys.argv[1]) + os.makedirs(OUT, exist_ok=True) + utils.reset_config(dict(FLAGS, seed=seed)) + # Deferred until after reset_config: create_new_env reads CFG at import. + # pylint: disable=import-outside-toplevel + from predicators.envs import create_new_env + env = create_new_env("pybullet_domino", do_cache=False) + tasks = env.get_test_tasks() + counts = [ + len([o for o in t.init if o.type.name == "domino"]) for t in tasks + ] + print(f"seed{seed} domino counts per task = {counts}") + cam = env._get_camera_matrices() # pylint: disable=protected-access + for t1, (arms, mode) in sorted(UNSOLVED.get(seed, {}).items()): + idx = t1 - 1 + env.reset("test", idx) + rgb = np.asarray(env.render()[0], dtype=np.uint8) + rgb = _annotate(rgb, tasks[idx].init, cam) + goal_ids = ",".join( + sorted( + str(a).rsplit("_", maxsplit=1)[-1].rstrip(":domino)") + for a in tasks[idx].goal)) + rgb = _caption(rgb, [ + f"seed {seed} task {t1} ({arms})", + f"goal: Toppled({goal_ids}) fail: {mode}" + ]) + fname = f"seed{seed}_task{t1}_{arms}_{mode}.png" + Image.fromarray(rgb).save(os.path.join(OUT, fname)) + goal = sorted(str(a) for a in tasks[idx].goal) + print(f" saved {fname} | {counts[idx]} dominoes | goal={goal}") + + +if __name__ == "__main__": + main() diff --git a/scripts/replay_domino_sketches.py b/scripts/replay_domino_sketches.py new file mode 100644 index 000000000..18dffb0f1 --- /dev/null +++ b/scripts/replay_domino_sketches.py @@ -0,0 +1,213 @@ +"""Faithfully reproduce domino refinement failures by replaying the recorded +LLM sketches through the real bilevel refinement -- no LLM required. + +The agent's plan sketches were logged verbatim in each run's ``info.log`` +(``Sketch (attempt N):`` blocks). This script extracts them, regenerates the +deterministic test task, and runs the *exact same* ``refine_sketch`` the +pipeline uses (oracle option model + oracle samplers + subgoal checks, same +per-(sketch,refine) RNG seeding). The pass/fail outcome and the "stuck at step +K" reason therefore reproduce the run's solve-time failures deterministically. + +Run ONE seed per process (task-gen RNG is shared; see +reproduce_domino_failures). + +Usage: + PYTHONPATH=. python scripts/replay_domino_sketches.py \ + [--all] + --all replays every task; default replays only tasks the run + did not solve. +""" + +import logging +import re +import sys +from glob import glob + +logging.disable(logging.CRITICAL) + +ANSI = re.compile(r"\x1b\[[0-9;]*m") +STEP = re.compile( + r"^\s*\d+:\s*([A-Za-z]\w*)\((.*?)\)(?:\s*->\s*\{(.*)\})?\s*$") +SKETCH_HDR = re.compile(r"Sketch \(attempt (\d+)\)") +TASK_RES = re.compile( + r"\[main\.py\] Task (\d+) / \d+: (.*)|Task (\d+) / \d+: (SOLVED)") + +_FLAGS = { + "env": "pybullet_domino", + "approach": "agent_sim_learning", + "num_train_tasks": 1, + "num_test_tasks": 5, + "skill_phase_use_motion_planning": True, + "pybullet_ik_validate": False, + "demonstrator": "oracle_process_planning", + "bilevel_plan_without_sim": True, + "explorer": "agent_bilevel", + "agent_sim_learn_oracle_sim_program": True, + "agent_sim_learn_oracle_sim_params": True, + "agent_sim_learn_synthesize_samplers": True, + "agent_sim_learn_oracle_samplers": True, + "execution_monitor": "subgoal_annotations", + "agent_bilevel_max_execution_replans": 2, + "horizon": 400, + "excluded_objects_in_state_str": "loc,rot,angle,direction", + "excluded_predicates": "InitialBlock,MovableBlock,Tilting,Upright", + "domino_initialize_at_finished_state": False, + "domino_use_domino_blocks_as_target": True, + "domino_use_continuous_place": True, + "domino_restricted_push": True, + "process_planning_heuristic_weight": 2.0, + "domino_has_glued_dominos": False, + "pybullet_birrt_extend_num_interp": 20, + "pybullet_birrt_path_subsample_ratio": 2, + "agent_sdk_use_local_sandbox": True, + "option_model_terminate_on_repeat": False, + "agent_planner_use_simulator": True, +} + + +def find_info_log(seed, arm): + """Return the newest info.log path for the given seed and arm.""" + exp = f"domino-agent_oracle_hybrid_sim_oracle_samplers_{arm}" + pat = f"logs/agent_sim_learning/{exp}/seed{seed}/run_*/info.log" + hits = sorted(glob(pat)) + if not hits: + raise SystemExit(f"no info.log at {pat}") + return hits[-1] + + +def extract_sketches(info_log): + """Return {task_idx (0-based): {"outcome": str, "sketches": [[step,...]]}}. + + Each step is (option_name, [obj_names], raw_subgoal_str). + """ + tasks, pending, cur = {}, [], None + with open(info_log, encoding="utf-8") as f: + for raw in f: + line = ANSI.sub("", raw.rstrip("\n")) + if SKETCH_HDR.search(line): + cur = [] + pending.append(cur) + continue + m = STEP.match(line) + if m and cur is not None: + opt, args, sg = m.group(1), m.group(2), m.group(3) or "" + objs = [ + a.split(":")[0].strip() for a in args.split(",") + if a.strip() + ] + cur.append((opt, objs, sg)) + continue + cur = None # any non-step line ends the current sketch block + tm = TASK_RES.search(line) + if tm: + ti = int(tm.group(1) or tm.group(3)) - 1 + outcome = (tm.group(2) or tm.group(4) or "").strip() + tasks[ti] = {"outcome": outcome, "sketches": pending} + pending = [] + return tasks + + +def typed_text(steps, name_to_type): + """Rebuild typed sketch text the option-plan parser expects.""" + lines = [] + for opt, objs, sg in steps: + typed = ", ".join(f"{o}:{name_to_type.get(o, 'object')}" for o in objs) + line = f"{opt}({typed})" + if sg: + line += f" -> {{{sg}}}" + lines.append(line) + return "\n".join(lines) + + +def main(): + """Replay the recorded sketches for one seed through real refinement.""" + seed = int(sys.argv[1]) + arm = sys.argv[2] if len(sys.argv) > 2 else "no_demo" + replay_all = "--all" in sys.argv + + info_log = find_info_log(seed, arm) + tasks = extract_sketches(info_log) + + # These imports are deferred until after reset_config because the + # imported modules read CFG at import time. + # pylint: disable=import-outside-toplevel + from predicators import utils + utils.reset_config(dict(_FLAGS, seed=seed)) + from predicators.agent_sdk import bilevel_sketch + from predicators.approaches import create_approach + from predicators.envs import get_or_create_env + from predicators.ground_truth_models import get_gt_options + from predicators.settings import CFG + + # pylint: enable=import-outside-toplevel + + env = get_or_create_env("pybullet_domino") + options = get_gt_options(env.get_name()) + preds, _ = utils.parse_config_excluded_predicates(env) + train_tasks = [t.task for t in env.get_train_tasks()] + approach = create_approach("agent_sim_learning", preds, options, env.types, + env.action_space, train_tasks) + approach._maybe_install_oracle_samplers() # pylint: disable=protected-access + test_tasks = env.get_test_tasks() + name_to_type = {o.name: o.type.name for o in test_tasks[0].task.init} + + print(f"# seed{seed} {arm}: replaying recorded sketches through real " + f"refinement (oracle option-model + oracle samplers, no LLM)") + for ti in sorted(tasks): + rec = tasks[ti] + solved = rec["outcome"].upper().startswith("SOLVED") + if solved and not replay_all: + continue + task = test_tasks[ti].task + print(f"\n== task{ti} (run Task{ti+1}) | " + f"run outcome: {rec['outcome'][:60]}") + if not rec["sketches"]: + print(" (no sketches recorded)") + continue + for si, steps in enumerate(rec["sketches"]): + sketch = bilevel_sketch.parse_sketch_from_text( + typed_text(steps, name_to_type), + task, + predicates=preds, + options=set(options), + types=env.types) + if not sketch: + print(f" sketch{si}: unparseable") + continue + any_success = False + deepest = (-1, "") + for r in range(CFG.agent_bilevel_max_refine_retries): + fail = {"idx": -1, "reason": ""} + + # _f snapshots this iteration's ``fail`` dict at definition + # time; rec_fail mutates it in place and is consumed within + # this same iteration, so the default-arg capture is safe. + def rec_fail( # pylint: disable=dangerous-default-value + idx, + _prefix, + reason, + _f=fail): + if idx > _f["idx"]: + _f["idx"], _f["reason"] = idx, reason + + attempt = si * CFG.agent_bilevel_max_refine_retries + r + _, success = approach._refine_sketch( # pylint: disable=protected-access + task, + sketch, + 600.0, + attempt=attempt, + on_step_fail=rec_fail) + if success: + any_success = True + break + if fail["idx"] > deepest[0]: + deepest = (fail["idx"], fail["reason"]) + verdict = "REFINED-OK" if any_success else \ + f"FAILED (stuck step {deepest[0]}: {deepest[1][:60]})" + head = " -> ".join(f"{o}({','.join(a)})" for o, a, _ in steps) + print(f" sketch{si} [{len(steps)} steps]: {verdict}") + print(f" {head}") + + +if __name__ == "__main__": + main() diff --git a/scripts/reproduce_domino_failures.py b/scripts/reproduce_domino_failures.py new file mode 100644 index 000000000..b96b2b5b5 --- /dev/null +++ b/scripts/reproduce_domino_failures.py @@ -0,0 +1,142 @@ +"""Deterministic, LLM-free reproduction of the domino oracle-samplers failures. + +Reproduces the geometric / parsing root causes behind the unsolved tasks in +``domino-agent_oracle_hybrid_sim_oracle_samplers_{demo,no_demo}`` (seeds 0-4), +*without* invoking the LLM sketcher. The test-task scenes are deterministic +given the seed, so the BiRRT motion-planning infeasibilities and the option-plan +parser bug reproduce exactly. + +IMPORTANT: run ONE seed per process. Generating tasks for several seeds inside +one interpreter advances the shared RNG and changes the scenes (e.g. seed1 would +regenerate as [4,5,5,5,4] dominoes instead of the real [4,4,5,4,4]). The bash +wrapper at the bottom of the module docstring loops correctly. + +Usage: + # motion-planning reproduction for a single seed (fresh process each): + for s in 0 1 2 3 4; do \ + PYTHONPATH=. python scripts/reproduce_domino_failures.py mp $s; \ + done + # option-plan parser (Push) bug: + PYTHONPATH=. python scripts/reproduce_domino_failures.py push 0 +""" + +import logging +import sys + +import numpy as np + +from predicators import utils + +logging.disable(logging.CRITICAL) + +# Geometry-affecting flags copied verbatim from the experiment command line. +_ARGS = { + "env": "pybullet_domino", + "approach": "oracle", + "num_train_tasks": 1, + "num_test_tasks": 5, + "pybullet_ik_validate": False, + "skill_phase_use_motion_planning": True, + "domino_initialize_at_finished_state": False, + "domino_use_domino_blocks_as_target": True, + "domino_use_continuous_place": True, + "domino_restricted_push": True, + "domino_has_glued_dominos": False, + "pybullet_birrt_extend_num_interp": 20, + "pybullet_birrt_path_subsample_ratio": 2, +} +_GRASP_Z_OFFSET = 0.0825 # value used by the oracle Pick sampler in the runs. +_POS_GAP = 0.098 # domino chain spacing (env.py: domino_width * 1.4). +_MAX_STEPS = 80 + + +def _setup(seed): # pylint: disable=redefined-outer-name + args = dict(_ARGS, seed=seed) + utils.reset_config(args) + # Deferred until after reset_config: these modules read CFG at import. + # pylint: disable=import-outside-toplevel + from predicators.envs import get_or_create_env + from predicators.ground_truth_models import get_gt_options + + # pylint: enable=import-outside-toplevel + env = get_or_create_env("pybullet_domino") + options = get_gt_options(env.get_name()) + return env, options + + +def _run_option(env, opt, state): + """Drive a grounded option to termination; return (ok, failure_msg).""" + if not opt.initiable(state): + return None, "not-initiable" + s = state + for _ in range(_MAX_STEPS): + try: + a = opt.policy(s) + except utils.OptionExecutionFailure as e: + return False, str(e) + s = env.step(a) + if opt.terminal(s): + return True, "ok" + return True, "ran-max-steps" + + +def reproduce_mp(seed): # pylint: disable=redefined-outer-name + """For each test task, report which dominoes are grasp-infeasible and probe + one Place/MoveToDrop into the tight InFront gap.""" + env, options = _setup(seed) + Pick = next(o for o in options if o.name == "Pick") + tasks = env.get_test_tasks() + for ti in range(len(tasks)): + env.reset("test", ti) + st = env._current_state # pylint: disable=protected-access + dominoes = sorted([o for o in st if o.type.name == "domino"], + key=lambda o: o.name) + infeasible = [] + for d in dominoes: + env.reset("test", ti) + s = env._current_state # pylint: disable=protected-access + rb = next(o for o in s if o.type.name == "robot") + dd = next(o for o in s if o.name == d.name) + opt = Pick.ground([rb, dd], + np.array([_GRASP_Z_OFFSET], dtype=np.float32)) + ok, _ = _run_option(env, opt, s) + if ok is False: + infeasible.append(d.name) + print( + f"seed{seed} task{ti} (run Task{ti+1}): {len(dominoes)} dominoes " + f"| grasp-INFEASIBLE: {infeasible if infeasible else 'none'}") + + +def reproduce_push_bug(seed): # pylint: disable=redefined-outer-name + """Show the option-plan parser silently drops a Push line that names a + target domino, because Push is registered with types=[robot].""" + env, options = _setup(seed) + push = next(o for o in options if o.name == "Push") + print(f"Push option signature: types={[t.name for t in push.types]}") + state = env.get_test_tasks()[0].init + objects = list(state) + cases = { + "LLM-style 'Push(robot, domino_0)'": + "Pick(robot:robot, domino_1:domino)\n" + "Push(robot:robot, domino_0:domino)\nWait(robot:robot)", + "legal 'Push(robot)'": + "Pick(robot:robot, domino_1:domino)\n" + "Push(robot:robot)\nWait(robot:robot)", + } + for label, txt in cases.items(): + plan = utils.parse_model_output_into_option_plan( + txt, objects, env.types, options, parse_continuous_params=False) + names = [op.name for op, _, _ in plan] + flag = "PUSH DROPPED!" if "Push" not in names else "ok" + print(f" {label:42s} -> {names} ({flag})") + + +if __name__ == "__main__": + mode = sys.argv[1] if len(sys.argv) > 1 else "mp" + seed = int(sys.argv[2]) if len(sys.argv) > 2 else 0 + if mode == "mp": + reproduce_mp(seed) + elif mode == "push": + reproduce_push_bug(seed) + else: + raise SystemExit(f"unknown mode {mode!r} (expected 'mp' or 'push')") diff --git a/scripts/scripted_option_policies/domino2.txt b/scripts/scripted_option_policies/domino2.txt index 825bc2047..0ebfd4d9a 100644 --- a/scripts/scripted_option_policies/domino2.txt +++ b/scripts/scripted_option_policies/domino2.txt @@ -1,6 +1,6 @@ Plan: Pick(robot:robot, domino_1:domino)[0.0825] -Place(robot:robot)[0.75, 1.259, 0.5695, -1.57] +Place(robot:robot)[0.889810, 1.326641, 0.58, 0.785398] Pick(robot:robot, domino_2:domino)[0.0825] -Place(robot:robot)[0.85, 1.259, 0.5695, -1.57] +Place(robot:robot)[0.820513, 1.360937, 0.58, 1.570796] Push(robot:robot)[0.045, 0.0825] diff --git a/tests/agent_sdk/test_bilevel_sketch_samplers.py b/tests/agent_sdk/test_bilevel_sketch_samplers.py new file mode 100644 index 000000000..47ae023e7 --- /dev/null +++ b/tests/agent_sdk/test_bilevel_sketch_samplers.py @@ -0,0 +1,243 @@ +"""Tests for per-skill synthesized samplers in bilevel_sketch refinement. + +Verifies that a sampler registered under an option name in +``option_samplers`` is consulted (with the step's subgoal + objects + +the option's params box) to draw that option's continuous params during +refinement — on both the plain and info-seeking paths — and that a +missing / misbehaving sampler falls back to uniform sampling so +refinement is byte-for-byte unchanged when no usable sampler is +supplied. +""" + +# pylint: disable=unused-import + +import numpy as np +from gym.spaces import Box + +from predicators import utils # noqa: F401 (settles import order) +from predicators.agent_sdk import bilevel_sketch +from predicators.agent_sdk.bilevel_sketch import SketchStep, sample_params +from predicators.structs import Action, GroundAtom, Object, \ + ParameterizedOption, Predicate, State, Task, Type + +_block_type = Type("block", ["x"]) +_block = Object("block0", _block_type) + + +def _noop_policy(_s, _m, _o, _p): + return Action(np.zeros(1, dtype=np.float32)) + + +def _true(_s, _m, _o, _p): + return True + + +def _false(_s, _m, _o, _p): + return False + + +# A 1-D option whose parameter becomes the post-state x of the block. +_Move = ParameterizedOption( + "Move", + types=[_block_type], + params_space=Box(low=np.array([0.0], dtype=np.float32), + high=np.array([1.0], dtype=np.float32)), + policy=_noop_policy, + initiable=_true, + terminal=_false, +) + + +class _FakeOptionModel: + """Deterministic model: Move sets block.x to its parameter value.""" + + last_execution_failure = None + + def __init__(self): + self.num_calls = 0 + + def get_next_state_and_num_actions(self, state, option): + """Roll the option forward one step, counting the call.""" + self.num_calls += 1 + nxt = state.copy() + if len(option.params): + nxt.set(_block, "x", float(option.params[0])) + return nxt, 1 + + +# Subgoal uniform sampling hits only ~10% of the time (x >= 0.9), but a +# targeted sampler lands on the first draw. +_ReachedHi = Predicate("ReachedHi", [_block_type], + lambda s, o: s.get(o[0], "x") >= 0.9) +# Always-true subgoal so the first draw (uniform or sampled) is accepted. +_Reached = Predicate("Reached", [_block_type], lambda s, o: True) + + +def _task_hi(): + init = State({_block: np.array([0.0], dtype=np.float32)}) + return Task(init, {GroundAtom(_ReachedHi, [_block])}) + + +def _sketch_hi(): + return [ + SketchStep(option=_Move, + objects=[_block], + subgoal_atoms={GroundAtom(_ReachedHi, [_block])}) + ] + + +def _easy_task_and_sketch(): + sketch = [ + SketchStep(option=_Move, + objects=[_block], + subgoal_atoms={GroundAtom(_Reached, [_block])}) + ] + task = Task(State({_block: np.array([0.0], dtype=np.float32)}), + {GroundAtom(_Reached, [_block])}) + return task, sketch + + +def test_registered_sampler_is_used(): + """A targeted sampler lands the hard subgoal on the first sample.""" + calls = [] + + def sampler(state, subgoal_atoms, rng, objects): + del state, rng + calls.append((objects, subgoal_atoms)) + return np.array([0.95], dtype=np.float32) + + model = _FakeOptionModel() + plan, success, total = bilevel_sketch.refine_sketch( + _task_hi(), + _sketch_hi(), + model, + predicates={_ReachedHi}, + timeout=10.0, + rng=np.random.default_rng(0), + max_samples_per_step=50, + check_subgoals=True, + check_final_goal=False, + option_samplers={"Move": sampler}) + assert success + assert np.isclose(float(plan[0].params[0]), 0.95) + # Feasible on the very first attempt — none of the uniform churn. + assert total == 1 + assert model.num_calls == 1 + # The sampler saw the right subgoal and objects. + objs, subgoal = calls[0] + assert [o.name for o in objs] == ["block0"] + assert GroundAtom(_ReachedHi, [_block]) in subgoal + + +def test_missing_entry_falls_back_to_uniform(): + """A sampler keyed by another option leaves Move on the uniform path.""" + seed = 7 + first = float(sample_params(_Move, np.random.default_rng(seed))[0]) + task, sketch = _easy_task_and_sketch() + + def other(*_args): + raise AssertionError("sampler for a different option was called") + + plan, success, _ = bilevel_sketch.refine_sketch( + task, + sketch, + _FakeOptionModel(), + predicates={_Reached}, + timeout=10.0, + rng=np.random.default_rng(seed), + max_samples_per_step=50, + check_subgoals=True, + check_final_goal=False, + option_samplers={"OtherOption": other}) + assert success + # Identical to the no-sampler uniform draw. + assert float(plan[0].params[0]) == first + + +def test_bad_shape_falls_back_to_uniform(): + """A wrong-shaped return is rejected; uniform sampling still succeeds.""" + task, sketch = _easy_task_and_sketch() + + def bad(*_args): + return np.array([0.5, 0.5], dtype=np.float32) # shape (2,) != (1,) + + plan, success, _ = bilevel_sketch.refine_sketch( + task, + sketch, + _FakeOptionModel(), + predicates={_Reached}, + timeout=10.0, + rng=np.random.default_rng(0), + max_samples_per_step=50, + check_subgoals=True, + check_final_goal=False, + option_samplers={"Move": bad}) + assert success + assert 0.0 <= float(plan[0].params[0]) <= 1.0 + + +def test_raising_sampler_falls_back_to_uniform(): + """A sampler that raises is caught and uniform sampling proceeds.""" + task, sketch = _easy_task_and_sketch() + + def boom(*_args): + raise ValueError("nope") + + _, success, _ = bilevel_sketch.refine_sketch( + task, + sketch, + _FakeOptionModel(), + predicates={_Reached}, + timeout=10.0, + rng=np.random.default_rng(0), + max_samples_per_step=50, + check_subgoals=True, + check_final_goal=False, + option_samplers={"Move": boom}) + assert success + + +def test_none_samplers_unchanged(): + """option_samplers=None reproduces the plain first-uniform-draw param.""" + seed = 7 + first = float(sample_params(_Move, np.random.default_rng(seed))[0]) + task, sketch = _easy_task_and_sketch() + plan, success, _ = bilevel_sketch.refine_sketch( + task, + sketch, + _FakeOptionModel(), + predicates={_Reached}, + timeout=10.0, + rng=np.random.default_rng(seed), + max_samples_per_step=50, + check_subgoals=True, + check_final_goal=False, + option_samplers=None) + assert success + assert float(plan[0].params[0]) == first + + +def test_sampler_used_on_info_seeking_path(): + """The info-seeking draw loop also routes through the sampler.""" + + def sampler(_s, _a, rng, _o): + # Jitter so candidates differ but all clear the x>=0.9 subgoal. + return np.array([0.9 + 0.05 * rng.random()], dtype=np.float32) + + model = _FakeOptionModel() + plan, success, _ = bilevel_sketch.refine_sketch( + _task_hi(), + _sketch_hi(), + model, + predicates={_ReachedHi}, + timeout=10.0, + rng=np.random.default_rng(0), + max_samples_per_step=50, + check_subgoals=True, + check_final_goal=False, + info_scorer=lambda s, _a: s.get(_block, "x"), + info_n_feasible_target=4, + option_samplers={"Move": sampler}) + assert success + # Every pooled candidate came from the sampler => satisfies x >= 0.9. + assert float(plan[0].params[0]) >= 0.9 diff --git a/tests/agent_sdk/test_tool_registry.py b/tests/agent_sdk/test_tool_registry.py index 4e2095a1a..eb3cf4ccc 100644 --- a/tests/agent_sdk/test_tool_registry.py +++ b/tests/agent_sdk/test_tool_registry.py @@ -90,14 +90,14 @@ def test_solve_and_synthesis_tool_names_are_independent() -> None: class _Approach(AgentSessionMixin): def _get_solve_tool_names(self) -> Optional[List[str]]: - return ["inspect_options", "test_option_plan"] + return ["inspect_options", "evaluate_option_plan"] def _get_synthesis_tool_names(self) -> Optional[List[str]]: return ["inspect_trajectories", "visualize_state"] obj = _Approach() assert obj._get_solve_tool_names() == [ - "inspect_options", "test_option_plan" + "inspect_options", "evaluate_option_plan" ] assert obj._get_synthesis_tool_names() == [ "inspect_trajectories", "visualize_state" diff --git a/tests/approaches/test_agent_bilevel_approach.py b/tests/approaches/test_agent_bilevel_approach.py index 63faf2d19..4323ddb63 100644 --- a/tests/approaches/test_agent_bilevel_approach.py +++ b/tests/approaches/test_agent_bilevel_approach.py @@ -245,6 +245,31 @@ def test_typed_object_refs_in_subgoals(self): pos2, _ = result[1] assert GroundAtom(_On, [_block0, _block1]) in pos2 + def test_numbered_prefix_subgoals(self): + """Agent numbers the lines (0:, 1:) — annotations must still align. + + Mirrors a real failure: the agent mirrored the numbered sketch + format shown in logs, embedding it between prose, and the + numbered prefix made every line parse as a non-option line so + the annotation list came back empty/misaligned. + """ + approach, _, _ = _make_approach() + text = ("Some analysis the agent wrote first.\n" + " 0: Pick(block0:block) -> {Holding(block0:block)}\n" + " 1: Place(block0:block, block1:block) " + "-> {On(block0:block, block1:block)}\n" + "Rationale: ...\n") + result = approach._parse_subgoal_annotations(text, _ALL_PREDICATES, + _ALL_OBJECTS) + + assert len(result) == 2 + assert result[0] is not None + pos, _ = result[0] + assert GroundAtom(_Holding, [_block0]) in pos + assert result[1] is not None + pos2, _ = result[1] + assert GroundAtom(_On, [_block0, _block1]) in pos2 + def test_preamble_ignored(self): """Non-option lines should be ignored.""" approach, _, _ = _make_approach() diff --git a/tests/envs/test_pybullet_domino_composed.py b/tests/envs/test_pybullet_domino_composed.py index 8b8212346..d1ff94246 100644 --- a/tests/envs/test_pybullet_domino_composed.py +++ b/tests/envs/test_pybullet_domino_composed.py @@ -7,6 +7,8 @@ DominoComponent from predicators.envs.pybullet_domino.components.grid_component import \ GridComponent +from predicators.envs.pybullet_domino.task_generators import \ + DominoTaskGenerator from predicators.settings import CFG from predicators.structs import Object, State, Type @@ -77,6 +79,57 @@ def test_place_target_domino(self) -> None: assert d["r"] == pytest.approx(0.85, abs=0.01) +def test_unfinished_state_avoids_staging_collisions() -> None: + """Test unfinished movable blocks avoid start/target blocks.""" + workspace_bounds = { + "x_lb": 0.4, + "x_ub": 1.1, + "y_lb": 1.1, + "y_ub": 1.6, + "z_lb": 0.4, + "z_ub": 0.95, + } + CFG.domino_use_domino_blocks_as_target = True + CFG.domino_has_glued_dominos = False + comp = DominoComponent(num_dominos_max=5, + num_targets_max=2, + num_pivots_max=1, + workspace_bounds=workspace_bounds) + robot = Object("robot", Type("robot", ["x", "y", "z"])) + generator = DominoTaskGenerator(comp, robot, {}) + + first_staging_x = comp.domino_x_lb + comp.domino_width + first_staging_y = comp.domino_y_lb + comp.domino_width + obj_dict = { + comp.dominos[0]: + comp.place_domino(0, + first_staging_x, + first_staging_y, + 0.0, + is_start_block=True), + comp.dominos[1]: + comp.place_domino(1, + first_staging_x + 0.25, + first_staging_y, + 0.0, + is_target_block=True), + comp.dominos[2]: + comp.place_domino(2, 0.9, 1.35, 0.0), + } + + # pylint: disable=protected-access + moved = generator._move_intermediate_objects_to_unfinished_state(obj_dict) + + assert moved is not None + movable = comp.dominos[2] + assert not generator._placement_collides( + movable, moved[movable], { + comp.dominos[0]: moved[comp.dominos[0]], + comp.dominos[1]: moved[comp.dominos[1]], + }) + assert moved[movable]["x"] != pytest.approx(first_staging_x) + + class TestGridComponent: """Tests for GridComponent.""" diff --git a/tests/ground_truth_models/test_domino_gt_samplers.py b/tests/ground_truth_models/test_domino_gt_samplers.py new file mode 100644 index 000000000..8f35b4f51 --- /dev/null +++ b/tests/ground_truth_models/test_domino_gt_samplers.py @@ -0,0 +1,225 @@ +"""Tests for the domino ground-truth grid-free per-skill samplers. + +Exercises the ``OptionSampler``-signature samplers exposed by +``PyBulletDominoGroundTruthSamplerFactory`` (in domino/processes.py). +The ``Place`` sampler is checked against the *real* ``InFront`` / +``Upright`` classifiers (called via a lightweight stub ``self`` so no +PyBullet env is built) — a placement it returns for an ``InFront`` +subgoal must actually satisfy that subgoal. +""" + +# pylint: disable=unused-import + +from typing import cast + +import numpy as np +from gym.spaces import Box + +from predicators import utils # noqa: F401 (settles import order) +from predicators.envs.pybullet_domino.components.domino_component import \ + DominoComponent +from predicators.ground_truth_models import get_gt_samplers +from predicators.ground_truth_models.domino.processes import \ + _place_option_sampler +from predicators.structs import GroundAtom, Object, Predicate, State, Type + +# Domino feature layout (matches the env's domino type). +_domino_type = Type("domino", + ["x", "y", "z", "yaw", "roll", "r", "g", "b", "is_held"]) +_robot_type = Type("robot", ["x"]) + +# Place option's continuous-parameter box: (target_x, target_y, release_z, +# target_yaw). +_PLACE_BOX = Box(low=np.array([0.4, 1.1, 0.5, -np.pi], dtype=np.float32), + high=np.array([1.1, 1.6, 0.6, np.pi], dtype=np.float32)) + + +class _ClassifierStub: + """Stub exposing the constants the InFront/Upright classifiers read.""" + pos_gap = 0.098 + domino_width = 0.07 + domino_roll_threshold = np.deg2rad(5) + + +_stub = cast(DominoComponent, _ClassifierStub()) +_InFront = Predicate("InFront", [_domino_type, _domino_type], + lambda s, o: DominoComponent._InFront_holds(_stub, s, o)) # pylint: disable=protected-access +_Upright = Predicate("Upright", [_domino_type], + lambda s, o: DominoComponent._Upright_holds(_stub, s, o)) # pylint: disable=protected-access + + +def _domino(name, x, y, yaw, is_held=0.0, rgb=(0.5, 0.5, 0.5)): + feats = { + "x": x, + "y": y, + "z": 0.475, + "yaw": yaw, + "roll": 0.0, + "r": rgb[0], + "g": rgb[1], + "b": rgb[2], + "is_held": is_held, + } + obj = Object(name, _domino_type) + return obj, feats + + +def _make_state(objs_and_feats): + data = {} + for obj, feats in objs_and_feats: + data[obj] = np.array([feats[f] for f in _domino_type.feature_names], + dtype=np.float32) + return State(data) + + +def test_factory_exposes_place_pick_push(): + """The domino factory registers grid-free samplers for all 3 skills.""" + samplers = get_gt_samplers("pybullet_domino") + assert samplers is not None + assert set(samplers) == {"Pick", "Push", "Place"} + + +def test_place_sampler_satisfies_infront_subgoal(): + """Placement for InFront(held, ref) actually makes InFront hold.""" + robot = Object("robot", _robot_type) + # Reference domino_0 at a cardinal facing (yaw=0); held domino_1 parked + # elsewhere (its current pose is irrelevant — the sampler computes a new + # placement from the subgoal). + d0, f0 = _domino("domino_0", x=0.8, y=1.3, yaw=0.0) + d1, f1 = _domino("domino_1", x=0.5, y=1.5, yaw=0.0, is_held=1.0) + state = _make_state([(d0, f0), (d1, f1)]) + state.data[robot] = np.array([0.0], dtype=np.float32) + + subgoal = {GroundAtom(_InFront, [d1, d0]), GroundAtom(_Upright, [d1])} + rng = np.random.default_rng(0) + params = _place_option_sampler(state, subgoal, rng, [robot]) + + assert params.shape == (4, ) + assert np.all(params >= _PLACE_BOX.low - 1e-6) + assert np.all(params <= _PLACE_BOX.high + 1e-6) + + # Apply the placement and confirm the subgoal now holds. + placed = state.copy() + placed.set(d1, "x", float(params[0])) + placed.set(d1, "y", float(params[1])) + placed.set(d1, "yaw", float(params[3])) + placed.set(d1, "roll", 0.0) + placed.set(d1, "is_held", 0.0) + assert GroundAtom(_InFront, [d1, d0]).holds(placed) + assert GroundAtom(_Upright, [d1]).holds(placed) + # The exact pose is no longer pinned: with a single subgoal the sampler + # randomizes among the tied-best straight / +-45 turn placements (see + # test_place_sampler_randomizes_turn_offset). All that is guaranteed is + # that the drawn placement satisfies the subgoal, checked above. + + +def test_place_sampler_randomizes_turn_offset(): + """The sampler explores straight and +-45 turn placements across draws. + + A single InFront subgoal is satisfied equally by a straight + placement and by a +-45 turn, so if the sampler always returned the + same one, backtracking that re-draws an upstream Place could never + turn a chain that needs a bend. Every draw must still satisfy the + subgoal. + """ + robot = Object("robot", _robot_type) + d0, f0 = _domino("domino_0", x=0.8, y=1.3, yaw=0.0) + d1, f1 = _domino("domino_1", x=0.5, y=1.5, yaw=0.0, is_held=1.0) + state = _make_state([(d0, f0), (d1, f1)]) + state.data[robot] = np.array([0.0], dtype=np.float32) + subgoal = {GroundAtom(_InFront, [d1, d0]), GroundAtom(_Upright, [d1])} + + saw_straight = False + saw_turn = False + for seed in range(40): + params = _place_option_sampler(state, subgoal, + np.random.default_rng(seed), [robot]) + placed = state.copy() + placed.set(d1, "x", float(params[0])) + placed.set(d1, "y", float(params[1])) + placed.set(d1, "yaw", float(params[3])) + placed.set(d1, "roll", 0.0) + placed.set(d1, "is_held", 0.0) + # Whatever offset was drawn, the subgoal must hold. + assert GroundAtom(_InFront, [d1, d0]).holds(placed) + turn = abs(utils.wrap_angle(float(params[3]))) + if turn < np.radians(10): + saw_straight = True + elif abs(turn - np.pi / 4) < np.radians(10): + saw_turn = True + assert saw_straight, "sampler never produced a straight placement" + assert saw_turn, "sampler never produced a +-45 turn placement" + + +def test_place_sampler_prefers_target_bridgeable_first_placement(): + """When a purple target is visible, tie-break toward a completable chain. + + In the seed-0 test layout, every first placement of domino_1 + satisfies ``InFront(domino_1, domino_0)`` locally, but only the + +45-degree placement leaves a one-domino bridge point that can also + connect to the purple target. + """ + robot = Object("robot", _robot_type) + d0, f0 = _domino("domino_0", x=0.9146, y=1.2534, yaw=0.0) + d1, f1 = _domino("domino_1", x=0.47, y=1.2975, yaw=0.0, is_held=1.0) + d2, f2 = _domino("domino_2", x=0.575, y=1.2975, yaw=0.0) + d3, f3 = _domino("domino_3", + x=0.7225, + y=1.3609, + yaw=np.pi / 2, + rgb=(0.85, 0.7, 0.85)) + state = _make_state([(d0, f0), (d1, f1), (d2, f2), (d3, f3)]) + state.data[robot] = np.array([0.0], dtype=np.float32) + subgoal = {GroundAtom(_InFront, [d1, d0]), GroundAtom(_Upright, [d1])} + + params = _place_option_sampler(state, subgoal, np.random.default_rng(0), + [robot]) + + assert np.allclose(params[:2], [0.88985, 1.32665], atol=1e-3) + assert np.isclose(float(params[2]), 0.58) + assert abs(utils.wrap_angle(float(params[3]) - np.pi / 4)) < 1e-3 + + +def test_place_sampler_chain_between_two_references(): + """A two-InFront subgoal lands the held domino on the shared chain + point.""" + robot = Object("robot", _robot_type) + # Collinear chain along +y at pos_gap spacing: d1 -- (d2 held) -- d3. + gap = 0.098 + d1, f1 = _domino("domino_1", x=0.8, y=1.30, yaw=0.0) + d3, f3 = _domino("domino_3", x=0.8, y=1.30 + 2 * gap, yaw=0.0) + d2, f2 = _domino("domino_2", x=0.5, y=1.5, yaw=0.0, is_held=1.0) + state = _make_state([(d1, f1), (d2, f2), (d3, f3)]) + state.data[robot] = np.array([0.0], dtype=np.float32) + + subgoal = { + GroundAtom(_InFront, [d2, d1]), + GroundAtom(_InFront, [d3, d2]), + GroundAtom(_Upright, [d2]), + } + params = _place_option_sampler(state, subgoal, np.random.default_rng(0), + [robot]) + placed = state.copy() + placed.set(d2, "x", float(params[0])) + placed.set(d2, "y", float(params[1])) + placed.set(d2, "yaw", float(params[3])) + placed.set(d2, "roll", 0.0) + placed.set(d2, "is_held", 0.0) + # Both InFront atoms satisfied at once (the shared midpoint). + assert GroundAtom(_InFront, [d2, d1]).holds(placed) + assert GroundAtom(_InFront, [d3, d2]).holds(placed) + + +def test_place_sampler_raises_without_held_domino(): + """No held domino => raise so refinement falls back to uniform.""" + robot = Object("robot", _robot_type) + d0, f0 = _domino("domino_0", x=0.8, y=1.3, yaw=0.0) + state = _make_state([(d0, f0)]) + state.data[robot] = np.array([0.0], dtype=np.float32) + subgoal = {GroundAtom(_Upright, [d0])} + try: + _place_option_sampler(state, subgoal, np.random.default_rng(0), + [robot]) + assert False, "expected ValueError" + except ValueError: + pass diff --git a/tests/test_agent_sdk_tools.py b/tests/test_agent_sdk_tools.py index 0b17bcb3e..1657e9a78 100644 --- a/tests/test_agent_sdk_tools.py +++ b/tests/test_agent_sdk_tools.py @@ -2,9 +2,9 @@ Validates: 1. inspect_options with option_name saves source code to sandbox -2. test_option_plan always saves scene images -3. test_option_plan shows "Missing goal atoms" when goal not achieved -4. test_option_plan shows object poses on failure +2. evaluate_option_plan always saves scene images +3. evaluate_option_plan shows "Missing goal atoms" when goal not achieved +4. evaluate_option_plan shows object poses on failure 5. propose_options saves code to sandbox/proposed_code/ 6. _format_object_poses helper 7. _render_scene_image helper @@ -37,7 +37,6 @@ "domino_use_continuous_place": True, "domino_use_skill_factories": True, "domino_use_domino_blocks_as_target": True, - "domino_use_grid": True, "domino_has_glued_dominos": False, "domino_initialize_at_finished_state": False, "num_train_tasks": 1, @@ -263,14 +262,15 @@ def _get_valid_option_plan_step(ctx: Any) -> dict[str, Any] | None: def test_option_plan_missing_goal_atoms(ctx: Any) -> None: - """test_option_plan reports missing goal atoms when goal not achieved.""" - tools = _make_tools(ctx, ["test_option_plan"]) + """evaluate_option_plan reports missing goal atoms when goal not + achieved.""" + tools = _make_tools(ctx, ["evaluate_option_plan"]) step = _get_valid_option_plan_step(ctx) assert step is not None, "No valid option found for testing" plan = [step] - result = _run(tools["test_option_plan"]({ + result = _run(tools["evaluate_option_plan"]({ "option_plan": plan, "include_atoms": True, })) @@ -284,21 +284,21 @@ def test_option_plan_missing_goal_atoms(ctx: Any) -> None: # agents). assert ("Missing goal atoms:" in text or "Goal (natural language):" in text) - print(" PASS: test_option_plan (failure diagnostic shown)") + print(" PASS: evaluate_option_plan (failure diagnostic shown)") elif "Goal achieved: True" in text: assert "Missing goal atoms:" not in text - print(" PASS: test_option_plan (goal achieved, no missing atoms)") + print(" PASS: evaluate_option_plan (goal achieved, no missing atoms)") else: # Plan failed early (grounding error, NOT INITIABLE, etc.) assert ("NOT INITIABLE" in text or "FAILURE REASON:" in text or "EXECUTION ERROR" in text or "Failed to ground" in text) - print(" PASS: test_option_plan (plan failed early, " + print(" PASS: evaluate_option_plan (plan failed early, " "goal check not reached)") def test_option_plan_not_initiable_shows_poses(ctx: Any) -> None: - """test_option_plan shows object poses when option is NOT INITIABLE.""" - tools = _make_tools(ctx, ["test_option_plan"]) + """evaluate_option_plan shows object poses when option is NOT INITIABLE.""" + tools = _make_tools(ctx, ["evaluate_option_plan"]) # Find Place option and try it without Pick first place_opt = None @@ -308,7 +308,7 @@ def test_option_plan_not_initiable_shows_poses(ctx: Any) -> None: break if place_opt is None: - print(" SKIP: test_option_plan (no Place option)") + print(" SKIP: evaluate_option_plan (no Place option)") return # Build object names from types @@ -330,33 +330,34 @@ def test_option_plan_not_initiable_shows_poses(ctx: Any) -> None: "params": params, }] - result = _run(tools["test_option_plan"]({ + result = _run(tools["evaluate_option_plan"]({ "option_plan": plan, })) text = result["content"][0]["text"] if "NOT INITIABLE" in text: assert "Object poses at failure:" in text - print(" PASS: test_option_plan (NOT INITIABLE shows poses)") + print(" PASS: evaluate_option_plan (NOT INITIABLE shows poses)") elif "Failed to ground" in text: - print(" SKIP: test_option_plan (Place could not be grounded)") + print(" SKIP: evaluate_option_plan (Place could not be grounded)") else: - print(" SKIP: test_option_plan (Place was initiable, " + print(" SKIP: evaluate_option_plan (Place was initiable, " "can't test NOT INITIABLE path)") def test_option_plan_saves_images(ctx: Any) -> None: - """test_option_plan always saves scene images (never returns inline).""" + """evaluate_option_plan always saves scene images (never returns + inline).""" with tempfile.TemporaryDirectory() as tmpdir: ctx.image_save_dir = tmpdir - tools = _make_tools(ctx, ["test_option_plan"]) + tools = _make_tools(ctx, ["evaluate_option_plan"]) step = _get_valid_option_plan_step(ctx) assert step is not None, "No valid option found for testing" plan = [step] - result = _run(tools["test_option_plan"]({ + result = _run(tools["evaluate_option_plan"]({ "option_plan": plan, })) @@ -368,22 +369,23 @@ def test_option_plan_saves_images(ctx: Any) -> None: # Check files were saved if env rendering works saved = [f for f in os.listdir(tmpdir) if f.endswith(".png")] if saved: - print(f" PASS: test_option_plan ({len(saved)} images saved)") + print(f" PASS: evaluate_option_plan ({len(saved)} images saved)") else: - print(" SKIP: test_option_plan (rendering not available)") + print(" SKIP: evaluate_option_plan (rendering not available)") ctx.image_save_dir = None def test_option_plan_failure_shows_poses(ctx: Any) -> None: - """test_option_plan shows object poses when option returns 0 actions.""" - tools = _make_tools(ctx, ["test_option_plan"]) + """evaluate_option_plan shows object poses when option returns 0 + actions.""" + tools = _make_tools(ctx, ["evaluate_option_plan"]) step = _get_valid_option_plan_step(ctx) assert step is not None, "No valid option found for testing" plan = [step] - result = _run(tools["test_option_plan"]({ + result = _run(tools["evaluate_option_plan"]({ "option_plan": plan, })) text = result["content"][0]["text"] @@ -394,12 +396,12 @@ def test_option_plan_failure_shows_poses(ctx: Any) -> None: or "Testing option plan" in text) if "FAILURE REASON:" in text: assert "Object poses at failure:" in text - print(" PASS: test_option_plan (failure shows poses)") + print(" PASS: evaluate_option_plan (failure shows poses)") elif "NOT INITIABLE" in text: assert "Object poses at failure:" in text - print(" PASS: test_option_plan (NOT INITIABLE shows poses)") + print(" PASS: evaluate_option_plan (NOT INITIABLE shows poses)") else: - print(" PASS: test_option_plan (no failures in output)") + print(" PASS: evaluate_option_plan (no failures in output)") def test_format_object_poses(ctx: Any) -> None: @@ -684,8 +686,8 @@ def main() -> None: test_inspect_options_unknown(ctx) test_inspect_options_proposed_code(ctx) - # test_option_plan tests - print("\n2. test_option_plan tests:") + # evaluate_option_plan tests + print("\n2. evaluate_option_plan tests:") test_option_plan_missing_goal_atoms(ctx) test_option_plan_not_initiable_shows_poses(ctx) test_option_plan_saves_images(ctx) diff --git a/tests/test_docker_option_plan.py b/tests/test_docker_option_plan.py index 488600647..d7b5fa2b9 100644 --- a/tests/test_docker_option_plan.py +++ b/tests/test_docker_option_plan.py @@ -1,4 +1,4 @@ -"""Test that test_option_plan produces correct results. +"""Test that evaluate_option_plan produces correct results. Validates that multi-step option plans (Pick→Place→Pick→Place→Push) produce non-zero actions at every step, both in-process and in a subprocess that @@ -38,7 +38,6 @@ "domino_use_continuous_place": True, "domino_use_skill_factories": True, "domino_use_domino_blocks_as_target": True, - "domino_use_grid": True, "domino_has_glued_dominos": False, "domino_initialize_at_finished_state": False, "num_train_tasks": 1, diff --git a/tests/test_skill_factories.py b/tests/test_skill_factories.py index 0e5991620..68d66637d 100644 --- a/tests/test_skill_factories.py +++ b/tests/test_skill_factories.py @@ -240,6 +240,7 @@ def dummy_target(_state, _objects, _params, _cfg): assert phase.action_type == PhaseAction.MOVE_TO_POSE assert phase.terminal_fn is None assert phase.use_motion_planning is False # default from CFG + assert not phase.allow_shallow_held_object_contacts def test_change_fingers_phase(self): """Test change fingers phase.""" @@ -276,6 +277,20 @@ def test_no_motion_planning_flag(self): ) assert phase.use_motion_planning is False + def test_move_to_phase_collision_metadata(self): + """Test move-to phase stores collision metadata.""" + + def dummy_pose(_state, _objects, _params, _cfg): + return 0.0, 0.0, 0.0, 0.0 + + phase = make_move_to_phase( + "Move", + dummy_pose, + allow_shallow_held_object_contacts=True, + ) + + assert phase.allow_shallow_held_object_contacts + # =========================================================================== # 3. PhaseSkill — structure and public-interface behaviour diff --git a/tests/test_skill_factories_integration.py b/tests/test_skill_factories_integration.py index 459795838..6cb0852c0 100644 --- a/tests/test_skill_factories_integration.py +++ b/tests/test_skill_factories_integration.py @@ -1124,7 +1124,6 @@ def test_pick_holds_domino_with_motion_planning(): "domino_use_skill_factories": True, "skill_phase_use_motion_planning": True, "pybullet_ik_validate": False, - "domino_use_grid": True, "domino_use_domino_blocks_as_target": True, "domino_restricted_push": True, "num_train_tasks": 1, @@ -1208,7 +1207,6 @@ def test_pick_holds_domino_without_motion_planning(): "domino_use_skill_factories": True, "skill_phase_use_motion_planning": False, "pybullet_ik_validate": False, - "domino_use_grid": True, "domino_use_domino_blocks_as_target": True, "domino_restricted_push": True, "num_train_tasks": 1, @@ -1297,7 +1295,6 @@ def test_domino_pick_place_no_collisions(): "pybullet_ik_validate": False, "domino_initialize_at_finished_state": False, "domino_use_domino_blocks_as_target": True, - "domino_use_grid": True, "domino_include_connected_predicate": False, "domino_use_continuous_place": True, "domino_restricted_push": True, @@ -1393,6 +1390,106 @@ def _check_moved(before, st, skip_names=()): f"Non-held dominoes moved during Place: {place_collisions}" +def test_domino_second_place_with_unvalidated_ik(): + """The seed-0 bridge placement for domino_2 should refine with + pybullet_ik_validate disabled. + + This covers a failure mode where the fast one-shot IK solution + reaches the EE target but leaves the held domino colliding with the + table, so collision-aware BiRRT needs to retry the IK target with + validation before declaring Place infeasible. + """ + try: + from predicators.envs.pybullet_domino import PyBulletDominoEnv + except ImportError: + pytest.skip("pybullet_domino not available") + + from predicators.ground_truth_models.domino.processes import \ + _pick_option_sampler, _place_option_sampler + from predicators.option_model import _OracleOptionModel + from predicators.structs import GroundAtom + + utils.reset_config({ + "env": "pybullet_domino", + "use_gui": False, + "pybullet_control_mode": "position", + "pybullet_robot": "fetch", + "domino_use_skill_factories": True, + "skill_phase_use_motion_planning": True, + "option_model_terminate_on_repeat": False, + "pybullet_ik_validate": False, + "domino_initialize_at_finished_state": False, + "domino_use_domino_blocks_as_target": True, + "domino_include_connected_predicate": False, + "domino_use_continuous_place": True, + "domino_restricted_push": True, + "domino_prune_actions": False, + "domino_has_glued_dominos": False, + "pybullet_birrt_extend_num_interp": 20, + "pybullet_birrt_path_subsample_ratio": 2, + "num_train_tasks": 1, + "num_test_tasks": 1, + }) + + class _ExposedDominoEnv( # type: ignore[misc] + _ExposedEnvMixin, PyBulletDominoEnv): + pass + + env = _ExposedDominoEnv(use_gui=False) + options = env._options + model = _OracleOptionModel(set(options.values()), env.simulate) + state = env.get_test_tasks()[0].init + objs = {o.name: o for o in state} + preds = {p.name: p for p in env.predicates} + robot = objs["robot"] + d0 = objs["domino_0"] + d1 = objs["domino_1"] + d2 = objs["domino_2"] + d3 = objs["domino_3"] + + def _run_option(option, cur_state): + next_state, num_actions = model.get_next_state_and_num_actions( + cur_state, option) + assert num_actions > 0, model.last_execution_failure + return next_state + + pick1 = options["Pick"].ground([robot, d1], + _pick_option_sampler( + state, set(), np.random.default_rng(0), + [robot, d1])) + state = _run_option(pick1, state) + + subgoal1 = { + GroundAtom(preds["InFront"], [d1, d0]), + GroundAtom(preds["HandEmpty"], [robot]), + } + place1 = options["Place"].ground([robot], + _place_option_sampler( + state, subgoal1, + np.random.default_rng(0), [robot])) + state = _run_option(place1, state) + + pick2 = options["Pick"].ground([robot, d2], + _pick_option_sampler( + state, set(), np.random.default_rng(0), + [robot, d2])) + state = _run_option(pick2, state) + + subgoal2 = { + GroundAtom(preds["InFront"], [d3, d2]), + GroundAtom(preds["InFront"], [d2, d1]), + GroundAtom(preds["HandEmpty"], [robot]), + } + place2 = options["Place"].ground([robot], + _place_option_sampler( + state, subgoal2, + np.random.default_rng(0), [robot])) + state = _run_option(place2, state) + + assert GroundAtom(preds["HandEmpty"], [robot]).holds(state) + assert state.get(d2, "is_held") < 0.5 + + @pytest.mark.xfail(reason="Button detection zone overlaps dispense area " "approach path — robot arm triggers button during place") def test_coffee_place_no_button_press(): @@ -1506,7 +1603,6 @@ def test_human_option_control_scripted_domino_solves_task(): "domino_use_skill_factories": True, "domino_initialize_at_finished_state": False, "domino_use_domino_blocks_as_target": True, - "domino_use_grid": True, "domino_include_connected_predicate": False, "domino_use_continuous_place": True, "domino_restricted_push": True, diff --git a/tests/test_utils.py b/tests/test_utils.py index 700c656e2..572be7a74 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -3574,3 +3574,29 @@ def test_parse_model_output_into_option_plan(): utils.parse_model_output_into_option_plan(options_str, [obj], [obj_type], options, False)) == 0 + # A numbered/enumerated line prefix ("0:", "1.") that agents emit when + # mirroring the logged sketch format must parse identically to the + # bare line; without prefix stripping the whole plan parses as empty. + pick_opt = next(o for o in options if o.name == "Pick") + robot_type, block_type = pick_opt.types + robby = Object("robby", robot_type) + b0 = Object("b0", block_type) + types = [robot_type, block_type] + objs = [robby, b0] + bare = "Pick(robby:robot, b0:block)" + bare_plan = utils.parse_model_output_into_option_plan( + bare, objs, types, options, False) + assert len(bare_plan) == 1 + for prefix in ("0: ", "1. ", "2) ", " 3: "): + numbered = prefix + bare + numbered_plan = utils.parse_model_output_into_option_plan( + numbered, objs, types, options, False) + assert len(numbered_plan) == 1 + assert numbered_plan[0][0].name == bare_plan[0][0].name + assert numbered_plan[0][1] == bare_plan[0][1] + # A prose bullet that merely mentions an option name is NOT a numbered + # plan line and must still be ignored (it is not stripped to an option). + prose = "- Step 1: Pick(robby:robot, b0:block) at the left side" + assert len( + utils.parse_model_output_into_option_plan(prose, objs, types, options, + False)) == 0