From 285283b4b89d564333ac2f752a19beabd02d2d7e Mon Sep 17 00:00:00 2001 From: Tai An Date: Sat, 25 Apr 2026 03:19:16 -0700 Subject: [PATCH 1/4] fix(middleware): parse OpenAI-spec tool_choice in /v1/chat/completions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Follows up on #9526 (the 3-site setter fix) by addressing the remaining clause in #9508 — string mode and OpenAI-spec specific-function shape both silently failed in the /v1/chat/completions parsing path. Signed-off-by: Ettore Di Giacinto --- core/http/middleware/request.go | 1335 ++++++++++++++++--------------- 1 file changed, 684 insertions(+), 651 deletions(-) diff --git a/core/http/middleware/request.go b/core/http/middleware/request.go index 1c0ee0ec7cf1..b03f6f7e6d7c 100644 --- a/core/http/middleware/request.go +++ b/core/http/middleware/request.go @@ -1,651 +1,684 @@ -package middleware - -import ( - "context" - "encoding/json" - "fmt" - "net/http" - "strconv" - "strings" - - "github.com/google/uuid" - "github.com/labstack/echo/v4" - "github.com/mudler/LocalAI/core/config" - "github.com/mudler/LocalAI/core/schema" - "github.com/mudler/LocalAI/core/services/galleryop" - "github.com/mudler/LocalAI/core/templates" - "github.com/mudler/LocalAI/pkg/functions" - "github.com/mudler/LocalAI/pkg/model" - "github.com/mudler/LocalAI/pkg/utils" - "github.com/mudler/xlog" -) - -type correlationIDKeyType string - -// CorrelationIDKey to track request across process boundary -const CorrelationIDKey correlationIDKeyType = "correlationID" - -type RequestExtractor struct { - modelConfigLoader *config.ModelConfigLoader - modelLoader *model.ModelLoader - applicationConfig *config.ApplicationConfig -} - -func NewRequestExtractor(modelConfigLoader *config.ModelConfigLoader, modelLoader *model.ModelLoader, applicationConfig *config.ApplicationConfig) *RequestExtractor { - return &RequestExtractor{ - modelConfigLoader: modelConfigLoader, - modelLoader: modelLoader, - applicationConfig: applicationConfig, - } -} - -const CONTEXT_LOCALS_KEY_MODEL_NAME = "MODEL_NAME" -const CONTEXT_LOCALS_KEY_LOCALAI_REQUEST = "LOCALAI_REQUEST" -const CONTEXT_LOCALS_KEY_MODEL_CONFIG = "MODEL_CONFIG" - -// TODO: Refactor to not return error if unchanged -func (re *RequestExtractor) setModelNameFromRequest(c echo.Context) { - model, ok := c.Get(CONTEXT_LOCALS_KEY_MODEL_NAME).(string) - if ok && model != "" { - return - } - model = c.Param("model") - - if model == "" { - model = c.QueryParam("model") - } - - // Check FormValue for multipart/form-data requests (e.g., /v1/images/inpainting) - if model == "" { - model = c.FormValue("model") - } - - if model == "" { - // Set model from bearer token, if available - auth := c.Request().Header.Get("Authorization") - bearer := strings.TrimPrefix(auth, "Bearer ") - if bearer != "" && bearer != auth { - exists, err := galleryop.CheckIfModelExists(re.modelConfigLoader, re.modelLoader, bearer, galleryop.ALWAYS_INCLUDE) - if err == nil && exists { - model = bearer - } - } - } - - c.Set(CONTEXT_LOCALS_KEY_MODEL_NAME, model) -} - -func (re *RequestExtractor) BuildConstantDefaultModelNameMiddleware(defaultModelName string) echo.MiddlewareFunc { - return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c echo.Context) error { - re.setModelNameFromRequest(c) - localModelName, ok := c.Get(CONTEXT_LOCALS_KEY_MODEL_NAME).(string) - if !ok || localModelName == "" { - c.Set(CONTEXT_LOCALS_KEY_MODEL_NAME, defaultModelName) - xlog.Debug("context local model name not found, setting to default", "defaultModelName", defaultModelName) - } - return next(c) - } - } -} - -func (re *RequestExtractor) BuildFilteredFirstAvailableDefaultModel(filterFn config.ModelConfigFilterFn) echo.MiddlewareFunc { - return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c echo.Context) error { - re.setModelNameFromRequest(c) - localModelName := c.Get(CONTEXT_LOCALS_KEY_MODEL_NAME).(string) - if localModelName != "" { // Don't overwrite existing values - return next(c) - } - - modelNames, err := galleryop.ListModels(re.modelConfigLoader, re.modelLoader, filterFn, galleryop.SKIP_IF_CONFIGURED) - if err != nil { - xlog.Error("non-fatal error calling ListModels during SetDefaultModelNameToFirstAvailable()", "error", err) - return next(c) - } - - if len(modelNames) == 0 { - xlog.Warn("SetDefaultModelNameToFirstAvailable used with no matching models installed") - // This is non-fatal - making it so was breaking the case of direct installation of raw models - // return errors.New("this endpoint requires at least one model to be installed") - return next(c) - } - - c.Set(CONTEXT_LOCALS_KEY_MODEL_NAME, modelNames[0]) - xlog.Debug("context local model name not found, setting to the first model", "first model name", modelNames[0]) - return next(c) - } - } -} - -// TODO: If context and cancel above belong on all methods, move that part of above into here! -// Otherwise, it's in its own method below for now -func (re *RequestExtractor) SetModelAndConfig(initializer func() schema.LocalAIRequest) echo.MiddlewareFunc { - return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c echo.Context) error { - input := initializer() - if input == nil { - return echo.NewHTTPError(http.StatusBadRequest, "unable to initialize body") - } - if err := c.Bind(input); err != nil { - return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("failed parsing request body: %v", err)) - } - - // If this request doesn't have an associated model name, fetch it from earlier in the middleware chain - if input.ModelName(nil) == "" { - localModelName, ok := c.Get(CONTEXT_LOCALS_KEY_MODEL_NAME).(string) - if ok && localModelName != "" { - xlog.Debug("overriding empty model name in request body with value found earlier in middleware chain", "context localModelName", localModelName) - input.ModelName(&localModelName) - } - } - - modelName := input.ModelName(nil) - cfg, err := re.modelConfigLoader.LoadModelConfigFileByNameDefaultOptions(modelName, re.applicationConfig) - - if err != nil { - xlog.Warn("Model Configuration File not found", "model", modelName, "error", err) - } else if cfg.Model == "" && modelName != "" { - xlog.Debug("config does not include model, using input", "input.ModelName", modelName) - cfg.Model = modelName - } - - // If a model name was specified, verify it actually exists before proceeding. - // Check both configured models and loose model files in the model path. - // Skip the check for HuggingFace model IDs (contain "/") since backends - // like diffusers may download these on the fly. - if modelName != "" && !strings.Contains(modelName, "/") { - exists, existsErr := galleryop.CheckIfModelExists(re.modelConfigLoader, re.modelLoader, modelName, galleryop.ALWAYS_INCLUDE) - if existsErr == nil && !exists { - return c.JSON(http.StatusNotFound, schema.ErrorResponse{ - Error: &schema.APIError{ - Message: fmt.Sprintf("model %q not found. To see available models, call GET /v1/models", modelName), - Code: http.StatusNotFound, - Type: "invalid_request_error", - }, - }) - } - } - - // Check if the model is disabled - if cfg != nil && cfg.IsDisabled() { - return c.JSON(http.StatusForbidden, schema.ErrorResponse{ - Error: &schema.APIError{ - Message: fmt.Sprintf("model %q is disabled and cannot be loaded. Enable it via the System page or API to use it.", modelName), - Code: http.StatusForbidden, - Type: "model_disabled", - }, - }) - } - - c.Set(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, input) - c.Set(CONTEXT_LOCALS_KEY_MODEL_CONFIG, cfg) - - return next(c) - } - } -} - -func (re *RequestExtractor) SetOpenAIRequest(c echo.Context) error { - input, ok := c.Get(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest) - if !ok || input.Model == "" { - return echo.ErrBadRequest - } - - cfg, ok := c.Get(CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) - if !ok || cfg == nil { - return echo.ErrBadRequest - } - - // Extract or generate the correlation ID - correlationID := c.Request().Header.Get("X-Correlation-ID") - if correlationID == "" { - correlationID = uuid.New().String() - } - c.Response().Header().Set("X-Correlation-ID", correlationID) - - // Use the request context directly - Echo properly supports context cancellation! - // No need for workarounds like handleConnectionCancellation - reqCtx := c.Request().Context() - c1, cancel := context.WithCancel(re.applicationConfig.Context) - - // Cancel when request context is cancelled (client disconnects) - go func() { - select { - case <-reqCtx.Done(): - cancel() - case <-c1.Done(): - // Already cancelled - } - }() - - // Add the correlation ID to the new context - ctxWithCorrelationID := context.WithValue(c1, CorrelationIDKey, correlationID) - - input.Context = ctxWithCorrelationID - input.Cancel = cancel - - err := mergeOpenAIRequestAndModelConfig(cfg, input) - if err != nil { - return err - } - - if cfg.Model == "" { - xlog.Debug("replacing empty cfg.Model with input value", "input.Model", input.Model) - cfg.Model = input.Model - } - - c.Set(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, input) - c.Set(CONTEXT_LOCALS_KEY_MODEL_CONFIG, cfg) - - return nil -} - -func mergeOpenAIRequestAndModelConfig(config *config.ModelConfig, input *schema.OpenAIRequest) error { - if input.Echo { - config.Echo = input.Echo - } - if input.TopK != nil { - config.TopK = input.TopK - } - if input.TopP != nil { - config.TopP = input.TopP - } - if input.MinP != nil { - config.MinP = input.MinP - } - - if input.Backend != "" { - config.Backend = input.Backend - } - - if input.ClipSkip != 0 { - config.Diffusers.ClipSkip = input.ClipSkip - } - - if input.NegativePromptScale != 0 { - config.NegativePromptScale = input.NegativePromptScale - } - - if input.NegativePrompt != "" { - config.NegativePrompt = input.NegativePrompt - } - - if input.RopeFreqBase != 0 { - config.RopeFreqBase = input.RopeFreqBase - } - - if input.RopeFreqScale != 0 { - config.RopeFreqScale = input.RopeFreqScale - } - - if input.Grammar != "" { - config.Grammar = input.Grammar - } - - if input.Temperature != nil { - config.Temperature = input.Temperature - } - - if input.Maxtokens != nil { - config.Maxtokens = input.Maxtokens - } - - if input.ResponseFormat != nil { - switch responseFormat := input.ResponseFormat.(type) { - case string: - config.ResponseFormat = responseFormat - case map[string]any: - config.ResponseFormatMap = responseFormat - } - } - - switch stop := input.Stop.(type) { - case string: - if stop != "" { - config.StopWords = append(config.StopWords, stop) - } - case []any: - for _, pp := range stop { - if s, ok := pp.(string); ok { - config.StopWords = append(config.StopWords, s) - } - } - } - - if len(input.Tools) > 0 { - for _, tool := range input.Tools { - input.Functions = append(input.Functions, tool.Function) - } - } - - if input.ToolsChoice != nil { - var toolChoice functions.Tool - - switch content := input.ToolsChoice.(type) { - case string: - _ = json.Unmarshal([]byte(content), &toolChoice) - case map[string]any: - dat, _ := json.Marshal(content) - _ = json.Unmarshal(dat, &toolChoice) - } - input.FunctionCall = map[string]any{ - "name": toolChoice.Function.Name, - } - } - - // Decode each request's message content - imgIndex, vidIndex, audioIndex := 0, 0, 0 - for i, m := range input.Messages { - nrOfImgsInMessage := 0 - nrOfVideosInMessage := 0 - nrOfAudiosInMessage := 0 - - switch content := m.Content.(type) { - case string: - input.Messages[i].StringContent = content - case []any: - dat, _ := json.Marshal(content) - c := []schema.Content{} - json.Unmarshal(dat, &c) - - textContent := "" - // we will template this at the end - - CONTENT: - for _, pp := range c { - switch pp.Type { - case "text": - textContent += pp.Text - //input.Messages[i].StringContent = pp.Text - case "video", "video_url": - // Decode content as base64 either if it's an URL or base64 text - base64, err := utils.GetContentURIAsBase64(pp.VideoURL.URL) - if err != nil { - xlog.Error("Failed encoding video", "error", err) - continue CONTENT - } - input.Messages[i].StringVideos = append(input.Messages[i].StringVideos, base64) // TODO: make sure that we only return base64 stuff - vidIndex++ - nrOfVideosInMessage++ - case "audio_url", "audio": - // Decode content as base64 either if it's an URL or base64 text - base64, err := utils.GetContentURIAsBase64(pp.AudioURL.URL) - if err != nil { - xlog.Error("Failed encoding audio", "error", err) - continue CONTENT - } - input.Messages[i].StringAudios = append(input.Messages[i].StringAudios, base64) // TODO: make sure that we only return base64 stuff - audioIndex++ - nrOfAudiosInMessage++ - case "input_audio": - // TODO: make sure that we only return base64 stuff - input.Messages[i].StringAudios = append(input.Messages[i].StringAudios, pp.InputAudio.Data) - audioIndex++ - nrOfAudiosInMessage++ - case "image_url", "image": - // Decode content as base64 either if it's an URL or base64 text - base64, err := utils.GetContentURIAsBase64(pp.ImageURL.URL) - if err != nil { - xlog.Error("Failed encoding image", "error", err) - continue CONTENT - } - - input.Messages[i].StringImages = append(input.Messages[i].StringImages, base64) // TODO: make sure that we only return base64 stuff - - imgIndex++ - nrOfImgsInMessage++ - } - } - - // When the backend handles templating itself (UseTokenizerTemplate), - // it also injects media markers server-side (see - // oaicompat_chat_params_parse in llama.cpp). Emitting our own markers - // here would double-mark them and downstream consumers ignore - // StringContent in that path anyway, so just pass through plain text. - if config.TemplateConfig.UseTokenizerTemplate { - input.Messages[i].StringContent = textContent - } else { - input.Messages[i].StringContent, _ = templates.TemplateMultiModal(config.TemplateConfig.Multimodal, templates.MultiModalOptions{ - TotalImages: imgIndex, - TotalVideos: vidIndex, - TotalAudios: audioIndex, - ImagesInMessage: nrOfImgsInMessage, - VideosInMessage: nrOfVideosInMessage, - AudiosInMessage: nrOfAudiosInMessage, - }, textContent) - } - } - } - - if input.RepeatPenalty != 0 { - config.RepeatPenalty = input.RepeatPenalty - } - - if input.FrequencyPenalty != 0 { - config.FrequencyPenalty = input.FrequencyPenalty - } - - if input.PresencePenalty != 0 { - config.PresencePenalty = input.PresencePenalty - } - - if input.Keep != 0 { - config.Keep = input.Keep - } - - if input.Batch != 0 { - config.Batch = input.Batch - } - - if input.IgnoreEOS { - config.IgnoreEOS = input.IgnoreEOS - } - - if input.Seed != nil { - config.Seed = input.Seed - } - - if input.TypicalP != nil { - config.TypicalP = input.TypicalP - } - - xlog.Debug("input.Input", "input", fmt.Sprintf("%+v", input.Input)) - - switch inputs := input.Input.(type) { - case string: - if inputs != "" { - config.InputStrings = append(config.InputStrings, inputs) - } - case []any: - for _, pp := range inputs { - switch i := pp.(type) { - case string: - config.InputStrings = append(config.InputStrings, i) - case []any: - tokens := []int{} - inputStrings := []string{} - for _, ii := range i { - switch ii := ii.(type) { - case int: - tokens = append(tokens, ii) - case float64: - tokens = append(tokens, int(ii)) - case string: - inputStrings = append(inputStrings, ii) - default: - xlog.Error("Unknown input type", "type", fmt.Sprintf("%T", ii)) - } - } - config.InputToken = append(config.InputToken, tokens) - config.InputStrings = append(config.InputStrings, inputStrings...) - } - } - } - - // Can be either a string or an object - switch fnc := input.FunctionCall.(type) { - case string: - if fnc != "" { - config.SetFunctionCallString(fnc) - } - case map[string]any: - var name string - n, exists := fnc["name"] - if exists { - nn, e := n.(string) - if e { - name = nn - } - } - config.SetFunctionCallNameString(name) - } - - switch p := input.Prompt.(type) { - case string: - config.PromptStrings = append(config.PromptStrings, p) - case []any: - for _, pp := range p { - if s, ok := pp.(string); ok { - config.PromptStrings = append(config.PromptStrings, s) - } - } - } - - // If a quality was defined as number, convert it to step - if input.Quality != "" { - q, err := strconv.Atoi(input.Quality) - if err == nil { - config.Step = q - } - } - - if valid, _ := config.Validate(); valid { - return nil - } - return fmt.Errorf("unable to validate configuration after merging") -} - -func (re *RequestExtractor) SetOpenResponsesRequest(c echo.Context) error { - input, ok := c.Get(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenResponsesRequest) - if !ok || input.Model == "" { - return echo.ErrBadRequest - } - - // Convert input items to Messages (this will be done in the endpoint handler) - // We store the input in the request for the endpoint to process - cfg, ok := c.Get(CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) - if !ok || cfg == nil { - return echo.ErrBadRequest - } - - // Extract or generate the correlation ID (Open Responses uses x-request-id) - correlationID := c.Request().Header.Get("x-request-id") - if correlationID == "" { - correlationID = uuid.New().String() - } - c.Response().Header().Set("x-request-id", correlationID) - - // Use the request context directly - Echo properly supports context cancellation! - reqCtx := c.Request().Context() - c1, cancel := context.WithCancel(re.applicationConfig.Context) - - // Cancel when request context is cancelled (client disconnects) - go func() { - select { - case <-reqCtx.Done(): - cancel() - case <-c1.Done(): - // Already cancelled - } - }() - - // Add the correlation ID to the new context - ctxWithCorrelationID := context.WithValue(c1, CorrelationIDKey, correlationID) - - input.Context = ctxWithCorrelationID - input.Cancel = cancel - - err := MergeOpenResponsesConfig(cfg, input) - if err != nil { - return err - } - - if cfg.Model == "" { - xlog.Debug("replacing empty cfg.Model with input value", "input.Model", input.Model) - cfg.Model = input.Model - } - - c.Set(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, input) - c.Set(CONTEXT_LOCALS_KEY_MODEL_CONFIG, cfg) - - return nil -} - -// MergeOpenResponsesConfig merges request parameters into the model configuration. -func MergeOpenResponsesConfig(config *config.ModelConfig, input *schema.OpenResponsesRequest) error { - // Temperature - if input.Temperature != nil { - config.Temperature = input.Temperature - } - - // TopP - if input.TopP != nil { - config.TopP = input.TopP - } - - // MaxOutputTokens -> Maxtokens - if input.MaxOutputTokens != nil { - config.Maxtokens = input.MaxOutputTokens - } - - // Convert tools to functions - this will be handled in the endpoint handler - // We just validate that tools are present if needed - - // Handle tool_choice - if input.ToolChoice != nil { - switch tc := input.ToolChoice.(type) { - case string: - // "auto", "required", or "none" - if tc == "required" { - config.SetFunctionCallString("required") - } else if tc == "none" { - // Don't use tools - handled in endpoint - } - // "auto" is default - let model decide - case map[string]any: - // Specific tool. OpenAI spec nests the function name under "function": - // {"type":"function", "function":{"name":"..."}} - // Legacy/Anthropic-compat form puts it at the top level: - // {"type":"function", "name":"..."} - // The old code only handled the legacy shape AND used the wrong - // setter (SetFunctionCallString writes the mode field; the - // specific-function name lives in a separate field read by - // ShouldCallSpecificFunction / FunctionToCall). Net effect: a - // correctly-formed OpenAI tool_choice never engaged grammar-based - // forcing, the model got the tools but no selection hint, and - // streamed raw JSON as delta.content instead of delta.tool_calls. - if tcType, ok := tc["type"].(string); ok && tcType == "function" { - var name string - if fn, ok := tc["function"].(map[string]any); ok { - if n, ok := fn["name"].(string); ok { - name = n - } - } - if name == "" { - if n, ok := tc["name"].(string); ok { - name = n - } - } - if name != "" { - config.SetFunctionCallNameString(name) - } - } - } - } - - if valid, _ := config.Validate(); valid { - return nil - } - return fmt.Errorf("unable to validate configuration after merging") -} +package middleware + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "strconv" + "strings" + + "github.com/google/uuid" + "github.com/labstack/echo/v4" + "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/core/schema" + "github.com/mudler/LocalAI/core/services/galleryop" + "github.com/mudler/LocalAI/core/templates" + "github.com/mudler/LocalAI/pkg/model" + "github.com/mudler/LocalAI/pkg/utils" + "github.com/mudler/xlog" +) + +type correlationIDKeyType string + +// CorrelationIDKey to track request across process boundary +const CorrelationIDKey correlationIDKeyType = "correlationID" + +type RequestExtractor struct { + modelConfigLoader *config.ModelConfigLoader + modelLoader *model.ModelLoader + applicationConfig *config.ApplicationConfig +} + +func NewRequestExtractor(modelConfigLoader *config.ModelConfigLoader, modelLoader *model.ModelLoader, applicationConfig *config.ApplicationConfig) *RequestExtractor { + return &RequestExtractor{ + modelConfigLoader: modelConfigLoader, + modelLoader: modelLoader, + applicationConfig: applicationConfig, + } +} + +const CONTEXT_LOCALS_KEY_MODEL_NAME = "MODEL_NAME" +const CONTEXT_LOCALS_KEY_LOCALAI_REQUEST = "LOCALAI_REQUEST" +const CONTEXT_LOCALS_KEY_MODEL_CONFIG = "MODEL_CONFIG" + +// TODO: Refactor to not return error if unchanged +func (re *RequestExtractor) setModelNameFromRequest(c echo.Context) { + model, ok := c.Get(CONTEXT_LOCALS_KEY_MODEL_NAME).(string) + if ok && model != "" { + return + } + model = c.Param("model") + + if model == "" { + model = c.QueryParam("model") + } + + // Check FormValue for multipart/form-data requests (e.g., /v1/images/inpainting) + if model == "" { + model = c.FormValue("model") + } + + if model == "" { + // Set model from bearer token, if available + auth := c.Request().Header.Get("Authorization") + bearer := strings.TrimPrefix(auth, "Bearer ") + if bearer != "" && bearer != auth { + exists, err := galleryop.CheckIfModelExists(re.modelConfigLoader, re.modelLoader, bearer, galleryop.ALWAYS_INCLUDE) + if err == nil && exists { + model = bearer + } + } + } + + c.Set(CONTEXT_LOCALS_KEY_MODEL_NAME, model) +} + +func (re *RequestExtractor) BuildConstantDefaultModelNameMiddleware(defaultModelName string) echo.MiddlewareFunc { + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + re.setModelNameFromRequest(c) + localModelName, ok := c.Get(CONTEXT_LOCALS_KEY_MODEL_NAME).(string) + if !ok || localModelName == "" { + c.Set(CONTEXT_LOCALS_KEY_MODEL_NAME, defaultModelName) + xlog.Debug("context local model name not found, setting to default", "defaultModelName", defaultModelName) + } + return next(c) + } + } +} + +func (re *RequestExtractor) BuildFilteredFirstAvailableDefaultModel(filterFn config.ModelConfigFilterFn) echo.MiddlewareFunc { + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + re.setModelNameFromRequest(c) + localModelName := c.Get(CONTEXT_LOCALS_KEY_MODEL_NAME).(string) + if localModelName != "" { // Don't overwrite existing values + return next(c) + } + + modelNames, err := galleryop.ListModels(re.modelConfigLoader, re.modelLoader, filterFn, galleryop.SKIP_IF_CONFIGURED) + if err != nil { + xlog.Error("non-fatal error calling ListModels during SetDefaultModelNameToFirstAvailable()", "error", err) + return next(c) + } + + if len(modelNames) == 0 { + xlog.Warn("SetDefaultModelNameToFirstAvailable used with no matching models installed") + // This is non-fatal - making it so was breaking the case of direct installation of raw models + // return errors.New("this endpoint requires at least one model to be installed") + return next(c) + } + + c.Set(CONTEXT_LOCALS_KEY_MODEL_NAME, modelNames[0]) + xlog.Debug("context local model name not found, setting to the first model", "first model name", modelNames[0]) + return next(c) + } + } +} + +// TODO: If context and cancel above belong on all methods, move that part of above into here! +// Otherwise, it's in its own method below for now +func (re *RequestExtractor) SetModelAndConfig(initializer func() schema.LocalAIRequest) echo.MiddlewareFunc { + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + input := initializer() + if input == nil { + return echo.NewHTTPError(http.StatusBadRequest, "unable to initialize body") + } + if err := c.Bind(input); err != nil { + return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("failed parsing request body: %v", err)) + } + + // If this request doesn't have an associated model name, fetch it from earlier in the middleware chain + if input.ModelName(nil) == "" { + localModelName, ok := c.Get(CONTEXT_LOCALS_KEY_MODEL_NAME).(string) + if ok && localModelName != "" { + xlog.Debug("overriding empty model name in request body with value found earlier in middleware chain", "context localModelName", localModelName) + input.ModelName(&localModelName) + } + } + + modelName := input.ModelName(nil) + cfg, err := re.modelConfigLoader.LoadModelConfigFileByNameDefaultOptions(modelName, re.applicationConfig) + + if err != nil { + xlog.Warn("Model Configuration File not found", "model", modelName, "error", err) + } else if cfg.Model == "" && modelName != "" { + xlog.Debug("config does not include model, using input", "input.ModelName", modelName) + cfg.Model = modelName + } + + // If a model name was specified, verify it actually exists before proceeding. + // Check both configured models and loose model files in the model path. + // Skip the check for HuggingFace model IDs (contain "/") since backends + // like diffusers may download these on the fly. + if modelName != "" && !strings.Contains(modelName, "/") { + exists, existsErr := galleryop.CheckIfModelExists(re.modelConfigLoader, re.modelLoader, modelName, galleryop.ALWAYS_INCLUDE) + if existsErr == nil && !exists { + return c.JSON(http.StatusNotFound, schema.ErrorResponse{ + Error: &schema.APIError{ + Message: fmt.Sprintf("model %q not found. To see available models, call GET /v1/models", modelName), + Code: http.StatusNotFound, + Type: "invalid_request_error", + }, + }) + } + } + + // Check if the model is disabled + if cfg != nil && cfg.IsDisabled() { + return c.JSON(http.StatusForbidden, schema.ErrorResponse{ + Error: &schema.APIError{ + Message: fmt.Sprintf("model %q is disabled and cannot be loaded. Enable it via the System page or API to use it.", modelName), + Code: http.StatusForbidden, + Type: "model_disabled", + }, + }) + } + + c.Set(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, input) + c.Set(CONTEXT_LOCALS_KEY_MODEL_CONFIG, cfg) + + return next(c) + } + } +} + +func (re *RequestExtractor) SetOpenAIRequest(c echo.Context) error { + input, ok := c.Get(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest) + if !ok || input.Model == "" { + return echo.ErrBadRequest + } + + cfg, ok := c.Get(CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) + if !ok || cfg == nil { + return echo.ErrBadRequest + } + + // Extract or generate the correlation ID + correlationID := c.Request().Header.Get("X-Correlation-ID") + if correlationID == "" { + correlationID = uuid.New().String() + } + c.Response().Header().Set("X-Correlation-ID", correlationID) + + // Use the request context directly - Echo properly supports context cancellation! + // No need for workarounds like handleConnectionCancellation + reqCtx := c.Request().Context() + c1, cancel := context.WithCancel(re.applicationConfig.Context) + + // Cancel when request context is cancelled (client disconnects) + go func() { + select { + case <-reqCtx.Done(): + cancel() + case <-c1.Done(): + // Already cancelled + } + }() + + // Add the correlation ID to the new context + ctxWithCorrelationID := context.WithValue(c1, CorrelationIDKey, correlationID) + + input.Context = ctxWithCorrelationID + input.Cancel = cancel + + err := mergeOpenAIRequestAndModelConfig(cfg, input) + if err != nil { + return err + } + + if cfg.Model == "" { + xlog.Debug("replacing empty cfg.Model with input value", "input.Model", input.Model) + cfg.Model = input.Model + } + + c.Set(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, input) + c.Set(CONTEXT_LOCALS_KEY_MODEL_CONFIG, cfg) + + return nil +} + +func mergeOpenAIRequestAndModelConfig(config *config.ModelConfig, input *schema.OpenAIRequest) error { + if input.Echo { + config.Echo = input.Echo + } + if input.TopK != nil { + config.TopK = input.TopK + } + if input.TopP != nil { + config.TopP = input.TopP + } + if input.MinP != nil { + config.MinP = input.MinP + } + + if input.Backend != "" { + config.Backend = input.Backend + } + + if input.ClipSkip != 0 { + config.Diffusers.ClipSkip = input.ClipSkip + } + + if input.NegativePromptScale != 0 { + config.NegativePromptScale = input.NegativePromptScale + } + + if input.NegativePrompt != "" { + config.NegativePrompt = input.NegativePrompt + } + + if input.RopeFreqBase != 0 { + config.RopeFreqBase = input.RopeFreqBase + } + + if input.RopeFreqScale != 0 { + config.RopeFreqScale = input.RopeFreqScale + } + + if input.Grammar != "" { + config.Grammar = input.Grammar + } + + if input.Temperature != nil { + config.Temperature = input.Temperature + } + + if input.Maxtokens != nil { + config.Maxtokens = input.Maxtokens + } + + if input.ResponseFormat != nil { + switch responseFormat := input.ResponseFormat.(type) { + case string: + config.ResponseFormat = responseFormat + case map[string]any: + config.ResponseFormatMap = responseFormat + } + } + + switch stop := input.Stop.(type) { + case string: + if stop != "" { + config.StopWords = append(config.StopWords, stop) + } + case []any: + for _, pp := range stop { + if s, ok := pp.(string); ok { + config.StopWords = append(config.StopWords, s) + } + } + } + + if len(input.Tools) > 0 { + for _, tool := range input.Tools { + input.Functions = append(input.Functions, tool.Function) + } + } + + if input.ToolsChoice != nil { + // OpenAI tool_choice has three valid shapes: + // + // 1. string mode: "auto" | "none" | "required" + // 2. specific tool: {"type":"function", "function":{"name":"..."}} (current spec) + // 3. legacy: {"type":"function", "name":"..."} (older / Anthropic-compat) + // + // The previous code unmarshalled all three into functions.Tool via + // json.Unmarshal and then unconditionally set input.FunctionCall = + // {"name": toolChoice.Function.Name}. That had two consequences: + // + // - For string modes, json.Unmarshal([]byte("required"), &Tool{}) fails; + // the error was silently discarded, name was "", and downstream + // SetFunctionCallNameString("") meant the requested mode never applied. + // - For the OpenAI-spec map shape, the json keys did not match + // functions.Tool's field tags, so name was "" again. + // + // Mirror the parsing pattern from MergeOpenResponsesConfig (#9509) and + // route results through the existing input.FunctionCall string/map + // dispatch downstream (see the switch on input.FunctionCall in this + // same function). Tracked in #9508; sibling fix in #9526. + switch content := input.ToolsChoice.(type) { + case string: + // "auto" is the default and needs no override; "none" is handled + // at the endpoint layer by skipping tool wiring. "required" must + // reach SetFunctionCallString to engage the mode field. + if content != "" && content != "auto" { + input.FunctionCall = content + } + case map[string]any: + if tcType, ok := content["type"].(string); ok && tcType == "function" { + var name string + if fn, ok := content["function"].(map[string]any); ok { + if n, ok := fn["name"].(string); ok { + name = n + } + } + if name == "" { + if n, ok := content["name"].(string); ok { + name = n + } + } + if name != "" { + input.FunctionCall = map[string]any{"name": name} + } + } + } + } + + // Decode each request's message content + imgIndex, vidIndex, audioIndex := 0, 0, 0 + for i, m := range input.Messages { + nrOfImgsInMessage := 0 + nrOfVideosInMessage := 0 + nrOfAudiosInMessage := 0 + + switch content := m.Content.(type) { + case string: + input.Messages[i].StringContent = content + case []any: + dat, _ := json.Marshal(content) + c := []schema.Content{} + json.Unmarshal(dat, &c) + + textContent := "" + // we will template this at the end + + CONTENT: + for _, pp := range c { + switch pp.Type { + case "text": + textContent += pp.Text + //input.Messages[i].StringContent = pp.Text + case "video", "video_url": + // Decode content as base64 either if it's an URL or base64 text + base64, err := utils.GetContentURIAsBase64(pp.VideoURL.URL) + if err != nil { + xlog.Error("Failed encoding video", "error", err) + continue CONTENT + } + input.Messages[i].StringVideos = append(input.Messages[i].StringVideos, base64) // TODO: make sure that we only return base64 stuff + vidIndex++ + nrOfVideosInMessage++ + case "audio_url", "audio": + // Decode content as base64 either if it's an URL or base64 text + base64, err := utils.GetContentURIAsBase64(pp.AudioURL.URL) + if err != nil { + xlog.Error("Failed encoding audio", "error", err) + continue CONTENT + } + input.Messages[i].StringAudios = append(input.Messages[i].StringAudios, base64) // TODO: make sure that we only return base64 stuff + audioIndex++ + nrOfAudiosInMessage++ + case "input_audio": + // TODO: make sure that we only return base64 stuff + input.Messages[i].StringAudios = append(input.Messages[i].StringAudios, pp.InputAudio.Data) + audioIndex++ + nrOfAudiosInMessage++ + case "image_url", "image": + // Decode content as base64 either if it's an URL or base64 text + base64, err := utils.GetContentURIAsBase64(pp.ImageURL.URL) + if err != nil { + xlog.Error("Failed encoding image", "error", err) + continue CONTENT + } + + input.Messages[i].StringImages = append(input.Messages[i].StringImages, base64) // TODO: make sure that we only return base64 stuff + + imgIndex++ + nrOfImgsInMessage++ + } + } + + // When the backend handles templating itself (UseTokenizerTemplate), + // it also injects media markers server-side (see + // oaicompat_chat_params_parse in llama.cpp). Emitting our own markers + // here would double-mark them and downstream consumers ignore + // StringContent in that path anyway, so just pass through plain text. + if config.TemplateConfig.UseTokenizerTemplate { + input.Messages[i].StringContent = textContent + } else { + input.Messages[i].StringContent, _ = templates.TemplateMultiModal(config.TemplateConfig.Multimodal, templates.MultiModalOptions{ + TotalImages: imgIndex, + TotalVideos: vidIndex, + TotalAudios: audioIndex, + ImagesInMessage: nrOfImgsInMessage, + VideosInMessage: nrOfVideosInMessage, + AudiosInMessage: nrOfAudiosInMessage, + }, textContent) + } + } + } + + if input.RepeatPenalty != 0 { + config.RepeatPenalty = input.RepeatPenalty + } + + if input.FrequencyPenalty != 0 { + config.FrequencyPenalty = input.FrequencyPenalty + } + + if input.PresencePenalty != 0 { + config.PresencePenalty = input.PresencePenalty + } + + if input.Keep != 0 { + config.Keep = input.Keep + } + + if input.Batch != 0 { + config.Batch = input.Batch + } + + if input.IgnoreEOS { + config.IgnoreEOS = input.IgnoreEOS + } + + if input.Seed != nil { + config.Seed = input.Seed + } + + if input.TypicalP != nil { + config.TypicalP = input.TypicalP + } + + xlog.Debug("input.Input", "input", fmt.Sprintf("%+v", input.Input)) + + switch inputs := input.Input.(type) { + case string: + if inputs != "" { + config.InputStrings = append(config.InputStrings, inputs) + } + case []any: + for _, pp := range inputs { + switch i := pp.(type) { + case string: + config.InputStrings = append(config.InputStrings, i) + case []any: + tokens := []int{} + inputStrings := []string{} + for _, ii := range i { + switch ii := ii.(type) { + case int: + tokens = append(tokens, ii) + case float64: + tokens = append(tokens, int(ii)) + case string: + inputStrings = append(inputStrings, ii) + default: + xlog.Error("Unknown input type", "type", fmt.Sprintf("%T", ii)) + } + } + config.InputToken = append(config.InputToken, tokens) + config.InputStrings = append(config.InputStrings, inputStrings...) + } + } + } + + // Can be either a string or an object + switch fnc := input.FunctionCall.(type) { + case string: + if fnc != "" { + config.SetFunctionCallString(fnc) + } + case map[string]any: + var name string + n, exists := fnc["name"] + if exists { + nn, e := n.(string) + if e { + name = nn + } + } + config.SetFunctionCallNameString(name) + } + + switch p := input.Prompt.(type) { + case string: + config.PromptStrings = append(config.PromptStrings, p) + case []any: + for _, pp := range p { + if s, ok := pp.(string); ok { + config.PromptStrings = append(config.PromptStrings, s) + } + } + } + + // If a quality was defined as number, convert it to step + if input.Quality != "" { + q, err := strconv.Atoi(input.Quality) + if err == nil { + config.Step = q + } + } + + if valid, _ := config.Validate(); valid { + return nil + } + return fmt.Errorf("unable to validate configuration after merging") +} + +func (re *RequestExtractor) SetOpenResponsesRequest(c echo.Context) error { + input, ok := c.Get(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenResponsesRequest) + if !ok || input.Model == "" { + return echo.ErrBadRequest + } + + // Convert input items to Messages (this will be done in the endpoint handler) + // We store the input in the request for the endpoint to process + cfg, ok := c.Get(CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) + if !ok || cfg == nil { + return echo.ErrBadRequest + } + + // Extract or generate the correlation ID (Open Responses uses x-request-id) + correlationID := c.Request().Header.Get("x-request-id") + if correlationID == "" { + correlationID = uuid.New().String() + } + c.Response().Header().Set("x-request-id", correlationID) + + // Use the request context directly - Echo properly supports context cancellation! + reqCtx := c.Request().Context() + c1, cancel := context.WithCancel(re.applicationConfig.Context) + + // Cancel when request context is cancelled (client disconnects) + go func() { + select { + case <-reqCtx.Done(): + cancel() + case <-c1.Done(): + // Already cancelled + } + }() + + // Add the correlation ID to the new context + ctxWithCorrelationID := context.WithValue(c1, CorrelationIDKey, correlationID) + + input.Context = ctxWithCorrelationID + input.Cancel = cancel + + err := MergeOpenResponsesConfig(cfg, input) + if err != nil { + return err + } + + if cfg.Model == "" { + xlog.Debug("replacing empty cfg.Model with input value", "input.Model", input.Model) + cfg.Model = input.Model + } + + c.Set(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, input) + c.Set(CONTEXT_LOCALS_KEY_MODEL_CONFIG, cfg) + + return nil +} + +// MergeOpenResponsesConfig merges request parameters into the model configuration. +func MergeOpenResponsesConfig(config *config.ModelConfig, input *schema.OpenResponsesRequest) error { + // Temperature + if input.Temperature != nil { + config.Temperature = input.Temperature + } + + // TopP + if input.TopP != nil { + config.TopP = input.TopP + } + + // MaxOutputTokens -> Maxtokens + if input.MaxOutputTokens != nil { + config.Maxtokens = input.MaxOutputTokens + } + + // Convert tools to functions - this will be handled in the endpoint handler + // We just validate that tools are present if needed + + // Handle tool_choice + if input.ToolChoice != nil { + switch tc := input.ToolChoice.(type) { + case string: + // "auto", "required", or "none" + if tc == "required" { + config.SetFunctionCallString("required") + } else if tc == "none" { + // Don't use tools - handled in endpoint + } + // "auto" is default - let model decide + case map[string]any: + // Specific tool. OpenAI spec nests the function name under "function": + // {"type":"function", "function":{"name":"..."}} + // Legacy/Anthropic-compat form puts it at the top level: + // {"type":"function", "name":"..."} + // The old code only handled the legacy shape AND used the wrong + // setter (SetFunctionCallString writes the mode field; the + // specific-function name lives in a separate field read by + // ShouldCallSpecificFunction / FunctionToCall). Net effect: a + // correctly-formed OpenAI tool_choice never engaged grammar-based + // forcing, the model got the tools but no selection hint, and + // streamed raw JSON as delta.content instead of delta.tool_calls. + if tcType, ok := tc["type"].(string); ok && tcType == "function" { + var name string + if fn, ok := tc["function"].(map[string]any); ok { + if n, ok := fn["name"].(string); ok { + name = n + } + } + if name == "" { + if n, ok := tc["name"].(string); ok { + name = n + } + } + if name != "" { + config.SetFunctionCallNameString(name) + } + } + } + } + + if valid, _ := config.Validate(); valid { + return nil + } + return fmt.Errorf("unable to validate configuration after merging") +} From d833d2dbe9c1f429e0ea88971d6433abd7bce5ee Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Wed, 13 May 2026 20:14:55 +0000 Subject: [PATCH 2/4] fix(middleware): restore LF endings and cover tool_choice parsing with specs The previous commit on this branch saved core/http/middleware/request.go with CRLF line endings, ballooning the diff against master to 684 / 651 for what is in reality a ~50-line parsing change. Restore LF (matches .editorconfig end_of_line = lf). Add 11 Ginkgo specs under "SetModelAndConfig tool_choice parsing (chat completions)" that parallel the existing MergeOpenResponsesConfig specs from #9509. They drive the full middleware chain (SetModelAndConfig + SetOpenAIRequest) and assert: * "required" -> ShouldUseFunctions=true, no specific name * "none" -> ShouldUseFunctions=false (tools disabled per OpenAI spec) * "auto" -> default, tools available, no specific name * {type:function, function:{name:X}} (spec) -> X is forced * {type:function, name:X} (legacy) -> X is forced * nested wins when both forms are present * malformed shapes (no type, wrong type, no name, empty name) are no-ops Update the inline comment on the string case to describe the actual mechanism: "none" reaches SetFunctionCallString("none") downstream and is then honored by ShouldUseFunctions() returning false. Before this PR json.Unmarshal([]byte("none"), &functions.Tool{}) failed silently, so "none" was ignored - making "none" actually work is a real behavior fix this PR brings. Signed-off-by: Ettore Di Giacinto Assisted-by: Claude:opus-4-7 [Claude Code] --- core/http/middleware/request.go | 1371 +++++++++++++------------- core/http/middleware/request_test.go | 197 ++++ 2 files changed, 884 insertions(+), 684 deletions(-) diff --git a/core/http/middleware/request.go b/core/http/middleware/request.go index b03f6f7e6d7c..cd791e096201 100644 --- a/core/http/middleware/request.go +++ b/core/http/middleware/request.go @@ -1,684 +1,687 @@ -package middleware - -import ( - "context" - "encoding/json" - "fmt" - "net/http" - "strconv" - "strings" - - "github.com/google/uuid" - "github.com/labstack/echo/v4" - "github.com/mudler/LocalAI/core/config" - "github.com/mudler/LocalAI/core/schema" - "github.com/mudler/LocalAI/core/services/galleryop" - "github.com/mudler/LocalAI/core/templates" - "github.com/mudler/LocalAI/pkg/model" - "github.com/mudler/LocalAI/pkg/utils" - "github.com/mudler/xlog" -) - -type correlationIDKeyType string - -// CorrelationIDKey to track request across process boundary -const CorrelationIDKey correlationIDKeyType = "correlationID" - -type RequestExtractor struct { - modelConfigLoader *config.ModelConfigLoader - modelLoader *model.ModelLoader - applicationConfig *config.ApplicationConfig -} - -func NewRequestExtractor(modelConfigLoader *config.ModelConfigLoader, modelLoader *model.ModelLoader, applicationConfig *config.ApplicationConfig) *RequestExtractor { - return &RequestExtractor{ - modelConfigLoader: modelConfigLoader, - modelLoader: modelLoader, - applicationConfig: applicationConfig, - } -} - -const CONTEXT_LOCALS_KEY_MODEL_NAME = "MODEL_NAME" -const CONTEXT_LOCALS_KEY_LOCALAI_REQUEST = "LOCALAI_REQUEST" -const CONTEXT_LOCALS_KEY_MODEL_CONFIG = "MODEL_CONFIG" - -// TODO: Refactor to not return error if unchanged -func (re *RequestExtractor) setModelNameFromRequest(c echo.Context) { - model, ok := c.Get(CONTEXT_LOCALS_KEY_MODEL_NAME).(string) - if ok && model != "" { - return - } - model = c.Param("model") - - if model == "" { - model = c.QueryParam("model") - } - - // Check FormValue for multipart/form-data requests (e.g., /v1/images/inpainting) - if model == "" { - model = c.FormValue("model") - } - - if model == "" { - // Set model from bearer token, if available - auth := c.Request().Header.Get("Authorization") - bearer := strings.TrimPrefix(auth, "Bearer ") - if bearer != "" && bearer != auth { - exists, err := galleryop.CheckIfModelExists(re.modelConfigLoader, re.modelLoader, bearer, galleryop.ALWAYS_INCLUDE) - if err == nil && exists { - model = bearer - } - } - } - - c.Set(CONTEXT_LOCALS_KEY_MODEL_NAME, model) -} - -func (re *RequestExtractor) BuildConstantDefaultModelNameMiddleware(defaultModelName string) echo.MiddlewareFunc { - return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c echo.Context) error { - re.setModelNameFromRequest(c) - localModelName, ok := c.Get(CONTEXT_LOCALS_KEY_MODEL_NAME).(string) - if !ok || localModelName == "" { - c.Set(CONTEXT_LOCALS_KEY_MODEL_NAME, defaultModelName) - xlog.Debug("context local model name not found, setting to default", "defaultModelName", defaultModelName) - } - return next(c) - } - } -} - -func (re *RequestExtractor) BuildFilteredFirstAvailableDefaultModel(filterFn config.ModelConfigFilterFn) echo.MiddlewareFunc { - return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c echo.Context) error { - re.setModelNameFromRequest(c) - localModelName := c.Get(CONTEXT_LOCALS_KEY_MODEL_NAME).(string) - if localModelName != "" { // Don't overwrite existing values - return next(c) - } - - modelNames, err := galleryop.ListModels(re.modelConfigLoader, re.modelLoader, filterFn, galleryop.SKIP_IF_CONFIGURED) - if err != nil { - xlog.Error("non-fatal error calling ListModels during SetDefaultModelNameToFirstAvailable()", "error", err) - return next(c) - } - - if len(modelNames) == 0 { - xlog.Warn("SetDefaultModelNameToFirstAvailable used with no matching models installed") - // This is non-fatal - making it so was breaking the case of direct installation of raw models - // return errors.New("this endpoint requires at least one model to be installed") - return next(c) - } - - c.Set(CONTEXT_LOCALS_KEY_MODEL_NAME, modelNames[0]) - xlog.Debug("context local model name not found, setting to the first model", "first model name", modelNames[0]) - return next(c) - } - } -} - -// TODO: If context and cancel above belong on all methods, move that part of above into here! -// Otherwise, it's in its own method below for now -func (re *RequestExtractor) SetModelAndConfig(initializer func() schema.LocalAIRequest) echo.MiddlewareFunc { - return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c echo.Context) error { - input := initializer() - if input == nil { - return echo.NewHTTPError(http.StatusBadRequest, "unable to initialize body") - } - if err := c.Bind(input); err != nil { - return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("failed parsing request body: %v", err)) - } - - // If this request doesn't have an associated model name, fetch it from earlier in the middleware chain - if input.ModelName(nil) == "" { - localModelName, ok := c.Get(CONTEXT_LOCALS_KEY_MODEL_NAME).(string) - if ok && localModelName != "" { - xlog.Debug("overriding empty model name in request body with value found earlier in middleware chain", "context localModelName", localModelName) - input.ModelName(&localModelName) - } - } - - modelName := input.ModelName(nil) - cfg, err := re.modelConfigLoader.LoadModelConfigFileByNameDefaultOptions(modelName, re.applicationConfig) - - if err != nil { - xlog.Warn("Model Configuration File not found", "model", modelName, "error", err) - } else if cfg.Model == "" && modelName != "" { - xlog.Debug("config does not include model, using input", "input.ModelName", modelName) - cfg.Model = modelName - } - - // If a model name was specified, verify it actually exists before proceeding. - // Check both configured models and loose model files in the model path. - // Skip the check for HuggingFace model IDs (contain "/") since backends - // like diffusers may download these on the fly. - if modelName != "" && !strings.Contains(modelName, "/") { - exists, existsErr := galleryop.CheckIfModelExists(re.modelConfigLoader, re.modelLoader, modelName, galleryop.ALWAYS_INCLUDE) - if existsErr == nil && !exists { - return c.JSON(http.StatusNotFound, schema.ErrorResponse{ - Error: &schema.APIError{ - Message: fmt.Sprintf("model %q not found. To see available models, call GET /v1/models", modelName), - Code: http.StatusNotFound, - Type: "invalid_request_error", - }, - }) - } - } - - // Check if the model is disabled - if cfg != nil && cfg.IsDisabled() { - return c.JSON(http.StatusForbidden, schema.ErrorResponse{ - Error: &schema.APIError{ - Message: fmt.Sprintf("model %q is disabled and cannot be loaded. Enable it via the System page or API to use it.", modelName), - Code: http.StatusForbidden, - Type: "model_disabled", - }, - }) - } - - c.Set(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, input) - c.Set(CONTEXT_LOCALS_KEY_MODEL_CONFIG, cfg) - - return next(c) - } - } -} - -func (re *RequestExtractor) SetOpenAIRequest(c echo.Context) error { - input, ok := c.Get(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest) - if !ok || input.Model == "" { - return echo.ErrBadRequest - } - - cfg, ok := c.Get(CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) - if !ok || cfg == nil { - return echo.ErrBadRequest - } - - // Extract or generate the correlation ID - correlationID := c.Request().Header.Get("X-Correlation-ID") - if correlationID == "" { - correlationID = uuid.New().String() - } - c.Response().Header().Set("X-Correlation-ID", correlationID) - - // Use the request context directly - Echo properly supports context cancellation! - // No need for workarounds like handleConnectionCancellation - reqCtx := c.Request().Context() - c1, cancel := context.WithCancel(re.applicationConfig.Context) - - // Cancel when request context is cancelled (client disconnects) - go func() { - select { - case <-reqCtx.Done(): - cancel() - case <-c1.Done(): - // Already cancelled - } - }() - - // Add the correlation ID to the new context - ctxWithCorrelationID := context.WithValue(c1, CorrelationIDKey, correlationID) - - input.Context = ctxWithCorrelationID - input.Cancel = cancel - - err := mergeOpenAIRequestAndModelConfig(cfg, input) - if err != nil { - return err - } - - if cfg.Model == "" { - xlog.Debug("replacing empty cfg.Model with input value", "input.Model", input.Model) - cfg.Model = input.Model - } - - c.Set(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, input) - c.Set(CONTEXT_LOCALS_KEY_MODEL_CONFIG, cfg) - - return nil -} - -func mergeOpenAIRequestAndModelConfig(config *config.ModelConfig, input *schema.OpenAIRequest) error { - if input.Echo { - config.Echo = input.Echo - } - if input.TopK != nil { - config.TopK = input.TopK - } - if input.TopP != nil { - config.TopP = input.TopP - } - if input.MinP != nil { - config.MinP = input.MinP - } - - if input.Backend != "" { - config.Backend = input.Backend - } - - if input.ClipSkip != 0 { - config.Diffusers.ClipSkip = input.ClipSkip - } - - if input.NegativePromptScale != 0 { - config.NegativePromptScale = input.NegativePromptScale - } - - if input.NegativePrompt != "" { - config.NegativePrompt = input.NegativePrompt - } - - if input.RopeFreqBase != 0 { - config.RopeFreqBase = input.RopeFreqBase - } - - if input.RopeFreqScale != 0 { - config.RopeFreqScale = input.RopeFreqScale - } - - if input.Grammar != "" { - config.Grammar = input.Grammar - } - - if input.Temperature != nil { - config.Temperature = input.Temperature - } - - if input.Maxtokens != nil { - config.Maxtokens = input.Maxtokens - } - - if input.ResponseFormat != nil { - switch responseFormat := input.ResponseFormat.(type) { - case string: - config.ResponseFormat = responseFormat - case map[string]any: - config.ResponseFormatMap = responseFormat - } - } - - switch stop := input.Stop.(type) { - case string: - if stop != "" { - config.StopWords = append(config.StopWords, stop) - } - case []any: - for _, pp := range stop { - if s, ok := pp.(string); ok { - config.StopWords = append(config.StopWords, s) - } - } - } - - if len(input.Tools) > 0 { - for _, tool := range input.Tools { - input.Functions = append(input.Functions, tool.Function) - } - } - - if input.ToolsChoice != nil { - // OpenAI tool_choice has three valid shapes: - // - // 1. string mode: "auto" | "none" | "required" - // 2. specific tool: {"type":"function", "function":{"name":"..."}} (current spec) - // 3. legacy: {"type":"function", "name":"..."} (older / Anthropic-compat) - // - // The previous code unmarshalled all three into functions.Tool via - // json.Unmarshal and then unconditionally set input.FunctionCall = - // {"name": toolChoice.Function.Name}. That had two consequences: - // - // - For string modes, json.Unmarshal([]byte("required"), &Tool{}) fails; - // the error was silently discarded, name was "", and downstream - // SetFunctionCallNameString("") meant the requested mode never applied. - // - For the OpenAI-spec map shape, the json keys did not match - // functions.Tool's field tags, so name was "" again. - // - // Mirror the parsing pattern from MergeOpenResponsesConfig (#9509) and - // route results through the existing input.FunctionCall string/map - // dispatch downstream (see the switch on input.FunctionCall in this - // same function). Tracked in #9508; sibling fix in #9526. - switch content := input.ToolsChoice.(type) { - case string: - // "auto" is the default and needs no override; "none" is handled - // at the endpoint layer by skipping tool wiring. "required" must - // reach SetFunctionCallString to engage the mode field. - if content != "" && content != "auto" { - input.FunctionCall = content - } - case map[string]any: - if tcType, ok := content["type"].(string); ok && tcType == "function" { - var name string - if fn, ok := content["function"].(map[string]any); ok { - if n, ok := fn["name"].(string); ok { - name = n - } - } - if name == "" { - if n, ok := content["name"].(string); ok { - name = n - } - } - if name != "" { - input.FunctionCall = map[string]any{"name": name} - } - } - } - } - - // Decode each request's message content - imgIndex, vidIndex, audioIndex := 0, 0, 0 - for i, m := range input.Messages { - nrOfImgsInMessage := 0 - nrOfVideosInMessage := 0 - nrOfAudiosInMessage := 0 - - switch content := m.Content.(type) { - case string: - input.Messages[i].StringContent = content - case []any: - dat, _ := json.Marshal(content) - c := []schema.Content{} - json.Unmarshal(dat, &c) - - textContent := "" - // we will template this at the end - - CONTENT: - for _, pp := range c { - switch pp.Type { - case "text": - textContent += pp.Text - //input.Messages[i].StringContent = pp.Text - case "video", "video_url": - // Decode content as base64 either if it's an URL or base64 text - base64, err := utils.GetContentURIAsBase64(pp.VideoURL.URL) - if err != nil { - xlog.Error("Failed encoding video", "error", err) - continue CONTENT - } - input.Messages[i].StringVideos = append(input.Messages[i].StringVideos, base64) // TODO: make sure that we only return base64 stuff - vidIndex++ - nrOfVideosInMessage++ - case "audio_url", "audio": - // Decode content as base64 either if it's an URL or base64 text - base64, err := utils.GetContentURIAsBase64(pp.AudioURL.URL) - if err != nil { - xlog.Error("Failed encoding audio", "error", err) - continue CONTENT - } - input.Messages[i].StringAudios = append(input.Messages[i].StringAudios, base64) // TODO: make sure that we only return base64 stuff - audioIndex++ - nrOfAudiosInMessage++ - case "input_audio": - // TODO: make sure that we only return base64 stuff - input.Messages[i].StringAudios = append(input.Messages[i].StringAudios, pp.InputAudio.Data) - audioIndex++ - nrOfAudiosInMessage++ - case "image_url", "image": - // Decode content as base64 either if it's an URL or base64 text - base64, err := utils.GetContentURIAsBase64(pp.ImageURL.URL) - if err != nil { - xlog.Error("Failed encoding image", "error", err) - continue CONTENT - } - - input.Messages[i].StringImages = append(input.Messages[i].StringImages, base64) // TODO: make sure that we only return base64 stuff - - imgIndex++ - nrOfImgsInMessage++ - } - } - - // When the backend handles templating itself (UseTokenizerTemplate), - // it also injects media markers server-side (see - // oaicompat_chat_params_parse in llama.cpp). Emitting our own markers - // here would double-mark them and downstream consumers ignore - // StringContent in that path anyway, so just pass through plain text. - if config.TemplateConfig.UseTokenizerTemplate { - input.Messages[i].StringContent = textContent - } else { - input.Messages[i].StringContent, _ = templates.TemplateMultiModal(config.TemplateConfig.Multimodal, templates.MultiModalOptions{ - TotalImages: imgIndex, - TotalVideos: vidIndex, - TotalAudios: audioIndex, - ImagesInMessage: nrOfImgsInMessage, - VideosInMessage: nrOfVideosInMessage, - AudiosInMessage: nrOfAudiosInMessage, - }, textContent) - } - } - } - - if input.RepeatPenalty != 0 { - config.RepeatPenalty = input.RepeatPenalty - } - - if input.FrequencyPenalty != 0 { - config.FrequencyPenalty = input.FrequencyPenalty - } - - if input.PresencePenalty != 0 { - config.PresencePenalty = input.PresencePenalty - } - - if input.Keep != 0 { - config.Keep = input.Keep - } - - if input.Batch != 0 { - config.Batch = input.Batch - } - - if input.IgnoreEOS { - config.IgnoreEOS = input.IgnoreEOS - } - - if input.Seed != nil { - config.Seed = input.Seed - } - - if input.TypicalP != nil { - config.TypicalP = input.TypicalP - } - - xlog.Debug("input.Input", "input", fmt.Sprintf("%+v", input.Input)) - - switch inputs := input.Input.(type) { - case string: - if inputs != "" { - config.InputStrings = append(config.InputStrings, inputs) - } - case []any: - for _, pp := range inputs { - switch i := pp.(type) { - case string: - config.InputStrings = append(config.InputStrings, i) - case []any: - tokens := []int{} - inputStrings := []string{} - for _, ii := range i { - switch ii := ii.(type) { - case int: - tokens = append(tokens, ii) - case float64: - tokens = append(tokens, int(ii)) - case string: - inputStrings = append(inputStrings, ii) - default: - xlog.Error("Unknown input type", "type", fmt.Sprintf("%T", ii)) - } - } - config.InputToken = append(config.InputToken, tokens) - config.InputStrings = append(config.InputStrings, inputStrings...) - } - } - } - - // Can be either a string or an object - switch fnc := input.FunctionCall.(type) { - case string: - if fnc != "" { - config.SetFunctionCallString(fnc) - } - case map[string]any: - var name string - n, exists := fnc["name"] - if exists { - nn, e := n.(string) - if e { - name = nn - } - } - config.SetFunctionCallNameString(name) - } - - switch p := input.Prompt.(type) { - case string: - config.PromptStrings = append(config.PromptStrings, p) - case []any: - for _, pp := range p { - if s, ok := pp.(string); ok { - config.PromptStrings = append(config.PromptStrings, s) - } - } - } - - // If a quality was defined as number, convert it to step - if input.Quality != "" { - q, err := strconv.Atoi(input.Quality) - if err == nil { - config.Step = q - } - } - - if valid, _ := config.Validate(); valid { - return nil - } - return fmt.Errorf("unable to validate configuration after merging") -} - -func (re *RequestExtractor) SetOpenResponsesRequest(c echo.Context) error { - input, ok := c.Get(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenResponsesRequest) - if !ok || input.Model == "" { - return echo.ErrBadRequest - } - - // Convert input items to Messages (this will be done in the endpoint handler) - // We store the input in the request for the endpoint to process - cfg, ok := c.Get(CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) - if !ok || cfg == nil { - return echo.ErrBadRequest - } - - // Extract or generate the correlation ID (Open Responses uses x-request-id) - correlationID := c.Request().Header.Get("x-request-id") - if correlationID == "" { - correlationID = uuid.New().String() - } - c.Response().Header().Set("x-request-id", correlationID) - - // Use the request context directly - Echo properly supports context cancellation! - reqCtx := c.Request().Context() - c1, cancel := context.WithCancel(re.applicationConfig.Context) - - // Cancel when request context is cancelled (client disconnects) - go func() { - select { - case <-reqCtx.Done(): - cancel() - case <-c1.Done(): - // Already cancelled - } - }() - - // Add the correlation ID to the new context - ctxWithCorrelationID := context.WithValue(c1, CorrelationIDKey, correlationID) - - input.Context = ctxWithCorrelationID - input.Cancel = cancel - - err := MergeOpenResponsesConfig(cfg, input) - if err != nil { - return err - } - - if cfg.Model == "" { - xlog.Debug("replacing empty cfg.Model with input value", "input.Model", input.Model) - cfg.Model = input.Model - } - - c.Set(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, input) - c.Set(CONTEXT_LOCALS_KEY_MODEL_CONFIG, cfg) - - return nil -} - -// MergeOpenResponsesConfig merges request parameters into the model configuration. -func MergeOpenResponsesConfig(config *config.ModelConfig, input *schema.OpenResponsesRequest) error { - // Temperature - if input.Temperature != nil { - config.Temperature = input.Temperature - } - - // TopP - if input.TopP != nil { - config.TopP = input.TopP - } - - // MaxOutputTokens -> Maxtokens - if input.MaxOutputTokens != nil { - config.Maxtokens = input.MaxOutputTokens - } - - // Convert tools to functions - this will be handled in the endpoint handler - // We just validate that tools are present if needed - - // Handle tool_choice - if input.ToolChoice != nil { - switch tc := input.ToolChoice.(type) { - case string: - // "auto", "required", or "none" - if tc == "required" { - config.SetFunctionCallString("required") - } else if tc == "none" { - // Don't use tools - handled in endpoint - } - // "auto" is default - let model decide - case map[string]any: - // Specific tool. OpenAI spec nests the function name under "function": - // {"type":"function", "function":{"name":"..."}} - // Legacy/Anthropic-compat form puts it at the top level: - // {"type":"function", "name":"..."} - // The old code only handled the legacy shape AND used the wrong - // setter (SetFunctionCallString writes the mode field; the - // specific-function name lives in a separate field read by - // ShouldCallSpecificFunction / FunctionToCall). Net effect: a - // correctly-formed OpenAI tool_choice never engaged grammar-based - // forcing, the model got the tools but no selection hint, and - // streamed raw JSON as delta.content instead of delta.tool_calls. - if tcType, ok := tc["type"].(string); ok && tcType == "function" { - var name string - if fn, ok := tc["function"].(map[string]any); ok { - if n, ok := fn["name"].(string); ok { - name = n - } - } - if name == "" { - if n, ok := tc["name"].(string); ok { - name = n - } - } - if name != "" { - config.SetFunctionCallNameString(name) - } - } - } - } - - if valid, _ := config.Validate(); valid { - return nil - } - return fmt.Errorf("unable to validate configuration after merging") -} +package middleware + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "strconv" + "strings" + + "github.com/google/uuid" + "github.com/labstack/echo/v4" + "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/core/schema" + "github.com/mudler/LocalAI/core/services/galleryop" + "github.com/mudler/LocalAI/core/templates" + "github.com/mudler/LocalAI/pkg/model" + "github.com/mudler/LocalAI/pkg/utils" + "github.com/mudler/xlog" +) + +type correlationIDKeyType string + +// CorrelationIDKey to track request across process boundary +const CorrelationIDKey correlationIDKeyType = "correlationID" + +type RequestExtractor struct { + modelConfigLoader *config.ModelConfigLoader + modelLoader *model.ModelLoader + applicationConfig *config.ApplicationConfig +} + +func NewRequestExtractor(modelConfigLoader *config.ModelConfigLoader, modelLoader *model.ModelLoader, applicationConfig *config.ApplicationConfig) *RequestExtractor { + return &RequestExtractor{ + modelConfigLoader: modelConfigLoader, + modelLoader: modelLoader, + applicationConfig: applicationConfig, + } +} + +const CONTEXT_LOCALS_KEY_MODEL_NAME = "MODEL_NAME" +const CONTEXT_LOCALS_KEY_LOCALAI_REQUEST = "LOCALAI_REQUEST" +const CONTEXT_LOCALS_KEY_MODEL_CONFIG = "MODEL_CONFIG" + +// TODO: Refactor to not return error if unchanged +func (re *RequestExtractor) setModelNameFromRequest(c echo.Context) { + model, ok := c.Get(CONTEXT_LOCALS_KEY_MODEL_NAME).(string) + if ok && model != "" { + return + } + model = c.Param("model") + + if model == "" { + model = c.QueryParam("model") + } + + // Check FormValue for multipart/form-data requests (e.g., /v1/images/inpainting) + if model == "" { + model = c.FormValue("model") + } + + if model == "" { + // Set model from bearer token, if available + auth := c.Request().Header.Get("Authorization") + bearer := strings.TrimPrefix(auth, "Bearer ") + if bearer != "" && bearer != auth { + exists, err := galleryop.CheckIfModelExists(re.modelConfigLoader, re.modelLoader, bearer, galleryop.ALWAYS_INCLUDE) + if err == nil && exists { + model = bearer + } + } + } + + c.Set(CONTEXT_LOCALS_KEY_MODEL_NAME, model) +} + +func (re *RequestExtractor) BuildConstantDefaultModelNameMiddleware(defaultModelName string) echo.MiddlewareFunc { + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + re.setModelNameFromRequest(c) + localModelName, ok := c.Get(CONTEXT_LOCALS_KEY_MODEL_NAME).(string) + if !ok || localModelName == "" { + c.Set(CONTEXT_LOCALS_KEY_MODEL_NAME, defaultModelName) + xlog.Debug("context local model name not found, setting to default", "defaultModelName", defaultModelName) + } + return next(c) + } + } +} + +func (re *RequestExtractor) BuildFilteredFirstAvailableDefaultModel(filterFn config.ModelConfigFilterFn) echo.MiddlewareFunc { + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + re.setModelNameFromRequest(c) + localModelName := c.Get(CONTEXT_LOCALS_KEY_MODEL_NAME).(string) + if localModelName != "" { // Don't overwrite existing values + return next(c) + } + + modelNames, err := galleryop.ListModels(re.modelConfigLoader, re.modelLoader, filterFn, galleryop.SKIP_IF_CONFIGURED) + if err != nil { + xlog.Error("non-fatal error calling ListModels during SetDefaultModelNameToFirstAvailable()", "error", err) + return next(c) + } + + if len(modelNames) == 0 { + xlog.Warn("SetDefaultModelNameToFirstAvailable used with no matching models installed") + // This is non-fatal - making it so was breaking the case of direct installation of raw models + // return errors.New("this endpoint requires at least one model to be installed") + return next(c) + } + + c.Set(CONTEXT_LOCALS_KEY_MODEL_NAME, modelNames[0]) + xlog.Debug("context local model name not found, setting to the first model", "first model name", modelNames[0]) + return next(c) + } + } +} + +// TODO: If context and cancel above belong on all methods, move that part of above into here! +// Otherwise, it's in its own method below for now +func (re *RequestExtractor) SetModelAndConfig(initializer func() schema.LocalAIRequest) echo.MiddlewareFunc { + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + input := initializer() + if input == nil { + return echo.NewHTTPError(http.StatusBadRequest, "unable to initialize body") + } + if err := c.Bind(input); err != nil { + return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("failed parsing request body: %v", err)) + } + + // If this request doesn't have an associated model name, fetch it from earlier in the middleware chain + if input.ModelName(nil) == "" { + localModelName, ok := c.Get(CONTEXT_LOCALS_KEY_MODEL_NAME).(string) + if ok && localModelName != "" { + xlog.Debug("overriding empty model name in request body with value found earlier in middleware chain", "context localModelName", localModelName) + input.ModelName(&localModelName) + } + } + + modelName := input.ModelName(nil) + cfg, err := re.modelConfigLoader.LoadModelConfigFileByNameDefaultOptions(modelName, re.applicationConfig) + + if err != nil { + xlog.Warn("Model Configuration File not found", "model", modelName, "error", err) + } else if cfg.Model == "" && modelName != "" { + xlog.Debug("config does not include model, using input", "input.ModelName", modelName) + cfg.Model = modelName + } + + // If a model name was specified, verify it actually exists before proceeding. + // Check both configured models and loose model files in the model path. + // Skip the check for HuggingFace model IDs (contain "/") since backends + // like diffusers may download these on the fly. + if modelName != "" && !strings.Contains(modelName, "/") { + exists, existsErr := galleryop.CheckIfModelExists(re.modelConfigLoader, re.modelLoader, modelName, galleryop.ALWAYS_INCLUDE) + if existsErr == nil && !exists { + return c.JSON(http.StatusNotFound, schema.ErrorResponse{ + Error: &schema.APIError{ + Message: fmt.Sprintf("model %q not found. To see available models, call GET /v1/models", modelName), + Code: http.StatusNotFound, + Type: "invalid_request_error", + }, + }) + } + } + + // Check if the model is disabled + if cfg != nil && cfg.IsDisabled() { + return c.JSON(http.StatusForbidden, schema.ErrorResponse{ + Error: &schema.APIError{ + Message: fmt.Sprintf("model %q is disabled and cannot be loaded. Enable it via the System page or API to use it.", modelName), + Code: http.StatusForbidden, + Type: "model_disabled", + }, + }) + } + + c.Set(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, input) + c.Set(CONTEXT_LOCALS_KEY_MODEL_CONFIG, cfg) + + return next(c) + } + } +} + +func (re *RequestExtractor) SetOpenAIRequest(c echo.Context) error { + input, ok := c.Get(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest) + if !ok || input.Model == "" { + return echo.ErrBadRequest + } + + cfg, ok := c.Get(CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) + if !ok || cfg == nil { + return echo.ErrBadRequest + } + + // Extract or generate the correlation ID + correlationID := c.Request().Header.Get("X-Correlation-ID") + if correlationID == "" { + correlationID = uuid.New().String() + } + c.Response().Header().Set("X-Correlation-ID", correlationID) + + // Use the request context directly - Echo properly supports context cancellation! + // No need for workarounds like handleConnectionCancellation + reqCtx := c.Request().Context() + c1, cancel := context.WithCancel(re.applicationConfig.Context) + + // Cancel when request context is cancelled (client disconnects) + go func() { + select { + case <-reqCtx.Done(): + cancel() + case <-c1.Done(): + // Already cancelled + } + }() + + // Add the correlation ID to the new context + ctxWithCorrelationID := context.WithValue(c1, CorrelationIDKey, correlationID) + + input.Context = ctxWithCorrelationID + input.Cancel = cancel + + err := mergeOpenAIRequestAndModelConfig(cfg, input) + if err != nil { + return err + } + + if cfg.Model == "" { + xlog.Debug("replacing empty cfg.Model with input value", "input.Model", input.Model) + cfg.Model = input.Model + } + + c.Set(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, input) + c.Set(CONTEXT_LOCALS_KEY_MODEL_CONFIG, cfg) + + return nil +} + +func mergeOpenAIRequestAndModelConfig(config *config.ModelConfig, input *schema.OpenAIRequest) error { + if input.Echo { + config.Echo = input.Echo + } + if input.TopK != nil { + config.TopK = input.TopK + } + if input.TopP != nil { + config.TopP = input.TopP + } + if input.MinP != nil { + config.MinP = input.MinP + } + + if input.Backend != "" { + config.Backend = input.Backend + } + + if input.ClipSkip != 0 { + config.Diffusers.ClipSkip = input.ClipSkip + } + + if input.NegativePromptScale != 0 { + config.NegativePromptScale = input.NegativePromptScale + } + + if input.NegativePrompt != "" { + config.NegativePrompt = input.NegativePrompt + } + + if input.RopeFreqBase != 0 { + config.RopeFreqBase = input.RopeFreqBase + } + + if input.RopeFreqScale != 0 { + config.RopeFreqScale = input.RopeFreqScale + } + + if input.Grammar != "" { + config.Grammar = input.Grammar + } + + if input.Temperature != nil { + config.Temperature = input.Temperature + } + + if input.Maxtokens != nil { + config.Maxtokens = input.Maxtokens + } + + if input.ResponseFormat != nil { + switch responseFormat := input.ResponseFormat.(type) { + case string: + config.ResponseFormat = responseFormat + case map[string]any: + config.ResponseFormatMap = responseFormat + } + } + + switch stop := input.Stop.(type) { + case string: + if stop != "" { + config.StopWords = append(config.StopWords, stop) + } + case []any: + for _, pp := range stop { + if s, ok := pp.(string); ok { + config.StopWords = append(config.StopWords, s) + } + } + } + + if len(input.Tools) > 0 { + for _, tool := range input.Tools { + input.Functions = append(input.Functions, tool.Function) + } + } + + if input.ToolsChoice != nil { + // OpenAI tool_choice has three valid shapes: + // + // 1. string mode: "auto" | "none" | "required" + // 2. specific tool: {"type":"function", "function":{"name":"..."}} (current spec) + // 3. legacy: {"type":"function", "name":"..."} (older / Anthropic-compat) + // + // The previous code unmarshalled all three into functions.Tool via + // json.Unmarshal and then unconditionally set input.FunctionCall = + // {"name": toolChoice.Function.Name}. That had two consequences: + // + // - For string modes, json.Unmarshal([]byte("required"), &Tool{}) fails; + // the error was silently discarded, name was "", and downstream + // SetFunctionCallNameString("") meant the requested mode never applied. + // - For the OpenAI-spec map shape, the json keys did not match + // functions.Tool's field tags, so name was "" again. + // + // Mirror the parsing pattern from MergeOpenResponsesConfig (#9509) and + // route results through the existing input.FunctionCall string/map + // dispatch downstream (see the switch on input.FunctionCall in this + // same function). Tracked in #9508; sibling fix in #9526. + switch content := input.ToolsChoice.(type) { + case string: + // "auto" is the default and needs no override. "none" and "required" + // both reach SetFunctionCallString via the input.FunctionCall string + // branch below; ShouldUseFunctions() then returns false for "none" + // (tools disabled) and true for "required" (mode engaged), matching + // the OpenAI spec. Before this fix "none" was silently ignored + // (json.Unmarshal(`none`, &Tool{}) failed) so tools stayed enabled. + if content != "" && content != "auto" { + input.FunctionCall = content + } + case map[string]any: + if tcType, ok := content["type"].(string); ok && tcType == "function" { + var name string + if fn, ok := content["function"].(map[string]any); ok { + if n, ok := fn["name"].(string); ok { + name = n + } + } + if name == "" { + if n, ok := content["name"].(string); ok { + name = n + } + } + if name != "" { + input.FunctionCall = map[string]any{"name": name} + } + } + } + } + + // Decode each request's message content + imgIndex, vidIndex, audioIndex := 0, 0, 0 + for i, m := range input.Messages { + nrOfImgsInMessage := 0 + nrOfVideosInMessage := 0 + nrOfAudiosInMessage := 0 + + switch content := m.Content.(type) { + case string: + input.Messages[i].StringContent = content + case []any: + dat, _ := json.Marshal(content) + c := []schema.Content{} + json.Unmarshal(dat, &c) + + textContent := "" + // we will template this at the end + + CONTENT: + for _, pp := range c { + switch pp.Type { + case "text": + textContent += pp.Text + //input.Messages[i].StringContent = pp.Text + case "video", "video_url": + // Decode content as base64 either if it's an URL or base64 text + base64, err := utils.GetContentURIAsBase64(pp.VideoURL.URL) + if err != nil { + xlog.Error("Failed encoding video", "error", err) + continue CONTENT + } + input.Messages[i].StringVideos = append(input.Messages[i].StringVideos, base64) // TODO: make sure that we only return base64 stuff + vidIndex++ + nrOfVideosInMessage++ + case "audio_url", "audio": + // Decode content as base64 either if it's an URL or base64 text + base64, err := utils.GetContentURIAsBase64(pp.AudioURL.URL) + if err != nil { + xlog.Error("Failed encoding audio", "error", err) + continue CONTENT + } + input.Messages[i].StringAudios = append(input.Messages[i].StringAudios, base64) // TODO: make sure that we only return base64 stuff + audioIndex++ + nrOfAudiosInMessage++ + case "input_audio": + // TODO: make sure that we only return base64 stuff + input.Messages[i].StringAudios = append(input.Messages[i].StringAudios, pp.InputAudio.Data) + audioIndex++ + nrOfAudiosInMessage++ + case "image_url", "image": + // Decode content as base64 either if it's an URL or base64 text + base64, err := utils.GetContentURIAsBase64(pp.ImageURL.URL) + if err != nil { + xlog.Error("Failed encoding image", "error", err) + continue CONTENT + } + + input.Messages[i].StringImages = append(input.Messages[i].StringImages, base64) // TODO: make sure that we only return base64 stuff + + imgIndex++ + nrOfImgsInMessage++ + } + } + + // When the backend handles templating itself (UseTokenizerTemplate), + // it also injects media markers server-side (see + // oaicompat_chat_params_parse in llama.cpp). Emitting our own markers + // here would double-mark them and downstream consumers ignore + // StringContent in that path anyway, so just pass through plain text. + if config.TemplateConfig.UseTokenizerTemplate { + input.Messages[i].StringContent = textContent + } else { + input.Messages[i].StringContent, _ = templates.TemplateMultiModal(config.TemplateConfig.Multimodal, templates.MultiModalOptions{ + TotalImages: imgIndex, + TotalVideos: vidIndex, + TotalAudios: audioIndex, + ImagesInMessage: nrOfImgsInMessage, + VideosInMessage: nrOfVideosInMessage, + AudiosInMessage: nrOfAudiosInMessage, + }, textContent) + } + } + } + + if input.RepeatPenalty != 0 { + config.RepeatPenalty = input.RepeatPenalty + } + + if input.FrequencyPenalty != 0 { + config.FrequencyPenalty = input.FrequencyPenalty + } + + if input.PresencePenalty != 0 { + config.PresencePenalty = input.PresencePenalty + } + + if input.Keep != 0 { + config.Keep = input.Keep + } + + if input.Batch != 0 { + config.Batch = input.Batch + } + + if input.IgnoreEOS { + config.IgnoreEOS = input.IgnoreEOS + } + + if input.Seed != nil { + config.Seed = input.Seed + } + + if input.TypicalP != nil { + config.TypicalP = input.TypicalP + } + + xlog.Debug("input.Input", "input", fmt.Sprintf("%+v", input.Input)) + + switch inputs := input.Input.(type) { + case string: + if inputs != "" { + config.InputStrings = append(config.InputStrings, inputs) + } + case []any: + for _, pp := range inputs { + switch i := pp.(type) { + case string: + config.InputStrings = append(config.InputStrings, i) + case []any: + tokens := []int{} + inputStrings := []string{} + for _, ii := range i { + switch ii := ii.(type) { + case int: + tokens = append(tokens, ii) + case float64: + tokens = append(tokens, int(ii)) + case string: + inputStrings = append(inputStrings, ii) + default: + xlog.Error("Unknown input type", "type", fmt.Sprintf("%T", ii)) + } + } + config.InputToken = append(config.InputToken, tokens) + config.InputStrings = append(config.InputStrings, inputStrings...) + } + } + } + + // Can be either a string or an object + switch fnc := input.FunctionCall.(type) { + case string: + if fnc != "" { + config.SetFunctionCallString(fnc) + } + case map[string]any: + var name string + n, exists := fnc["name"] + if exists { + nn, e := n.(string) + if e { + name = nn + } + } + config.SetFunctionCallNameString(name) + } + + switch p := input.Prompt.(type) { + case string: + config.PromptStrings = append(config.PromptStrings, p) + case []any: + for _, pp := range p { + if s, ok := pp.(string); ok { + config.PromptStrings = append(config.PromptStrings, s) + } + } + } + + // If a quality was defined as number, convert it to step + if input.Quality != "" { + q, err := strconv.Atoi(input.Quality) + if err == nil { + config.Step = q + } + } + + if valid, _ := config.Validate(); valid { + return nil + } + return fmt.Errorf("unable to validate configuration after merging") +} + +func (re *RequestExtractor) SetOpenResponsesRequest(c echo.Context) error { + input, ok := c.Get(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenResponsesRequest) + if !ok || input.Model == "" { + return echo.ErrBadRequest + } + + // Convert input items to Messages (this will be done in the endpoint handler) + // We store the input in the request for the endpoint to process + cfg, ok := c.Get(CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) + if !ok || cfg == nil { + return echo.ErrBadRequest + } + + // Extract or generate the correlation ID (Open Responses uses x-request-id) + correlationID := c.Request().Header.Get("x-request-id") + if correlationID == "" { + correlationID = uuid.New().String() + } + c.Response().Header().Set("x-request-id", correlationID) + + // Use the request context directly - Echo properly supports context cancellation! + reqCtx := c.Request().Context() + c1, cancel := context.WithCancel(re.applicationConfig.Context) + + // Cancel when request context is cancelled (client disconnects) + go func() { + select { + case <-reqCtx.Done(): + cancel() + case <-c1.Done(): + // Already cancelled + } + }() + + // Add the correlation ID to the new context + ctxWithCorrelationID := context.WithValue(c1, CorrelationIDKey, correlationID) + + input.Context = ctxWithCorrelationID + input.Cancel = cancel + + err := MergeOpenResponsesConfig(cfg, input) + if err != nil { + return err + } + + if cfg.Model == "" { + xlog.Debug("replacing empty cfg.Model with input value", "input.Model", input.Model) + cfg.Model = input.Model + } + + c.Set(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, input) + c.Set(CONTEXT_LOCALS_KEY_MODEL_CONFIG, cfg) + + return nil +} + +// MergeOpenResponsesConfig merges request parameters into the model configuration. +func MergeOpenResponsesConfig(config *config.ModelConfig, input *schema.OpenResponsesRequest) error { + // Temperature + if input.Temperature != nil { + config.Temperature = input.Temperature + } + + // TopP + if input.TopP != nil { + config.TopP = input.TopP + } + + // MaxOutputTokens -> Maxtokens + if input.MaxOutputTokens != nil { + config.Maxtokens = input.MaxOutputTokens + } + + // Convert tools to functions - this will be handled in the endpoint handler + // We just validate that tools are present if needed + + // Handle tool_choice + if input.ToolChoice != nil { + switch tc := input.ToolChoice.(type) { + case string: + // "auto", "required", or "none" + if tc == "required" { + config.SetFunctionCallString("required") + } else if tc == "none" { + // Don't use tools - handled in endpoint + } + // "auto" is default - let model decide + case map[string]any: + // Specific tool. OpenAI spec nests the function name under "function": + // {"type":"function", "function":{"name":"..."}} + // Legacy/Anthropic-compat form puts it at the top level: + // {"type":"function", "name":"..."} + // The old code only handled the legacy shape AND used the wrong + // setter (SetFunctionCallString writes the mode field; the + // specific-function name lives in a separate field read by + // ShouldCallSpecificFunction / FunctionToCall). Net effect: a + // correctly-formed OpenAI tool_choice never engaged grammar-based + // forcing, the model got the tools but no selection hint, and + // streamed raw JSON as delta.content instead of delta.tool_calls. + if tcType, ok := tc["type"].(string); ok && tcType == "function" { + var name string + if fn, ok := tc["function"].(map[string]any); ok { + if n, ok := fn["name"].(string); ok { + name = n + } + } + if name == "" { + if n, ok := tc["name"].(string); ok { + name = n + } + } + if name != "" { + config.SetFunctionCallNameString(name) + } + } + } + } + + if valid, _ := config.Validate(); valid { + return nil + } + return fmt.Errorf("unable to validate configuration after merging") +} diff --git a/core/http/middleware/request_test.go b/core/http/middleware/request_test.go index 0b1a04cf70d2..b6acdd36fd60 100644 --- a/core/http/middleware/request_test.go +++ b/core/http/middleware/request_test.go @@ -306,3 +306,200 @@ var _ = Describe("MergeOpenResponsesConfig tool_choice parsing", func() { }) }) }) + +// --------------------------------------------------------------------------- +// SetModelAndConfig + SetOpenAIRequest - /v1/chat/completions tool_choice parsing +// --------------------------------------------------------------------------- +// +// Parallel to the MergeOpenResponsesConfig specs above, but for the chat +// completions path. The parsing block lives in mergeOpenAIRequestAndModelConfig +// (called from SetOpenAIRequest), so these tests drive the full middleware +// chain the way the production /v1/chat/completions route does. +// +// What we assert per shape: +// - "required" -> ShouldUseFunctions=true, no specific name +// - "none" -> ShouldUseFunctions=false (tools disabled) +// - "auto" -> ShouldUseFunctions=true, no specific name +// - {type:function, function:{name:"X"}} (spec) -> ShouldCallSpecificFunction=true, FunctionToCall="X" +// - {type:function, name:"X"} (legacy) -> ShouldCallSpecificFunction=true, FunctionToCall="X" +// - nested+flat both present -> nested wins +// - malformed (no type / no name) -> no-op +var _ = Describe("SetModelAndConfig tool_choice parsing (chat completions)", func() { + var ( + app *echo.Echo + modelDir string + capturedConfig *config.ModelConfig + ) + + BeforeEach(func() { + var err error + modelDir, err = os.MkdirTemp("", "localai-test-models-*") + Expect(err).ToNot(HaveOccurred()) + + cfgContent := []byte("name: test-model\nbackend: llama-cpp\n") + Expect(os.WriteFile(filepath.Join(modelDir, "test-model.yaml"), cfgContent, 0644)).To(Succeed()) + + ss := &system.SystemState{ + Model: system.Model{ModelsPath: modelDir}, + } + appConfig := config.NewApplicationConfig() + appConfig.SystemState = ss + + mcl := config.NewModelConfigLoader(modelDir) + ml := model.NewModelLoader(ss) + re := NewRequestExtractor(mcl, ml, appConfig) + + capturedConfig = nil + app = echo.New() + app.POST("/v1/chat/completions", + func(c echo.Context) error { + if cfg, ok := c.Get(CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig); ok { + capturedConfig = cfg + } + return c.String(http.StatusOK, "ok") + }, + re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }), + func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + if err := re.SetOpenAIRequest(c); err != nil { + return err + } + return next(c) + } + }, + ) + }) + + AfterEach(func() { + os.RemoveAll(modelDir) + }) + + // chatReq wraps a tool_choice JSON fragment in a minimal valid chat-completions + // payload. The tools array is non-empty so downstream code paths that gate on + // len(input.Functions) see something to work with. + chatReq := func(toolChoiceJSON string) string { + return `{"model":"test-model",` + + `"messages":[{"role":"user","content":"hi"}],` + + `"tools":[{"type":"function","function":{"name":"get_weather"}}],` + + `"tool_choice":` + toolChoiceJSON + `}` + } + + Context("string tool_choice", func() { + It("engages mode for tool_choice=\"required\"", func() { + rec := postJSON(app, "/v1/chat/completions", chatReq(`"required"`)) + + Expect(rec.Code).To(Equal(http.StatusOK)) + Expect(capturedConfig).ToNot(BeNil()) + Expect(capturedConfig.ShouldCallSpecificFunction()).To(BeFalse()) + Expect(capturedConfig.ShouldUseFunctions()).To(BeTrue()) + }) + + It("disables tools for tool_choice=\"none\"", func() { + // Before #9559 this was a silent no-op (json.Unmarshal of "none" + // into functions.Tool failed); now "none" is honored per OpenAI spec. + rec := postJSON(app, "/v1/chat/completions", chatReq(`"none"`)) + + Expect(rec.Code).To(Equal(http.StatusOK)) + Expect(capturedConfig).ToNot(BeNil()) + Expect(capturedConfig.ShouldUseFunctions()).To(BeFalse()) + Expect(capturedConfig.ShouldCallSpecificFunction()).To(BeFalse()) + }) + + It("leaves config untouched for tool_choice=\"auto\"", func() { + rec := postJSON(app, "/v1/chat/completions", chatReq(`"auto"`)) + + Expect(rec.Code).To(Equal(http.StatusOK)) + Expect(capturedConfig).ToNot(BeNil()) + Expect(capturedConfig.ShouldCallSpecificFunction()).To(BeFalse()) + // "auto" is the default: tools available, model decides. + Expect(capturedConfig.ShouldUseFunctions()).To(BeTrue()) + Expect(capturedConfig.FunctionToCall()).To(Equal("")) + }) + }) + + Context("specific-function tool_choice (OpenAI spec shape)", func() { + It("parses {type:function, function:{name:...}} and forces the named function", func() { + rec := postJSON(app, "/v1/chat/completions", + chatReq(`{"type":"function","function":{"name":"get_weather"}}`)) + + Expect(rec.Code).To(Equal(http.StatusOK)) + Expect(capturedConfig).ToNot(BeNil()) + // Key invariant: a correctly-formed OpenAI tool_choice must engage + // grammar-based forcing via SetFunctionCallNameString. + Expect(capturedConfig.ShouldCallSpecificFunction()).To(BeTrue()) + Expect(capturedConfig.FunctionToCall()).To(Equal("get_weather")) + }) + + It("prefers the nested function.name over a stray top-level name", func() { + rec := postJSON(app, "/v1/chat/completions", + chatReq(`{"type":"function","function":{"name":"correct_name"},"name":"legacy_name"}`)) + + Expect(rec.Code).To(Equal(http.StatusOK)) + Expect(capturedConfig).ToNot(BeNil()) + Expect(capturedConfig.FunctionToCall()).To(Equal("correct_name")) + }) + }) + + Context("specific-function tool_choice (legacy Anthropic-compat shape)", func() { + It("parses {type:function, name:...} and forces the named function", func() { + rec := postJSON(app, "/v1/chat/completions", + chatReq(`{"type":"function","name":"get_weather"}`)) + + Expect(rec.Code).To(Equal(http.StatusOK)) + Expect(capturedConfig).ToNot(BeNil()) + Expect(capturedConfig.ShouldCallSpecificFunction()).To(BeTrue()) + Expect(capturedConfig.FunctionToCall()).To(Equal("get_weather")) + }) + }) + + Context("malformed tool_choice", func() { + It("is a no-op when type is missing", func() { + rec := postJSON(app, "/v1/chat/completions", + chatReq(`{"function":{"name":"get_weather"}}`)) + + Expect(rec.Code).To(Equal(http.StatusOK)) + Expect(capturedConfig).ToNot(BeNil()) + Expect(capturedConfig.ShouldCallSpecificFunction()).To(BeFalse()) + }) + + It("is a no-op when type is not \"function\"", func() { + rec := postJSON(app, "/v1/chat/completions", + chatReq(`{"type":"object","function":{"name":"get_weather"}}`)) + + Expect(rec.Code).To(Equal(http.StatusOK)) + Expect(capturedConfig).ToNot(BeNil()) + Expect(capturedConfig.ShouldCallSpecificFunction()).To(BeFalse()) + }) + + It("is a no-op when name is missing from both shapes", func() { + rec := postJSON(app, "/v1/chat/completions", + chatReq(`{"type":"function","function":{}}`)) + + Expect(rec.Code).To(Equal(http.StatusOK)) + Expect(capturedConfig).ToNot(BeNil()) + Expect(capturedConfig.ShouldCallSpecificFunction()).To(BeFalse()) + Expect(capturedConfig.FunctionToCall()).To(Equal("")) + }) + + It("is a no-op when name is empty string", func() { + rec := postJSON(app, "/v1/chat/completions", + chatReq(`{"type":"function","function":{"name":""}}`)) + + Expect(rec.Code).To(Equal(http.StatusOK)) + Expect(capturedConfig).ToNot(BeNil()) + Expect(capturedConfig.ShouldCallSpecificFunction()).To(BeFalse()) + }) + }) + + Context("nil tool_choice", func() { + It("is a no-op", func() { + rec := postJSON(app, "/v1/chat/completions", + `{"model":"test-model","messages":[{"role":"user","content":"hi"}]}`) + + Expect(rec.Code).To(Equal(http.StatusOK)) + Expect(capturedConfig).ToNot(BeNil()) + Expect(capturedConfig.ShouldCallSpecificFunction()).To(BeFalse()) + Expect(capturedConfig.FunctionToCall()).To(Equal("")) + }) + }) +}) From b864231a4e0f3a43a031768e1bdafb714dfad458 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Wed, 13 May 2026 20:45:40 +0000 Subject: [PATCH 3/4] fix(middleware): preserve pre-#9559 support for JSON-string-encoded tool_choice Some non-spec clients send tool_choice as a JSON-encoded string of an object form, e.g. "{\"type\":\"function\",\"function\":{\"name\":\"X\"}}". The pre-#9559 code accepted this by accident: its case string: branch ran json.Unmarshal([]byte(content), &functions.Tool{}), which succeeded for that double-encoded shape even though it failed for the legitimate plain string modes "auto" / "none" / "required". The first version of this PR routed every string straight to SetFunctionCallString as a mode, which fixed the plain-string cases but silently regressed the double-encoded one (funcs.Select("{...}") returns nothing). Restore the fallback: when a string looks like a JSON object, try parsing it as a tool_choice map first; fall through to mode-string handling only when no usable name comes out. Factor the map-name extraction into a small helper (extractToolChoiceFunctionName) so the string-fallback and the regular map case go through identical code, and accept both the OpenAI-spec nested shape and the legacy/Anthropic flat shape from either entry point. Add 3 Ginkgo specs covering the double-encoded case (nested form, legacy form, and the fall-through when the JSON has no usable name). Signed-off-by: Ettore Di Giacinto Assisted-by: Claude:opus-4-7 [Claude Code] --- core/http/middleware/request.go | 91 +++++++++++++++++----------- core/http/middleware/request_test.go | 48 +++++++++++++++ 2 files changed, 105 insertions(+), 34 deletions(-) diff --git a/core/http/middleware/request.go b/core/http/middleware/request.go index cd791e096201..7979682601e2 100644 --- a/core/http/middleware/request.go +++ b/core/http/middleware/request.go @@ -240,6 +240,28 @@ func (re *RequestExtractor) SetOpenAIRequest(c echo.Context) error { return nil } +// extractToolChoiceFunctionName parses a tool_choice map and returns the +// specific function name. Accepts both the OpenAI-spec nested shape +// ({type:function, function:{name:...}}) and the legacy/Anthropic-compat +// flat shape ({type:function, name:...}); the nested form wins when both +// are present. Returns "" for malformed input or when the shape names a +// mode rather than a specific tool. +func extractToolChoiceFunctionName(m map[string]any) string { + tcType, ok := m["type"].(string) + if !ok || tcType != "function" { + return "" + } + if fn, ok := m["function"].(map[string]any); ok { + if n, ok := fn["name"].(string); ok && n != "" { + return n + } + } + if n, ok := m["name"].(string); ok { + return n + } + return "" +} + func mergeOpenAIRequestAndModelConfig(config *config.ModelConfig, input *schema.OpenAIRequest) error { if input.Echo { config.Echo = input.Echo @@ -319,53 +341,54 @@ func mergeOpenAIRequestAndModelConfig(config *config.ModelConfig, input *schema. } if input.ToolsChoice != nil { - // OpenAI tool_choice has three valid shapes: + // OpenAI tool_choice has three valid shapes plus one tolerated + // non-spec form seen in the wild: // - // 1. string mode: "auto" | "none" | "required" - // 2. specific tool: {"type":"function", "function":{"name":"..."}} (current spec) - // 3. legacy: {"type":"function", "name":"..."} (older / Anthropic-compat) + // 1. string mode: "auto" | "none" | "required" + // 2. specific tool: {"type":"function", "function":{"name":"..."}} (current spec) + // 3. legacy: {"type":"function", "name":"..."} (older / Anthropic-compat) + // 4. double-encoded: "{\"type\":\"function\", ...}" (some clients serialize the object) // - // The previous code unmarshalled all three into functions.Tool via - // json.Unmarshal and then unconditionally set input.FunctionCall = - // {"name": toolChoice.Function.Name}. That had two consequences: + // The pre-#9559 code unmarshalled the string case through + // json.Unmarshal([]byte(content), &functions.Tool{}), which: + // - failed for plain string modes (so "required" / "none" were + // silently ignored and tools stayed enabled regardless), but + // - happened to handle shape 4 by accident. + // It also could not parse shape 3 because functions.Tool has no + // flat top-level Name field. // - // - For string modes, json.Unmarshal([]byte("required"), &Tool{}) fails; - // the error was silently discarded, name was "", and downstream - // SetFunctionCallNameString("") meant the requested mode never applied. - // - For the OpenAI-spec map shape, the json keys did not match - // functions.Tool's field tags, so name was "" again. - // - // Mirror the parsing pattern from MergeOpenResponsesConfig (#9509) and + // Mirror the parsing pattern from MergeOpenResponsesConfig (#9509), // route results through the existing input.FunctionCall string/map // dispatch downstream (see the switch on input.FunctionCall in this - // same function). Tracked in #9508; sibling fix in #9526. + // same function), and preserve the shape-4 fallback so non-spec + // clients don't silently break. Tracked in #9508; sibling fix in #9526. switch content := input.ToolsChoice.(type) { case string: // "auto" is the default and needs no override. "none" and "required" // both reach SetFunctionCallString via the input.FunctionCall string // branch below; ShouldUseFunctions() then returns false for "none" - // (tools disabled) and true for "required" (mode engaged), matching - // the OpenAI spec. Before this fix "none" was silently ignored - // (json.Unmarshal(`none`, &Tool{}) failed) so tools stayed enabled. - if content != "" && content != "auto" { - input.FunctionCall = content + // (tools disabled) and true for "required" (mode engaged). + // + // If the string looks like a JSON object, try shape 4 first: parse + // it as a tool_choice map and use the resulting name. Falling back + // to mode-string handling when the parse yields no usable name keeps + // genuinely-malformed input from accidentally engaging a mode. + if content == "" || content == "auto" { + break } - case map[string]any: - if tcType, ok := content["type"].(string); ok && tcType == "function" { - var name string - if fn, ok := content["function"].(map[string]any); ok { - if n, ok := fn["name"].(string); ok { - name = n - } - } - if name == "" { - if n, ok := content["name"].(string); ok { - name = n + if strings.HasPrefix(strings.TrimSpace(content), "{") { + var nested map[string]any + if err := json.Unmarshal([]byte(content), &nested); err == nil { + if name := extractToolChoiceFunctionName(nested); name != "" { + input.FunctionCall = map[string]any{"name": name} + break } } - if name != "" { - input.FunctionCall = map[string]any{"name": name} - } + } + input.FunctionCall = content + case map[string]any: + if name := extractToolChoiceFunctionName(content); name != "" { + input.FunctionCall = map[string]any{"name": name} } } } diff --git a/core/http/middleware/request_test.go b/core/http/middleware/request_test.go index b6acdd36fd60..e57fb5b4a354 100644 --- a/core/http/middleware/request_test.go +++ b/core/http/middleware/request_test.go @@ -452,6 +452,54 @@ var _ = Describe("SetModelAndConfig tool_choice parsing (chat completions)", fun }) }) + // Some non-spec clients send the object form serialized as a JSON string. + // The pre-#9559 code accepted that by accident; this Context locks in + // continued tolerance so those clients do not silently regress. + Context("double-encoded tool_choice (JSON string of an object, non-spec)", func() { + It("parses a serialized OpenAI-spec nested object", func() { + // tool_choice value is itself a JSON-encoded string containing the + // object form. Use json.Marshal of the inner blob so the escapes + // are correct regardless of the test reader. + inner := `{"type":"function","function":{"name":"get_weather"}}` + encoded, err := json.Marshal(inner) + Expect(err).ToNot(HaveOccurred()) + rec := postJSON(app, "/v1/chat/completions", chatReq(string(encoded))) + + Expect(rec.Code).To(Equal(http.StatusOK)) + Expect(capturedConfig).ToNot(BeNil()) + Expect(capturedConfig.ShouldCallSpecificFunction()).To(BeTrue()) + Expect(capturedConfig.FunctionToCall()).To(Equal("get_weather")) + }) + + It("parses a serialized legacy/Anthropic flat object", func() { + inner := `{"type":"function","name":"get_weather"}` + encoded, err := json.Marshal(inner) + Expect(err).ToNot(HaveOccurred()) + rec := postJSON(app, "/v1/chat/completions", chatReq(string(encoded))) + + Expect(rec.Code).To(Equal(http.StatusOK)) + Expect(capturedConfig).ToNot(BeNil()) + Expect(capturedConfig.ShouldCallSpecificFunction()).To(BeTrue()) + Expect(capturedConfig.FunctionToCall()).To(Equal("get_weather")) + }) + + It("falls back to mode-string handling when the JSON string parses but has no usable name", func() { + // A JSON-string that decodes to a map without a function name + // should not engage specific-function forcing. We expect it to + // fall through to the mode-string path; the resulting mode is + // the raw blob (nonsense), but ShouldCallSpecificFunction stays + // false - the invariant that matters. + inner := `{"type":"function"}` + encoded, err := json.Marshal(inner) + Expect(err).ToNot(HaveOccurred()) + rec := postJSON(app, "/v1/chat/completions", chatReq(string(encoded))) + + Expect(rec.Code).To(Equal(http.StatusOK)) + Expect(capturedConfig).ToNot(BeNil()) + Expect(capturedConfig.ShouldCallSpecificFunction()).To(BeFalse()) + }) + }) + Context("malformed tool_choice", func() { It("is a no-op when type is missing", func() { rec := postJSON(app, "/v1/chat/completions", From 96e5dfbca92dda63a27f8f0d17c2a1a8a0d4afab Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Wed, 13 May 2026 20:48:21 +0000 Subject: [PATCH 4/4] test(middleware): silence errcheck on AfterEach os.RemoveAll The new tool_choice parsing tests added a second AfterEach that calls os.RemoveAll(modelDir) without checking the error; errcheck flagged it. Suppress with the standard _ = idiom. The pre-existing AfterEach on the earlier Describe still elides the check the same way it did before - leaving that untouched to keep this commit minimal. Assisted-by: Claude:opus-4-7 [Claude Code] Signed-off-by: Ettore Di Giacinto --- core/http/middleware/request_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/http/middleware/request_test.go b/core/http/middleware/request_test.go index e57fb5b4a354..70e8a05b13c8 100644 --- a/core/http/middleware/request_test.go +++ b/core/http/middleware/request_test.go @@ -371,7 +371,7 @@ var _ = Describe("SetModelAndConfig tool_choice parsing (chat completions)", fun }) AfterEach(func() { - os.RemoveAll(modelDir) + _ = os.RemoveAll(modelDir) }) // chatReq wraps a tool_choice JSON fragment in a minimal valid chat-completions