From d3dfa85d0ad296ae2d82cbfae4853b6faad99106 Mon Sep 17 00:00:00 2001 From: Patrice Bechard Date: Wed, 7 Jan 2026 23:54:19 +0000 Subject: [PATCH 1/4] add litellm chat model for genericagent --- src/agentlab/llm/chat_api.py | 123 +++++++++++++++++++++++++++++++- src/agentlab/llm/llm_configs.py | 63 ++++++++++++++++ 2 files changed, 185 insertions(+), 1 deletion(-) diff --git a/src/agentlab/llm/chat_api.py b/src/agentlab/llm/chat_api.py index d69147d7..56b2b820 100644 --- a/src/agentlab/llm/chat_api.py +++ b/src/agentlab/llm/chat_api.py @@ -89,6 +89,18 @@ def make_model(self): log_probs=self.log_probs, ) +@dataclass +class LiteLLMModelArgs(BaseModelArgs): + + def make_model(self): + return LiteLLMChatModel( + model_name=self.model_name, + temperature=self.temperature, + max_tokens=self.max_new_tokens, + log_probs=self.log_probs, + reasoning_effort=self.reasoning_effort, + ) + @dataclass class OpenAIModelArgs(BaseModelArgs): @@ -393,7 +405,6 @@ def __init__( log_probs=log_probs, ) - class AzureChatModel(ChatModel): def __init__( self, @@ -627,3 +638,113 @@ def make_model(self): temperature=self.temperature, max_tokens=self.max_new_tokens, ) + +class LiteLLMChatModel(AbstractChatModel): + def __init__( + self, + model_name, + api_key=None, + temperature=0.5, + max_tokens=100, + max_retry=4, + min_retry_wait_time=60, + api_key_env_var=None, + client_class=OpenAI, + client_args=None, + pricing_func=None, + log_probs=False, + reasoning_effort=None, + ): + assert max_retry > 0, "max_retry should be greater than 0" + + self.model_name = model_name + self.temperature = temperature + self.max_tokens = max_tokens + self.max_retry = max_retry + self.min_retry_wait_time = min_retry_wait_time + self.log_probs = log_probs + self.reasoning_effort = reasoning_effort + + # Get pricing information + if pricing_func: + pricings = pricing_func() + try: + self.input_cost = float(pricings[model_name]["prompt"]) + self.output_cost = float(pricings[model_name]["completion"]) + except KeyError: + logging.warning( + f"Model {model_name} not found in the pricing information, prices are set to 0. Maybe try upgrading langchain_community." + ) + self.input_cost = 0.0 + self.output_cost = 0.0 + else: + self.input_cost = 0.0 + self.output_cost = 0.0 + + + def __call__(self, messages: list[dict], n_samples: int = 1, temperature: float = None) -> dict: + from litellm import completion as litellm_completion + # Initialize retry tracking attributes + self.retries = 0 + self.success = False + self.error_types = [] + + completion = None + e = None + for itr in range(self.max_retry): + self.retries += 1 + temperature = temperature if temperature is not None else self.temperature + try: + completion = litellm_completion( + model=self.model_name, + messages=messages, + # n=n_samples, + # temperature=temperature, + # max_completion_tokens=self.max_tokens, + reasoning_effort=self.reasoning_effort, + ) + + if completion.usage is None: + raise OpenRouterError( + "The completion object does not contain usage information. This is likely a bug in the OpenRouter API." + ) + + self.success = True + break + except openai.OpenAIError as e: + error_type = handle_error(e, itr, self.min_retry_wait_time, self.max_retry) + self.error_types.append(error_type) + + if not completion: + raise RetryError( + f"Failed to get a response from the API after {self.max_retry} retries\n" + f"Last error: {error_type}" + ) + + input_tokens = completion.usage.prompt_tokens + output_tokens = completion.usage.completion_tokens + cost = input_tokens * self.input_cost + output_tokens * self.output_cost + + if hasattr(tracking.TRACKER, "instance") and isinstance( + tracking.TRACKER.instance, tracking.LLMTracker + ): + tracking.TRACKER.instance(input_tokens, output_tokens, cost) + + if n_samples == 1: + res_text = completion.choices[0].message.content + if res_text is not None: + res_text = res_text.removesuffix("<|end|>").strip() + else: + res_text = "" + res = AIMessage(res_text) + if self.log_probs: + res["log_probs"] = completion.choices[0].log_probs + return res + else: + return [AIMessage(c.message.content.removesuffix("<|end|>").strip()) for c in completion.choices] + + def get_stats(self): + return { + "n_retry_llm": self.retries, + # "busted_retry_llm": int(not self.success), # not logged if it occurs anyways + } \ No newline at end of file diff --git a/src/agentlab/llm/llm_configs.py b/src/agentlab/llm/llm_configs.py index 46860b5f..2ae22606 100644 --- a/src/agentlab/llm/llm_configs.py +++ b/src/agentlab/llm/llm_configs.py @@ -7,6 +7,7 @@ OpenAIModelArgs, OpenRouterModelArgs, SelfHostedModelArgs, + LiteLLMModelArgs, ) default_oss_llms_args = { @@ -200,6 +201,68 @@ temperature=1, # temperature param not supported by gpt-5 vision_support=True, ), + "azure/gpt-5-high-2025-08-07": AzureModelArgs( + model_name="gpt-5", + max_total_tokens=400_000, + max_input_tokens=256_000, + max_new_tokens=128_000, + temperature=1, # temperature param not supported by gpt-5 + vision_support=True, + reasoning_effort="high", + ), + "azure/gpt-5-mini-high-2025-08-07": AzureModelArgs( + model_name="gpt-5-mini", + max_total_tokens=400_000, + max_input_tokens=256_000, + max_new_tokens=128_000, + temperature=1, # temperature param not supported by gpt-5 + vision_support=True, + reasoning_effort="high", + ), + "azure/gpt-5-nano-high-2025-08-07": AzureModelArgs( + model_name="gpt-5-nano", + max_total_tokens=400_000, + max_input_tokens=256_000, + max_new_tokens=128_000, + temperature=1, # temperature param not supported by gpt-5 + vision_support=True, + reasoning_effort="high", + ), + "azure/gpt-oss-120b": AzureModelArgs( + model_name="gpt-oss-120b", + max_total_tokens=200_000, + max_input_tokens=200_000, + max_new_tokens=100_000, + temperature=1, + vision_support=False, + reasoning_effort="low", + ), + "azure/o3-high-2025-04-16": AzureModelArgs( + model_name="o3", + max_total_tokens=200_000, + max_input_tokens=200_000, + max_new_tokens=100_000, + temperature=1, + vision_support=False, + reasoning_effort="high", + ), + "azure/o3-mini-2025-01-31": AzureModelArgs( + model_name="o3-mini", + max_total_tokens=200_000, + max_input_tokens=200_000, + max_new_tokens=100_000, + temperature=1, + vision_support=False, + ), + "azure/o3-mini-high-2025-01-31": AzureModelArgs( + model_name="o3-mini", + max_total_tokens=200_000, + max_input_tokens=200_000, + max_new_tokens=100_000, + temperature=1, + vision_support=False, + reasoning_effort="high", + ), # ---------------- Anthropic ----------------# "anthropic/claude-3-7-sonnet-20250219": AnthropicModelArgs( model_name="claude-3-7-sonnet-20250219", From 05795d0c33ff7d9fff279ab12818845e81c8c36b Mon Sep 17 00:00:00 2001 From: Patrice Bechard Date: Thu, 8 Jan 2026 00:08:31 +0000 Subject: [PATCH 2/4] remove added models to llm config --- src/agentlab/llm/chat_api.py | 2 -- src/agentlab/llm/llm_configs.py | 62 --------------------------------- 2 files changed, 64 deletions(-) diff --git a/src/agentlab/llm/chat_api.py b/src/agentlab/llm/chat_api.py index 56b2b820..301b3995 100644 --- a/src/agentlab/llm/chat_api.py +++ b/src/agentlab/llm/chat_api.py @@ -98,7 +98,6 @@ def make_model(self): temperature=self.temperature, max_tokens=self.max_new_tokens, log_probs=self.log_probs, - reasoning_effort=self.reasoning_effort, ) @@ -653,7 +652,6 @@ def __init__( client_args=None, pricing_func=None, log_probs=False, - reasoning_effort=None, ): assert max_retry > 0, "max_retry should be greater than 0" diff --git a/src/agentlab/llm/llm_configs.py b/src/agentlab/llm/llm_configs.py index 2ae22606..7e571646 100644 --- a/src/agentlab/llm/llm_configs.py +++ b/src/agentlab/llm/llm_configs.py @@ -201,68 +201,6 @@ temperature=1, # temperature param not supported by gpt-5 vision_support=True, ), - "azure/gpt-5-high-2025-08-07": AzureModelArgs( - model_name="gpt-5", - max_total_tokens=400_000, - max_input_tokens=256_000, - max_new_tokens=128_000, - temperature=1, # temperature param not supported by gpt-5 - vision_support=True, - reasoning_effort="high", - ), - "azure/gpt-5-mini-high-2025-08-07": AzureModelArgs( - model_name="gpt-5-mini", - max_total_tokens=400_000, - max_input_tokens=256_000, - max_new_tokens=128_000, - temperature=1, # temperature param not supported by gpt-5 - vision_support=True, - reasoning_effort="high", - ), - "azure/gpt-5-nano-high-2025-08-07": AzureModelArgs( - model_name="gpt-5-nano", - max_total_tokens=400_000, - max_input_tokens=256_000, - max_new_tokens=128_000, - temperature=1, # temperature param not supported by gpt-5 - vision_support=True, - reasoning_effort="high", - ), - "azure/gpt-oss-120b": AzureModelArgs( - model_name="gpt-oss-120b", - max_total_tokens=200_000, - max_input_tokens=200_000, - max_new_tokens=100_000, - temperature=1, - vision_support=False, - reasoning_effort="low", - ), - "azure/o3-high-2025-04-16": AzureModelArgs( - model_name="o3", - max_total_tokens=200_000, - max_input_tokens=200_000, - max_new_tokens=100_000, - temperature=1, - vision_support=False, - reasoning_effort="high", - ), - "azure/o3-mini-2025-01-31": AzureModelArgs( - model_name="o3-mini", - max_total_tokens=200_000, - max_input_tokens=200_000, - max_new_tokens=100_000, - temperature=1, - vision_support=False, - ), - "azure/o3-mini-high-2025-01-31": AzureModelArgs( - model_name="o3-mini", - max_total_tokens=200_000, - max_input_tokens=200_000, - max_new_tokens=100_000, - temperature=1, - vision_support=False, - reasoning_effort="high", - ), # ---------------- Anthropic ----------------# "anthropic/claude-3-7-sonnet-20250219": AnthropicModelArgs( model_name="claude-3-7-sonnet-20250219", From f7a0f4943004a3603465a68dc13ec5d0aef736cc Mon Sep 17 00:00:00 2001 From: Patrice Bechard Date: Thu, 8 Jan 2026 00:09:32 +0000 Subject: [PATCH 3/4] format --- src/agentlab/llm/chat_api.py | 77 +++++++++------------------------ src/agentlab/llm/llm_configs.py | 5 +-- 2 files changed, 23 insertions(+), 59 deletions(-) diff --git a/src/agentlab/llm/chat_api.py b/src/agentlab/llm/chat_api.py index 301b3995..fdb277a4 100644 --- a/src/agentlab/llm/chat_api.py +++ b/src/agentlab/llm/chat_api.py @@ -6,13 +6,12 @@ from functools import partial from typing import Optional +import agentlab.llm.tracking as tracking import anthropic import openai -from openai import NOT_GIVEN, OpenAI - -import agentlab.llm.tracking as tracking from agentlab.llm.base_api import AbstractChatModel, BaseModelArgs from agentlab.llm.llm_utils import AIMessage, Discussion +from openai import NOT_GIVEN, OpenAI def make_system_message(content: str) -> dict: @@ -89,6 +88,7 @@ def make_model(self): log_probs=self.log_probs, ) + @dataclass class LiteLLMModelArgs(BaseModelArgs): @@ -119,9 +119,7 @@ def make_model(self): class AzureModelArgs(BaseModelArgs): """Serializable object for instantiating a generic chat model with an Azure model.""" - deployment_name: str = ( - None # NOTE: deployment_name is deprecated for Azure OpenAI and won't be used. - ) + deployment_name: str = None # NOTE: deployment_name is deprecated for Azure OpenAI and won't be used. def make_model(self): return AzureChatModel( @@ -219,9 +217,7 @@ class RetryError(Exception): def handle_error(error, itr, min_retry_wait_time, max_retry): if not isinstance(error, openai.OpenAIError): raise error - logging.warning( - f"Failed to get a response from the API: \n{error}\n" f"Retrying... ({itr+1}/{max_retry})" - ) + logging.warning(f"Failed to get a response from the API: \n{error}\n" f"Retrying... ({itr+1}/{max_retry})") wait_time = _extract_wait_time( error.args[0], min_retry_wait_time=min_retry_wait_time, @@ -320,18 +316,13 @@ def __call__(self, messages: list[dict], n_samples: int = 1, temperature: float self.error_types.append(error_type) if not completion: - raise RetryError( - f"Failed to get a response from the API after {self.max_retry} retries\n" - f"Last error: {error_type}" - ) + raise RetryError(f"Failed to get a response from the API after {self.max_retry} retries\n" f"Last error: {error_type}") input_tokens = completion.usage.prompt_tokens output_tokens = completion.usage.completion_tokens cost = input_tokens * self.input_cost + output_tokens * self.output_cost - if hasattr(tracking.TRACKER, "instance") and isinstance( - tracking.TRACKER.instance, tracking.LLMTracker - ): + if hasattr(tracking.TRACKER, "instance") and isinstance(tracking.TRACKER.instance, tracking.LLMTracker): tracking.TRACKER.instance(input_tokens, output_tokens, cost) if n_samples == 1: @@ -404,6 +395,7 @@ def __init__( log_probs=log_probs, ) + class AzureChatModel(ChatModel): def __init__( self, @@ -417,18 +409,12 @@ def __init__( log_probs=False, ): api_key = api_key or os.getenv("AZURE_OPENAI_API_KEY") - assert ( - api_key - ), "AZURE_OPENAI_API_KEY has to be defined in the environment when using AzureChatModel" + assert api_key, "AZURE_OPENAI_API_KEY has to be defined in the environment when using AzureChatModel" endpoint = os.getenv("AZURE_OPENAI_ENDPOINT") - assert ( - endpoint - ), "AZURE_OPENAI_ENDPOINT has to be defined in the environment when using AzureChatModel" + assert endpoint, "AZURE_OPENAI_ENDPOINT has to be defined in the environment when using AzureChatModel" if deployment_name is not None: - logging.info( - f"Deployment name is deprecated for Azure OpenAI and won't be used. Using model name: {model_name}." - ) + logging.info(f"Deployment name is deprecated for Azure OpenAI and won't be used. Using model name: {model_name}.") client_args = { "base_url": endpoint, @@ -560,12 +546,8 @@ def __call__(self, messages: list[dict], n_samples: int = 1, temperature: float output_tokens = getattr(usage, "output_tokens", 0) cache_read_tokens = getattr(usage, "cache_input_tokens", 0) cache_write_tokens = getattr(usage, "cache_creation_input_tokens", 0) - cache_read_cost = ( - self.input_cost * tracking.ANTHROPIC_CACHE_PRICING_FACTOR["cache_read_tokens"] - ) - cache_write_cost = ( - self.input_cost * tracking.ANTHROPIC_CACHE_PRICING_FACTOR["cache_write_tokens"] - ) + cache_read_cost = self.input_cost * tracking.ANTHROPIC_CACHE_PRICING_FACTOR["cache_read_tokens"] + cache_write_cost = self.input_cost * tracking.ANTHROPIC_CACHE_PRICING_FACTOR["cache_write_tokens"] cost = ( new_input_tokens * self.input_cost + output_tokens * self.output_cost @@ -574,9 +556,7 @@ def __call__(self, messages: list[dict], n_samples: int = 1, temperature: float ) # Track usage if available - if hasattr(tracking.TRACKER, "instance") and isinstance( - tracking.TRACKER.instance, tracking.LLMTracker - ): + if hasattr(tracking.TRACKER, "instance") and isinstance(tracking.TRACKER.instance, tracking.LLMTracker): tracking.TRACKER.instance(new_input_tokens, output_tokens, cost) return AIMessage(response.content[0].text) @@ -613,14 +593,8 @@ def __init__( self.max_tokens = max_tokens self.max_retry = max_retry - if ( - not os.getenv("AWS_REGION") - or not os.getenv("AWS_ACCESS_KEY") - or not os.getenv("AWS_SECRET_KEY") - ): - raise ValueError( - "AWS_REGION, AWS_ACCESS_KEY and AWS_SECRET_KEY must be set in the environment when using BedrockChatModel" - ) + if not os.getenv("AWS_REGION") or not os.getenv("AWS_ACCESS_KEY") or not os.getenv("AWS_SECRET_KEY"): + raise ValueError("AWS_REGION, AWS_ACCESS_KEY and AWS_SECRET_KEY must be set in the environment when using BedrockChatModel") self.client = anthropic.AnthropicBedrock( aws_region=os.getenv("AWS_REGION"), @@ -638,6 +612,7 @@ def make_model(self): max_tokens=self.max_new_tokens, ) + class LiteLLMChatModel(AbstractChatModel): def __init__( self, @@ -661,7 +636,6 @@ def __init__( self.max_retry = max_retry self.min_retry_wait_time = min_retry_wait_time self.log_probs = log_probs - self.reasoning_effort = reasoning_effort # Get pricing information if pricing_func: @@ -679,9 +653,9 @@ def __init__( self.input_cost = 0.0 self.output_cost = 0.0 - def __call__(self, messages: list[dict], n_samples: int = 1, temperature: float = None) -> dict: from litellm import completion as litellm_completion + # Initialize retry tracking attributes self.retries = 0 self.success = False @@ -696,10 +670,6 @@ def __call__(self, messages: list[dict], n_samples: int = 1, temperature: float completion = litellm_completion( model=self.model_name, messages=messages, - # n=n_samples, - # temperature=temperature, - # max_completion_tokens=self.max_tokens, - reasoning_effort=self.reasoning_effort, ) if completion.usage is None: @@ -714,18 +684,13 @@ def __call__(self, messages: list[dict], n_samples: int = 1, temperature: float self.error_types.append(error_type) if not completion: - raise RetryError( - f"Failed to get a response from the API after {self.max_retry} retries\n" - f"Last error: {error_type}" - ) + raise RetryError(f"Failed to get a response from the API after {self.max_retry} retries\n" f"Last error: {error_type}") input_tokens = completion.usage.prompt_tokens output_tokens = completion.usage.completion_tokens cost = input_tokens * self.input_cost + output_tokens * self.output_cost - if hasattr(tracking.TRACKER, "instance") and isinstance( - tracking.TRACKER.instance, tracking.LLMTracker - ): + if hasattr(tracking.TRACKER, "instance") and isinstance(tracking.TRACKER.instance, tracking.LLMTracker): tracking.TRACKER.instance(input_tokens, output_tokens, cost) if n_samples == 1: @@ -745,4 +710,4 @@ def get_stats(self): return { "n_retry_llm": self.retries, # "busted_retry_llm": int(not self.success), # not logged if it occurs anyways - } \ No newline at end of file + } diff --git a/src/agentlab/llm/llm_configs.py b/src/agentlab/llm/llm_configs.py index 7e571646..eab6ec7c 100644 --- a/src/agentlab/llm/llm_configs.py +++ b/src/agentlab/llm/llm_configs.py @@ -1,14 +1,13 @@ -from openai import NOT_GIVEN - from agentlab.llm.chat_api import ( AnthropicModelArgs, AzureModelArgs, BedrockModelArgs, + LiteLLMModelArgs, OpenAIModelArgs, OpenRouterModelArgs, SelfHostedModelArgs, - LiteLLMModelArgs, ) +from openai import NOT_GIVEN default_oss_llms_args = { "n_retry_server": 4, From 97f471c3f0a085277bb0ca11bb7f119cd83a997f Mon Sep 17 00:00:00 2001 From: Patrice Bechard Date: Thu, 8 Jan 2026 00:15:51 +0000 Subject: [PATCH 4/4] format --- src/agentlab/llm/chat_api.py | 65 +++++++++++++++++++++++++++--------- 1 file changed, 50 insertions(+), 15 deletions(-) diff --git a/src/agentlab/llm/chat_api.py b/src/agentlab/llm/chat_api.py index fdb277a4..4c642220 100644 --- a/src/agentlab/llm/chat_api.py +++ b/src/agentlab/llm/chat_api.py @@ -119,7 +119,9 @@ def make_model(self): class AzureModelArgs(BaseModelArgs): """Serializable object for instantiating a generic chat model with an Azure model.""" - deployment_name: str = None # NOTE: deployment_name is deprecated for Azure OpenAI and won't be used. + deployment_name: str = ( + None # NOTE: deployment_name is deprecated for Azure OpenAI and won't be used. + ) def make_model(self): return AzureChatModel( @@ -217,7 +219,9 @@ class RetryError(Exception): def handle_error(error, itr, min_retry_wait_time, max_retry): if not isinstance(error, openai.OpenAIError): raise error - logging.warning(f"Failed to get a response from the API: \n{error}\n" f"Retrying... ({itr+1}/{max_retry})") + logging.warning( + f"Failed to get a response from the API: \n{error}\n" f"Retrying... ({itr+1}/{max_retry})" + ) wait_time = _extract_wait_time( error.args[0], min_retry_wait_time=min_retry_wait_time, @@ -316,13 +320,18 @@ def __call__(self, messages: list[dict], n_samples: int = 1, temperature: float self.error_types.append(error_type) if not completion: - raise RetryError(f"Failed to get a response from the API after {self.max_retry} retries\n" f"Last error: {error_type}") + raise RetryError( + f"Failed to get a response from the API after {self.max_retry} retries\n" + f"Last error: {error_type}" + ) input_tokens = completion.usage.prompt_tokens output_tokens = completion.usage.completion_tokens cost = input_tokens * self.input_cost + output_tokens * self.output_cost - if hasattr(tracking.TRACKER, "instance") and isinstance(tracking.TRACKER.instance, tracking.LLMTracker): + if hasattr(tracking.TRACKER, "instance") and isinstance( + tracking.TRACKER.instance, tracking.LLMTracker + ): tracking.TRACKER.instance(input_tokens, output_tokens, cost) if n_samples == 1: @@ -409,12 +418,18 @@ def __init__( log_probs=False, ): api_key = api_key or os.getenv("AZURE_OPENAI_API_KEY") - assert api_key, "AZURE_OPENAI_API_KEY has to be defined in the environment when using AzureChatModel" + assert ( + api_key + ), "AZURE_OPENAI_API_KEY has to be defined in the environment when using AzureChatModel" endpoint = os.getenv("AZURE_OPENAI_ENDPOINT") - assert endpoint, "AZURE_OPENAI_ENDPOINT has to be defined in the environment when using AzureChatModel" + assert ( + endpoint + ), "AZURE_OPENAI_ENDPOINT has to be defined in the environment when using AzureChatModel" if deployment_name is not None: - logging.info(f"Deployment name is deprecated for Azure OpenAI and won't be used. Using model name: {model_name}.") + logging.info( + f"Deployment name is deprecated for Azure OpenAI and won't be used. Using model name: {model_name}." + ) client_args = { "base_url": endpoint, @@ -546,8 +561,12 @@ def __call__(self, messages: list[dict], n_samples: int = 1, temperature: float output_tokens = getattr(usage, "output_tokens", 0) cache_read_tokens = getattr(usage, "cache_input_tokens", 0) cache_write_tokens = getattr(usage, "cache_creation_input_tokens", 0) - cache_read_cost = self.input_cost * tracking.ANTHROPIC_CACHE_PRICING_FACTOR["cache_read_tokens"] - cache_write_cost = self.input_cost * tracking.ANTHROPIC_CACHE_PRICING_FACTOR["cache_write_tokens"] + cache_read_cost = ( + self.input_cost * tracking.ANTHROPIC_CACHE_PRICING_FACTOR["cache_read_tokens"] + ) + cache_write_cost = ( + self.input_cost * tracking.ANTHROPIC_CACHE_PRICING_FACTOR["cache_write_tokens"] + ) cost = ( new_input_tokens * self.input_cost + output_tokens * self.output_cost @@ -556,7 +575,9 @@ def __call__(self, messages: list[dict], n_samples: int = 1, temperature: float ) # Track usage if available - if hasattr(tracking.TRACKER, "instance") and isinstance(tracking.TRACKER.instance, tracking.LLMTracker): + if hasattr(tracking.TRACKER, "instance") and isinstance( + tracking.TRACKER.instance, tracking.LLMTracker + ): tracking.TRACKER.instance(new_input_tokens, output_tokens, cost) return AIMessage(response.content[0].text) @@ -593,8 +614,14 @@ def __init__( self.max_tokens = max_tokens self.max_retry = max_retry - if not os.getenv("AWS_REGION") or not os.getenv("AWS_ACCESS_KEY") or not os.getenv("AWS_SECRET_KEY"): - raise ValueError("AWS_REGION, AWS_ACCESS_KEY and AWS_SECRET_KEY must be set in the environment when using BedrockChatModel") + if ( + not os.getenv("AWS_REGION") + or not os.getenv("AWS_ACCESS_KEY") + or not os.getenv("AWS_SECRET_KEY") + ): + raise ValueError( + "AWS_REGION, AWS_ACCESS_KEY and AWS_SECRET_KEY must be set in the environment when using BedrockChatModel" + ) self.client = anthropic.AnthropicBedrock( aws_region=os.getenv("AWS_REGION"), @@ -684,13 +711,18 @@ def __call__(self, messages: list[dict], n_samples: int = 1, temperature: float self.error_types.append(error_type) if not completion: - raise RetryError(f"Failed to get a response from the API after {self.max_retry} retries\n" f"Last error: {error_type}") + raise RetryError( + f"Failed to get a response from the API after {self.max_retry} retries\n" + f"Last error: {error_type}" + ) input_tokens = completion.usage.prompt_tokens output_tokens = completion.usage.completion_tokens cost = input_tokens * self.input_cost + output_tokens * self.output_cost - if hasattr(tracking.TRACKER, "instance") and isinstance(tracking.TRACKER.instance, tracking.LLMTracker): + if hasattr(tracking.TRACKER, "instance") and isinstance( + tracking.TRACKER.instance, tracking.LLMTracker + ): tracking.TRACKER.instance(input_tokens, output_tokens, cost) if n_samples == 1: @@ -704,7 +736,10 @@ def __call__(self, messages: list[dict], n_samples: int = 1, temperature: float res["log_probs"] = completion.choices[0].log_probs return res else: - return [AIMessage(c.message.content.removesuffix("<|end|>").strip()) for c in completion.choices] + return [ + AIMessage(c.message.content.removesuffix("<|end|>").strip()) + for c in completion.choices + ] def get_stats(self): return {