Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 122 additions & 3 deletions src/agentlab/llm/chat_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,12 @@
from functools import partial
from typing import Optional

import agentlab.llm.tracking as tracking
import anthropic
import openai
from openai import NOT_GIVEN, OpenAI

import agentlab.llm.tracking as tracking
from agentlab.llm.base_api import AbstractChatModel, BaseModelArgs
from agentlab.llm.llm_utils import AIMessage, Discussion
from openai import NOT_GIVEN, OpenAI


def make_system_message(content: str) -> dict:
Expand Down Expand Up @@ -90,6 +89,18 @@ def make_model(self):
)


@dataclass
class LiteLLMModelArgs(BaseModelArgs):

def make_model(self):
return LiteLLMChatModel(
model_name=self.model_name,
temperature=self.temperature,
max_tokens=self.max_new_tokens,
log_probs=self.log_probs,
)


@dataclass
class OpenAIModelArgs(BaseModelArgs):
"""Serializable object for instantiating a generic chat model with an OpenAI
Expand Down Expand Up @@ -627,3 +638,111 @@ def make_model(self):
temperature=self.temperature,
max_tokens=self.max_new_tokens,
)


class LiteLLMChatModel(AbstractChatModel):
def __init__(
self,
model_name,
api_key=None,
temperature=0.5,
max_tokens=100,
max_retry=4,
min_retry_wait_time=60,
api_key_env_var=None,
client_class=OpenAI,
client_args=None,
pricing_func=None,
log_probs=False,
):
assert max_retry > 0, "max_retry should be greater than 0"

self.model_name = model_name
self.temperature = temperature
self.max_tokens = max_tokens
self.max_retry = max_retry
self.min_retry_wait_time = min_retry_wait_time
self.log_probs = log_probs

# Get pricing information
if pricing_func:
pricings = pricing_func()
try:
self.input_cost = float(pricings[model_name]["prompt"])
self.output_cost = float(pricings[model_name]["completion"])
except KeyError:
logging.warning(
f"Model {model_name} not found in the pricing information, prices are set to 0. Maybe try upgrading langchain_community."
)
self.input_cost = 0.0
self.output_cost = 0.0
else:
self.input_cost = 0.0
self.output_cost = 0.0

def __call__(self, messages: list[dict], n_samples: int = 1, temperature: float = None) -> dict:
from litellm import completion as litellm_completion

# Initialize retry tracking attributes
self.retries = 0
self.success = False
self.error_types = []

completion = None
e = None
for itr in range(self.max_retry):
self.retries += 1
temperature = temperature if temperature is not None else self.temperature
try:
completion = litellm_completion(
model=self.model_name,
messages=messages,
)

if completion.usage is None:
raise OpenRouterError(
"The completion object does not contain usage information. This is likely a bug in the OpenRouter API."
)

self.success = True
break
except openai.OpenAIError as e:
error_type = handle_error(e, itr, self.min_retry_wait_time, self.max_retry)
self.error_types.append(error_type)

if not completion:
raise RetryError(
f"Failed to get a response from the API after {self.max_retry} retries\n"
f"Last error: {error_type}"
)

input_tokens = completion.usage.prompt_tokens
output_tokens = completion.usage.completion_tokens
cost = input_tokens * self.input_cost + output_tokens * self.output_cost

if hasattr(tracking.TRACKER, "instance") and isinstance(
tracking.TRACKER.instance, tracking.LLMTracker
):
tracking.TRACKER.instance(input_tokens, output_tokens, cost)

if n_samples == 1:
res_text = completion.choices[0].message.content
if res_text is not None:
res_text = res_text.removesuffix("<|end|>").strip()
else:
res_text = ""
res = AIMessage(res_text)
if self.log_probs:
res["log_probs"] = completion.choices[0].log_probs
return res
else:
return [
AIMessage(c.message.content.removesuffix("<|end|>").strip())
for c in completion.choices
]

def get_stats(self):
return {
"n_retry_llm": self.retries,
# "busted_retry_llm": int(not self.success), # not logged if it occurs anyways
}
4 changes: 2 additions & 2 deletions src/agentlab/llm/llm_configs.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from openai import NOT_GIVEN

from agentlab.llm.chat_api import (
AnthropicModelArgs,
AzureModelArgs,
BedrockModelArgs,
LiteLLMModelArgs,
OpenAIModelArgs,
OpenRouterModelArgs,
SelfHostedModelArgs,
)
from openai import NOT_GIVEN

default_oss_llms_args = {
"n_retry_server": 4,
Expand Down
Loading