Skip to content

Commit b3060c8

Browse files
author
Andrei Bratu
committed
Support prompt.call inside utilities
1 parent 24a6fc3 commit b3060c8

File tree

21 files changed

+1028
-391
lines changed

21 files changed

+1028
-391
lines changed

poetry.lock

Lines changed: 35 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ parse = ">=1"
4949
pydantic = ">= 1.9.2"
5050
pydantic-core = "^2.18.2"
5151
typing_extensions = ">= 4.0.0"
52+
deepdiff = {extras = ["murmur"], version = "^8.2.0"}
53+
mmh3 = "^5.1.0"
5254

5355
[tool.poetry.dev-dependencies]
5456
mypy = "1.0.1"

src/humanloop/client.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from opentelemetry.trace import Tracer
1010

1111
from humanloop.core.client_wrapper import SyncClientWrapper
12+
from humanloop.eval_utils.run import prompt_call_evaluation_aware
1213
from humanloop.utilities.types import DecoratorPromptKernelRequestParams
1314

1415
from humanloop.eval_utils import log_with_evaluation_context, run_eval
@@ -120,6 +121,7 @@ def __init__(
120121

121122
# Overload the .log method of the clients to be aware of Evaluation Context
122123
self.prompts = log_with_evaluation_context(client=self.prompts)
124+
self.prompts = prompt_call_evaluation_aware(client=self.prompts)
123125
self.flows = log_with_evaluation_context(client=self.flows)
124126

125127
if opentelemetry_tracer_provider is not None:
@@ -135,6 +137,7 @@ def __init__(
135137
instrument_provider(provider=self._tracer_provider)
136138
self._tracer_provider.add_span_processor(
137139
HumanloopSpanProcessor(
140+
client=self,
138141
exporter=HumanloopSpanExporter(client=self),
139142
),
140143
)

src/humanloop/eval_utils/context.py

Lines changed: 60 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from contextvars import ContextVar
22
from dataclasses import dataclass
33
from typing import Any, Callable
4+
from opentelemetry.trace import Tracer
45

56

67
@dataclass
@@ -26,11 +27,9 @@ class EvaluationContext:
2627
run_id: str
2728

2829

29-
EVALUATION_CONTEXT_VARIABLE_NAME = "__EVALUATION_CONTEXT"
30+
_EVALUATION_CONTEXT_VAR: ContextVar[EvaluationContext] = ContextVar("__EVALUATION_CONTEXT")
3031

31-
_EVALUATION_CONTEXT_VAR: ContextVar[EvaluationContext] = ContextVar(EVALUATION_CONTEXT_VARIABLE_NAME)
32-
33-
_UnsafeEvaluationContextRead = RuntimeError("EvaluationContext not set in the current thread.")
32+
_UnsafeContextRead = RuntimeError("Attempting to read from thread Context when variable was not set.")
3433

3534

3635
def set_evaluation_context(context: EvaluationContext):
@@ -41,7 +40,7 @@ def get_evaluation_context() -> EvaluationContext:
4140
try:
4241
return _EVALUATION_CONTEXT_VAR.get()
4342
except LookupError:
44-
raise _UnsafeEvaluationContextRead
43+
raise _UnsafeContextRead
4544

4645

4746
def evaluation_context_set() -> bool:
@@ -66,4 +65,59 @@ def is_evaluated_file(file_path) -> bool:
6665
evaluation_context = _EVALUATION_CONTEXT_VAR.get()
6766
return evaluation_context.path == file_path
6867
except LookupError:
69-
raise _UnsafeEvaluationContextRead
68+
raise _UnsafeContextRead
69+
70+
71+
@dataclass
72+
class PromptUtilityContext:
73+
tracer: Tracer
74+
_in_prompt_utility: int
75+
76+
@property
77+
def in_prompt_utility(self) -> bool:
78+
return self._in_prompt_utility > 0
79+
80+
81+
_PROMPT_UTILITY_CONTEXT_VAR: ContextVar[PromptUtilityContext] = ContextVar("__PROMPT_UTILITY_CONTEXT")
82+
83+
84+
def in_prompt_utility_context() -> bool:
85+
try:
86+
return _PROMPT_UTILITY_CONTEXT_VAR.get().in_prompt_utility
87+
except LookupError:
88+
return False
89+
90+
91+
def set_prompt_utility_context(tracer: Tracer):
92+
global _PROMPT_UTILITY_CONTEXT_VAR
93+
try:
94+
prompt_utility_context = _PROMPT_UTILITY_CONTEXT_VAR.get()
95+
# Already set, push another context
96+
prompt_utility_context._in_prompt_utility += 1
97+
_PROMPT_UTILITY_CONTEXT_VAR.set(prompt_utility_context)
98+
except LookupError:
99+
_PROMPT_UTILITY_CONTEXT_VAR.set(
100+
PromptUtilityContext(
101+
tracer=tracer,
102+
_in_prompt_utility=1,
103+
)
104+
)
105+
106+
107+
def get_prompt_utility_context() -> PromptUtilityContext:
108+
try:
109+
return _PROMPT_UTILITY_CONTEXT_VAR.get()
110+
except LookupError:
111+
raise _UnsafeContextRead
112+
113+
114+
def unset_prompt_utility_context():
115+
global _PROMPT_UTILITY_CONTEXT_VAR_TOKEN
116+
try:
117+
prompt_utility_context = _PROMPT_UTILITY_CONTEXT_VAR.get()
118+
if prompt_utility_context._in_prompt_utility >= 1:
119+
prompt_utility_context._in_prompt_utility -= 1
120+
else:
121+
raise ValueError("No matching unset_prompt_utility_context() call.")
122+
except LookupError:
123+
raise _UnsafeContextRead

src/humanloop/eval_utils/run.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
from humanloop.eval_utils.context import (
3030
EvaluationContext,
3131
get_evaluation_context,
32+
get_prompt_utility_context,
33+
in_prompt_utility_context,
3234
log_belongs_to_evaluated_file,
3335
set_evaluation_context,
3436
)
@@ -37,6 +39,8 @@
3739
# We use TypedDicts for requests, which is consistent with the rest of the SDK
3840
from humanloop.evaluators.client import EvaluatorsClient
3941
from humanloop.flows.client import FlowsClient
42+
from humanloop.otel.constants import HUMANLOOP_INTERCEPTED_HL_CALL_RESPONSE, HUMANLOOP_INTERCEPTED_HL_CALL_SPAN_NAME
43+
from humanloop.otel.helpers import write_to_opentelemetry_span
4044
from humanloop.prompts.client import PromptsClient
4145
from humanloop.requests import CodeEvaluatorRequestParams as CodeEvaluatorDict
4246
from humanloop.requests import ExternalEvaluatorRequestParams as ExternalEvaluator
@@ -62,6 +66,7 @@
6266
from humanloop.types.datapoint_response import DatapointResponse
6367
from humanloop.types.dataset_response import DatasetResponse
6468
from humanloop.types.evaluation_run_response import EvaluationRunResponse
69+
from humanloop.types.prompt_call_response import PromptCallResponse
6570
from humanloop.types.run_stats_response import RunStatsResponse
6671
from pydantic import ValidationError
6772

@@ -94,6 +99,47 @@
9499
CLIENT_TYPE = TypeVar("CLIENT_TYPE", PromptsClient, ToolsClient, FlowsClient, EvaluatorsClient)
95100

96101

102+
class HumanloopUtilitySyntaxError(Exception):
103+
def __init__(self, message):
104+
self.message = message
105+
106+
def __str__(self):
107+
return self.message
108+
109+
110+
def prompt_call_evaluation_aware(client: PromptsClient) -> PromptsClient:
111+
client._call = client.call
112+
113+
def _overload_call(self, **kwargs) -> PromptCallResponse:
114+
if in_prompt_utility_context():
115+
kwargs = {**kwargs, "save": False}
116+
117+
try:
118+
response = self._call(**kwargs)
119+
response = typing.cast(PromptCallResponse, response)
120+
except Exception as e:
121+
# TODO: Bug found in backend: not specifying a model 400s but creates a File
122+
raise HumanloopUtilitySyntaxError(message=str(e)) from e
123+
124+
prompt_utility_context = get_prompt_utility_context()
125+
126+
with prompt_utility_context.tracer.start_as_current_span(HUMANLOOP_INTERCEPTED_HL_CALL_SPAN_NAME) as span:
127+
write_to_opentelemetry_span(
128+
span=span,
129+
key=HUMANLOOP_INTERCEPTED_HL_CALL_RESPONSE,
130+
value=response.dict(),
131+
)
132+
return response
133+
else:
134+
return self._call(kwargs)
135+
136+
# Replace the original log method with the overloaded one
137+
client.call = types.MethodType(_overload_call, client)
138+
# Return the client with the overloaded log method
139+
logger.debug("Overloaded the .log method of %s", client)
140+
return client
141+
142+
97143
def log_with_evaluation_context(client: CLIENT_TYPE) -> CLIENT_TYPE:
98144
"""
99145
Wrap the `log` method of the provided Humanloop client to use EVALUATION_CONTEXT.
@@ -142,7 +188,7 @@ def _overload_log(
142188
# Replace the original log method with the overloaded one
143189
client.log = types.MethodType(_overload_log, client)
144190
# Return the client with the overloaded log method
145-
logger.debug("Overloaded the .log method of %s", client)
191+
logger.debug("Overloaded the .call method of %s", client)
146192
return client
147193

148194

src/humanloop/otel/constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,5 @@
66
HUMANLOOP_PATH_KEY = "humanloop.file.path"
77
# Required for the exporter to know when to mark the Flow Log as complete
88
HUMANLOOP_FLOW_PREREQUISITES_KEY = "humanloop.flow.prerequisites"
9+
HUMANLOOP_INTERCEPTED_HL_CALL_SPAN_NAME = "humanloop_intercepted_hl_call"
10+
HUMANLOOP_INTERCEPTED_HL_CALL_RESPONSE = "intercepted_call_response"

src/humanloop/otel/exporter.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,6 @@ def _do_work(self):
186186

187187
def _export_span_dispatch(self, span: ReadableSpan) -> None:
188188
"""Call the appropriate BaseHumanloop.X.log based on the Span type."""
189-
hl_file = read_from_opentelemetry_span(span, key=HUMANLOOP_FILE_KEY)
190189
file_type = span._attributes.get(HUMANLOOP_FILE_TYPE_KEY) # type: ignore
191190
parent_span_id = span.parent.span_id if span.parent else None
192191

src/humanloop/otel/helpers.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from opentelemetry.trace import SpanKind
66
from opentelemetry.util.types import AttributeValue
77

8+
from humanloop.otel.constants import HUMANLOOP_INTERCEPTED_HL_CALL_SPAN_NAME
9+
810
NestedDict = dict[str, Union["NestedDict", AttributeValue]]
911
NestedList = list[Union["NestedList", NestedDict]]
1012

@@ -262,6 +264,10 @@ def is_llm_provider_call(span: ReadableSpan) -> bool:
262264
)
263265

264266

267+
def is_intercepted_call(span: ReadableSpan) -> bool:
268+
return span.name == HUMANLOOP_INTERCEPTED_HL_CALL_SPAN_NAME
269+
270+
265271
def is_humanloop_span(span: ReadableSpan) -> bool:
266272
"""Check if the Span was created by the Humanloop SDK."""
267273
return span.name.startswith("humanloop.")

src/humanloop/otel/processor/__init__.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,16 @@
22
import logging
33
from collections import defaultdict
44
from typing import Optional
5+
import typing
56

67
from opentelemetry.sdk.trace import ReadableSpan
78
from opentelemetry.sdk.trace.export import SimpleSpanProcessor, SpanExporter
89

10+
from humanloop.base_client import BaseHumanloop
911
from humanloop.otel.constants import (
1012
HUMANLOOP_FILE_TYPE_KEY,
1113
HUMANLOOP_FLOW_PREREQUISITES_KEY,
14+
HUMANLOOP_INTERCEPTED_HL_CALL_SPAN_NAME,
1215
HUMANLOOP_LOG_KEY,
1316
)
1417
from humanloop.otel.helpers import (
@@ -18,6 +21,10 @@
1821
)
1922
from humanloop.otel.processor.prompts import enhance_prompt_span
2023

24+
if typing.TYPE_CHECKING:
25+
from humanloop.base_client import BaseHumanloop
26+
27+
2128
logger = logging.getLogger("humanloop.sdk")
2229

2330

@@ -49,6 +56,7 @@ class HumanloopSpanProcessor(SimpleSpanProcessor):
4956
def __init__(
5057
self,
5158
exporter: SpanExporter,
59+
client: "BaseHumanloop",
5260
) -> None:
5361
super().__init__(exporter)
5462
# span parent to span children map
@@ -58,6 +66,7 @@ def __init__(
5866
# They are passed to the Exporter as a span attribute
5967
# so the Exporter knows when to complete a trace
6068
self._spans_to_complete_flow_trace: dict[int, list[int]] = {}
69+
self._client = client
6170

6271
def shutdown(self):
6372
return super().shutdown()
@@ -172,6 +181,7 @@ def _send_to_exporter(
172181
span_id = span.context.span_id
173182
if file_type == "prompt":
174183
enhance_prompt_span(
184+
client=self._client,
175185
prompt_span=span,
176186
dependencies=dependencies,
177187
)
@@ -209,7 +219,9 @@ def _is_dependency(cls, span: ReadableSpan) -> bool:
209219
# At the moment we only enrich Spans created by the Prompt decorators
210220
# As we add Instrumentors for other libraries, this function must
211221
# be expanded
212-
return span.parent is not None and is_llm_provider_call(span=span)
222+
return span.parent is not None and (
223+
is_llm_provider_call(span=span) or span.name == HUMANLOOP_INTERCEPTED_HL_CALL_SPAN_NAME
224+
)
213225

214226
@classmethod
215227
def _write_start_end_times(cls, span: ReadableSpan):

0 commit comments

Comments
 (0)