Skip to content
Open
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,12 @@
"method": "get_details",
"span_name": "watsonx.get_details",
},
{
"module": "ibm_watsonx_ai.foundation_models",
"object": "ModelInference",
"method": "chat",
"span_name": "watsonx.chat",
},
]

WATSON_MODULES = [
Expand Down Expand Up @@ -487,16 +493,45 @@ def wrapper(wrapped, instance, args, kwargs):


@dont_throw
def _handle_input(span, event_logger, name, instance, response_counter, args, kwargs):
def _handle_input(span, event_logger, name, instance, args, kwargs):
messages = None
_set_api_attributes(span)

if "generate" in name:
set_model_input_attributes(span, instance)
if should_send_prompts():
_set_input_attributes(span, args=args, kwargs=kwargs)

if "chat" in name:
messages = kwargs.get("messages")
if messages is None and args and isinstance(args[0], list):
messages = args[0]
if messages and span.is_recording():
for index, msg in enumerate(messages):
if not isinstance(msg, dict):
continue
role = msg.get("role")
content = msg.get("content")
if role and isinstance(content, str):
_set_span_attribute(
span,
f"{GenAIAttributes.GEN_AI_PROMPT}.{index}.{role}",
content.strip(),
)

if should_emit_events() and event_logger:
_emit_input_events(args, kwargs, event_logger)
if "chat" in name and isinstance(messages, list):
for msg in messages:
if not isinstance(msg, dict):
continue
emit_event(
MessageEvent(
content=msg.get("content"),
role=msg.get("role", "user"),
),
event_logger,
)
else:
_emit_input_events(args, kwargs, event_logger)
Comment thread
coderabbitai[bot] marked this conversation as resolved.


@dont_throw
Expand Down Expand Up @@ -547,6 +582,79 @@ def _handle_stream_response(
_set_stream_response_attributes(span, stream_response)


@dont_throw
def _handle_chat_response(
span,
event_logger,
response,
token_histogram,
response_counter,
duration_histogram,
duration,
):
if not span.is_recording() or not isinstance(response, dict):
return

model_id = response.get("model_id") or response.get("model") or "unknown"

_set_span_attribute(span, GenAIAttributes.GEN_AI_RESPONSE_MODEL, model_id)

# Content
choices = response.get("choices") or []
for index, choice in enumerate(choices):
message = choice.get("message", {})
content = message.get("content")
finish_reason = choice.get("finish_reason") or "unknown"

if should_emit_events() and event_logger:
emit_event(
ChoiceEvent(
index=index,
message=message,
finish_reason=finish_reason or "unknown",
),
event_logger,
)
elif content and should_send_prompts():
_set_span_attribute(
span,
f"{GenAIAttributes.GEN_AI_COMPLETION}.{index}.content",
content,
)

if response_counter:
attributes = {
GenAIAttributes.GEN_AI_RESPONSE_MODEL: model_id,
SpanAttributes.LLM_RESPONSE_STOP_REASON: finish_reason,
}
response_counter.add(1, attributes=attributes)
Comment thread
adharshctr marked this conversation as resolved.

# Usage
usage = response.get("usage") or {}
prompt_tokens = usage.get("prompt_tokens", 0)
completion_tokens = usage.get("completion_tokens", 0)
total_tokens = usage.get("total_tokens", prompt_tokens + completion_tokens)

_set_span_attribute(span, GenAIAttributes.GEN_AI_USAGE_INPUT_TOKENS, prompt_tokens)
_set_span_attribute(span, GenAIAttributes.GEN_AI_USAGE_OUTPUT_TOKENS, completion_tokens)
_set_span_attribute(span, SpanAttributes.LLM_USAGE_TOTAL_TOKENS, total_tokens)

shared_attributes = _metric_shared_attributes(response_model=model_id)

if token_histogram:
token_histogram.record(
completion_tokens,
attributes={**shared_attributes, GenAIAttributes.GEN_AI_TOKEN_TYPE: "output"},
)
token_histogram.record(
prompt_tokens,
attributes={**shared_attributes, GenAIAttributes.GEN_AI_TOKEN_TYPE: "input"},
)

if duration and duration_histogram:
duration_histogram.record(duration, attributes=shared_attributes)

Comment thread
adharshctr marked this conversation as resolved.

@_with_tracer_wrapper
def _wrap(
tracer,
Expand Down Expand Up @@ -578,15 +686,22 @@ def _wrap(
},
)

_handle_input(span, event_logger, name, instance, args, kwargs)

if "generate" in name:
if "generate" in name or "chat" in name:
if to_wrap.get("method") == "generate_text_stream":
if (raw_flag := kwargs.get("raw_response", None)) is None:
kwargs = {**kwargs, "raw_response": True}
elif raw_flag is False:
kwargs["raw_response"] = True

if to_wrap.get("method") == "chat":
if "prompt" in kwargs and "messages" not in kwargs:
prompt = kwargs.pop("prompt")
kwargs["messages"] = [
{"role": "user", "content": prompt}
]
Comment thread
coderabbitai[bot] marked this conversation as resolved.
_handle_input(span, event_logger, name, instance, args, kwargs)

try:
start_time = time.time()
response = wrapped(*args, **kwargs)
Expand Down Expand Up @@ -629,6 +744,18 @@ def _wrap(
duration_histogram,
duration,
)

if "chat" in name:
duration = end_time - start_time
_handle_chat_response(
span,
event_logger,
response,
token_histogram,
response_counter,
duration_histogram,
duration,
)
span.end()
return response

Expand Down