Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 68 additions & 9 deletions core/http/middleware/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ import (
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/services/galleryop"
"github.com/mudler/LocalAI/core/templates"
"github.com/mudler/LocalAI/pkg/functions"
"github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/LocalAI/pkg/utils"
"github.com/mudler/xlog"
Expand Down Expand Up @@ -241,6 +240,28 @@ func (re *RequestExtractor) SetOpenAIRequest(c echo.Context) error {
return nil
}

// extractToolChoiceFunctionName parses a tool_choice map and returns the
// specific function name. Accepts both the OpenAI-spec nested shape
// ({type:function, function:{name:...}}) and the legacy/Anthropic-compat
// flat shape ({type:function, name:...}); the nested form wins when both
// are present. Returns "" for malformed input or when the shape names a
// mode rather than a specific tool.
func extractToolChoiceFunctionName(m map[string]any) string {
tcType, ok := m["type"].(string)
if !ok || tcType != "function" {
return ""
}
if fn, ok := m["function"].(map[string]any); ok {
if n, ok := fn["name"].(string); ok && n != "" {
return n
}
}
if n, ok := m["name"].(string); ok {
return n
}
return ""
}

func mergeOpenAIRequestAndModelConfig(config *config.ModelConfig, input *schema.OpenAIRequest) error {
if input.Echo {
config.Echo = input.Echo
Expand Down Expand Up @@ -320,17 +341,55 @@ func mergeOpenAIRequestAndModelConfig(config *config.ModelConfig, input *schema.
}

if input.ToolsChoice != nil {
var toolChoice functions.Tool

// OpenAI tool_choice has three valid shapes plus one tolerated
// non-spec form seen in the wild:
//
// 1. string mode: "auto" | "none" | "required"
// 2. specific tool: {"type":"function", "function":{"name":"..."}} (current spec)
// 3. legacy: {"type":"function", "name":"..."} (older / Anthropic-compat)
// 4. double-encoded: "{\"type\":\"function\", ...}" (some clients serialize the object)
//
// The pre-#9559 code unmarshalled the string case through
// json.Unmarshal([]byte(content), &functions.Tool{}), which:
// - failed for plain string modes (so "required" / "none" were
// silently ignored and tools stayed enabled regardless), but
// - happened to handle shape 4 by accident.
// It also could not parse shape 3 because functions.Tool has no
// flat top-level Name field.
//
// Mirror the parsing pattern from MergeOpenResponsesConfig (#9509),
// route results through the existing input.FunctionCall string/map
// dispatch downstream (see the switch on input.FunctionCall in this
// same function), and preserve the shape-4 fallback so non-spec
// clients don't silently break. Tracked in #9508; sibling fix in #9526.
switch content := input.ToolsChoice.(type) {
case string:
_ = json.Unmarshal([]byte(content), &toolChoice)
// "auto" is the default and needs no override. "none" and "required"
// both reach SetFunctionCallString via the input.FunctionCall string
// branch below; ShouldUseFunctions() then returns false for "none"
// (tools disabled) and true for "required" (mode engaged).
//
// If the string looks like a JSON object, try shape 4 first: parse
// it as a tool_choice map and use the resulting name. Falling back
// to mode-string handling when the parse yields no usable name keeps
// genuinely-malformed input from accidentally engaging a mode.
if content == "" || content == "auto" {
break
}
if strings.HasPrefix(strings.TrimSpace(content), "{") {
var nested map[string]any
if err := json.Unmarshal([]byte(content), &nested); err == nil {
if name := extractToolChoiceFunctionName(nested); name != "" {
input.FunctionCall = map[string]any{"name": name}
break
}
}
}
input.FunctionCall = content
case map[string]any:
dat, _ := json.Marshal(content)
_ = json.Unmarshal(dat, &toolChoice)
}
input.FunctionCall = map[string]any{
"name": toolChoice.Function.Name,
if name := extractToolChoiceFunctionName(content); name != "" {
input.FunctionCall = map[string]any{"name": name}
}
}
}

Expand Down
245 changes: 245 additions & 0 deletions core/http/middleware/request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -306,3 +306,248 @@ var _ = Describe("MergeOpenResponsesConfig tool_choice parsing", func() {
})
})
})

