diff --git a/openevolve/config.py b/openevolve/config.py index bef193da21..c6bb2f8069 100644 --- a/openevolve/config.py +++ b/openevolve/config.py @@ -350,6 +350,7 @@ class DatabaseConfig: novelty_llm: Optional["LLMInterface"] = None embedding_model: Optional[str] = None + embedding_base_url: Optional[str] = None similarity_threshold: float = 0.99 diff --git a/openevolve/database.py b/openevolve/database.py index eca5eab0bb..d6f85f2f6b 100644 --- a/openevolve/database.py +++ b/openevolve/database.py @@ -204,7 +204,7 @@ def __init__(self, config: DatabaseConfig): self.novelty_llm = config.novelty_llm self.embedding_client = ( - EmbeddingClient(config.embedding_model) if config.embedding_model else None + EmbeddingClient(config.embedding_model, base_url=config.embedding_base_url) if config.embedding_model else None ) self.similarity_threshold = config.similarity_threshold diff --git a/openevolve/embedding.py b/openevolve/embedding.py index 302d4513f6..7c3bd3484a 100644 --- a/openevolve/embedding.py +++ b/openevolve/embedding.py @@ -10,51 +10,40 @@ logger = logging.getLogger(__name__) -M = 1_000_000 - -OPENAI_EMBEDDING_MODELS = [ - "text-embedding-3-small", - "text-embedding-3-large", -] - AZURE_EMBEDDING_MODELS = [ "azure-text-embedding-3-small", "azure-text-embedding-3-large", ] -OPENAI_EMBEDDING_COSTS = { - "text-embedding-3-small": 0.02 / M, - "text-embedding-3-large": 0.13 / M, -} - class EmbeddingClient: - def __init__(self, model_name: str = "text-embedding-3-small"): + def __init__(self, model_name: str = "text-embedding-3-small", base_url: str | None = None): """ Initialize the EmbeddingClient. Args: - model (str): The OpenAI embedding model name to use. + model_name: The embedding model name to use. + base_url: Optional base URL for the embedding API endpoint. """ - self.client, self.model = self._get_client_model(model_name) + self.client, self.model = self._get_client_model(model_name, base_url) - def _get_client_model(self, model_name: str) -> tuple[openai.OpenAI, str]: - if model_name in OPENAI_EMBEDDING_MODELS: - # Use OPENAI_EMBEDDING_API_KEY if set, otherwise fall back to OPENAI_API_KEY - # This allows users to use OpenRouter for LLMs while using OpenAI for embeddings - embedding_api_key = os.getenv("OPENAI_EMBEDDING_API_KEY") or os.getenv("OPENAI_API_KEY") - client = openai.OpenAI(api_key=embedding_api_key) - model_to_use = model_name - elif model_name in AZURE_EMBEDDING_MODELS: + def _get_client_model( + self, model_name: str, base_url: str | None = None + ) -> tuple[openai.OpenAI, str]: + if model_name in AZURE_EMBEDDING_MODELS: # get rid of the azure- prefix model_to_use = model_name.split("azure-")[-1] client = openai.AzureOpenAI( api_key=os.getenv("AZURE_OPENAI_API_KEY"), api_version=os.getenv("AZURE_API_VERSION"), - azure_endpoint=os.getenv("AZURE_API_ENDPOINT"), + azure_endpoint=os.environ["AZURE_API_ENDPOINT"], ) else: - raise ValueError(f"Invalid embedding model: {model_name}") + # Use OPENAI_EMBEDDING_API_KEY if set, otherwise fall back to OPENAI_API_KEY + # This allows users to use OpenRouter for LLMs while using OpenAI for embeddings + embedding_api_key = os.getenv("OPENAI_EMBEDDING_API_KEY") or os.getenv("OPENAI_API_KEY") + client = openai.OpenAI(api_key=embedding_api_key, base_url=base_url) + model_to_use = model_name return client, model_to_use