|
1 | 1 | import inspect |
2 | 2 | import logging |
3 | 3 | import types |
4 | | -from typing import Any, Dict, Optional, Union, Callable |
| 4 | +from typing import Any, Callable, Dict, Optional, Union, TypeVar, Protocol |
5 | 5 |
|
| 6 | +from humanloop.agents.client import AgentsClient |
6 | 7 | from humanloop.context import ( |
7 | 8 | get_decorator_context, |
8 | 9 | get_evaluation_context, |
9 | 10 | get_trace_id, |
10 | 11 | ) |
| 12 | +from humanloop.datasets.client import DatasetsClient |
11 | 13 | from humanloop.error import HumanloopRuntimeError |
12 | | -from humanloop.sync.sync_client import SyncClient |
13 | | -from humanloop.prompts.client import PromptsClient |
| 14 | +from humanloop.evaluators.client import EvaluatorsClient |
14 | 15 | from humanloop.flows.client import FlowsClient |
15 | | -from humanloop.datasets.client import DatasetsClient |
16 | | -from humanloop.agents.client import AgentsClient |
| 16 | +from humanloop.prompts.client import PromptsClient |
| 17 | +from humanloop.sync.sync_client import SyncClient |
17 | 18 | from humanloop.tools.client import ToolsClient |
18 | | -from humanloop.evaluators.client import EvaluatorsClient |
19 | 19 | from humanloop.types import FileType |
| 20 | +from humanloop.types.agent_call_response import AgentCallResponse |
20 | 21 | from humanloop.types.create_evaluator_log_response import CreateEvaluatorLogResponse |
21 | 22 | from humanloop.types.create_flow_log_response import CreateFlowLogResponse |
22 | 23 | from humanloop.types.create_prompt_log_response import CreatePromptLogResponse |
23 | 24 | from humanloop.types.create_tool_log_response import CreateToolLogResponse |
24 | 25 | from humanloop.types.prompt_call_response import PromptCallResponse |
25 | | -from humanloop.types.agent_call_response import AgentCallResponse |
26 | 26 |
|
27 | 27 | logger = logging.getLogger("humanloop.sdk") |
28 | 28 |
|
| 29 | + |
29 | 30 | LogResponseType = Union[ |
30 | 31 | CreatePromptLogResponse, |
31 | 32 | CreateToolLogResponse, |
|
39 | 40 | ] |
40 | 41 |
|
41 | 42 |
|
| 43 | +T = TypeVar("T", bound=Union[PromptsClient, AgentsClient, ToolsClient, FlowsClient, DatasetsClient, EvaluatorsClient]) |
| 44 | + |
| 45 | + |
42 | 46 | def _get_file_type_from_client( |
43 | 47 | client: Union[PromptsClient, AgentsClient, ToolsClient, FlowsClient, DatasetsClient, EvaluatorsClient], |
44 | 48 | ) -> FileType: |
@@ -184,33 +188,39 @@ def _overload_call(self: Any, sync_client: Optional[SyncClient], use_local_files |
184 | 188 |
|
185 | 189 |
|
186 | 190 | def overload_client( |
187 | | - client: Any, |
| 191 | + client: T, |
188 | 192 | sync_client: Optional[SyncClient] = None, |
189 | 193 | use_local_files: bool = False, |
190 | | -) -> Any: |
| 194 | +) -> T: |
191 | 195 | """Overloads client methods to add tracing, local file handling, and evaluation context.""" |
192 | 196 | # Store original log method as _log for all clients. Used in flow decorator |
193 | 197 | if hasattr(client, "log") and not hasattr(client, "_log"): |
194 | | - client._log = client.log # type: ignore[attr-defined] |
| 198 | + # Store original method - using getattr/setattr to avoid type errors |
| 199 | + original_log = getattr(client, "log") |
| 200 | + setattr(client, "_log", original_log) |
195 | 201 |
|
196 | 202 | # Create a closure to capture sync_client and use_local_files |
197 | 203 | def log_wrapper(self: Any, **kwargs) -> LogResponseType: |
198 | 204 | return _overload_log(self, sync_client, use_local_files, **kwargs) |
199 | 205 |
|
200 | | - client.log = types.MethodType(log_wrapper, client) |
| 206 | + # Replace the log method |
| 207 | + setattr(client, "log", types.MethodType(log_wrapper, client)) |
201 | 208 |
|
202 | 209 | # Overload call method for Prompt and Agent clients |
203 | 210 | if _get_file_type_from_client(client) in ["prompt", "agent"]: |
204 | 211 | if sync_client is None and use_local_files: |
205 | 212 | logger.error("sync_client is None but client has call method and use_local_files=%s", use_local_files) |
206 | 213 | raise HumanloopRuntimeError("sync_client is required for clients that support call operations") |
207 | 214 | if hasattr(client, "call") and not hasattr(client, "_call"): |
208 | | - client._call = client.call # type: ignore[attr-defined] |
| 215 | + # Store original method - using getattr/setattr to avoid type errors |
| 216 | + original_call = getattr(client, "call") |
| 217 | + setattr(client, "_call", original_call) |
209 | 218 |
|
210 | 219 | # Create a closure to capture sync_client and use_local_files |
211 | 220 | def call_wrapper(self: Any, **kwargs) -> CallResponseType: |
212 | 221 | return _overload_call(self, sync_client, use_local_files, **kwargs) |
213 | 222 |
|
214 | | - client.call = types.MethodType(call_wrapper, client) |
| 223 | + # Replace the call method |
| 224 | + setattr(client, "call", types.MethodType(call_wrapper, client)) |
215 | 225 |
|
216 | 226 | return client |
0 commit comments