Skip to content

Commit 6c65b10

Browse files
committed
Merge branch 'fix/specify-overload-client-return-type' of github.com:humanloop/humanloop-python into fix/specify-overload-client-return-type
2 parents 7577bec + d0bfa83 commit 6c65b10

File tree

1 file changed

+12
-12
lines changed

1 file changed

+12
-12
lines changed

src/humanloop/overload.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import inspect
22
import logging
33
import types
4-
from typing import Any, Callable, Dict, Optional, Union, TypeVar, Protocol
4+
from typing import Any, Callable, Dict, Optional, TypeVar, Union
55

66
from humanloop.agents.client import AgentsClient
77
from humanloop.context import (
@@ -59,13 +59,13 @@ def _get_file_type_from_client(
5959
return "dataset"
6060
elif isinstance(client, EvaluatorsClient):
6161
return "evaluator"
62+
else:
63+
raise ValueError(f"Unsupported client type: {type(client)}")
6264

63-
raise ValueError(f"Unsupported client type: {type(client)}")
6465

65-
66-
def _handle_tracing_context(kwargs: Dict[str, Any], client: Any) -> Dict[str, Any]:
66+
def _handle_tracing_context(kwargs: Dict[str, Any], client: T) -> Dict[str, Any]:
6767
"""Handle tracing context for both log and call methods."""
68-
trace_id = get_trace_id()
68+
trace_id = get_trace_id()
6969
if trace_id is not None:
7070
if "flow" in str(type(client).__name__).lower():
7171
context = get_decorator_context()
@@ -90,7 +90,7 @@ def _handle_tracing_context(kwargs: Dict[str, Any], client: Any) -> Dict[str, An
9090

9191
def _handle_local_files(
9292
kwargs: Dict[str, Any],
93-
client: Any,
93+
client: T,
9494
sync_client: Optional[SyncClient],
9595
use_local_files: bool,
9696
) -> Dict[str, Any]:
@@ -140,7 +140,7 @@ def _handle_evaluation_context(kwargs: Dict[str, Any]) -> tuple[Dict[str, Any],
140140
return kwargs, None
141141

142142

143-
def _overload_log(self: Any, sync_client: Optional[SyncClient], use_local_files: bool, **kwargs) -> LogResponseType:
143+
def _overload_log(self: T, sync_client: Optional[SyncClient], use_local_files: bool, **kwargs) -> LogResponseType:
144144
try:
145145
# Special handling for flows - prevent direct log usage
146146
if type(self) is FlowsClient and get_trace_id() is not None:
@@ -162,7 +162,7 @@ def _overload_log(self: Any, sync_client: Optional[SyncClient], use_local_files:
162162
kwargs = _handle_local_files(kwargs, self, sync_client, use_local_files)
163163

164164
kwargs, eval_callback = _handle_evaluation_context(kwargs)
165-
response = self._log(**kwargs) # Use stored original method
165+
response = self._log(**kwargs) # type: ignore[union-attr] # Use stored original method
166166
if eval_callback is not None:
167167
eval_callback(response.id)
168168
return response
@@ -174,11 +174,11 @@ def _overload_log(self: Any, sync_client: Optional[SyncClient], use_local_files:
174174
raise HumanloopRuntimeError from e
175175

176176

177-
def _overload_call(self: Any, sync_client: Optional[SyncClient], use_local_files: bool, **kwargs) -> CallResponseType:
177+
def _overload_call(self: T, sync_client: Optional[SyncClient], use_local_files: bool, **kwargs) -> CallResponseType:
178178
try:
179179
kwargs = _handle_tracing_context(kwargs, self)
180180
kwargs = _handle_local_files(kwargs, self, sync_client, use_local_files)
181-
return self._call(**kwargs) # Use stored original method
181+
return self._call(**kwargs) # type: ignore[union-attr] # Use stored original method
182182
except HumanloopRuntimeError:
183183
# Re-raise HumanloopRuntimeError without wrapping to preserve the message
184184
raise
@@ -199,7 +199,7 @@ def overload_client(
199199
client._log = client.log # type: ignore
200200

201201
# Create a closure to capture sync_client and use_local_files
202-
def log_wrapper(self: Any, **kwargs) -> LogResponseType:
202+
def log_wrapper(self: T, **kwargs) -> LogResponseType:
203203
return _overload_log(self, sync_client, use_local_files, **kwargs)
204204

205205
# Replace the log method with type ignore
@@ -215,7 +215,7 @@ def log_wrapper(self: Any, **kwargs) -> LogResponseType:
215215
client._call = client.call # type: ignore
216216

217217
# Create a closure to capture sync_client and use_local_files
218-
def call_wrapper(self: Any, **kwargs) -> CallResponseType:
218+
def call_wrapper(self: T, **kwargs) -> CallResponseType:
219219
return _overload_call(self, sync_client, use_local_files, **kwargs)
220220

221221
# Replace the call method with type ignore

0 commit comments

Comments
 (0)