Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions apisix/cli/ngx_tpl.lua
Original file line number Diff line number Diff line change
Expand Up @@ -829,6 +829,14 @@ http {
set $llm_model '';
set $llm_prompt_tokens '0';
set $llm_completion_tokens '0';
set $llm_total_tokens '0';
set $llm_stream 'false';
set $llm_has_tool_calls 'false';
set $llm_tool_count '0';
set $llm_end_user_id '';
set $llm_cache_read_input_tokens '0';
set $llm_cache_creation_input_tokens '0';
set $llm_reasoning_tokens '0';


{% if use_apisix_base then %}
Expand Down
8 changes: 8 additions & 0 deletions apisix/core/ctx.lua
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,14 @@ do
llm_model = true,
llm_prompt_tokens = true,
llm_completion_tokens = true,
llm_total_tokens = true,
llm_stream = true,
llm_has_tool_calls = true,
llm_tool_count = true,
llm_end_user_id = true,
llm_cache_read_input_tokens = true,
llm_cache_creation_input_tokens = true,
llm_reasoning_tokens = true,

upstream_mirror_host = true,
upstream_mirror_uri = true,
Expand Down
12 changes: 12 additions & 0 deletions apisix/plugins/ai-protocols/anthropic-messages.lua
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,14 @@ function _M.parse_sse_event(event, ctx, state)
end
return { type = "skip" }

elseif event.type == "content_block_start" then
local data = core.json.decode(event.data, { null_as_nil = true })
if data and type(data.content_block) == "table"
and data.content_block.type == "tool_use" then
return { type = "skip", has_tool_call = true }
end
return { type = "skip" }

elseif event.type == "message_delta" then
local data, err = core.json.decode(event.data, { null_as_nil = true })
if not data then
Expand Down Expand Up @@ -102,6 +110,8 @@ function _M.parse_sse_event(event, ctx, state)
prompt_tokens = usage.input_tokens or 0,
completion_tokens = usage.output_tokens or 0,
total_tokens = (usage.input_tokens or 0) + (usage.output_tokens or 0),
cache_read_input_tokens = usage.cache_read_input_tokens or 0,
cache_creation_input_tokens = usage.cache_creation_input_tokens or 0,
},
raw_usage = usage,
}
Expand Down Expand Up @@ -169,6 +179,8 @@ function _M.extract_usage(res_body)
prompt_tokens = prompt,
completion_tokens = completion,
total_tokens = prompt + completion,
cache_read_input_tokens = raw.cache_read_input_tokens or 0,
cache_creation_input_tokens = raw.cache_creation_input_tokens or 0,
}, raw
end

Expand Down
29 changes: 24 additions & 5 deletions apisix/plugins/ai-protocols/openai-chat.lua
Original file line number Diff line number Diff line change
Expand Up @@ -79,14 +79,18 @@ function _M.parse_sse_event(event, ctx, state)

local result = { type = "delta", data = data }