// ---------------------------------------------------------------------------
// SetModelAndConfig + SetOpenAIRequest - /v1/chat/completions tool_choice parsing
// ---------------------------------------------------------------------------
//
// Parallel to the MergeOpenResponsesConfig specs above, but for the chat
// completions path. The parsing block lives in mergeOpenAIRequestAndModelConfig
// (called from SetOpenAIRequest), so these tests drive the full middleware
// chain the way the production /v1/chat/completions route does.
//
// What we assert per shape:
// - "required" -> ShouldUseFunctions=true, no specific name
// - "none" -> ShouldUseFunctions=false (tools disabled)
// - "auto" -> ShouldUseFunctions=true, no specific name
// - {type:function, function:{name:"X"}} (spec) -> ShouldCallSpecificFunction=true, FunctionToCall="X"
// - {type:function, name:"X"} (legacy) -> ShouldCallSpecificFunction=true, FunctionToCall="X"
// - nested+flat both present -> nested wins
// - malformed (no type / no name) -> no-op
var _ = Describe("SetModelAndConfig tool_choice parsing (chat completions)", func() {
var (
app *echo.Echo
modelDir string
capturedConfig *config.ModelConfig
)

BeforeEach(func() {
var err error
modelDir, err = os.MkdirTemp("", "localai-test-models-*")
Expect(err).ToNot(HaveOccurred())

cfgContent := []byte("name: test-model\nbackend: llama-cpp\n")
Expect(os.WriteFile(filepath.Join(modelDir, "test-model.yaml"), cfgContent, 0644)).To(Succeed())

ss := &system.SystemState{
Model: system.Model{ModelsPath: modelDir},
}
appConfig := config.NewApplicationConfig()
appConfig.SystemState = ss

mcl := config.NewModelConfigLoader(modelDir)
ml := model.NewModelLoader(ss)
re := NewRequestExtractor(mcl, ml, appConfig)

capturedConfig = nil
app = echo.New()
app.POST("/v1/chat/completions",
func(c echo.Context) error {
if cfg, ok := c.Get(CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig); ok {
capturedConfig = cfg
}
return c.String(http.StatusOK, "ok")
},
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }),
func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if err := re.SetOpenAIRequest(c); err != nil {
return err
}
return next(c)
}
},
)
})

AfterEach(func() {
_ = os.RemoveAll(modelDir)
})

// chatReq wraps a tool_choice JSON fragment in a minimal valid chat-completions
// payload. The tools array is non-empty so downstream code paths that gate on
// len(input.Functions) see something to work with.
chatReq := func(toolChoiceJSON string) string {
return `{"model":"test-model",` +
`"messages":[{"role":"user","content":"hi"}],` +
`"tools":[{"type":"function","function":{"name":"get_weather"}}],` +
`"tool_choice":` + toolChoiceJSON + `}`
}

Context("string tool_choice", func() {
It("engages mode for tool_choice=\"required\"", func() {
rec := postJSON(app, "/v1/chat/completions", chatReq(`"required"`))

Expect(rec.Code).To(Equal(http.StatusOK))
Expect(capturedConfig).ToNot(BeNil())
Expect(capturedConfig.ShouldCallSpecificFunction()).To(BeFalse())
Expect(capturedConfig.ShouldUseFunctions()).To(BeTrue())
})

It("disables tools for tool_choice=\"none\"", func() {
// Before #9559 this was a silent no-op (json.Unmarshal of "none"
// into functions.Tool failed); now "none" is honored per OpenAI spec.
rec := postJSON(app, "/v1/chat/completions", chatReq(`"none"`))

Expect(rec.Code).To(Equal(http.StatusOK))
Expect(capturedConfig).ToNot(BeNil())
Expect(capturedConfig.ShouldUseFunctions()).To(BeFalse())
Expect(capturedConfig.ShouldCallSpecificFunction()).To(BeFalse())
})

It("leaves config untouched for tool_choice=\"auto\"", func() {
rec := postJSON(app, "/v1/chat/completions", chatReq(`"auto"`))

Expect(rec.Code).To(Equal(http.StatusOK))
Expect(capturedConfig).ToNot(BeNil())
Expect(capturedConfig.ShouldCallSpecificFunction()).To(BeFalse())
// "auto" is the default: tools available, model decides.
Expect(capturedConfig.ShouldUseFunctions()).To(BeTrue())
Expect(capturedConfig.FunctionToCall()).To(Equal(""))
})
})

Context("specific-function tool_choice (OpenAI spec shape)", func() {
It("parses {type:function, function:{name:...}} and forces the named function", func() {
rec := postJSON(app, "/v1/chat/completions",
chatReq(`{"type":"function","function":{"name":"get_weather"}}`))

Expect(rec.Code).To(Equal(http.StatusOK))
Expect(capturedConfig).ToNot(BeNil())
// Key invariant: a correctly-formed OpenAI tool_choice must engage
// grammar-based forcing via SetFunctionCallNameString.
Expect(capturedConfig.ShouldCallSpecificFunction()).To(BeTrue())
Expect(capturedConfig.FunctionToCall()).To(Equal("get_weather"))
})

It("prefers the nested function.name over a stray top-level name", func() {
rec := postJSON(app, "/v1/chat/completions",
chatReq(`{"type":"function","function":{"name":"correct_name"},"name":"legacy_name"}`))

Expect(rec.Code).To(Equal(http.StatusOK))
Expect(capturedConfig).ToNot(BeNil())
Expect(capturedConfig.FunctionToCall()).To(Equal("correct_name"))
})
})

