From af0db1419c0e1b5ba57e4b4228e98d3c9a653ad5 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Thu, 11 Jun 2026 14:28:39 +0000 Subject: [PATCH 01/13] test(http): make the suite listen port configurable The core/http specs hardcoded 127.0.0.1:9090 in ~70 call sites, so the pre-commit coverage gate fails on any machine where an unrelated service holds 9090. Centralize the address in the suite file behind LOCALAI_TEST_HTTP_PORT (default unchanged: 9090). Assisted-by: Claude Code (Fable 5) Signed-off-by: Ettore Di Giacinto --- core/http/app_test.go | 76 ++++++++++++++++----------------- core/http/http_suite_test.go | 14 ++++++ core/http/openresponses_test.go | 60 +++++++++++++------------- 3 files changed, 82 insertions(+), 68 deletions(-) diff --git a/core/http/app_test.go b/core/http/app_test.go index 735edaf1c27a..70e20918bca8 100644 --- a/core/http/app_test.go +++ b/core/http/app_test.go @@ -383,13 +383,13 @@ var _ = Describe("API test", func() { Expect(err).ToNot(HaveOccurred()) go func() { - if err := app.Start("127.0.0.1:9090"); err != nil && err != http.ErrServerClosed { + if err := app.Start(testHTTPAddr); err != nil && err != http.ErrServerClosed { xlog.Error("server error", "error", err) } }() defaultConfig := openai.DefaultConfig(apiKey) - defaultConfig.BaseURL = "http://127.0.0.1:9090/v1" + defaultConfig.BaseURL = testHTTPBase + "/v1" client2 = openaigo.NewClient("") client2.BaseURL = defaultConfig.BaseURL @@ -418,7 +418,7 @@ var _ = Describe("API test", func() { Context("Auth Tests", func() { It("Should fail if the api key is missing", func() { - err, sc := postInvalidRequest("http://127.0.0.1:9090/models/available") + err, sc := postInvalidRequest(testHTTPBase + "/models/available") Expect(err).ToNot(BeNil()) Expect(sc).To(Equal(401)) }) @@ -427,7 +427,7 @@ var _ = Describe("API test", func() { Context("URL routing Tests", func() { It("Should support reverse-proxy when unauthenticated", func() { - err, sc, body := getRequest("http://127.0.0.1:9090/myprefix/", http.Header{ + err, sc, body := getRequest(testHTTPBase+"/myprefix/", http.Header{ "X-Forwarded-Proto": {"https"}, "X-Forwarded-Host": {"example.org"}, "X-Forwarded-Prefix": {"/myprefix/"}, @@ -441,7 +441,7 @@ var _ = Describe("API test", func() { It("Should support reverse-proxy when authenticated", func() { - err, sc, body := getRequest("http://127.0.0.1:9090/myprefix/", http.Header{ + err, sc, body := getRequest(testHTTPBase+"/myprefix/", http.Header{ "Authorization": {bearerKey}, "X-Forwarded-Proto": {"https"}, "X-Forwarded-Host": {"example.org"}, @@ -459,7 +459,7 @@ var _ = Describe("API test", func() { // requests them through the proxy. It("Should support reverse-proxy when prefix is stripped by the proxy", func() { - err, sc, body := getRequest("http://127.0.0.1:9090/app", http.Header{ + err, sc, body := getRequest(testHTTPBase+"/app", http.Header{ "X-Forwarded-Proto": {"https"}, "X-Forwarded-Host": {"example.org"}, "X-Forwarded-Prefix": {"/myprefix"}, @@ -477,7 +477,7 @@ var _ = Describe("API test", func() { // from a foreign origin. BasePathPrefix must reject these via // SafeForwardedPrefix and fall back to "/". It("Should ignore an unsafe X-Forwarded-Prefix and not poison asset URLs", func() { - err, sc, body := getRequest("http://127.0.0.1:9090/app", http.Header{ + err, sc, body := getRequest(testHTTPBase+"/app", http.Header{ "X-Forwarded-Proto": {"https"}, "X-Forwarded-Host": {"example.org"}, "X-Forwarded-Prefix": {"//evil.com"}, @@ -492,13 +492,13 @@ var _ = Describe("API test", func() { Context("Applying models", func() { It("applies models from a gallery", func() { - models, err := getModels("http://127.0.0.1:9090/models/available") + models, err := getModels(testHTTPBase + "/models/available") Expect(err).To(BeNil()) Expect(len(models)).To(Equal(2), fmt.Sprint(models)) Expect(models[0].Installed).To(BeFalse(), fmt.Sprint(models)) Expect(models[1].Installed).To(BeFalse(), fmt.Sprint(models)) - response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{ + response := postModelApplyRequest(testHTTPBase+"/models/apply", modelApplyRequest{ ID: "test@bert2", }) @@ -507,7 +507,7 @@ var _ = Describe("API test", func() { uuid := response["uuid"].(string) resp := map[string]any{} Eventually(func() bool { - response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid) + response := getModelStatus(testHTTPBase + "/models/jobs/" + uuid) fmt.Println(response) resp = response return response["processed"].(bool) @@ -526,7 +526,7 @@ var _ = Describe("API test", func() { Expect(content["usage"]).To(ContainSubstring("You can test this model with curl like this")) Expect(content["foo"]).To(Equal("bar")) - models, err = getModels("http://127.0.0.1:9090/models/available") + models, err = getModels(testHTTPBase + "/models/available") Expect(err).To(BeNil()) Expect(len(models)).To(Equal(2), fmt.Sprint(models)) Expect(models[0].Name).To(Or(Equal("bert"), Equal("bert2"))) @@ -541,7 +541,7 @@ var _ = Describe("API test", func() { }) It("overrides models", func() { - response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{ + response := postModelApplyRequest(testHTTPBase+"/models/apply", modelApplyRequest{ URL: bertEmbeddingsURL, Name: "bert", Overrides: map[string]any{ @@ -554,7 +554,7 @@ var _ = Describe("API test", func() { uuid := response["uuid"].(string) Eventually(func() bool { - response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid) + response := getModelStatus(testHTTPBase + "/models/jobs/" + uuid) return response["processed"].(bool) }, "360s", "10s").Should(Equal(true)) @@ -567,7 +567,7 @@ var _ = Describe("API test", func() { Expect(content["backend"]).To(Equal("llama")) }) It("apply models without overrides", func() { - response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{ + response := postModelApplyRequest(testHTTPBase+"/models/apply", modelApplyRequest{ URL: bertEmbeddingsURL, Name: "bert", Overrides: map[string]any{}, @@ -578,7 +578,7 @@ var _ = Describe("API test", func() { uuid := response["uuid"].(string) Eventually(func() bool { - response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid) + response := getModelStatus(testHTTPBase + "/models/jobs/" + uuid) return response["processed"].(bool) }, "360s", "10s").Should(Equal(true)) @@ -622,14 +622,14 @@ parameters: } var response schema.GalleryResponse - err := postRequestResponseJSON("http://127.0.0.1:9090/models/import-uri", &importReq, &response) + err := postRequestResponseJSON(testHTTPBase+"/models/import-uri", &importReq, &response) Expect(err).ToNot(HaveOccurred()) Expect(response.ID).ToNot(BeEmpty()) uuid := response.ID resp := map[string]any{} Eventually(func() bool { - response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid) + response := getModelStatus(testHTTPBase + "/models/jobs/" + uuid) resp = response return response["processed"].(bool) }, "360s", "10s").Should(Equal(true)) @@ -657,7 +657,7 @@ parameters: } var response schema.GalleryResponse - err := postRequestResponseJSON("http://127.0.0.1:9090/models/import-uri", &importReq, &response) + err := postRequestResponseJSON(testHTTPBase+"/models/import-uri", &importReq, &response) // The endpoint should return an error immediately Expect(err).To(HaveOccurred()) Expect(err.Error()).To(ContainSubstring("failed to discover model config")) @@ -693,14 +693,14 @@ parameters: } var response schema.GalleryResponse - err := postRequestResponseJSON("http://127.0.0.1:9090/models/import-uri", &importReq, &response) + err := postRequestResponseJSON(testHTTPBase+"/models/import-uri", &importReq, &response) Expect(err).ToNot(HaveOccurred()) Expect(response.ID).ToNot(BeEmpty()) uuid := response.ID resp := map[string]any{} Eventually(func() bool { - response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid) + response := getModelStatus(testHTTPBase + "/models/jobs/" + uuid) resp = response return response["processed"].(bool) }, "360s", "10s").Should(Equal(true)) @@ -751,13 +751,13 @@ parameters: app, err = API(localAIApp) Expect(err).ToNot(HaveOccurred()) go func() { - if err := app.Start("127.0.0.1:9090"); err != nil && err != http.ErrServerClosed { + if err := app.Start(testHTTPAddr); err != nil && err != http.ErrServerClosed { xlog.Error("server error", "error", err) } }() defaultConfig := openai.DefaultConfig("") - defaultConfig.BaseURL = "http://127.0.0.1:9090/v1" + defaultConfig.BaseURL = testHTTPBase + "/v1" client2 = openaigo.NewClient("") client2.BaseURL = defaultConfig.BaseURL @@ -801,7 +801,7 @@ parameters: // Mock-backend is registered via SetExternalBackend so it appears // alongside any built-in entries; verifying that string proves the // endpoint is wired up regardless of which real backends exist. - resp, err := http.Get("http://127.0.0.1:9090/system") + resp, err := http.Get(testHTTPBase + "/system") Expect(err).ToNot(HaveOccurred()) Expect(resp.StatusCode).To(Equal(200)) dat, err := io.ReadAll(resp.Body) @@ -824,14 +824,14 @@ parameters: } var createResp map[string]any - err := postRequestResponseJSON("http://127.0.0.1:9090/api/agent/tasks", &taskBody, &createResp) + err := postRequestResponseJSON(testHTTPBase+"/api/agent/tasks", &taskBody, &createResp) Expect(err).ToNot(HaveOccurred()) Expect(createResp["id"]).ToNot(BeEmpty()) taskID := createResp["id"].(string) // Get the task var task schema.Task - resp, err := http.Get("http://127.0.0.1:9090/api/agent/tasks/" + taskID) + resp, err := http.Get(testHTTPBase + "/api/agent/tasks/" + taskID) Expect(err).ToNot(HaveOccurred()) Expect(resp.StatusCode).To(Equal(200)) body, _ := io.ReadAll(resp.Body) @@ -839,7 +839,7 @@ parameters: Expect(task.Name).To(Equal("Test Task")) // List tasks - resp, err = http.Get("http://127.0.0.1:9090/api/agent/tasks") + resp, err = http.Get(testHTTPBase + "/api/agent/tasks") Expect(err).ToNot(HaveOccurred()) Expect(resp.StatusCode).To(Equal(200)) var tasks []schema.Task @@ -849,18 +849,18 @@ parameters: // Update task taskBody["name"] = "Updated Task" - err = putRequestJSON("http://127.0.0.1:9090/api/agent/tasks/"+taskID, &taskBody) + err = putRequestJSON(testHTTPBase+"/api/agent/tasks/"+taskID, &taskBody) Expect(err).ToNot(HaveOccurred()) // Verify update - resp, err = http.Get("http://127.0.0.1:9090/api/agent/tasks/" + taskID) + resp, err = http.Get(testHTTPBase + "/api/agent/tasks/" + taskID) Expect(err).ToNot(HaveOccurred()) body, _ = io.ReadAll(resp.Body) json.Unmarshal(body, &task) Expect(task.Name).To(Equal("Updated Task")) // Delete task - req, _ := http.NewRequest("DELETE", "http://127.0.0.1:9090/api/agent/tasks/"+taskID, nil) + req, _ := http.NewRequest("DELETE", testHTTPBase+"/api/agent/tasks/"+taskID, nil) req.Header.Set("Authorization", bearerKey) resp, err = http.DefaultClient.Do(req) Expect(err).ToNot(HaveOccurred()) @@ -877,7 +877,7 @@ parameters: } var createResp map[string]any - err := postRequestResponseJSON("http://127.0.0.1:9090/api/agent/tasks", &taskBody, &createResp) + err := postRequestResponseJSON(testHTTPBase+"/api/agent/tasks", &taskBody, &createResp) Expect(err).ToNot(HaveOccurred()) taskID := createResp["id"].(string) @@ -888,14 +888,14 @@ parameters: } var jobResp schema.JobExecutionResponse - err = postRequestResponseJSON("http://127.0.0.1:9090/api/agent/jobs/execute", &jobBody, &jobResp) + err = postRequestResponseJSON(testHTTPBase+"/api/agent/jobs/execute", &jobBody, &jobResp) Expect(err).ToNot(HaveOccurred()) Expect(jobResp.JobID).ToNot(BeEmpty()) jobID := jobResp.JobID // Get job status var job schema.Job - resp, err := http.Get("http://127.0.0.1:9090/api/agent/jobs/" + jobID) + resp, err := http.Get(testHTTPBase + "/api/agent/jobs/" + jobID) Expect(err).ToNot(HaveOccurred()) Expect(resp.StatusCode).To(Equal(200)) body, _ := io.ReadAll(resp.Body) @@ -904,7 +904,7 @@ parameters: Expect(job.TaskID).To(Equal(taskID)) // List jobs - resp, err = http.Get("http://127.0.0.1:9090/api/agent/jobs") + resp, err = http.Get(testHTTPBase + "/api/agent/jobs") Expect(err).ToNot(HaveOccurred()) Expect(resp.StatusCode).To(Equal(200)) var jobs []schema.Job @@ -914,7 +914,7 @@ parameters: // Cancel job (if still pending/running) if job.Status == schema.JobStatusPending || job.Status == schema.JobStatusRunning { - req, _ := http.NewRequest("POST", "http://127.0.0.1:9090/api/agent/jobs/"+jobID+"/cancel", nil) + req, _ := http.NewRequest("POST", testHTTPBase+"/api/agent/jobs/"+jobID+"/cancel", nil) req.Header.Set("Authorization", bearerKey) resp, err = http.DefaultClient.Do(req) Expect(err).ToNot(HaveOccurred()) @@ -932,13 +932,13 @@ parameters: } var createResp map[string]any - err := postRequestResponseJSON("http://127.0.0.1:9090/api/agent/tasks", &taskBody, &createResp) + err := postRequestResponseJSON(testHTTPBase+"/api/agent/tasks", &taskBody, &createResp) Expect(err).ToNot(HaveOccurred()) // Execute by name paramsBody := map[string]string{"param1": "value1"} var jobResp schema.JobExecutionResponse - err = postRequestResponseJSON("http://127.0.0.1:9090/api/agent/tasks/Named Task/execute", ¶msBody, &jobResp) + err = postRequestResponseJSON(testHTTPBase+"/api/agent/tasks/Named Task/execute", ¶msBody, &jobResp) Expect(err).ToNot(HaveOccurred()) Expect(jobResp.JobID).ToNot(BeEmpty()) }) @@ -998,13 +998,13 @@ parameters: Expect(err).ToNot(HaveOccurred()) go func() { - if err := app.Start("127.0.0.1:9090"); err != nil && err != http.ErrServerClosed { + if err := app.Start(testHTTPAddr); err != nil && err != http.ErrServerClosed { xlog.Error("server error", "error", err) } }() defaultConfig := openai.DefaultConfig("") - defaultConfig.BaseURL = "http://127.0.0.1:9090/v1" + defaultConfig.BaseURL = testHTTPBase + "/v1" client2 = openaigo.NewClient("") client2.BaseURL = defaultConfig.BaseURL // Wait for API to be ready diff --git a/core/http/http_suite_test.go b/core/http/http_suite_test.go index 744c2c25285c..0e022d3a59e0 100644 --- a/core/http/http_suite_test.go +++ b/core/http/http_suite_test.go @@ -21,6 +21,20 @@ var ( mockBackendPath string ) +// testHTTPAddr is the listen address used by specs that start a full HTTP +// server. Configurable so the suite can run on machines where the default +// port is taken by an unrelated service (override: LOCALAI_TEST_HTTP_PORT). +var testHTTPAddr = func() string { + port := os.Getenv("LOCALAI_TEST_HTTP_PORT") + if port == "" { + port = "9090" + } + return "127.0.0.1:" + port +}() + +// testHTTPBase is the matching http://host:port prefix for client requests. +var testHTTPBase = "http://" + testHTTPAddr + // findMockBackendBinary locates the mock-backend binary built by // `make build-mock-backend`. Mirrors the lookup used by // tests/e2e/e2e_suite_test.go so both suites consume the same artifact. diff --git a/core/http/openresponses_test.go b/core/http/openresponses_test.go index fb28df38099c..6299a22ecdae 100644 --- a/core/http/openresponses_test.go +++ b/core/http/openresponses_test.go @@ -59,14 +59,14 @@ var _ = Describe("Open Responses API", func() { Expect(err).ToNot(HaveOccurred()) go func() { - if err := app.Start("127.0.0.1:9090"); err != nil && err != http.ErrServerClosed { + if err := app.Start(testHTTPAddr); err != nil && err != http.ErrServerClosed { xlog.Error("server error", "error", err) } }() // Wait for API to be ready Eventually(func() error { - resp, err := http.Get("http://127.0.0.1:9090/healthz") + resp, err := http.Get(testHTTPBase + "/healthz") if err != nil { return err } @@ -95,7 +95,7 @@ var _ = Describe("Open Responses API", func() { payload, err := json.Marshal(reqBody) Expect(err).ToNot(HaveOccurred()) - req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload)) + req, err := http.NewRequest("POST", testHTTPBase+"/v1/responses", bytes.NewBuffer(payload)) Expect(err).ToNot(HaveOccurred()) req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", bearerKey) @@ -118,7 +118,7 @@ var _ = Describe("Open Responses API", func() { payload, err := json.Marshal(reqBody) Expect(err).ToNot(HaveOccurred()) - req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload)) + req, err := http.NewRequest("POST", testHTTPBase+"/v1/responses", bytes.NewBuffer(payload)) Expect(err).ToNot(HaveOccurred()) req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", bearerKey) @@ -143,7 +143,7 @@ var _ = Describe("Open Responses API", func() { payload, err := json.Marshal(reqBody) Expect(err).ToNot(HaveOccurred()) - req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload)) + req, err := http.NewRequest("POST", testHTTPBase+"/v1/responses", bytes.NewBuffer(payload)) Expect(err).ToNot(HaveOccurred()) req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", bearerKey) @@ -168,7 +168,7 @@ var _ = Describe("Open Responses API", func() { payload, err := json.Marshal(reqBody) Expect(err).ToNot(HaveOccurred()) - req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload)) + req, err := http.NewRequest("POST", testHTTPBase+"/v1/responses", bytes.NewBuffer(payload)) Expect(err).ToNot(HaveOccurred()) req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", bearerKey) @@ -196,7 +196,7 @@ var _ = Describe("Open Responses API", func() { payload, err := json.Marshal(reqBody) Expect(err).ToNot(HaveOccurred()) - req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload)) + req, err := http.NewRequest("POST", testHTTPBase+"/v1/responses", bytes.NewBuffer(payload)) Expect(err).ToNot(HaveOccurred()) req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", bearerKey) @@ -241,7 +241,7 @@ var _ = Describe("Open Responses API", func() { payload, err := json.Marshal(reqBody) Expect(err).ToNot(HaveOccurred()) - req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload)) + req, err := http.NewRequest("POST", testHTTPBase+"/v1/responses", bytes.NewBuffer(payload)) Expect(err).ToNot(HaveOccurred()) req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", bearerKey) @@ -269,7 +269,7 @@ var _ = Describe("Open Responses API", func() { payload, err := json.Marshal(reqBody) Expect(err).ToNot(HaveOccurred()) - req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload)) + req, err := http.NewRequest("POST", testHTTPBase+"/v1/responses", bytes.NewBuffer(payload)) Expect(err).ToNot(HaveOccurred()) req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", bearerKey) @@ -297,7 +297,7 @@ var _ = Describe("Open Responses API", func() { payload, err := json.Marshal(reqBody) Expect(err).ToNot(HaveOccurred()) - req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload)) + req, err := http.NewRequest("POST", testHTTPBase+"/v1/responses", bytes.NewBuffer(payload)) Expect(err).ToNot(HaveOccurred()) req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", bearerKey) @@ -328,7 +328,7 @@ var _ = Describe("Open Responses API", func() { payload, err := json.Marshal(reqBody) Expect(err).ToNot(HaveOccurred()) - req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload)) + req, err := http.NewRequest("POST", testHTTPBase+"/v1/responses", bytes.NewBuffer(payload)) Expect(err).ToNot(HaveOccurred()) req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", bearerKey) @@ -358,7 +358,7 @@ var _ = Describe("Open Responses API", func() { payload, err := json.Marshal(reqBody) Expect(err).ToNot(HaveOccurred()) - req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload)) + req, err := http.NewRequest("POST", testHTTPBase+"/v1/responses", bytes.NewBuffer(payload)) Expect(err).ToNot(HaveOccurred()) req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", bearerKey) @@ -386,7 +386,7 @@ var _ = Describe("Open Responses API", func() { payload, err := json.Marshal(reqBody) Expect(err).ToNot(HaveOccurred()) - req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload)) + req, err := http.NewRequest("POST", testHTTPBase+"/v1/responses", bytes.NewBuffer(payload)) Expect(err).ToNot(HaveOccurred()) req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", bearerKey) @@ -418,7 +418,7 @@ var _ = Describe("Open Responses API", func() { payload, err := json.Marshal(reqBody) Expect(err).ToNot(HaveOccurred()) - req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload)) + req, err := http.NewRequest("POST", testHTTPBase+"/v1/responses", bytes.NewBuffer(payload)) Expect(err).ToNot(HaveOccurred()) req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", bearerKey) @@ -454,7 +454,7 @@ var _ = Describe("Open Responses API", func() { payload, err := json.Marshal(reqBody) Expect(err).ToNot(HaveOccurred()) - req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload)) + req, err := http.NewRequest("POST", testHTTPBase+"/v1/responses", bytes.NewBuffer(payload)) Expect(err).ToNot(HaveOccurred()) req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", bearerKey) @@ -490,7 +490,7 @@ var _ = Describe("Open Responses API", func() { payload, err := json.Marshal(reqBody) Expect(err).ToNot(HaveOccurred()) - req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload)) + req, err := http.NewRequest("POST", testHTTPBase+"/v1/responses", bytes.NewBuffer(payload)) Expect(err).ToNot(HaveOccurred()) req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", bearerKey) @@ -539,7 +539,7 @@ var _ = Describe("Open Responses API", func() { payload, err := json.Marshal(reqBody) Expect(err).ToNot(HaveOccurred()) - req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload)) + req, err := http.NewRequest("POST", testHTTPBase+"/v1/responses", bytes.NewBuffer(payload)) Expect(err).ToNot(HaveOccurred()) req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", bearerKey) @@ -590,7 +590,7 @@ var _ = Describe("Open Responses API", func() { payload, err := json.Marshal(reqBody) Expect(err).ToNot(HaveOccurred()) - req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload)) + req, err := http.NewRequest("POST", testHTTPBase+"/v1/responses", bytes.NewBuffer(payload)) Expect(err).ToNot(HaveOccurred()) req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", bearerKey) @@ -624,7 +624,7 @@ var _ = Describe("Open Responses API", func() { payload, err := json.Marshal(reqBody) Expect(err).ToNot(HaveOccurred()) - req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload)) + req, err := http.NewRequest("POST", testHTTPBase+"/v1/responses", bytes.NewBuffer(payload)) Expect(err).ToNot(HaveOccurred()) req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", bearerKey) @@ -658,7 +658,7 @@ var _ = Describe("Open Responses API", func() { payload, err := json.Marshal(reqBody) Expect(err).ToNot(HaveOccurred()) - req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload)) + req, err := http.NewRequest("POST", testHTTPBase+"/v1/responses", bytes.NewBuffer(payload)) Expect(err).ToNot(HaveOccurred()) req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", bearerKey) @@ -680,7 +680,7 @@ var _ = Describe("Open Responses API", func() { payload, err := json.Marshal(reqBody) Expect(err).ToNot(HaveOccurred()) - req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload)) + req, err := http.NewRequest("POST", testHTTPBase+"/v1/responses", bytes.NewBuffer(payload)) Expect(err).ToNot(HaveOccurred()) req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", bearerKey) @@ -727,7 +727,7 @@ var _ = Describe("Open Responses API", func() { payload, err := json.Marshal(reqBody) Expect(err).ToNot(HaveOccurred()) - req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload)) + req, err := http.NewRequest("POST", testHTTPBase+"/v1/responses", bytes.NewBuffer(payload)) Expect(err).ToNot(HaveOccurred()) req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", bearerKey) @@ -756,7 +756,7 @@ var _ = Describe("Open Responses API", func() { payload, err := json.Marshal(reqBody) Expect(err).ToNot(HaveOccurred()) - req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload)) + req, err := http.NewRequest("POST", testHTTPBase+"/v1/responses", bytes.NewBuffer(payload)) Expect(err).ToNot(HaveOccurred()) req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", bearerKey) @@ -799,7 +799,7 @@ var _ = Describe("Open Responses API", func() { payload, err := json.Marshal(reqBody) Expect(err).ToNot(HaveOccurred()) - req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload)) + req, err := http.NewRequest("POST", testHTTPBase+"/v1/responses", bytes.NewBuffer(payload)) Expect(err).ToNot(HaveOccurred()) req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", bearerKey) @@ -835,7 +835,7 @@ var _ = Describe("Open Responses API", func() { payload1, err := json.Marshal(reqBody1) Expect(err).ToNot(HaveOccurred()) - req1, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload1)) + req1, err := http.NewRequest("POST", testHTTPBase+"/v1/responses", bytes.NewBuffer(payload1)) Expect(err).ToNot(HaveOccurred()) req1.Header.Set("Content-Type", "application/json") req1.Header.Set("Authorization", bearerKey) @@ -869,7 +869,7 @@ var _ = Describe("Open Responses API", func() { payload2, err := json.Marshal(reqBody2) Expect(err).ToNot(HaveOccurred()) - req2, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload2)) + req2, err := http.NewRequest("POST", testHTTPBase+"/v1/responses", bytes.NewBuffer(payload2)) Expect(err).ToNot(HaveOccurred()) req2.Header.Set("Content-Type", "application/json") req2.Header.Set("Authorization", bearerKey) @@ -897,7 +897,7 @@ var _ = Describe("Open Responses API", func() { payload, err := json.Marshal(reqBody) Expect(err).ToNot(HaveOccurred()) - req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload)) + req, err := http.NewRequest("POST", testHTTPBase+"/v1/responses", bytes.NewBuffer(payload)) Expect(err).ToNot(HaveOccurred()) req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", bearerKey) @@ -933,7 +933,7 @@ var _ = Describe("Open Responses API", func() { payload1, err := json.Marshal(reqBody1) Expect(err).ToNot(HaveOccurred()) - req1, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload1)) + req1, err := http.NewRequest("POST", testHTTPBase+"/v1/responses", bytes.NewBuffer(payload1)) Expect(err).ToNot(HaveOccurred()) req1.Header.Set("Content-Type", "application/json") req1.Header.Set("Authorization", bearerKey) @@ -983,7 +983,7 @@ var _ = Describe("Open Responses API", func() { payload2, err := json.Marshal(reqBody2) Expect(err).ToNot(HaveOccurred()) - req2, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload2)) + req2, err := http.NewRequest("POST", testHTTPBase+"/v1/responses", bytes.NewBuffer(payload2)) Expect(err).ToNot(HaveOccurred()) req2.Header.Set("Content-Type", "application/json") req2.Header.Set("Authorization", bearerKey) @@ -1009,7 +1009,7 @@ var _ = Describe("Open Responses API", func() { payload, err := json.Marshal(reqBody) Expect(err).ToNot(HaveOccurred()) - req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload)) + req, err := http.NewRequest("POST", testHTTPBase+"/v1/responses", bytes.NewBuffer(payload)) Expect(err).ToNot(HaveOccurred()) req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", bearerKey) From 778f85c2a01c8bf18d3fd6c0443a8d09239297a2 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Thu, 11 Jun 2026 14:50:39 +0000 Subject: [PATCH 02/13] feat(dllm): purego backend scaffold over the dllm.cpp C-ABI Binds the 9-symbol flat C-ABI of dllm.cpp (DiffusionGemma engine) via purego: typed wrappers with correct string ownership (malloc'd returns freed via dllm_capi_free_string, borrowed last_error never freed), once-allocated stream-callback trampolines, and a gated Ginkgo binding smoke against the tiny fixture model. Assisted-by: Claude Code (Fable 5) Signed-off-by: Ettore Di Giacinto --- backend/go/dllm/.gitignore | 10 ++ backend/go/dllm/Makefile | 89 ++++++++++++ backend/go/dllm/capi.go | 265 +++++++++++++++++++++++++++++++++++ backend/go/dllm/dllm_test.go | 144 +++++++++++++++++++ backend/go/dllm/main.go | 85 +++++++++++ backend/go/dllm/package.sh | 24 ++++ backend/go/dllm/run.sh | 16 +++ 7 files changed, 633 insertions(+) create mode 100644 backend/go/dllm/.gitignore create mode 100644 backend/go/dllm/Makefile create mode 100644 backend/go/dllm/capi.go create mode 100644 backend/go/dllm/dllm_test.go create mode 100644 backend/go/dllm/main.go create mode 100755 backend/go/dllm/package.sh create mode 100755 backend/go/dllm/run.sh diff --git a/backend/go/dllm/.gitignore b/backend/go/dllm/.gitignore new file mode 100644 index 000000000000..5b1edf6d31ea --- /dev/null +++ b/backend/go/dllm/.gitignore @@ -0,0 +1,10 @@ +.cache/ +sources/ +build/ +package/ +dllm-grpc +# build artifacts staged in-tree by the Makefile (cp from sources/) or +# symlinked for local dev; the real sources live in dllm.cpp upstream. +*.so +*.so.* +compile_commands.json diff --git a/backend/go/dllm/Makefile b/backend/go/dllm/Makefile new file mode 100644 index 000000000000..3b7114c12ed5 --- /dev/null +++ b/backend/go/dllm/Makefile @@ -0,0 +1,89 @@ +# dllm backend Makefile. +# +# Upstream pin lives below as DLLM_VERSION?= so .github/bump_deps.sh +# can find and update it - matches the whisper.cpp / parakeet-cpp / ds4 +# convention. +# +# Local dev shortcut: if you already have an out-of-tree dllm.cpp build, +# you can symlink the .so into this directory and skip the clone/cmake +# steps entirely, e.g.: +# +# ln -sf /path/to/dllm.cpp/build/libdllm.so . +# go build -o dllm-grpc . +# +# That's what the gated C-ABI binding smoke uses (DLLM_TEST_LIBRARY). The +# default target below does the proper clone-at-pin + cmake build so CI +# doesn't need a side-checkout. + +DLLM_VERSION?=b22fcebebfb225131113188599a9ae542b2935d7 +DLLM_REPO?=https://github.com/mudler/dllm.cpp + +GOCMD?=go +GO_TAGS?= +JOBS?=$(shell nproc 2>/dev/null || sysctl -n hw.ncpu 2>/dev/null || echo 4) + +BUILD_TYPE?= +NATIVE?=false + +# libdllm.so is self-contained: dllm.cpp's CMakeLists statically absorbs ggml +# (BUILD_SHARED_LIBS=OFF + PIC) into the shared lib, so dlopen needs no +# libggml*.so alongside it, only system libs (libstdc++/libgomp/libc) the +# runtime image already provides. Tests/CLI are upstream-only concerns. +CMAKE_ARGS?=-DCMAKE_BUILD_TYPE=Release -DDLLM_BUILD_TESTS=OFF + +ifeq ($(NATIVE),false) + CMAKE_ARGS+=-DGGML_NATIVE=OFF +endif + +# Same arch set the sibling ggml backends (acestep/vibevoice/qwen3-tts) bake +# for their cublas images; override for a native build. +CUDA_ARCHITECTURES?=75-virtual;80-virtual;86-real;89-real + +# dllm.cpp gates CUDA behind DLLM_CUDA (set(GGML_CUDA ... CACHE FORCE)), so +# forward that instead of a bare -DGGML_CUDA=ON. +ifeq ($(BUILD_TYPE),cublas) + CMAKE_ARGS+=-DDLLM_CUDA=ON -DCMAKE_CUDA_ARCHITECTURES="$(CUDA_ARCHITECTURES)" +endif + +.PHONY: dllm-grpc package build clean purge test all + +all: dllm-grpc + +# Clone the upstream dllm.cpp source at the pinned commit (ggml comes in as +# a submodule). Directory acts as the target so make only re-clones when +# missing. After a DLLM_VERSION bump, run 'make purge && make' to refetch. +sources/dllm.cpp: + mkdir -p sources/dllm.cpp + cd sources/dllm.cpp && \ + git init -q && \ + git remote add origin $(DLLM_REPO) && \ + git fetch --depth 1 origin $(DLLM_VERSION) && \ + git checkout FETCH_HEAD && \ + git submodule update --init --recursive --depth 1 --single-branch + +# Build the shared lib out-of-tree, then stage it next to the Go sources so +# purego.Dlopen("libdllm.so") and the packaging step both pick it up. +libdllm.so: sources/dllm.cpp + cmake -B sources/dllm.cpp/build -S sources/dllm.cpp $(CMAKE_ARGS) + cmake --build sources/dllm.cpp/build --config Release -j$(JOBS) + cp -fv sources/dllm.cpp/build/libdllm.so ./ + +dllm-grpc: libdllm.so main.go capi.go + CGO_ENABLED=0 $(GOCMD) build -tags "$(GO_TAGS)" -o dllm-grpc . + +package: dllm-grpc + bash package.sh + +build: package + +# Test target. The C-ABI binding smoke is gated on DLLM_TEST_LIBRARY + +# DLLM_TEST_TINY_MODEL; without them the gated specs auto-skip and only the +# pure-Go helper specs run. +test: + LD_LIBRARY_PATH=$(CURDIR):$$LD_LIBRARY_PATH $(GOCMD) test ./... -count=1 + +clean: purge + rm -rf libdllm.so* package dllm-grpc + +purge: + rm -rf sources/dllm.cpp diff --git a/backend/go/dllm/capi.go b/backend/go/dllm/capi.go new file mode 100644 index 000000000000..d8c0ca11e1ba --- /dev/null +++ b/backend/go/dllm/capi.go @@ -0,0 +1,265 @@ +package main + +// Typed Go wrappers over dllm.cpp's flat C-ABI (include/dllm_capi.h, ABI v1). +// +// Contract highlights the wrappers encode (see the header + src/capi.cpp): +// - tokenize_json/generate return malloc'd char* the CALLER owns: bound as +// uintptr, copied with goStringFromCPtr, released via dllm_capi_free_string. +// - last_error returns a BORROWED pointer (valid until the next call on the +// same ctx): bound as a plain string (purego copies), never freed, and only +// read AFTER the failing call has returned - reading it while a generate is +// in flight on the same ctx violates the per-ctx serialization contract. +// - All entry points except dllm_capi_cancel must be externally serialized +// per ctx (one ctx = one concurrent generate/tokenize). Cancel only flips +// an atomic and may be called from any goroutine mid-generate. +// - No C++ exception crosses the boundary; failures land in last_error. + +import ( + "encoding/json" + "errors" + "fmt" + "sync" + "sync/atomic" + "unsafe" + + "github.com/ebitengine/purego" + "github.com/mudler/LocalAI/pkg/grpc/base" + pb "github.com/mudler/LocalAI/pkg/grpc/proto" +) + +// dllmABIVersion is the DLLM_CAPI_ABI_VERSION this binding was written +// against; main.go refuses to start against a libdllm.so reporting another. +const dllmABIVersion = 1 + +// purego-bound entry points from libdllm.so. Names match dllm_capi.h +// exactly; loadCAPI (main.go) fills these in at boot. +var ( + cppAbiVersion func() int32 + cppLoad func(ggufPath, paramsJSON string) uintptr + cppFree func(ctx uintptr) + cppLastError func(ctx uintptr) string // borrowed pointer: purego copies, do NOT free + cppFreeString func(s uintptr) + // malloc'd char* returns, hence uintptr (see loadCAPI's doc comment). + cppTokenizeJSON func(ctx uintptr, text string) uintptr + cppGenerate func(ctx uintptr, prompt, optsJSON string) uintptr + // on_block/on_step are C function pointers produced by purego.NewCallback; + // userData carries the streamCallStates registry key. + cppGenerateStream func(ctx uintptr, prompt, optsJSON string, onBlock, onStep, userData uintptr) int32 + cppCancel func(ctx uintptr) +) + +// Dllm is the LocalAI gRPC backend over the dllm.cpp C-ABI. T1 ships only +// the binding scaffold; Load/PredictRich/PredictStreamRich (and the move to +// a dedicated dllm.go with the per-model worker goroutine) land in T4. +type Dllm struct { + base.Base +} + +// Load is not wired yet: the binding smoke drives the C functions directly. +func (d *Dllm) Load(opts *pb.ModelOptions) error { + return errors.New("dllm: model loading not implemented yet (backend wiring lands in T4)") +} + +// cAbiVersion returns the library's DLLM_CAPI_ABI_VERSION. +func cAbiVersion() int32 { + return cppAbiVersion() +} + +// cLoad opens the GGUF at path with the flat params JSON (e.g. +// {"n_gpu_layers":99}). Returns 0 on failure; per the header contract there +// is no ctx to carry the reason, the C side logs it to stderr (and +// cLastError(0) only yields the static NULL-ctx message). +func cLoad(path, paramsJSON string) uintptr { + return cppLoad(path, paramsJSON) +} + +// cFree releases a ctx; safe on 0 (delete nullptr). +func cFree(h uintptr) { + cppFree(h) +} + +// cLastError returns the ctx's last error message (or the static NULL-ctx +// message for h==0). The C pointer is borrowed and only valid until the next +// call on the same ctx; purego's string return copies it immediately, so the +// returned Go string is safe to keep. Must not be called while another call +// on the same ctx is in flight. +func cLastError(h uintptr) string { + return cppLastError(h) +} + +// lastErrorOr is cLastError with a fallback for the empty-message case, so +// wrapped errors never end in ": ". +func lastErrorOr(h uintptr, fallback string) string { + if msg := cLastError(h); msg != "" { + return msg + } + return fallback +} + +// cTokenizeJSON tokenizes text (the C side prepends bos per vocab.add_bos) +// and returns the token ids as a JSON array string, e.g. "[2,18]". +func cTokenizeJSON(h uintptr, text string) (string, error) { + ret := cppTokenizeJSON(h, text) + if ret == 0 { + return "", fmt.Errorf("dllm: tokenize failed: %s", lastErrorOr(h, "unknown error")) + } + out := goStringFromCPtr(ret) + cppFreeString(ret) + return out, nil +} + +// cGenerate runs a blocking generation and returns the detokenized text. +// optsJSON must be a FLAT JSON object of scalars (use buildOptsJSON); the C +// parser rejects nested objects/arrays. NULL return -> last_error (read only +// after the call returned, per the serialization contract); a cancelled call +// surfaces as the "cancelled" message. +func cGenerate(h uintptr, prompt, optsJSON string) (string, error) { + ret := cppGenerate(h, prompt, optsJSON) + if ret == 0 { + return "", fmt.Errorf("dllm: generate failed: %s", lastErrorOr(h, "unknown error")) + } + out := goStringFromCPtr(ret) + cppFreeString(ret) + return out, nil +} + +// streamCallState carries the Go callbacks for one in-flight +// cGenerateStream call; the registry key travels through C as user_data. +// The map shape mirrors the whisper backend's streamCallStates: only one +// entry per ctx is ever live (the C-ABI is serialized per ctx), but keying +// by call survives multiple models/processes sharing the package. +type streamCallState struct { + onBlock func(text string) + onStep func(step, total int, preview string) +} + +var ( + streamCallStates sync.Map // uint64 -> *streamCallState + streamCallSeq atomic.Uint64 + + // purego.NewCallback allocates a finite, never-released callback slot, so + // the two trampolines are created exactly once and reused across calls. + streamCbOnce sync.Once + blockCbPtr uintptr + stepCbPtr uintptr +) + +// onBlockTrampoline is the Go side of dllm_block_cb. It runs on the C +// calling thread, mid-generate: keep it tiny and non-blocking (callers that +// bridge to goroutines must hand off via buffered channels). The text +// pointer is only valid for the duration of the invocation, so it is copied +// to a Go string immediately. +func onBlockTrampoline(text uintptr, userData uintptr) { + v, ok := streamCallStates.Load(uint64(userData)) + if !ok { + return // call already torn down + } + state := v.(*streamCallState) + if state.onBlock != nil { + state.onBlock(goStringFromCPtr(text)) + } +} + +// onStepTrampoline is the Go side of dllm_step_cb; same threading and +// lifetime caveats as onBlockTrampoline. +func onStepTrampoline(step int32, totalSteps int32, canvasPreview uintptr, userData uintptr) { + v, ok := streamCallStates.Load(uint64(userData)) + if !ok { + return + } + state := v.(*streamCallState) + if state.onStep != nil { + state.onStep(int(step), int(totalSteps), goStringFromCPtr(canvasPreview)) + } +} + +// cGenerateStream runs a generation with per-committed-block (onBlock) and +// per-denoising-step (onStep) callbacks; either may be nil. The callbacks +// run on the C thread (see the trampoline docs). Returns an error carrying +// last_error on failure; cancellation surfaces as the "cancelled" message. +func cGenerateStream(h uintptr, prompt, optsJSON string, onBlock func(text string), onStep func(step, total int, preview string)) error { + streamCbOnce.Do(func() { + blockCbPtr = purego.NewCallback(onBlockTrampoline) + stepCbPtr = purego.NewCallback(onStepTrampoline) + }) + + id := streamCallSeq.Add(1) + streamCallStates.Store(id, &streamCallState{onBlock: onBlock, onStep: onStep}) + defer streamCallStates.Delete(id) + + // Pass NULL for absent callbacks so the C side skips the per-block / + // per-step detokenize work entirely. + var blockPtr, stepPtr uintptr + if onBlock != nil { + blockPtr = blockCbPtr + } + if onStep != nil { + stepPtr = stepCbPtr + } + + if rc := cppGenerateStream(h, prompt, optsJSON, blockPtr, stepPtr, uintptr(id)); rc != 0 { + return fmt.Errorf("dllm: generate_stream failed: %s", lastErrorOr(h, "unknown error")) + } + return nil +} + +// cCancel requests cancellation of the in-flight generate on h. This is the +// ONE entry point safe to call from any goroutine while a generate runs (it +// only flips an atomic). Note the cancel-reset race from the header: each +// generate resets the flag on entry, so a watchdog should re-issue cancel if +// the call has not returned. +func cCancel(h uintptr) { + cppCancel(h) +} + +// buildOptsJSON renders generation options as the flat JSON object the +// C-ABI expects (known keys: n_predict, blocks, seed, eb_*, kv_cache). The +// C-side scanner only understands scalar number/string values and rejects +// nested objects/arrays loudly; bools are rejected here too because the +// scanner has no concept of them. Fail loud rather than let an option be +// silently misread. +func buildOptsJSON(opts map[string]any) (string, error) { + if len(opts) == 0 { + return "{}", nil + } + for k, v := range opts { + switch v.(type) { + case string, + int, int8, int16, int32, int64, + uint, uint8, uint16, uint32, uint64, + float32, float64, + json.Number: + // scalar: fine + default: + return "", fmt.Errorf("dllm: opts key %q has non-scalar value %T (the C-ABI only accepts flat number/string scalars)", k, v) + } + } + b, err := json.Marshal(opts) + if err != nil { + return "", fmt.Errorf("dllm: marshal opts: %w", err) + } + return string(b), nil +} + +// goStringFromCPtr copies a NUL-terminated C string into Go memory. cptr is +// the raw pointer returned by purego from the C-ABI (a malloc'd buffer the +// caller owns, or a callback argument only valid during the invocation); +// owning callers must free it via cppFreeString after the copy lands. +// +// The uintptr->unsafe.Pointer conversion below trips go vet's unsafeptr +// check, which can't distinguish a C-owned heap pointer from Go-managed +// memory. It is safe here: the pointer addresses C memory the Go GC neither +// tracks nor moves, and we dereference it immediately to copy the bytes out, +// the same pattern (and the same tolerated warning) as the parakeet-cpp and +// whisper backends. +func goStringFromCPtr(cptr uintptr) string { + if cptr == 0 { + return "" + } + p := unsafe.Pointer(cptr) //nolint:govet // C-owned buffer, not Go-GC memory (see doc above) + n := 0 + for *(*byte)(unsafe.Add(p, n)) != 0 { + n++ + } + return string(unsafe.Slice((*byte)(p), n)) +} diff --git a/backend/go/dllm/dllm_test.go b/backend/go/dllm/dllm_test.go new file mode 100644 index 000000000000..a6f1e569775c --- /dev/null +++ b/backend/go/dllm/dllm_test.go @@ -0,0 +1,144 @@ +package main + +import ( + "os" + "sync" + "testing" + "unsafe" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +func TestDllm(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "dllm Backend Suite") +} + +var ( + libLoadOnce sync.Once + libLoadErr error +) + +// ensureLibLoaded mirrors main.go's bootstrap so a Go test can drive the +// C-ABI bridge without spinning up the gRPC server. The library path comes +// from DLLM_TEST_LIBRARY (gated specs Skip when it is unset). +func ensureLibLoaded() { + libLoadOnce.Do(func() { + libLoadErr = loadCAPI(os.Getenv("DLLM_TEST_LIBRARY")) + }) +} + +// C-ABI binding smoke: drives the real libdllm.so against the tiny GGUF +// fixture from dllm.cpp (tests/fixtures/tiny_with_vocab.gguf). Gated on: +// +// DLLM_TEST_LIBRARY absolute path to libdllm.so +// DLLM_TEST_TINY_MODEL absolute path to tiny_with_vocab.gguf +var _ = Describe("C-ABI binding", func() { + BeforeEach(func() { + if os.Getenv("DLLM_TEST_LIBRARY") == "" || os.Getenv("DLLM_TEST_TINY_MODEL") == "" { + Skip("set DLLM_TEST_LIBRARY and DLLM_TEST_TINY_MODEL to run the C-ABI binding smoke") + } + ensureLibLoaded() + Expect(libLoadErr).ToNot(HaveOccurred()) + }) + + It("binds the 9 symbols and round-trips the tiny model", func() { + Expect(cAbiVersion()).To(Equal(int32(1))) + + h := cLoad(os.Getenv("DLLM_TEST_TINY_MODEL"), "{}") + Expect(h).ToNot(BeZero(), "dllm_capi_load of the tiny fixture") + + // Tiny fixture vocab: "hello" tokenizes to ids [2,18] (bos prepended + // by the C side: vocab.add_bos). + toks, err := cTokenizeJSON(h, "hello") + Expect(err).ToNot(HaveOccurred()) + Expect(toks).To(Equal("[2,18]")) + + // Deterministic generation: an explicit non-negative seed seeds + // mt19937, so two identical calls must produce identical text. + out1, err := cGenerate(h, "hello", `{"n_predict":16,"seed":7}`) + Expect(err).ToNot(HaveOccurred()) + Expect(out1).ToNot(BeEmpty()) + // Cancel with no call in flight is dropped: each generate resets the + // cancel flag on entry (header contract), so this must not affect + // the next call. Also binds the 9th symbol; safe on NULL too. + cCancel(h) + cCancel(0) + + out2, err := cGenerate(h, "hello", `{"n_predict":16,"seed":7}`) + Expect(err).ToNot(HaveOccurred()) + Expect(out2).To(Equal(out1)) + + // Streaming variant: same opts, blocks arrive via the purego + // callback trampoline. The per-block detokenize can differ from the + // seamless full-text decode at block boundaries, so only assert that + // blocks arrived and were non-trivial, not byte equality with out1. + var blocks []string + var steps int + err = cGenerateStream(h, "hello", `{"n_predict":16,"seed":7}`, + func(text string) { blocks = append(blocks, text) }, + func(step, total int, preview string) { steps++ }, + ) + Expect(err).ToNot(HaveOccurred()) + Expect(blocks).ToNot(BeEmpty()) + Expect(steps).To(BeNumerically(">", 0)) + + // Load failure path: NULL ctx back, and last_error(NULL) returns the + // static NULL-ctx message (there is no ctx to carry the real reason). + bad := cLoad("/nonexistent/dllm-model.gguf", "{}") + Expect(bad).To(BeZero()) + Expect(cLastError(0)).ToNot(BeEmpty()) + + // Free is safe on a live handle and a NULL one (delete nullptr). + cFree(h) + cFree(0) + }) +}) + +// Ungated specs for the pure-Go helpers (no libdllm.so required). +var _ = Describe("buildOptsJSON", func() { + It("renders flat scalars as a JSON object", func() { + out, err := buildOptsJSON(map[string]any{ + "n_predict": 16, + "seed": int64(7), + "eb_t_min": 0.5, + "kv_cache": "auto", + }) + Expect(err).ToNot(HaveOccurred()) + Expect(out).To(MatchJSON(`{"n_predict":16,"seed":7,"eb_t_min":0.5,"kv_cache":"auto"}`)) + }) + + It("renders an empty object for no options", func() { + out, err := buildOptsJSON(nil) + Expect(err).ToNot(HaveOccurred()) + Expect(out).To(Equal("{}")) + }) + + It("rejects nested objects (the C-side scanner only reads flat scalars)", func() { + _, err := buildOptsJSON(map[string]any{"sampler": map[string]any{"seed": 1}}) + Expect(err).To(HaveOccurred()) + }) + + It("rejects arrays", func() { + _, err := buildOptsJSON(map[string]any{"stop": []string{"a"}}) + Expect(err).To(HaveOccurred()) + }) + + It("rejects booleans (the C-side scanner only understands numbers and strings)", func() { + _, err := buildOptsJSON(map[string]any{"flag": true}) + Expect(err).To(HaveOccurred()) + }) +}) + +var _ = Describe("goStringFromCPtr", func() { + It("copies a NUL-terminated buffer", func() { + buf := []byte("dllm\x00") + s := goStringFromCPtr(uintptr(unsafe.Pointer(&buf[0]))) + Expect(s).To(Equal("dllm")) + }) + + It("returns the empty string for NULL", func() { + Expect(goStringFromCPtr(0)).To(Equal("")) + }) +}) diff --git a/backend/go/dllm/main.go b/backend/go/dllm/main.go new file mode 100644 index 000000000000..41d4368f2752 --- /dev/null +++ b/backend/go/dllm/main.go @@ -0,0 +1,85 @@ +package main + +// Started internally by LocalAI - one gRPC server per loaded model. +// +// Loads libdllm.so via purego and registers the 9-symbol flat C-ABI +// declared in dllm.cpp's include/dllm_capi.h (ABI v1). The library name can +// be overridden with DLLM_LIBRARY (mirrors the PARAKEET_LIBRARY / +// WHISPER_LIBRARY convention in the sibling backends); the default looks +// for the .so next to this binary (run.sh puts the package dir on +// LD_LIBRARY_PATH). +import ( + "flag" + "fmt" + "os" + + "github.com/ebitengine/purego" + grpc "github.com/mudler/LocalAI/pkg/grpc" +) + +var ( + addr = flag.String("addr", "localhost:50051", "the address to connect to") +) + +type LibFuncs struct { + FuncPtr any + Name string +} + +// loadCAPI dlopens libName and binds the 9 dllm_capi_* entry points 1:1 to +// dllm_capi.h, so an `nm libdllm.so | grep dllm_capi` is enough to spot +// drift. Shared with the test suite (ensureLibLoaded), which drives the +// bridge without the gRPC server. +// +// The C-ABI returns malloc'd char* buffers from tokenize_json/generate; we +// register those as uintptr so we get the raw pointer back and can call +// dllm_capi_free_string on it (purego's string return would copy and forget +// the original pointer, leaking it on every call). last_error returns a +// BORROWED pointer instead, so it is registered as a plain string: purego +// copies it and nothing must be freed. +func loadCAPI(libName string) error { + lib, err := purego.Dlopen(libName, purego.RTLD_NOW|purego.RTLD_GLOBAL) + if err != nil { + return fmt.Errorf("dllm: dlopen %q: %w", libName, err) + } + + libFuncs := []LibFuncs{ + {&cppAbiVersion, "dllm_capi_abi_version"}, + {&cppLoad, "dllm_capi_load"}, + {&cppFree, "dllm_capi_free"}, + {&cppLastError, "dllm_capi_last_error"}, + {&cppFreeString, "dllm_capi_free_string"}, + {&cppTokenizeJSON, "dllm_capi_tokenize_json"}, + {&cppGenerate, "dllm_capi_generate"}, + {&cppGenerateStream, "dllm_capi_generate_stream"}, + {&cppCancel, "dllm_capi_cancel"}, + } + for _, lf := range libFuncs { + purego.RegisterLibFunc(lf.FuncPtr, lib, lf.Name) + } + return nil +} + +func main() { + libName := os.Getenv("DLLM_LIBRARY") + if libName == "" { + libName = "libdllm.so" + } + + if err := loadCAPI(libName); err != nil { + panic(err) + } + + // Hard-fail on an ABI mismatch: the flat-pointer bindings above would + // otherwise misbehave silently against a future libdllm.so. + if v := cAbiVersion(); v != dllmABIVersion { + panic(fmt.Errorf("dllm: libdllm.so ABI=%d, this backend speaks ABI=%d", v, dllmABIVersion)) + } + fmt.Fprintf(os.Stderr, "[dllm] ABI=%d\n", cAbiVersion()) + + flag.Parse() + + if err := grpc.StartServer(*addr, &Dllm{}); err != nil { + panic(err) + } +} diff --git a/backend/go/dllm/package.sh b/backend/go/dllm/package.sh new file mode 100755 index 000000000000..5b2b8f8b935a --- /dev/null +++ b/backend/go/dllm/package.sh @@ -0,0 +1,24 @@ +#!/bin/bash +# +# T1 packaging stub: copy the binary, run.sh and libdllm.so into package/. +# The full ldd walk (libc, libstdc++, libgomp, GPU runtimes, arch +# detection) lands with the registration task, mirroring +# backend/go/whisper/package.sh. + +set -e + +CURDIR=$(dirname "$(realpath "$0")") + +mkdir -p "$CURDIR/package/lib" + +cp -avf "$CURDIR/dllm-grpc" "$CURDIR/package/" +cp -avf "$CURDIR/run.sh" "$CURDIR/package/" + +# libdllm.so + any soname symlinks, should upstream ever add them. +cp -avf "$CURDIR"/libdllm.so* "$CURDIR/package/lib/" 2>/dev/null || { + echo "ERROR: libdllm.so not found in $CURDIR, run 'make' first" >&2 + exit 1 +} + +echo "T1 package layout (full ldd walk lands with registration):" +ls -liah "$CURDIR/package/" "$CURDIR/package/lib/" diff --git a/backend/go/dllm/run.sh b/backend/go/dllm/run.sh new file mode 100755 index 000000000000..ab30af4b01c7 --- /dev/null +++ b/backend/go/dllm/run.sh @@ -0,0 +1,16 @@ +#!/bin/bash +set -e + +CURDIR=$(dirname "$(realpath "$0")") + +export LD_LIBRARY_PATH="$CURDIR/lib:$CURDIR:${LD_LIBRARY_PATH:-}" + +# If a self-contained ld.so was packaged, route through it so the +# packaged libc / libstdc++ are used instead of the host's (matches the +# whisper / parakeet-cpp backends' runtime layout). +if [ -f "$CURDIR/lib/ld.so" ]; then + echo "Using lib/ld.so" + exec "$CURDIR/lib/ld.so" "$CURDIR/dllm-grpc" "$@" +fi + +exec "$CURDIR/dllm-grpc" "$@" From 294c04ae2f386d29798c9ff3f6f220a3d44e30ae Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Thu, 11 Jun 2026 15:55:27 +0000 Subject: [PATCH 03/13] feat(dllm): gemma4 streaming parser emitting ChatDeltas Fragment-safe state machine (content / channel header / thought / tool-call / done) classifying model output into content, reasoning_content and tool_calls deltas. Tool-call payload decoder is a non-partial port of vLLM's gemma4 parser grammar; ~25 of its test cases are ported with citations, plus a 2-split invariance property over every byte position. Recursion depth-capped against model-generated deep nesting; marker constants shared with the renderer. Assisted-by: Claude Code (Fable 5) Signed-off-by: Ettore Di Giacinto --- backend/go/dllm/gemma4_parser.go | 562 +++++++++++++ backend/go/dllm/gemma4_parser_test.go | 592 +++++++++++++ backend/go/dllm/gemma4_renderer.go | 1026 +++++++++++++++++++++++ backend/go/dllm/gemma4_renderer_test.go | 347 ++++++++ 4 files changed, 2527 insertions(+) create mode 100755 backend/go/dllm/gemma4_parser.go create mode 100755 backend/go/dllm/gemma4_parser_test.go create mode 100755 backend/go/dllm/gemma4_renderer.go create mode 100755 backend/go/dllm/gemma4_renderer_test.go diff --git a/backend/go/dllm/gemma4_parser.go b/backend/go/dllm/gemma4_parser.go new file mode 100755 index 000000000000..cf381cea1cef --- /dev/null +++ b/backend/go/dllm/gemma4_parser.go @@ -0,0 +1,562 @@ +// Gemma4 (DiffusionGemma) streaming output parser: raw model text, fed in +// arbitrary fragments (per committed diffusion block; a fragment can split +// anywhere, including mid-marker and mid-payload), is turned into +// pb.ChatDelta events (content / reasoning_content / tool_calls). +// +// Normative sources: +// - The chat template embedded at the top of gemma4_renderer.go ("tpl L" +// citations below refer to its numbered lines). The OUTPUT format mirrors +// what the template renders for assistant history: thought channels +// (<|channel>thought\n ... , tpl L240), tool calls +// (<|tool_call>call:name{...}, tpl L246-L257) and turn ends +// (, tpl L351). +// - vLLM PR #45163: vllm/tool_parsers/gemma4_tool_parser.py (marker +// handling, the call:name{...} argument grammar and its decoder, ported +// below) and vllm/reasoning/gemma4_reasoning_parser.py (channel markers, +// the "thought\n" role label, is_reasoning_end semantics). +// +// Initial state (derived from the generation prompt, tpl L356-L362, see +// RenderGemma4): +// - enable_thinking=false: the prompt ends with "<|turn>model\n" + +// "<|channel>thought\n" - an EMPTY thought channel, pre-opened +// AND pre-closed by the template. The model's output therefore starts in +// plain content. Use NewGemma4Parser(false). +// - enable_thinking=true: the prompt ends at "<|turn>model\n" and the model +// opens and closes its own thought channel in the OUTPUT +// ("<|channel>thought\n...reasoning...final answer", per the +// vLLM Gemma4ReasoningParser docstring). The parser still starts in +// content state - the channel markers in the output drive the switch. +// Use NewGemma4Parser(false) here too. +// - NewGemma4Parser(true) is for callers that pre-open the thought channel +// in the prompt themselves (appending "<|channel>thought\n" after the +// generation prompt to force thinking): the output then begins mid-thought +// and everything is reasoning until the first . +// +// State diagram (markers are consumed, never emitted): +// +// <|channel> \n (channel name dropped: the +// [content] --------------> [chan-header] ----> [thought] "thought\n" role +// ^ | (stray close: swallowed, label, stripped +// +-+ strip_thinking semantics, tpl L148-L158) like vLLM does) +// ^ +// +----------------------------------------- [thought] +// ^ | <|tool_call> (implicit +// +-------------- [tool-call] <-------------------+ reasoning end, vLLM +// | <|tool_call> ^ is_reasoning_end) +// +-------------------+ +// [content]/[thought] --- ---> [done] (everything after is dropped) +// +// Buffering rules: +// - content/thought states hold back at most len(longest marker)-1 bytes: +// the longest tail that is still a proper prefix of a watched marker. +// Content is otherwise emitted immediately (no unbounded buffering). +// - the tool-call state buffers the whole payload until . This +// is unbounded in principle but bounded in practice by the model's +// diffusion canvas, and is required because the call:name{...} payload +// only becomes decodable (and trustworthy) once complete - the same +// reason vLLM's parser accumulates before parsing. +// - Close() flushes whatever is still held: partial markers come out as +// content/reasoning (per the state that held them); an unterminated +// channel header or tool-call payload is re-emitted RAW (including its +// opening marker) as content - malformed output is never silently +// dropped (mirrors vLLM extract_tool_calls returning the raw text as +// content when its regex does not match). +// +// Streaming granularity DIVERGENCE from vLLM: vLLM re-parses the partial +// payload on every token and streams argument-JSON diffs (its `partial=True` +// decoder mode plus withholding logic exist only for that). Our fragments are +// whole committed diffusion blocks, so each completed tool call is emitted +// once, as a single ToolCallDelta carrying index + id + name + the full +// arguments JSON - exactly the shape backend/python/vllm/backend.py emits +// per call and pkg/functions.ToolCallsFromChatDeltas re-accumulates. +package main + +import ( + "encoding/json" + "regexp" + "strconv" + "strings" + + pb "github.com/mudler/LocalAI/pkg/grpc/proto" +) + +// gemma4CallRE is vLLM's tool_call_regex +// (`<\|tool_call>call:([\w\-\.]+)\{(.*?)\}`, DOTALL) anchored to +// a single already-extracted payload: name charset [\w\-.], braces mandatory. +var gemma4CallRE = regexp.MustCompile(`(?s)^call:([\w\-.]+)\{(.*)\}$`) + +type g4State int + +const ( + g4Content g4State = iota + g4ChanHeader + g4Thought + g4ToolCall + g4Done +) + +// Markers watched per emitting state. A stray outside a tool +// call is deliberately NOT watched: it passes through verbatim, consistent +// with the malformed-payload fallback re-emitting it as content. +var ( + gemma4ContentMarkers = []string{gemma4ChannelOpen, gemma4ChannelClose, gemma4ToolCallOpen, gemma4TurnEnd} + gemma4ThoughtMarkers = []string{gemma4ChannelClose, gemma4ToolCallOpen, gemma4TurnEnd} +) + +type Gemma4Parser struct { + state g4State + // held is the per-state carry-over between Feed calls: a partial marker + // (content/thought), a partial channel header (chan-header) or the + // payload accumulated so far (tool-call). + held string + toolIdx int +} + +// NewGemma4Parser returns a parser positioned per the initial-state rules in +// the header comment: startInThought=true only when the caller pre-opened a +// thought channel in the prompt. +func NewGemma4Parser(startInThought bool) *Gemma4Parser { + state := g4Content + if startInThought { + state = g4Thought + } + return &Gemma4Parser{state: state} +} + +// Feed consumes the next output fragment and returns the deltas it completes. +func (p *Gemma4Parser) Feed(text string) []*pb.ChatDelta { + if text == "" || p.state == g4Done { + return nil + } + pending := p.held + text + p.held = "" + var em g4Emitter + for pending != "" { + switch p.state { + case g4Content, g4Thought: + markers := gemma4ContentMarkers + if p.state == g4Thought { + markers = gemma4ThoughtMarkers + } + idx, marker := findEarliestGemma4Marker(pending, markers) + if idx == -1 { + hold := gemma4MarkerHoldback(pending, markers) + p.emitText(&em, pending[:len(pending)-hold]) + p.held = pending[len(pending)-hold:] + pending = "" + continue + } + p.emitText(&em, pending[:idx]) + pending = pending[idx+len(marker):] + switch marker { + case gemma4ChannelOpen: + p.state = g4ChanHeader + case gemma4ChannelClose: + // In thought: channel ends. In content: stray close, + // swallowed (strip_thinking keeps both sides, tpl L148-L158). + p.state = g4Content + case gemma4ToolCallOpen: + p.state = g4ToolCall + case gemma4TurnEnd: + p.state = g4Done + } + case g4ChanHeader: + // The channel header is "\n"; the template only ever writes + // "thought" (tpl L240/L360) and the label is structural, so it is + // dropped, not emitted (vLLM strips the same "thought\n" prefix). + nl := strings.IndexByte(pending, '\n') + if nl == -1 { + p.held = pending + pending = "" + continue + } + pending = pending[nl+1:] + p.state = g4Thought + case g4ToolCall: + end := strings.Index(pending, gemma4ToolCallClose) + if end == -1 { + p.held = pending + pending = "" + continue + } + p.emitToolCall(&em, pending[:end]) + pending = pending[end+len(gemma4ToolCallClose):] + p.state = g4Content + case g4Done: + pending = "" + } + } + return em.deltas +} + +// Close flushes held-back partials. Incomplete structures (open channel +// header, unterminated tool payload) are re-emitted raw as content rather +// than dropped. The parser is finished afterwards. +func (p *Gemma4Parser) Close() []*pb.ChatDelta { + var em g4Emitter + switch p.state { + case g4Content: + em.content(p.held) + case g4Thought: + em.reasoning(p.held) + case g4ChanHeader: + em.content(gemma4ChannelOpen + p.held) + case g4ToolCall: + em.content(gemma4ToolCallOpen + p.held) + case g4Done: + } + p.held = "" + p.state = g4Done + return em.deltas +} + +func (p *Gemma4Parser) emitText(em *g4Emitter, s string) { + if p.state == g4Thought { + em.reasoning(s) + return + } + em.content(s) +} + +// emitToolCall decodes one complete <|tool_call>... payload. On a +// payload that does not match call:name{...} the raw text (markers included) +// is emitted as content, mirroring vLLM's extract_tool_calls fallback. +func (p *Gemma4Parser) emitToolCall(em *g4Emitter, payload string) { + m := gemma4CallRE.FindStringSubmatch(payload) + if m == nil { + em.content(gemma4ToolCallOpen + payload + gemma4ToolCallClose) + return + } + // Index-based ids: deterministic (the split-invariance property relies + // on it) and matching the call_ convention of pkg/grpc/rich_test.go; + // core only needs ids to be non-empty and unique within the response. + em.tool(p.toolIdx, "call_"+strconv.Itoa(p.toolIdx), m[1], decodeGemma4Args(m[2], 0)) + p.toolIdx++ +} + +// g4Emitter collects ChatDeltas; empty text events are dropped. +type g4Emitter struct { + deltas []*pb.ChatDelta +} + +func (e *g4Emitter) content(s string) { + if s != "" { + e.deltas = append(e.deltas, &pb.ChatDelta{Content: s}) + } +} + +func (e *g4Emitter) reasoning(s string) { + if s != "" { + e.deltas = append(e.deltas, &pb.ChatDelta{ReasoningContent: s}) + } +} + +func (e *g4Emitter) tool(index int, id, name, argsJSON string) { + e.deltas = append(e.deltas, &pb.ChatDelta{ToolCalls: []*pb.ToolCallDelta{{ + Index: int32(index), + Id: id, + Name: name, + Arguments: argsJSON, + }}}) +} + +// findEarliestGemma4Marker returns the position and value of the first +// complete marker occurrence, or (-1, ""). +func findEarliestGemma4Marker(s string, markers []string) (int, string) { + best, bestMarker := -1, "" + for _, m := range markers { + if idx := strings.Index(s, m); idx >= 0 && (best == -1 || idx < best) { + best, bestMarker = idx, m + } + } + return best, bestMarker +} + +// gemma4MarkerHoldback returns the length of the longest suffix of s that is +// a proper prefix of a watched marker - the only bytes that may still grow +// into a marker and therefore must not be emitted yet (bounded by the +// longest marker, so content is never buffered unboundedly). +func gemma4MarkerHoldback(s string, markers []string) int { + maxHold := 0 + for _, m := range markers { + if len(m)-1 > maxHold { + maxHold = len(m) - 1 + } + } + if len(s) < maxHold { + maxHold = len(s) + } + for k := maxHold; k >= 1; k-- { + tail := s[len(s)-k:] + for _, m := range markers { + if strings.HasPrefix(m, tail) { + return k + } + } + } + return 0 +} + +// --------------------------------------------------------------------------- +// call:name{...} argument decoder +// +// Port of vLLM's _parse_gemma4_args / _parse_gemma4_array / +// _parse_gemma4_value (gemma4_tool_parser.py) in non-partial mode only: this +// parser decodes exclusively COMPLETE payloads (incomplete ones fall back to +// raw content at Close), so vLLM's partial-withholding machinery +// (trailing-dot floats, withheld bare tails) is intentionally not ported. +// +// Grammar (inverse of the renderer's formatGemma4Argument, tpl L118-L147): +// +// args := pair (',' pair)* +// pair := key ':' value (keys unquoted, up to the first ':') +// value := string | object | array | bare +// string := '<|"|>' ... '<|"|>' (no escapes; unterminated -> rest) +// object := '{' args '}' (delimited strings skipped when +// array := '[' value,* ']' counting braces/brackets) +// bare := true | false | null/none/nil | number | bare-string +// +// Output is a JSON object/array string with keys in payload order (Python +// dict insertion order), built with HTML escaping off so payload text +// survives byte-for-byte. +// --------------------------------------------------------------------------- + +func isGemma4Space(c byte) bool { return c == ' ' || c == '\n' || c == '\t' } + +// gemma4MaxArgsDepth caps the mutual recursion between decodeGemma4Args and +// decodeGemma4Array. Defense against model-generated deep nesting: a Go stack +// overflow is a fatal process kill, not a recoverable error, so past the cap +// a nested body gracefully degrades to a JSON string of its raw text. +const gemma4MaxArgsDepth = 100 + +// decodeGemma4Args decodes one args body (the text between the outer braces +// of call:name{...}) into a JSON object string. depth is the current nesting +// level (0 at the payload root); see gemma4MaxArgsDepth. +func decodeGemma4Args(s string, depth int) string { + if depth > gemma4MaxArgsDepth { + return gemma4JSONString(s) + } + var b strings.Builder + b.WriteString("{") + first := true + pair := func(key, val string) { + if !first { + b.WriteString(",") + } + first = false + b.WriteString(gemma4JSONString(key)) + b.WriteString(":") + b.WriteString(val) + } + i, n := 0, len(s) + for i < n { + for i < n && (isGemma4Space(s[i]) || s[i] == ',') { + i++ + } + if i >= n { + break + } + keyStart := i + for i < n && s[i] != ':' { + i++ + } + if i >= n { + break // no ':' -> trailing junk, dropped (vLLM does the same) + } + key := strings.TrimSpace(s[keyStart:i]) + i++ // skip ':' + for i < n && isGemma4Space(s[i]) { + i++ + } + if i >= n { + pair(key, `""`) // "key:" with nothing after -> empty string + break + } + switch { + case strings.HasPrefix(s[i:], gemma4StringDelim): + i += len(gemma4StringDelim) + if end := strings.Index(s[i:], gemma4StringDelim); end == -1 { + pair(key, gemma4JSONString(s[i:])) // unterminated -> take rest + i = n + } else { + pair(key, gemma4JSONString(s[i:i+end])) + i += end + len(gemma4StringDelim) + } + case s[i] == '{': + inner, next := scanGemma4Balanced(s, i, '{', '}') + pair(key, decodeGemma4Args(inner, depth+1)) + i = next + case s[i] == '[': + inner, next := scanGemma4Balanced(s, i, '[', ']') + pair(key, decodeGemma4Array(inner, depth+1)) + i = next + default: + valStart := i + for i < n && s[i] != ',' && s[i] != '}' && s[i] != ']' { + i++ + } + if i == valStart { + // No progress (value starts on a stray '}'/']'): abort on + // malformed input rather than loop, like vLLM. + i = n + continue + } + pair(key, decodeGemma4Bare(s[valStart:i])) + } + } + b.WriteString("}") + return b.String() +} + +// decodeGemma4Array decodes one array body (the text between '[' and ']') +// into a JSON array string. depth is the current nesting level; see +// gemma4MaxArgsDepth. +func decodeGemma4Array(s string, depth int) string { + if depth > gemma4MaxArgsDepth { + return gemma4JSONString(s) + } + var b strings.Builder + b.WriteString("[") + first := true + item := func(val string) { + if !first { + b.WriteString(",") + } + first = false + b.WriteString(val) + } + i, n := 0, len(s) + for i < n { + for i < n && (isGemma4Space(s[i]) || s[i] == ',') { + i++ + } + if i >= n { + break + } + switch { + case strings.HasPrefix(s[i:], gemma4StringDelim): + i += len(gemma4StringDelim) + if end := strings.Index(s[i:], gemma4StringDelim); end == -1 { + item(gemma4JSONString(s[i:])) + i = n + } else { + item(gemma4JSONString(s[i : i+end])) + i += end + len(gemma4StringDelim) + } + case s[i] == '{': + inner, next := scanGemma4Balanced(s, i, '{', '}') + item(decodeGemma4Args(inner, depth+1)) + i = next + case s[i] == '[': + inner, next := scanGemma4Balanced(s, i, '[', ']') + item(decodeGemma4Array(inner, depth+1)) + i = next + default: + valStart := i + for i < n && s[i] != ',' && s[i] != ']' { + i++ + } + if i == valStart { + i = n // no progress: abort on malformed input, like vLLM + continue + } + item(decodeGemma4Bare(s[valStart:i])) + } + } + b.WriteString("]") + return b.String() +} + +// scanGemma4Balanced scans a brace/bracket-balanced span starting at the +// opener s[start], skipping over <|"|>-delimited strings so structural +// characters inside them do not count (vLLM's depth scan). Returns the inner +// text and the index just past the closer; an unterminated span yields the +// rest of the string (the inner decoder still extracts what is there - this +// path is only reachable from genuinely malformed complete payloads). +func scanGemma4Balanced(s string, start int, open, close byte) (string, int) { + depth := 1 + i := start + 1 + innerStart := i + n := len(s) + for i < n && depth > 0 { + if strings.HasPrefix(s[i:], gemma4StringDelim) { + i += len(gemma4StringDelim) + if nd := strings.Index(s[i:], gemma4StringDelim); nd == -1 { + i = n + } else { + i += nd + len(gemma4StringDelim) + } + continue + } + switch s[i] { + case open: + depth++ + case close: + depth-- + } + i++ + } + if depth > 0 { + return s[innerStart:], n + } + return s[innerStart : i-1], i +} + +// decodeGemma4Bare maps an undelimited value to its JSON form: booleans, +// null aliases (null/none/nil, case-insensitive - the renderer writes +// Python None as "None", tpl L144-L145 via format_argument's else branch), +// numbers (vLLM's rule: a '.' tries float, otherwise int; anything that +// fails parses as a bare string). +func decodeGemma4Bare(raw string) string { + v := strings.TrimSpace(raw) + if v == "" { + return `""` + } + if v == "true" || v == "false" { + return v + } + switch strings.ToLower(v) { + case "null", "none", "nil": + return "null" + } + if strings.Contains(v, ".") { + if f, err := strconv.ParseFloat(v, 64); err == nil { + return formatGemma4Float(f) + } + } else if iv, err := strconv.ParseInt(v, 10, 64); err == nil { + return strconv.FormatInt(iv, 10) + } + return gemma4JSONString(v) +} + +// formatGemma4Float renders like Python's json.dumps(float): integral floats +// keep a ".0" suffix ("108." decodes to 108.0, not 108), so the arguments +// JSON matches what vLLM would have produced for the same payload. +func formatGemma4Float(f float64) string { + s := strconv.FormatFloat(f, 'g', -1, 64) + if !strings.ContainsAny(s, ".eE") { + s += ".0" + } + return s +} + +// gemma4JSONString encodes a JSON string WITHOUT HTML escaping (json.Marshal +// would escape the angle brackets in "
" to \u003c / \u003e sequences; +// payload text should survive +// byte-for-byte, like Python's json.dumps(ensure_ascii=False)). +func gemma4JSONString(s string) string { + var sb strings.Builder + enc := json.NewEncoder(&sb) + enc.SetEscapeHTML(false) + if err := enc.Encode(s); err != nil { + // Unreachable for plain strings; fall back to default escaping + // rather than emitting invalid JSON. + b, mErr := json.Marshal(s) + if mErr != nil { + return `""` + } + return string(b) + } + // Encode appends a trailing newline. + return strings.TrimSuffix(sb.String(), "\n") +} diff --git a/backend/go/dllm/gemma4_parser_test.go b/backend/go/dllm/gemma4_parser_test.go new file mode 100755 index 000000000000..f3c243c0200a --- /dev/null +++ b/backend/go/dllm/gemma4_parser_test.go @@ -0,0 +1,592 @@ +package main + +// Parser specs for Gemma4Parser (model output text -> pb.ChatDelta events). +// +// Fixture provenance: +// - Entries marked "vLLM: " are direct ports of the named test from +// vLLM PR #45163, tests/tool_parsers/test_gemma4_tool_parser.py (the +// authoritative test-suite for the gemma4 tool-call wire format). The +// streaming tests' chunk lists are reused verbatim as Feed fragments. +// - Decoder entries port the TestParseGemma4Args / TestParseGemma4Array +// classes from the same file (non-partial mode only; this parser never +// decodes partial payloads, see the divergence note in gemma4_parser.go). +// - Channel/turn-marker expectations come from the chat template embedded +// in gemma4_renderer.go (tpl L356-L362 generation prompt, L148-L158 +// strip_thinking) and vLLM's Gemma4ReasoningParser +// (vllm/reasoning/gemma4_reasoning_parser.py). + +import ( + "encoding/json" + "fmt" + "strings" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + + pb "github.com/mudler/LocalAI/pkg/grpc/proto" +) + +// flatGemma4Tool is one accumulated tool call, mirroring how LocalAI core +// folds ToolCallDelta streams (pkg/functions/chat_deltas.go +// ToolCallsFromChatDeltas: name/id latch on first non-empty, arguments +// concatenate per index). Tests flatten through the same rules so they +// assert exactly what core will reconstruct. +type flatGemma4Tool struct { + id string + name string + args string +} + +func flattenGemma4Deltas(deltas []*pb.ChatDelta) (string, string, []flatGemma4Tool) { + var content, reasoning strings.Builder + byIndex := map[int32]*flatGemma4Tool{} + maxIdx := int32(-1) + for _, d := range deltas { + content.WriteString(d.GetContent()) + reasoning.WriteString(d.GetReasoningContent()) + for _, tc := range d.GetToolCalls() { + acc, ok := byIndex[tc.GetIndex()] + if !ok { + acc = &flatGemma4Tool{} + byIndex[tc.GetIndex()] = acc + } + if tc.GetName() != "" { + acc.name = tc.GetName() + } + if tc.GetId() != "" { + acc.id = tc.GetId() + } + acc.args += tc.GetArguments() + if tc.GetIndex() > maxIdx { + maxIdx = tc.GetIndex() + } + } + } + var tools []flatGemma4Tool + for i := int32(0); i <= maxIdx; i++ { + if acc, ok := byIndex[i]; ok { + tools = append(tools, *acc) + } + } + return content.String(), reasoning.String(), tools +} + +type wantGemma4Tool struct { + name string + argsJSON string // compared with MatchJSON (key order irrelevant) +} + +type parseGemma4Case struct { + startInThought bool + fragments []string + wantContent string + wantReasoning string + wantTools []wantGemma4Tool +} + +func parseGemma4Fragments(startInThought bool, fragments []string) []*pb.ChatDelta { + p := NewGemma4Parser(startInThought) + var all []*pb.ChatDelta + for _, f := range fragments { + all = append(all, p.Feed(f)...) + } + return append(all, p.Close()...) +} + +var _ = Describe("Gemma4Parser", func() { + DescribeTable("parses streamed gemma4 output into ChatDeltas", + func(c parseGemma4Case) { + content, reasoning, tools := flattenGemma4Deltas(parseGemma4Fragments(c.startInThought, c.fragments)) + Expect(content).To(Equal(c.wantContent)) + Expect(reasoning).To(Equal(c.wantReasoning)) + Expect(tools).To(HaveLen(len(c.wantTools))) + seenIDs := map[string]bool{} + for i, want := range c.wantTools { + Expect(tools[i].name).To(Equal(want.name), "tool %d name", i) + Expect(tools[i].args).To(MatchJSON(want.argsJSON), "tool %d arguments", i) + Expect(tools[i].id).ToNot(BeEmpty(), "tool %d id", i) + Expect(seenIDs).ToNot(HaveKey(tools[i].id), "tool %d id must be unique", i) + seenIDs[tools[i].id] = true + } + }, + + // --- (1) pure content ------------------------------------------------- + // vLLM: test_no_tool_calls + Entry("pure content, single fragment", parseGemma4Case{ + fragments: []string{"Hello, how can I help you today?"}, + wantContent: "Hello, how can I help you today?", + }), + + // --- (2) thought -> final transition ---------------------------------- + // enable_thinking render: prompt ends at <|turn>model\n and the model + // opens/closes its own thought channel in the OUTPUT (vLLM + // Gemma4ReasoningParser docstring; tpl L356-L362). The "thought\n" + // role label after <|channel> is structural and must be stripped + // (vLLM _THOUGHT_PREFIX handling). + Entry("thought channel then final content", parseGemma4Case{ + fragments: []string{"<|channel>thought\nLet me think about this.\nThe answer is 42."}, + wantReasoning: "Let me think about this.\n", + wantContent: "The answer is 42.", + }), + + // --- (3) startInThought both ways ------------------------------------- + Entry("startInThought=true routes initial text to reasoning until ", parseGemma4Case{ + startInThought: true, + fragments: []string{"I am thinking hard.Done."}, + wantReasoning: "I am thinking hard.", + wantContent: "Done.", + }), + // A stray with no open channel is swallowed, matching the + // template's strip_thinking (tpl L148-L158: the marker is dropped, + // text on both sides is kept). + Entry("startInThought=false keeps the same text as content, stray swallowed", parseGemma4Case{ + startInThought: false, + fragments: []string{"I am thinking hard.Done."}, + wantContent: "I am thinking hard.Done.", + }), + + // --- (4) one tool call, full payload type zoo -------------------------- + Entry("single tool call: strings, numbers, bools, null, nested object and array", parseGemma4Case{ + fragments: []string{`<|tool_call>call:complex_function{text:<|"|>with, comma and {braces}<|"|>,count:42,score:3.14,yes:true,no:false,nothing:null,obj:{inner:<|"|>v<|"|>,k:1},arr:[<|"|>a<|"|>,2,true]}`}, + wantTools: []wantGemma4Tool{{ + name: "complex_function", + argsJSON: `{"text":"with, comma and {braces}","count":42,"score":3.14,"yes":true,"no":false,"nothing":null,"obj":{"inner":"v","k":1},"arr":["a",2,true]}`, + }}, + }), + + // --- (5) payload split across 3 fragments ------------------------------ + Entry("tool-call payload split across three fragments", parseGemma4Case{ + fragments: []string{ + "<|tool_call>call:get_weather{loc", + `ation:<|"|>Paris, Fra`, + `nce<|"|>}`, + }, + wantTools: []wantGemma4Tool{{name: "get_weather", argsJSON: `{"location":"Paris, France"}`}}, + }), + + // --- (6) marker split across fragments ---------------------------------- + Entry("tool-call open marker split across fragments", parseGemma4Case{ + fragments: []string{ + "<|tool_ca", + `ll>call:get_weather{location:<|"|>London<|"|>}`, + }, + wantTools: []wantGemma4Tool{{name: "get_weather", argsJSON: `{"location":"London"}`}}, + }), + Entry("channel open marker split across fragments", parseGemma4Case{ + fragments: []string{ + "<|chan", + "nel>thought\ndeep thoughtfinal", + }, + wantReasoning: "deep thought", + wantContent: "final", + }), + + // --- (7) trailing partial marker held, flushed by Close ----------------- + Entry("trailing partial marker is held back and flushed by Close", parseGemma4Case{ + fragments: []string{"Hello <|tool"}, + wantContent: "Hello <|tool", + }), + + // --- (8) malformed/incomplete payload -> content fallback --------------- + // vLLM: test_incomplete_tool_call (no end marker: the whole text stays + // content, never silently dropped). + Entry("incomplete tool payload at Close is emitted as raw content", parseGemma4Case{ + fragments: []string{`<|tool_call>call:get_weather{location:<|"|>London`}, + wantContent: `<|tool_call>call:get_weather{location:<|"|>London`, + }), + Entry("malformed complete payload is emitted as raw content, parsing continues", parseGemma4Case{ + fragments: []string{"<|tool_call>oops no call syntax done"}, + wantContent: "<|tool_call>oops no call syntax done", + }), + + // --- (9) ends the turn ------------------------------------------- + Entry("text after is ignored, including later fragments", parseGemma4Case{ + fragments: []string{ + "beforeafter", + `more <|tool_call>call:f{}`, + }, + wantContent: "before", + }), + Entry(" inside a thought channel ends the turn", parseGemma4Case{ + startInThought: true, + fragments: []string{"thinkingignored"}, + wantReasoning: "thinking", + }), + + // --- (10) ported vLLM non-streaming cases --------------------------------- + // vLLM: test_single_tool_call + Entry("vLLM: test_single_tool_call", parseGemma4Case{ + fragments: []string{`<|tool_call>call:get_weather{location:<|"|>London<|"|>}`}, + wantTools: []wantGemma4Tool{{name: "get_weather", argsJSON: `{"location":"London"}`}}, + }), + // vLLM: test_multiple_arguments + Entry("vLLM: test_multiple_arguments", parseGemma4Case{ + fragments: []string{`<|tool_call>call:get_weather{location:<|"|>San Francisco<|"|>,unit:<|"|>celsius<|"|>}`}, + wantTools: []wantGemma4Tool{{name: "get_weather", argsJSON: `{"location":"San Francisco","unit":"celsius"}`}}, + }), + // vLLM: test_text_before_tool_call. DIVERGENCE: vLLM's non-streaming + // extractor trims the content ("...you."); a streaming parser cannot + // retroactively trim already-emitted text, so the trailing space is + // kept (vLLM's own streaming path keeps it too, see + // test_streaming_text_before_tool_call which only checks a prefix). + Entry("vLLM: test_text_before_tool_call (streaming semantics: no trim)", parseGemma4Case{ + fragments: []string{`Let me check the weather for you. <|tool_call>call:get_weather{location:<|"|>Paris<|"|>}`}, + wantContent: "Let me check the weather for you. ", + wantTools: []wantGemma4Tool{{name: "get_weather", argsJSON: `{"location":"Paris"}`}}, + }), + // vLLM: test_multiple_tool_calls (also covers case 11: multi-tool sequence) + Entry("vLLM: test_multiple_tool_calls", parseGemma4Case{ + fragments: []string{`<|tool_call>call:get_weather{location:<|"|>London<|"|>}<|tool_call>call:get_time{location:<|"|>London<|"|>}`}, + wantTools: []wantGemma4Tool{ + {name: "get_weather", argsJSON: `{"location":"London"}`}, + {name: "get_time", argsJSON: `{"location":"London"}`}, + }, + }), + // vLLM: test_nested_arguments + Entry("vLLM: test_nested_arguments", parseGemma4Case{ + fragments: []string{`<|tool_call>call:complex_function{nested:{inner:<|"|>value<|"|>},list:[<|"|>a<|"|>,<|"|>b<|"|>]}`}, + wantTools: []wantGemma4Tool{{name: "complex_function", argsJSON: `{"nested":{"inner":"value"},"list":["a","b"]}`}}, + }), + // vLLM: test_tool_call_with_number_and_boolean + Entry("vLLM: test_tool_call_with_number_and_boolean", parseGemma4Case{ + fragments: []string{`<|tool_call>call:set_status{is_active:true,count:42,score:3.14}`}, + wantTools: []wantGemma4Tool{{name: "set_status", argsJSON: `{"is_active":true,"count":42,"score":3.14}`}}, + }), + // vLLM: test_hyphenated_function_name + Entry("vLLM: test_hyphenated_function_name", parseGemma4Case{ + fragments: []string{`<|tool_call>call:get-weather{location:<|"|>London<|"|>}`}, + wantTools: []wantGemma4Tool{{name: "get-weather", argsJSON: `{"location":"London"}`}}, + }), + // vLLM: test_dotted_function_name + Entry("vLLM: test_dotted_function_name", parseGemma4Case{ + fragments: []string{`<|tool_call>call:weather.get{location:<|"|>London<|"|>}`}, + wantTools: []wantGemma4Tool{{name: "weather.get", argsJSON: `{"location":"London"}`}}, + }), + // vLLM: test_no_arguments + Entry("vLLM: test_no_arguments", parseGemma4Case{ + fragments: []string{"<|tool_call>call:get_status{}"}, + wantTools: []wantGemma4Tool{{name: "get_status", argsJSON: `{}`}}, + }), + + // --- ported vLLM streaming cases (chunk lists reused as fragments) -------- + // vLLM: test_basic_streaming_single_tool + Entry("vLLM: test_basic_streaming_single_tool", parseGemma4Case{ + fragments: []string{ + "<|tool_call>", + "call:get_weather{", + `location:<|"|>Paris`, + ", France", + `<|"|>}`, + "", + }, + wantTools: []wantGemma4Tool{{name: "get_weather", argsJSON: `{"location":"Paris, France"}`}}, + }), + // vLLM: test_streaming_multi_arg + Entry("vLLM: test_streaming_multi_arg", parseGemma4Case{ + fragments: []string{ + "<|tool_call>", + "call:get_weather{", + `location:<|"|>Tokyo<|"|>,`, + `unit:<|"|>celsius<|"|>}`, + "", + }, + wantTools: []wantGemma4Tool{{name: "get_weather", argsJSON: `{"location":"Tokyo","unit":"celsius"}`}}, + }), + // vLLM: test_streaming_text_before_tool_call + Entry("vLLM: test_streaming_text_before_tool_call", parseGemma4Case{ + fragments: []string{ + "Let me check ", + "the weather. ", + "<|tool_call>", + "call:get_weather{", + `location:<|"|>London<|"|>}`, + "", + }, + wantContent: "Let me check the weather. ", + wantTools: []wantGemma4Tool{{name: "get_weather", argsJSON: `{"location":"London"}`}}, + }), + // vLLM: test_streaming_numeric_args + Entry("vLLM: test_streaming_numeric_args", parseGemma4Case{ + fragments: []string{ + "<|tool_call>", + "call:set_config{", + "count:42,", + "active:true}", + "", + }, + wantTools: []wantGemma4Tool{{name: "set_config", argsJSON: `{"count":42,"active":true}`}}, + }), + // vLLM: test_streaming_boolean_split_across_chunks + Entry("vLLM: test_streaming_boolean_split_across_chunks", parseGemma4Case{ + fragments: []string{ + "<|tool_call>", + "call:search{input:{all:tru", + "e}}", + "", + }, + wantTools: []wantGemma4Tool{{name: "search", argsJSON: `{"input":{"all":true}}`}}, + }), + // vLLM: test_streaming_false_split_across_chunks + Entry("vLLM: test_streaming_false_split_across_chunks", parseGemma4Case{ + fragments: []string{ + "<|tool_call>", + "call:set{flag:fals", + "e}", + "", + }, + wantTools: []wantGemma4Tool{{name: "set", argsJSON: `{"flag":false}`}}, + }), + // vLLM: test_streaming_number_split_across_chunks + Entry("vLLM: test_streaming_number_split_across_chunks", parseGemma4Case{ + fragments: []string{ + "<|tool_call>", + "call:set{count:4", + "2}", + "", + }, + wantTools: []wantGemma4Tool{{name: "set", argsJSON: `{"count":42}`}}, + }), + // vLLM: test_streaming_empty_args + Entry("vLLM: test_streaming_empty_args", parseGemma4Case{ + fragments: []string{ + "<|tool_call>", + "call:get_status{}", + "", + }, + wantTools: []wantGemma4Tool{{name: "get_status", argsJSON: `{}`}}, + }), + // vLLM: test_streaming_split_delimiter_no_invalid_json (string + // delimiter <|"|> split across fragments must not leak fragments). + Entry("vLLM: test_streaming_split_delimiter_no_invalid_json", parseGemma4Case{ + fragments: []string{ + "<|tool_call>", + "call:todowrite{", + `content:<|"|>Buy milk<|`, + `"|>}`, + "", + }, + wantTools: []wantGemma4Tool{{name: "todowrite", argsJSON: `{"content":"Buy milk"}`}}, + }), + // vLLM: test_streaming_does_not_duplicate_plain_text_after_tool_call + Entry("vLLM: test_streaming_does_not_duplicate_plain_text_after_tool_call", parseGemma4Case{ + fragments: []string{ + "<|tool_call>", + "call:get_weather{", + `location:<|"|>Paris<|"|>}`, + "<", + "div>", + }, + wantContent: "
", + wantTools: []wantGemma4Tool{{name: "get_weather", argsJSON: `{"location":"Paris"}`}}, + }), + // vLLM: test_streaming_html_argument_does_not_duplicate_tag_prefixes + Entry("vLLM: test_streaming_html_argument_does_not_duplicate_tag_prefixes", parseGemma4Case{ + fragments: []string{ + "<|tool_call>", + "call:write_file{", + `path:<|"|>index.html<|"|>,`, + `content:<|"|>` + "\n<", + `html lang="zh-CN">` + "\n<", + "head>\n <", + `meta charset="UTF-8">` + "\n <", + `meta name="viewport" content="width=device-width">` + "\n", + `<|"|>}`, + "", + }, + wantTools: []wantGemma4Tool{{ + name: "write_file", + argsJSON: `{"path":"index.html","content":"\n\n\n \n \n"}`, + }}, + }), + // vLLM: test_streaming_single_chunk_complete_tool_call + Entry("vLLM: test_streaming_single_chunk_complete_tool_call", parseGemma4Case{ + fragments: []string{`<|tool_call>call:name_a_color{color_hex:<|"|>00ff11<|"|>}`}, + wantTools: []wantGemma4Tool{{name: "name_a_color", argsJSON: `{"color_hex":"00ff11"}`}}, + }), + // vLLM: test_streaming_multi_chunk_batched_tool_calls (two complete + // calls in ONE fragment; both must come out with distinct indices) + Entry("vLLM: test_streaming_multi_chunk_batched_tool_calls", parseGemma4Case{ + fragments: []string{ + `<|tool_call>call:get_weather{location:<|"|>London<|"|>}` + + `<|tool_call>call:get_time{timezone:<|"|>GMT<|"|>}`, + }, + wantTools: []wantGemma4Tool{ + {name: "get_weather", argsJSON: `{"location":"London"}`}, + {name: "get_time", argsJSON: `{"timezone":"GMT"}`}, + }, + }), + // vLLM: test_streaming_trailing_bare_bool_not_duplicated + Entry("vLLM: test_streaming_trailing_bare_bool_not_duplicated", parseGemma4Case{ + fragments: []string{ + "<|tool_call>", + "call:Edit{", + `file_path:<|"|>src/env.py<|"|>,`, + `old_string:<|"|>old_val<|"|>,`, + `new_string:<|"|>new_val<|"|>,`, + "replace_all:", + "false}", + "", + }, + wantTools: []wantGemma4Tool{{ + name: "Edit", + argsJSON: `{"file_path":"src/env.py","old_string":"old_val","new_string":"new_val","replace_all":false}`, + }}, + }), + + // --- implicit reasoning end on <|tool_call> (vLLM is_reasoning_end: + // a tool_call token means reasoning is over) ----------------------------- + Entry("tool call inside an open thought channel ends the reasoning", parseGemma4Case{ + startInThought: true, + fragments: []string{`need the weather<|tool_call>call:get_weather{location:<|"|>Rome<|"|>}`}, + wantReasoning: "need the weather", + wantTools: []wantGemma4Tool{{name: "get_weather", argsJSON: `{"location":"Rome"}`}}, + }), + + // --- (12) empty fragments are no-ops -------------------------------------- + Entry("empty fragments are no-ops", parseGemma4Case{ + fragments: []string{"", "Hello", "", "", " world", ""}, + wantContent: "Hello world", + }), + ) + + It("returns no deltas for an empty fragment and after Close", func() { + p := NewGemma4Parser(false) + Expect(p.Feed("")).To(BeEmpty()) + Expect(p.Feed("hi")).ToNot(BeEmpty()) + Expect(p.Close()).To(BeEmpty()) // nothing held back + // The parser is finished after Close: further input is dropped. + Expect(p.Feed("more")).To(BeEmpty()) + Expect(p.Close()).To(BeEmpty()) + }) + + It("generates index-based tool call ids (call_)", func() { + // Mirrors the index-based id convention of pkg/grpc/rich_test.go and + // keeps ids deterministic for the split-invariance property below. + deltas := parseGemma4Fragments(false, []string{ + `<|tool_call>call:a{}<|tool_call>call:b{}`, + }) + _, _, tools := flattenGemma4Deltas(deltas) + Expect(tools).To(HaveLen(2)) + Expect(tools[0].id).To(Equal("call_0")) + Expect(tools[1].id).To(Equal("call_1")) + }) + + // Property: for a fixed full output, EVERY 2-split position must yield + // exactly the same flattened result as the unsplit parse. This kills + // fragment-boundary bugs (mid-marker, mid-delimiter, mid-payload splits). + DescribeTable("2-split fragment invariance", + func(startInThought bool, full string) { + refContent, refReasoning, refTools := flattenGemma4Deltas( + parseGemma4Fragments(startInThought, []string{full})) + for i := 0; i <= len(full); i++ { + content, reasoning, tools := flattenGemma4Deltas( + parseGemma4Fragments(startInThought, []string{full[:i], full[i:]})) + Expect(content).To(Equal(refContent), fmt.Sprintf("content diverged at split %d", i)) + Expect(reasoning).To(Equal(refReasoning), fmt.Sprintf("reasoning diverged at split %d", i)) + Expect(tools).To(Equal(refTools), fmt.Sprintf("tool calls diverged at split %d", i)) + } + }, + Entry("thought + content + two tool calls + turn end", false, + "<|channel>thought\nPondering the request...\nSure - calling tools now. "+ + `<|tool_call>call:get_weather{location:<|"|>Paris, France<|"|>,unit:<|"|>celsius<|"|>,days:3,detailed:true}`+ + `<|tool_call>call:get_time{timezone:<|"|>Europe/Lisbon<|"|>,nested:{flag:false,vals:[1,2.5,<|"|>x<|"|>]}}`+ + "Done.ignored tail"), + Entry("startInThought + tool call + trailing partial marker", true, + `Deep thoughtfinal answer <|tool_call>call:noop{} trailing <|tool`), + Entry("malformed payload fallback", false, + `pre <|tool_call>not a call post`), + ) +}) + +// Decoder-level ports of vLLM's TestParseGemma4Args / TestParseGemma4Array +// (non-partial mode; the partial-withholding tests do not apply because this +// parser only ever decodes COMPLETE payloads, see gemma4_parser.go). +var _ = Describe("decodeGemma4Args", func() { + DescribeTable("decodes the gemma4 call syntax into JSON arguments", + func(in, wantJSON string) { + Expect(decodeGemma4Args(in, 0)).To(MatchJSON(wantJSON)) + }, + // vLLM: test_empty_string / test_whitespace_only + Entry("empty string", "", `{}`), + Entry("whitespace only", " ", `{}`), + // vLLM: test_single_string_value + Entry("single string value", `location:<|"|>Paris<|"|>`, `{"location":"Paris"}`), + // vLLM: test_string_value_with_comma + Entry("string value with comma", `location:<|"|>Paris, France<|"|>`, `{"location":"Paris, France"}`), + // vLLM: test_multiple_string_values + Entry("multiple string values", `location:<|"|>San Francisco<|"|>,unit:<|"|>celsius<|"|>`, `{"location":"San Francisco","unit":"celsius"}`), + // vLLM: test_integer_value / test_float_value + Entry("integer value", "count:42", `{"count":42}`), + Entry("float value", "score:3.14", `{"score":3.14}`), + // vLLM: test_boolean_true / test_boolean_false + Entry("boolean true", "flag:true", `{"flag":true}`), + Entry("boolean false", "flag:false", `{"flag":false}`), + // vLLM: test_null_value (bare null must become JSON null, not "null") + Entry("null value", "param:null", `{"param":null}`), + // vLLM: test_mixed_types + Entry("mixed types", `name:<|"|>test<|"|>,count:42,active:true,score:3.14`, + `{"name":"test","count":42,"active":true,"score":3.14}`), + // vLLM: test_nested_object + Entry("nested object", `nested:{inner:<|"|>value<|"|>}`, `{"nested":{"inner":"value"}}`), + // vLLM: test_array_of_strings + Entry("array of strings", `items:[<|"|>a<|"|>,<|"|>b<|"|>]`, `{"items":["a","b"]}`), + // vLLM: test_unterminated_string (take everything after the delimiter) + Entry("unterminated string", `key:<|"|>unterminated`, `{"key":"unterminated"}`), + // vLLM: test_empty_value (key with no value after colon) + Entry("empty value", "key:", `{"key":""}`), + // vLLM: test_trailing_dot_float_partial_withheld, non-partial branch + // (trailing-dot floats parse normally outside streaming). + Entry("trailing dot float, complete payload", "left:108.,right:22.8", `{"left":108.0,"right":22.8}`), + ) + + It("terminates and yields valid JSON on malformed input", func() { + // vLLM: test_malformed_partial_array (the assertion there is only + // "returns a dict without hanging"; ours is "valid JSON object"). + out := decodeGemma4Args(":[t:[]", 0) + var v map[string]any + Expect(json.Unmarshal([]byte(out), &v)).To(Succeed()) + }) + + It("degrades nesting beyond the recursion cap to a string value", func() { + // 200 levels of a:{a:{...a:1...}}. Without the depth cap the mutual + // recursion would grow the stack with the model's output; a Go stack + // overflow is a fatal process kill, so levels past gemma4MaxArgsDepth + // must gracefully fall back to the raw inner text as a JSON string. + const depth = 200 + body := strings.Repeat("a:{", depth-1) + "a:1" + strings.Repeat("}", depth-1) + out := decodeGemma4Args(body, 0) + var v map[string]any + Expect(json.Unmarshal([]byte(out), &v)).To(Succeed()) + levels := 0 + var cur any = v + for { + m, ok := cur.(map[string]any) + if !ok { + break + } + Expect(m).To(HaveKey("a")) + cur = m["a"] + levels++ + } + Expect(levels).To(Equal(gemma4MaxArgsDepth + 1)) + Expect(cur).To(BeAssignableToTypeOf("")) + Expect(cur).To(ContainSubstring("a:{")) + }) +}) + +var _ = Describe("decodeGemma4Array", func() { + DescribeTable("decodes gemma4 array bodies into JSON arrays", + func(in, wantJSON string) { + Expect(decodeGemma4Array(in, 0)).To(MatchJSON(wantJSON)) + }, + // vLLM: test_string_array / test_empty_array / test_bare_values + Entry("string array", `<|"|>a<|"|>,<|"|>b<|"|>`, `["a","b"]`), + Entry("empty array", "", `[]`), + Entry("bare values", "42,true,3.14", `[42,true,3.14]`), + // vLLM: test_string_element_with_closing_bracket (a ']' inside a + // delimited string must not close the array) + Entry("string element with closing bracket", `[<|"|>a]b<|"|>,<|"|>c<|"|>],<|"|>tail<|"|>`, `[["a]b","c"],"tail"]`), + // vLLM: test_stray_closing_bracket (no-progress abort, keep prefix) + Entry("stray closing bracket", "42,]trailing", `[42]`), + ) +}) diff --git a/backend/go/dllm/gemma4_renderer.go b/backend/go/dllm/gemma4_renderer.go new file mode 100755 index 000000000000..868d98e4a8a1 --- /dev/null +++ b/backend/go/dllm/gemma4_renderer.go @@ -0,0 +1,1026 @@ +// Gemma4 (DiffusionGemma) chat template - NORMATIVE REFERENCE. +// +// The block comment below is the FULL `tokenizer.chat_template` extracted +// verbatim from diffusiongemma-26B-A4B-it-BF16.gguf via gguf-py's GGUFReader +// (17466 bytes, md5 8c34cf93c7a7815b3fdb300a009c4c17). Line numbers were +// added for citation only ("tpl L" throughout this file); the template +// text itself is untouched. RenderGemma4 replicates this template +// byte-for-byte (verified against jinja2 renders and the transformers +// fixtures in tests/models/diffusion_gemma/test_modeling_diffusion_gemma.py), +// with ONE deliberate exception: the leading `{{- bos_token -}}` is NOT +// emitted - see the BOS NOTE after the template. +// +/* + 1 {%- macro format_parameters(properties, required, filter_keys=false) -%} + 2 {%- set standard_keys = ['description', 'type', 'properties', 'required', 'nullable'] -%} + 3 {%- set ns = namespace(found_first=false) -%} + 4 {%- for key, value in properties | dictsort -%} + 5 {%- set add_comma = false -%} + 6 {%- if not filter_keys or key not in standard_keys -%} + 7 {%- if ns.found_first %},{% endif -%} + 8 {%- set ns.found_first = true -%} + 9 {{ key }}:{ + 10 {%- if value['description'] -%} + 11 description:<|"|>{{ value['description'] }}<|"|> + 12 {%- set add_comma = true -%} + 13 {%- endif -%} + 14 {%- if value['type'] | upper == 'STRING' -%} + 15 {%- if value['enum'] -%} + 16 {%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%} + 17 enum:{{ format_argument(value['enum']) }} + 18 {%- endif -%} + 19 {%- elif value['type'] | upper == 'ARRAY' -%} + 20 {%- if value['items'] is mapping and value['items'] -%} + 21 {%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%} + 22 items:{ + 23 {%- set ns_items = namespace(found_first=false) -%} + 24 {%- for item_key, item_value in value['items'] | dictsort -%} + 25 {%- if item_value is not none -%} + 26 {%- if ns_items.found_first %},{% endif -%} + 27 {%- set ns_items.found_first = true -%} + 28 {%- if item_key == 'properties' -%} + 29 properties:{ + 30 {%- if item_value is mapping -%} + 31 {{- format_parameters(item_value, value['items']['required'] | default([])) -}} + 32 {%- endif -%} + 33 } + 34 {%- elif item_key == 'required' -%} + 35 required:[ + 36 {%- for req_item in item_value -%} + 37 <|"|>{{- req_item -}}<|"|> + 38 {%- if not loop.last %},{% endif -%} + 39 {%- endfor -%} + 40 ] + 41 {%- elif item_key == 'type' -%} + 42 {%- if item_value is string -%} + 43 type:{{ format_argument(item_value | upper) }} + 44 {%- else -%} + 45 type:{{ format_argument(item_value | map('upper') | list) }} + 46 {%- endif -%} + 47 {%- else -%} + 48 {{ item_key }}:{{ format_argument(item_value) }} + 49 {%- endif -%} + 50 {%- endif -%} + 51 {%- endfor -%} + 52 } + 53 {%- endif -%} + 54 {%- endif -%} + 55 {%- if value['nullable'] %} + 56 {%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%} + 57 nullable:true + 58 {%- endif -%} + 59 {%- if value['type'] | upper == 'OBJECT' -%} + 60 {%- if value['properties'] is defined and value['properties'] is mapping -%} + 61 {%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%} + 62 properties:{ + 63 {{- format_parameters(value['properties'], value['required'] | default([])) -}} + 64 } + 65 {%- elif value is mapping -%} + 66 {%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%} + 67 properties:{ + 68 {{- format_parameters(value, value['required'] | default([]), filter_keys=true) -}} + 69 } + 70 {%- endif -%} + 71 {%- if value['required'] -%} + 72 {%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%} + 73 required:[ + 74 {%- for item in value['required'] | default([]) -%} + 75 <|"|>{{- item -}}<|"|> + 76 {%- if not loop.last %},{% endif -%} + 77 {%- endfor -%} + 78 ] + 79 {%- endif -%} + 80 {%- endif -%} + 81 {%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%} + 82 type:<|"|>{{ value['type'] | upper }}<|"|>} + 83 {%- endif -%} + 84 {%- endfor -%} + 85 {%- endmacro -%} + 86 {%- macro format_function_declaration(tool_data) -%} + 87 declaration:{{- tool_data['function']['name'] -}}{description:<|"|>{{- tool_data['function']['description'] -}}<|"|> + 88 {%- set params = tool_data['function']['parameters'] -%} + 89 {%- if params -%} + 90 ,parameters:{ + 91 {%- if params['properties'] -%} + 92 properties:{ {{- format_parameters(params['properties'], params['required']) -}} }, + 93 {%- endif -%} + 94 {%- if params['required'] -%} + 95 required:[ + 96 {%- for item in params['required'] -%} + 97 <|"|>{{- item -}}<|"|> + 98 {{- ',' if not loop.last -}} + 99 {%- endfor -%} + 100 ], + 101 {%- endif -%} + 102 {%- if params['type'] -%} + 103 type:<|"|>{{- params['type'] | upper -}}<|"|>} + 104 {%- endif -%} + 105 {%- endif -%} + 106 {%- if 'response' in tool_data['function'] -%} + 107 {%- set response_declaration = tool_data['function']['response'] -%} + 108 ,response:{ + 109 {%- if response_declaration['description'] -%} + 110 description:<|"|>{{- response_declaration['description'] -}}<|"|>, + 111 {%- endif -%} + 112 {%- if response_declaration['type'] | upper == 'OBJECT' -%} + 113 type:<|"|>{{- response_declaration['type'] | upper -}}<|"|>} + 114 {%- endif -%} + 115 {%- endif -%} + 116 } + 117 {%- endmacro -%} + 118 {%- macro format_argument(argument, escape_keys=True) -%} + 119 {%- if argument is string -%} + 120 {{- '<|"|>' + argument + '<|"|>' -}} + 121 {%- elif argument is boolean -%} + 122 {{- 'true' if argument else 'false' -}} + 123 {%- elif argument is mapping -%} + 124 {{- '{' -}} + 125 {%- set ns = namespace(found_first=false) -%} + 126 {%- for key, value in argument | dictsort -%} + 127 {%- if ns.found_first %},{% endif -%} + 128 {%- set ns.found_first = true -%} + 129 {%- if escape_keys -%} + 130 {{- '<|"|>' + key + '<|"|>' -}} + 131 {%- else -%} + 132 {{- key -}} + 133 {%- endif -%} + 134 :{{- format_argument(value, escape_keys=escape_keys) -}} + 135 {%- endfor -%} + 136 {{- '}' -}} + 137 {%- elif argument is sequence -%} + 138 {{- '[' -}} + 139 {%- for item in argument -%} + 140 {{- format_argument(item, escape_keys=escape_keys) -}} + 141 {%- if not loop.last %},{% endif -%} + 142 {%- endfor -%} + 143 {{- ']' -}} + 144 {%- else -%} + 145 {{- argument -}} + 146 {%- endif -%} + 147 {%- endmacro -%} + 148 {%- macro strip_thinking(text) -%} + 149 {%- set ns = namespace(result='') -%} + 150 {%- for part in text.split('') -%} + 151 {%- if '<|channel>' in part -%} + 152 {%- set ns.result = ns.result + part.split('<|channel>')[0] -%} + 153 {%- else -%} + 154 {%- set ns.result = ns.result + part -%} + 155 {%- endif -%} + 156 {%- endfor -%} + 157 {{- ns.result | trim -}} + 158 {%- endmacro -%} + 159 + 160 {%- macro format_tool_response_block(tool_name, response) -%} + 161 {{- '<|tool_response>' -}} + 162 {%- if response is mapping -%} + 163 {{- 'response:' + tool_name + '{' -}} + 164 {%- for key, value in response | dictsort -%} + 165 {{- key -}}:{{- format_argument(value, escape_keys=False) -}} + 166 {%- if not loop.last %},{% endif -%} + 167 {%- endfor -%} + 168 {{- '}' -}} + 169 {%- else -%} + 170 {{- 'response:' + tool_name + '{value:' + format_argument(response, escape_keys=False) + '}' -}} + 171 {%- endif -%} + 172 {{- '' -}} + 173 {%- endmacro -%} + 174 + 175 {%- set ns = namespace(prev_message_type=None) -%} + 176 {%- set loop_messages = messages -%} + 177 {{- bos_token -}} + 178 {#- Handle System/Tool Definitions Block -#} + 179 {%- if (enable_thinking is defined and enable_thinking) or tools or messages[0]['role'] in ['system', 'developer'] -%} + 180 {{- '<|turn>system\n' -}} + 181 {#- Inject Thinking token at the very top of the FIRST system turn -#} + 182 {%- if enable_thinking is defined and enable_thinking -%} + 183 {{- '<|think|>\n' -}} + 184 {%- set ns.prev_message_type = 'think' -%} + 185 {%- endif -%} + 186 {%- if messages[0]['role'] in ['system', 'developer'] -%} + 187 {%- if messages[0]['content'] is string -%} + 188 {{- messages[0]['content'] | trim -}} + 189 {%- elif messages[0]['content'] is sequence -%} + 190 {%- for item in messages[0]['content'] -%} + 191 {{- item['text'] | trim + ' '-}} + 192 {%- endfor -%} + 193 {%- endif -%} + 194 {%- set loop_messages = messages[1:] -%} + 195 {%- endif -%} + 196 {%- if tools -%} + 197 {%- for tool in tools %} + 198 {{- '<|tool>' -}} + 199 {{- format_function_declaration(tool) | trim -}} + 200 {{- '' -}} + 201 {%- endfor %} + 202 {%- set ns.prev_message_type = 'tool' -%} + 203 {%- endif -%} + 204 {{- '\n' -}} + 205 {%- endif %} + 206 + 207 {#- Pre-scan: find last user message index for reasoning guard -#} + 208 {%- set ns_turn = namespace(last_user_idx=-1) -%} + 209 {%- for i in range(loop_messages | length) -%} + 210 {%- if loop_messages[i]['role'] == 'user' -%} + 211 {%- set ns_turn.last_user_idx = i -%} + 212 {%- endif -%} + 213 {%- endfor -%} + 214 + 215 {#- Loop through messages -#} + 216 {%- for message in loop_messages -%} + 217 {%- if message['role'] != 'tool' -%} + 218 {%- set ns.prev_message_type = None -%} + 219 {%- set role = 'model' if message['role'] == 'assistant' else message['role'] -%} + 220 {#- Detect continuation: suppress duplicate <|turn>model when previous non-tool message was also assistant -#} + 221 {%- set prev_nt = namespace(role=None, found=false) -%} + 222 {%- if loop.index0 > 0 -%} + 223 {%- for j in range(loop.index0 - 1, -1, -1) -%} + 224 {%- if not prev_nt.found -%} + 225 {%- if loop_messages[j]['role'] != 'tool' -%} + 226 {%- set prev_nt.role = loop_messages[j]['role'] -%} + 227 {%- set prev_nt.found = true -%} + 228 {%- endif -%} + 229 {%- endif -%} + 230 {%- endfor -%} + 231 {%- endif -%} + 232 {%- set continue_same_model_turn = (role == 'model' and prev_nt.role == 'assistant') -%} + 233 {%- if not continue_same_model_turn -%} + 234 {{- '<|turn>' + role + '\n' }} + 235 {%- endif -%} + 236 + 237 {#- Render reasoning/reasoning_content as thinking channel -#} + 238 {%- set thinking_text = message.get('reasoning') or message.get('reasoning_content') -%} + 239 {%- if thinking_text and loop.index0 > ns_turn.last_user_idx and message.get('tool_calls') -%} + 240 {{- '<|channel>thought\n' + thinking_text + '\n' -}} + 241 {%- endif -%} + 242 + 243 {%- if message['tool_calls'] -%} + 244 {%- for tool_call in message['tool_calls'] -%} + 245 {%- set function = tool_call['function'] -%} + 246 {{- '<|tool_call>call:' + function['name'] + '{' -}} + 247 {%- if function['arguments'] is mapping -%} + 248 {%- set ns_args = namespace(found_first=false) -%} + 249 {%- for key, value in function['arguments'] | dictsort -%} + 250 {%- if ns_args.found_first %},{% endif -%} + 251 {%- set ns_args.found_first = true -%} + 252 {{- key -}}:{{- format_argument(value, escape_keys=False) -}} + 253 {%- endfor -%} + 254 {%- elif function['arguments'] is string -%} + 255 {{- function['arguments'] -}} + 256 {%- endif -%} + 257 {{- '}' -}} + 258 {%- endfor -%} + 259 {%- set ns.prev_message_type = 'tool_call' -%} + 260 {%- endif -%} + 261 + 262 {%- set ns_tr_out = namespace(flag=false) -%} + 263 {%- if message.get('tool_responses') -%} + 264 {#- Legacy: tool_responses embedded on the assistant message (Google/Gemma native) -#} + 265 {%- for tool_response in message['tool_responses'] -%} + 266 {{- format_tool_response_block(tool_response['name'] | default('unknown'), tool_response['response']) -}} + 267 {%- set ns_tr_out.flag = true -%} + 268 {%- set ns.prev_message_type = 'tool_response' -%} + 269 {%- endfor -%} + 270 {%- elif message.get('tool_calls') -%} + 271 {#- OpenAI Chat Completions: forward-scan consecutive role:tool messages -#} + 272 {%- set ns_tool_scan = namespace(stopped=false) -%} + 273 {%- for k in range(loop.index0 + 1, loop_messages | length) -%} + 274 {%- if ns_tool_scan.stopped -%} + 275 {%- elif loop_messages[k]['role'] != 'tool' -%} + 276 {%- set ns_tool_scan.stopped = true -%} + 277 {%- else -%} + 278 {%- set follow = loop_messages[k] -%} + 279 {#- Resolve tool_call_id to function name -#} + 280 {%- set ns_tname = namespace(name=follow.get('name') | default('unknown')) -%} + 281 {%- for tc in message['tool_calls'] -%} + 282 {%- if tc.get('id') == follow.get('tool_call_id') -%} + 283 {%- set ns_tname.name = tc['function']['name'] -%} + 284 {%- endif -%} + 285 {%- endfor -%} + 286 {#- Handle content as string or content-parts array -#} + 287 {%- set tool_body = follow.get('content') -%} + 288 {%- if tool_body is string -%} + 289 {{- format_tool_response_block(ns_tname.name, tool_body) -}} + 290 {%- elif tool_body is sequence and tool_body is not string -%} + 291 {%- set ns_txt = namespace(s='') -%} + 292 {%- for part in tool_body -%} + 293 {%- if part.get('type') == 'text' -%} + 294 {%- set ns_txt.s = ns_txt.s + (part.get('text') | default('')) -%} + 295 {%- endif -%} + 296 {%- endfor -%} + 297 {{- format_tool_response_block(ns_tname.name, ns_txt.s) -}} + 298 {%- for part in tool_body -%} + 299 {%- if part.get('type') == 'image' -%} + 300 {{- '<|image|>' -}} + 301 {%- elif part.get('type') == 'audio' -%} + 302 {{- '<|audio|>' -}} + 303 {%- elif part.get('type') == 'video' -%} + 304 {{- '<|video|>' -}} + 305 {%- endif -%} + 306 {%- endfor -%} + 307 {%- else -%} + 308 {{- format_tool_response_block(ns_tname.name, tool_body) -}} + 309 {%- endif -%} + 310 {%- set ns_tr_out.flag = true -%} + 311 {%- set ns.prev_message_type = 'tool_response' -%} + 312 {%- endif -%} + 313 {%- endfor -%} + 314 {%- endif -%} + 315 + 316 {%- set captured_content -%} + 317 {%- if message['content'] is string -%} + 318 {%- if role == 'model' -%} + 319 {{- strip_thinking(message['content']) -}} + 320 {%- else -%} + 321 {{- message['content'] | trim -}} + 322 {%- endif -%} + 323 {%- elif message['content'] is sequence -%} + 324 {%- for item in message['content'] -%} + 325 {%- if item['type'] == 'text' -%} + 326 {%- if role == 'model' -%} + 327 {{- strip_thinking(item['text']) -}} + 328 {%- else -%} + 329 {{- item['text'] | trim -}} + 330 {%- endif -%} + 331 {%- elif item['type'] == 'image' -%} + 332 {{- '<|image|>' -}} + 333 {%- set ns.prev_message_type = 'image' -%} + 334 {%- elif item['type'] == 'audio' -%} + 335 {{- '<|audio|>' -}} + 336 {%- set ns.prev_message_type = 'audio' -%} + 337 {%- elif item['type'] == 'video' -%} + 338 {{- '<|video|>' -}} + 339 {%- set ns.prev_message_type = 'video' -%} + 340 {%- endif -%} + 341 {%- endfor -%} + 342 {%- endif -%} + 343 {%- endset -%} + 344 + 345 {{- captured_content -}} + 346 {%- set has_content = captured_content | trim | length > 0 -%} + 347 + 348 {%- if ns.prev_message_type == 'tool_call' and not ns_tr_out.flag -%} + 349 {{- '<|tool_response>' -}} + 350 {%- elif not (ns_tr_out.flag and not has_content) -%} + 351 {{- '\n' -}} + 352 {%- endif -%} + 353 {%- endif -%} + 354 {%- endfor -%} + 355 + 356 {%- if add_generation_prompt -%} + 357 {%- if ns.prev_message_type != 'tool_response' and ns.prev_message_type != 'tool_call' -%} + 358 {{- '<|turn>model\n' -}} + 359 {%- if not enable_thinking | default(false) -%} + 360 {{- '<|channel>thought\n' -}} + 361 {%- endif -%} + 362 {%- endif -%} + 363 {%- endif -%}*/ + +// Every rule below cites "tpl L" line numbers from the numbered template +// text above. +// +// BOS NOTE (tpl L177 `{{- bos_token -}}`): the template emits because +// HF's apply_chat_template is expected to produce the FULL token stream. Our +// renderer feeds dllm_capi_generate, whose run_generate tokenizes with +// prepend_bos = vocab.add_bos (dllm.cpp src/capi.cpp:230-231), and gemma4 +// GGUFs carry add_bos=true - the C side prepends BOS itself. A literal +// "" here would therefore double it, so RenderGemma4 NEVER emits it. + +package main + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "io" + "sort" + "strings" + + pb "github.com/mudler/LocalAI/pkg/grpc/proto" +) + +// Gemma4 marker vocabulary (special tokens referenced by the template). +const ( + gemma4StringDelim = `<|"|>` // string delimiter, tpl L119 etc. + gemma4TurnOpen = "<|turn>" // tpl L180/L234/L358 + // gemma4TurnEnd is the turn terminator as the MODEL emits it: the output + // parser (gemma4_parser.go) must trigger on the bare token, while the + // renderer appends the template's inter-turn newline (gemma4TurnClose). + gemma4TurnEnd = "" // tpl L204/L351 + gemma4TurnClose = gemma4TurnEnd + "\n" // tpl L204/L351 + gemma4ThinkToken = "<|think|>\n" // tpl L183 + gemma4ToolOpen = "<|tool>" // tpl L198 + gemma4ToolClose = "" // tpl L200 + gemma4ToolCallOpen = "<|tool_call>" // tpl L246 + gemma4ToolCallClose = "" // tpl L257 + gemma4ToolResponseOpen = "<|tool_response>" // tpl L161/L349 + gemma4ToolResponseClose = "" // tpl L172 + gemma4ChannelOpen = "<|channel>" // tpl L240/L360 + gemma4ChannelClose = "" // tpl L240/L360 + gemma4ThoughtChannel = gemma4ChannelOpen + "thought\n" +) + +// gemma4ToolCall is the wire shape LocalAI core puts into pb.Message.ToolCalls +// (core/schema/message.go ToolCall marshalled by Messages.ToProto): a JSON +// array of {"index":..,"id":..,"type":..,"function":{"name":..,"arguments":..}}. +type gemma4ToolCall struct { + ID string `json:"id"` + Function struct { + Name string `json:"name"` + // Arguments is a JSON-encoded string in the OpenAI wire format + // (schema.FunctionCall.Arguments is a string), but kept raw here so a + // template-native object also works. See renderGemma4ToolCallArgs. + Arguments json.RawMessage `json:"arguments"` + } `json:"function"` +} + +// RenderGemma4 renders an OpenAI-style message list (plus the request's tools +// JSON array) into the gemma4 prompt string, replicating the GGUF chat +// template above byte-for-byte - except for the leading (see BOS NOTE). +// +// enableThinking maps to the template's enable_thinking flag (ds4 convention: +// Metadata["enable_thinking"]); addGenerationPrompt to add_generation_prompt. +func RenderGemma4(msgs []*pb.Message, toolsJSON string, enableThinking bool, addGenerationPrompt bool) (string, error) { + // Fail loud on roles the template does not know about. The jinja would + // happily render any role as a generic turn; we reject instead so typos + // surface at the API boundary rather than as silent bad prompts. + for i, m := range msgs { + switch m.GetRole() { + case "system", "developer", "user", "assistant", "tool": + default: + return "", fmt.Errorf("dllm: gemma4 renderer: unknown role %q in message %d", m.GetRole(), i) + } + } + + tools, err := parseGemma4Tools(toolsJSON) + if err != nil { + return "", err + } + + var b strings.Builder + // ns.prev_message_type (tpl L175); "" stands for jinja None. + prev := "" + + // System/tool-definitions block (tpl L178-L205). + loopMsgs := msgs + firstIsSystem := len(msgs) > 0 && (msgs[0].GetRole() == "system" || msgs[0].GetRole() == "developer") + if enableThinking || len(tools) > 0 || firstIsSystem { + b.WriteString(gemma4TurnOpen + "system\n") // tpl L180 + if enableThinking { + // Thinking token at the very top of the first system turn, + // tpl L182-L185. NOTE: prev_message_type='think' (not used by + // the ending logic, mirrored for fidelity). + b.WriteString(gemma4ThinkToken) + prev = "think" + } + if firstIsSystem { + // First system/developer message is folded into this turn and + // consumed (loop_messages = messages[1:]), tpl L186-L195. + // pb.Message.Content is always a flattened string (core/schema/ + // message.go ToProto), so only the `is string` branch applies. + b.WriteString(strings.TrimSpace(msgs[0].GetContent())) + loopMsgs = msgs[1:] + } + if len(tools) > 0 { + // One <|tool>declaration:... block per tool, tpl L196-L203. + for _, t := range tools { + b.WriteString(gemma4ToolOpen) + b.WriteString(strings.TrimSpace(formatGemma4FunctionDeclaration(t))) + b.WriteString(gemma4ToolClose) + } + prev = "tool" + } + b.WriteString(gemma4TurnClose) // tpl L204 + } + + // Pre-scan: last user message index for the reasoning guard, tpl L207-L213. + lastUserIdx := -1 + for i, m := range loopMsgs { + if m.GetRole() == "user" { + lastUserIdx = i + } + } + + // Message loop, tpl L215-L354. role=tool messages are skipped here: they + // are rendered by the forward-scan from their assistant tool_calls turn. + // consumedTool tracks which of them a forward-scan actually rendered, so + // an orphan tool message (no preceding assistant tool_calls turn) fails + // loud below instead of vanishing from the prompt. + consumedTool := make([]bool, len(loopMsgs)) + for i, m := range loopMsgs { + if m.GetRole() == "tool" { + continue + } + prev = "" // tpl L218 + role := m.GetRole() + if role == "assistant" { + role = "model" // tpl L219 + } + + // Continuation: suppress duplicate <|turn>model when the previous + // non-tool message was also assistant, tpl L220-L235. + prevNonToolRole := "" + for j := i - 1; j >= 0; j-- { + if loopMsgs[j].GetRole() != "tool" { + prevNonToolRole = loopMsgs[j].GetRole() + break + } + } + if !(role == "model" && prevNonToolRole == "assistant") { + b.WriteString(gemma4TurnOpen + role + "\n") + } + + var toolCalls []gemma4ToolCall + if tc := m.GetToolCalls(); strings.TrimSpace(tc) != "" { + if err := json.Unmarshal([]byte(tc), &toolCalls); err != nil { + return "", fmt.Errorf("dllm: gemma4 renderer: message %d: invalid tool_calls JSON: %w", i, err) + } + } + + // reasoning_content renders as a thought channel ONLY on the + // tool-calling turn after the last user message, tpl L237-L241. + if rc := m.GetReasoningContent(); rc != "" && i > lastUserIdx && len(toolCalls) > 0 { + b.WriteString(gemma4ThoughtChannel + rc + "\n" + gemma4ChannelClose) + } + + // Tool calls: <|tool_call>call:name{args}, concatenated + // without separators, tpl L243-L260. + if len(toolCalls) > 0 { + for _, tc := range toolCalls { + b.WriteString(gemma4ToolCallOpen + "call:" + tc.Function.Name + "{") + b.WriteString(renderGemma4ToolCallArgs(tc.Function.Arguments)) + b.WriteString("}" + gemma4ToolCallClose) + } + prev = "tool_call" + } + + // Tool responses: pb has no legacy tool_responses field (tpl + // L263-L269 is unreachable through proto), so only the OpenAI + // forward-scan of consecutive role=tool messages applies, + // tpl L270-L313. + trOut := false + if len(toolCalls) > 0 { + for k := i + 1; k < len(loopMsgs); k++ { + if loopMsgs[k].GetRole() != "tool" { + break + } + follow := loopMsgs[k] + // Resolve tool_call_id to the function name; the message's + // own name (default 'unknown') is the fallback, tpl L278-L285. + name := follow.GetName() + if name == "" { + name = "unknown" + } + for _, tc := range toolCalls { + if tc.ID == follow.GetToolCallId() { + name = tc.Function.Name + } + } + // pb content is a flattened string: only the string body + // branch (tpl L287-L289) is reachable. + b.WriteString(formatGemma4ToolResponseBlock(name, follow.GetContent())) + consumedTool[k] = true + trOut = true + prev = "tool_response" + } + } + + // Captured content, tpl L316-L345. Model content gets thinking + // channels stripped (strip_thinking, tpl L148-L158); other roles are + // trimmed. pb content is a flattened string: the content-parts array + // branch (tpl L322-L342, incl. <|image|> markers) is unreachable. + var content string + if role == "model" { + content = stripGemma4Thinking(m.GetContent()) + } else { + content = strings.TrimSpace(m.GetContent()) + } + b.WriteString(content) + hasContent := strings.TrimSpace(content) != "" // tpl L346 + + // Turn ending, tpl L348-L353: a tool_calls turn with no rendered + // responses ends on an OPEN <|tool_response> (the runtime fills it); + // a turn whose only payload was tool responses stays open (no + // ); everything else closes the turn. + if prev == "tool_call" && !trOut { + b.WriteString(gemma4ToolResponseOpen) + } else if !(trOut && !hasContent) { + b.WriteString(gemma4TurnClose) + } + } + + // Fail loud on orphan role:tool messages no forward-scan consumed (e.g. a + // tool message with no preceding assistant tool_calls turn): the jinja + // would silently drop them from the prompt; we surface the bad request + // instead, same philosophy as the unknown-role check above. + for i, m := range loopMsgs { + if m.GetRole() == "tool" && !consumedTool[i] { + return "", fmt.Errorf("dllm: gemma4 renderer: orphan tool message %d: no preceding assistant tool_calls turn consumed it", i+(len(msgs)-len(loopMsgs))) + } + } + + // Generation prompt, tpl L356-L362: never reopened right after a + // tool_call/tool_response (the model continues its own open turn); the + // thought channel is pre-opened only when thinking is NOT enabled. + if addGenerationPrompt && prev != "tool_response" && prev != "tool_call" { + b.WriteString(gemma4TurnOpen + "model\n") + if !enableThinking { + b.WriteString(gemma4ThoughtChannel + gemma4ChannelClose) + } + } + return b.String(), nil +} + +// parseGemma4Tools decodes the request's OpenAI tools JSON array +// ([{"type":"function","function":{...}}]). Numbers are kept as json.Number +// so 42 / 3.5 / 1.0 render exactly as jinja renders the Python values. +// An empty/null/[] input is jinja-falsy (tpl L196 `{%- if tools -%}`). +func parseGemma4Tools(toolsJSON string) ([]map[string]any, error) { + s := strings.TrimSpace(toolsJSON) + if s == "" || s == "null" { + return nil, nil + } + v, err := decodeGemma4JSON([]byte(s)) + if err != nil { + return nil, fmt.Errorf("dllm: gemma4 renderer: invalid tools JSON: %w", err) + } + list, ok := v.([]any) + if !ok { + return nil, fmt.Errorf("dllm: gemma4 renderer: tools JSON is not an array") + } + tools := make([]map[string]any, 0, len(list)) + for i, e := range list { + m, ok := e.(map[string]any) + if !ok { + return nil, fmt.Errorf("dllm: gemma4 renderer: tools[%d] is not an object", i) + } + tools = append(tools, m) + } + return tools, nil +} + +// decodeGemma4JSON unmarshals with UseNumber so numeric literals survive +// verbatim ("1.0" stays "1.0", matching jinja's rendering of Python 1.0). +// Trailing non-whitespace after the first value is rejected: json.Decoder +// stops at the value boundary, and silently ignoring the rest would render +// a prompt from a prefix of what the caller sent. +func decodeGemma4JSON(data []byte) (any, error) { + dec := json.NewDecoder(bytes.NewReader(data)) + dec.UseNumber() + var v any + if err := dec.Decode(&v); err != nil { + return nil, err + } + if err := dec.Decode(new(any)); !errors.Is(err, io.EOF) { + return nil, fmt.Errorf("trailing data after JSON value") + } + return v, nil +} + +// renderGemma4ToolCallArgs renders the arguments between the braces of +// call:name{...}, tpl L247-L256: a mapping renders as dictsorted +// key:format_argument(value, escape_keys=False) pairs; a string renders +// verbatim; anything else renders nothing (mirroring the if/elif). +// +// DIVERGENCE NOTE: through pb the arguments arrive as a JSON-encoded string +// (OpenAI wire format; schema.FunctionCall.Arguments is a string). HF/vLLM +// parse that string into a dict before applying the template, so we do the +// same: a string that parses as a JSON object takes the mapping branch; only +// a non-object string falls back to the template's verbatim string branch. +// +// Also note: string values containing the literal <|"|> delimiter render +// unescaped (prompt-structure injection), byte-faithful to the jinja which +// has identical behavior. +func renderGemma4ToolCallArgs(raw json.RawMessage) string { + if len(bytes.TrimSpace(raw)) == 0 { + return "" + } + v, err := decodeGemma4JSON(raw) + if err != nil { + // Not JSON at all: treat like the template's string branch on the + // raw bytes (never drop caller data silently). + return string(raw) + } + if s, ok := v.(string); ok { + inner, err := decodeGemma4JSON([]byte(s)) + if err == nil { + if m, ok := inner.(map[string]any); ok { + v = m + } else { + return s // tpl L253-L254: string renders verbatim + } + } else { + return s + } + } + m, ok := v.(map[string]any) + if !ok { + return "" // tpl L247-L255: non-mapping, non-string renders nothing + } + parts := make([]string, 0, len(m)) + for _, k := range gemma4DictsortKeys(m) { + parts = append(parts, k+":"+formatGemma4Argument(m[k], false)) + } + return strings.Join(parts, ",") +} + +// formatGemma4Argument is the format_argument macro, tpl L118-L147: +// strings get <|"|> delimiters, booleans lower-case, mappings dictsorted +// {key:value} (keys delimited only when escape_keys), sequences [..], +// everything else verbatim (json.Number keeps its literal; null renders +// "None" exactly as jinja renders Python None). +func formatGemma4Argument(v any, escapeKeys bool) string { + switch a := v.(type) { + case string: + return gemma4StringDelim + a + gemma4StringDelim + case bool: + if a { + return "true" + } + return "false" + case map[string]any: + var b strings.Builder + b.WriteString("{") + for i, k := range gemma4DictsortKeys(a) { + if i > 0 { + b.WriteString(",") + } + if escapeKeys { + b.WriteString(gemma4StringDelim + k + gemma4StringDelim) + } else { + b.WriteString(k) + } + b.WriteString(":" + formatGemma4Argument(a[k], escapeKeys)) + } + b.WriteString("}") + return b.String() + case []any: + var b strings.Builder + b.WriteString("[") + for i, item := range a { + if i > 0 { + b.WriteString(",") + } + b.WriteString(formatGemma4Argument(item, escapeKeys)) + } + b.WriteString("]") + return b.String() + case json.Number: + return a.String() + case nil: + return "None" // jinja renders Python None as "None" + default: + return fmt.Sprint(a) + } +} + +// gemma4DictsortKeys mirrors jinja's dictsort default: case-insensitive sort +// by key. Distinct keys equal under lowering tie-break on the raw key for +// determinism (Go maps have no insertion order to preserve). +func gemma4DictsortKeys(m map[string]any) []string { + keys := make([]string, 0, len(m)) + for k := range m { + keys = append(keys, k) + } + sort.Slice(keys, func(i, j int) bool { + li, lj := strings.ToLower(keys[i]), strings.ToLower(keys[j]) + if li != lj { + return li < lj + } + return keys[i] < keys[j] + }) + return keys +} + +// gemma4Lookup is jinja's value['key'] on a value of unknown type: missing +// keys and non-mapping receivers yield Undefined (nil here). +func gemma4Lookup(v any, key string) any { + if m, ok := v.(map[string]any); ok { + return m[key] + } + return nil +} + +// gemma4Truthy is jinja truthiness for the decoded JSON value set. +func gemma4Truthy(v any) bool { + switch a := v.(type) { + case nil: + return false + case bool: + return a + case string: + return a != "" + case json.Number: + f, err := a.Float64() + return err != nil || f != 0 + case map[string]any: + return len(a) > 0 + case []any: + return len(a) > 0 + default: + return true + } +} + +// gemma4Str renders a scalar the way `{{ value }}` would (Undefined -> ""). +func gemma4Str(v any) string { + switch a := v.(type) { + case nil: + return "" + case string: + return a + case json.Number: + return a.String() + case bool: + if a { + return "True" // Python bool repr; only reachable via odd schemas + } + return "False" + default: + return fmt.Sprint(a) + } +} + +// gemma4TypeUpper is `value['type'] | upper` (Undefined | upper -> ""). +func gemma4TypeUpper(v any) string { + return strings.ToUpper(gemma4Str(gemma4Lookup(v, "type"))) +} + +// gemma4QuoteJoin renders required-style lists: <|"|>item<|"|> joined by ',' +// (tpl L36-L41, L72-L78, L96-L101). +func gemma4QuoteJoin(list []any) string { + parts := make([]string, 0, len(list)) + for _, item := range list { + parts = append(parts, gemma4StringDelim+gemma4Str(item)+gemma4StringDelim) + } + return strings.Join(parts, ",") +} + +// formatGemma4FunctionDeclaration is the format_function_declaration macro, +// tpl L86-L117: declaration:name{description:<|"|>..<|"|>[,parameters:{..}] +// [,response:{..}]}. Brace placement (incl. the parameters block being closed +// by the type clause) is replicated exactly, quirks and all. +func formatGemma4FunctionDeclaration(tool map[string]any) string { + fn, _ := tool["function"].(map[string]any) + var b strings.Builder + b.WriteString("declaration:" + gemma4Str(gemma4Lookup(fn, "name"))) + b.WriteString("{description:" + gemma4StringDelim + gemma4Str(gemma4Lookup(fn, "description")) + gemma4StringDelim) + params := gemma4Lookup(fn, "parameters") + if gemma4Truthy(params) { // tpl L89 + b.WriteString(",parameters:{") + if props, ok := gemma4Lookup(params, "properties").(map[string]any); ok && gemma4Truthy(gemma4Lookup(params, "properties")) { // tpl L92 + required, _ := gemma4Lookup(params, "required").([]any) + b.WriteString("properties:{" + formatGemma4Parameters(props, required, false) + "},") + } + if required, ok := gemma4Lookup(params, "required").([]any); ok && len(required) > 0 { // tpl L95 + b.WriteString("required:[" + gemma4QuoteJoin(required) + "],") + } + if gemma4Truthy(gemma4Lookup(params, "type")) { // tpl L102: closes the parameters block + b.WriteString("type:" + gemma4StringDelim + gemma4TypeUpper(params) + gemma4StringDelim + "}") + } + } + if fn != nil { // tpl L106: `'response' in tool_data['function']` + if resp, present := fn["response"]; present { + b.WriteString(",response:{") + if gemma4Truthy(gemma4Lookup(resp, "description")) { + b.WriteString("description:" + gemma4StringDelim + gemma4Str(gemma4Lookup(resp, "description")) + gemma4StringDelim + ",") + } + if gemma4TypeUpper(resp) == "OBJECT" { // tpl L112: closes the response block + b.WriteString("type:" + gemma4StringDelim + gemma4TypeUpper(resp) + gemma4StringDelim + "}") + } + } + } + b.WriteString("}") + return b.String() +} + +// formatGemma4Parameters is the format_parameters macro, tpl L1-L85. Each +// property renders as key:{[description][,enum|items][,nullable][,properties] +// [,required],type:<|"|>TYPE<|"|>} with the comma threading of the macro's +// add_comma flag. +func formatGemma4Parameters(properties map[string]any, required []any, filterKeys bool) string { + _ = required // tpl L1: passed through by callers but never read here + standardKeys := map[string]bool{ // tpl L2 + "description": true, "type": true, "properties": true, "required": true, "nullable": true, + } + var b strings.Builder + foundFirst := false + for _, key := range gemma4DictsortKeys(properties) { + if filterKeys && standardKeys[key] { // tpl L6 + continue + } + value := properties[key] + if foundFirst { + b.WriteString(",") + } + foundFirst = true + b.WriteString(key + ":{") // tpl L9 + addComma := false + comma := func() { + if addComma { + b.WriteString(",") + } else { + addComma = true + } + } + typeUpper := gemma4TypeUpper(value) + + if gemma4Truthy(gemma4Lookup(value, "description")) { // tpl L10-L13 + b.WriteString("description:" + gemma4StringDelim + gemma4Str(gemma4Lookup(value, "description")) + gemma4StringDelim) + addComma = true + } + switch typeUpper { + case "STRING": // tpl L14-L19 + if enum := gemma4Lookup(value, "enum"); gemma4Truthy(enum) { + comma() + b.WriteString("enum:" + formatGemma4Argument(enum, true)) + } + case "ARRAY": // tpl L20-L55 + if items, ok := gemma4Lookup(value, "items").(map[string]any); ok && len(items) > 0 { + comma() + b.WriteString("items:{") + itemsFound := false + for _, itemKey := range gemma4DictsortKeys(items) { + itemValue := items[itemKey] + if itemValue == nil { // tpl L25: `is not none` + continue + } + if itemsFound { + b.WriteString(",") + } + itemsFound = true + switch itemKey { + case "properties": // tpl L29-L34 + b.WriteString("properties:{") + if m, ok := itemValue.(map[string]any); ok { + itemsRequired, _ := items["required"].([]any) + b.WriteString(formatGemma4Parameters(m, itemsRequired, false)) + } + b.WriteString("}") + case "required": // tpl L35-L41 + list, _ := itemValue.([]any) + b.WriteString("required:[" + gemma4QuoteJoin(list) + "]") + case "type": // tpl L42-L47 + if s, ok := itemValue.(string); ok { + b.WriteString("type:" + formatGemma4Argument(strings.ToUpper(s), true)) + } else if list, ok := itemValue.([]any); ok { + upped := make([]any, len(list)) + for li, lv := range list { + upped[li] = strings.ToUpper(gemma4Str(lv)) + } + b.WriteString("type:" + formatGemma4Argument(upped, true)) + } + default: // tpl L48-L49 + b.WriteString(itemKey + ":" + formatGemma4Argument(itemValue, true)) + } + } + b.WriteString("}") + } + } + if gemma4Truthy(gemma4Lookup(value, "nullable")) { // tpl L56-L59 + comma() + b.WriteString("nullable:true") + } + if typeUpper == "OBJECT" { // tpl L60-L80 + if props, ok := gemma4Lookup(value, "properties").(map[string]any); ok { // tpl L61: defined and mapping + comma() + req, _ := gemma4Lookup(value, "required").([]any) + b.WriteString("properties:{" + formatGemma4Parameters(props, req, false) + "}") + } else if vm, ok := value.(map[string]any); ok { // tpl L66 + comma() + req, _ := gemma4Lookup(value, "required").([]any) + b.WriteString("properties:{" + formatGemma4Parameters(vm, req, true) + "}") + } + if req, ok := gemma4Lookup(value, "required").([]any); ok && len(req) > 0 { // tpl L72 + comma() + b.WriteString("required:[" + gemma4QuoteJoin(req) + "]") + } + } + comma() // tpl L81-L82: type is always last and closes the property + b.WriteString("type:" + gemma4StringDelim + typeUpper + gemma4StringDelim + "}") + } + return b.String() +} + +// formatGemma4ToolResponseBlock is the format_tool_response_block macro, +// tpl L160-L173, restricted to the string-response branch: pb tool messages +// carry flattened string content, so the mapping branch is unreachable. +func formatGemma4ToolResponseBlock(toolName, response string) string { + return gemma4ToolResponseOpen + + "response:" + toolName + "{value:" + formatGemma4Argument(response, false) + "}" + + gemma4ToolResponseClose +} + +// stripGemma4Thinking is the strip_thinking macro, tpl L148-L158: split on +// , drop everything from <|channel> onward in each part, trim. +func stripGemma4Thinking(text string) string { + var b strings.Builder + for _, part := range strings.Split(text, gemma4ChannelClose) { + if idx := strings.Index(part, gemma4ChannelOpen); idx >= 0 { + b.WriteString(part[:idx]) + } else { + b.WriteString(part) + } + } + return strings.TrimSpace(b.String()) +} diff --git a/backend/go/dllm/gemma4_renderer_test.go b/backend/go/dllm/gemma4_renderer_test.go new file mode 100755 index 000000000000..3600fbf7a07a --- /dev/null +++ b/backend/go/dllm/gemma4_renderer_test.go @@ -0,0 +1,347 @@ +package main + +// Renderer specs for RenderGemma4 against the canonical gemma4 chat template +// (see the normative template comment in gemma4_renderer.go). +// +// Fixture provenance: +// - "single user message" and "enable_thinking" are the EXACT expected +// decodes from transformers tests/models/diffusion_gemma/ +// test_modeling_diffusion_gemma.py (test_diffusion_gemma_chat_template +// and ..._with_thinking) with ONE difference: the transformers fixtures +// start with "" because apply_chat_template tokenizes the rendered +// text with add_bos. Our prompt goes through dllm_capi_generate, whose +// run_generate already tokenizes with prepend_bos = vocab.add_bos +// (dllm.cpp src/capi.cpp:230-231, true for gemma4), so the renderer must +// NOT emit a literal (it would double) and every expected string +// here drops that leading token. +// - All other expected strings were produced by rendering the verbatim +// GGUF template with jinja2 3.1.2 (bos_token="") and dropping the +// leading "" for the same reason. + +import ( + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + + pb "github.com/mudler/LocalAI/pkg/grpc/proto" +) + +// Two-function tools array used by the tool fixtures (OpenAI wire shape, as +// LocalAI passes it through PredictOptions.Tools). +const testToolsJSON = `[{"type":"function","function":{"name":"get_weather","description":"Get the current weather in a location.","parameters":{"type":"object","properties":{"location":{"type":"string","description":"The city name."},"unit":{"type":"string","enum":["celsius","fahrenheit"]}},"required":["location"]}}},{"type":"function","function":{"name":"get_time","description":"Get the current time in a timezone.","parameters":{"type":"object","properties":{"timezone":{"type":"string","description":"IANA timezone name."}},"required":["timezone"]}}}]` + +// The <|tool>... block the template renders for testToolsJSON inside +// the system turn (jinja2-verified). +const testToolsBlock = `<|tool>declaration:get_weather{description:<|"|>Get the current weather in a location.<|"|>,parameters:{properties:{location:{description:<|"|>The city name.<|"|>,type:<|"|>STRING<|"|>},unit:{enum:[<|"|>celsius<|"|>,<|"|>fahrenheit<|"|>],type:<|"|>STRING<|"|>}},required:[<|"|>location<|"|>],type:<|"|>OBJECT<|"|>}}<|tool>declaration:get_time{description:<|"|>Get the current time in a timezone.<|"|>,parameters:{properties:{timezone:{description:<|"|>IANA timezone name.<|"|>,type:<|"|>STRING<|"|>}},required:[<|"|>timezone<|"|>],type:<|"|>OBJECT<|"|>}}` + +// A single tool exercising the deep format_parameters branches: array items +// (string-typed and nested-array), nullable, enum+nullable, nested object +// properties/required, and a response declaration. +const complexToolsJSON = `[{"type":"function","function":{"name":"complex_tool","description":"A complex tool.","parameters":{"type":"object","properties":{"tags":{"type":"array","description":"Tags.","items":{"type":"string"}},"matrix":{"type":"array","items":{"type":"array","items":{"type":"number"}}},"opts":{"type":"object","description":"Options.","properties":{"depth":{"type":"integer","nullable":true}},"required":["depth"]},"mode":{"type":"string","enum":["a","b"],"nullable":true}},"required":["tags","opts"]},"response":{"description":"The result.","type":"object"}}}]` + +// jinja2-verified render of complexToolsJSON. Notable template quirks pinned +// here: nested array items go through format_argument with ESCAPED keys and +// an un-uppercased type (<|"|>type<|"|>:<|"|>number<|"|>), while direct item +// types are uppercased; properties dictsort case-insensitively. +const complexToolsBlock = `<|tool>declaration:complex_tool{description:<|"|>A complex tool.<|"|>,parameters:{properties:{matrix:{items:{items:{<|"|>type<|"|>:<|"|>number<|"|>},type:<|"|>ARRAY<|"|>},type:<|"|>ARRAY<|"|>},mode:{enum:[<|"|>a<|"|>,<|"|>b<|"|>],nullable:true,type:<|"|>STRING<|"|>},opts:{description:<|"|>Options.<|"|>,properties:{depth:{nullable:true,type:<|"|>INTEGER<|"|>}},required:[<|"|>depth<|"|>],type:<|"|>OBJECT<|"|>},tags:{description:<|"|>Tags.<|"|>,items:{type:<|"|>STRING<|"|>},type:<|"|>ARRAY<|"|>}},required:[<|"|>tags<|"|>,<|"|>opts<|"|>],type:<|"|>OBJECT<|"|>},response:{description:<|"|>The result.<|"|>,type:<|"|>OBJECT<|"|>}}` + +type renderGemma4Case struct { + msgs []*pb.Message + toolsJSON string + enableThinking bool + noGenerationPrompt bool // inverted so the zero value is the common case + expected string +} + +var _ = Describe("RenderGemma4", func() { + DescribeTable("renders the canonical gemma4 prompt", + func(c renderGemma4Case) { + out, err := RenderGemma4(c.msgs, c.toolsJSON, c.enableThinking, !c.noGenerationPrompt) + Expect(err).ToNot(HaveOccurred()) + Expect(out).To(Equal(c.expected)) + // The C-ABI generate prepends BOS itself: a literal + // anywhere in the rendered prompt would double-encode it. + Expect(out).ToNot(ContainSubstring("")) + }, + + // transformers fixture (test_diffusion_gemma_chat_template), sans : + // default thinking pre-opens an EMPTY thought channel in the + // generation prompt. + Entry("single user message, default (no thinking)", renderGemma4Case{ + msgs: []*pb.Message{ + {Role: "user", Content: "Write a long essay about Portugal."}, + }, + expected: "<|turn>user\nWrite a long essay about Portugal.\n<|turn>model\n<|channel>thought\n", + }), + + // transformers fixture (test_diffusion_gemma_chat_template_with_thinking), + // sans : a system turn carrying <|think|> and NO auto-opened + // thought channel. + Entry("enable_thinking=true", renderGemma4Case{ + msgs: []*pb.Message{ + {Role: "user", Content: "Write a long essay about Portugal."}, + }, + enableThinking: true, + expected: "<|turn>system\n<|think|>\n\n<|turn>user\nWrite a long essay about Portugal.\n<|turn>model\n", + }), + + Entry("multi-turn user/assistant/user", renderGemma4Case{ + msgs: []*pb.Message{ + {Role: "user", Content: "Hello, who are you?"}, + {Role: "assistant", Content: "I am Gemma, a helpful assistant."}, + {Role: "user", Content: "Tell me a joke."}, + }, + expected: "<|turn>user\nHello, who are you?\n<|turn>model\nI am Gemma, a helpful assistant.\n<|turn>user\nTell me a joke.\n<|turn>model\n<|channel>thought\n", + }), + + // tpl L178-L195: a leading system message is folded into the system + // turn (trimmed) and consumed from the loop. + Entry("system message folds into the system turn", renderGemma4Case{ + msgs: []*pb.Message{ + {Role: "system", Content: "You are a pirate."}, + {Role: "user", Content: "Hello!"}, + }, + expected: "<|turn>system\nYou are a pirate.\n<|turn>user\nHello!\n<|turn>model\n<|channel>thought\n", + }), + + // tpl L182-L185: <|think|> goes at the very top of the SAME system + // turn, before the system prompt text. + Entry("system message with enable_thinking shares the turn", renderGemma4Case{ + msgs: []*pb.Message{ + {Role: "system", Content: "You are a pirate."}, + {Role: "user", Content: "Hello!"}, + }, + enableThinking: true, + expected: "<|turn>system\n<|think|>\nYou are a pirate.\n<|turn>user\nHello!\n<|turn>model\n", + }), + + // tpl L196-L203: tool declarations render in the system turn, one + // <|tool>declaration:... block per tool, no separators. + Entry("tools array (two functions)", renderGemma4Case{ + msgs: []*pb.Message{ + {Role: "user", Content: "What is the weather in Tokyo?"}, + }, + toolsJSON: testToolsJSON, + expected: "<|turn>system\n" + testToolsBlock + "\n<|turn>user\nWhat is the weather in Tokyo?\n<|turn>model\n<|channel>thought\n", + }), + + // format_parameters deep branches (tpl L1-L85) + response declaration + // (tpl L106-L116). + Entry("complex tool schema (array items, nullable, nested object, response)", renderGemma4Case{ + msgs: []*pb.Message{ + {Role: "user", Content: "go"}, + }, + toolsJSON: complexToolsJSON, + expected: "<|turn>system\n" + complexToolsBlock + "\n<|turn>user\ngo\n<|turn>model\n<|channel>thought\n", + }), + + // tpl L243-L313: assistant tool_calls render as + // <|tool_call>call:name{args}; the following role=tool + // message renders inline as <|tool_response>response:name{value:..} + // ; the model turn stays OPEN (no , no new + // generation prompt) so the model continues after the response. + Entry("assistant tool_calls + role=tool result", renderGemma4Case{ + msgs: []*pb.Message{ + {Role: "user", Content: "What is the weather in Tokyo?"}, + {Role: "assistant", Content: "", ToolCalls: `[{"index":0,"id":"call_1","type":"function","function":{"name":"get_weather","arguments":"{\"location\":\"Tokyo\",\"unit\":\"celsius\"}"}}]`}, + {Role: "tool", ToolCallId: "call_1", Content: "Sunny, 22 degrees celsius."}, + }, + toolsJSON: testToolsJSON, + expected: "<|turn>system\n" + testToolsBlock + "\n<|turn>user\nWhat is the weather in Tokyo?\n<|turn>model\n" + `<|tool_call>call:get_weather{location:<|"|>Tokyo<|"|>,unit:<|"|>celsius<|"|>}<|tool_response>response:get_weather{value:<|"|>Sunny, 22 degrees celsius.<|"|>}`, + }), + + // tpl L348-L349: a tool_calls turn with no rendered responses ends + // on an OPEN <|tool_response> marker for the runtime to fill, and + // add_generation_prompt adds nothing (tpl L357). + Entry("assistant tool_calls without a result leaves <|tool_response> open", renderGemma4Case{ + msgs: []*pb.Message{ + {Role: "user", Content: "What is the weather in Tokyo?"}, + {Role: "assistant", Content: "", ToolCalls: `[{"index":0,"id":"call_1","type":"function","function":{"name":"get_weather","arguments":"{\"location\":\"Tokyo\",\"unit\":\"celsius\"}"}}]`}, + }, + toolsJSON: testToolsJSON, + expected: "<|turn>system\n" + testToolsBlock + "\n<|turn>user\nWhat is the weather in Tokyo?\n<|turn>model\n" + `<|tool_call>call:get_weather{location:<|"|>Tokyo<|"|>,unit:<|"|>celsius<|"|>}<|tool_response>`, + }), + + // tpl L237-L241: reasoning_content renders as a thought channel only + // on a tool-calling turn after the last user message. + Entry("reasoning_content with tool_calls renders the thought channel", renderGemma4Case{ + msgs: []*pb.Message{ + {Role: "user", Content: "weather?"}, + {Role: "assistant", Content: "", ReasoningContent: "I should call the tool", ToolCalls: `[{"index":0,"id":"c1","type":"function","function":{"name":"get_weather","arguments":"{\"location\":\"Tokyo\"}"}}]`}, + {Role: "tool", ToolCallId: "c1", Content: "Sunny"}, + }, + expected: "<|turn>user\nweather?\n<|turn>model\n<|channel>thought\nI should call the tool\n" + `<|tool_call>call:get_weather{location:<|"|>Tokyo<|"|>}<|tool_response>response:get_weather{value:<|"|>Sunny<|"|>}`, + }), + + // tpl L220-L235: the assistant answer following its own tool round + // continues the SAME model turn (no second <|turn>model). + Entry("tool round then final assistant answer then user", renderGemma4Case{ + msgs: []*pb.Message{ + {Role: "user", Content: "weather?"}, + {Role: "assistant", Content: "", ToolCalls: `[{"index":0,"id":"c1","type":"function","function":{"name":"get_weather","arguments":"{\"location\":\"Tokyo\"}"}}]`}, + {Role: "tool", ToolCallId: "c1", Content: "Sunny"}, + {Role: "assistant", Content: "It is sunny."}, + {Role: "user", Content: "thanks"}, + }, + expected: "<|turn>user\nweather?\n<|turn>model\n" + `<|tool_call>call:get_weather{location:<|"|>Tokyo<|"|>}<|tool_response>response:get_weather{value:<|"|>Sunny<|"|>}` + "It is sunny.\n<|turn>user\nthanks\n<|turn>model\n<|channel>thought\n", + }), + + // format_argument (tpl L118-L147): numbers keep their JSON literal, + // booleans lower-case, nested maps have unquoted dictsorted keys, + // arrays bracketed; top-level args are dictsorted case-insensitively. + Entry("tool_call argument types (number/bool/nested/array)", renderGemma4Case{ + msgs: []*pb.Message{ + {Role: "user", Content: "go"}, + {Role: "assistant", Content: "", ToolCalls: `[{"index":0,"id":"c1","type":"function","function":{"name":"f","arguments":"{\"count\":42,\"ratio\":3.5,\"flag\":true,\"off\":false,\"nested\":{\"x\":\"y\",\"n\":7},\"list\":[\"a\",1,true]}"}}]`}, + }, + expected: "<|turn>user\ngo\n<|turn>model\n" + `<|tool_call>call:f{count:42,flag:true,list:[<|"|>a<|"|>,1,true],nested:{n:7,x:<|"|>y<|"|>},off:false,ratio:3.5}<|tool_response>`, + }), + + // jinja dictsort is case-insensitive: alpha sorts before Beta. + Entry("tool_call argument dictsort is case-insensitive", renderGemma4Case{ + msgs: []*pb.Message{ + {Role: "user", Content: "go"}, + {Role: "assistant", Content: "", ToolCalls: `[{"index":0,"id":"c1","type":"function","function":{"name":"f","arguments":"{\"Beta\":1,\"alpha\":2}"}}]`}, + }, + expected: "<|turn>user\ngo\n<|turn>model\n<|tool_call>call:f{alpha:2,Beta:1}<|tool_response>", + }), + + // jinja renders Python None as "None" (round-trips through vLLM's + // parser, which lowers "none" back to null). + Entry("tool_call null argument renders as None", renderGemma4Case{ + msgs: []*pb.Message{ + {Role: "user", Content: "go"}, + {Role: "assistant", Content: "", ToolCalls: `[{"index":0,"id":"c1","type":"function","function":{"name":"f","arguments":"{\"maybe\":null}"}}]`}, + }, + expected: "<|turn>user\ngo\n<|turn>model\n<|tool_call>call:f{maybe:None}<|tool_response>", + }), + + Entry("tool_call empty arguments render empty braces", renderGemma4Case{ + msgs: []*pb.Message{ + {Role: "user", Content: "go"}, + {Role: "assistant", Content: "", ToolCalls: `[{"index":0,"id":"c1","type":"function","function":{"name":"f","arguments":"{}"}}]`}, + }, + expected: "<|turn>user\ngo\n<|turn>model\n<|tool_call>call:f{}<|tool_response>", + }), + + // tpl L253-L254: a non-object arguments string renders verbatim. + Entry("tool_call non-object string arguments render verbatim", renderGemma4Case{ + msgs: []*pb.Message{ + {Role: "user", Content: "go"}, + {Role: "assistant", Content: "", ToolCalls: `[{"index":0,"id":"c1","type":"function","function":{"name":"f","arguments":"just text"}}]`}, + }, + expected: "<|turn>user\ngo\n<|turn>model\n<|tool_call>call:f{just text}<|tool_response>", + }), + + // tpl L278-L285: unmatched tool_call_id falls back to the tool + // message's own name. + Entry("tool result name falls back when tool_call_id does not match", renderGemma4Case{ + msgs: []*pb.Message{ + {Role: "user", Content: "go"}, + {Role: "assistant", Content: "", ToolCalls: `[{"index":0,"id":"c1","type":"function","function":{"name":"f","arguments":"{}"}}]`}, + {Role: "tool", ToolCallId: "OTHER", Name: "named_tool", Content: "out"}, + }, + expected: "<|turn>user\ngo\n<|turn>model\n" + `<|tool_call>call:f{}<|tool_response>response:named_tool{value:<|"|>out<|"|>}`, + }), + + // strip_thinking (tpl L148-L158): historical assistant content loses + // its <|channel>... spans. + Entry("assistant content thinking channels are stripped", renderGemma4Case{ + msgs: []*pb.Message{ + {Role: "user", Content: "hi"}, + {Role: "assistant", Content: "<|channel>thought\nsecret\nvisible answer"}, + {Role: "user", Content: "more"}, + }, + expected: "<|turn>user\nhi\n<|turn>model\nvisible answer\n<|turn>user\nmore\n<|turn>model\n<|channel>thought\n", + }), + + // tpl L220-L235: consecutive assistant messages suppress the second + // <|turn>model (continuation), but each still closes with . + Entry("consecutive assistant messages continue the model turn", renderGemma4Case{ + msgs: []*pb.Message{ + {Role: "user", Content: "hi"}, + {Role: "assistant", Content: "part one"}, + {Role: "assistant", Content: "part two"}, + {Role: "user", Content: "ok"}, + }, + expected: "<|turn>user\nhi\n<|turn>model\npart one\npart two\n<|turn>user\nok\n<|turn>model\n<|channel>thought\n", + }), + + Entry("add_generation_prompt=false renders no model turn", renderGemma4Case{ + msgs: []*pb.Message{ + {Role: "user", Content: "hi"}, + }, + noGenerationPrompt: true, + expected: "<|turn>user\nhi\n", + }), + ) + + Describe("error handling", func() { + It("fails loud on an unknown role", func() { + _, err := RenderGemma4([]*pb.Message{ + {Role: "narrator", Content: "Meanwhile..."}, + }, "", false, true) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring(`unknown role "narrator"`)) + }) + + It("fails on invalid tools JSON", func() { + _, err := RenderGemma4([]*pb.Message{ + {Role: "user", Content: "hi"}, + }, "{not json", false, true) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("tools JSON")) + }) + + It("fails on invalid tool_calls JSON", func() { + _, err := RenderGemma4([]*pb.Message{ + {Role: "user", Content: "hi"}, + {Role: "assistant", Content: "", ToolCalls: "{not json"}, + }, "", false, true) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("tool_calls JSON")) + }) + + It("fails on an orphan tool message, naming its index", func() { + // A role:tool message with no preceding assistant tool_calls turn + // would be silently dropped by the jinja; we fail loud instead. + _, err := RenderGemma4([]*pb.Message{ + {Role: "user", Content: "hi"}, + {Role: "tool", Content: `{"temp": 20}`, ToolCallId: "call_1"}, + }, "", false, true) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("orphan tool message 1")) + }) + + It("fails on trailing garbage after the tools JSON array", func() { + _, err := RenderGemma4([]*pb.Message{ + {Role: "user", Content: "hi"}, + }, "[] junk", false, true) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("tools JSON")) + }) + + It("fails when the tools JSON is not an array", func() { + _, err := RenderGemma4([]*pb.Message{ + {Role: "user", Content: "hi"}, + }, `{"type":"function"}`, false, true) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("tools JSON is not an array")) + }) + + It("fails when a tools array element is not an object", func() { + _, err := RenderGemma4([]*pb.Message{ + {Role: "user", Content: "hi"}, + }, `[42]`, false, true) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("tools[0] is not an object")) + }) + + It("rejects a nil message via the unknown-role check", func() { + // Pins current behavior: pb getters are nil-safe, so a nil message + // reads as role "" and trips the fail-loud unknown-role guard. + _, err := RenderGemma4([]*pb.Message{nil}, "", false, true) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring(`unknown role "" in message 0`)) + }) + }) +}) From 99184809fa5a310ceba9aa6b036caa806838ad75 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Thu, 11 Jun 2026 16:14:37 +0000 Subject: [PATCH 04/13] feat(dllm): rich gRPC backend with ChatDelta streaming Implements PredictRich/PredictStreamRich (legacy methods delegate), TokenizeString, and Load over the purego binding. A single worker goroutine serializes all C calls per the dllm.cpp one-generate-per-ctx contract (cancel is the documented exception); an RWMutex guards Free against in-flight requests. Under use_tokenizer_template the gemma4 renderer and streaming parser own templating and ChatDelta extraction; raw-prompt mode passes through verbatim. enable_thinking is opt-in via request metadata (the gemma4 template treats thinking as opt-in). Assisted-by: Claude Code (Fable 5) Signed-off-by: Ettore Di Giacinto --- backend/go/dllm/capi.go | 35 +- backend/go/dllm/dllm.go | 536 ++++++++++++++++++++++++++++++ backend/go/dllm/dllm_test.go | 627 +++++++++++++++++++++++++++++++++++ backend/go/dllm/main.go | 0 4 files changed, 1176 insertions(+), 22 deletions(-) mode change 100644 => 100755 backend/go/dllm/capi.go create mode 100755 backend/go/dllm/dllm.go mode change 100644 => 100755 backend/go/dllm/dllm_test.go mode change 100644 => 100755 backend/go/dllm/main.go diff --git a/backend/go/dllm/capi.go b/backend/go/dllm/capi.go old mode 100644 new mode 100755 index d8c0ca11e1ba..088bb6f26c3b --- a/backend/go/dllm/capi.go +++ b/backend/go/dllm/capi.go @@ -16,15 +16,12 @@ package main import ( "encoding/json" - "errors" "fmt" "sync" "sync/atomic" "unsafe" "github.com/ebitengine/purego" - "github.com/mudler/LocalAI/pkg/grpc/base" - pb "github.com/mudler/LocalAI/pkg/grpc/proto" ) // dllmABIVersion is the DLLM_CAPI_ABI_VERSION this binding was written @@ -48,18 +45,6 @@ var ( cppCancel func(ctx uintptr) ) -// Dllm is the LocalAI gRPC backend over the dllm.cpp C-ABI. T1 ships only -// the binding scaffold; Load/PredictRich/PredictStreamRich (and the move to -// a dedicated dllm.go with the per-model worker goroutine) land in T4. -type Dllm struct { - base.Base -} - -// Load is not wired yet: the binding smoke drives the C functions directly. -func (d *Dllm) Load(opts *pb.ModelOptions) error { - return errors.New("dllm: model loading not implemented yet (backend wiring lands in T4)") -} - // cAbiVersion returns the library's DLLM_CAPI_ABI_VERSION. func cAbiVersion() int32 { return cppAbiVersion() @@ -218,6 +203,11 @@ func cCancel(h uintptr) { // nested objects/arrays loudly; bools are rejected here too because the // scanner has no concept of them. Fail loud rather than let an option be // silently misread. +// +// CAVEAT: json.Marshal HTML-escapes <, > and & inside string values (e.g. +// "<" becomes the six-byte \u003c sequence). None of the known string-valued keys +// (kv_cache: auto|on|off) can contain those bytes today; if one ever does, +// switch to an Encoder with SetEscapeHTML(false) like gemma4JSONString. func buildOptsJSON(opts map[string]any) (string, error) { if len(opts) == 0 { return "{}", nil @@ -246,17 +236,18 @@ func buildOptsJSON(opts map[string]any) (string, error) { // caller owns, or a callback argument only valid during the invocation); // owning callers must free it via cppFreeString after the copy lands. // -// The uintptr->unsafe.Pointer conversion below trips go vet's unsafeptr -// check, which can't distinguish a C-owned heap pointer from Go-managed -// memory. It is safe here: the pointer addresses C memory the Go GC neither -// tracks nor moves, and we dereference it immediately to copy the bytes out, -// the same pattern (and the same tolerated warning) as the parakeet-cpp and -// whisper backends. +// A direct unsafe.Pointer(cptr) conversion trips go vet's unsafeptr check, +// which can't distinguish a C-owned heap pointer from Go-managed memory (the +// parakeet-cpp and whisper backends tolerate that warning). Reinterpreting +// through &cptr below is equivalent at runtime and keeps plain `go vet` +// clean. It is safe either way: the pointer addresses C memory the Go GC +// neither tracks nor moves, and we dereference it immediately to copy the +// bytes out. func goStringFromCPtr(cptr uintptr) string { if cptr == 0 { return "" } - p := unsafe.Pointer(cptr) //nolint:govet // C-owned buffer, not Go-GC memory (see doc above) + p := *(*unsafe.Pointer)(unsafe.Pointer(&cptr)) // C-owned buffer, not Go-GC memory (see doc above) n := 0 for *(*byte)(unsafe.Add(p, n)) != 0 { n++ diff --git a/backend/go/dllm/dllm.go b/backend/go/dllm/dllm.go new file mode 100755 index 000000000000..cd82ff0b3bb0 --- /dev/null +++ b/backend/go/dllm/dllm.go @@ -0,0 +1,536 @@ +package main + +// LocalAI gRPC backend for dllm.cpp (DiffusionGemma block-diffusion models). +// +// Wiring overview: +// - Load opens the GGUF via dllm_capi_load and starts the per-model worker +// goroutine that serializes every C call (see submit). +// - PredictRich / PredictStreamRich implement grpc.AIModelRich: when the +// request carries raw messages (use_tokenizer_template), the backend owns +// templating (RenderGemma4) and output parsing (Gemma4Parser) and replies +// with ChatDeltas, like the llama.cpp autoparser and the ds4 backend. +// - The legacy Predict / PredictStream methods delegate to the rich pair +// (cloud-proxy precedent); the gRPC server prefers the rich path anyway. + +import ( + "encoding/json" + "errors" + "fmt" + "strconv" + "strings" + "sync" + "unicode/utf8" + + "github.com/mudler/LocalAI/pkg/grpc/base" + "github.com/mudler/LocalAI/pkg/grpc/grpcerrors" + pb "github.com/mudler/LocalAI/pkg/grpc/proto" + "github.com/mudler/xlog" +) + +// generator is the seam between the backend wiring and the dllm.cpp C-ABI: +// the real implementation (capiGenerator) wraps the cGenerate/cTokenizeJSON +// family, while tests substitute a fake to exercise prompt construction, +// parsing and serialization without libdllm.so. +type generator interface { + generate(prompt, optsJSON string) (string, error) + // generateStream invokes onBlock once per committed diffusion block, on + // the thread running the C call, before returning. + generateStream(prompt, optsJSON string, onBlock func(text string)) error + tokenizeJSON(text string) (string, error) + // cancel is the ONE entry point safe to call concurrently with an + // in-flight generate on the same ctx (dllm_capi.h: it only flips an + // atomic; everything else must be externally serialized per ctx). + cancel() + free() +} + +// capiGenerator is the production generator over one dllm_ctx handle. +type capiGenerator struct { + h uintptr +} + +func (g *capiGenerator) generate(prompt, optsJSON string) (string, error) { + return cGenerate(g.h, prompt, optsJSON) +} + +func (g *capiGenerator) generateStream(prompt, optsJSON string, onBlock func(text string)) error { + // on_step (per-denoise-step canvas preview, dllm.cpp's --visual) is + // passed as nil for now: a future progress hook for the React UI can + // plumb it through without touching the C binding. + return cGenerateStream(g.h, prompt, optsJSON, onBlock, nil) +} + +func (g *capiGenerator) tokenizeJSON(text string) (string, error) { + return cTokenizeJSON(g.h, text) +} + +func (g *capiGenerator) cancel() { + cCancel(g.h) +} + +func (g *capiGenerator) free() { + cFree(g.h) +} + +// Dllm is the gRPC backend instance: one per loaded model (LocalAI starts +// one backend process per model). +type Dllm struct { + base.Base + + gen generator + // genOpts holds the model-level generation overrides parsed from + // ModelOptions.Options at Load (eb_*, blocks, kv_cache). The C-ABI takes + // them per-generate, not per-load, so they are merged into every + // request's opts JSON (requestOptsJSON). + genOpts map[string]any + + // jobs is the per-model worker queue. dllm_capi.h requires every entry + // point EXCEPT dllm_capi_cancel to be externally serialized per ctx (one + // ctx = one concurrent generate/tokenize; last_error is unsafe to read + // while a call is in flight). A single goroutine owning all C calls makes + // that contract structural instead of relying on lock discipline. + jobs chan func() + workerWG sync.WaitGroup + + // genMu guards gen against Free racing in-flight requests: requests hold + // the read lock for their full duration (they stay concurrent with each + // other - the worker still serializes the C calls), Free takes the write + // lock so it can only run when no request is in flight. + genMu sync.RWMutex +} + +func (d *Dllm) startWorker() { + d.jobs = make(chan func()) + d.workerWG.Add(1) + go func() { + defer d.workerWG.Done() + for job := range d.jobs { + job() + } + }() +} + +// submit runs job on the worker goroutine and waits for it to finish. +// Concurrent gRPC requests therefore queue up and execute one at a time +// against the single dllm_ctx. +func (d *Dllm) submit(job func()) { + done := make(chan struct{}) + d.jobs <- func() { + defer close(done) + job() + } + <-done +} + +// Load opens the GGUF and prepares the worker. Load-time engine parameters +// travel as the flat params JSON of dllm_capi_load; generation overrides +// from Options are stored for per-request opts JSON instead (the C-ABI has +// no per-load sampler state). +func (d *Dllm) Load(opts *pb.ModelOptions) error { + if d.gen != nil { + return errors.New("dllm: model already loaded") + } + + params := map[string]any{ + "n_gpu_layers": opts.GetNGPULayers(), + } + if opts.GetThreads() > 0 { + params["n_threads"] = opts.GetThreads() + } + if opts.GetContextSize() > 0 { + params["ctx_len"] = opts.GetContextSize() + } + paramsJSON, err := buildOptsJSON(params) + if err != nil { + return err + } + + d.genOpts = parseModelGenOpts(opts.GetOptions()) + + h := cLoad(opts.GetModelFile(), paramsJSON) + if h == 0 { + // No ctx exists on load failure, so last_error(NULL) only carries the + // static NULL-ctx message; the real reason is on the backend's stderr. + return fmt.Errorf("dllm: load %q failed: %s (see backend log for details)", + opts.GetModelFile(), lastErrorOr(0, "unknown error")) + } + d.gen = &capiGenerator{h: h} + d.startWorker() + xlog.Info("dllm: model loaded", "model", opts.GetModelFile(), "params", paramsJSON, "gen_opts", d.genOpts) + return nil +} + +// Free releases the dllm ctx and stops the worker. Safe when never loaded. +// +// The write lock is essential: the gRPC server (pkg/grpc/server.go, see the +// model-unload path around line 764) calls Free with no locking of its own, +// and base.Base provides none either. Without it a request racing Free would +// panic sending on the closed jobs channel - or worse, generate on a freed C +// ctx. Holding genMu until gen is nil also turns post-Free requests into a +// clean "model not loaded" error instead of a crash. +func (d *Dllm) Free() error { + d.genMu.Lock() + defer d.genMu.Unlock() + if d.gen == nil { + return nil + } + d.submit(d.gen.free) + close(d.jobs) + d.workerWG.Wait() + d.gen = nil + return nil +} + +// Cancel requests cancellation of the in-flight generate. It deliberately +// bypasses the worker queue: dllm_capi_cancel is the one call the C-ABI +// allows from any goroutine mid-generate (it only flips an atomic). +// +// LIMITATION: nothing invokes this on client disconnect today. The gRPC +// server (pkg/grpc/server.go) does not hand the request/stream context to +// Predict/PredictStreamRich, so a dropped HTTP client cannot reach the +// backend until that plumbing exists; the method is here so future server +// wiring (or an admin RPC) has something to call. Note dllm_capi.h's +// cancel-reset race: each generate resets the flag on entry, so a caller +// racing a new generate should re-issue Cancel. +func (d *Dllm) Cancel() { + if d.gen != nil { + d.gen.cancel() + } +} + +// dllmGenOptKeys are the ModelOptions.Options keys this backend forwards to +// the engine. Options is a shared free-form bag (other layers put their own +// entries there), so unknown keys are skipped with a warning, not an error. +var dllmGenOptKeys = map[string]bool{ + "blocks": true, + "kv_cache": true, // "auto"|"on"|"off"; honored by the engine from P3 +} + +// parseModelGenOpts parses "key:value" Options entries into the flat scalar +// map merged into every generate's opts JSON. eb_* (Entropy-Bound sampler +// knobs) and the keys in dllmGenOptKeys are recognized; values are typed by +// first successful parse (int, then float, else string) to match the C +// scanner's number/string scalars. +func parseModelGenOpts(options []string) map[string]any { + out := map[string]any{} + for _, o := range options { + key, val, found := strings.Cut(o, ":") + if !found { + xlog.Warn("dllm: ignoring malformed option (want key:value)", "option", o) + continue + } + if !strings.HasPrefix(key, "eb_") && !dllmGenOptKeys[key] { + xlog.Debug("dllm: ignoring unrecognized option", "key", key) + continue + } + out[key] = parseScalarOpt(val) + } + return out +} + +func parseScalarOpt(v string) any { + if iv, err := strconv.ParseInt(v, 10, 64); err == nil { + return iv + } + if fv, err := strconv.ParseFloat(v, 64); err == nil { + return fv + } + return v +} + +// metadataEnableThinking reads the enable_thinking gate. Unlike ds4 (default +// ON, matching ds4-server), dllm defaults OFF: DiffusionGemma's chat +// template guards every thinking branch with `enable_thinking is defined and +// enable_thinking`, i.e. thinking is opt-in for this model family, and the +// no-thinking render pre-closes an empty thought channel that the OFF +// default must produce. +func metadataEnableThinking(opts *pb.PredictOptions) bool { + v := opts.GetMetadata()["enable_thinking"] + return v == "true" || v == "1" +} + +// buildPrompt resolves the prompt for a request. With use_tokenizer_template +// and raw messages the backend owns templating (RenderGemma4) and the output +// is in the known gemma4 format, so parse=true. Without it the caller +// templated the prompt themselves (LocalAI's Go templates + PEG fallback, or +// a bare completion): the prompt passes through verbatim and the output is +// NOT gemma4-parsed - it is emitted as plain content and the Go side's +// extraction applies, as for any non-autoparsing backend. +func buildPrompt(opts *pb.PredictOptions) (prompt string, parse bool, err error) { + if opts.GetUseTokenizerTemplate() && len(opts.GetMessages()) > 0 { + prompt, err = RenderGemma4(opts.GetMessages(), opts.GetTools(), metadataEnableThinking(opts), true) + return prompt, true, err + } + return opts.GetPrompt(), false, nil +} + +// requestOptsJSON merges the model-level overrides with the request's +// sampling fields into the flat opts JSON for one generate call. +func (d *Dllm) requestOptsJSON(opts *pb.PredictOptions) (string, error) { + m := make(map[string]any, len(d.genOpts)+2) + for k, v := range d.genOpts { + m[k] = v + } + if n := opts.GetTokens(); n > 0 { + // The engine rounds n_predict UP to a whole number of diffusion + // blocks (the canvas is denoised block-wise), so the completion may + // run slightly past the requested budget. Tokens==0 omits the key so + // the engine's GGUF-metadata default applies (the C-ABI documents + // per-key defaults; no hardcoded 256 like ds4's grpc-server). + m["n_predict"] = n + } + if s := opts.GetSeed(); s > 0 { + // The engine seeds mt19937 with explicit non-negative seeds. Seed<=0 + // is omitted: proto3 cannot distinguish 0 from unset, and negative + // values conventionally mean "random" across LocalAI backends. + m["seed"] = s + } + return buildOptsJSON(m) +} + +// prepareRequest is the shared prologue of the rich methods: resolve the +// prompt (and whether the output gets gemma4-parsed) and build the per-call +// opts JSON. +func (d *Dllm) prepareRequest(opts *pb.PredictOptions) (prompt string, parse bool, optsJSON string, err error) { + prompt, parse, err = buildPrompt(opts) + if err != nil { + return "", false, "", err + } + optsJSON, err = d.requestOptsJSON(opts) + if err != nil { + return "", false, "", err + } + return prompt, parse, optsJSON, nil +} + +// sanitizeUTF8 makes s safe for a proto3 string field. Block-boundary +// detokenization and byte-fallback tokens can produce invalid UTF-8, and +// grpc-go refuses to marshal it ("string field contains invalid UTF-8"), so +// every string destined for a Reply/ChatDelta must pass through here (or +// through splitValidUTF8, which calls it). Lone malformed bytes are genuinely +// undecodable: replace with U+FFFD rather than crash the stream. +func sanitizeUTF8(s string) string { + if utf8.ValidString(s) { + return s + } + return strings.ToValidUTF8(s, "�") +} + +// utf8SeqLen returns the declared sequence length of a UTF-8 leading byte +// (1 for bytes that can never lead a multi-byte sequence, so they are never +// held back and fall through to sanitizeUTF8's replacement). +func utf8SeqLen(b byte) int { + switch { + case b&0xE0 == 0xC0: + return 2 + case b&0xF0 == 0xE0: + return 3 + case b&0xF8 == 0xF0: + return 4 + default: + return 1 + } +} + +// splitValidUTF8 prepends the previous block's carry to the new block and +// splits the result into text safe to emit now and a trailing INCOMPLETE +// UTF-8 sequence (at most utf8.UTFMax-1 bytes) to carry into the next block: +// the per-block detokenize can split a multi-byte character across block +// boundaries (llama.cpp's grpc-server holds back the same way). Only a +// suffix that can still become a valid rune is withheld; bytes that are +// already undecodable are replaced immediately so the carry stays bounded. +func splitValidUTF8(carry, block string) (emit, newCarry string) { + s := carry + block + cut := len(s) + for i := len(s) - 1; i >= 0 && len(s)-i < utf8.UTFMax; i-- { + b := s[i] + if b < utf8.RuneSelf { + break // ASCII: everything before the tail scan is complete + } + if !utf8.RuneStart(b) { + continue // continuation byte: keep looking for its leading byte + } + // Leading byte: hold the sequence back iff it declares more bytes + // than the stream has produced so far (it may complete next block). + if utf8SeqLen(b) > len(s)-i { + cut = i + } + break + } + return sanitizeUTF8(s[:cut]), s[cut:] +} + +// PredictRich is the non-streaming inference path (grpc.AIModelRich). +// Returns one Reply whose Message is the aggregated assistant content and +// whose ChatDeltas carry the parsed content/reasoning/tool-call events. +func (d *Dllm) PredictRich(opts *pb.PredictOptions) (*pb.Reply, error) { + d.genMu.RLock() + defer d.genMu.RUnlock() + if d.gen == nil { + return nil, grpcerrors.ModelNotLoaded("dllm") + } + prompt, parse, optsJSON, err := d.prepareRequest(opts) + if err != nil { + return nil, err + } + + var out string + var genErr error + d.submit(func() { + out, genErr = d.gen.generate(prompt, optsJSON) + }) + if genErr != nil { + return nil, genErr + } + // Byte-fallback tokens can detokenize to invalid UTF-8; proto3 strings + // must be valid or grpc-go fails the whole reply at marshal time. + out = sanitizeUTF8(out) + + if !parse { + // Raw-prompt mode: plain content, no gemma4 parsing (see buildPrompt). + return &pb.Reply{Message: []byte(out), ChatDeltas: []*pb.ChatDelta{{Content: out}}}, nil + } + + // The prompt renders with add_generation_prompt; both thinking modes + // leave the model starting in content state (see the Gemma4Parser header + // comment), hence NewGemma4Parser(false). + parser := NewGemma4Parser(false) + if reply := replyFromDeltas(append(parser.Feed(out), parser.Close()...)); reply != nil { + return reply, nil + } + // Everything was markers (or out was empty): an empty but non-nil Reply. + return &pb.Reply{}, nil +} + +// PredictStreamRich is the streaming counterpart (grpc.AIModelRich): one +// Reply per committed diffusion block that produced deltas. Per the +// interface contract the channel is only sent into here - the gRPC server +// closes it after this returns (opposite to legacy PredictStream). +func (d *Dllm) PredictStreamRich(opts *pb.PredictOptions, results chan<- *pb.Reply) error { + d.genMu.RLock() + defer d.genMu.RUnlock() + if d.gen == nil { + return grpcerrors.ModelNotLoaded("dllm") + } + prompt, parse, optsJSON, err := d.prepareRequest(opts) + if err != nil { + return err + } + + var parser *Gemma4Parser + if parse { + parser = NewGemma4Parser(false) + } + // emit runs inside onBlock, i.e. on the thread driving the C generate. + // Sending on results can block on a slow consumer, but the server-side + // pump (pkg/grpc/server.go PredictStream) drains continuously and drops + // undeliverable sends, so this backpressure is brief and bounded - and + // pausing the diffusion loop under it is the desired behavior anyway. + emit := func(text string) { + if !parse { + if text != "" { + results <- &pb.Reply{Message: []byte(text), ChatDeltas: []*pb.ChatDelta{{Content: text}}} + } + return + } + deltas := parser.Feed(text) + if reply := replyFromDeltas(deltas); reply != nil { + results <- reply + } + } + // onBlock guards emit (and through it the parser) against invalid UTF-8: + // a multi-byte character split across block boundaries is held back until + // it completes (see splitValidUTF8), so proto3 marshaling never fails. + var carry string + onBlock := func(block string) { + var text string + text, carry = splitValidUTF8(carry, block) + emit(text) + } + + var genErr error + d.submit(func() { + genErr = d.gen.generateStream(prompt, optsJSON, onBlock) + }) + if genErr != nil { + return genErr + } + if carry != "" { + // The stream ended mid-sequence: the held-back bytes can no longer + // complete, so flush them through the U+FFFD last resort. + emit(sanitizeUTF8(carry)) + } + if parse { + if reply := replyFromDeltas(parser.Close()); reply != nil { + results <- reply + } + } + return nil +} + +// replyFromDeltas wraps one batch of parsed deltas into a streaming Reply, +// or nil when the batch is empty (markers consumed, nothing emitted yet). +// Message mirrors the batch's content text so legacy chan-string consumers +// see exactly the displayed tokens. +func replyFromDeltas(deltas []*pb.ChatDelta) *pb.Reply { + if len(deltas) == 0 { + return nil + } + var content strings.Builder + for _, delta := range deltas { + content.WriteString(delta.GetContent()) + } + return &pb.Reply{Message: []byte(content.String()), ChatDeltas: deltas} +} + +// Predict is the legacy (string, error) signature; the gRPC server prefers +// PredictRich, this exists for non-rich callers (cloud-proxy precedent). +func (d *Dllm) Predict(opts *pb.PredictOptions) (string, error) { + reply, err := d.PredictRich(opts) + if err != nil { + return "", err + } + return string(reply.GetMessage()), nil +} + +// PredictStream is the legacy chan-string path: rich replies reduced to +// their content text. Note the inverted channel ownership - the LEGACY +// contract requires the impl to close the channel. +func (d *Dllm) PredictStream(opts *pb.PredictOptions, results chan string) error { + defer close(results) + richCh := make(chan *pb.Reply) + errCh := make(chan error, 1) + go func() { + errCh <- d.PredictStreamRich(opts, richCh) + close(richCh) + }() + for reply := range richCh { + if msg := reply.GetMessage(); len(msg) > 0 { + results <- string(msg) + } + } + return <-errCh +} + +// TokenizeString tokenizes opts.Prompt via dllm_capi_tokenize_json (the C +// side prepends bos per the vocab) and decodes the returned id array. +func (d *Dllm) TokenizeString(opts *pb.PredictOptions) (pb.TokenizationResponse, error) { + d.genMu.RLock() + defer d.genMu.RUnlock() + if d.gen == nil { + return pb.TokenizationResponse{}, grpcerrors.ModelNotLoaded("dllm") + } + var out string + var tokErr error + d.submit(func() { + out, tokErr = d.gen.tokenizeJSON(opts.GetPrompt()) + }) + if tokErr != nil { + return pb.TokenizationResponse{}, tokErr + } + var tokens []int32 + if err := json.Unmarshal([]byte(out), &tokens); err != nil { + return pb.TokenizationResponse{}, fmt.Errorf("dllm: decode tokenize result %q: %w", out, err) + } + return pb.TokenizationResponse{Length: int32(len(tokens)), Tokens: tokens}, nil +} diff --git a/backend/go/dllm/dllm_test.go b/backend/go/dllm/dllm_test.go old mode 100644 new mode 100755 index a6f1e569775c..22ef767cd654 --- a/backend/go/dllm/dllm_test.go +++ b/backend/go/dllm/dllm_test.go @@ -1,13 +1,19 @@ package main import ( + "errors" "os" + "runtime" "sync" "testing" + "time" + "unicode/utf8" "unsafe" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + + pb "github.com/mudler/LocalAI/pkg/grpc/proto" ) func TestDllm(t *testing.T) { @@ -131,10 +137,59 @@ var _ = Describe("buildOptsJSON", func() { }) }) +var _ = Describe("splitValidUTF8", func() { + It("holds back a trailing incomplete sequence and completes it next block", func() { + emit, carry := splitValidUTF8("", "caf\xe2") + Expect(emit).To(Equal("caf")) + Expect(carry).To(Equal("\xe2")) + + emit, carry = splitValidUTF8(carry, "\x82") + Expect(emit).To(BeEmpty()) + Expect(carry).To(Equal("\xe2\x82")) + + emit, carry = splitValidUTF8(carry, "\xac!") + Expect(emit).To(Equal("€!")) + Expect(carry).To(BeEmpty()) + }) + + It("holds back up to 3 bytes of a 4-byte sequence", func() { + emit, carry := splitValidUTF8("", "x\xf0\x9f\x98") // 😀 missing its last byte + Expect(emit).To(Equal("x")) + Expect(carry).To(Equal("\xf0\x9f\x98")) + + emit, carry = splitValidUTF8(carry, "\x80") + Expect(emit).To(Equal("😀")) + Expect(carry).To(BeEmpty()) + }) + + It("replaces undecodable bytes immediately instead of carrying them", func() { + // A mid-string invalid byte can never complete: carrying it would let + // the carry grow unboundedly, so it is substituted on the spot. + emit, carry := splitValidUTF8("", "a\xe2bc") + Expect(emit).To(Equal("a�bc")) + Expect(carry).To(BeEmpty()) + + // Orphan continuation bytes at the end have no leading byte to wait + // for either. + emit, carry = splitValidUTF8("", "a\x82") + Expect(emit).To(Equal("a�")) + Expect(carry).To(BeEmpty()) + }) + + It("passes pure ASCII and complete UTF-8 through untouched", func() { + emit, carry := splitValidUTF8("", "héllo €") + Expect(emit).To(Equal("héllo €")) + Expect(carry).To(BeEmpty()) + }) +}) + var _ = Describe("goStringFromCPtr", func() { It("copies a NUL-terminated buffer", func() { buf := []byte("dllm\x00") s := goStringFromCPtr(uintptr(unsafe.Pointer(&buf[0]))) + // The uintptr round-trip hides buf from the GC's liveness analysis; + // keep it reachable until after the copy. + runtime.KeepAlive(buf) Expect(s).To(Equal("dllm")) }) @@ -142,3 +197,575 @@ var _ = Describe("goStringFromCPtr", func() { Expect(goStringFromCPtr(0)).To(Equal("")) }) }) + +// --------------------------------------------------------------------------- +// Backend wiring (T4): fake-generator specs, no libdllm.so required. +// --------------------------------------------------------------------------- + +type fakeGenCall struct { + prompt string + optsJSON string +} + +// fakeGen implements generator in-process. It records every call (prompt + +// opts JSON), tracks concurrent in-flight calls to prove worker +// serialization, and replays canned output (out for generate/tokenize, +// blocks for generateStream). +type fakeGen struct { + mu sync.Mutex + calls []fakeGenCall + inFlight int + maxInFlight int + + out string + blocks []string + err error + delay time.Duration +} + +func (f *fakeGen) begin(prompt, optsJSON string) { + f.mu.Lock() + defer f.mu.Unlock() + f.calls = append(f.calls, fakeGenCall{prompt: prompt, optsJSON: optsJSON}) + f.inFlight++ + if f.inFlight > f.maxInFlight { + f.maxInFlight = f.inFlight + } +} + +func (f *fakeGen) end() { + f.mu.Lock() + defer f.mu.Unlock() + f.inFlight-- +} + +func (f *fakeGen) snapshot() (calls []fakeGenCall, maxInFlight int) { + f.mu.Lock() + defer f.mu.Unlock() + return append([]fakeGenCall(nil), f.calls...), f.maxInFlight +} + +func (f *fakeGen) generate(prompt, optsJSON string) (string, error) { + f.begin(prompt, optsJSON) + defer f.end() + if f.delay > 0 { + time.Sleep(f.delay) + } + return f.out, f.err +} + +func (f *fakeGen) generateStream(prompt, optsJSON string, onBlock func(text string)) error { + f.begin(prompt, optsJSON) + defer f.end() + if f.err != nil { + return f.err + } + for _, b := range f.blocks { + onBlock(b) + } + return nil +} + +func (f *fakeGen) tokenizeJSON(text string) (string, error) { + f.begin(text, "") + defer f.end() + return f.out, f.err +} + +func (f *fakeGen) cancel() {} +func (f *fakeGen) free() {} + +// newTestDllm assembles a backend around a fake generator (bypassing Load, +// which needs libdllm.so) and registers cleanup of the worker goroutine. +func newTestDllm(g generator, genOpts map[string]any) *Dllm { + d := &Dllm{gen: g, genOpts: genOpts} + d.startWorker() + DeferCleanup(func() { Expect(d.Free()).To(Succeed()) }) + return d +} + +// drainReplies empties ch without blocking, failing the spec if the channel +// was closed (PredictStreamRich must NOT close it - interface.go contract). +// Size ch above the expected reply count: an overflow deadlocks the spec on +// the producer's send instead of failing it. +func drainReplies(ch chan *pb.Reply) []*pb.Reply { + var out []*pb.Reply + for { + select { + case r, ok := <-ch: + if !ok { + Fail("PredictStreamRich closed the results channel (the gRPC server owns the close)") + } + expectValidUTF8Reply(r) + out = append(out, r) + default: + return out + } + } +} + +// expectValidUTF8Reply is the blanket guard for the proto3 marshaling +// contract: grpc-go rejects any string field carrying invalid UTF-8, so every +// reply field that ends up in a proto string must validate. +func expectValidUTF8Reply(r *pb.Reply) { + GinkgoHelper() + Expect(utf8.ValidString(string(r.GetMessage()))).To(BeTrue(), "Reply.Message carries invalid UTF-8") + for _, delta := range r.GetChatDeltas() { + Expect(utf8.ValidString(delta.GetContent())).To(BeTrue(), "ChatDelta.Content carries invalid UTF-8") + Expect(utf8.ValidString(delta.GetReasoningContent())).To(BeTrue(), "ChatDelta.ReasoningContent carries invalid UTF-8") + for _, tc := range delta.GetToolCalls() { + Expect(utf8.ValidString(tc.GetName())).To(BeTrue(), "ToolCallDelta.Name carries invalid UTF-8") + Expect(utf8.ValidString(tc.GetArguments())).To(BeTrue(), "ToolCallDelta.Arguments carries invalid UTF-8") + } + } +} + +var _ = Describe("Dllm backend wiring", func() { + Describe("PredictRich", func() { + It("renders gemma4 from raw messages and parses the output when use_tokenizer_template is set", func() { + fake := &fakeGen{out: "<|channel>thought\nponderingThe answer."} + d := newTestDllm(fake, nil) + + reply, err := d.PredictRich(&pb.PredictOptions{ + UseTokenizerTemplate: true, + Messages: []*pb.Message{{Role: "user", Content: "Write a long essay about Portugal."}}, + Metadata: map[string]string{"enable_thinking": "true"}, + }) + Expect(err).ToNot(HaveOccurred()) + + calls, _ := fake.snapshot() + Expect(calls).To(HaveLen(1)) + // The enable_thinking=true render from the transformers fixture. + Expect(calls[0].prompt).To(Equal( + "<|turn>system\n<|think|>\n\n<|turn>user\nWrite a long essay about Portugal.\n<|turn>model\n")) + + Expect(string(reply.GetMessage())).To(Equal("The answer.")) + Expect(reply.GetChatDeltas()).To(HaveLen(2)) + Expect(reply.GetChatDeltas()[0].GetReasoningContent()).To(Equal("pondering")) + Expect(reply.GetChatDeltas()[1].GetContent()).To(Equal("The answer.")) + }) + + It("defaults enable_thinking OFF (the gemma4 template treats thinking as opt-in)", func() { + fake := &fakeGen{out: "hi"} + d := newTestDllm(fake, nil) + + _, err := d.PredictRich(&pb.PredictOptions{ + UseTokenizerTemplate: true, + Messages: []*pb.Message{{Role: "user", Content: "Write a long essay about Portugal."}}, + }) + Expect(err).ToNot(HaveOccurred()) + + calls, _ := fake.snapshot() + // No-thinking render: the template pre-opens AND pre-closes an + // empty thought channel in the generation prompt. + Expect(calls[0].prompt).To(Equal( + "<|turn>user\nWrite a long essay about Portugal.\n<|turn>model\n<|channel>thought\n")) + }) + + It("passes the raw prompt verbatim and skips gemma4 parsing without use_tokenizer_template", func() { + // Marker-looking text must survive untouched: in raw-prompt mode + // the caller templates themselves and the Go-side extraction + // applies, so the backend must not interpret the output. + fake := &fakeGen{out: "<|channel>thought\nnot parsedtail"} + d := newTestDllm(fake, nil) + + reply, err := d.PredictRich(&pb.PredictOptions{Prompt: "my raw prompt"}) + Expect(err).ToNot(HaveOccurred()) + + calls, _ := fake.snapshot() + Expect(calls[0].prompt).To(Equal("my raw prompt")) + Expect(string(reply.GetMessage())).To(Equal(fake.out)) + Expect(reply.GetChatDeltas()).To(HaveLen(1)) + Expect(reply.GetChatDeltas()[0].GetContent()).To(Equal(fake.out)) + }) + + It("sanitizes invalid UTF-8 in the non-streaming output", func() { + // Byte-fallback tokens can decode to lone malformed bytes; the + // whole-output sanitize must replace them so proto3 marshaling of + // Message/ChatDeltas cannot fail. + fake := &fakeGen{out: "a\xe2b"} + d := newTestDllm(fake, nil) + + reply, err := d.PredictRich(&pb.PredictOptions{Prompt: "p"}) + Expect(err).ToNot(HaveOccurred()) + expectValidUTF8Reply(reply) + Expect(string(reply.GetMessage())).To(Equal("a�b")) + Expect(reply.GetChatDeltas()[0].GetContent()).To(Equal("a�b")) + }) + + It("maps Tokens and Seed into the opts JSON on top of the model-level overrides", func() { + fake := &fakeGen{out: "x"} + d := newTestDllm(fake, map[string]any{"eb_t_min": 0.5, "kv_cache": "auto"}) + + _, err := d.PredictRich(&pb.PredictOptions{Prompt: "p", Tokens: 32, Seed: 7}) + Expect(err).ToNot(HaveOccurred()) + + calls, _ := fake.snapshot() + Expect(calls[0].optsJSON).To(MatchJSON(`{"n_predict":32,"seed":7,"eb_t_min":0.5,"kv_cache":"auto"}`)) + }) + + It("omits n_predict and seed when unset so the engine defaults apply", func() { + fake := &fakeGen{out: "x"} + d := newTestDllm(fake, nil) + + _, err := d.PredictRich(&pb.PredictOptions{Prompt: "p"}) + Expect(err).ToNot(HaveOccurred()) + + calls, _ := fake.snapshot() + Expect(calls[0].optsJSON).To(MatchJSON(`{}`)) + }) + + It("surfaces generator errors", func() { + fake := &fakeGen{err: errors.New("boom")} + d := newTestDllm(fake, nil) + + _, err := d.PredictRich(&pb.PredictOptions{Prompt: "p"}) + Expect(err).To(MatchError("boom")) + }) + + It("errors before generating when no model is loaded", func() { + d := &Dllm{} // no Load, no worker: must fail fast, not hang + _, err := d.PredictRich(&pb.PredictOptions{Prompt: "p"}) + Expect(err).To(HaveOccurred()) + }) + + It("makes a concurrent Free wait for the in-flight request (both finish cleanly)", func() { + // server.go's Free has no locking of its own: the backend's genMu + // must hold Free back until the racing generate drains, instead of + // closing the jobs channel (panic) or freeing the C ctx under it. + fake := &fakeGen{out: "x", delay: 50 * time.Millisecond} + d := newTestDllm(fake, nil) + + predictDone := make(chan error, 1) + go func() { + defer GinkgoRecover() + _, err := d.PredictRich(&pb.PredictOptions{Prompt: "p"}) + predictDone <- err + }() + // Wait until the fake generate is actually in flight (the read + // lock is held from before submit until PredictRich returns). + Eventually(func() int { + _, maxInFlight := fake.snapshot() + return maxInFlight + }).Should(Equal(1)) + + Expect(d.Free()).To(Succeed()) + // Free's write lock means the request finished before Free did. + var predictErr error + Eventually(predictDone).Should(Receive(&predictErr)) + Expect(predictErr).ToNot(HaveOccurred()) + }) + + It("returns model-not-loaded for requests after Free", func() { + fake := &fakeGen{out: "x"} + d := newTestDllm(fake, nil) + Expect(d.Free()).To(Succeed()) + + _, err := d.PredictRich(&pb.PredictOptions{Prompt: "p"}) + Expect(err).To(MatchError(ContainSubstring("model not loaded"))) + }) + + It("serializes concurrent requests through the worker goroutine", func() { + // dllm_capi.h: one ctx = one concurrent generate. Two overlapping + // PredictRich calls must execute the C calls one at a time. + fake := &fakeGen{out: "x", delay: 30 * time.Millisecond} + d := newTestDllm(fake, nil) + + var wg sync.WaitGroup + for range 2 { + wg.Add(1) + go func() { + defer wg.Done() + defer GinkgoRecover() + _, err := d.PredictRich(&pb.PredictOptions{Prompt: "p"}) + Expect(err).ToNot(HaveOccurred()) + }() + } + wg.Wait() + + calls, maxInFlight := fake.snapshot() + Expect(calls).To(HaveLen(2)) + Expect(maxInFlight).To(Equal(1), "generate calls overlapped despite the worker queue") + }) + }) + + Describe("PredictStreamRich", func() { + It("emits one reply per delta-producing block and leaves the channel open", func() { + // Blocks split mid-marker and mid-payload: the parser's holdback + // must keep marker fragments out of the emitted deltas. + fake := &fakeGen{blocks: []string{ + "<|channel>thou", // partial channel open: no deltas yet + "ght\nponder", // header completes, reasoning starts + "ingHi ", // reasoning ends, content starts + "therediscarded", // turn ends: trailing text dropped + }} + d := newTestDllm(fake, nil) + + ch := make(chan *pb.Reply, 16) + err := d.PredictStreamRich(&pb.PredictOptions{ + UseTokenizerTemplate: true, + Messages: []*pb.Message{{Role: "user", Content: "hi"}}, + }, ch) + Expect(err).ToNot(HaveOccurred()) + + replies := drainReplies(ch) + Expect(replies).To(HaveLen(3), "block 1 completes no delta and must not produce a reply") + + var content, reasoning string + for _, r := range replies { + for _, delta := range r.GetChatDeltas() { + content += delta.GetContent() + reasoning += delta.GetReasoningContent() + } + } + Expect(reasoning).To(Equal("pondering")) + Expect(content).To(Equal("Hi there")) + // Message mirrors each reply's content so legacy consumers see + // exactly the displayed tokens. + Expect(string(replies[1].GetMessage())).To(Equal("Hi ")) + Expect(string(replies[2].GetMessage())).To(Equal("there")) + }) + + It("streams raw blocks verbatim without use_tokenizer_template", func() { + fake := &fakeGen{blocks: []string{"abc", "", "<|channel>def"}} + d := newTestDllm(fake, nil) + + ch := make(chan *pb.Reply, 16) + err := d.PredictStreamRich(&pb.PredictOptions{Prompt: "raw"}, ch) + Expect(err).ToNot(HaveOccurred()) + + replies := drainReplies(ch) + Expect(replies).To(HaveLen(2), "empty blocks produce no reply") + Expect(string(replies[0].GetMessage())).To(Equal("abc")) + Expect(string(replies[1].GetMessage())).To(Equal("<|channel>def")) + Expect(replies[1].GetChatDeltas()).To(HaveLen(1)) + }) + + It("flushes parser holdback after the stream ends", func() { + // The unterminated partial marker ""}} + d := newTestDllm(fake, nil) + + ch := make(chan *pb.Reply, 16) + err := d.PredictStreamRich(&pb.PredictOptions{ + UseTokenizerTemplate: true, + Messages: []*pb.Message{{Role: "user", Content: "hi"}}, + }, ch) + Expect(err).ToNot(HaveOccurred()) + + var content string + for _, r := range drainReplies(ch) { + for _, delta := range r.GetChatDeltas() { + content += delta.GetContent() + } + } + Expect(content).To(Equal("caf€")) + }) + + It("replaces an incomplete sequence left at stream end with U+FFFD", func() { + // A byte-fallback token can leave a lone leading byte (0xE2) that + // no later block completes: the final flush must substitute it, + // never emit it raw and never drop into a marshal error. + fake := &fakeGen{blocks: []string{"ok\xe2"}} + d := newTestDllm(fake, nil) + + ch := make(chan *pb.Reply, 16) + err := d.PredictStreamRich(&pb.PredictOptions{Prompt: "raw"}, ch) + Expect(err).ToNot(HaveOccurred()) + + var content string + for _, r := range drainReplies(ch) { + content += string(r.GetMessage()) + } + Expect(content).To(Equal("ok�")) + }) + + It("surfaces generator errors without sending replies", func() { + fake := &fakeGen{err: errors.New("stream boom")} + d := newTestDllm(fake, nil) + + ch := make(chan *pb.Reply, 16) + err := d.PredictStreamRich(&pb.PredictOptions{Prompt: "p"}, ch) + Expect(err).To(MatchError("stream boom")) + Expect(drainReplies(ch)).To(BeEmpty()) + }) + + It("errors before generating when no model is loaded", func() { + d := &Dllm{} // no Load, no worker: must fail fast, not hang + ch := make(chan *pb.Reply, 1) + err := d.PredictStreamRich(&pb.PredictOptions{Prompt: "p"}, ch) + Expect(err).To(MatchError(ContainSubstring("model not loaded"))) + Expect(drainReplies(ch)).To(BeEmpty()) + }) + }) + + Describe("legacy Predict/PredictStream adapters", func() { + It("Predict returns the aggregated content string", func() { + fake := &fakeGen{out: "plain text"} + d := newTestDllm(fake, nil) + + out, err := d.Predict(&pb.PredictOptions{Prompt: "p"}) + Expect(err).ToNot(HaveOccurred()) + Expect(out).To(Equal("plain text")) + }) + + It("PredictStream forwards content strings and closes the channel (legacy ownership)", func() { + fake := &fakeGen{blocks: []string{"a", "b"}} + d := newTestDllm(fake, nil) + + ch := make(chan string, 16) + Expect(d.PredictStream(&pb.PredictOptions{Prompt: "p"}, ch)).To(Succeed()) + + var got []string + for s := range ch { // terminates only if the impl closed ch + got = append(got, s) + } + Expect(got).To(Equal([]string{"a", "b"})) + }) + }) + + Describe("TokenizeString", func() { + It("decodes the C-side JSON id array", func() { + fake := &fakeGen{out: "[2,18]"} + d := newTestDllm(fake, nil) + + resp, err := d.TokenizeString(&pb.PredictOptions{Prompt: "hello"}) + Expect(err).ToNot(HaveOccurred()) + Expect(resp.Length).To(Equal(int32(2))) + Expect(resp.Tokens).To(Equal([]int32{2, 18})) + + calls, _ := fake.snapshot() + Expect(calls[0].prompt).To(Equal("hello")) + }) + + It("fails loud on a malformed id array", func() { + fake := &fakeGen{out: "not json"} + d := newTestDllm(fake, nil) + + _, err := d.TokenizeString(&pb.PredictOptions{Prompt: "hello"}) + Expect(err).To(HaveOccurred()) + }) + + It("errors before tokenizing when no model is loaded", func() { + d := &Dllm{} // no Load, no worker: must fail fast, not hang + _, err := d.TokenizeString(&pb.PredictOptions{Prompt: "hello"}) + Expect(err).To(MatchError(ContainSubstring("model not loaded"))) + }) + }) + + Describe("parseModelGenOpts", func() { + It("parses eb_*/blocks/kv_cache entries and types values by first successful parse", func() { + got := parseModelGenOpts([]string{ + "eb_max_steps:16", + "eb_t_min:0.25", + "kv_cache:auto", + "blocks:4", + "unrelated_key:1", // other layers' options: skipped + "malformed", // no colon: skipped + }) + Expect(got).To(Equal(map[string]any{ + "eb_max_steps": int64(16), + "eb_t_min": 0.25, + "kv_cache": "auto", + "blocks": int64(4), + })) + }) + + It("round-trips through buildOptsJSON (only flat scalars are produced)", func() { + got := parseModelGenOpts([]string{"eb_entropy_bound:0.8", "kv_cache:off"}) + out, err := buildOptsJSON(got) + Expect(err).ToNot(HaveOccurred()) + Expect(out).To(MatchJSON(`{"eb_entropy_bound":0.8,"kv_cache":"off"}`)) + }) + }) +}) + +// --------------------------------------------------------------------------- +// Gated backend round-trip against the real libdllm.so + tiny GGUF fixture. +// --------------------------------------------------------------------------- + +var _ = Describe("Dllm backend (real tiny model)", func() { + BeforeEach(func() { + if os.Getenv("DLLM_TEST_LIBRARY") == "" || os.Getenv("DLLM_TEST_TINY_MODEL") == "" { + Skip("set DLLM_TEST_LIBRARY and DLLM_TEST_TINY_MODEL to run the backend round-trip") + } + ensureLibLoaded() + Expect(libLoadErr).ToNot(HaveOccurred()) + }) + + It("round-trips Load, PredictRich, PredictStreamRich and TokenizeString", func() { + d := &Dllm{} + Expect(d.Load(&pb.ModelOptions{ModelFile: os.Getenv("DLLM_TEST_TINY_MODEL")})).To(Succeed()) + DeferCleanup(func() { Expect(d.Free()).To(Succeed()) }) + + // TokenizeString: tiny fixture vocab tokenizes "hello" to [2,18]. + resp, err := d.TokenizeString(&pb.PredictOptions{Prompt: "hello"}) + Expect(err).ToNot(HaveOccurred()) + Expect(resp.Tokens).To(Equal([]int32{2, 18})) + Expect(resp.Length).To(Equal(int32(2))) + + req := &pb.PredictOptions{ + UseTokenizerTemplate: true, + Messages: []*pb.Message{{Role: "user", Content: "hello"}}, + Tokens: 16, + Seed: 7, + } + + // Non-streaming: the tiny random-weight model emits arbitrary vocab + // words; with no gemma4 markers in them everything is content. + reply, err := d.PredictRich(req) + Expect(err).ToNot(HaveOccurred()) + Expect(string(reply.GetMessage())).ToNot(BeEmpty()) + Expect(reply.GetChatDeltas()).ToNot(BeEmpty()) + + // Streaming: at least one reply, and the channel-ownership rule is + // honored (drainReplies fails the spec on a closed channel). + ch := make(chan *pb.Reply, 64) + Expect(d.PredictStreamRich(req, ch)).To(Succeed()) + replies := drainReplies(ch) + Expect(replies).ToNot(BeEmpty()) + var streamed string + for _, r := range replies { + streamed += string(r.GetMessage()) + } + Expect(streamed).ToNot(BeEmpty()) + }) +}) diff --git a/backend/go/dllm/main.go b/backend/go/dllm/main.go old mode 100644 new mode 100755 From 52b3b68ceaefc12cbfd7a2dce1fa3d191cf151b2 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Thu, 11 Jun 2026 17:05:18 +0000 Subject: [PATCH 05/13] feat(dllm): backend packaging, gallery index, CI matrix Registers the dllm backend across every surface: backend gallery index (cpu amd64+arm64 with manifest merge, cuda13, l4t-cuda13 for GB10-class hardware; no darwin per engine scope), top-level Makefile targets, bump_deps pin tracking for DLLM_VERSION, and the curated known-backends list for /backends/known (pref-only: auto-detecting on .gguf would shadow llama-cpp). Note: image builds and the nightly bump leg stay red until github.com/mudler/dllm.cpp is published (planned at merge time). Assisted-by: Claude Code (Fable 5) Signed-off-by: Ettore Di Giacinto --- .github/backend-matrix.yml | 55 +++++++++++++++++++ .github/workflows/bump_deps.yaml | 4 ++ Makefile | 6 +- backend/go/dllm/Makefile | 4 ++ backend/go/dllm/dllm.go | 4 +- backend/index.yaml | 61 +++++++++++++++++++++ core/http/endpoints/localai/backend.go | 4 ++ core/http/endpoints/localai/backend_test.go | 1 + 8 files changed, 136 insertions(+), 3 deletions(-) mode change 100644 => 100755 backend/go/dllm/Makefile diff --git a/.github/backend-matrix.yml b/.github/backend-matrix.yml index 464ffc36c4f6..997f5def0a5b 100644 --- a/.github/backend-matrix.yml +++ b/.github/backend-matrix.yml @@ -1608,6 +1608,19 @@ include: dockerfile: "./backend/Dockerfile.golang" context: "./" ubuntu-version: '2404' + - build-type: 'cublas' + cuda-major-version: "13" + cuda-minor-version: "0" + platforms: 'linux/amd64' + tag-latest: 'auto' + tag-suffix: '-gpu-nvidia-cuda-13-dllm' + runs-on: 'ubuntu-latest' + base-image: "ubuntu:24.04" + skip-drivers: 'false' + backend: "dllm" + dockerfile: "./backend/Dockerfile.golang" + context: "./" + ubuntu-version: '2404' - build-type: 'cublas' cuda-major-version: "13" cuda-minor-version: "0" @@ -1647,6 +1660,19 @@ include: backend: "parakeet-cpp" dockerfile: "./backend/Dockerfile.golang" context: "./" + - build-type: 'cublas' + cuda-major-version: "13" + cuda-minor-version: "0" + platforms: 'linux/arm64' + skip-drivers: 'false' + tag-latest: 'auto' + tag-suffix: '-nvidia-l4t-cuda-13-arm64-dllm' + base-image: "ubuntu:24.04" + ubuntu-version: '2404' + runs-on: 'ubuntu-24.04-arm' + backend: "dllm" + dockerfile: "./backend/Dockerfile.golang" + context: "./" - build-type: 'cublas' cuda-major-version: "13" cuda-minor-version: "0" @@ -3145,6 +3171,35 @@ include: dockerfile: "./backend/Dockerfile.golang" context: "./" ubuntu-version: '2404' + # dllm + - build-type: '' + cuda-major-version: "" + cuda-minor-version: "" + platforms: 'linux/amd64' + platform-tag: 'amd64' + tag-latest: 'auto' + tag-suffix: '-cpu-dllm' + runs-on: 'ubuntu-latest' + base-image: "ubuntu:24.04" + skip-drivers: 'false' + backend: "dllm" + dockerfile: "./backend/Dockerfile.golang" + context: "./" + ubuntu-version: '2404' + - build-type: '' + cuda-major-version: "" + cuda-minor-version: "" + platforms: 'linux/arm64' + platform-tag: 'arm64' + tag-latest: 'auto' + tag-suffix: '-cpu-dllm' + runs-on: 'ubuntu-24.04-arm' + base-image: "ubuntu:24.04" + skip-drivers: 'false' + backend: "dllm" + dockerfile: "./backend/Dockerfile.golang" + context: "./" + ubuntu-version: '2404' - build-type: 'sycl_f32' cuda-major-version: "" cuda-minor-version: "" diff --git a/.github/workflows/bump_deps.yaml b/.github/workflows/bump_deps.yaml index 5f1ac0c21525..5572262d1fc7 100644 --- a/.github/workflows/bump_deps.yaml +++ b/.github/workflows/bump_deps.yaml @@ -38,6 +38,10 @@ jobs: variable: "PARAKEET_VERSION" branch: "master" file: "backend/go/parakeet-cpp/Makefile" + - repository: "mudler/dllm.cpp" + variable: "DLLM_VERSION" + branch: "main" + file: "backend/go/dllm/Makefile" - repository: "leejet/stable-diffusion.cpp" variable: "STABLEDIFFUSION_GGML_VERSION" branch: "master" diff --git a/Makefile b/Makefile index cafcdd44a692..89cc6bf013fb 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,5 @@ # Disable parallel execution for backend builds -.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/turboquant backends/outetts backends/piper backends/stablediffusion-ggml backends/whisper backends/crispasr backends/parakeet-cpp backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/rfdetr-cpp backends/insightface backends/speaker-recognition backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/mlx-distributed backends/stablediffusion-ggml-darwin backends/vllm backends/vllm-omni backends/sglang backends/moonshine backends/pocket-tts backends/qwen-tts backends/faster-qwen3-tts backends/qwen-asr backends/nemo backends/voxcpm backends/whisperx backends/ace-step backends/acestep-cpp backends/fish-speech backends/voxtral backends/opus backends/trl backends/llama-cpp-quantization backends/kokoros backends/sam3-cpp backends/qwen3-tts-cpp backends/vibevoice-cpp backends/localvqe backends/tinygrad backends/sherpa-onnx backends/ds4 backends/ds4-darwin backends/liquid-audio +.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/turboquant backends/outetts backends/piper backends/stablediffusion-ggml backends/whisper backends/crispasr backends/parakeet-cpp backends/dllm backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/rfdetr-cpp backends/insightface backends/speaker-recognition backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/mlx-distributed backends/stablediffusion-ggml-darwin backends/vllm backends/vllm-omni backends/sglang backends/moonshine backends/pocket-tts backends/qwen-tts backends/faster-qwen3-tts backends/qwen-asr backends/nemo backends/voxcpm backends/whisperx backends/ace-step backends/acestep-cpp backends/fish-speech backends/voxtral backends/opus backends/trl backends/llama-cpp-quantization backends/kokoros backends/sam3-cpp backends/qwen3-tts-cpp backends/vibevoice-cpp backends/localvqe backends/tinygrad backends/sherpa-onnx backends/ds4 backends/ds4-darwin backends/liquid-audio GOCMD=go GOTEST=$(GOCMD) test @@ -1171,6 +1171,9 @@ BACKEND_STABLEDIFFUSION_GGML = stablediffusion-ggml|golang|.|--progress=plain|tr BACKEND_WHISPER = whisper|golang|.|false|true BACKEND_CRISPASR = crispasr|golang|.|false|true BACKEND_PARAKEET_CPP = parakeet-cpp|golang|.|false|true +# dllm is mudler/dllm.cpp, the DiffusionGemma block-diffusion engine, +# wrapped by the purego backend at backend/go/dllm. +BACKEND_DLLM = dllm|golang|.|false|true BACKEND_VOXTRAL = voxtral|golang|.|false|true BACKEND_ACESTEP_CPP = acestep-cpp|golang|.|false|true BACKEND_QWEN3_TTS_CPP = qwen3-tts-cpp|golang|.|false|true @@ -1260,6 +1263,7 @@ $(eval $(call generate-docker-build-target,$(BACKEND_STABLEDIFFUSION_GGML))) $(eval $(call generate-docker-build-target,$(BACKEND_WHISPER))) $(eval $(call generate-docker-build-target,$(BACKEND_CRISPASR))) $(eval $(call generate-docker-build-target,$(BACKEND_PARAKEET_CPP))) +$(eval $(call generate-docker-build-target,$(BACKEND_DLLM))) $(eval $(call generate-docker-build-target,$(BACKEND_VOXTRAL))) $(eval $(call generate-docker-build-target,$(BACKEND_OPUS))) $(eval $(call generate-docker-build-target,$(BACKEND_RERANKERS))) diff --git a/backend/go/dllm/Makefile b/backend/go/dllm/Makefile old mode 100644 new mode 100755 index 3b7114c12ed5..1e0825c73644 --- a/backend/go/dllm/Makefile +++ b/backend/go/dllm/Makefile @@ -14,6 +14,10 @@ # That's what the gated C-ABI binding smoke uses (DLLM_TEST_LIBRARY). The # default target below does the proper clone-at-pin + cmake build so CI # doesn't need a side-checkout. +# +# NOTE: github.com/mudler/dllm.cpp is still private (publishing is planned); +# until then the anonymous clone below fails. Use the symlink shortcut above +# with a local checkout, or a git credential helper with access to the repo. DLLM_VERSION?=b22fcebebfb225131113188599a9ae542b2935d7 DLLM_REPO?=https://github.com/mudler/dllm.cpp diff --git a/backend/go/dllm/dllm.go b/backend/go/dllm/dllm.go index cd82ff0b3bb0..17d46de2fd2b 100755 --- a/backend/go/dllm/dllm.go +++ b/backend/go/dllm/dllm.go @@ -275,8 +275,8 @@ func (d *Dllm) requestOptsJSON(opts *pb.PredictOptions) (string, error) { // The engine rounds n_predict UP to a whole number of diffusion // blocks (the canvas is denoised block-wise), so the completion may // run slightly past the requested budget. Tokens==0 omits the key so - // the engine's GGUF-metadata default applies (the C-ABI documents - // per-key defaults; no hardcoded 256 like ds4's grpc-server). + // the C-ABI default of 256 applies (hardcoded in capi.cpp's + // parse_gen_opts, independent of canvas_length). m["n_predict"] = n } if s := opts.GetSeed(); s > 0 { diff --git a/backend/index.yaml b/backend/index.yaml index 37e6890710e4..508e87b24d1f 100644 --- a/backend/index.yaml +++ b/backend/index.yaml @@ -95,6 +95,29 @@ nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-ds4" metal: "metal-ds4" metal-darwin-arm64: "metal-ds4" +- &dllm + name: "dllm" + alias: "dllm" + license: mit + description: | + mudler/dllm.cpp - DiffusionGemma block-diffusion LLM inference engine + (C++/ggml, GGUF weights). Decodes whole token canvases per diffusion + round instead of autoregressive sampling. Runs on CPU and NVIDIA CUDA 13 + (including Jetson/GB10 L4T targets). + urls: + - https://github.com/mudler/dllm.cpp + tags: + - text-to-text + - LLM + - gguf + - diffusion + - CPU + - CUDA + capabilities: + default: "cpu-dllm" + nvidia: "cuda13-dllm" + nvidia-cuda-13: "cuda13-dllm" + nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-dllm" - &whispercpp name: "whisper" alias: "whisper" @@ -1272,6 +1295,13 @@ nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-ds4-development" metal: "metal-ds4-development" metal-darwin-arm64: "metal-ds4-development" +- !!merge <<: *dllm + name: "dllm-development" + capabilities: + default: "cpu-dllm-development" + nvidia: "cuda13-dllm-development" + nvidia-cuda-13: "cuda13-dllm-development" + nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-dllm-development" - !!merge <<: *stablediffusionggml name: "stablediffusion-ggml-development" capabilities: @@ -1859,6 +1889,37 @@ uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-ds4" mirrors: - localai/localai-backends:master-metal-darwin-arm64-ds4 +## dllm +- !!merge <<: *dllm + name: "cpu-dllm" + uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-dllm" + mirrors: + - localai/localai-backends:latest-cpu-dllm +- !!merge <<: *dllm + name: "cpu-dllm-development" + uri: "quay.io/go-skynet/local-ai-backends:master-cpu-dllm" + mirrors: + - localai/localai-backends:master-cpu-dllm +- !!merge <<: *dllm + name: "cuda13-dllm" + uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-13-dllm" + mirrors: + - localai/localai-backends:latest-gpu-nvidia-cuda-13-dllm +- !!merge <<: *dllm + name: "cuda13-dllm-development" + uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-13-dllm" + mirrors: + - localai/localai-backends:master-gpu-nvidia-cuda-13-dllm +- !!merge <<: *dllm + name: "cuda13-nvidia-l4t-arm64-dllm" + uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-cuda-13-arm64-dllm" + mirrors: + - localai/localai-backends:latest-nvidia-l4t-cuda-13-arm64-dllm +- !!merge <<: *dllm + name: "cuda13-nvidia-l4t-arm64-dllm-development" + uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-cuda-13-arm64-dllm" + mirrors: + - localai/localai-backends:master-nvidia-l4t-cuda-13-arm64-dllm ## whisper - !!merge <<: *whispercpp name: "whisper-development" diff --git a/core/http/endpoints/localai/backend.go b/core/http/endpoints/localai/backend.go index cbda648d69d2..29ac187040ba 100644 --- a/core/http/endpoints/localai/backend.go +++ b/core/http/endpoints/localai/backend.go @@ -25,6 +25,10 @@ var knownPrefOnlyBackends = []schema.KnownBackend{ // Text LLM // ds4: antirez/ds4 - single-model DeepSeek V4 Flash engine; auto-detected via DS4Importer {Name: "ds4", Modality: "text", AutoDetect: false, Description: "antirez/ds4 DeepSeek V4 Flash engine (auto-detected; pref-only fallback)"}, + // dllm consumes GGUF weights like llama-cpp does, but only for the + // DiffusionGemma architecture - auto-detecting on .gguf would shadow + // llama-cpp, so it stays preference-only. + {Name: "dllm", Modality: "text", AutoDetect: false, Description: "dllm.cpp DiffusionGemma block-diffusion engine (preference-only)"}, {Name: "sglang", Modality: "text", AutoDetect: false, Description: "SGLang runtime (preference-only)"}, {Name: "tinygrad", Modality: "text", AutoDetect: false, Description: "tinygrad runtime (preference-only)"}, {Name: "trl", Modality: "text", AutoDetect: false, Description: "Transformers Reinforcement Learning (preference-only)"}, diff --git a/core/http/endpoints/localai/backend_test.go b/core/http/endpoints/localai/backend_test.go index 0c21bb7b4f6a..70877c1b40fe 100644 --- a/core/http/endpoints/localai/backend_test.go +++ b/core/http/endpoints/localai/backend_test.go @@ -135,6 +135,7 @@ var _ = Describe("Backend Endpoints", func() { Expect(entry.Modality).To(Equal(modality)) } + expectPrefOnly("dllm", "text") expectPrefOnly("sglang", "text") expectPrefOnly("tinygrad", "text") expectPrefOnly("trl", "text") From 04d6f66a9a012aa542b1af22b4144b23fc527a64 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Thu, 11 Jun 2026 17:05:18 +0000 Subject: [PATCH 06/13] feat(dllm): diffusiongemma gallery entry and e2e coverage Gallery model diffusiongemma-26b-a4b-it (unsloth BF16 GGUF, sha256 verified against the HF LFS oid) with use_tokenizer_template and an honest experimental/throughput description. e2e: BACKEND_BINARY-mode specs boot the real gRPC backend with the tiny fixture model (templated chat + streaming); real-26B specs are separately env-gated. Adds an opt-in BACKEND_TEST_SEED knob so random-weight fixture models run the generic specs deterministically. Assisted-by: Claude Code (Fable 5) Signed-off-by: Ettore Di Giacinto --- gallery/diffusiongemma.yaml | 27 ++++ gallery/index.yaml | 37 +++++ tests/e2e-backends/backend_test.go | 8 ++ tests/e2e-backends/dllm_test.go | 223 +++++++++++++++++++++++++++++ 4 files changed, 295 insertions(+) create mode 100644 gallery/diffusiongemma.yaml create mode 100644 tests/e2e-backends/dllm_test.go diff --git a/gallery/diffusiongemma.yaml b/gallery/diffusiongemma.yaml new file mode 100644 index 000000000000..a64bfefebc56 --- /dev/null +++ b/gallery/diffusiongemma.yaml @@ -0,0 +1,27 @@ +config_file: | + backend: dllm + known_usecases: + - chat + parameters: + # Forwarded to the engine as ctx_len, but the engine at the current + # pin ignores it - the effective bound is the GGUF's trained context + # (n_ctx_train, 262144 for this model). Kept for forward-compatibility + # once the engine honors it. Note dllm generates by denoising whole + # 256-token canvases, and until the prefix-KV cache lands (dllm P3) + # EVERY denoise step recomputes the full prompt+canvas, so throughput + # drops roughly linearly with context occupancy. + context_size: 4096 + stopwords: + - + # Templating AND output parsing (content/thought channels, tool calls) + # are owned by the dllm backend's native gemma4 renderer/parser - NOT + # llama.cpp's jinja autoparser, so no use_jinja option here. + # Disabling LocalAI's grammar keeps its generated tool grammar from + # overriding the backend's native tool-call pipeline (same reasoning as + # qwen3.yaml / the ds4 importer). + function: + grammar: + disable: true + template: + use_tokenizer_template: true +name: diffusiongemma diff --git a/gallery/index.yaml b/gallery/index.yaml index 9d03a98a9bad..8c24bbd60ee3 100644 --- a/gallery/index.yaml +++ b/gallery/index.yaml @@ -1,4 +1,41 @@ --- +- name: "diffusiongemma-26b-a4b-it" + url: "github:mudler/LocalAI/gallery/diffusiongemma.yaml@master" + urls: + - https://huggingface.co/unsloth/diffusiongemma-26B-A4B-it-GGUF + - https://github.com/mudler/dllm.cpp + description: | + DiffusionGemma 26B A4B (instruction-tuned): Google DeepMind's experimental + block-diffusion language model, served by LocalAI's dllm backend + (dllm.cpp). Instead of autoregressive token-by-token decoding, text is + generated by iteratively denoising fixed 256-token canvases. + + Honest expectations: + * Experimental: both the model family and the dllm backend are young - + expect rough edges. + * BF16 weights (~50 GB): CUDA-13-class hardware (DGX Spark / large-VRAM or + unified-memory machines) is recommended; CPU works but is slow. + * Throughput: every denoise step currently recomputes the full + prompt+canvas - the prefix-KV cache that removes this lands with the + dllm backend's P3 work - so long prompts cost proportionally more per + generated block. + * Chat templating, thinking channels and tool calls are rendered and + parsed natively by the dllm backend (gemma4 renderer/parser), not by + llama.cpp's jinja autoparser. + license: apache-2.0 + tags: + - llm + - gguf + - gemma + - diffusion + - dllm + overrides: + parameters: + model: dllm/diffusiongemma-26B-A4B-it-BF16.gguf + files: + - filename: dllm/diffusiongemma-26B-A4B-it-BF16.gguf + sha256: b0ef5dbf246608953ee9945fb03c6056af9e2459799fb179651a20a8bbaa2921 + uri: https://huggingface.co/unsloth/diffusiongemma-26B-A4B-it-GGUF/resolve/main/diffusiongemma-26B-A4B-it-BF16.gguf - name: "gemma-4-26b-a4b-it-qat" url: "github:mudler/LocalAI/gallery/virtual.yaml@master" urls: diff --git a/tests/e2e-backends/backend_test.go b/tests/e2e-backends/backend_test.go index 4c7dac33cd0e..81bb668fa9ff 100644 --- a/tests/e2e-backends/backend_test.go +++ b/tests/e2e-backends/backend_test.go @@ -66,6 +66,12 @@ import ( // BACKEND_TEST_IMAGE_STEPS Override the diffusion step count for the image spec // (default: 4 — keeps CPU-only runs under a few minutes). // BACKEND_TEST_PROMPT Override the prompt used by predict/stream specs. +// BACKEND_TEST_SEED Optional sampling seed (>0) passed to the predict +// and stream specs. Unset keeps backend-default +// randomness. Needed for random-weight fixture +// models (e.g. dllm's tiny_with_vocab.gguf) where +// unseeded sampling makes the output - and thus the +// spec outcome - nondeterministic. // BACKEND_TEST_CTX_SIZE Override the context size passed to LoadModel (default 512). // BACKEND_TEST_THREADS Override Threads passed to LoadModel (default 4). // BACKEND_TEST_OPTIONS Comma-separated Options[] entries passed to LoadModel, @@ -419,6 +425,7 @@ var _ = Describe("Backend container", Ordered, func() { Temperature: 0.1, TopK: 40, TopP: 0.9, + Seed: envInt32("BACKEND_TEST_SEED", 0), }) Expect(err).NotTo(HaveOccurred()) Expect(res.GetMessage()).NotTo(BeEmpty(), "Predict produced empty output") @@ -438,6 +445,7 @@ var _ = Describe("Backend container", Ordered, func() { Temperature: 0.1, TopK: 40, TopP: 0.9, + Seed: envInt32("BACKEND_TEST_SEED", 0), }) Expect(err).NotTo(HaveOccurred()) diff --git a/tests/e2e-backends/dllm_test.go b/tests/e2e-backends/dllm_test.go new file mode 100644 index 000000000000..dd763603ce0d --- /dev/null +++ b/tests/e2e-backends/dllm_test.go @@ -0,0 +1,223 @@ +package e2ebackends_test + +import ( + "context" + "fmt" + "io" + "net" + "os" + "os/exec" + "path/filepath" + "time" + + pb "github.com/mudler/LocalAI/pkg/grpc/proto" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + "github.com/phayes/freeport" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" +) + +// ─── dllm templated chat-completion e2e (opt-in, BACKEND_BINARY mode) ──────── +// +// The generic "Backend container" suite already exercises dllm's +// health/load/predict/stream surface in BACKEND_BINARY mode (ds4 precedent: +// hardware-gated backends skip the Docker image and point BACKEND_BINARY at a +// packaged run.sh). What it does NOT cover is the templated chat path: dllm +// owns prompt rendering AND output parsing natively (the gemma4 +// renderer/parser, not llama.cpp's jinja autoparser), and that path only +// triggers when PredictOptions carries Messages + UseTokenizerTemplate. +// These specs drive exactly that round trip over the real gRPC server +// binary, non-streaming and streaming. +// +// Tiny-model spec (cheap, runs anywhere a libdllm.so build exists): +// +// BACKEND_TEST_DLLM=1 enables the spec (skipped by default, CI-safe) +// BACKEND_BINARY packaged dllm run.sh (backend/go/dllm/run.sh with +// dllm-grpc + libdllm.so next to it, or +// package/run.sh from 'make -C backend/go/dllm package') +// BACKEND_TEST_MODEL_FILE dllm.cpp's tests/fixtures/tiny_with_vocab.gguf +// (random weights + handcrafted 43-token gemma4 vocab) +// +// Real-model spec (the 26B BF16 GGUF, ~50 GB; CUDA-13-class hardware): +// +// BACKEND_TEST_DLLM_REAL_MODEL_FILE path to diffusiongemma-26B-A4B-it-BF16.gguf; +// setting it enables the spec (skipped by +// default; BACKEND_BINARY still required) +// BACKEND_TEST_DLLM_REAL_GPU_LAYERS NGPULayers for the real model +// (default -1 = full offload) +// +// Tool-call e2e is deliberately absent: the tiny fixture has RANDOM weights, +// so it cannot be prompted into emitting gemma4 <|tool_call> markup and a +// live tool-call assertion would be flaky-by-construction. Tool-call +// rendering and parsing are pinned by unit tables in backend/go/dllm +// (gemma4_renderer_test.go / gemma4_parser_test.go) instead; the real-model +// spec can grow a tools cap once a quantized checkpoint is cheap enough to +// gate on. + +// startDllmBackend boots the packaged dllm backend via BACKEND_BINARY's +// run.sh, waits for the gRPC port, loads modelFile, and returns a connected +// client. Fails the spec on any error. Teardown is registered with +// DeferCleanup the moment each resource exists, so a failure anywhere in +// setup (port-wait timeout, dial error, LoadModel failure) still reaps the +// spawned server - critical for the real-model spec, where a failed load +// would otherwise leak a ~50GB process. +func startDllmBackend(modelFile string, gpuLayers int32) pb.BackendClient { + GinkgoHelper() + + binary := os.Getenv("BACKEND_BINARY") + Expect(binary).NotTo(BeEmpty(), + "dllm chat spec requires BACKEND_BINARY pointing at the packaged dllm run.sh") + Expect(filepath.Base(binary)).To(Equal("run.sh"), + "BACKEND_BINARY must point at a run.sh (see backend/go/dllm/package.sh)") + binaryDir := filepath.Dir(binary) + Expect(filepath.Join(binaryDir, "run.sh")).To(BeAnExistingFile()) + Expect(modelFile).To(BeAnExistingFile()) + + port, err := freeport.GetFreePort() + Expect(err).NotTo(HaveOccurred()) + addr := fmt.Sprintf("127.0.0.1:%d", port) + + Expect(os.Chmod(filepath.Join(binaryDir, "run.sh"), 0o755)).To(Succeed()) + cmd := exec.Command(filepath.Join(binaryDir, "run.sh"), "--addr="+addr) + cmd.Stdout = GinkgoWriter + cmd.Stderr = GinkgoWriter + Expect(cmd.Start()).To(Succeed()) + DeferCleanup(func() { + _ = cmd.Process.Kill() + _, _ = cmd.Process.Wait() + }) + + Eventually(func() error { + c, derr := net.DialTimeout("tcp", addr, 500*time.Millisecond) + if derr != nil { + return derr + } + _ = c.Close() + return nil + }, 30*time.Second, 200*time.Millisecond).Should(Succeed(), "dllm backend did not start") + + conn, err := grpc.Dial(addr, + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(50*1024*1024)), + ) + Expect(err).NotTo(HaveOccurred()) + DeferCleanup(func() { + _ = conn.Close() + }) + client := pb.NewBackendClient(conn) + + // 15 min: reading the 26B BF16 from a cold disk dominates real-model load. + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Minute) + defer cancel() + res, err := client.LoadModel(ctx, &pb.ModelOptions{ + Model: modelFile, + ModelFile: modelFile, + ContextSize: envInt32("BACKEND_TEST_CTX_SIZE", 512), + Threads: envInt32("BACKEND_TEST_THREADS", 4), + NGPULayers: gpuLayers, + }) + Expect(err).NotTo(HaveOccurred()) + Expect(res.GetSuccess()).To(BeTrue(), "dllm LoadModel failed: %s", res.GetMessage()) + + return client +} + +// dllmChatRequest builds the templated chat request shared by both specs. +// The user content is fixed to "hello": the tiny fixture's handcrafted +// 43-token vocab is guaranteed to cover it (and the gemma4 template markup), +// while arbitrary English text is not tokenizable by that vocab. +func dllmChatRequest() *pb.PredictOptions { + return &pb.PredictOptions{ + Messages: []*pb.Message{{Role: "user", Content: "hello"}}, + UseTokenizerTemplate: true, + // Rounds up to one whole 256-token canvas (dllm commits whole + // canvases); keeps the tiny run fast and the real run bounded. + Tokens: 16, + Temperature: 0.1, + Seed: 7, + } +} + +// assertDllmChat does the non-streaming templated round trip: no error, +// non-empty content, parsed ChatDeltas present. +func assertDllmChat(client pb.BackendClient) { + GinkgoHelper() + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Minute) + defer cancel() + res, err := client.Predict(ctx, dllmChatRequest()) + Expect(err).NotTo(HaveOccurred()) + Expect(string(res.GetMessage())).NotTo(BeEmpty(), "templated chat completion produced empty content") + Expect(res.GetChatDeltas()).NotTo(BeEmpty(), "templated chat completion produced no ChatDeltas") + GinkgoWriter.Printf("dllm chat: %q (deltas=%d)\n", string(res.GetMessage()), len(res.GetChatDeltas())) +} + +// assertDllmChatStream does the streaming variant: >=1 chunk, non-empty +// combined content. +func assertDllmChatStream(client pb.BackendClient) { + GinkgoHelper() + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Minute) + defer cancel() + stream, err := client.PredictStream(ctx, dllmChatRequest()) + Expect(err).NotTo(HaveOccurred()) + + var chunks int + var combined string + for { + msg, rerr := stream.Recv() + if rerr == io.EOF { + break + } + Expect(rerr).NotTo(HaveOccurred()) + if len(msg.GetMessage()) > 0 { + chunks++ + combined += string(msg.GetMessage()) + } + } + Expect(chunks).To(BeNumerically(">=", 1), "no stream chunks received") + Expect(combined).NotTo(BeEmpty(), "streamed chat completion produced empty content") + GinkgoWriter.Printf("dllm chat stream: %d chunks, combined=%q\n", chunks, combined) +} + +var _ = Describe("dllm templated chat-completion (tiny model)", Ordered, func() { + var client pb.BackendClient + + BeforeAll(func() { + if os.Getenv("BACKEND_TEST_DLLM") != "1" { + Skip("dllm chat spec is opt-in; set BACKEND_TEST_DLLM=1 (plus BACKEND_BINARY and BACKEND_TEST_MODEL_FILE) to run it") + } + modelFile := os.Getenv("BACKEND_TEST_MODEL_FILE") + Expect(modelFile).NotTo(BeEmpty(), + "dllm chat spec requires BACKEND_TEST_MODEL_FILE (dllm.cpp's tests/fixtures/tiny_with_vocab.gguf)") + client = startDllmBackend(modelFile, 0) + }) + + It("answers a templated chat completion", func() { + assertDllmChat(client) + }) + + It("streams a templated chat completion", func() { + assertDllmChatStream(client) + }) +}) + +var _ = Describe("dllm templated chat-completion (real model)", Ordered, func() { + var client pb.BackendClient + + BeforeAll(func() { + modelFile := os.Getenv("BACKEND_TEST_DLLM_REAL_MODEL_FILE") + if modelFile == "" { + Skip("real-model dllm spec is opt-in; set BACKEND_TEST_DLLM_REAL_MODEL_FILE (the 26B BF16 GGUF) to run it") + } + client = startDllmBackend(modelFile, + envInt32("BACKEND_TEST_DLLM_REAL_GPU_LAYERS", -1)) + }) + + It("answers a templated chat completion", func() { + assertDllmChat(client) + }) + + It("streams a templated chat completion", func() { + assertDllmChatStream(client) + }) +}) From aba9c4794aaba911a4ef14e1968790edec45caba Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Thu, 11 Jun 2026 17:05:18 +0000 Subject: [PATCH 07/13] docs(dllm): backend documentation and agents topic guide User docs: dllm section in text-generation (setup, eb_* options table, n_predict canvas rounding, enable_thinking metadata, honest GB10 throughput numbers). Agents guide: .agents/dllm-backend.md covering the purego C-ABI contract, serialization rules, template provenance, test layers, and known limitations. Assisted-by: Claude Code (Fable 5) Signed-off-by: Ettore Di Giacinto --- .agents/dllm-backend.md | 134 +++++++++++++++++++++++ AGENTS.md | 1 + docs/content/features/text-generation.md | 117 ++++++++++++++++++++ 3 files changed, 252 insertions(+) create mode 100644 .agents/dllm-backend.md diff --git a/.agents/dllm-backend.md b/.agents/dllm-backend.md new file mode 100644 index 000000000000..6cd8c10dd125 --- /dev/null +++ b/.agents/dllm-backend.md @@ -0,0 +1,134 @@ +# Working on the dllm Backend + +`mudler/dllm.cpp` is a standalone C++/ggml engine for DiffusionGemma +block-diffusion models. LocalAI wraps it with a **pure-Go** backend at +`backend/go/dllm/` that dlopens `libdllm.so` via purego (ebitengine/purego) - +NOT cgo, and NOT a C++ grpc-server fork. The Go side owns chat templating +(gemma4 renderer) and output parsing (gemma4 streaming parser) and implements +the rich gRPC interface (`PredictRich`/`PredictStreamRich`, ChatDelta replies). + +> NOTE: github.com/mudler/dllm.cpp is still **private** (publishing is +> planned). Until then the Makefile's anonymous clone fails; use the local-dev +> symlink shortcut documented at the top of `backend/go/dllm/Makefile` +> (symlink an out-of-tree `build/libdllm.so` into the backend dir and skip the +> clone), or a git credential helper with repo access. + +## Pin + +`backend/go/dllm/Makefile` pins `DLLM_VERSION?=` at the top +(whisper / parakeet-cpp / ds4 convention). The bump-deps bot +(`.github/workflows/bump_deps.yaml`) tracks `mudler/dllm.cpp` `main` and +rewrites that variable. After a manual bump: `make -C backend/go/dllm purge && +make -C backend/go/dllm` (the clone is keyed on the directory existing, not +the sha). + +## C-ABI and the serialization contract + +The binding covers the 9-symbol flat C-ABI from dllm.cpp's +`include/dllm_capi.h` (ABI v1; `main.go` hard-fails on a version mismatch): +`abi_version, load, free, last_error, free_string, tokenize_json, generate, +generate_stream, cancel`. Contract points the Go wiring encodes (`capi.go` +header comment has the full list): + +- **One ctx = one concurrent generate/tokenize.** A per-model worker + goroutine (`Dllm.jobs` in `dllm.go`) owns ALL C calls, making the + serialization structural instead of lock discipline. +- **`dllm_capi_cancel` is the ONE exception**: it only flips an atomic and may + be called from any goroutine mid-generate, so `Dllm.Cancel` bypasses the + worker queue. The flag resets at the start of each generate, so a watchdog + racing a new generate must re-issue cancel. +- **`last_error` is a borrowed pointer** and must only be read AFTER the + failing call returned (never while a generate is in flight on the same ctx). +- **Free vs in-flight requests**: requests hold `genMu.RLock` for their full + duration; `Free` takes the write lock, so it only runs when nothing is in + flight, then drains and closes the worker. Post-Free requests get a clean + "model not loaded" error. +- `tokenize_json`/`generate` return malloc'd `char*` (bound as `uintptr`, + copied, then `dllm_capi_free_string`d); opts/params JSON must be a FLAT + object of scalars (`buildOptsJSON` rejects anything else). + +## Wire shape + +| RPC | Implementation | +|---|---| +| LoadModel | `dllm_capi_load` (params: `n_gpu_layers`, `n_threads`, `ctx_len`); `Options[]` parsed into per-request gen opts (`eb_*`, `blocks`, `kv_cache`) by `parseModelGenOpts` | +| PredictRich | render (if templated) → `dllm_capi_generate` → parse → ONE Reply with aggregated ChatDeltas + legacy `Message` bytes | +| PredictStreamRich | `dllm_capi_generate_stream`; per committed diffusion block → UTF-8 holdback → parser.Feed → one Reply per non-empty delta batch (channel closed by the CALLER, per `pkg/grpc/interface.go`) | +| Predict / PredictStream | Legacy paths, delegate to the rich pair (legacy stream INVERTS channel ownership: the impl closes) | +| TokenizeString | `dllm_capi_tokenize_json` (C side prepends BOS per `vocab.add_bos`) | +| Cancel | `dllm_capi_cancel`; currently INERT in practice - the gRPC server does not hand the request/stream context to backends, so client disconnects never reach it (plumbing is future work) | + +`n_threads` and `ctx_len` are accepted-but-ignored by the engine at the +current pin (the context bound comes from GGUF `n_ctx_train`); they are sent +for forward compatibility. + +## Renderer / parser (the templated chat path) + +With `use_tokenizer_template` + raw Messages, the backend owns templating and +parsing (the ds4 precedent, but in Go): + +- `gemma4_renderer.go` - `RenderGemma4(msgs, toolsJSON, enableThinking, + addGenerationPrompt)`. The file embeds the FULL `tokenizer.chat_template` + jinja (17466 bytes, md5 `8c34cf93c7a7815b3fdb300a009c4c17`) extracted + verbatim from `diffusiongemma-26B-A4B-it-BF16.gguf` via gguf-py - e.g. + `python scripts/dump_gguf.py model.gguf | grep -A400 chat_template` in the + dllm.cpp checkout - as a numbered comment block; every Go rule cites its + "tpl L" line. Re-verify the md5 before blaming the renderer for a + mismatch with a new GGUF. **BOS exception**: the template emits + `{{- bos_token -}}` but the renderer deliberately does NOT - dllm.cpp's + `run_generate` tokenizes with `prepend_bos = vocab.add_bos` (true for + gemma4), so a literal `` would double it. +- `gemma4_parser.go` - streaming state machine turning raw model text + (fragments can split anywhere, including mid-marker) into ChatDeltas: + thought channels → `reasoning_content`, `<|tool_call>call:name{...}` → + ToolCallDelta, `` → done. Marker grammar cross-checked against vLLM + PR #45163's gemma4 tool/reasoning parsers. Malformed payloads are re-emitted + raw as content, never dropped. +- Thinking is **opt-in** for this family (`Metadata["enable_thinking"]`, + default OFF - the inverse of ds4): the template gates every thinking branch + on `enable_thinking`, and the no-thinking render pre-closes an empty thought + channel, so the parser always starts in content state. +- **UTF-8 boundary holdback** (`splitValidUTF8` in `dllm.go`): per-block + detokenization can split a multi-byte character across block boundaries, and + grpc-go refuses to marshal invalid UTF-8 in proto3 strings. An incomplete + trailing sequence (at most 3 bytes) is carried into the next block; genuinely + undecodable bytes become U+FFFD. + +Without `use_tokenizer_template`, the prompt passes through verbatim and the +output is NOT gemma4-parsed (plain content, like any non-autoparsing backend). + +## Tests + +| Layer | Gate | What | +|---|---|---| +| `backend/go/dllm/*_test.go` (renderer/parser/wiring) | none - run in plain `go test ./backend/go/dllm/...` | Ginkgo specs over a fake `generator` seam; canonical renderer fixtures from transformers' `test_modeling_diffusion_gemma.py`, parser tables from the vLLM gemma4 parsers | +| `backend/go/dllm/dllm_test.go` C-ABI smoke | `DLLM_TEST_LIBRARY` + `DLLM_TEST_TINY_MODEL` (dllm.cpp's `tests/fixtures/tiny_with_vocab.gguf`); Skips when unset | Drives the real `libdllm.so`: ABI check, load, tokenize `[2,18]`, deterministic generate, cancel | +| `tests/e2e-backends/dllm_test.go` | `BACKEND_TEST_DLLM=1` + `BACKEND_BINARY` (packaged run.sh) + `BACKEND_TEST_MODEL_FILE` (tiny fixture) | Templated chat round trip (Messages + UseTokenizerTemplate) over the real gRPC binary, non-streaming + streaming | +| Real-model e2e | `BACKEND_TEST_DLLM_REAL_MODEL_FILE` (26B BF16, ~50 GB) + `BACKEND_TEST_DLLM_REAL_GPU_LAYERS` | CUDA-13-class hardware only | + +Tool-call e2e is deliberately absent from the tiny-model spec: the fixture has +random weights and cannot be coaxed into emitting tool markup; the unit tables +carry that coverage. + +## Build matrix + +`cpu-dllm` (amd64 + arm64), `cuda13-dllm` (amd64 + arm64), and +`cuda13-nvidia-l4t-arm64-dllm` (Jetson / DGX Spark GB10), via +`.github/backend-matrix.yml`. No darwin/Metal. CUDA builds forward +`-DDLLM_CUDA=ON` (dllm.cpp gates ggml's CUDA behind its own flag - a bare +`-DGGML_CUDA=ON` is overridden by the cache FORCE). `libdllm.so` is +self-contained (ggml statically absorbed, PIC), so packaging only ships the +one .so plus the usual ldd walk. + +## Known limitations + +- **Cancel is unwired**: nothing calls `Dllm.Cancel` on client disconnect + until the gRPC server plumbs the request context through to backends. +- **Throughput**: ~0.15 tok/s on the 26B at default settings (GB10) - every + denoise step recomputes the full prompt+canvas. The upstream prefix-KV + cache (dllm.cpp P3) is the fix; `kv_cache:on` errors until it lands + (`auto`/`off` are accepted no-ops). +- **Repo privacy**: see the note at the top - CI clone of dllm.cpp needs the + repo published (or credentials) before the backend images can build. +- Engine spec/validation references: dllm.cpp `docs/validation.md` and + LocalAI `docs/superpowers/specs/2026-06-10-dllm-cpp-design.md`. diff --git a/AGENTS.md b/AGENTS.md index 9f397e613fca..417efab5fbf3 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -26,6 +26,7 @@ LocalAI follows the Linux kernel project's [guidelines for AI coding assistants] | [.agents/vllm-backend.md](.agents/vllm-backend.md) | Working on the vLLM / vLLM-omni backends — native parsers, ChatDelta, CPU build, libnuma packaging, backend hooks | | [.agents/sglang-backend.md](.agents/sglang-backend.md) | Working on the SGLang backend — `engine_args` validation against ServerArgs, speculative-decoding (EAGLE/EAGLE3/DFLASH/MTP) recipes, parser handling | | [.agents/ds4-backend.md](.agents/ds4-backend.md) | Working on the ds4 backend - DSML state machine, thinking modes, KV cache, Metal+CUDA matrix | +| [.agents/dllm-backend.md](.agents/dllm-backend.md) | Working on the dllm backend (DiffusionGemma block-diffusion) - purego C-ABI binding, per-ctx serialization contract, gemma4 renderer/parser, gated test layers | | [.agents/testing-mcp-apps.md](.agents/testing-mcp-apps.md) | Testing MCP Apps (interactive tool UIs) in the React UI | | [.agents/api-endpoints-and-auth.md](.agents/api-endpoints-and-auth.md) | Adding API endpoints, auth middleware, feature permissions, user access control | | [.agents/debugging-backends.md](.agents/debugging-backends.md) | Debugging runtime backend failures, dependency conflicts, rebuilding backends | diff --git a/docs/content/features/text-generation.md b/docs/content/features/text-generation.md index c09717a3fbfc..04b481312d32 100644 --- a/docs/content/features/text-generation.md +++ b/docs/content/features/text-generation.md @@ -655,6 +655,123 @@ The `cache_type_k` / `cache_type_v` fields map to llama.cpp's `-ctk` / `-ctv` fl - [Tracked branch: `feature/turboquant-kv-cache`](https://github.com/TheTom/llama-cpp-turboquant/tree/feature/turboquant-kv-cache) +### dllm (DiffusionGemma block-diffusion) + +[dllm.cpp](https://github.com/mudler/dllm.cpp) is a standalone C++/ggml engine for **DiffusionGemma** block-diffusion language models (GGUF weights). Instead of sampling one token at a time, generation works on fixed-size token **canvases** (256 tokens for the published model): each canvas is iteratively denoised with the Entropy-Bound (EB) sampler, committed as a whole block, and committed blocks feed back as prompt for the next canvas. LocalAI wraps the engine with a native Go backend (`dllm`) that also owns chat templating and output parsing: the model's thought channels and tool calls stream natively as `reasoning_content` and `tool_calls` deltas, with no jinja template involved. + +{{% notice note %}} + +This backend is **experimental**, and the engine does not yet have a prompt-KV prefix cache: every denoise step recomputes the full prompt+canvas forward pass, so throughput is low (~0.15 tok/s at default settings on a single GB10 GPU) and drops further as the context fills up. The prefix cache is the planned fix in upstream dllm.cpp. + +{{% /notice %}} + +#### Features + +- [📖 Text generation (GPT)]({{%relref "features/text-generation" %}}) +- [🔥 OpenAI functions]({{%relref "features/openai-functions" %}}) - tool calls are parsed natively by the backend (gemma4 `<|tool_call>` markers), not by LocalAI's grammar/regex fallback +- Reasoning - opt-in thinking streams as `reasoning_content` (see below) + +#### Supported platforms + +| Flavor | Hardware | +|---|---| +| `cpu-dllm` | CPU (amd64 + arm64) - functional but very slow on the 26B model; mainly useful for wiring tests | +| `cuda13-dllm` | NVIDIA CUDA 13 (amd64 + arm64) | +| `cuda13-nvidia-l4t-arm64-dllm` | NVIDIA L4T (Jetson / DGX Spark GB10) | + +macOS/Metal is not available yet. + +#### Setup + +The easiest path is the model gallery; the entry installs the backend and the model together: + +```bash +local-ai models install diffusiongemma-26b-a4b-it +``` + +Or configure it manually with a YAML file pointing at the GGUF (BF16 is the only published file the engine's validation is calibrated for; the model card flags quantized MoE exports as problematic): + +```yaml +name: diffusiongemma +backend: dllm +parameters: + model: diffusiongemma-26B-A4B-it-BF16.gguf +context_size: 4096 +stopwords: + - +# The backend parses tool calls natively; keep LocalAI's generated tool +# grammar from overriding that pipeline. +function: + grammar: + disable: true +template: + use_tokenizer_template: true +``` + +`use_tokenizer_template: true` is what routes chat requests through the backend's native gemma4 renderer/parser (messages and tools in, `content`/`reasoning_content`/`tool_calls` out). Without it, your own prompt template output is passed to the engine verbatim and the raw model text comes back as plain content. + +#### Backend options + +Model-level generation options go in the `options:` array (format: `key:value`), like other backends: + +```yaml +options: + - eb_max_steps:24 + - kv_cache:auto +``` + +| Option | Type | Default | Description | +|--------|------|---------|-------------| +| `blocks` | integer | unset | Generation budget in whole diffusion canvases (`blocks * canvas_length` tokens, 256 per canvas for the published model). Must be >= 1. When both `blocks` and a token budget are present, `blocks` wins. | +| `kv_cache` | string | `auto` | One of `auto`, `off`, `on`. The engine has no KV cache yet, so `auto` and `off` are accepted no-ops; `kv_cache:on` fails the request until the prefix-KV cache lands upstream. | +| `eb_max_steps` | integer | 48 | Maximum denoise steps per canvas. Blocks exit early once stable **and** confident, so this is a ceiling, not a fixed cost. Lower values are faster but can degrade quality. | +| `eb_t_min` | float | 0.4 | Lower bound of the linear temperature schedule. | +| `eb_t_max` | float | 0.8 | Upper bound of the linear temperature schedule: `t = t_min + (t_max - t_min) * cur_step/max_steps`, with `cur_step` counting down, so denoising anneals from `t_max` toward `t_min`. | +| `eb_entropy_bound` | float | 0.1 | Per-step acceptance budget: canvas positions are sorted by entropy (ascending) and accepted while the cumulative entropy, minus the position's own, stays at or below the bound. Higher accepts more tokens per step (faster, riskier). | +| `eb_stability_threshold` | integer | 1 | Consecutive identical argmax canvases required before a block counts as stable (`0` = always stable; at `1` the earliest exit is the 2nd identical step). | +| `eb_confidence_threshold` | float | 0.005 | Mean-entropy ceiling for the "confident" half of the early-exit test; a block stops denoising only when it is both stable and below this. | + +Defaults for the `eb_*` knobs come from the GGUF's `diffusion.*` metadata when present, falling back to the engine defaults shown (DiffusionGemma's canonical values). The published `diffusiongemma-26B-A4B-it` GGUF carries only `diffusion.canvas_length`, so the fallbacks above are what you actually get. + +Per-request parameters: `max_tokens` maps to the engine's `n_predict` (omitted: engine default of 256), and a **positive** `seed` gives deterministic output (absent, zero or negative = a fresh random seed per call). Autoregressive sampling fields (`temperature`, `top_p`, `top_k`, ...) are **not used**: the EB sampler's own temperature schedule (`eb_t_min`/`eb_t_max`) replaces them. + +{{% notice note %}} + +**`max_tokens` rounds up to whole canvases.** The scheduler always commits whole canvases, so the token budget rounds **up** to `ceil(n_predict / canvas_length)` blocks and the completion may run slightly past the requested `max_tokens` (canonical DiffusionGemma behavior). Generation can still end earlier when the model emits an end-of-turn token, which finalizes the canvas. + +{{% /notice %}} + +#### Thinking + +DiffusionGemma's chat template makes thinking **opt-in** (the default render pre-closes an empty thought channel), so the backend defaults to thinking OFF - the opposite of most reasoning models. Enable it per request via the `metadata` field ([per-request override]({{%relref "advanced/model-configuration#per-request-override-via-metadata" %}})): + +```bash +curl http://localhost:8080/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "diffusiongemma", + "messages": [{"role": "user", "content": "Explain quantum computing"}], + "metadata": {"enable_thinking": "true"} + }' +``` + +The model's thought channel then streams as `reasoning_content`, separate from the final `content`. + +#### Performance expectations + +Honest numbers from validation on a DGX Spark (GB10, CUDA 13, BF16 26B model, full GPU offload): + +- Engine load: ~33 s (50 GB of weights to GPU) +- Forward pass: ~5.6 s per denoise step (256-token canvas); a block takes up to `eb_max_steps` steps but typically exits early (24/48 observed on a normal prompt, 4 steps on a trivial one) +- End-to-end: ~0.15 tok/s at default settings, dominated by the per-step full recompute - this is the cost the upstream prefix-KV cache work targets + +On CPU the same forward step takes ~139 s (20 Grace cores): treat the CPU flavor as functional, not practical, for the 26B model. + +#### Reference + +- [dllm.cpp](https://github.com/mudler/dllm.cpp) +- [unsloth/diffusiongemma-26B-A4B-it-GGUF](https://huggingface.co/unsloth/diffusiongemma-26B-A4B-it-GGUF) + ### vLLM [vLLM](https://github.com/vllm-project/vllm) is a fast and easy-to-use library for LLM inference. From eb61e1d77012a087f3afc76ba1f95c275c5e7ab7 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Thu, 11 Jun 2026 17:17:54 +0000 Subject: [PATCH 08/13] chore(dllm): review fixes - file modes and build-matrix doc accuracy Drop the stray executable bit from the Go sources and Makefile (the sibling Go backends commit them 644; only run.sh/package.sh are executable), and correct two documentation claims found in the final branch review: cuda13-dllm is built for amd64 only (arm64 CUDA ships as the l4t flavor), and package.sh is the parakeet-cpp-style stub layout with no ldd walk. Assisted-by: Claude Code (Fable 5) Signed-off-by: Ettore Di Giacinto --- .agents/dllm-backend.md | 9 +++++---- backend/go/dllm/Makefile | 0 backend/go/dllm/capi.go | 0 backend/go/dllm/dllm.go | 0 backend/go/dllm/dllm_test.go | 0 backend/go/dllm/gemma4_parser.go | 0 backend/go/dllm/gemma4_parser_test.go | 0 backend/go/dllm/gemma4_renderer.go | 0 backend/go/dllm/gemma4_renderer_test.go | 0 backend/go/dllm/main.go | 0 docs/content/features/text-generation.md | 4 ++-- 11 files changed, 7 insertions(+), 6 deletions(-) mode change 100755 => 100644 backend/go/dllm/Makefile mode change 100755 => 100644 backend/go/dllm/capi.go mode change 100755 => 100644 backend/go/dllm/dllm.go mode change 100755 => 100644 backend/go/dllm/dllm_test.go mode change 100755 => 100644 backend/go/dllm/gemma4_parser.go mode change 100755 => 100644 backend/go/dllm/gemma4_parser_test.go mode change 100755 => 100644 backend/go/dllm/gemma4_renderer.go mode change 100755 => 100644 backend/go/dllm/gemma4_renderer_test.go mode change 100755 => 100644 backend/go/dllm/main.go diff --git a/.agents/dllm-backend.md b/.agents/dllm-backend.md index 6cd8c10dd125..9f8586235565 100644 --- a/.agents/dllm-backend.md +++ b/.agents/dllm-backend.md @@ -112,13 +112,14 @@ carry that coverage. ## Build matrix -`cpu-dllm` (amd64 + arm64), `cuda13-dllm` (amd64 + arm64), and -`cuda13-nvidia-l4t-arm64-dllm` (Jetson / DGX Spark GB10), via +`cpu-dllm` (amd64 + arm64), `cuda13-dllm` (amd64), and +`cuda13-nvidia-l4t-arm64-dllm` (arm64 CUDA: Jetson / DGX Spark GB10), via `.github/backend-matrix.yml`. No darwin/Metal. CUDA builds forward `-DDLLM_CUDA=ON` (dllm.cpp gates ggml's CUDA behind its own flag - a bare `-DGGML_CUDA=ON` is overridden by the cache FORCE). `libdllm.so` is -self-contained (ggml statically absorbed, PIC), so packaging only ships the -one .so plus the usual ldd walk. +self-contained (ggml statically absorbed, PIC), so `package.sh` only ships +the binary, `run.sh` and that one .so (the parakeet-cpp-style stub layout; +no ldd walk yet). ## Known limitations diff --git a/backend/go/dllm/Makefile b/backend/go/dllm/Makefile old mode 100755 new mode 100644 diff --git a/backend/go/dllm/capi.go b/backend/go/dllm/capi.go old mode 100755 new mode 100644 diff --git a/backend/go/dllm/dllm.go b/backend/go/dllm/dllm.go old mode 100755 new mode 100644 diff --git a/backend/go/dllm/dllm_test.go b/backend/go/dllm/dllm_test.go old mode 100755 new mode 100644 diff --git a/backend/go/dllm/gemma4_parser.go b/backend/go/dllm/gemma4_parser.go old mode 100755 new mode 100644 diff --git a/backend/go/dllm/gemma4_parser_test.go b/backend/go/dllm/gemma4_parser_test.go old mode 100755 new mode 100644 diff --git a/backend/go/dllm/gemma4_renderer.go b/backend/go/dllm/gemma4_renderer.go old mode 100755 new mode 100644 diff --git a/backend/go/dllm/gemma4_renderer_test.go b/backend/go/dllm/gemma4_renderer_test.go old mode 100755 new mode 100644 diff --git a/backend/go/dllm/main.go b/backend/go/dllm/main.go old mode 100755 new mode 100644 diff --git a/docs/content/features/text-generation.md b/docs/content/features/text-generation.md index 04b481312d32..b6232cbb5985 100644 --- a/docs/content/features/text-generation.md +++ b/docs/content/features/text-generation.md @@ -676,8 +676,8 @@ This backend is **experimental**, and the engine does not yet have a prompt-KV p | Flavor | Hardware | |---|---| | `cpu-dllm` | CPU (amd64 + arm64) - functional but very slow on the 26B model; mainly useful for wiring tests | -| `cuda13-dllm` | NVIDIA CUDA 13 (amd64 + arm64) | -| `cuda13-nvidia-l4t-arm64-dllm` | NVIDIA L4T (Jetson / DGX Spark GB10) | +| `cuda13-dllm` | NVIDIA CUDA 13 (amd64) | +| `cuda13-nvidia-l4t-arm64-dllm` | NVIDIA L4T arm64 (Jetson / DGX Spark GB10) | macOS/Metal is not available yet. From ad6d1dbc8b163101f2d3709d2c794bbaf681d2dd Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Thu, 11 Jun 2026 17:50:04 +0000 Subject: [PATCH 09/13] feat(grpc): request cancellation for Go backends via the Cancellable capability The llama.cpp C++ backend aborts generation when its gRPC context is cancelled (grpc-server.cpp polls context->IsCancelled() in the result loops), but Go backends served by pkg/grpc never observed context cancellation: a disconnected client left the generation running to completion. Add an optional Cancellable capability; the server registers context.AfterFunc on the request/stream context (after the Locking block so queued requests cannot abort the current owner) covering both rich and legacy paths. dllm implements it: measured cancel latency ~10ms vs ~10s of orphaned generation, and follow-up requests no longer queue behind cancelled ones (~220ms vs ~9s in the e2e proof). Assisted-by: Claude Code (Fable 5) Signed-off-by: Ettore Di Giacinto --- .agents/dllm-backend.md | 13 +- backend/go/dllm/dllm.go | 37 ++++-- backend/go/dllm/dllm_test.go | 36 ++++++ docs/content/features/text-generation.md | 1 + pkg/grpc/cancel_test.go | 158 +++++++++++++++++++++++ pkg/grpc/interface.go | 13 ++ pkg/grpc/server.go | 25 ++++ tests/e2e-backends/dllm_test.go | 83 +++++++++++- 8 files changed, 349 insertions(+), 17 deletions(-) create mode 100644 pkg/grpc/cancel_test.go diff --git a/.agents/dllm-backend.md b/.agents/dllm-backend.md index 9f8586235565..ca9516f31bd4 100644 --- a/.agents/dllm-backend.md +++ b/.agents/dllm-backend.md @@ -56,7 +56,7 @@ header comment has the full list): | PredictStreamRich | `dllm_capi_generate_stream`; per committed diffusion block → UTF-8 holdback → parser.Feed → one Reply per non-empty delta batch (channel closed by the CALLER, per `pkg/grpc/interface.go`) | | Predict / PredictStream | Legacy paths, delegate to the rich pair (legacy stream INVERTS channel ownership: the impl closes) | | TokenizeString | `dllm_capi_tokenize_json` (C side prepends BOS per `vocab.add_bos`) | -| Cancel | `dllm_capi_cancel`; currently INERT in practice - the gRPC server does not hand the request/stream context to backends, so client disconnects never reach it (plumbing is future work) | +| Cancel | `dllm_capi_cancel`, exposed as the `grpc.Cancellable` capability (`pkg/grpc/interface.go`): the gRPC server arms it via `context.AfterFunc` on the Predict/PredictStream context, so client disconnects/timeouts abort the in-flight generate - llama.cpp `IsCancelled()` parity for Go backends | `n_threads` and `ctx_len` are accepted-but-ignored by the engine at the current pin (the context bound comes from GGUF `n_ctx_train`); they are sent @@ -102,8 +102,8 @@ output is NOT gemma4-parsed (plain content, like any non-autoparsing backend). | Layer | Gate | What | |---|---|---| | `backend/go/dllm/*_test.go` (renderer/parser/wiring) | none - run in plain `go test ./backend/go/dllm/...` | Ginkgo specs over a fake `generator` seam; canonical renderer fixtures from transformers' `test_modeling_diffusion_gemma.py`, parser tables from the vLLM gemma4 parsers | -| `backend/go/dllm/dllm_test.go` C-ABI smoke | `DLLM_TEST_LIBRARY` + `DLLM_TEST_TINY_MODEL` (dllm.cpp's `tests/fixtures/tiny_with_vocab.gguf`); Skips when unset | Drives the real `libdllm.so`: ABI check, load, tokenize `[2,18]`, deterministic generate, cancel | -| `tests/e2e-backends/dllm_test.go` | `BACKEND_TEST_DLLM=1` + `BACKEND_BINARY` (packaged run.sh) + `BACKEND_TEST_MODEL_FILE` (tiny fixture) | Templated chat round trip (Messages + UseTokenizerTemplate) over the real gRPC binary, non-streaming + streaming | +| `backend/go/dllm/dllm_test.go` C-ABI smoke | `DLLM_TEST_LIBRARY` + `DLLM_TEST_TINY_MODEL` (dllm.cpp's `tests/fixtures/tiny_with_vocab.gguf`); Skips when unset | Drives the real `libdllm.so`: ABI check, load, tokenize `[2,18]`, deterministic generate, cancel (incl. mid-stream `Dllm.Cancel` aborting a deliberately slow `eb_max_steps:256` run in ~10ms) | +| `tests/e2e-backends/dllm_test.go` | `BACKEND_TEST_DLLM=1` + `BACKEND_BINARY` (packaged run.sh) + `BACKEND_TEST_MODEL_FILE` (tiny fixture) | Templated chat round trip (Messages + UseTokenizerTemplate) over the real gRPC binary, non-streaming + streaming; plus client-context cancellation mid-stream (proves the `Cancellable` server plumbing end to end) | | Real-model e2e | `BACKEND_TEST_DLLM_REAL_MODEL_FILE` (26B BF16, ~50 GB) + `BACKEND_TEST_DLLM_REAL_GPU_LAYERS` | CUDA-13-class hardware only | Tool-call e2e is deliberately absent from the tiny-model spec: the fixture has @@ -123,8 +123,11 @@ no ldd walk yet). ## Known limitations -- **Cancel is unwired**: nothing calls `Dllm.Cancel` on client disconnect - until the gRPC server plumbs the request context through to backends. +- **Cancel granularity**: the C-ABI cancel flag is per-ctx and resets on + every generate entry, so a Cancel racing a NEW generate can be lost, and + with requests queued on the worker it aborts whichever generate is + currently running (acceptable: the server de-registers the hook on normal + completion, one process serves one model). - **Throughput**: ~0.15 tok/s on the 26B at default settings (GB10) - every denoise step recomputes the full prompt+canvas. The upstream prefix-KV cache (dllm.cpp P3) is the fix; `kv_cache:on` errors until it lands diff --git a/backend/go/dllm/dllm.go b/backend/go/dllm/dllm.go index 17d46de2fd2b..8ba049119152 100644 --- a/backend/go/dllm/dllm.go +++ b/backend/go/dllm/dllm.go @@ -21,12 +21,18 @@ import ( "sync" "unicode/utf8" + grpc "github.com/mudler/LocalAI/pkg/grpc" "github.com/mudler/LocalAI/pkg/grpc/base" "github.com/mudler/LocalAI/pkg/grpc/grpcerrors" pb "github.com/mudler/LocalAI/pkg/grpc/proto" "github.com/mudler/xlog" ) +// The gRPC server cancels in-flight generations on client disconnect only +// for backends advertising the Cancellable capability; keep Dllm pinned to +// it so a signature drift fails the build, not the disconnect path. +var _ grpc.Cancellable = (*Dllm)(nil) + // generator is the seam between the backend wiring and the dllm.cpp C-ABI: // the real implementation (capiGenerator) wraps the cGenerate/cTokenizeJSON // family, while tests substitute a fake to exercise prompt construction, @@ -181,18 +187,29 @@ func (d *Dllm) Free() error { return nil } -// Cancel requests cancellation of the in-flight generate. It deliberately -// bypasses the worker queue: dllm_capi_cancel is the one call the C-ABI -// allows from any goroutine mid-generate (it only flips an atomic). +// Cancel requests cancellation of the in-flight generate (the +// grpc.Cancellable capability). The gRPC server arms it via +// context.AfterFunc on the request/stream context, so a client +// disconnect or timeout aborts the generation server-side - the same +// semantics the llama.cpp C++ backend gets from polling IsCancelled(). +// It deliberately bypasses the worker queue: dllm_capi_cancel is the one +// call the C-ABI allows from any goroutine mid-generate (it only flips +// an atomic). // -// LIMITATION: nothing invokes this on client disconnect today. The gRPC -// server (pkg/grpc/server.go) does not hand the request/stream context to -// Predict/PredictStreamRich, so a dropped HTTP client cannot reach the -// backend until that plumbing exists; the method is here so future server -// wiring (or an admin RPC) has something to call. Note dllm_capi.h's -// cancel-reset race: each generate resets the flag on entry, so a caller -// racing a new generate should re-issue Cancel. +// Note dllm_capi.h's cancel-reset race: each generate resets the flag on +// entry, so a Cancel racing a NEW generate on the same ctx can be lost +// (and, with requests queued on the worker, it aborts whichever generate +// is currently running). The single-flag granularity is acceptable here +// because the server de-registers the hook on normal completion and one +// backend process serves one model. func (d *Dllm) Cancel() { + // RLock so a server-side AfterFunc firing in the window between a + // request finishing and a model unload cannot touch a freed C ctx + // (Free holds the write lock while tearing gen down). cancel() is the + // one C call that is safe concurrently with an in-flight generate, so + // taking a read lock here cannot deadlock against request holders. + d.genMu.RLock() + defer d.genMu.RUnlock() if d.gen != nil { d.gen.cancel() } diff --git a/backend/go/dllm/dllm_test.go b/backend/go/dllm/dllm_test.go index 22ef767cd654..599e00d7d1a0 100644 --- a/backend/go/dllm/dllm_test.go +++ b/backend/go/dllm/dllm_test.go @@ -768,4 +768,40 @@ var _ = Describe("Dllm backend (real tiny model)", func() { } Expect(streamed).ToNot(BeEmpty()) }) + + It("aborts an in-flight generation promptly on Cancel", func() { + d := &Dllm{} + // eb_max_steps inflates the per-block denoise loop so the full run + // takes ~10s on the tiny fixture (vs ~40ms at engine defaults; 16 + // blocks, first block after ~0.7s) - long enough that a prompt + // post-cancel return is distinguishable from the generation simply + // finishing. + Expect(d.Load(&pb.ModelOptions{ + ModelFile: os.Getenv("DLLM_TEST_TINY_MODEL"), + Options: []string{"eb_max_steps:256"}, + })).To(Succeed()) + DeferCleanup(func() { Expect(d.Free()).To(Succeed()) }) + + ch := make(chan *pb.Reply, 64) + errCh := make(chan error, 1) + go func() { + defer GinkgoRecover() + errCh <- d.PredictStreamRich(&pb.PredictOptions{Prompt: "hello", Tokens: 256, Seed: 7}, ch) + }() + + // Cancel only once the first block proves the generate is in + // flight: the C side resets the cancel flag on generate entry, so + // an earlier Cancel would be swallowed (dllm_capi.h race note). + Eventually(ch, "60s").Should(Receive()) + cancelAt := time.Now() + d.Cancel() + + // Uncancelled, ~10s of generation remain; the cancelled call must + // come back in milliseconds (the flag is checked per denoise step). + var genErr error + Eventually(errCh, "5s").Should(Receive(&genErr)) + latency := time.Since(cancelAt) + Expect(genErr).To(MatchError(ContainSubstring("cancelled"))) + GinkgoWriter.Printf("dllm cancel: PredictStreamRich returned %v after Cancel\n", latency) + }) }) diff --git a/docs/content/features/text-generation.md b/docs/content/features/text-generation.md index b6232cbb5985..199507ba0b96 100644 --- a/docs/content/features/text-generation.md +++ b/docs/content/features/text-generation.md @@ -670,6 +670,7 @@ This backend is **experimental**, and the engine does not yet have a prompt-KV p - [📖 Text generation (GPT)]({{%relref "features/text-generation" %}}) - [🔥 OpenAI functions]({{%relref "features/openai-functions" %}}) - tool calls are parsed natively by the backend (gemma4 `<|tool_call>` markers), not by LocalAI's grammar/regex fallback - Reasoning - opt-in thinking streams as `reasoning_content` (see below) +- Request cancellation - disconnecting the client (or a request timeout) aborts the in-flight generation server-side, so an abandoned slow run does not keep the GPU busy #### Supported platforms diff --git a/pkg/grpc/cancel_test.go b/pkg/grpc/cancel_test.go new file mode 100644 index 000000000000..7bdd3c3ac1db --- /dev/null +++ b/pkg/grpc/cancel_test.go @@ -0,0 +1,158 @@ +package grpc + +import ( + "context" + "errors" + "sync" + "sync/atomic" + "time" + + "github.com/mudler/LocalAI/pkg/grpc/base" + pb "github.com/mudler/LocalAI/pkg/grpc/proto" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var errGenCancelled = errors.New("generation cancelled") + +// cancellableBackend implements AIModel + AIModelRich + Cancellable. Its +// rich predict paths optionally block until Cancel fires (blockUntilCancel), +// which lets the specs prove the server's context.AfterFunc plumbing: a +// cancelled request context must reach Cancel and unblock the generation. +type cancellableBackend struct { + base.SingleThread + + blockUntilCancel bool + + started chan struct{} // closed when a predict call is in flight + startOnce sync.Once + cancelled chan struct{} // closed by Cancel + cancelOnce sync.Once + cancelCalls atomic.Int32 +} + +func newCancellableBackend(blockUntilCancel bool) *cancellableBackend { + return &cancellableBackend{ + blockUntilCancel: blockUntilCancel, + started: make(chan struct{}), + cancelled: make(chan struct{}), + } +} + +func (c *cancellableBackend) Cancel() { + c.cancelCalls.Add(1) + c.cancelOnce.Do(func() { close(c.cancelled) }) +} + +func (c *cancellableBackend) run() error { + c.startOnce.Do(func() { close(c.started) }) + if !c.blockUntilCancel { + return nil + } + select { + case <-c.cancelled: + return errGenCancelled + case <-time.After(30 * time.Second): + // Backstop so a regression (Cancel never wired) fails the spec + // instead of hanging the suite. + return errors.New("cancellableBackend: Cancel never fired") + } +} + +func (c *cancellableBackend) PredictRich(*pb.PredictOptions) (*pb.Reply, error) { + if err := c.run(); err != nil { + return nil, err + } + return &pb.Reply{Message: []byte("done")}, nil +} + +func (c *cancellableBackend) PredictStreamRich(_ *pb.PredictOptions, out chan<- *pb.Reply) error { + out <- &pb.Reply{Message: []byte("first")} + return c.run() +} + +func (c *cancellableBackend) Predict(*pb.PredictOptions) (string, error) { + return "", errors.New("cancellableBackend: legacy Predict should not have been called") +} + +func (c *cancellableBackend) PredictStream(*pb.PredictOptions, chan string) error { + return errors.New("cancellableBackend: legacy PredictStream should not have been called") +} + +var _ AIModelRich = (*cancellableBackend)(nil) +var _ Cancellable = (*cancellableBackend)(nil) + +var _ = Describe("Cancellable capability", func() { + It("PredictStream: cancelling the request context fires Cancel and ends the stream with the backend's error", func() { + backend := newCancellableBackend(true) + addr := "test://cancel-stream" + Provide(addr, backend) + c := NewClient(addr, true, nil, false) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + errCh := make(chan error, 1) + go func() { + defer GinkgoRecover() + errCh <- c.PredictStream(ctx, &pb.PredictOptions{}, func(*pb.Reply) {}) + }() + + // Only cancel once the generation is provably in flight; cancelling + // earlier would race the AfterFunc registration in the server. + Eventually(backend.started, "5s").Should(BeClosed()) + cancel() + + var err error + Eventually(errCh, "5s").Should(Receive(&err)) + Expect(err).To(MatchError(errGenCancelled)) + Expect(backend.cancelCalls.Load()).To(BeNumerically(">=", 1)) + }) + + It("Predict: cancelling the request context fires Cancel and unblocks the call", func() { + backend := newCancellableBackend(true) + addr := "test://cancel-predict" + Provide(addr, backend) + c := NewClient(addr, true, nil, false) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + errCh := make(chan error, 1) + go func() { + defer GinkgoRecover() + _, err := c.Predict(ctx, &pb.PredictOptions{}) + errCh <- err + }() + + Eventually(backend.started, "5s").Should(BeClosed()) + cancel() + + var err error + Eventually(errCh, "5s").Should(Receive(&err)) + Expect(err).To(MatchError(errGenCancelled)) + Expect(backend.cancelCalls.Load()).To(BeNumerically(">=", 1)) + }) + + It("does not call Cancel when the request completes normally", func() { + backend := newCancellableBackend(false) + addr := "test://cancel-clean" + Provide(addr, backend) + c := NewClient(addr, true, nil, false) + + ctx, cancel := context.WithCancel(context.Background()) + + var replies []*pb.Reply + err := c.PredictStream(ctx, &pb.PredictOptions{}, func(r *pb.Reply) { + replies = append(replies, r) + }) + Expect(err).ToNot(HaveOccurred()) + Expect(replies).To(HaveLen(1)) + + // Cancelling AFTER completion must not reach the backend: the + // deferred AfterFunc stop de-registered the hook, so a shared or + // reused context cannot abort someone else's later generation. + cancel() + Consistently(backend.cancelCalls.Load, "200ms").Should(BeZero()) + }) +}) diff --git a/pkg/grpc/interface.go b/pkg/grpc/interface.go index 31b9ab26deb6..6647b5ec3de9 100644 --- a/pkg/grpc/interface.go +++ b/pkg/grpc/interface.go @@ -72,6 +72,19 @@ func newReply(s string) *pb.Reply { return &pb.Reply{Message: []byte(s)} } +// Cancellable is an optional capability: backends that can abort an +// in-flight generation implement it. The server calls Cancel when the +// request's gRPC context is cancelled (client disconnect/timeout), +// giving Go backends the same semantics the llama.cpp C++ backend gets +// from polling context->IsCancelled() in its result loops. +// +// Cancel may be invoked from an arbitrary goroutine while the +// generation is running, so implementations must make it safe to call +// concurrently with Predict/PredictStream (and their rich variants). +type Cancellable interface { + Cancel() +} + // AIModelRich is an optional extension to AIModel for backends that // can produce a full *pb.Reply — including tool-call deltas and // usage tokens — rather than just a content string. The gRPC server diff --git a/pkg/grpc/server.go b/pkg/grpc/server.go index 5be668497b77..361d4107eddc 100644 --- a/pkg/grpc/server.go +++ b/pkg/grpc/server.go @@ -63,11 +63,32 @@ func (s *server) LoadModel(ctx context.Context, in *pb.ModelOptions) (*pb.Result return &pb.Result{Message: "Loading succeeded", Success: true}, nil } +// cancelOnDone arms the optional Cancellable capability: when ctx is +// cancelled (client disconnect/timeout) the backend's Cancel fires so it +// can abort the in-flight generation - the Go-backend equivalent of the +// llama.cpp C++ server polling context->IsCancelled() in its result loops. +// Callers MUST defer the returned stop so a normally-completed request +// de-registers the hook before returning; otherwise a later cancellation +// of the same ctx would abort an unrelated in-flight generation. +// +// Arm it AFTER the Locking() block: for serialized backends a request +// queued on the lock is not generating yet, and cancelling it must not +// abort whichever request currently owns the backend. +func (s *server) cancelOnDone(ctx context.Context) (stop func() bool) { + if c, ok := s.llm.(Cancellable); ok { + return context.AfterFunc(ctx, c.Cancel) + } + return func() bool { return false } +} + func (s *server) Predict(ctx context.Context, in *pb.PredictOptions) (*pb.Reply, error) { if s.llm.Locking() { s.llm.Lock() defer s.llm.Unlock() } + // One registration covers both the rich and the legacy branch below. + stop := s.cancelOnDone(ctx) + defer stop() if rich, ok := s.llm.(AIModelRich); ok { return rich.PredictRich(in) } @@ -275,6 +296,10 @@ func (s *server) PredictStream(in *pb.PredictOptions, stream pb.Backend_PredictS defer s.llm.Unlock() } + // One registration covers both the rich and the legacy branch below. + stop := s.cancelOnDone(stream.Context()) + defer stop() + if rich, ok := s.llm.(AIModelRich); ok { replyChan := make(chan *pb.Reply) done := make(chan bool) diff --git a/tests/e2e-backends/dllm_test.go b/tests/e2e-backends/dllm_test.go index dd763603ce0d..5c7766c98d92 100644 --- a/tests/e2e-backends/dllm_test.go +++ b/tests/e2e-backends/dllm_test.go @@ -61,8 +61,9 @@ import ( // DeferCleanup the moment each resource exists, so a failure anywhere in // setup (port-wait timeout, dial error, LoadModel failure) still reaps the // spawned server - critical for the real-model spec, where a failed load -// would otherwise leak a ~50GB process. -func startDllmBackend(modelFile string, gpuLayers int32) pb.BackendClient { +// would otherwise leak a ~50GB process. options are extra ModelOptions +// "key:value" entries (eb_* sampler knobs etc.). +func startDllmBackend(modelFile string, gpuLayers int32, options ...string) pb.BackendClient { GinkgoHelper() binary := os.Getenv("BACKEND_BINARY") @@ -116,6 +117,7 @@ func startDllmBackend(modelFile string, gpuLayers int32) pb.BackendClient { ContextSize: envInt32("BACKEND_TEST_CTX_SIZE", 512), Threads: envInt32("BACKEND_TEST_THREADS", 4), NGPULayers: gpuLayers, + Options: options, }) Expect(err).NotTo(HaveOccurred()) Expect(res.GetSuccess()).To(BeTrue(), "dllm LoadModel failed: %s", res.GetMessage()) @@ -201,6 +203,83 @@ var _ = Describe("dllm templated chat-completion (tiny model)", Ordered, func() }) }) +var _ = Describe("dllm request cancellation (tiny model)", Ordered, func() { + var client pb.BackendClient + + BeforeAll(func() { + if os.Getenv("BACKEND_TEST_DLLM") != "1" { + Skip("dllm cancellation spec is opt-in; set BACKEND_TEST_DLLM=1 (plus BACKEND_BINARY and BACKEND_TEST_MODEL_FILE) to run it") + } + modelFile := os.Getenv("BACKEND_TEST_MODEL_FILE") + Expect(modelFile).NotTo(BeEmpty(), + "dllm cancellation spec requires BACKEND_TEST_MODEL_FILE (dllm.cpp's tests/fixtures/tiny_with_vocab.gguf)") + // eb_max_steps inflates the per-block denoise loop: a 256-token run + // takes ~10s on the tiny fixture (vs ~40ms at engine defaults), so a + // cancelled request is clearly distinguishable from one that simply + // finished. A dedicated backend process keeps the chat specs fast. + client = startDllmBackend(modelFile, 0, "eb_max_steps:256") + }) + + // This is the end-to-end proof of the Cancellable plumbing + // (pkg/grpc/server.go arming backend.Cancel via context.AfterFunc on + // the stream context): a client disconnect mid-stream must abort the + // server-side generation, not just orphan it. + It("aborts the in-flight generation when the client context is cancelled mid-stream", func() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + // Raw-prompt mode, not the templated chat request: the templated + // render can hit an end-of-turn token after the first block and + // finish before the cancel lands, which would silently turn this + // into a no-op spec. The raw "hello" run is probed deterministic + // with seed 7: 16 blocks, the eb_max_steps cap hit on every one, + // ~10s total if left to finish. + req := &pb.PredictOptions{Prompt: "hello", Tokens: 256, Seed: 7} + + stream, err := client.PredictStream(ctx, req) + Expect(err).NotTo(HaveOccurred()) + + // First chunk received = the generate is provably in flight (the C + // side resets the cancel flag on generate entry, so cancelling + // before it starts would be swallowed). + _, err = stream.Recv() + Expect(err).NotTo(HaveOccurred()) + cancel() + + // Client side: the stream must end promptly, not after the + // remaining ~9s of generation (the first chunk arrives after one + // ~0.7s block, so plenty of generation is provably outstanding). + recvDone := make(chan error, 1) + go func() { + defer GinkgoRecover() + for { + if _, rerr := stream.Recv(); rerr != nil { + recvDone <- rerr + return + } + } + }() + var rerr error + Eventually(recvDone, "5s").Should(Receive(&rerr)) + Expect(rerr).NotTo(Equal(io.EOF), "stream completed normally despite the cancelled context") + + // Server side: prove the generation actually aborted. dllm + // serializes every C call through one worker goroutine, so if the + // orphaned generation were still grinding, this follow-up would + // queue behind its remaining ~9s instead of completing in ~1s + // (16 tokens = one block at eb_max_steps:256). + followCtx, followCancel := context.WithTimeout(context.Background(), 2*time.Minute) + defer followCancel() + start := time.Now() + res, err := client.Predict(followCtx, dllmChatRequest()) + elapsed := time.Since(start) + Expect(err).NotTo(HaveOccurred()) + Expect(string(res.GetMessage())).NotTo(BeEmpty()) + Expect(elapsed).To(BeNumerically("<", 5*time.Second), + "follow-up request queued behind the cancelled generation - server-side Cancel did not reach the backend") + GinkgoWriter.Printf("dllm cancel e2e: follow-up completed in %v after mid-stream cancellation\n", elapsed) + }) +}) + var _ = Describe("dllm templated chat-completion (real model)", Ordered, func() { var client pb.BackendClient From 8134d6db374a6f6fa6b6ae5784244159e5bec54d Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Thu, 11 Jun 2026 19:22:02 +0000 Subject: [PATCH 10/13] docs(dllm): record Q4_K_M validation and quantization guidance Q4_K_M validated on GB10: quality holds (cosine 0.9862, coherent generation, 19/48 stopper exit) but a forward step is ~5x slower than BF16 (27.5s vs 5.6s: native BF16 tensor cores vs K-quant MoE dequant). Guidance: prefer BF16 when it fits; Q4_K_M is the memory-bound option. Assisted-by: Claude Code (Fable 5) Signed-off-by: Ettore Di Giacinto --- docs/content/features/text-generation.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/content/features/text-generation.md b/docs/content/features/text-generation.md index 199507ba0b96..da1f25cd205b 100644 --- a/docs/content/features/text-generation.md +++ b/docs/content/features/text-generation.md @@ -768,6 +768,8 @@ Honest numbers from validation on a DGX Spark (GB10, CUDA 13, BF16 26B model, fu On CPU the same forward step takes ~139 s (20 Grace cores): treat the CPU flavor as functional, not practical, for the 26B model. +**Quantized models.** The Q4_K_M export (16.8 GB vs 50.5 GB BF16) was validated on the same GB10: it loads faster (~12.6 s vs ~32.7 s), quality held up in validation (golden-logits cosine 0.9862, coherent generation on the same prompt as the BF16 run, EB stopper exiting at 19/48 steps, ~0.49 tok/s on that run) - but a forward step takes ~27.5 s, about **5x slower than BF16** (~5.6 s/step) on this hardware. GB10-class GPUs run BF16 natively on tensor cores, while the K-quant MoE weights pay a dequantization cost on every denoise step. Choose Q4_K_M only when you are memory-bound; if BF16 fits, it is both faster and the file the engine's validation tolerances are calibrated for. + #### Reference - [dllm.cpp](https://github.com/mudler/dllm.cpp) From c9c6040fe8d24cd747ca4eeba039ab2b10ccfcf4 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Thu, 11 Jun 2026 20:24:26 +0000 Subject: [PATCH 11/13] feat(dllm): default gallery entry on Q4_K_M; add Q8_0 variant Q4_K_M (~17 GB, GB10-validated: cosine 0.9862, coherent generation) is the friendlier default download than the 50 GB BF16; Q8_0 (~27 GB) is the higher-fidelity middle ground. Both descriptions carry the measured caveat that BF16 is ~5x faster per denoise step on BF16-native hardware, with a pointer to fetch it manually when it fits. sha256 values are the HF LFS oids. Assisted-by: Claude Code (Fable 5) Signed-off-by: Ettore Di Giacinto --- gallery/index.yaml | 40 ++++++++++++++++++++++++++++++++++------ 1 file changed, 34 insertions(+), 6 deletions(-) diff --git a/gallery/index.yaml b/gallery/index.yaml index 8c24bbd60ee3..28326f12720b 100644 --- a/gallery/index.yaml +++ b/gallery/index.yaml @@ -13,8 +13,11 @@ Honest expectations: * Experimental: both the model family and the dllm backend are young - expect rough edges. - * BF16 weights (~50 GB): CUDA-13-class hardware (DGX Spark / large-VRAM or - unified-memory machines) is recommended; CPU works but is slow. + * Q4_K_M weights (~17 GB): the memory-friendly default. Validated on + GB10 (quality holds: golden cosine 0.9862, coherent generation), but + note that on hardware with native BF16 tensor cores (GB10-class) the + BF16 file is ~5x FASTER per denoise step than K-quants - if ~50 GB + fits, fetch diffusiongemma-26B-A4B-it-BF16.gguf manually instead. * Throughput: every denoise step currently recomputes the full prompt+canvas - the prefix-KV cache that removes this lands with the dllm backend's P3 work - so long prompts cost proportionally more per @@ -31,11 +34,36 @@ - dllm overrides: parameters: - model: dllm/diffusiongemma-26B-A4B-it-BF16.gguf + model: dllm/diffusiongemma-26B-A4B-it-Q4_K_M.gguf files: - - filename: dllm/diffusiongemma-26B-A4B-it-BF16.gguf - sha256: b0ef5dbf246608953ee9945fb03c6056af9e2459799fb179651a20a8bbaa2921 - uri: https://huggingface.co/unsloth/diffusiongemma-26B-A4B-it-GGUF/resolve/main/diffusiongemma-26B-A4B-it-BF16.gguf + - filename: dllm/diffusiongemma-26B-A4B-it-Q4_K_M.gguf + sha256: d2ca2c032ebfb23cf2d1794a3465e615c7545634d46b3c30652a26d8b07c4ad3 + uri: https://huggingface.co/unsloth/diffusiongemma-26B-A4B-it-GGUF/resolve/main/diffusiongemma-26B-A4B-it-Q4_K_M.gguf +- name: "diffusiongemma-26b-a4b-it-q8_0" + url: "github:mudler/LocalAI/gallery/diffusiongemma.yaml@master" + urls: + - https://huggingface.co/unsloth/diffusiongemma-26B-A4B-it-GGUF + - https://github.com/mudler/dllm.cpp + description: | + DiffusionGemma 26B A4B (instruction-tuned), Q8_0 quantization (~27 GB): + the higher-fidelity middle ground between Q4_K_M (~17 GB) and BF16 + (~50 GB). Served by LocalAI's dllm backend (dllm.cpp); see the Q4_K_M + entry for the full notes on experimental status, throughput, and the + K-quant-vs-BF16 speed trade-off on BF16-native hardware. + license: apache-2.0 + tags: + - llm + - gguf + - gemma + - diffusion + - dllm + overrides: + parameters: + model: dllm/diffusiongemma-26B-A4B-it-Q8_0.gguf + files: + - filename: dllm/diffusiongemma-26B-A4B-it-Q8_0.gguf + sha256: fa5180660b80d52aae94ed814a6183af303841d8bb425a27f13ea27400a7b430 + uri: https://huggingface.co/unsloth/diffusiongemma-26B-A4B-it-GGUF/resolve/main/diffusiongemma-26B-A4B-it-Q8_0.gguf - name: "gemma-4-26b-a4b-it-qat" url: "github:mudler/LocalAI/gallery/virtual.yaml@master" urls: From b40843cf62d3c9a2f381943377a9ef0a5a417761 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Fri, 12 Jun 2026 00:41:04 +0000 Subject: [PATCH 12/13] feat(dllm): image input through the backend (multimodal C-ABI) Routes PredictOptions.Images (raw base64, the core convention) through dllm.cpp's probed multimodal entry points as data: URIs; the gemma4 renderer appends one engine-side marker per image after the last user message (llama.cpp attachment convention; the template's content-parts branch is unreachable through the flattened pb shape). The engine expands markers to boi + soft*n + eoi and splices the vision-tower embeddings. Older libdllm.so without the mm symbols fails with an actionable error (Dlsym probe). DLLM_VERSION pin bumped to the engine's vision-capable commit. Assisted-by: Claude Code (Fable 5) Signed-off-by: Ettore Di Giacinto --- backend/go/dllm/Makefile | 6 +- backend/go/dllm/capi.go | 86 ++++++- backend/go/dllm/dllm.go | 101 ++++++-- backend/go/dllm/dllm_test.go | 293 +++++++++++++++++++++++- backend/go/dllm/gemma4_renderer.go | 40 +++- backend/go/dllm/gemma4_renderer_test.go | 81 ++++++- backend/go/dllm/main.go | 19 +- tests/e2e-backends/dllm_test.go | 79 +++++++ 8 files changed, 662 insertions(+), 43 deletions(-) diff --git a/backend/go/dllm/Makefile b/backend/go/dllm/Makefile index 1e0825c73644..9a0f0aeae84c 100644 --- a/backend/go/dllm/Makefile +++ b/backend/go/dllm/Makefile @@ -19,7 +19,11 @@ # until then the anonymous clone below fails. Use the symlink shortcut above # with a local checkout, or a git credential helper with access to the repo. -DLLM_VERSION?=b22fcebebfb225131113188599a9ae542b2935d7 +# The pin below is the first commit carrying the multimodal C-ABI entry +# points (dllm_capi_generate_mm / dllm_capi_generate_stream_mm) the +# image-input path probes for; older libs still load, but image requests +# then fail with "library predates the multimodal entry points". +DLLM_VERSION?=e6dcf44cddd65845e3a0814a1c2282a5d90ee98a DLLM_REPO?=https://github.com/mudler/dllm.cpp GOCMD?=go diff --git a/backend/go/dllm/capi.go b/backend/go/dllm/capi.go index 088bb6f26c3b..f389e0721a00 100644 --- a/backend/go/dllm/capi.go +++ b/backend/go/dllm/capi.go @@ -16,6 +16,7 @@ package main import ( "encoding/json" + "errors" "fmt" "sync" "sync/atomic" @@ -45,6 +46,34 @@ var ( cppCancel func(ctx uintptr) ) +// Optional multimodal entry points (dllm_capi.h's P4 surface). The ABI +// version stays 1: presence is detected by PROBING the symbols with Dlsym at +// boot (loadCAPI, mirroring the parakeet-cpp optional-symbol pattern). nil +// means the loaded libdllm.so predates the mm surface; the wrappers below +// then fail with errMMUnsupported instead of crashing on a nil call. +var ( + cppGenerateMM func(ctx uintptr, prompt, imagesJSON, optsJSON string) uintptr + cppGenerateStreamMM func(ctx uintptr, prompt, imagesJSON, optsJSON string, onBlock, onStep, userData uintptr) int32 +) + +// mmImageMarker is the literal placeholder dllm_capi_generate_mm expands to +// + soft-token placeholders + (dllm_capi.h placeholder contract; +// capi.cpp MM_MARKER). The prompt must carry exactly one marker per +// images_json entry, in image order. +const mmImageMarker = "" + +// errMMUnsupported is returned for image-bearing requests against an old +// text-only libdllm.so (the Dlsym probe found no mm symbols). +var errMMUnsupported = errors.New( + "dllm: image input requires libdllm.so with the multimodal entry points (dllm_capi_generate_mm), but the loaded library predates them - rebuild/upgrade the dllm backend to use images") + +// cMMSupported reports whether the loaded libdllm.so carries the multimodal +// generate pair. Both symbols ship together (same dllm.cpp commit), but the +// guard requires both anyway so a half-present surface can never dispatch. +func cMMSupported() bool { + return cppGenerateMM != nil && cppGenerateStreamMM != nil +} + // cAbiVersion returns the library's DLLM_CAPI_ABI_VERSION. func cAbiVersion() int32 { return cppAbiVersion() @@ -108,6 +137,23 @@ func cGenerate(h uintptr, prompt, optsJSON string) (string, error) { return out, nil } +// cGenerateMM is cGenerate's multimodal counterpart. imagesJSON is the flat +// JSON array of image entries (data: base64 URIs here; the C side also takes +// file paths) and the prompt must carry one mmImageMarker per entry - the +// engine enforces the 1:1 match and reports mismatches through last_error. +func cGenerateMM(h uintptr, prompt, imagesJSON, optsJSON string) (string, error) { + if !cMMSupported() { + return "", errMMUnsupported + } + ret := cppGenerateMM(h, prompt, imagesJSON, optsJSON) + if ret == 0 { + return "", fmt.Errorf("dllm: generate_mm failed: %s", lastErrorOr(h, "unknown error")) + } + out := goStringFromCPtr(ret) + cppFreeString(ret) + return out, nil +} + // streamCallState carries the Go callbacks for one in-flight // cGenerateStream call; the registry key travels through C as user_data. // The map shape mirrors the whisper backend's streamCallStates: only one @@ -158,11 +204,12 @@ func onStepTrampoline(step int32, totalSteps int32, canvasPreview uintptr, userD } } -// cGenerateStream runs a generation with per-committed-block (onBlock) and -// per-denoising-step (onStep) callbacks; either may be nil. The callbacks -// run on the C thread (see the trampoline docs). Returns an error carrying -// last_error on failure; cancellation surfaces as the "cancelled" message. -func cGenerateStream(h uintptr, prompt, optsJSON string, onBlock func(text string), onStep func(step, total int, preview string)) error { +// withStreamCallbacks registers onBlock/onStep in the trampoline registry +// for the duration of one streaming C call and invokes call with the C +// function pointers (NULL for absent callbacks, so the C side skips the +// per-block / per-step detokenize work entirely) plus the registry key to +// pass as user_data. Shared by the text and multimodal stream wrappers. +func withStreamCallbacks(onBlock func(text string), onStep func(step, total int, preview string), call func(blockPtr, stepPtr, userData uintptr) int32) int32 { streamCbOnce.Do(func() { blockCbPtr = purego.NewCallback(onBlockTrampoline) stepCbPtr = purego.NewCallback(onStepTrampoline) @@ -172,8 +219,6 @@ func cGenerateStream(h uintptr, prompt, optsJSON string, onBlock func(text strin streamCallStates.Store(id, &streamCallState{onBlock: onBlock, onStep: onStep}) defer streamCallStates.Delete(id) - // Pass NULL for absent callbacks so the C side skips the per-block / - // per-step detokenize work entirely. var blockPtr, stepPtr uintptr if onBlock != nil { blockPtr = blockCbPtr @@ -181,13 +226,38 @@ func cGenerateStream(h uintptr, prompt, optsJSON string, onBlock func(text strin if onStep != nil { stepPtr = stepCbPtr } + return call(blockPtr, stepPtr, uintptr(id)) +} - if rc := cppGenerateStream(h, prompt, optsJSON, blockPtr, stepPtr, uintptr(id)); rc != 0 { +// cGenerateStream runs a generation with per-committed-block (onBlock) and +// per-denoising-step (onStep) callbacks; either may be nil. The callbacks +// run on the C thread (see the trampoline docs). Returns an error carrying +// last_error on failure; cancellation surfaces as the "cancelled" message. +func cGenerateStream(h uintptr, prompt, optsJSON string, onBlock func(text string), onStep func(step, total int, preview string)) error { + rc := withStreamCallbacks(onBlock, onStep, func(blockPtr, stepPtr, userData uintptr) int32 { + return cppGenerateStream(h, prompt, optsJSON, blockPtr, stepPtr, userData) + }) + if rc != 0 { return fmt.Errorf("dllm: generate_stream failed: %s", lastErrorOr(h, "unknown error")) } return nil } +// cGenerateStreamMM is cGenerateStream's multimodal counterpart; see +// cGenerateMM for the imagesJSON/marker contract. +func cGenerateStreamMM(h uintptr, prompt, imagesJSON, optsJSON string, onBlock func(text string), onStep func(step, total int, preview string)) error { + if !cMMSupported() { + return errMMUnsupported + } + rc := withStreamCallbacks(onBlock, onStep, func(blockPtr, stepPtr, userData uintptr) int32 { + return cppGenerateStreamMM(h, prompt, imagesJSON, optsJSON, blockPtr, stepPtr, userData) + }) + if rc != 0 { + return fmt.Errorf("dllm: generate_stream_mm failed: %s", lastErrorOr(h, "unknown error")) + } + return nil +} + // cCancel requests cancellation of the in-flight generate on h. This is the // ONE entry point safe to call from any goroutine while a generate runs (it // only flips an atomic). Note the cancel-reset race from the header: each diff --git a/backend/go/dllm/dllm.go b/backend/go/dllm/dllm.go index 8ba049119152..394aa106ba3d 100644 --- a/backend/go/dllm/dllm.go +++ b/backend/go/dllm/dllm.go @@ -42,6 +42,13 @@ type generator interface { // generateStream invokes onBlock once per committed diffusion block, on // the thread running the C call, before returning. generateStream(prompt, optsJSON string, onBlock func(text string)) error + // generateMM / generateStreamMM are the multimodal counterparts: + // imagesJSON is a flat JSON array of data: base64 URIs and the prompt + // carries one mmImageMarker per entry (dllm_capi.h placeholder + // contract). Against an old text-only libdllm.so they fail with + // errMMUnsupported. + generateMM(prompt, imagesJSON, optsJSON string) (string, error) + generateStreamMM(prompt, imagesJSON, optsJSON string, onBlock func(text string)) error tokenizeJSON(text string) (string, error) // cancel is the ONE entry point safe to call concurrently with an // in-flight generate on the same ctx (dllm_capi.h: it only flips an @@ -66,6 +73,15 @@ func (g *capiGenerator) generateStream(prompt, optsJSON string, onBlock func(tex return cGenerateStream(g.h, prompt, optsJSON, onBlock, nil) } +func (g *capiGenerator) generateMM(prompt, imagesJSON, optsJSON string) (string, error) { + return cGenerateMM(g.h, prompt, imagesJSON, optsJSON) +} + +func (g *capiGenerator) generateStreamMM(prompt, imagesJSON, optsJSON string, onBlock func(text string)) error { + // on_step is nil for the same reason as generateStream. + return cGenerateStreamMM(g.h, prompt, imagesJSON, optsJSON, onBlock, nil) +} + func (g *capiGenerator) tokenizeJSON(text string) (string, error) { return cTokenizeJSON(g.h, text) } @@ -267,20 +283,55 @@ func metadataEnableThinking(opts *pb.PredictOptions) bool { } // buildPrompt resolves the prompt for a request. With use_tokenizer_template -// and raw messages the backend owns templating (RenderGemma4) and the output -// is in the known gemma4 format, so parse=true. Without it the caller -// templated the prompt themselves (LocalAI's Go templates + PEG fallback, or -// a bare completion): the prompt passes through verbatim and the output is -// NOT gemma4-parsed - it is emitted as plain content and the Go side's -// extraction applies, as for any non-autoparsing backend. +// and raw messages the backend owns templating (RenderGemma4, including the +// mmImageMarker injection for opts.Images) and the output is in the known +// gemma4 format, so parse=true. Without it the caller templated the prompt +// themselves (LocalAI's Go templates + PEG fallback, or a bare completion): +// the prompt passes through verbatim - for image requests it must already +// carry one literal mmImageMarker per image (the engine enforces the 1:1 +// match) - and the output is NOT gemma4-parsed - it is emitted as plain +// content and the Go side's extraction applies, as for any non-autoparsing +// backend. func buildPrompt(opts *pb.PredictOptions) (prompt string, parse bool, err error) { if opts.GetUseTokenizerTemplate() && len(opts.GetMessages()) > 0 { - prompt, err = RenderGemma4(opts.GetMessages(), opts.GetTools(), metadataEnableThinking(opts), true) + prompt, err = RenderGemma4(opts.GetMessages(), opts.GetTools(), len(opts.GetImages()), metadataEnableThinking(opts), true) return prompt, true, err } return opts.GetPrompt(), false, nil } +// imagesJSON renders opts.Images as the flat JSON array of data: URIs the mm +// C-ABI expects, or "" when the request carries no images. The entries arrive +// as RAW base64 payloads: LocalAI's OpenAI layer decodes every image_url / +// image content part (URL download or data: URI) to plain base64 via +// utils.GetContentURIAsBase64 (core/http/middleware/request.go) and core +// flattens them into PredictOptions.Images (core/backend/llm.go). The +// hardcoded image/jpeg mime mirrors the llama.cpp backend's re-wrapping +// convention (grpc-server.cpp, "data:image/jpeg;base64," + images(i)); the +// engine ignores the declared mime and sniffs the real format from the +// decoded bytes (stb_image), so PNG/BMP payloads work through it too. +func imagesJSON(images []string) (string, error) { + if len(images) == 0 { + return "", nil + } + uris := make([]string, len(images)) + for i, img := range images { + // dllm_capi.h: array entries are read VERBATIM up to the closing + // quote, with NO escape handling. json.Marshal would escape these + // bytes and the C side would misparse the entry, so fail loud (they + // can never appear in genuine base64 anyway). + if strings.ContainsAny(img, "\"\\") { + return "", fmt.Errorf("dllm: image %d is not base64 (contains a quote or backslash; PredictOptions.Images entries must be raw base64 payloads)", i) + } + uris[i] = "data:image/jpeg;base64," + img + } + b, err := json.Marshal(uris) + if err != nil { + return "", fmt.Errorf("dllm: marshal images: %w", err) + } + return string(b), nil +} + // requestOptsJSON merges the model-level overrides with the request's // sampling fields into the flat opts JSON for one generate call. func (d *Dllm) requestOptsJSON(opts *pb.PredictOptions) (string, error) { @@ -307,17 +358,27 @@ func (d *Dllm) requestOptsJSON(opts *pb.PredictOptions) (string, error) { // prepareRequest is the shared prologue of the rich methods: resolve the // prompt (and whether the output gets gemma4-parsed) and build the per-call -// opts JSON. -func (d *Dllm) prepareRequest(opts *pb.PredictOptions) (prompt string, parse bool, optsJSON string, err error) { +// opts JSON plus the images JSON ("" for text-only requests, which routes +// the call through the text generate entry points). +func (d *Dllm) prepareRequest(opts *pb.PredictOptions) (prompt string, parse bool, optsJSON, imgJSON string, err error) { + // Fail loud on media the engine has no path for, instead of silently + // generating from a prompt that ignores them. + if len(opts.GetVideos()) > 0 || len(opts.GetAudios()) > 0 { + return "", false, "", "", errors.New("dllm: video/audio input is not supported (images only)") + } prompt, parse, err = buildPrompt(opts) if err != nil { - return "", false, "", err + return "", false, "", "", err } optsJSON, err = d.requestOptsJSON(opts) if err != nil { - return "", false, "", err + return "", false, "", "", err } - return prompt, parse, optsJSON, nil + imgJSON, err = imagesJSON(opts.GetImages()) + if err != nil { + return "", false, "", "", err + } + return prompt, parse, optsJSON, imgJSON, nil } // sanitizeUTF8 makes s safe for a proto3 string field. Block-boundary @@ -386,7 +447,7 @@ func (d *Dllm) PredictRich(opts *pb.PredictOptions) (*pb.Reply, error) { if d.gen == nil { return nil, grpcerrors.ModelNotLoaded("dllm") } - prompt, parse, optsJSON, err := d.prepareRequest(opts) + prompt, parse, optsJSON, imgJSON, err := d.prepareRequest(opts) if err != nil { return nil, err } @@ -394,7 +455,11 @@ func (d *Dllm) PredictRich(opts *pb.PredictOptions) (*pb.Reply, error) { var out string var genErr error d.submit(func() { - out, genErr = d.gen.generate(prompt, optsJSON) + if imgJSON != "" { + out, genErr = d.gen.generateMM(prompt, imgJSON, optsJSON) + } else { + out, genErr = d.gen.generate(prompt, optsJSON) + } }) if genErr != nil { return nil, genErr @@ -429,7 +494,7 @@ func (d *Dllm) PredictStreamRich(opts *pb.PredictOptions, results chan<- *pb.Rep if d.gen == nil { return grpcerrors.ModelNotLoaded("dllm") } - prompt, parse, optsJSON, err := d.prepareRequest(opts) + prompt, parse, optsJSON, imgJSON, err := d.prepareRequest(opts) if err != nil { return err } @@ -467,7 +532,11 @@ func (d *Dllm) PredictStreamRich(opts *pb.PredictOptions, results chan<- *pb.Rep var genErr error d.submit(func() { - genErr = d.gen.generateStream(prompt, optsJSON, onBlock) + if imgJSON != "" { + genErr = d.gen.generateStreamMM(prompt, imgJSON, optsJSON, onBlock) + } else { + genErr = d.gen.generateStream(prompt, optsJSON, onBlock) + } }) if genErr != nil { return genErr diff --git a/backend/go/dllm/dllm_test.go b/backend/go/dllm/dllm_test.go index 599e00d7d1a0..ab6259ef783d 100644 --- a/backend/go/dllm/dllm_test.go +++ b/backend/go/dllm/dllm_test.go @@ -1,6 +1,7 @@ package main import ( + "encoding/base64" "errors" "os" "runtime" @@ -205,6 +206,9 @@ var _ = Describe("goStringFromCPtr", func() { type fakeGenCall struct { prompt string optsJSON string + // imagesJSON is set only by the multimodal entry points; "" means the + // call went through the text path. + imagesJSON string } // fakeGen implements generator in-process. It records every call (prompt + @@ -224,9 +228,13 @@ type fakeGen struct { } func (f *fakeGen) begin(prompt, optsJSON string) { + f.beginMM(prompt, "", optsJSON) +} + +func (f *fakeGen) beginMM(prompt, imagesJSON, optsJSON string) { f.mu.Lock() defer f.mu.Unlock() - f.calls = append(f.calls, fakeGenCall{prompt: prompt, optsJSON: optsJSON}) + f.calls = append(f.calls, fakeGenCall{prompt: prompt, optsJSON: optsJSON, imagesJSON: imagesJSON}) f.inFlight++ if f.inFlight > f.maxInFlight { f.maxInFlight = f.inFlight @@ -266,6 +274,27 @@ func (f *fakeGen) generateStream(prompt, optsJSON string, onBlock func(text stri return nil } +func (f *fakeGen) generateMM(prompt, imagesJSON, optsJSON string) (string, error) { + f.beginMM(prompt, imagesJSON, optsJSON) + defer f.end() + if f.delay > 0 { + time.Sleep(f.delay) + } + return f.out, f.err +} + +func (f *fakeGen) generateStreamMM(prompt, imagesJSON, optsJSON string, onBlock func(text string)) error { + f.beginMM(prompt, imagesJSON, optsJSON) + defer f.end() + if f.err != nil { + return f.err + } + for _, b := range f.blocks { + onBlock(b) + } + return nil +} + func (f *fakeGen) tokenizeJSON(text string) (string, error) { f.begin(text, "") defer f.end() @@ -637,6 +666,164 @@ var _ = Describe("Dllm backend wiring", func() { }) }) + Describe("image input routing", func() { + // "QUJD" is base64("ABC"); core delivers raw base64 payloads in + // PredictOptions.Images (the data: prefix is stripped by the OpenAI + // layer), and the backend re-wraps them as data: URIs for the mm + // C-ABI. + const imgB64 = "QUJD" + const imgURI = "data:image/jpeg;base64," + imgB64 + + It("routes PredictRich through generateMM with data-URI images and a marker-bearing prompt", func() { + fake := &fakeGen{out: "a cat"} + d := newTestDllm(fake, nil) + + reply, err := d.PredictRich(&pb.PredictOptions{ + UseTokenizerTemplate: true, + Messages: []*pb.Message{{Role: "user", Content: "What is this?"}}, + Images: []string{imgB64}, + }) + Expect(err).ToNot(HaveOccurred()) + Expect(string(reply.GetMessage())).To(Equal("a cat")) + + calls, _ := fake.snapshot() + Expect(calls).To(HaveLen(1)) + Expect(calls[0].imagesJSON).To(MatchJSON(`["` + imgURI + `"]`)) + // One engine marker per image, injected on the user turn by the + // renderer (the engine enforces the 1:1 marker/image match). + Expect(calls[0].prompt).To(Equal( + "<|turn>user\nWhat is this?\n<|turn>model\n<|channel>thought\n")) + }) + + It("routes PredictStreamRich through generateStreamMM with the same images JSON", func() { + fake := &fakeGen{blocks: []string{"a dog"}} + d := newTestDllm(fake, nil) + + ch := make(chan *pb.Reply, 16) + err := d.PredictStreamRich(&pb.PredictOptions{ + UseTokenizerTemplate: true, + Messages: []*pb.Message{{Role: "user", Content: "And this?"}}, + Images: []string{imgB64}, + }, ch) + Expect(err).ToNot(HaveOccurred()) + + var content string + for _, r := range drainReplies(ch) { + content += string(r.GetMessage()) + } + Expect(content).To(Equal("a dog")) + + calls, _ := fake.snapshot() + Expect(calls).To(HaveLen(1)) + Expect(calls[0].imagesJSON).To(MatchJSON(`["` + imgURI + `"]`)) + Expect(calls[0].prompt).To(ContainSubstring("And this?")) + }) + + It("keeps image order: one data-URI entry per image, one marker each", func() { + fake := &fakeGen{out: "x"} + d := newTestDllm(fake, nil) + + _, err := d.PredictRich(&pb.PredictOptions{ + UseTokenizerTemplate: true, + Messages: []*pb.Message{{Role: "user", Content: "Compare."}}, + Images: []string{"QQ==", "Qg=="}, // base64("A"), base64("B") + }) + Expect(err).ToNot(HaveOccurred()) + + calls, _ := fake.snapshot() + Expect(calls[0].imagesJSON).To(MatchJSON( + `["data:image/jpeg;base64,QQ==","data:image/jpeg;base64,Qg=="]`)) + Expect(calls[0].prompt).To(ContainSubstring("Compare.")) + }) + + It("keeps text-only requests on the text entry points (old libs stay usable)", func() { + fake := &fakeGen{out: "x"} + d := newTestDllm(fake, nil) + + _, err := d.PredictRich(&pb.PredictOptions{ + UseTokenizerTemplate: true, + Messages: []*pb.Message{{Role: "user", Content: "hi"}}, + }) + Expect(err).ToNot(HaveOccurred()) + + calls, _ := fake.snapshot() + Expect(calls[0].imagesJSON).To(BeEmpty(), "text-only request must not dispatch to the mm entry points") + Expect(calls[0].prompt).ToNot(ContainSubstring(mmImageMarker)) + }) + + It("routes raw-prompt (non-templated) image requests through generateMM verbatim", func() { + // Without use_tokenizer_template the caller owns marker placement; + // the backend must not inject anything, just forward the images. + fake := &fakeGen{out: "x"} + d := newTestDllm(fake, nil) + + _, err := d.PredictRich(&pb.PredictOptions{ + Prompt: "look: here", + Images: []string{imgB64}, + }) + Expect(err).ToNot(HaveOccurred()) + + calls, _ := fake.snapshot() + Expect(calls[0].prompt).To(Equal("look: here")) + Expect(calls[0].imagesJSON).To(MatchJSON(`["` + imgURI + `"]`)) + }) + + It("rejects video and audio inputs loudly", func() { + fake := &fakeGen{out: "x"} + d := newTestDllm(fake, nil) + + _, err := d.PredictRich(&pb.PredictOptions{Prompt: "p", Videos: []string{"vvv"}}) + Expect(err).To(MatchError(ContainSubstring("not supported"))) + + ch := make(chan *pb.Reply, 1) + err = d.PredictStreamRich(&pb.PredictOptions{Prompt: "p", Audios: []string{"aaa"}}, ch) + Expect(err).To(MatchError(ContainSubstring("not supported"))) + + calls, _ := fake.snapshot() + Expect(calls).To(BeEmpty(), "unsupported media must be rejected before any generate call") + }) + + It("fails with a clear error against a libdllm.so without the mm entry points", func() { + // Simulate the old-library probe outcome regardless of whether the + // gated specs loaded a real (mm-capable) libdllm.so first. + oldGen, oldStream := cppGenerateMM, cppGenerateStreamMM + cppGenerateMM, cppGenerateStreamMM = nil, nil + DeferCleanup(func() { cppGenerateMM, cppGenerateStreamMM = oldGen, oldStream }) + + g := &capiGenerator{h: 0} + _, err := g.generateMM("p", `["data:image/png;base64,QQ=="]`, "{}") + Expect(err).To(MatchError(errMMUnsupported)) + err = g.generateStreamMM("p", `["data:image/png;base64,QQ=="]`, "{}", func(string) {}) + Expect(err).To(MatchError(errMMUnsupported)) + // The message must tell the operator what to do, not just fail. + Expect(errMMUnsupported.Error()).To(ContainSubstring("rebuild/upgrade")) + }) + }) + + Describe("imagesJSON", func() { + It("returns empty for no images (text-path sentinel)", func() { + out, err := imagesJSON(nil) + Expect(err).ToNot(HaveOccurred()) + Expect(out).To(BeEmpty()) + }) + + It("wraps raw base64 payloads as data: URIs", func() { + out, err := imagesJSON([]string{"QQ==", "Qg=="}) + Expect(err).ToNot(HaveOccurred()) + Expect(out).To(MatchJSON(`["data:image/jpeg;base64,QQ==","data:image/jpeg;base64,Qg=="]`)) + }) + + It("rejects entries that cannot survive the C side's verbatim (no-escape) parser", func() { + // dllm_capi.h: entries are read verbatim up to the closing quote; + // a quote or backslash would be JSON-escaped here and misparsed + // there, so fail loud instead. + _, err := imagesJSON([]string{`with"quote`}) + Expect(err).To(MatchError(ContainSubstring("not base64"))) + _, err = imagesJSON([]string{`with\backslash`}) + Expect(err).To(MatchError(ContainSubstring("not base64"))) + }) + }) + Describe("legacy Predict/PredictStream adapters", func() { It("Predict returns the aggregated content string", func() { fake := &fakeGen{out: "plain text"} @@ -805,3 +992,107 @@ var _ = Describe("Dllm backend (real tiny model)", func() { GinkgoWriter.Printf("dllm cancel: PredictStreamRich returned %v after Cancel\n", latency) }) }) + +// --------------------------------------------------------------------------- +// Gated multimodal round-trip against the real libdllm.so + the tiny VISION +// GGUF fixture (dllm.cpp tests/fixtures/tiny_vision_with_vocab.gguf: random +// weights, the same handcrafted vocab as tiny_with_vocab.gguf, plus a tiny +// vision tower). Additional gates on top of the text suite: +// +// DLLM_TEST_TINY_MODEL must point at tiny_vision_with_vocab.gguf +// DLLM_TEST_IMAGE a decodable image fixture +// (dllm.cpp tests/fixtures/test_image_24x17.bmp) +// --------------------------------------------------------------------------- + +var _ = Describe("Dllm backend (real tiny vision model)", func() { + var imageB64 string + + BeforeEach(func() { + if os.Getenv("DLLM_TEST_LIBRARY") == "" || os.Getenv("DLLM_TEST_TINY_MODEL") == "" || os.Getenv("DLLM_TEST_IMAGE") == "" { + Skip("set DLLM_TEST_LIBRARY, DLLM_TEST_TINY_MODEL (tiny_vision_with_vocab.gguf) and DLLM_TEST_IMAGE to run the vision round-trip") + } + ensureLibLoaded() + Expect(libLoadErr).ToNot(HaveOccurred()) + Expect(cMMSupported()).To(BeTrue(), "this libdllm.so lacks the mm entry points; rebuild dllm.cpp") + + // Deliver the image exactly as LocalAI core does: a raw base64 + // payload in PredictOptions.Images (no data: prefix). + raw, err := os.ReadFile(os.Getenv("DLLM_TEST_IMAGE")) + Expect(err).ToNot(HaveOccurred()) + imageB64 = base64.StdEncoding.EncodeToString(raw) + }) + + // loadVisionDllm loads the tiny vision fixture with eb_max_steps:4 (the + // tiny tower still resizes every image to the full 280-soft-token patch + // budget, so capping the denoise loop keeps the prefill-heavy mm runs + // fast - same trick as dllm.cpp's own test_capi_dlopen mm section). + loadVisionDllm := func() *Dllm { + d := &Dllm{} + Expect(d.Load(&pb.ModelOptions{ + ModelFile: os.Getenv("DLLM_TEST_TINY_MODEL"), + Options: []string{"eb_max_steps:4"}, + })).To(Succeed()) + DeferCleanup(func() { Expect(d.Free()).To(Succeed()) }) + return d + } + + It("answers a templated image request deterministically and streams it", func() { + d := loadVisionDllm() + + req := func() *pb.PredictOptions { + return &pb.PredictOptions{ + UseTokenizerTemplate: true, + Messages: []*pb.Message{{Role: "user", Content: "hello"}}, + Images: []string{imageB64}, + Tokens: 16, + Seed: 7, + } + } + + // Non-streaming, twice with the same seed: the full pipeline (data-URI + // decode -> BMP decode -> preprocess -> vision tower -> splice -> + // diffusion) must be deterministic. + reply1, err := d.PredictRich(req()) + Expect(err).ToNot(HaveOccurred()) + Expect(string(reply1.GetMessage())).ToNot(BeEmpty()) + Expect(reply1.GetChatDeltas()).ToNot(BeEmpty()) + + reply2, err := d.PredictRich(req()) + Expect(err).ToNot(HaveOccurred()) + Expect(string(reply2.GetMessage())).To(Equal(string(reply1.GetMessage()))) + + // The image must CHANGE the generation: same prompt and seed without + // it goes through the text path and must diverge (soft embeddings + // shift every position after the splice). + textOnly := req() + textOnly.Images = nil + replyText, err := d.PredictRich(textOnly) + Expect(err).ToNot(HaveOccurred()) + Expect(string(replyText.GetMessage())).ToNot(Equal(string(reply1.GetMessage()))) + + // Streaming variant over the same request shape. + ch := make(chan *pb.Reply, 64) + Expect(d.PredictStreamRich(req(), ch)).To(Succeed()) + replies := drainReplies(ch) + Expect(replies).ToNot(BeEmpty()) + var streamed string + for _, r := range replies { + streamed += string(r.GetMessage()) + } + Expect(streamed).ToNot(BeEmpty()) + }) + + It("surfaces the engine's marker/image mismatch error", func() { + d := loadVisionDllm() + + // Raw-prompt mode with an image but no marker: the engine must + // reject the 0-marker/1-image mismatch through last_error. + _, err := d.PredictRich(&pb.PredictOptions{ + Prompt: "hello", + Images: []string{imageB64}, + Tokens: 16, + Seed: 7, + }) + Expect(err).To(MatchError(ContainSubstring("markers"))) + }) +}) diff --git a/backend/go/dllm/gemma4_renderer.go b/backend/go/dllm/gemma4_renderer.go index 868d98e4a8a1..ddaeadc2ef13 100644 --- a/backend/go/dllm/gemma4_renderer.go +++ b/backend/go/dllm/gemma4_renderer.go @@ -440,7 +440,26 @@ type gemma4ToolCall struct { // // enableThinking maps to the template's enable_thinking flag (ds4 convention: // Metadata["enable_thinking"]); addGenerationPrompt to add_generation_prompt. -func RenderGemma4(msgs []*pb.Message, toolsJSON string, enableThinking bool, addGenerationPrompt bool) (string, error) { +// +// IMAGE NOTE (tpl L323-L342): the template's content-parts branch renders +// one <|image|> token per image part, at the part's position. Through pb +// that branch is unreachable: LocalAI's OpenAI layer flattens content parts +// before the backend sees them - text parts are concatenated into +// pb.Message.Content (core/schema/message.go ToProto) and image parts are +// decoded to raw base64 in PredictOptions.Images (core/http/middleware/ +// request.go), losing per-message attribution and intra-message position. +// The llama.cpp backend's convention for the same flattened delivery is to +// attach ALL request images to the LAST user message, text first then +// images (grpc-server.cpp, "Add text first" in the last-user-msg branch); +// nImages mirrors that: one marker per image appended directly after the +// last user message's text, in image order (the template emits parts +// back-to-back with no separator either). The marker emitted is the ENGINE +// splice marker mmImageMarker ("", dllm_capi.h placeholder +// contract), NOT the template's <|image|> text token: the engine expands +// "" to + soft-token placeholders + and splices the +// vision embeddings there, whereas a literal <|image|> would just tokenize +// as text and leave a marker/image count mismatch. +func RenderGemma4(msgs []*pb.Message, toolsJSON string, nImages int, enableThinking bool, addGenerationPrompt bool) (string, error) { // Fail loud on roles the template does not know about. The jinja would // happily render any role as a generic turn; we reject instead so typos // surface at the API boundary rather than as silent bad prompts. @@ -493,13 +512,20 @@ func RenderGemma4(msgs []*pb.Message, toolsJSON string, enableThinking bool, add b.WriteString(gemma4TurnClose) // tpl L204 } - // Pre-scan: last user message index for the reasoning guard, tpl L207-L213. + // Pre-scan: last user message index for the reasoning guard, tpl L207-L213 + // (also the image attachment point - see the IMAGE NOTE). lastUserIdx := -1 for i, m := range loopMsgs { if m.GetRole() == "user" { lastUserIdx = i } } + if nImages > 0 && lastUserIdx == -1 { + // No user turn to attach the markers to: the engine would reject the + // markerless prompt anyway (marker/image count mismatch), so surface + // the bad request here with a usable message. + return "", fmt.Errorf("dllm: gemma4 renderer: %d image(s) provided but no user message to attach them to", nImages) + } // Message loop, tpl L215-L354. role=tool messages are skipped here: they // are rendered by the forward-scan from their assistant tool_calls turn. @@ -588,13 +614,21 @@ func RenderGemma4(msgs []*pb.Message, toolsJSON string, enableThinking bool, add // Captured content, tpl L316-L345. Model content gets thinking // channels stripped (strip_thinking, tpl L148-L158); other roles are // trimmed. pb content is a flattened string: the content-parts array - // branch (tpl L322-L342, incl. <|image|> markers) is unreachable. + // branch (tpl L322-L342) is unreachable through it - the image part + // of that branch is reconstructed below from PredictOptions.Images + // (see the IMAGE NOTE on RenderGemma4). var content string if role == "model" { content = stripGemma4Thinking(m.GetContent()) } else { content = strings.TrimSpace(m.GetContent()) } + if i == lastUserIdx && nImages > 0 { + // Markers are part of captured_content in the template (an + // image-only message still counts as has_content and closes its + // turn), so append before the hasContent computation. + content += strings.Repeat(mmImageMarker, nImages) + } b.WriteString(content) hasContent := strings.TrimSpace(content) != "" // tpl L346 diff --git a/backend/go/dllm/gemma4_renderer_test.go b/backend/go/dllm/gemma4_renderer_test.go index 3600fbf7a07a..033d909d705b 100644 --- a/backend/go/dllm/gemma4_renderer_test.go +++ b/backend/go/dllm/gemma4_renderer_test.go @@ -45,8 +45,13 @@ const complexToolsJSON = `[{"type":"function","function":{"name":"complex_tool", const complexToolsBlock = `<|tool>declaration:complex_tool{description:<|"|>A complex tool.<|"|>,parameters:{properties:{matrix:{items:{items:{<|"|>type<|"|>:<|"|>number<|"|>},type:<|"|>ARRAY<|"|>},type:<|"|>ARRAY<|"|>},mode:{enum:[<|"|>a<|"|>,<|"|>b<|"|>],nullable:true,type:<|"|>STRING<|"|>},opts:{description:<|"|>Options.<|"|>,properties:{depth:{nullable:true,type:<|"|>INTEGER<|"|>}},required:[<|"|>depth<|"|>],type:<|"|>OBJECT<|"|>},tags:{description:<|"|>Tags.<|"|>,items:{type:<|"|>STRING<|"|>},type:<|"|>ARRAY<|"|>}},required:[<|"|>tags<|"|>,<|"|>opts<|"|>],type:<|"|>OBJECT<|"|>},response:{description:<|"|>The result.<|"|>,type:<|"|>OBJECT<|"|>}}` type renderGemma4Case struct { - msgs []*pb.Message - toolsJSON string + msgs []*pb.Message + toolsJSON string + // nImages mirrors len(PredictOptions.Images): the OpenAI layer strips + // image content parts out of the messages, so the renderer re-injects + // one engine marker per image on the last user message (see the IMAGE + // NOTE on RenderGemma4). + nImages int enableThinking bool noGenerationPrompt bool // inverted so the zero value is the common case expected string @@ -55,7 +60,7 @@ type renderGemma4Case struct { var _ = Describe("RenderGemma4", func() { DescribeTable("renders the canonical gemma4 prompt", func(c renderGemma4Case) { - out, err := RenderGemma4(c.msgs, c.toolsJSON, c.enableThinking, !c.noGenerationPrompt) + out, err := RenderGemma4(c.msgs, c.toolsJSON, c.nImages, c.enableThinking, !c.noGenerationPrompt) Expect(err).ToNot(HaveOccurred()) Expect(out).To(Equal(c.expected)) // The C-ABI generate prepends BOS itself: a literal @@ -273,13 +278,55 @@ var _ = Describe("RenderGemma4", func() { noGenerationPrompt: true, expected: "<|turn>user\nhi\n", }), + + // One engine marker per image, appended directly after the user + // text with no separator (tpl L323-L341 emits parts back-to-back; + // "" is dllm_capi.h's splice marker, not the template's + // <|image|> text token - see the IMAGE NOTE on RenderGemma4). + Entry("one image appends one engine marker to the user message", renderGemma4Case{ + msgs: []*pb.Message{ + {Role: "user", Content: "What is in this picture?"}, + }, + nImages: 1, + expected: "<|turn>user\nWhat is in this picture?\n<|turn>model\n<|channel>thought\n", + }), + + Entry("multiple images append markers in image order", renderGemma4Case{ + msgs: []*pb.Message{ + {Role: "user", Content: "Compare these."}, + }, + nImages: 3, + expected: "<|turn>user\nCompare these.\n<|turn>model\n<|channel>thought\n", + }), + + // Flattened delivery loses per-message attribution, so all images + // attach to the LAST user message (llama.cpp grpc-server convention). + Entry("images attach to the last user message in multi-turn", renderGemma4Case{ + msgs: []*pb.Message{ + {Role: "user", Content: "hi"}, + {Role: "assistant", Content: "hello"}, + {Role: "user", Content: "and this?"}, + }, + nImages: 1, + expected: "<|turn>user\nhi\n<|turn>model\nhello\n<|turn>user\nand this?\n<|turn>model\n<|channel>thought\n", + }), + + // tpl L346: the markers count as captured_content, so an image-only + // user message still has content and closes its turn normally. + Entry("image with empty user text still closes the turn", renderGemma4Case{ + msgs: []*pb.Message{ + {Role: "user", Content: ""}, + }, + nImages: 1, + expected: "<|turn>user\n\n<|turn>model\n<|channel>thought\n", + }), ) Describe("error handling", func() { It("fails loud on an unknown role", func() { _, err := RenderGemma4([]*pb.Message{ {Role: "narrator", Content: "Meanwhile..."}, - }, "", false, true) + }, "", 0, false, true) Expect(err).To(HaveOccurred()) Expect(err.Error()).To(ContainSubstring(`unknown role "narrator"`)) }) @@ -287,7 +334,7 @@ var _ = Describe("RenderGemma4", func() { It("fails on invalid tools JSON", func() { _, err := RenderGemma4([]*pb.Message{ {Role: "user", Content: "hi"}, - }, "{not json", false, true) + }, "{not json", 0, false, true) Expect(err).To(HaveOccurred()) Expect(err.Error()).To(ContainSubstring("tools JSON")) }) @@ -296,7 +343,7 @@ var _ = Describe("RenderGemma4", func() { _, err := RenderGemma4([]*pb.Message{ {Role: "user", Content: "hi"}, {Role: "assistant", Content: "", ToolCalls: "{not json"}, - }, "", false, true) + }, "", 0, false, true) Expect(err).To(HaveOccurred()) Expect(err.Error()).To(ContainSubstring("tool_calls JSON")) }) @@ -307,7 +354,7 @@ var _ = Describe("RenderGemma4", func() { _, err := RenderGemma4([]*pb.Message{ {Role: "user", Content: "hi"}, {Role: "tool", Content: `{"temp": 20}`, ToolCallId: "call_1"}, - }, "", false, true) + }, "", 0, false, true) Expect(err).To(HaveOccurred()) Expect(err.Error()).To(ContainSubstring("orphan tool message 1")) }) @@ -315,7 +362,7 @@ var _ = Describe("RenderGemma4", func() { It("fails on trailing garbage after the tools JSON array", func() { _, err := RenderGemma4([]*pb.Message{ {Role: "user", Content: "hi"}, - }, "[] junk", false, true) + }, "[] junk", 0, false, true) Expect(err).To(HaveOccurred()) Expect(err.Error()).To(ContainSubstring("tools JSON")) }) @@ -323,7 +370,7 @@ var _ = Describe("RenderGemma4", func() { It("fails when the tools JSON is not an array", func() { _, err := RenderGemma4([]*pb.Message{ {Role: "user", Content: "hi"}, - }, `{"type":"function"}`, false, true) + }, `{"type":"function"}`, 0, false, true) Expect(err).To(HaveOccurred()) Expect(err.Error()).To(ContainSubstring("tools JSON is not an array")) }) @@ -331,7 +378,7 @@ var _ = Describe("RenderGemma4", func() { It("fails when a tools array element is not an object", func() { _, err := RenderGemma4([]*pb.Message{ {Role: "user", Content: "hi"}, - }, `[42]`, false, true) + }, `[42]`, 0, false, true) Expect(err).To(HaveOccurred()) Expect(err.Error()).To(ContainSubstring("tools[0] is not an object")) }) @@ -339,9 +386,21 @@ var _ = Describe("RenderGemma4", func() { It("rejects a nil message via the unknown-role check", func() { // Pins current behavior: pb getters are nil-safe, so a nil message // reads as role "" and trips the fail-loud unknown-role guard. - _, err := RenderGemma4([]*pb.Message{nil}, "", false, true) + _, err := RenderGemma4([]*pb.Message{nil}, "", 0, false, true) Expect(err).To(HaveOccurred()) Expect(err.Error()).To(ContainSubstring(`unknown role "" in message 0`)) }) + + It("fails loud on images with no user message to attach them to", func() { + // The engine would reject the markerless prompt anyway + // (marker/image count mismatch); the renderer surfaces the bad + // request with a usable message instead. + _, err := RenderGemma4([]*pb.Message{ + {Role: "system", Content: "sys"}, + {Role: "assistant", Content: "hi"}, + }, "", 1, false, true) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("no user message")) + }) }) }) diff --git a/backend/go/dllm/main.go b/backend/go/dllm/main.go index 41d4368f2752..9657a8ab8aba 100644 --- a/backend/go/dllm/main.go +++ b/backend/go/dllm/main.go @@ -2,8 +2,9 @@ package main // Started internally by LocalAI - one gRPC server per loaded model. // -// Loads libdllm.so via purego and registers the 9-symbol flat C-ABI -// declared in dllm.cpp's include/dllm_capi.h (ABI v1). The library name can +// Loads libdllm.so via purego and registers the flat C-ABI declared in +// dllm.cpp's include/dllm_capi.h (ABI v1): 9 mandatory symbols plus the +// Dlsym-probed optional multimodal pair. The library name can // be overridden with DLLM_LIBRARY (mirrors the PARAKEET_LIBRARY / // WHISPER_LIBRARY convention in the sibling backends); the default looks // for the .so next to this binary (run.sh puts the package dir on @@ -57,6 +58,18 @@ func loadCAPI(libName string) error { for _, lf := range libFuncs { purego.RegisterLibFunc(lf.FuncPtr, lib, lf.Name) } + + // Multimodal entry points (dllm_capi.h's P4 surface). Additive: the ABI + // version stays 1 and consumers detect the surface by probing the symbols + // (the parakeet-cpp optional-symbol pattern), so the backend still loads + // against an older text-only libdllm.so - image requests then fail with + // errMMUnsupported instead of a boot failure. + if sym, err := purego.Dlsym(lib, "dllm_capi_generate_mm"); err == nil && sym != 0 { + purego.RegisterLibFunc(&cppGenerateMM, lib, "dllm_capi_generate_mm") + } + if sym, err := purego.Dlsym(lib, "dllm_capi_generate_stream_mm"); err == nil && sym != 0 { + purego.RegisterLibFunc(&cppGenerateStreamMM, lib, "dllm_capi_generate_stream_mm") + } return nil } @@ -75,7 +88,7 @@ func main() { if v := cAbiVersion(); v != dllmABIVersion { panic(fmt.Errorf("dllm: libdllm.so ABI=%d, this backend speaks ABI=%d", v, dllmABIVersion)) } - fmt.Fprintf(os.Stderr, "[dllm] ABI=%d\n", cAbiVersion()) + fmt.Fprintf(os.Stderr, "[dllm] ABI=%d multimodal=%t\n", cAbiVersion(), cMMSupported()) flag.Parse() diff --git a/tests/e2e-backends/dllm_test.go b/tests/e2e-backends/dllm_test.go index 5c7766c98d92..ea54301ad1fa 100644 --- a/tests/e2e-backends/dllm_test.go +++ b/tests/e2e-backends/dllm_test.go @@ -2,6 +2,7 @@ package e2ebackends_test import ( "context" + "encoding/base64" "fmt" "io" "net" @@ -39,6 +40,16 @@ import ( // BACKEND_TEST_MODEL_FILE dllm.cpp's tests/fixtures/tiny_with_vocab.gguf // (random weights + handcrafted 43-token gemma4 vocab) // +// Tiny vision spec (same gating as the tiny-model spec, plus an image): +// +// BACKEND_TEST_DLLM_IMAGE a decodable image fixture (dllm.cpp's +// tests/fixtures/test_image_24x17.bmp); setting it +// enables the spec. BACKEND_TEST_MODEL_FILE must +// then point at tiny_vision_with_vocab.gguf (the +// tiny fixture WITH a vision tower) and the +// packaged libdllm.so must carry the multimodal +// C-ABI entry points (dllm.cpp >= the P4 surface). +// // Real-model spec (the 26B BF16 GGUF, ~50 GB; CUDA-13-class hardware): // // BACKEND_TEST_DLLM_REAL_MODEL_FILE path to diffusiongemma-26B-A4B-it-BF16.gguf; @@ -280,6 +291,74 @@ var _ = Describe("dllm request cancellation (tiny model)", Ordered, func() { }) }) +var _ = Describe("dllm templated vision chat-completion (tiny vision model)", Ordered, func() { + var client pb.BackendClient + var imageB64 string + + BeforeAll(func() { + if os.Getenv("BACKEND_TEST_DLLM") != "1" { + Skip("dllm vision spec is opt-in; set BACKEND_TEST_DLLM=1 (plus BACKEND_BINARY, BACKEND_TEST_MODEL_FILE and BACKEND_TEST_DLLM_IMAGE) to run it") + } + imagePath := os.Getenv("BACKEND_TEST_DLLM_IMAGE") + if imagePath == "" { + Skip("dllm vision spec requires BACKEND_TEST_DLLM_IMAGE (dllm.cpp's tests/fixtures/test_image_24x17.bmp)") + } + modelFile := os.Getenv("BACKEND_TEST_MODEL_FILE") + Expect(modelFile).NotTo(BeEmpty(), + "dllm vision spec requires BACKEND_TEST_MODEL_FILE (dllm.cpp's tests/fixtures/tiny_vision_with_vocab.gguf)") + + // Deliver the image exactly as LocalAI core does: a raw base64 + // payload in PredictOptions.Images (core decodes every image_url / + // image content part to plain base64 before it reaches the backend). + raw, err := os.ReadFile(imagePath) + Expect(err).NotTo(HaveOccurred()) + imageB64 = base64.StdEncoding.EncodeToString(raw) + + // eb_max_steps:4 keeps the prefill-heavy mm runs fast: the tiny + // vision tower still resizes every image to the full 280-soft-token + // patch budget (same trick as dllm.cpp's own mm tests). + client = startDllmBackend(modelFile, 0, "eb_max_steps:4") + }) + + It("answers a templated chat completion with an image attached", func() { + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Minute) + defer cancel() + req := dllmChatRequest() + req.Images = []string{imageB64} + res, err := client.Predict(ctx, req) + Expect(err).NotTo(HaveOccurred()) + Expect(string(res.GetMessage())).NotTo(BeEmpty(), "vision chat completion produced empty content") + Expect(res.GetChatDeltas()).NotTo(BeEmpty(), "vision chat completion produced no ChatDeltas") + GinkgoWriter.Printf("dllm vision chat: %q (deltas=%d)\n", string(res.GetMessage()), len(res.GetChatDeltas())) + }) + + It("streams a templated chat completion with an image attached", func() { + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Minute) + defer cancel() + req := dllmChatRequest() + req.Images = []string{imageB64} + stream, err := client.PredictStream(ctx, req) + Expect(err).NotTo(HaveOccurred()) + + var chunks int + var combined string + for { + msg, rerr := stream.Recv() + if rerr == io.EOF { + break + } + Expect(rerr).NotTo(HaveOccurred()) + if len(msg.GetMessage()) > 0 { + chunks++ + combined += string(msg.GetMessage()) + } + } + Expect(chunks).To(BeNumerically(">=", 1), "no vision stream chunks received") + Expect(combined).NotTo(BeEmpty(), "streamed vision chat completion produced empty content") + GinkgoWriter.Printf("dllm vision chat stream: %d chunks, combined=%q\n", chunks, combined) + }) +}) + var _ = Describe("dllm templated chat-completion (real model)", Ordered, func() { var client pb.BackendClient From b75ab7c3bbe582af8822bdc7ccf14d1f67a65c67 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Sat, 13 Jun 2026 00:00:42 +0000 Subject: [PATCH 13/13] chore(dllm): bump dllm.cpp pin to P5 head Signed-off-by: Ettore Di Giacinto --- backend/go/dllm/Makefile | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/backend/go/dllm/Makefile b/backend/go/dllm/Makefile index 9a0f0aeae84c..2ba0183a3c34 100644 --- a/backend/go/dllm/Makefile +++ b/backend/go/dllm/Makefile @@ -19,11 +19,15 @@ # until then the anonymous clone below fails. Use the symlink shortcut above # with a local checkout, or a git credential helper with access to the repo. -# The pin below is the first commit carrying the multimodal C-ABI entry -# points (dllm_capi_generate_mm / dllm_capi_generate_stream_mm) the -# image-input path probes for; older libs still load, but image requests -# then fail with "library predates the multimodal entry points". -DLLM_VERSION?=e6dcf44cddd65845e3a0814a1c2282a5d90ee98a +# The pin below is the P5 performance-parity head (device-resident +# self-conditioning, full-GPU placement at ngl >= n_layer, graph reuse, +# device-side EB reductions: ~8x per-step on GB10, see dllm.cpp +# docs/validation.md section 10). C-ABI unchanged (still version 1). It +# also carries the multimodal entry points (dllm_capi_generate_mm / +# dllm_capi_generate_stream_mm) the image-input path probes for; older +# libs still load, but image requests then fail with "library predates +# the multimodal entry points". +DLLM_VERSION?=320b57756efc3460169b8ea9e8c782867198f2a5 DLLM_REPO?=https://github.com/mudler/dllm.cpp GOCMD?=go