-- Extract text content from choices
-- Extract text content and detect tool calls from choices
if type(data.choices) == "table" and #data.choices > 0 then
local texts = {}
for _, choice in ipairs(data.choices) do
if type(choice) == "table"
and type(choice.delta) == "table"
and type(choice.delta.content) == "string" then
core.table.insert(texts, choice.delta.content)
if type(choice) == "table" and type(choice.delta) == "table" then
if type(choice.delta.content) == "string" then
core.table.insert(texts, choice.delta.content)
end
if type(choice.delta.tool_calls) == "table"
and #choice.delta.tool_calls > 0 then
result.has_tool_call = true
end
end
end
if #texts > 0 then
Expand All @@ -97,10 +101,18 @@ function _M.parse_sse_event(event, ctx, state)
-- Extract usage (null for non-final chunks; cjson decodes null as userdata)
if type(data.usage) == "table" then
result.type = "usage"
local pd = type(data.usage.prompt_tokens_details) == "table"
and data.usage.prompt_tokens_details
local cd = type(data.usage.completion_tokens_details) == "table"
and data.usage.completion_tokens_details
result.usage = {
prompt_tokens = data.usage.prompt_tokens or 0,
completion_tokens = data.usage.completion_tokens or 0,
total_tokens = data.usage.total_tokens or 0,
cache_read_input_tokens = pd and pd.cached_tokens
or data.usage.prompt_cache_hit_tokens or 0,
cache_creation_input_tokens = pd and pd.cache_creation_input_tokens or 0,
reasoning_tokens = cd and cd.reasoning_tokens or 0,
}
result.raw_usage = data.usage
end
Expand Down Expand Up @@ -160,10 +172,17 @@ function _M.extract_usage(res_body)
return nil, nil
end
local raw = res_body.usage
local pd = type(raw.prompt_tokens_details) == "table" and raw.prompt_tokens_details
local cd = type(raw.completion_tokens_details) == "table" and raw.completion_tokens_details
-- OpenAI uses prompt_tokens_details.cached_tokens; DeepSeek uses prompt_cache_hit_tokens
local cache_read = pd and pd.cached_tokens or raw.prompt_cache_hit_tokens or 0
return {
prompt_tokens = raw.prompt_tokens or 0,
completion_tokens = raw.completion_tokens or 0,
total_tokens = raw.total_tokens or (raw.prompt_tokens or 0) + (raw.completion_tokens or 0),
cache_read_input_tokens = cache_read,
cache_creation_input_tokens = pd and pd.cache_creation_input_tokens or 0,
reasoning_tokens = cd and cd.reasoning_tokens or 0,
}, raw
end

Expand Down
40 changes: 29 additions & 11 deletions apisix/plugins/ai-protocols/openai-responses.lua
Original file line number Diff line number Diff line change
Expand Up @@ -68,16 +68,30 @@ function _M.parse_sse_event(event, ctx, state)
core.log.warn("failed to decode response.completed SSE data: ", err)
return result
end
if type(data.response) == "table"
and type(data.response.usage) == "table" then
local usage = data.response.usage
result.type = "usage_and_done"
result.usage = {
prompt_tokens = usage.input_tokens or 0,
completion_tokens = usage.output_tokens or 0,
total_tokens = usage.total_tokens or 0,
}
result.raw_usage = usage
if type(data.response) == "table" then
local resp = data.response
if type(resp.usage) == "table" then
local usage = resp.usage
result.type = "usage_and_done"
result.usage = {
prompt_tokens = usage.input_tokens or 0,
completion_tokens = usage.output_tokens or 0,
total_tokens = usage.total_tokens or 0,
cache_read_input_tokens = type(usage.input_tokens_details) == "table"
and usage.input_tokens_details.cached_tokens or 0,
reasoning_tokens = type(usage.output_tokens_details) == "table"
and usage.output_tokens_details.reasoning_tokens or 0,
}
result.raw_usage = usage
end
if type(resp.output) == "table" then
for _, item in ipairs(resp.output) do
if type(item) == "table" and item.type == "function_call" then
result.has_tool_call = true
break
end
end
end
end
return result

Expand Down Expand Up @@ -135,13 +149,17 @@ function _M.extract_usage(res_body)
return nil, nil
end
local raw = res_body.usage
-- Responses API uses input_tokens / output_tokens
local idetails = type(raw.input_tokens_details) == "table" and raw.input_tokens_details
local odetails = type(raw.output_tokens_details) == "table" and raw.output_tokens_details
local prompt = raw.input_tokens or 0
local completion = raw.output_tokens or 0
return {
prompt_tokens = prompt,
completion_tokens = completion,
total_tokens = raw.total_tokens or (prompt + completion),
cache_read_input_tokens = idetails and idetails.cached_tokens or 0,
cache_creation_input_tokens = 0,
reasoning_tokens = odetails and odetails.reasoning_tokens or 0,
}, raw
end

