Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 34 additions & 1 deletion packages/optimization/src/ldai_optimizer/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ def __init__(self, ldClient: LDAIClient) -> None:
self._last_succeeded_context: Optional[OptimizationContext] = None
self._last_optimization_result_id: Optional[str] = None
self._initial_tool_keys: List[str] = []
self._initial_model_custom: Optional[Dict[str, Any]] = None
self._total_token_usage: int = 0

if os.environ.get("LAUNCHDARKLY_API_KEY"):
Expand Down Expand Up @@ -861,6 +862,11 @@ async def _get_agent_config(
if isinstance(t, dict) and "key" in t
]

raw_model = raw_variation.get("model")
self._initial_model_custom = (
raw_model.get("custom") if isinstance(raw_model, dict) else None
)

agent_config = dataclasses.replace(
agent_config, instructions=raw_instructions
)
Expand Down Expand Up @@ -1231,7 +1237,32 @@ def _apply_new_variation_response(
for msg in placeholder_warnings:
logger.warning("[Iteration %d] -> %s", iteration, msg)

self._current_parameters = response_data["current_parameters"]
# Merge the LLM's returned parameters into the existing ones so that custom
# parameters (e.g. response_format, max_tokens, structured-output config)
# are preserved even when the LLM omits them from its response.
original_params = self._current_parameters.copy()
new_params = response_data["current_parameters"]
merged_params = {**original_params, **new_params}

# Tools must be returned "unchanged" per the variation prompt. Always restore
# the original tools so that (a) user-defined tools are never silently dropped
# and (b) internal framework tools (e.g. structured-output tool injected by
# the agent SDK) cannot leak in from the LLM's response.
original_tools = original_params.get("tools")
if original_tools is not None:
returned_tools = new_params.get("tools")
if returned_tools is not None and returned_tools != original_tools:
logger.warning(
"[Iteration %d] -> LLM returned a modified tools list; restoring "
"original tools to prevent tool drift or internal-tool leakage. "
"Original: %s Returned: %s",
iteration,
[t.get("name") if isinstance(t, dict) else getattr(t, "name", t) for t in original_tools],
[t.get("name") if isinstance(t, dict) else getattr(t, "name", t) for t in returned_tools],
)
merged_params["tools"] = original_tools

self._current_parameters = merged_params

# Update model — it should always be provided since it's required in the schema
model_value = (
Expand Down Expand Up @@ -2017,6 +2048,8 @@ def _commit_variation(
}
if self._initial_tool_keys:
payload["toolKeys"] = list(self._initial_tool_keys)
if self._initial_model_custom:
payload["model"] = {"custom": self._initial_model_custom}

last_exc: Optional[Exception] = None
for attempt in range(1, 4):
Expand Down
228 changes: 228 additions & 0 deletions packages/optimization/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -948,6 +948,162 @@ async def test_raises_after_max_retries_exhausted(self):
assert self.handle_agent_call.call_count == 3


# ---------------------------------------------------------------------------
# Parameter persistence across variation generation
# ---------------------------------------------------------------------------


class TestParameterPersistence:
"""Ensure custom parameters are preserved when the LLM generates a new variation."""

def setup_method(self):
self.client = _make_client()
agent_config = _make_agent_config()
self.client._agent_key = "test-agent"
self.client._agent_config = agent_config
self.client._initial_instructions = AGENT_INSTRUCTIONS
self.client._initialize_class_members_from_config(agent_config)

def _set_params(self, params: Dict[str, Any]) -> None:
self.client._current_parameters = params

def _run_variation(self, returned_params: Dict[str, Any]) -> None:
"""Helper: simulate _apply_new_variation_response with a given returned params dict."""
variation_ctx = OptimizationContext(
scores={},
completion_response="",
current_instructions=AGENT_INSTRUCTIONS,
current_parameters={"temperature": 0.1},
current_variables={},
current_model="gpt-4o",
user_input=None,
iteration=1,
)
response_data = {
"current_instructions": "Improved instructions.",
"current_parameters": returned_params,
"model": "gpt-4o",
}
self.client._options = _make_options()
self.client._apply_new_variation_response(response_data, variation_ctx, json.dumps(response_data), 1)

async def test_custom_param_preserved_when_llm_omits_it(self):
"""Parameters not in LLM response should be preserved from the original config."""
self.client._options = _make_options()
self.client._current_parameters = {"temperature": 0.7, "max_tokens": 512, "seed": 42}
self._run_variation({"temperature": 0.5})
assert self.client._current_parameters["max_tokens"] == 512
assert self.client._current_parameters["seed"] == 42
assert self.client._current_parameters["temperature"] == 0.5

async def test_response_format_preserved_when_llm_omits_it(self):
"""response_format (structured output config) is preserved even if LLM returns only temperature."""
self.client._options = _make_options()
self.client._current_parameters = {
"temperature": 0.7,
"response_format": {"type": "json_schema", "json_schema": {"name": "output"}},
}
self._run_variation({"temperature": 0.5})
assert self.client._current_parameters["response_format"] == {
"type": "json_schema",
"json_schema": {"name": "output"},
}

async def test_empty_returned_params_preserves_all_original_params(self):
"""If LLM returns {}, all original parameters survive."""
self.client._options = _make_options()
self.client._current_parameters = {"temperature": 0.7, "max_tokens": 256}
self._run_variation({})
assert self.client._current_parameters["temperature"] == 0.7
assert self.client._current_parameters["max_tokens"] == 256

async def test_llm_explicit_param_override_is_applied(self):
"""If the LLM explicitly returns a parameter, the new value is used."""
self.client._options = _make_options()
self.client._current_parameters = {"temperature": 0.7, "max_tokens": 256}
self._run_variation({"temperature": 0.3, "max_tokens": 128})
assert self.client._current_parameters["temperature"] == 0.3
assert self.client._current_parameters["max_tokens"] == 128

async def test_original_tools_always_restored(self):
"""Tools from the original config are always restored regardless of LLM response."""
original_tool = {"name": "my-tool", "type": "function", "description": "desc", "parameters": {}}
self.client._options = _make_options()
self.client._current_parameters = {"temperature": 0.7, "tools": [original_tool]}
self._run_variation({"temperature": 0.5, "tools": []})
assert self.client._current_parameters["tools"] == [original_tool]

async def test_internal_tool_leakage_is_blocked(self):
"""If LLM returns tools including an internal framework tool, original tools are restored."""
original_tool = {"name": "user-lookup", "type": "function", "description": "Looks up users", "parameters": {}}
internal_tool = {"name": "FinalAnswer", "type": "function", "description": "internal", "parameters": {}}
self.client._options = _make_options()
self.client._current_parameters = {"temperature": 0.7, "tools": [original_tool]}
self._run_variation({"temperature": 0.5, "tools": [original_tool, internal_tool]})
result_tools = self.client._current_parameters["tools"]
assert result_tools == [original_tool]
assert not any(t.get("name") == "FinalAnswer" for t in result_tools)

async def test_internal_tool_leakage_logs_warning(self):
"""Tool mismatch should emit a warning."""
original_tool = {"name": "my-tool", "type": "function", "description": "d", "parameters": {}}
internal_tool = {"name": "structured_output_tool", "type": "function", "description": "internal", "parameters": {}}
self.client._options = _make_options()
self.client._current_parameters = {"temperature": 0.7, "tools": [original_tool]}
with patch("ldai_optimizer.client.logger") as mock_logger:
self._run_variation({"temperature": 0.5, "tools": [internal_tool]})
warning_calls = [c for c in mock_logger.warning.call_args_list if "tool" in str(c).lower()]
assert len(warning_calls) >= 1

async def test_no_original_tools_allows_llm_returned_tools(self):
"""When the original config had no tools, the LLM is free to return tools."""
new_tool = {"name": "new-tool", "type": "function", "description": "desc", "parameters": {}}
self.client._options = _make_options()
self.client._current_parameters = {"temperature": 0.7}
self._run_variation({"temperature": 0.5, "tools": [new_tool]})
assert self.client._current_parameters.get("tools") == [new_tool]

async def test_params_preserved_across_full_optimization_loop(self):
"""End-to-end: custom params survive through a full failed-then-succeeded optimization."""
custom_params_response = json.dumps({
"current_instructions": "Improved.",
"current_parameters": {"temperature": 0.3}, # omits max_tokens and response_format
"model": "gpt-4o",
})
agent_config_with_params = _make_agent_config(
parameters={"temperature": 0.7, "max_tokens": 512, "response_format": {"type": "json_object"}},
)
mock_ldai = _make_ldai_client(agent_config=agent_config_with_params)
mock_ldai._client.variation.return_value = {
"instructions": AGENT_INSTRUCTIONS,
}
agent_responses = [
OptimizationResponse(output="Bad answer."), # iteration 1: agent
OptimizationResponse(output=custom_params_response), # iteration 1: variation
OptimizationResponse(output="Good answer."), # iteration 2: agent
OptimizationResponse(output="Good answer."), # iteration 2: validation
]
handle_agent_call = AsyncMock(side_effect=agent_responses)
judge_responses = [
OptimizationResponse(output=JUDGE_FAIL_RESPONSE),
OptimizationResponse(output=JUDGE_PASS_RESPONSE),
OptimizationResponse(output=JUDGE_PASS_RESPONSE),
]
handle_judge_call = AsyncMock(side_effect=judge_responses)
client = _make_client(mock_ldai)
options = _make_options(
handle_agent_call=handle_agent_call,
handle_judge_call=handle_judge_call,
max_attempts=3,
)
result = await client.optimize_from_options("test-agent", options)
assert result.scores["accuracy"].score == 1.0
# After variation, max_tokens and response_format should still be present
assert client._current_parameters.get("max_tokens") == 512
assert client._current_parameters.get("response_format") == {"type": "json_object"}
assert client._current_parameters.get("temperature") == 0.3 # LLM's update applied


# ---------------------------------------------------------------------------
# Full optimization loop
# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -4048,6 +4204,48 @@ def test_toolkeys_not_in_payload_when_no_tools(self):
payload = api_client.create_ai_config_variation.call_args[0][2]
assert "toolKeys" not in payload

# --- model.custom propagation ---

def test_model_custom_included_in_payload_when_set(self):
client = self._make_client()
client._initial_model_custom = {"myApp": {"debug": True, "region": "us-east-1"}}
api_client = _make_api_client_for_commit()

client._commit_variation(
_make_winning_context(), project_key="my-project",
ai_config_key="my-agent", output_key="k", api_client=api_client,
)

payload = api_client.create_ai_config_variation.call_args[0][2]
assert payload["model"] == {"custom": {"myApp": {"debug": True, "region": "us-east-1"}}}

def test_model_not_in_payload_when_model_custom_is_none(self):
client = self._make_client()
client._initial_model_custom = None
api_client = _make_api_client_for_commit()

client._commit_variation(
_make_winning_context(), project_key="my-project",
ai_config_key="my-agent", output_key="k", api_client=api_client,
)

payload = api_client.create_ai_config_variation.call_args[0][2]
assert "model" not in payload

def test_model_not_in_payload_when_model_custom_is_empty_dict(self):
"""An empty custom dict is falsy — treated the same as absent."""
client = self._make_client()
client._initial_model_custom = {}
api_client = _make_api_client_for_commit()

client._commit_variation(
_make_winning_context(), project_key="my-project",
ai_config_key="my-agent", output_key="k", api_client=api_client,
)

payload = api_client.create_ai_config_variation.call_args[0][2]
assert "model" not in payload


# ---------------------------------------------------------------------------
# Tool key extraction from raw variation (_get_agent_config)
Expand Down Expand Up @@ -4097,6 +4295,36 @@ async def test_skips_tool_entries_without_key(self):
await client._get_agent_config("test-agent", LD_CONTEXT)
assert client._initial_tool_keys == ["good-tool"]

async def test_extracts_model_custom_from_raw_variation(self):
raw = {
"instructions": AGENT_INSTRUCTIONS,
"model": {"modelName": "gpt-4o", "custom": {"myApp": {"debug": True}}},
}
client = self._make_client_with_variation(raw)
await client._get_agent_config("test-agent", LD_CONTEXT)
assert client._initial_model_custom == {"myApp": {"debug": True}}

async def test_model_custom_is_none_when_variation_has_no_model(self):
raw = {"instructions": AGENT_INSTRUCTIONS}
client = self._make_client_with_variation(raw)
await client._get_agent_config("test-agent", LD_CONTEXT)
assert client._initial_model_custom is None

async def test_model_custom_is_none_when_model_has_no_custom_key(self):
raw = {
"instructions": AGENT_INSTRUCTIONS,
"model": {"modelName": "gpt-4o", "parameters": {"temperature": 0.7}},
}
client = self._make_client_with_variation(raw)
await client._get_agent_config("test-agent", LD_CONTEXT)
assert client._initial_model_custom is None

async def test_model_custom_is_none_when_model_is_not_a_dict(self):
raw = {"instructions": AGENT_INSTRUCTIONS, "model": "gpt-4o"}
client = self._make_client_with_variation(raw)
await client._get_agent_config("test-agent", LD_CONTEXT)
assert client._initial_model_custom is None


# ---------------------------------------------------------------------------
# auto_commit in optimize_from_options
Expand Down
Loading