Context("specific-function tool_choice (legacy Anthropic-compat shape)", func() {
It("parses {type:function, name:...} and forces the named function", func() {
rec := postJSON(app, "/v1/chat/completions",
chatReq(`{"type":"function","name":"get_weather"}`))

Expect(rec.Code).To(Equal(http.StatusOK))
Expect(capturedConfig).ToNot(BeNil())
Expect(capturedConfig.ShouldCallSpecificFunction()).To(BeTrue())
Expect(capturedConfig.FunctionToCall()).To(Equal("get_weather"))
})
})

// Some non-spec clients send the object form serialized as a JSON string.
// The pre-#9559 code accepted that by accident; this Context locks in
// continued tolerance so those clients do not silently regress.
Context("double-encoded tool_choice (JSON string of an object, non-spec)", func() {
It("parses a serialized OpenAI-spec nested object", func() {
// tool_choice value is itself a JSON-encoded string containing the
// object form. Use json.Marshal of the inner blob so the escapes
// are correct regardless of the test reader.
inner := `{"type":"function","function":{"name":"get_weather"}}`
encoded, err := json.Marshal(inner)
Expect(err).ToNot(HaveOccurred())
rec := postJSON(app, "/v1/chat/completions", chatReq(string(encoded)))

Expect(rec.Code).To(Equal(http.StatusOK))
Expect(capturedConfig).ToNot(BeNil())
Expect(capturedConfig.ShouldCallSpecificFunction()).To(BeTrue())
Expect(capturedConfig.FunctionToCall()).To(Equal("get_weather"))
})

It("parses a serialized legacy/Anthropic flat object", func() {
inner := `{"type":"function","name":"get_weather"}`
encoded, err := json.Marshal(inner)
Expect(err).ToNot(HaveOccurred())
rec := postJSON(app, "/v1/chat/completions", chatReq(string(encoded)))

Expect(rec.Code).To(Equal(http.StatusOK))
Expect(capturedConfig).ToNot(BeNil())
Expect(capturedConfig.ShouldCallSpecificFunction()).To(BeTrue())
Expect(capturedConfig.FunctionToCall()).To(Equal("get_weather"))
})

It("falls back to mode-string handling when the JSON string parses but has no usable name", func() {
// A JSON-string that decodes to a map without a function name
// should not engage specific-function forcing. We expect it to
// fall through to the mode-string path; the resulting mode is
// the raw blob (nonsense), but ShouldCallSpecificFunction stays
// false - the invariant that matters.
inner := `{"type":"function"}`
encoded, err := json.Marshal(inner)
Expect(err).ToNot(HaveOccurred())
rec := postJSON(app, "/v1/chat/completions", chatReq(string(encoded)))

Expect(rec.Code).To(Equal(http.StatusOK))
Expect(capturedConfig).ToNot(BeNil())
Expect(capturedConfig.ShouldCallSpecificFunction()).To(BeFalse())
})
})

Context("malformed tool_choice", func() {
It("is a no-op when type is missing", func() {
rec := postJSON(app, "/v1/chat/completions",
chatReq(`{"function":{"name":"get_weather"}}`))

Expect(rec.Code).To(Equal(http.StatusOK))
Expect(capturedConfig).ToNot(BeNil())
Expect(capturedConfig.ShouldCallSpecificFunction()).To(BeFalse())
})

It("is a no-op when type is not \"function\"", func() {
rec := postJSON(app, "/v1/chat/completions",
chatReq(`{"type":"object","function":{"name":"get_weather"}}`))

Expect(rec.Code).To(Equal(http.StatusOK))
Expect(capturedConfig).ToNot(BeNil())
Expect(capturedConfig.ShouldCallSpecificFunction()).To(BeFalse())
})

It("is a no-op when name is missing from both shapes", func() {
rec := postJSON(app, "/v1/chat/completions",
chatReq(`{"type":"function","function":{}}`))

Expect(rec.Code).To(Equal(http.StatusOK))
Expect(capturedConfig).ToNot(BeNil())
Expect(capturedConfig.ShouldCallSpecificFunction()).To(BeFalse())
Expect(capturedConfig.FunctionToCall()).To(Equal(""))
})

It("is a no-op when name is empty string", func() {
rec := postJSON(app, "/v1/chat/completions",
chatReq(`{"type":"function","function":{"name":""}}`))

Expect(rec.Code).To(Equal(http.StatusOK))
Expect(capturedConfig).ToNot(BeNil())
Expect(capturedConfig.ShouldCallSpecificFunction()).To(BeFalse())
})
})

Context("nil tool_choice", func() {
It("is a no-op", func() {
rec := postJSON(app, "/v1/chat/completions",
`{"model":"test-model","messages":[{"role":"user","content":"hi"}]}`)

Expect(rec.Code).To(Equal(http.StatusOK))
Expect(capturedConfig).ToNot(BeNil())
Expect(capturedConfig.ShouldCallSpecificFunction()).To(BeFalse())
Expect(capturedConfig.FunctionToCall()).To(Equal(""))
})
})
})
Loading