From 572a2aa35a39e3eb904005e32e28214855f5883a Mon Sep 17 00:00:00 2001 From: Andrew Klatzke Date: Mon, 4 May 2026 15:40:13 -0800 Subject: [PATCH] fix: better handling of params and custom params for optimization --- .../optimization/src/ldai_optimizer/client.py | 35 ++- packages/optimization/tests/test_client.py | 228 ++++++++++++++++++ 2 files changed, 262 insertions(+), 1 deletion(-) diff --git a/packages/optimization/src/ldai_optimizer/client.py b/packages/optimization/src/ldai_optimizer/client.py index c7927d4c..00353e3c 100644 --- a/packages/optimization/src/ldai_optimizer/client.py +++ b/packages/optimization/src/ldai_optimizer/client.py @@ -157,6 +157,7 @@ def __init__(self, ldClient: LDAIClient) -> None: self._last_succeeded_context: Optional[OptimizationContext] = None self._last_optimization_result_id: Optional[str] = None self._initial_tool_keys: List[str] = [] + self._initial_model_custom: Optional[Dict[str, Any]] = None self._total_token_usage: int = 0 if os.environ.get("LAUNCHDARKLY_API_KEY"): @@ -861,6 +862,11 @@ async def _get_agent_config( if isinstance(t, dict) and "key" in t ] + raw_model = raw_variation.get("model") + self._initial_model_custom = ( + raw_model.get("custom") if isinstance(raw_model, dict) else None + ) + agent_config = dataclasses.replace( agent_config, instructions=raw_instructions ) @@ -1231,7 +1237,32 @@ def _apply_new_variation_response( for msg in placeholder_warnings: logger.warning("[Iteration %d] -> %s", iteration, msg) - self._current_parameters = response_data["current_parameters"] + # Merge the LLM's returned parameters into the existing ones so that custom + # parameters (e.g. response_format, max_tokens, structured-output config) + # are preserved even when the LLM omits them from its response. + original_params = self._current_parameters.copy() + new_params = response_data["current_parameters"] + merged_params = {**original_params, **new_params} + + # Tools must be returned "unchanged" per the variation prompt. Always restore + # the original tools so that (a) user-defined tools are never silently dropped + # and (b) internal framework tools (e.g. structured-output tool injected by + # the agent SDK) cannot leak in from the LLM's response. + original_tools = original_params.get("tools") + if original_tools is not None: + returned_tools = new_params.get("tools") + if returned_tools is not None and returned_tools != original_tools: + logger.warning( + "[Iteration %d] -> LLM returned a modified tools list; restoring " + "original tools to prevent tool drift or internal-tool leakage. " + "Original: %s Returned: %s", + iteration, + [t.get("name") if isinstance(t, dict) else getattr(t, "name", t) for t in original_tools], + [t.get("name") if isinstance(t, dict) else getattr(t, "name", t) for t in returned_tools], + ) + merged_params["tools"] = original_tools + + self._current_parameters = merged_params # Update model — it should always be provided since it's required in the schema model_value = ( @@ -2017,6 +2048,8 @@ def _commit_variation( } if self._initial_tool_keys: payload["toolKeys"] = list(self._initial_tool_keys) + if self._initial_model_custom: + payload["model"] = {"custom": self._initial_model_custom} last_exc: Optional[Exception] = None for attempt in range(1, 4): diff --git a/packages/optimization/tests/test_client.py b/packages/optimization/tests/test_client.py index c441eedc..d4641084 100644 --- a/packages/optimization/tests/test_client.py +++ b/packages/optimization/tests/test_client.py @@ -948,6 +948,162 @@ async def test_raises_after_max_retries_exhausted(self): assert self.handle_agent_call.call_count == 3 +# --------------------------------------------------------------------------- +# Parameter persistence across variation generation +# --------------------------------------------------------------------------- + + +class TestParameterPersistence: + """Ensure custom parameters are preserved when the LLM generates a new variation.""" + + def setup_method(self): + self.client = _make_client() + agent_config = _make_agent_config() + self.client._agent_key = "test-agent" + self.client._agent_config = agent_config + self.client._initial_instructions = AGENT_INSTRUCTIONS + self.client._initialize_class_members_from_config(agent_config) + + def _set_params(self, params: Dict[str, Any]) -> None: + self.client._current_parameters = params + + def _run_variation(self, returned_params: Dict[str, Any]) -> None: + """Helper: simulate _apply_new_variation_response with a given returned params dict.""" + variation_ctx = OptimizationContext( + scores={}, + completion_response="", + current_instructions=AGENT_INSTRUCTIONS, + current_parameters={"temperature": 0.1}, + current_variables={}, + current_model="gpt-4o", + user_input=None, + iteration=1, + ) + response_data = { + "current_instructions": "Improved instructions.", + "current_parameters": returned_params, + "model": "gpt-4o", + } + self.client._options = _make_options() + self.client._apply_new_variation_response(response_data, variation_ctx, json.dumps(response_data), 1) + + async def test_custom_param_preserved_when_llm_omits_it(self): + """Parameters not in LLM response should be preserved from the original config.""" + self.client._options = _make_options() + self.client._current_parameters = {"temperature": 0.7, "max_tokens": 512, "seed": 42} + self._run_variation({"temperature": 0.5}) + assert self.client._current_parameters["max_tokens"] == 512 + assert self.client._current_parameters["seed"] == 42 + assert self.client._current_parameters["temperature"] == 0.5 + + async def test_response_format_preserved_when_llm_omits_it(self): + """response_format (structured output config) is preserved even if LLM returns only temperature.""" + self.client._options = _make_options() + self.client._current_parameters = { + "temperature": 0.7, + "response_format": {"type": "json_schema", "json_schema": {"name": "output"}}, + } + self._run_variation({"temperature": 0.5}) + assert self.client._current_parameters["response_format"] == { + "type": "json_schema", + "json_schema": {"name": "output"}, + } + + async def test_empty_returned_params_preserves_all_original_params(self): + """If LLM returns {}, all original parameters survive.""" + self.client._options = _make_options() + self.client._current_parameters = {"temperature": 0.7, "max_tokens": 256} + self._run_variation({}) + assert self.client._current_parameters["temperature"] == 0.7 + assert self.client._current_parameters["max_tokens"] == 256 + + async def test_llm_explicit_param_override_is_applied(self): + """If the LLM explicitly returns a parameter, the new value is used.""" + self.client._options = _make_options() + self.client._current_parameters = {"temperature": 0.7, "max_tokens": 256} + self._run_variation({"temperature": 0.3, "max_tokens": 128}) + assert self.client._current_parameters["temperature"] == 0.3 + assert self.client._current_parameters["max_tokens"] == 128 + + async def test_original_tools_always_restored(self): + """Tools from the original config are always restored regardless of LLM response.""" + original_tool = {"name": "my-tool", "type": "function", "description": "desc", "parameters": {}} + self.client._options = _make_options() + self.client._current_parameters = {"temperature": 0.7, "tools": [original_tool]} + self._run_variation({"temperature": 0.5, "tools": []}) + assert self.client._current_parameters["tools"] == [original_tool] + + async def test_internal_tool_leakage_is_blocked(self): + """If LLM returns tools including an internal framework tool, original tools are restored.""" + original_tool = {"name": "user-lookup", "type": "function", "description": "Looks up users", "parameters": {}} + internal_tool = {"name": "FinalAnswer", "type": "function", "description": "internal", "parameters": {}} + self.client._options = _make_options() + self.client._current_parameters = {"temperature": 0.7, "tools": [original_tool]} + self._run_variation({"temperature": 0.5, "tools": [original_tool, internal_tool]}) + result_tools = self.client._current_parameters["tools"] + assert result_tools == [original_tool] + assert not any(t.get("name") == "FinalAnswer" for t in result_tools) + + async def test_internal_tool_leakage_logs_warning(self): + """Tool mismatch should emit a warning.""" + original_tool = {"name": "my-tool", "type": "function", "description": "d", "parameters": {}} + internal_tool = {"name": "structured_output_tool", "type": "function", "description": "internal", "parameters": {}} + self.client._options = _make_options() + self.client._current_parameters = {"temperature": 0.7, "tools": [original_tool]} + with patch("ldai_optimizer.client.logger") as mock_logger: + self._run_variation({"temperature": 0.5, "tools": [internal_tool]}) + warning_calls = [c for c in mock_logger.warning.call_args_list if "tool" in str(c).lower()] + assert len(warning_calls) >= 1 + + async def test_no_original_tools_allows_llm_returned_tools(self): + """When the original config had no tools, the LLM is free to return tools.""" + new_tool = {"name": "new-tool", "type": "function", "description": "desc", "parameters": {}} + self.client._options = _make_options() + self.client._current_parameters = {"temperature": 0.7} + self._run_variation({"temperature": 0.5, "tools": [new_tool]}) + assert self.client._current_parameters.get("tools") == [new_tool] + + async def test_params_preserved_across_full_optimization_loop(self): + """End-to-end: custom params survive through a full failed-then-succeeded optimization.""" + custom_params_response = json.dumps({ + "current_instructions": "Improved.", + "current_parameters": {"temperature": 0.3}, # omits max_tokens and response_format + "model": "gpt-4o", + }) + agent_config_with_params = _make_agent_config( + parameters={"temperature": 0.7, "max_tokens": 512, "response_format": {"type": "json_object"}}, + ) + mock_ldai = _make_ldai_client(agent_config=agent_config_with_params) + mock_ldai._client.variation.return_value = { + "instructions": AGENT_INSTRUCTIONS, + } + agent_responses = [ + OptimizationResponse(output="Bad answer."), # iteration 1: agent + OptimizationResponse(output=custom_params_response), # iteration 1: variation + OptimizationResponse(output="Good answer."), # iteration 2: agent + OptimizationResponse(output="Good answer."), # iteration 2: validation + ] + handle_agent_call = AsyncMock(side_effect=agent_responses) + judge_responses = [ + OptimizationResponse(output=JUDGE_FAIL_RESPONSE), + OptimizationResponse(output=JUDGE_PASS_RESPONSE), + OptimizationResponse(output=JUDGE_PASS_RESPONSE), + ] + handle_judge_call = AsyncMock(side_effect=judge_responses) + client = _make_client(mock_ldai) + options = _make_options( + handle_agent_call=handle_agent_call, + handle_judge_call=handle_judge_call, + max_attempts=3, + ) + result = await client.optimize_from_options("test-agent", options) + assert result.scores["accuracy"].score == 1.0 + # After variation, max_tokens and response_format should still be present + assert client._current_parameters.get("max_tokens") == 512 + assert client._current_parameters.get("response_format") == {"type": "json_object"} + assert client._current_parameters.get("temperature") == 0.3 # LLM's update applied + + # --------------------------------------------------------------------------- # Full optimization loop # --------------------------------------------------------------------------- @@ -4048,6 +4204,48 @@ def test_toolkeys_not_in_payload_when_no_tools(self): payload = api_client.create_ai_config_variation.call_args[0][2] assert "toolKeys" not in payload + # --- model.custom propagation --- + + def test_model_custom_included_in_payload_when_set(self): + client = self._make_client() + client._initial_model_custom = {"myApp": {"debug": True, "region": "us-east-1"}} + api_client = _make_api_client_for_commit() + + client._commit_variation( + _make_winning_context(), project_key="my-project", + ai_config_key="my-agent", output_key="k", api_client=api_client, + ) + + payload = api_client.create_ai_config_variation.call_args[0][2] + assert payload["model"] == {"custom": {"myApp": {"debug": True, "region": "us-east-1"}}} + + def test_model_not_in_payload_when_model_custom_is_none(self): + client = self._make_client() + client._initial_model_custom = None + api_client = _make_api_client_for_commit() + + client._commit_variation( + _make_winning_context(), project_key="my-project", + ai_config_key="my-agent", output_key="k", api_client=api_client, + ) + + payload = api_client.create_ai_config_variation.call_args[0][2] + assert "model" not in payload + + def test_model_not_in_payload_when_model_custom_is_empty_dict(self): + """An empty custom dict is falsy — treated the same as absent.""" + client = self._make_client() + client._initial_model_custom = {} + api_client = _make_api_client_for_commit() + + client._commit_variation( + _make_winning_context(), project_key="my-project", + ai_config_key="my-agent", output_key="k", api_client=api_client, + ) + + payload = api_client.create_ai_config_variation.call_args[0][2] + assert "model" not in payload + # --------------------------------------------------------------------------- # Tool key extraction from raw variation (_get_agent_config) @@ -4097,6 +4295,36 @@ async def test_skips_tool_entries_without_key(self): await client._get_agent_config("test-agent", LD_CONTEXT) assert client._initial_tool_keys == ["good-tool"] + async def test_extracts_model_custom_from_raw_variation(self): + raw = { + "instructions": AGENT_INSTRUCTIONS, + "model": {"modelName": "gpt-4o", "custom": {"myApp": {"debug": True}}}, + } + client = self._make_client_with_variation(raw) + await client._get_agent_config("test-agent", LD_CONTEXT) + assert client._initial_model_custom == {"myApp": {"debug": True}} + + async def test_model_custom_is_none_when_variation_has_no_model(self): + raw = {"instructions": AGENT_INSTRUCTIONS} + client = self._make_client_with_variation(raw) + await client._get_agent_config("test-agent", LD_CONTEXT) + assert client._initial_model_custom is None + + async def test_model_custom_is_none_when_model_has_no_custom_key(self): + raw = { + "instructions": AGENT_INSTRUCTIONS, + "model": {"modelName": "gpt-4o", "parameters": {"temperature": 0.7}}, + } + client = self._make_client_with_variation(raw) + await client._get_agent_config("test-agent", LD_CONTEXT) + assert client._initial_model_custom is None + + async def test_model_custom_is_none_when_model_is_not_a_dict(self): + raw = {"instructions": AGENT_INSTRUCTIONS, "model": "gpt-4o"} + client = self._make_client_with_variation(raw) + await client._get_agent_config("test-agent", LD_CONTEXT) + assert client._initial_model_custom is None + # --------------------------------------------------------------------------- # auto_commit in optimize_from_options