diff --git a/src/llm/language_model/legacy/servable.cpp b/src/llm/language_model/legacy/servable.cpp index 8e244df219..d8faa20b42 100644 --- a/src/llm/language_model/legacy/servable.cpp +++ b/src/llm/language_model/legacy/servable.cpp @@ -234,13 +234,12 @@ absl::Status LegacyServable::preparePartialResponse(std::shared_ptrresults.finish_reasons.empty() ? ov::genai::GenerationFinishReason::STOP : legacyExecutionContext->results.finish_reasons[0]; + executionContext->apiHandler->setPromptTokensUsage(legacyExecutionContext->results.perf_metrics.get_num_input_tokens()); + executionContext->apiHandler->setCompletionTokensUsage(legacyExecutionContext->results.perf_metrics.get_num_generated_tokens()); std::string serializedChunk = executionContext->apiHandler->serializeStreamingChunk(lastTextChunk, finishReason); if (!serializedChunk.empty()) { executionContext->response = wrapTextInServerSideEventMessage(serializedChunk); } - - executionContext->apiHandler->setPromptTokensUsage(legacyExecutionContext->results.perf_metrics.get_num_input_tokens()); - executionContext->apiHandler->setCompletionTokensUsage(legacyExecutionContext->results.perf_metrics.get_num_generated_tokens()); if (executionContext->apiHandler->getStreamOptions().includeUsage) executionContext->response += wrapTextInServerSideEventMessage(executionContext->apiHandler->serializeStreamingUsageChunk()); diff --git a/src/llm/visual_language_model/legacy/servable.cpp b/src/llm/visual_language_model/legacy/servable.cpp index 033cb8641d..6e6fcadcd0 100644 --- a/src/llm/visual_language_model/legacy/servable.cpp +++ b/src/llm/visual_language_model/legacy/servable.cpp @@ -250,12 +250,12 @@ absl::Status VisualLanguageModelLegacyServable::preparePartialResponse(std::shar } // Legacy generation path always runs with batch=1, so we read the single finish reason at index 0. ov::genai::GenerationFinishReason finishReason = legacyExecutionContext->results.finish_reasons.empty() ? ov::genai::GenerationFinishReason::STOP : legacyExecutionContext->results.finish_reasons[0]; + executionContext->apiHandler->setPromptTokensUsage(legacyExecutionContext->results.perf_metrics.get_num_input_tokens()); + executionContext->apiHandler->setCompletionTokensUsage(legacyExecutionContext->results.perf_metrics.get_num_generated_tokens()); std::string serializedChunk = executionContext->apiHandler->serializeStreamingChunk(lastTextChunk, finishReason); if (!serializedChunk.empty()) { executionContext->response = wrapTextInServerSideEventMessage(serializedChunk); } - executionContext->apiHandler->setPromptTokensUsage(legacyExecutionContext->results.perf_metrics.get_num_input_tokens()); - executionContext->apiHandler->setCompletionTokensUsage(legacyExecutionContext->results.perf_metrics.get_num_generated_tokens()); if (executionContext->apiHandler->getStreamOptions().includeUsage) executionContext->response += wrapTextInServerSideEventMessage(executionContext->apiHandler->serializeStreamingUsageChunk()); diff --git a/src/test/http_openai_handler_test.cpp b/src/test/http_openai_handler_test.cpp index 46a31a3337..3e36b68a0b 100644 --- a/src/test/http_openai_handler_test.cpp +++ b/src/test/http_openai_handler_test.cpp @@ -28,6 +28,9 @@ #include "../filesystem/filesystem.hpp" #include "../llm/apis/openai_completions.hpp" #include "../llm/apis/openai_responses.hpp" +#include "../llm/language_model/legacy/servable.hpp" +#include "../llm/visual_language_model/legacy/servable.hpp" +#include "../client_connection.hpp" #include #include "../module_names.hpp" #include "../servablemanagermodule.hpp" @@ -5652,3 +5655,167 @@ TEST_F(HttpOpenAIHandlerParsingTest, ParsingResponsesStreamOptionsRejected) { EXPECT_EQ(status, absl::InvalidArgumentError("stream_options is not supported in Responses API.")); } + +// Stub client that is never disconnected, used by the LM legacy servable tests below. +namespace { +struct NeverDisconnectedClient : public ovms::ClientConnection { + bool isDisconnected() const override { return false; } + void registerDisconnectionCallback(std::function) override {} +}; +} // namespace + +static std::shared_ptr makeLegacyResponsesContext( + const std::shared_ptr& tok, + size_t numInputTokens, size_t numGeneratedTokens, + ov::genai::GenerationFinishReason finishReason = ov::genai::GenerationFinishReason::STOP) { + auto ctx = std::make_shared(); + + ctx->payload.client = std::make_shared(); + ctx->payload.parsedJson = std::make_shared(); + ctx->payload.parsedJson->Parse(R"({"model":"llama","input":"test","stream":true})"); + ctx->endpoint = ovms::Endpoint::RESPONSES; + + auto apiHandler = std::make_shared( + *ctx->payload.parsedJson, ovms::Endpoint::RESPONSES, + std::chrono::system_clock::now(), *tok); + std::optional maxTokensLimit; + const absl::Status parseStatus = apiHandler->parseRequest(maxTokensLimit, 0, std::nullopt); + EXPECT_TRUE(parseStatus.ok()) << parseStatus; + if (!parseStatus.ok()) { + return nullptr; + } + ctx->apiHandler = apiHandler; + + ctx->results.finish_reasons.push_back(finishReason); + ctx->results.perf_metrics.num_input_tokens = numInputTokens; + ctx->results.perf_metrics.num_generated_tokens = numGeneratedTokens; + ctx->success = true; + // Signal that generation is done so preparePartialResponse goes straight to + // the "finish generation" branch without waiting. + ctx->readySignal.set_value(); + + ctx->textStreamer = std::make_shared( + *tok, [](std::string) { return ov::genai::StreamingStatus::RUNNING; }); + + return ctx; +} + +TEST_F(HttpOpenAIHandlerParsingTest, legacyServablePreparePartialResponseResponsesEndpointHasCorrectUsageInCompletedEvent) { + auto ctx = makeLegacyResponsesContext(tokenizer, /*numInputTokens=*/10, /*numGeneratedTokens=*/5); + std::shared_ptr ctxBase = ctx; + + ovms::LegacyServable servable; + ASSERT_EQ(servable.preparePartialResponse(ctxBase), absl::OkStatus()); + + const std::string& response = ctxBase->response; + ASSERT_NE(response.find("\"type\":\"response.completed\""), std::string::npos) + << "response.completed event must be present: " << response; + ASSERT_NE(response.find("\"output_tokens\":5"), std::string::npos) + << "output_tokens must equal num_generated_tokens from perf_metrics: " << response; + ASSERT_NE(response.find("\"input_tokens\":10"), std::string::npos) + << "input_tokens must equal num_input_tokens from perf_metrics: " << response; + ASSERT_NE(response.find("\"total_tokens\":15"), std::string::npos) + << "total_tokens must be input+output: " << response; + ASSERT_FALSE(ctxBase->sendLoopbackSignal); +} + +TEST_F(HttpOpenAIHandlerParsingTest, legacyServablePreparePartialResponseResponsesEndpointHasCorrectUsageOnLength) { + auto ctx = makeLegacyResponsesContext(tokenizer, /*numInputTokens=*/8, /*numGeneratedTokens=*/3, + ov::genai::GenerationFinishReason::LENGTH); + std::shared_ptr ctxBase = ctx; + + ovms::LegacyServable servable; + ASSERT_EQ(servable.preparePartialResponse(ctxBase), absl::OkStatus()); + + const std::string& response = ctxBase->response; + ASSERT_NE(response.find("\"type\":\"response.incomplete\""), std::string::npos) + << "response.incomplete event must be present for LENGTH finish reason: " << response; + ASSERT_NE(response.find("\"output_tokens\":3"), std::string::npos) + << "output_tokens must equal num_generated_tokens from perf_metrics: " << response; + ASSERT_NE(response.find("\"input_tokens\":8"), std::string::npos) + << "input_tokens must equal num_input_tokens from perf_metrics: " << response; +} + +TEST_F(HttpOpenAIHandlerParsingTest, vlmLegacyServablePreparePartialResponseResponsesEndpointHasCorrectUsageInCompletedEvent) { + auto ctx = std::make_shared(); + + ctx->payload.client = std::make_shared(); + ctx->payload.parsedJson = std::make_shared(); + ctx->payload.parsedJson->Parse(R"({"model":"llama","input":"test","stream":true})"); + ctx->endpoint = ovms::Endpoint::RESPONSES; + + auto apiHandler = std::make_shared( + *ctx->payload.parsedJson, ovms::Endpoint::RESPONSES, + std::chrono::system_clock::now(), *tokenizer); + std::optional maxTokensLimit; + const absl::Status parseStatus = apiHandler->parseRequest(maxTokensLimit, 0, std::nullopt); + ASSERT_TRUE(parseStatus.ok()) << parseStatus; + ctx->apiHandler = apiHandler; + + ctx->results.finish_reasons.push_back(ov::genai::GenerationFinishReason::STOP); + ctx->results.perf_metrics.num_input_tokens = 12; + ctx->results.perf_metrics.num_generated_tokens = 6; + ctx->success = true; + ctx->readySignal.set_value(); + ctx->textStreamer = std::make_shared( + *tokenizer, [](std::string) { return ov::genai::StreamingStatus::RUNNING; }); + + ovms::VisualLanguageModelLegacyServable servable; + std::shared_ptr ctxBase = ctx; + ASSERT_EQ(servable.preparePartialResponse(ctxBase), absl::OkStatus()); + + const std::string& response = ctxBase->response; + ASSERT_NE(response.find("\"type\":\"response.completed\""), std::string::npos) + << "response.completed event must be present: " << response; + ASSERT_NE(response.find("\"output_tokens\":6"), std::string::npos) + << "output_tokens must equal num_generated_tokens from perf_metrics: " << response; + ASSERT_NE(response.find("\"input_tokens\":12"), std::string::npos) + << "input_tokens must equal num_input_tokens from perf_metrics: " << response; + ASSERT_NE(response.find("\"total_tokens\":18"), std::string::npos) + << "total_tokens must be input+output: " << response; + ASSERT_FALSE(ctxBase->sendLoopbackSignal); +} + +TEST_F(HttpOpenAIHandlerParsingTest, legacyServablePreparePartialResponseChatCompletionsStreamingHasCorrectUsageChunk) { + auto ctx = std::make_shared(); + + ctx->payload.client = std::make_shared(); + ctx->payload.parsedJson = std::make_shared(); + ctx->payload.parsedJson->Parse( + R"({"model":"llama","stream":true,"stream_options":{"include_usage":true},"messages":[{"role":"user","content":"hi"}]})"); + ctx->endpoint = ovms::Endpoint::CHAT_COMPLETIONS; + + auto apiHandler = std::make_shared( + *ctx->payload.parsedJson, ovms::Endpoint::CHAT_COMPLETIONS, + std::chrono::system_clock::now(), *tokenizer); + uint32_t maxTokensLimit = 100; + const absl::Status parseStatus = apiHandler->parseRequest(maxTokensLimit, 0, std::nullopt); + ASSERT_TRUE(parseStatus.ok()) << parseStatus; + ctx->apiHandler = apiHandler; + + ctx->results.finish_reasons.push_back(ov::genai::GenerationFinishReason::STOP); + ctx->results.perf_metrics.num_input_tokens = 10; + ctx->results.perf_metrics.num_generated_tokens = 5; + ctx->success = true; + ctx->readySignal.set_value(); + ctx->textStreamer = std::make_shared( + *tokenizer, [](std::string) { return ov::genai::StreamingStatus::RUNNING; }); + + ovms::LegacyServable servable; + std::shared_ptr ctxBase = ctx; + ASSERT_EQ(servable.preparePartialResponse(ctxBase), absl::OkStatus()); + + // For chat_completions, usage is in the separate SSE usage chunk (not in the + // final delta chunk), so it should be present even though set*Usage was called + // before serializeStreamingChunk in the fixed code. + const std::string& response = ctxBase->response; + ASSERT_NE(response.find("\"completion_tokens\":5"), std::string::npos) + << "completion_tokens must be in usage chunk: " << response; + ASSERT_NE(response.find("\"prompt_tokens\":10"), std::string::npos) + << "prompt_tokens must be in usage chunk: " << response; + ASSERT_NE(response.find("\"total_tokens\":15"), std::string::npos) + << "total_tokens must be in usage chunk: " << response; + ASSERT_NE(response.find("[DONE]"), std::string::npos) + << "[DONE] must be present: " << response; + ASSERT_FALSE(ctxBase->sendLoopbackSignal); +}