Expand Down
44 changes: 28 additions & 16 deletions apisix/plugins/ai-providers/base.lua
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,13 @@ local function merge_usage(ctx, parsed)
ctx.ai_token_usage[k] = v
end
end
-- Recompute total from accumulated parts (handles split events, e.g. Anthropic
-- message_start carries input tokens and message_delta carries output tokens)
local computed = (ctx.ai_token_usage.prompt_tokens or 0)
+ (ctx.ai_token_usage.completion_tokens or 0)
if computed > (ctx.ai_token_usage.total_tokens or 0) then
ctx.ai_token_usage.total_tokens = computed
end
end

local raw = parsed.raw_usage or parsed.usage
Expand Down Expand Up @@ -396,6 +403,9 @@ function _M.parse_response(self, ctx, res, client_proto, converter, conf)
end
ctx.var.llm_prompt_tokens = ctx.ai_token_usage.prompt_tokens or 0
ctx.var.llm_completion_tokens = ctx.ai_token_usage.completion_tokens or 0
ctx.var.llm_cache_read_input_tokens = ctx.ai_token_usage.cache_read_input_tokens or 0
ctx.var.llm_cache_creation_input_tokens = ctx.ai_token_usage.cache_creation_input_tokens or 0
ctx.var.llm_reasoning_tokens = ctx.ai_token_usage.reasoning_tokens or 0

