From b832ce365bb400c1bb5f30ba8a4b33c3b0f20228 Mon Sep 17 00:00:00 2001 From: Aaron Beckley Date: Tue, 17 Mar 2026 13:25:03 -0400 Subject: [PATCH] Add native Qwen3-Reranker support to RerankCalculatorOV Qwen3-Reranker models use CausalLM architecture instead of cross-encoder text-classification, requiring different input formatting and output postprocessing. This enables OVMS to natively serve Qwen3-Reranker models (all sizes: 0.6B, 8B) exported with --task text-generation via the standard /v3/rerank API, with no client-side workarounds needed. Changes: - Auto-detect Qwen3 models via model_type in config.json - Apply server-side chat template formatting for query-document pairs - Add CausalLM graph postprocessing (Slice/Squeeze/Gather/Subtract) to extract yes/no logits from 3D output, producing scores compatible with existing sigmoid scoring - Handle CausalLM-specific inputs (position_ids, beam_idx) - Guard token_type_ids to avoid conflicts with CausalLM input layout - Warn if model was exported as text-classification (random head weights) Tested with Qwen3-Reranker-0.6B and Qwen3-Reranker-8B (int8) on CPU and Intel Arc GPU, producing correct relevance scores. --- src/rerank/rerank_calculator_ov.cc | 62 ++++++++++++++- src/rerank/rerank_servable.hpp | 122 +++++++++++++++++++++++++++++ 2 files changed, 183 insertions(+), 1 deletion(-) diff --git a/src/rerank/rerank_calculator_ov.cc b/src/rerank/rerank_calculator_ov.cc index 48036362e8..4921447e63 100644 --- a/src/rerank/rerank_calculator_ov.cc +++ b/src/rerank/rerank_calculator_ov.cc @@ -60,6 +60,8 @@ class RerankCalculatorOV : public CalculatorBase { static const std::string RERANK_MODEL_INPUT_IDS_NAME; static const std::string RERANK_MODEL_ATTENTION_MASK_NAME; static const std::string RERANK_MODEL_TOKEN_TYPE_IDS_NAME; + static const std::string RERANK_MODEL_POSITION_IDS_NAME; + static const std::string RERANK_MODEL_BEAM_IDX_NAME; static constexpr size_t NUMBER_OF_SPECIAL_TOKENS = 4; mediapipe::Timestamp timestamp{0}; @@ -151,6 +153,39 @@ class RerankCalculatorOV : public CalculatorBase { // Validate batch size before tokenizing if (handler.getDocumentsList().size() > this->max_allowed_chunks) throw std::runtime_error("Number of documents exceeds max_allowed_chunks"); + if (rerank_session->isQwen3) { + // Qwen3 reranker: format each query-document pair using the chat template + // Template from openvino-2026.0-genai/tests/python_tests/utils/qwen3_reranker_utils.py + auto batchSize = handler.getDocumentsList().size(); + std::vector data(batchSize); + + std::string prefix = "<|im_start|>system\nJudge whether the Document meets the requirements " + "based on the Query and the Instruct provided. Note that the answer can only be " + "\"yes\" or \"no\".<|im_end|>\n<|im_start|>user\n" + ": Given a web search query, retrieve relevant passages that answer the query\n" + ": " + handler.getQuery() + "\n"; + std::string suffix = "<|im_end|>\n<|im_start|>assistant\n\n\n\n\n"; + + for (size_t i = 0; i < batchSize; i++) { + data[i] = prefix + ": " + handler.getDocumentsList()[i] + suffix; + } + + chunk_mapping.resize(batchSize); + std::iota(chunk_mapping.begin(), chunk_mapping.end(), 0); + auto tokens = rerank_session->getTokenizer().encode(data); + if (tokens.input_ids.get_shape().size() != 2) { + throw std::runtime_error("Tokens shape invalid."); + } + if (this->max_position_embeddings < tokens.input_ids.get_shape()[1]) { + std::ostringstream msg; + msg << "Qwen3 rerank request length of " << tokens.input_ids.get_shape()[1] + << " tokens exceeds the model context of " << max_position_embeddings; + throw std::runtime_error(msg.str()); + } + SPDLOG_LOGGER_DEBUG(rerank_calculator_logger, "Qwen3 rerank: {} documents, {} tokens per sequence", + batchSize, tokens.input_ids.get_shape()[1]); + return std::make_pair(tokens.input_ids, tokens.attention_mask); + } if (!rerank_session->addBosToken) { auto batchSize = handler.getDocumentsList().size(); std::vector data(batchSize); @@ -257,6 +292,27 @@ class RerankCalculatorOV : public CalculatorBase { if (typeIds.has_value()) { inferRequest.set_tensor(RERANK_MODEL_TOKEN_TYPE_IDS_NAME, typeIds.value()); } + // For CausalLM models (e.g. Qwen3 rerankers): set position_ids and beam_idx + if (rerank_session->hasPositionIds) { + size_t batch = input_ids.get_shape()[0]; + size_t seq_len = input_ids.get_shape()[1]; + auto position_ids = ov::Tensor(ov::element::i64, input_ids.get_shape()); + int64_t* pos_data = position_ids.data(); + int64_t* attn_data = attention_mask.data(); + for (size_t b = 0; b < batch; b++) { + int64_t pos = 0; + for (size_t s = 0; s < seq_len; s++) { + pos_data[b * seq_len + s] = attn_data[b * seq_len + s] ? pos++ : 0; + } + } + inferRequest.set_tensor(RERANK_MODEL_POSITION_IDS_NAME, position_ids); + } + if (rerank_session->hasBeamIdx) { + size_t batch = input_ids.get_shape()[0]; + auto beam_idx = ov::Tensor(ov::element::i32, {batch}); + std::fill_n(beam_idx.data(), batch, 0); + inferRequest.set_tensor(RERANK_MODEL_BEAM_IDX_NAME, beam_idx); + } inferRequest.start_async(); inferRequest.wait(); auto logits = inferRequest.get_tensor("logits"); @@ -321,7 +377,9 @@ class RerankCalculatorOV : public CalculatorBase { std::vector chunk_mapping; auto [input_ids, attention_mask] = PrepareInputsForRerankModel(handler, chunk_mapping); std::optional typeIds; - if (rerank_session->getNumberOfModelInputs() == 3) { + // Only create token_type_ids for non-Qwen3 models with 3 inputs + // (Qwen3 CausalLM has position_ids as 3rd input, not token_type_ids) + if (!rerank_session->isQwen3 && rerank_session->getNumberOfModelInputs() == 3) { typeIds = ov::Tensor{ov::element::i64, input_ids.get_shape()}; std::fill_n(typeIds->data(), input_ids.get_size(), 0); } @@ -359,6 +417,8 @@ const std::string RerankCalculatorOV::OUTPUT_TAG_NAME{"RESPONSE_PAYLOAD"}; const std::string RerankCalculatorOV::RERANK_MODEL_INPUT_IDS_NAME{"input_ids"}; const std::string RerankCalculatorOV::RERANK_MODEL_ATTENTION_MASK_NAME{"attention_mask"}; const std::string RerankCalculatorOV::RERANK_MODEL_TOKEN_TYPE_IDS_NAME{"token_type_ids"}; +const std::string RerankCalculatorOV::RERANK_MODEL_POSITION_IDS_NAME{"position_ids"}; +const std::string RerankCalculatorOV::RERANK_MODEL_BEAM_IDX_NAME{"beam_idx"}; REGISTER_CALCULATOR(RerankCalculatorOV); diff --git a/src/rerank/rerank_servable.hpp b/src/rerank/rerank_servable.hpp index 15e23983c4..9d9bc798bb 100644 --- a/src/rerank/rerank_servable.hpp +++ b/src/rerank/rerank_servable.hpp @@ -23,10 +23,18 @@ #include #include +#include +#include +#include + namespace ovms { struct RerankServable : SidepacketServable { bool addBosToken = true; + bool isQwen3 = false; + bool hasPositionIds = false; + bool hasBeamIdx = false; + RerankServable(const std::string& modelDir, const std::string& targetDevice, const std::string& pluginConfig, const std::string& graphPath) : SidepacketServable(modelDir, targetDevice, pluginConfig, graphPath) { std::filesystem::path tokenizerConfigPath = (parsedModelsPath / "tokenizer_config.json"); @@ -49,6 +57,120 @@ struct RerankServable : SidepacketServable { addBosToken = false; } } + +protected: + std::shared_ptr applyPrePostProcessing(ov::Core& core, std::shared_ptr model, ov::AnyMap& properties) override { + // Detect Qwen3 model type from config.json + std::filesystem::path configPath = parsedModelsPath / "config.json"; + if (std::filesystem::exists(configPath)) { + std::ifstream ifs(configPath.string()); + if (ifs.is_open()) { + rapidjson::Document modelConfig; + rapidjson::IStreamWrapper isw(ifs); + rapidjson::ParseResult parseResult = modelConfig.ParseStream(isw); + if (!parseResult.Code()) { + if (modelConfig.HasMember("model_type") && modelConfig["model_type"].IsString()) { + std::string modelType = modelConfig["model_type"].GetString(); + if (modelType == "qwen3") { + SPDLOG_INFO("Detected Qwen3 reranker model, applying specialized postprocessing"); + isQwen3 = true; + } + } + } + } + } + + if (!isQwen3) { + return model; + } + + // Check model inputs for position_ids and beam_idx + for (const auto& input : model->inputs()) { + if (input.get_any_name() == "position_ids") { + hasPositionIds = true; + SPDLOG_DEBUG("Qwen3 reranker model has position_ids input"); + } + if (input.get_any_name() == "beam_idx") { + hasBeamIdx = true; + SPDLOG_DEBUG("Qwen3 reranker model has beam_idx input"); + } + } + + // Check output shape — only apply postprocessing for CausalLM models (3D output) + ov::PartialShape outputShape = model->get_output_partial_shape(0); + if (outputShape.rank().get_length() == 2) { + // Already a 2D output (text-classification export) — postprocessing won't help + // because the classification head has random weights + SPDLOG_WARN("Qwen3 reranker has 2D output shape (text-classification export). " + "Re-export with --task text-generation for correct scoring."); + return model; + } + + // Look up yes/no token IDs + int64_t yesTokenId = -1; + int64_t noTokenId = -1; + { + auto yesTokens = tokenizer->encode("yes"); + if (yesTokens.input_ids.get_size() == 1 && yesTokens.input_ids.get_element_type() == ov::element::i64) { + yesTokenId = reinterpret_cast(yesTokens.input_ids.data())[0]; + } + auto noTokens = tokenizer->encode("no"); + if (noTokens.input_ids.get_size() == 1 && noTokens.input_ids.get_element_type() == ov::element::i64) { + noTokenId = reinterpret_cast(noTokens.input_ids.data())[0]; + } + } + + if (yesTokenId < 0 || noTokenId < 0) { + SPDLOG_ERROR("Failed to look up yes/no token IDs for Qwen3 reranker"); + return model; + } + SPDLOG_INFO("Qwen3 reranker token IDs: yes={}, no={}", yesTokenId, noTokenId); + + // Apply Qwen3 postprocessing to model graph + // Ported from openvino-2026.0-genai text_rerank_pipeline.cpp apply_qwen3_postprocessing() + // + // Input: model output logits [batch, seq_len, vocab_size] + // Output: [batch, 1] tensor containing (yes_logit - no_logit) + // sigmoid of this equals softmax P(yes), so OVMS's existing sigmoid scoring works. + ov::preprocess::PrePostProcessor processor(model); + + processor.output().postprocess().custom( + [yesTokenId, noTokenId](const ov::Output& node) -> std::shared_ptr { + // Step 1: Slice last token — [batch, seq_len, vocab] → [batch, 1, vocab] + auto start = std::make_shared(ov::element::i64, ov::Shape{1}, std::vector{-1}); + auto stop = std::make_shared(ov::element::i64, ov::Shape{1}, std::vector{std::numeric_limits::max()}); + auto step = std::make_shared(ov::element::i64, ov::Shape{1}, std::vector{1}); + auto axis1 = std::make_shared(ov::element::i64, ov::Shape{1}, std::vector{1}); + + auto lastTokenSlice = std::make_shared(node, start, stop, step, axis1); + + // Step 2: Squeeze seq_len dim — [batch, 1, vocab] → [batch, vocab] + auto squeezed = std::make_shared(lastTokenSlice, axis1); + + // Step 3: Gather yes and no logits — [batch, vocab] → [batch, 2] + auto indices = std::make_shared(ov::element::i64, ov::Shape{2}, + std::vector{noTokenId, yesTokenId}); + auto gatherAxis = std::make_shared(ov::element::i64, ov::Shape{1}, std::vector{1}); + auto gathered = std::make_shared(squeezed, indices, gatherAxis); + + // Step 4: Compute yes_logit - no_logit → [batch, 1] + // gathered[:, 0] = no_logit, gathered[:, 1] = yes_logit + auto yesStart = std::make_shared(ov::element::i64, ov::Shape{1}, std::vector{1}); + auto yesStop = std::make_shared(ov::element::i64, ov::Shape{1}, std::vector{2}); + auto yesSlice = std::make_shared(gathered, yesStart, yesStop, step, gatherAxis); + + auto noStart = std::make_shared(ov::element::i64, ov::Shape{1}, std::vector{0}); + auto noStop = std::make_shared(ov::element::i64, ov::Shape{1}, std::vector{1}); + auto noSlice = std::make_shared(gathered, noStart, noStop, step, gatherAxis); + + // yes_logit - no_logit → sigmoid of this = softmax P(yes) + auto diff = std::make_shared(yesSlice, noSlice); + + return diff; // [batch, 1] + }); + + return processor.build(); + } }; using RerankServableMap = std::unordered_map>;