diff --git a/src/agentlab/llm/chat_api.py b/src/agentlab/llm/chat_api.py index d69147d7..4c642220 100644 --- a/src/agentlab/llm/chat_api.py +++ b/src/agentlab/llm/chat_api.py @@ -6,13 +6,12 @@ from functools import partial from typing import Optional +import agentlab.llm.tracking as tracking import anthropic import openai -from openai import NOT_GIVEN, OpenAI - -import agentlab.llm.tracking as tracking from agentlab.llm.base_api import AbstractChatModel, BaseModelArgs from agentlab.llm.llm_utils import AIMessage, Discussion +from openai import NOT_GIVEN, OpenAI def make_system_message(content: str) -> dict: @@ -90,6 +89,18 @@ def make_model(self): ) +@dataclass +class LiteLLMModelArgs(BaseModelArgs): + + def make_model(self): + return LiteLLMChatModel( + model_name=self.model_name, + temperature=self.temperature, + max_tokens=self.max_new_tokens, + log_probs=self.log_probs, + ) + + @dataclass class OpenAIModelArgs(BaseModelArgs): """Serializable object for instantiating a generic chat model with an OpenAI @@ -627,3 +638,111 @@ def make_model(self): temperature=self.temperature, max_tokens=self.max_new_tokens, ) + + +class LiteLLMChatModel(AbstractChatModel): + def __init__( + self, + model_name, + api_key=None, + temperature=0.5, + max_tokens=100, + max_retry=4, + min_retry_wait_time=60, + api_key_env_var=None, + client_class=OpenAI, + client_args=None, + pricing_func=None, + log_probs=False, + ): + assert max_retry > 0, "max_retry should be greater than 0" + + self.model_name = model_name + self.temperature = temperature + self.max_tokens = max_tokens + self.max_retry = max_retry + self.min_retry_wait_time = min_retry_wait_time + self.log_probs = log_probs + + # Get pricing information + if pricing_func: + pricings = pricing_func() + try: + self.input_cost = float(pricings[model_name]["prompt"]) + self.output_cost = float(pricings[model_name]["completion"]) + except KeyError: + logging.warning( + f"Model {model_name} not found in the pricing information, prices are set to 0. Maybe try upgrading langchain_community." + ) + self.input_cost = 0.0 + self.output_cost = 0.0 + else: + self.input_cost = 0.0 + self.output_cost = 0.0 + + def __call__(self, messages: list[dict], n_samples: int = 1, temperature: float = None) -> dict: + from litellm import completion as litellm_completion + + # Initialize retry tracking attributes + self.retries = 0 + self.success = False + self.error_types = [] + + completion = None + e = None + for itr in range(self.max_retry): + self.retries += 1 + temperature = temperature if temperature is not None else self.temperature + try: + completion = litellm_completion( + model=self.model_name, + messages=messages, + ) + + if completion.usage is None: + raise OpenRouterError( + "The completion object does not contain usage information. This is likely a bug in the OpenRouter API." + ) + + self.success = True + break + except openai.OpenAIError as e: + error_type = handle_error(e, itr, self.min_retry_wait_time, self.max_retry) + self.error_types.append(error_type) + + if not completion: + raise RetryError( + f"Failed to get a response from the API after {self.max_retry} retries\n" + f"Last error: {error_type}" + ) + + input_tokens = completion.usage.prompt_tokens + output_tokens = completion.usage.completion_tokens + cost = input_tokens * self.input_cost + output_tokens * self.output_cost + + if hasattr(tracking.TRACKER, "instance") and isinstance( + tracking.TRACKER.instance, tracking.LLMTracker + ): + tracking.TRACKER.instance(input_tokens, output_tokens, cost) + + if n_samples == 1: + res_text = completion.choices[0].message.content + if res_text is not None: + res_text = res_text.removesuffix("<|end|>").strip() + else: + res_text = "" + res = AIMessage(res_text) + if self.log_probs: + res["log_probs"] = completion.choices[0].log_probs + return res + else: + return [ + AIMessage(c.message.content.removesuffix("<|end|>").strip()) + for c in completion.choices + ] + + def get_stats(self): + return { + "n_retry_llm": self.retries, + # "busted_retry_llm": int(not self.success), # not logged if it occurs anyways + } diff --git a/src/agentlab/llm/llm_configs.py b/src/agentlab/llm/llm_configs.py index 46860b5f..eab6ec7c 100644 --- a/src/agentlab/llm/llm_configs.py +++ b/src/agentlab/llm/llm_configs.py @@ -1,13 +1,13 @@ -from openai import NOT_GIVEN - from agentlab.llm.chat_api import ( AnthropicModelArgs, AzureModelArgs, BedrockModelArgs, + LiteLLMModelArgs, OpenAIModelArgs, OpenRouterModelArgs, SelfHostedModelArgs, ) +from openai import NOT_GIVEN default_oss_llms_args = { "n_retry_server": 4,