|
1 | 1 | import inspect |
2 | 2 | import logging |
3 | 3 | import types |
4 | | -from typing import Any, Dict, Optional, Union, Callable |
| 4 | +from typing import Any, Callable, Dict, Optional, TypeVar, Union |
5 | 5 |
|
| 6 | +from humanloop.agents.client import AgentsClient |
| 7 | +from humanloop.client import ExtendedPromptsClient |
6 | 8 | from humanloop.context import ( |
7 | 9 | get_decorator_context, |
8 | 10 | get_evaluation_context, |
9 | 11 | get_trace_id, |
10 | 12 | ) |
| 13 | +from humanloop.datasets.client import DatasetsClient |
11 | 14 | from humanloop.error import HumanloopRuntimeError |
12 | | -from humanloop.sync.sync_client import SyncClient |
13 | | -from humanloop.prompts.client import PromptsClient |
| 15 | +from humanloop.evaluators.client import EvaluatorsClient |
14 | 16 | from humanloop.flows.client import FlowsClient |
15 | | -from humanloop.datasets.client import DatasetsClient |
16 | | -from humanloop.agents.client import AgentsClient |
| 17 | +from humanloop.prompts.client import PromptsClient |
| 18 | +from humanloop.sync.sync_client import SyncClient |
17 | 19 | from humanloop.tools.client import ToolsClient |
18 | | -from humanloop.evaluators.client import EvaluatorsClient |
19 | 20 | from humanloop.types import FileType |
| 21 | +from humanloop.types.agent_call_response import AgentCallResponse |
20 | 22 | from humanloop.types.create_evaluator_log_response import CreateEvaluatorLogResponse |
21 | 23 | from humanloop.types.create_flow_log_response import CreateFlowLogResponse |
22 | 24 | from humanloop.types.create_prompt_log_response import CreatePromptLogResponse |
23 | 25 | from humanloop.types.create_tool_log_response import CreateToolLogResponse |
24 | 26 | from humanloop.types.prompt_call_response import PromptCallResponse |
25 | | -from humanloop.types.agent_call_response import AgentCallResponse |
26 | 27 |
|
27 | 28 | logger = logging.getLogger("humanloop.sdk") |
28 | 29 |
|
@@ -183,11 +184,22 @@ def _overload_call(self: Any, sync_client: Optional[SyncClient], use_local_files |
183 | 184 | raise HumanloopRuntimeError from e |
184 | 185 |
|
185 | 186 |
|
| 187 | +ClientTemplateType = TypeVar( |
| 188 | + "ClientTemplateType", |
| 189 | + bound=Union[ |
| 190 | + FlowsClient, |
| 191 | + ExtendedPromptsClient, |
| 192 | + AgentsClient, |
| 193 | + ToolsClient, |
| 194 | + ], |
| 195 | +) |
| 196 | + |
| 197 | + |
186 | 198 | def overload_client( |
187 | | - client: Any, |
| 199 | + client: ClientTemplateType, |
188 | 200 | sync_client: Optional[SyncClient] = None, |
189 | 201 | use_local_files: bool = False, |
190 | | -) -> Any: |
| 202 | +) -> ClientTemplateType: |
191 | 203 | """Overloads client methods to add tracing, local file handling, and evaluation context.""" |
192 | 204 | # Store original log method as _log for all clients. Used in flow decorator |
193 | 205 | if hasattr(client, "log") and not hasattr(client, "_log"): |
|
0 commit comments