diff --git a/openevolve/config.py b/openevolve/config.py index bef193da21..3885779b11 100644 --- a/openevolve/config.py +++ b/openevolve/config.py @@ -55,6 +55,7 @@ class LLMModelConfig: api_base: str = None api_key: Optional[str] = None name: str = None + default_headers: Optional[Dict[str, str]] = None # Custom LLM client init_client: Optional[Callable] = None @@ -84,8 +85,14 @@ class LLMModelConfig: _manual_queue_dir: Optional[str] = None def __post_init__(self): - """Post-initialization to resolve ${VAR} env var references in api_key""" + """Post-initialization to resolve ${VAR} env var references in api_key and headers""" self.api_key = _resolve_env_var(self.api_key) + # Resolve environment variables in default_headers values + if self.default_headers: + resolved_headers = {} + for key, value in self.default_headers.items(): + resolved_headers[key] = _resolve_env_var(value) if isinstance(value, str) else value + self.default_headers = resolved_headers @dataclass @@ -170,6 +177,7 @@ def __post_init__(self): shared_config = { "api_base": self.api_base, "api_key": self.api_key, + "default_headers": self.default_headers, "temperature": self.temperature, "top_p": self.top_p, "max_tokens": self.max_tokens, diff --git a/openevolve/llm/openai.py b/openevolve/llm/openai.py index 7477e5b349..7a9cda22b5 100644 --- a/openevolve/llm/openai.py +++ b/openevolve/llm/openai.py @@ -61,6 +61,7 @@ def __init__( self.retry_delay = model_cfg.retry_delay self.api_base = model_cfg.api_base self.api_key = model_cfg.api_key + self.default_headers = getattr(model_cfg, "default_headers", None) self.random_seed = getattr(model_cfg, "random_seed", None) self.reasoning_effort = getattr(model_cfg, "reasoning_effort", None) @@ -82,12 +83,15 @@ def __init__( # Set up API client (normal mode) # OpenAI client requires max_retries to be int, not None max_retries = self.retries if self.retries is not None else 0 - self.client = openai.OpenAI( - api_key=self.api_key, - base_url=self.api_base, - timeout=self.timeout, - max_retries=max_retries, - ) + client_kwargs = { + "api_key": self.api_key, + "base_url": self.api_base, + "timeout": self.timeout, + "max_retries": max_retries, + } + if self.default_headers: + client_kwargs["default_headers"] = self.default_headers + self.client = openai.OpenAI(**client_kwargs) # Only log unique models to reduce duplication if not hasattr(logger, "_initialized_models"):