Skip to content

Commit 5c4765b

Browse files
refactor: simplify embedding tracer handlers (OPEN-10480)
Code review findings addressed: - Move per-call imports of _openai_embedding_common to module-level (was in hot path of every embedding call). - Extract build_embedding_step_kwargs into _openai_embedding_common so that sync and async OpenAI handlers each become ~10 lines instead of ~50, and LiteLLM reuses the same kwargs assembly. - Drop LiteLLM's local _parse_embedding_response and _get_embedding_model_parameters; both now delegate to the shared helpers (LiteLLM-specific timeout/api_base/api_version/cost/metadata are layered on top of the common kwargs). - Type Bedrock _parse_embedding_output return as Tuple[Union[List[float], List[List[float]]], int, int] instead of bare tuple. Net: -34 lines across the 5 touched source files. Tests unchanged, all 77 embedding tests + 448 lib tests still green. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 0f1191a commit 5c4765b

5 files changed

Lines changed: 101 additions & 135 deletions

File tree

src/openlayer/lib/integrations/_openai_embedding_common.py

Lines changed: 51 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,21 @@
1-
"""Shared parsing helpers for OpenAI sync + async embedding tracers."""
1+
"""Shared parsing helpers for OpenAI-shaped embedding tracers (OpenAI, AsyncOpenAI, LiteLLM)."""
22

3-
from typing import Any, Dict, List, Tuple, Union
3+
from typing import Any, Dict, List, Optional, Tuple, Union
44

55