local response_text = client_proto.extract_response_text(res_body)
if response_text then
Expand Down Expand Up @@ -446,13 +456,10 @@ function _M.parse_streaming_response(self, ctx, res, target_proto, converter, co
local bytes_read = 0

-- streaming_flush_interval_ms controls both flush strategy and the thread:
-- == 0 : no thread; lua_response_filter flushes synchronously
-- per chunk via ngx.flush(true), guaranteeing immediate
-- client delivery.
-- > 0 (default: 10): background thread calls ngx.flush(false) every N ms;
-- lua_response_filter skips per-chunk flush for maximum
-- throughput. Useful when the upstream bursts multiple
-- tokens at once.
-- == 0 (default): no thread; lua_response_filter flushes synchronously
-- per chunk, guaranteeing immediate client delivery.
-- > 0 : background thread handles periodic flushing;
-- lua_response_filter skips flush for maximum throughput.
local flush_interval_ms = conf and conf.streaming_flush_interval_ms or 0
-- async_flush: true when the interval thread is responsible for flushing
local async_flush = flush_interval_ms > 0
Expand Down Expand Up @@ -591,6 +598,9 @@ function _M.parse_streaming_response(self, ctx, res, target_proto, converter, co
for _, event in ipairs(events) do
-- Target protocol parses the provider's SSE format
local parsed = target_proto.parse_sse_event(event, ctx, sse_state)
if parsed and parsed.has_tool_call then
ctx.var.llm_has_tool_calls = "true"
end
if not parsed or parsed.type == "skip" then
goto CONTINUE
end
Expand Down Expand Up @@ -618,6 +628,12 @@ function _M.parse_streaming_response(self, ctx, res, target_proto, converter, co
merge_usage(ctx, parsed)
ctx.var.llm_prompt_tokens = ctx.ai_token_usage.prompt_tokens
ctx.var.llm_completion_tokens = ctx.ai_token_usage.completion_tokens
ctx.var.llm_total_tokens = ctx.ai_token_usage.total_tokens or 0
ctx.var.llm_cache_read_input_tokens =
ctx.ai_token_usage.cache_read_input_tokens or 0
ctx.var.llm_cache_creation_input_tokens =
ctx.ai_token_usage.cache_creation_input_tokens or 0
ctx.var.llm_reasoning_tokens = ctx.ai_token_usage.reasoning_tokens or 0
ctx.var.llm_response_text = table.concat(contents, "")
end

Expand Down Expand Up @@ -702,15 +718,11 @@ function _M.parse_streaming_response(self, ctx, res, target_proto, converter, co
return status, limit_hit .. " exceeded"
end

-- WORKAROUND, not a real fix: yield to the nginx scheduler so other
-- coroutines on this worker (health checks, concurrent requests) can
-- run. body_reader() and ngx.flush() do not yield when the upstream
-- socket already has data buffered or the downstream client drains
-- immediately, so under bursty SSE upstreams this loop can monopolize
-- the worker CPU. ngx.sleep(0) only prevents a single request from
-- monopolizing the worker; it does not bound per-stream CPU time, add
-- backpressure, or time out stalled streams. See #13256 for a proper
-- solution.
-- Yield to the nginx scheduler so other coroutines on this worker
-- (health checks, concurrent requests) can run. body_reader() and
-- ngx.flush() do not yield when the upstream socket already has data
-- buffered or the downstream client drains immediately, so under
-- bursty SSE upstreams this loop can monopolize the worker CPU.
ngx.sleep(0)

end
Expand Down
88 changes: 87 additions & 1 deletion apisix/plugins/ai-proxy/base.lua
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,84 @@ local require = require
local pcall = pcall
local pairs = pairs
local type = type
local ipairs = ipairs
local table = table
local exporter = require("apisix.plugins.prometheus.exporter")
local protocols = require("apisix.plugins.ai-protocols")
local transport_http = require("apisix.plugins.ai-transport.http")
local log_sanitize = require("apisix.utils.log-sanitize")
local apisix_upstream = require("resty.apisix.upstream")

local function extract_end_user_id(body, protocol)
if type(body) ~= "table" then
return nil
end
if protocol == "anthropic-messages" then
local meta = body.metadata
if type(meta) == "table" and type(meta.user_id) == "string" then
return meta.user_id
end
return nil
end
-- openai-chat, openai-responses: safety_identifier takes precedence over user
if type(body.safety_identifier) == "string" then
return body.safety_identifier
end
if type(body.user) == "string" then
return body.user
end
return nil
end


local function count_request_tools(body)
if type(body) ~= "table" then
return 0
end
local tools = body.tools
if type(tools) == "table" then
return #tools
end
return 0
end


local function detect_tool_calls_in_response(body)
if type(body) ~= "table" then
return false
end
-- OpenAI Chat / Responses: choices[].message.tool_calls
if type(body.choices) == "table" then
for _, choice in ipairs(body.choices) do
if type(choice) == "table" then
local msg = choice.message
if type(msg) == "table" and type(msg.tool_calls) == "table"
and #msg.tool_calls > 0 then
return true
end
end
end
end
-- Anthropic Messages: content[].type == "tool_use"
if type(body.content) == "table" then
for _, block in ipairs(body.content) do
if type(block) == "table" and block.type == "tool_use" then
return true
end
end
end
-- OpenAI Responses: output[].type == "function_call"
if type(body.output) == "table" then
for _, item in ipairs(body.output) do
if type(item) == "table" and item.type == "function_call" then
return true
end
end
end
return false
end


local _M = {}


Expand Down Expand Up @@ -208,6 +279,15 @@ function _M.before_proxy(conf, ctx, on_error)
return 500, body
end

-- Compute built-in AI log fields from the final upstream request
local final_body = params.body
ctx.var.llm_stream = ctx.var.request_type == "ai_stream" and "true" or "false"
ctx.var.llm_tool_count = count_request_tools(final_body)
local end_user = extract_end_user_id(final_body, target_proto)
if end_user then
ctx.var.llm_end_user_id = end_user
end

core.log.info("sending request to LLM server: ",
core.json.delay_encode(log_sanitize.redact_params(params), true))

Expand Down Expand Up @@ -309,12 +389,18 @@ function _M.before_proxy(conf, ctx, on_error)
code, body = ai_provider:parse_streaming_response(
ctx, res, target_proto_module, converter, conf)
else
local _, parse_err, parse_status = ai_provider:parse_response(
local res_body, parse_err, parse_status = ai_provider:parse_response(
ctx, res, client_proto, converter, conf)
if parse_err then
code = parse_status or 500
body = parse_err
end
if ctx.ai_token_usage then
ctx.var.llm_total_tokens = ctx.ai_token_usage.total_tokens or 0
end
if res_body and detect_tool_calls_in_response(res_body) then
ctx.var.llm_has_tool_calls = "true"
end
end

-- Finalize upstream state with response_time after body is consumed
Expand Down
Loading
Loading