66
def parse_embedding_response(
77
response: Any,
88
) -> Tuple[Union[List[float], List[List[float]]], int, int]:
9-
"""Extract (embeddings, dimensions, count) from an OpenAI EmbeddingResponse.
9+
"""Extract (embeddings, dimensions, count) from an OpenAI-shaped EmbeddingResponse.
1010
1111
For a single input, returns the vector directly.
1212
For a batch, returns a list of vectors.
1313
"""
1414
try:
1515
data = getattr(response, "data", None)
16-
if data is None:
16+
if data is None and isinstance(response, dict):
17+
data = response.get("data", [])
18+
if not data:
1719
return [], 0, 0
1820
embeddings = [
1921
item["embedding"] if isinstance(item, dict) else item.embedding
@@ -35,3 +37,48 @@ def get_embedding_model_parameters(kwargs: Dict[str, Any]) -> Dict[str, Any]:
3537
"encoding_format": kwargs.get("encoding_format"),
3638
"user": kwargs.get("user"),
3739
}
40+
41+
42+
def build_embedding_step_kwargs(
43+
response: Any,
44+
call_kwargs: Dict[str, Any],
45+
start_time: float,
46+
end_time: float,
47+
*,
48+
name: str,
49+
provider: str,
50+
inference_id: Optional[str] = None,
51+
) -> Dict[str, Any]:
52+
"""Build the kwargs to pass to ``tracer.add_embedding_step_to_trace``.
53+
54+
Common boilerplate for OpenAI-shaped responses (OpenAI sync/async, LiteLLM).
55+
Callers may layer extra fields (cost, extra_metadata, model_parameters) on
56+
top of the returned dict before invoking the tracer helper.
57+
"""
58+
model_name = getattr(response, "model", call_kwargs.get("model", "unknown"))
59+
embeddings, dim, count = parse_embedding_response(response)
60+
usage = getattr(response, "usage", None)
61+
prompt_tokens = getattr(usage, "prompt_tokens", 0) if usage else 0
62+
total_tokens = getattr(usage, "total_tokens", prompt_tokens) if usage else prompt_tokens
63+
64+
return {
65+
"name": name,
66+
"end_time": end_time,
67+
"inputs": {"input": call_kwargs.get("input")},
68+
"output": embeddings,
69+
"latency": (end_time - start_time) * 1000,
70+
"tokens": total_tokens,
71+
"prompt_tokens": prompt_tokens,
72+
"model": model_name,
73+
"model_parameters": get_embedding_model_parameters(call_kwargs),
74+
"embedding_dimensions": dim,
75+
"embedding_count": count,
76+
"raw_output": (
77+
response.model_dump()
78+
if hasattr(response, "model_dump")
79+
else str(response)
80+
),
81+
"provider": provider,
82+
"id": inference_id,
83+
"metadata": {"provider": provider},
84+
}

src/openlayer/lib/integrations/async_openai_tracer.py

Lines changed: 10 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import openai
1818

1919
from ..tracing import tracer
20+
from ._openai_embedding_common import build_embedding_step_kwargs
2021
from .openai_tracer import (
2122
get_model_parameters,
2223
create_trace_args,
@@ -725,46 +726,21 @@ async def handle_embedding_async(
725726
**kwargs,
726727
) -> Any:
727728
"""Trace an async AsyncOpenAI client.embeddings.create() call."""
728-
from ._openai_embedding_common import (
729-
get_embedding_model_parameters as _get_embedding_model_parameters,
730-
)
731-
from ._openai_embedding_common import (
732-
parse_embedding_response as _parse_embedding_response,
733-
)
734-
735729
start_time = time.time()
736730
response = await original_func(*args, **kwargs)
737731
end_time = time.time()
738732

739733
try:
740-
model_name = getattr(response, "model", kwargs.get("model", "unknown"))
741-
embeddings, dim, count = _parse_embedding_response(response)
742-
usage = getattr(response, "usage", None)
743-
prompt_tokens = getattr(usage, "prompt_tokens", 0) if usage else 0
744-
total_tokens = (
745-
getattr(usage, "total_tokens", prompt_tokens) if usage else prompt_tokens
746-
)
747-
748734
tracer.add_embedding_step_to_trace(
749-
name="OpenAI Embedding",
750-
end_time=end_time,
751-
inputs={"input": kwargs.get("input")},
752-
output=embeddings,
753-
latency=(end_time - start_time) * 1000,
754-
tokens=total_tokens,
755-
prompt_tokens=prompt_tokens,
756-
model=model_name,
757-
model_parameters=_get_embedding_model_parameters(kwargs),
758-
embedding_dimensions=dim,
759-
embedding_count=count,
760-
raw_output=(
761-
response.model_dump()
762-
if hasattr(response, "model_dump")
763-
else str(response)
764-
),
765-
provider="OpenAI",
766-
id=inference_id,
767-
metadata={"provider": "OpenAI"},
735+
**build_embedding_step_kwargs(
736+
response,
737+
kwargs,
738+
start_time,
739+
end_time,
740+
name="OpenAI Embedding",
741+
provider="OpenAI",
742+
inference_id=inference_id,
743+
)
768744
)
769745
except Exception as e:
770746
logger.error(

src/openlayer/lib/integrations/bedrock_tracer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import logging
66
import time
77
from functools import wraps
8-
from typing import TYPE_CHECKING, Any, Dict, Iterator, Optional, Union
8+
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Tuple, Union
99

1010
from botocore.response import StreamingBody
1111

@@ -237,7 +237,7 @@ def _parse_embedding_input(body_data: Dict[str, Any], model_id: str) -> Dict[str
237237

238238
def _parse_embedding_output(
239239
response_data: Dict[str, Any], model_id: str
240-
) -> tuple:
240+
) -> Tuple[Union[List[float], List[List[float]]], int, int]:
241241
"""Returns (embeddings, dimensions, count)."""
242242
if model_id.startswith("amazon.titan-embed"):
243243
emb = response_data.get("embedding", [])

src/openlayer/lib/integrations/litellm_tracer.py

Lines changed: 28 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
from ..tracing import tracer
2020
from ..tracing import enums as tracer_enums
21+
from ._openai_embedding_common import build_embedding_step_kwargs
2122

2223
logger = logging.getLogger(__name__)
2324

@@ -367,40 +368,38 @@ def handle_embedding(
367368
try:
368369
model_name = kwargs.get("model", getattr(response, "model", "unknown"))
369370
provider = detect_provider_from_response(response, model_name)
370-
embeddings, dim, count = _parse_embedding_response(response)
371-
usage_data = extract_usage_from_response(response)
372371
extra_metadata = extract_litellm_metadata(response, model_name)
373-
cost = extra_metadata.get("cost", None)
374-
375-
prompt_tokens = usage_data.get("prompt_tokens") or 0
376-
total_tokens = usage_data.get("total_tokens") or prompt_tokens
372+
usage_data = extract_usage_from_response(response)
377373

378-
tracer.add_embedding_step_to_trace(
374+
step_kwargs = build_embedding_step_kwargs(
375+
response,
376+
kwargs,
377+
start_time,
378+
end_time,
379379
name="LiteLLM Embedding",
380-
end_time=end_time,
381-
inputs={"input": kwargs.get("input")},
382-
output=embeddings,
383-
latency=(end_time - start_time) * 1000,
384-
tokens=total_tokens,
385-
prompt_tokens=prompt_tokens,
386-
model=model_name,
387-
model_parameters=_get_embedding_model_parameters(kwargs),
388-
embedding_dimensions=dim,
389-
embedding_count=count,
390-
raw_output=(
391-
response.model_dump()
392-
if hasattr(response, "model_dump")
393-
else str(response)
394-
),
395380
provider=provider,
396-
cost=cost,
397-
id=inference_id,
398-
metadata={
399-
"provider": provider,
400-
"litellm_model": model_name,
401-
**extra_metadata,
402-
},
381+
inference_id=inference_id,
403382
)
383+
384+
# LiteLLM-specific overlays: usage uses LiteLLM's normalized dict, extra
385+
# connection params, response cost, and provider metadata.
386+
prompt_tokens = usage_data.get("prompt_tokens") or 0
387+
step_kwargs["prompt_tokens"] = prompt_tokens
388+
step_kwargs["tokens"] = usage_data.get("total_tokens") or prompt_tokens
389+
step_kwargs["model_parameters"] = {
390+
**step_kwargs["model_parameters"],
391+
"timeout": kwargs.get("timeout"),
392+
"api_base": kwargs.get("api_base"),
393+
"api_version": kwargs.get("api_version"),
394+
}
395+
step_kwargs["cost"] = extra_metadata.get("cost", None)
396+
step_kwargs["metadata"] = {
397+
**step_kwargs["metadata"],
398+
"litellm_model": model_name,
399+
**extra_metadata,
400+
}
401+
402+
tracer.add_embedding_step_to_trace(**step_kwargs)
404403
except Exception as e:
405404
logger.error(
406405
"Failed to trace the LiteLLM embedding request with Openlayer. %s", e
@@ -409,38 +408,6 @@ def handle_embedding(
409408
return response
410409

411410

412-
def _parse_embedding_response(response: Any) -> tuple:
413-
"""Returns (embeddings, dimensions, count). Mirrors OpenAI EmbeddingResponse."""
414-
try:
415-
data = getattr(response, "data", None)
416-
if data is None and isinstance(response, dict):
417-
data = response.get("data", [])
418-
if not data:
419-
return [], 0, 0
420-
embeddings = [
421-
item["embedding"] if isinstance(item, dict) else item.embedding
422-
for item in data
423-
]
424-
if not embeddings:
425-
return [], 0, 0
426-
if len(embeddings) == 1:
427-
return embeddings[0], len(embeddings[0]), 1
428-
return embeddings, len(embeddings[0]), len(embeddings)
429-
except Exception:
430-
return [], 0, 0
431-
432-
433-
def _get_embedding_model_parameters(kwargs: Dict[str, Any]) -> Dict[str, Any]:
434-
return {
435-
"dimensions": kwargs.get("dimensions"),
436-
"encoding_format": kwargs.get("encoding_format"),
437-
"user": kwargs.get("user"),
438-
"timeout": kwargs.get("timeout"),
439-
"api_base": kwargs.get("api_base"),
440-
"api_version": kwargs.get("api_version"),
441-
}
442-
443-
444411
def get_model_parameters(kwargs: Dict[str, Any]) -> Dict[str, Any]:
445412
"""Gets the model parameters from the kwargs."""
446413
return {

src/openlayer/lib/integrations/openai_tracer.py

Lines changed: 10 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
ImageContent,
2828
TextContent,
2929
)
30+
from ._openai_embedding_common import build_embedding_step_kwargs
3031

3132
logger = logging.getLogger(__name__)
3233

@@ -1635,46 +1636,21 @@ def handle_embedding(
16351636
**kwargs,
16361637
) -> Any:
16371638
"""Trace a sync OpenAI client.embeddings.create() call."""
1638-
from ._openai_embedding_common import (
1639-
get_embedding_model_parameters as _get_embedding_model_parameters,
1640-
)
1641-
from ._openai_embedding_common import (
1642-
parse_embedding_response as _parse_embedding_response,
1643-
)
1644-
16451639
start_time = time.time()
16461640
response = original_func(*args, **kwargs)
16471641
end_time = time.time()
16481642

16491643
try:
1650-
model_name = getattr(response, "model", kwargs.get("model", "unknown"))
1651-
embeddings, dim, count = _parse_embedding_response(response)
1652-
usage = getattr(response, "usage", None)
1653-
prompt_tokens = getattr(usage, "prompt_tokens", 0) if usage else 0
1654-
total_tokens = (
1655-
getattr(usage, "total_tokens", prompt_tokens) if usage else prompt_tokens
1656-
)
1657-
16581644
tracer.add_embedding_step_to_trace(
1659-
name="OpenAI Embedding",
1660-
end_time=end_time,
1661-
inputs={"input": kwargs.get("input")},
1662-
output=embeddings,
1663-
latency=(end_time - start_time) * 1000,
1664-
tokens=total_tokens,
1665-
prompt_tokens=prompt_tokens,
1666-
model=model_name,
1667-
model_parameters=_get_embedding_model_parameters(kwargs),
1668-
embedding_dimensions=dim,
1669-
embedding_count=count,
1670-
raw_output=(
1671-
response.model_dump()
1672-
if hasattr(response, "model_dump")
1673-
else str(response)
1674-
),
1675-
provider="OpenAI",
1676-
id=inference_id,
1677-
metadata={"provider": "OpenAI"},
1645+
**build_embedding_step_kwargs(
1646+
response,
1647+
kwargs,
1648+
start_time,
1649+
end_time,
1650+
name="OpenAI Embedding",
1651+
provider="OpenAI",
1652+
inference_id=inference_id,
1653+
)
16781654
)
16791655
except Exception as e:
16801656
logger.error(

0 commit comments

Comments
 (0)