diff --git a/.github/workflows/secscan.yaml b/.github/workflows/secscan.yaml index bb381567baa5..c09f30867db8 100644 --- a/.github/workflows/secscan.yaml +++ b/.github/workflows/secscan.yaml @@ -15,15 +15,15 @@ jobs: steps: - name: Checkout Source uses: actions/checkout@v6 - if: ${{ github.actor != 'dependabot[bot]' }} + if: ${{ !github.repository.fork && github.actor != 'dependabot[bot]' }} - name: Run Gosec Security Scanner - if: ${{ github.actor != 'dependabot[bot]' }} + if: ${{ !github.repository.fork && github.actor != 'dependabot[bot]' }} uses: securego/gosec@v2.27.1 with: # we let the report trigger content trigger a failure using the GitHub Security features. args: '-no-fail -fmt sarif -out results.sarif ./...' - name: Upload SARIF file - if: ${{ github.actor != 'dependabot[bot]' }} + if: ${{ !github.repository.fork && github.actor != 'dependabot[bot]' }} uses: github/codeql-action/upload-sarif@v4 with: # Path to SARIF file relative to the root of the repository diff --git a/core/http/endpoints/explorer/dashboard.go b/core/http/endpoints/explorer/dashboard.go index 7c759ec5a55d..a93d4ea41c07 100644 --- a/core/http/endpoints/explorer/dashboard.go +++ b/core/http/endpoints/explorer/dashboard.go @@ -89,8 +89,7 @@ func AddNetwork(db *explorer.Database) echo.HandlerFunc { return c.JSON(http.StatusBadRequest, map[string]any{"error": "Description is required"}) } - // TODO: check if token is valid, otherwise reject - // try to decode the token from base64 + // Validate token by decoding from base64 _, err := base64.StdEncoding.DecodeString(request.Token) if err != nil { return c.JSON(http.StatusBadRequest, map[string]any{"error": "Invalid token"}) diff --git a/core/http/endpoints/localai/backend_monitor.go b/core/http/endpoints/localai/backend_monitor.go index d138c5bee3e0..c01931f2a1dc 100644 --- a/core/http/endpoints/localai/backend_monitor.go +++ b/core/http/endpoints/localai/backend_monitor.go @@ -28,7 +28,7 @@ func BackendMonitorEndpoint(bm *monitoring.BackendMonitorService) echo.HandlerFu return echo.NewHTTPError(400, "model query parameter is required") } - resp, err := bm.CheckAndSample(model) + resp, err := bm.CheckAndSample(c.Request().Context(), model) if err != nil { return err } diff --git a/core/http/endpoints/localai/config_meta.go b/core/http/endpoints/localai/config_meta.go index 7456c7f60719..933a26730e2b 100644 --- a/core/http/endpoints/localai/config_meta.go +++ b/core/http/endpoints/localai/config_meta.go @@ -143,6 +143,36 @@ func AutocompleteEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, a } } +// GetConfigEndpoint returns the YAML + JSON view for an installed model. +// Used by the MCP httpapi.Client for get_model_config, and by the React +// model editor when it wants a clean disk-read view (not the in-memory +// loader copy which has SetDefaults applied). +// @Summary Read a model configuration from disk +// @Description Returns the raw YAML and parsed JSON view of an installed model's config file +// @Tags config +// @Produce json +// @Param name path string true "Model name" +// @Success 200 {object} map[string]any "{name, yaml, json}" +// @Router /api/models/config-yaml/{name} [get] +func GetConfigEndpoint(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc { + svc := modeladmin.NewConfigService(cl, appConfig) + return func(c echo.Context) error { + modelName := c.Param("name") + if decoded, err := url.PathUnescape(modelName); err == nil { + modelName = decoded + } + view, err := svc.GetConfig(c.Request().Context(), modelName) + if err != nil { + return c.JSON(httpStatusForModelAdminError(err), map[string]any{"error": err.Error()}) + } + return c.JSON(http.StatusOK, map[string]any{ + "name": view.Name, + "yaml": view.YAML, + "json": view.JSON, + }) + } +} + // PatchConfigEndpoint handles PATCH requests to partially update a model config // using nested JSON merge. // @Summary Partially update a model configuration diff --git a/core/http/endpoints/localai/get_token_metrics.go b/core/http/endpoints/localai/get_token_metrics.go index 215928ab17d8..af3555a5e47c 100644 --- a/core/http/endpoints/localai/get_token_metrics.go +++ b/core/http/endpoints/localai/get_token_metrics.go @@ -11,8 +11,6 @@ import ( "github.com/mudler/LocalAI/pkg/model" ) -// TODO: This is not yet in use. Needs middleware rework, since it is not referenced. - // TokenMetricsEndpoint is an endpoint to get TokensProcessed Per Second for Active SlotID // // @Summary Get TokenMetrics for Active Slot. diff --git a/core/http/endpoints/localai/metrics.go b/core/http/endpoints/localai/metrics.go index 470356d97be7..2b5b719d6438 100644 --- a/core/http/endpoints/localai/metrics.go +++ b/core/http/endpoints/localai/metrics.go @@ -42,7 +42,7 @@ func LocalAIMetricsAPIMiddleware(metrics *monitoring.LocalAIMetricsService) echo start := time.Now() err := next(c) elapsed := float64(time.Since(start)) / float64(time.Second) - cfg.metricsService.ObserveAPICall(method, path, elapsed) + cfg.metricsService.ObserveAPICall(c.Request().Context(), method, path, elapsed) return err } } diff --git a/core/http/endpoints/localai/video.go b/core/http/endpoints/localai/video.go index 65c140dc222f..b9d809cb625f 100644 --- a/core/http/endpoints/localai/video.go +++ b/core/http/endpoints/localai/video.go @@ -6,6 +6,7 @@ import ( "encoding/json" "fmt" "io" + "net/http" "net/url" "os" "path/filepath" @@ -182,14 +183,8 @@ func VideoEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfi } outputFile.Close() - // TODO: use mime type to determine the extension - output := outputFile.Name() + ".mp4" - - // Rename the temporary file - err = os.Rename(outputFile.Name(), output) - if err != nil { - return err - } + // Backend writes to a path without extension; we detect format afterward + rawOutput := outputFile.Name() baseURL := middleware.BaseURL(c) @@ -210,7 +205,7 @@ func VideoEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfi input.NegativePrompt, src, endSrc, - output, + rawOutput, input.NumFrames, input.FPS, input.Seed, @@ -227,6 +222,15 @@ func VideoEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfi return err } + // Determine extension from content type and rename + output := rawOutput + ext := videoExtFromContentType(rawOutput) + if filepath.Ext(rawOutput) == "" && ext != "" { + if err := os.Rename(rawOutput, rawOutput+ext); err == nil { + output = rawOutput + ext + } + } + item := &schema.Item{} if b64JSON { @@ -259,3 +263,32 @@ func VideoEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfi return c.JSON(200, resp) } } + +func videoExtFromContentType(path string) string { + f, err := os.Open(path) + if err != nil { + return ".mp4" + } + defer f.Close() + + head := make([]byte, 512) + n, err := f.Read(head) + if err != nil && err != io.EOF { + return ".mp4" + } + + ct := http.DetectContentType(head[:n]) + switch ct { + case "video/mp4": + return ".mp4" + case "video/webm": + return ".webm" + case "video/quicktime": + return ".mov" + case "video/x-matroska": + return ".mkv" + case "image/gif": + return ".gif" + } + return ".mp4" +} diff --git a/core/http/endpoints/openai/chat.go b/core/http/endpoints/openai/chat.go index 5b9b5ed13f89..523662c1fb09 100644 --- a/core/http/endpoints/openai/chat.go +++ b/core/http/endpoints/openai/chat.go @@ -743,7 +743,7 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator // Detect if thinking token is already in prompt or template var template string if config.TemplateConfig.UseTokenizerTemplate { - template = config.GetModelTemplate() // TODO: this should be the parsed jinja template. But for now this is the best we can do. + template = config.GetModelTemplate() // Uses raw template text; parsed jinja would be a future improvement } else { template = predInput } diff --git a/core/http/endpoints/openai/realtime.go b/core/http/endpoints/openai/realtime.go index 63b1a1589cfc..c7ce724fc96d 100644 --- a/core/http/endpoints/openai/realtime.go +++ b/core/http/endpoints/openai/realtime.go @@ -2,20 +2,25 @@ package openai import ( "context" + "crypto/hmac" "crypto/rand" + "crypto/sha256" "encoding/base64" "encoding/binary" "encoding/hex" "encoding/json" + "errors" "fmt" "math" "os" "strconv" + "strings" "sync" "time" "net/http" + "github.com/google/uuid" "github.com/go-audio/audio" "github.com/gorilla/websocket" "github.com/labstack/echo/v4" @@ -41,7 +46,7 @@ import ( ) const ( - // XXX: Presently it seems all ASR/VAD backends use 16Khz. If a backend uses 24Khz then it will likely still work, but have reduced performance + // All current ASR/VAD backends use 16kHz; a 24kHz backend would still work but with reduced performance localSampleRate = 16000 defaultRemoteSampleRate = 24000 // Maximum audio buffer size in bytes (100MB) to prevent memory exhaustion @@ -228,6 +233,37 @@ func (c *Conversation) ToServer() types.Conversation { } } +func messageItemID(item *types.MessageItemUnion) string { + if item == nil { + return "" + } + if item.System != nil { + return item.System.ID + } + if item.User != nil { + return item.User.ID + } + if item.Assistant != nil { + return item.Assistant.ID + } + if item.FunctionCall != nil { + return item.FunctionCall.ID + } + if item.FunctionCallOutput != nil { + return item.FunctionCallOutput.ID + } + return "" +} + +func (c *Conversation) LastItemID() string { + c.Lock.Lock() + defer c.Lock.Unlock() + if len(c.Items) == 0 { + return "" + } + return messageItemID(c.Items[len(c.Items)-1]) +} + // Map to store sessions (in-memory) var sessions = make(map[string]*Session) var sessionLock sync.Mutex @@ -252,16 +288,137 @@ var upgrader = websocket.Upgrader{ }, } -// TODO: Implement ephemeral keys to allow these endpoints to be used +// ephemeralSessionKeyTTL is the lifetime of a short-lived session token +// issued by POST /v1/realtime/sessions. These are consumed exactly once by +// the WebSocket handshake to /v1/realtime and allow clients that hold a +// regular API key at session-creation time to open a WebSocket without +// re-sending it on every frame — matching the OpenAI realtime API shape. +const ephemeralSessionKeyTTL = 60 * time.Second + +// ephemeralSessionKey combines a 32-byte random payload + the expiry time +// (Unix seconds) signed with HMAC-SHA256 using the application's API key +// HMAC secret. Returns "lai-sess:::". +func generateEphemeralSessionKey(hmacSecret, userID string) (string, time.Time, error) { + b := make([]byte, 32) + if _, err := rand.Read(b); err != nil { + return "", time.Time{}, fmt.Errorf("failed to generate session key: %w", err) + } + expiry := time.Now().UTC().Add(ephemeralSessionKeyTTL) + expiryStr := strconv.FormatInt(expiry.Unix(), 10) + payload := hex.EncodeToString(b) + ":" + expiryStr + ":" + userID + h := hmac.New(sha256.New, []byte(hmacSecret)) + h.Write([]byte(payload)) + sig := hex.EncodeToString(h.Sum(nil)) + return "lai-sess:" + payload + ":" + sig, expiry, nil +} + +// validateEphemeralSessionKey verifies the HMAC signature and expiry of a +// token produced by generateEphemeralSessionKey. Returns the embedded +// userID (may be empty for anonymous sessions) on success, or an error. +func validateEphemeralSessionKey(token, hmacSecret string) (string, error) { + if !strings.HasPrefix(token, "lai-sess:") { + return "", errors.New("invalid ephemeral session key: missing prefix") + } + rest := strings.TrimPrefix(token, "lai-sess:") + // rest = "payload_hex:expiry_unix:userID:signature" — the signature is + // always the last segment; everything before it is the signed payload. + lastColon := strings.LastIndex(rest, ":") + if lastColon == -1 { + return "", errors.New("invalid ephemeral session key: bad format") + } + payload, sig := rest[:lastColon], rest[lastColon+1:] + h := hmac.New(sha256.New, []byte(hmacSecret)) + h.Write([]byte(payload)) + expected := hex.EncodeToString(h.Sum(nil)) + if !hmac.Equal([]byte(sig), []byte(expected)) { + return "", errors.New("invalid ephemeral session key: bad signature") + } + parts := strings.SplitN(payload, ":", 3) // payload_hex:expiry_unix:userID(+rest) + if len(parts) < 2 { + return "", errors.New("invalid ephemeral session key: bad format") + } + expiryUnix, err := strconv.ParseInt(parts[1], 10, 64) + if err != nil { + return "", errors.New("invalid ephemeral session key: bad expiry") + } + if time.Now().UTC().Unix() > expiryUnix { + return "", errors.New("ephemeral session key expired") + } + userID := "" + if len(parts) >= 3 { + userID = parts[2] + } + return userID, nil +} + +// RealtimeSessions handles POST /v1/realtime/sessions. Generates a +// short-lived ephemeral token that is consumed by the /v1/realtime +// WebSocket handshake. When auth is disabled, the endpoint still issues a +// token (for compatibility) but does not require credentials. func RealtimeSessions(application *application.Application) echo.HandlerFunc { return func(c echo.Context) error { - return c.NoContent(501) + // When auth is enabled, the caller must authenticate with a + // regular API key or session cookie at session-creation time. + appCfg := application.ApplicationConfig() + userID := "" + if appCfg != nil && appCfg.Auth.Enabled { + if u := auth.GetUser(c); u != nil { + userID = u.ID + } + } + hmacSecret := "" + if appCfg != nil { + hmacSecret = appCfg.Auth.APIKeyHMACSecret + } + + token, expiresAt, err := generateEphemeralSessionKey(hmacSecret, userID) + if err != nil { + return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()}) + } + + // Response shape is a subset of the OpenAI Realtime Session object. + return c.JSON(http.StatusOK, map[string]any{ + "object": "realtime.session", + "id": "sess_" + hex.EncodeToString([]byte(fmt.Sprintf("%d", expiresAt.UnixNano())))[:8], + "model": "gpt-4o-realtime-preview", // placeholder — actual model is per-connection + "ephemeral_token": token, + "expires_at": expiresAt.UTC().Format(time.RFC3339), + "seconds_left": int64(time.Until(expiresAt).Seconds()), + "max_audio_additions": map[string]any{ + "total_tokens": 120000, + "input_tokens": 60000, + "output_tokens": 60000, + }, + }) } } +// RealtimeTranscriptionSession handles POST /v1/realtime/transcriptions — a +// transcription-only variant of the session endpoint. Shares the same +// ephemeral-session-key format so the same validator works for both. func RealtimeTranscriptionSession(application *application.Application) echo.HandlerFunc { return func(c echo.Context) error { - return c.NoContent(501) + appCfg := application.ApplicationConfig() + userID := "" + if appCfg != nil && appCfg.Auth.Enabled { + if u := auth.GetUser(c); u != nil { + userID = u.ID + } + } + hmacSecret := "" + if appCfg != nil { + hmacSecret = appCfg.Auth.APIKeyHMACSecret + } + token, expiresAt, err := generateEphemeralSessionKey(hmacSecret, userID) + if err != nil { + return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()}) + } + return c.JSON(http.StatusOK, map[string]any{ + "object": "realtime.transcription_session", + "ephemeral_token": token, + "expires_at": expiresAt.UTC().Format(time.RFC3339), + "seconds_left": int64(time.Until(expiresAt).Seconds()), + }) } } @@ -280,6 +437,29 @@ type RealtimeSessionOptions struct { func Realtime(application *application.Application) echo.HandlerFunc { return func(c echo.Context) error { + // Ephemeral session key validation. When auth is enabled, the + // client must pass a valid token (from POST /v1/realtime/sessions) + // either via `?session=` query parameter or via the Authorization + // header as `Bearer `. The standard `isCurrentUserAdmin` + // path is honored when auth is disabled. + appCfg := application.ApplicationConfig() + if appCfg != nil && appCfg.Auth.Enabled { + token := c.QueryParam("session") + if token == "" { + // Fall back to Authorization: Bearer + ah := c.Request().Header.Get("Authorization") + if strings.HasPrefix(ah, "Bearer ") { + token = strings.TrimPrefix(ah, "Bearer ") + } + } + if token == "" { + return echo.NewHTTPError(http.StatusUnauthorized, "missing ephemeral session key — call POST /v1/realtime/sessions first") + } + if _, err := validateEphemeralSessionKey(token, appCfg.Auth.APIKeyHMACSecret); err != nil { + return echo.NewHTTPError(http.StatusUnauthorized, fmt.Sprintf("invalid ephemeral session key: %s", err.Error())) + } + } + ws, err := upgrader.Upgrade(c.Response(), c.Request(), nil) if err != nil { return err @@ -532,7 +712,7 @@ func runRealtimeSession(application *application.Application, t Transport, model sendEvent(t, types.SessionCreatedEvent{ ServerEventBase: types.ServerEventBase{ - EventID: "event_TODO", + EventID: uuid.New().String(), }, Session: session.ToServer(), }) @@ -621,7 +801,7 @@ func runRealtimeSession(application *application.Application, t Transport, model sendEvent(t, types.SessionUpdatedEvent{ ServerEventBase: types.ServerEventBase{ - EventID: "event_TODO", + EventID: uuid.New().String(), }, Session: session.ToServer(), }) @@ -647,7 +827,7 @@ func runRealtimeSession(application *application.Application, t Transport, model sendEvent(t, types.SessionUpdatedEvent{ ServerEventBase: types.ServerEventBase{ - EventID: "event_TODO", + EventID: uuid.New().String(), }, Session: session.ToServer(), }) @@ -747,13 +927,13 @@ func runRealtimeSession(application *application.Application, t Transport, model }) case types.ConversationItemDeleteEvent: - sendError(t, "not_implemented", "Deleting items not implemented", "", "event_TODO") + sendError(t, "not_implemented", "Deleting items not implemented", "", uuid.New().String()) case types.ConversationItemRetrieveEvent: xlog.Debug("recv", "message", string(msg)) if e.ItemID == "" { - sendError(t, "invalid_item_id", "Need item_id, but none specified", "", "event_TODO") + sendError(t, "invalid_item_id", "Need item_id, but none specified", "", uuid.New().String()) continue } @@ -783,7 +963,7 @@ func runRealtimeSession(application *application.Application, t Transport, model sendEvent(t, types.ConversationItemRetrievedEvent{ ServerEventBase: types.ServerEventBase{ - EventID: "event_TODO", + EventID: uuid.New().String(), }, Item: retrievedItem, }) @@ -879,7 +1059,7 @@ func sendError(t Transport, code, message, param, eventID string) { } func sendNotImplemented(t Transport, message string) { - sendError(t, "not_implemented", message, "", "event_TODO") + sendError(t, "not_implemented", message, "", uuid.New().String()) } // sendTestTone generates a 1-second 440 Hz sine wave and sends it through @@ -1195,7 +1375,7 @@ func handleVAD(session *Session, conv *Conversation, t Transport, done chan stru sendEvent(t, types.InputAudioBufferSpeechStartedEvent{ ServerEventBase: types.ServerEventBase{ - EventID: "event_TODO", + EventID: uuid.New().String(), }, AudioStartMs: time.Since(startTime).Milliseconds(), }) @@ -1216,7 +1396,7 @@ func handleVAD(session *Session, conv *Conversation, t Transport, done chan stru sendEvent(t, types.InputAudioBufferSpeechStoppedEvent{ ServerEventBase: types.ServerEventBase{ - EventID: "event_TODO", + EventID: uuid.New().String(), }, AudioEndMs: time.Since(startTime).Milliseconds(), }) @@ -1224,14 +1404,25 @@ func handleVAD(session *Session, conv *Conversation, t Transport, done chan stru sendEvent(t, types.InputAudioBufferCommittedEvent{ ServerEventBase: types.ServerEventBase{ - EventID: "event_TODO", + EventID: uuid.New().String(), }, ItemID: generateItemID(), - PreviousItemID: "TODO", + PreviousItemID: conv.LastItemID(), }) + // Trim prefix silence exceeding PrefixPaddingMs. + // segments[0].Start is the speech onset in seconds; audio + // before it is silence that we cap at the configured padding. + if td := session.TurnDetection; td != nil && td.ServerVad != nil && td.ServerVad.PrefixPaddingMs > 0 { + prefixSec := float64(segments[0].Start) + padSec := float64(td.ServerVad.PrefixPaddingMs) / 1000.0 + if trimSamples := int((prefixSec - padSec) * localSampleRate); trimSamples > 0 { + if trimSamples < len(aints) { + aints = aints[trimSamples:] + } + } + } abytes := sound.Int16toBytesLE(aints) - // TODO: Remove prefix silence that is is over TurnDetectionParams.PrefixPaddingMs respCtx, respDone := session.startResponse(vadContext) go func() { defer close(respDone) @@ -1278,7 +1469,7 @@ func commitUtterance(ctx context.Context, utt []byte, session *Session, conv *Co var err error transcript, err = emitTranscription(ctx, t, session, generateItemID(), f.Name()) if err != nil { - sendError(t, "transcription_failed", err.Error(), "", "event_TODO") + sendError(t, "transcription_failed", err.Error(), "", uuid.New().String()) return } } else { diff --git a/core/http/endpoints/openai/realtime_model.go b/core/http/endpoints/openai/realtime_model.go index 789ce0a0d319..82281002b725 100644 --- a/core/http/endpoints/openai/realtime_model.go +++ b/core/http/endpoints/openai/realtime_model.go @@ -486,39 +486,22 @@ func newModel(pipeline *config.Pipeline, cl *config.ModelConfigLoader, ml *model return nil, fmt.Errorf("failed to validate config: %w", err) } - // TODO: Do we always need a transcription model? It can be disabled. Note that any-to-any instruction following models don't transcribe as such, so if transcription is required it is a separate process - cfgSST, err := cl.LoadModelConfigFileByName(pipeline.Transcription, ml.ModelPath) - if err != nil { - - return nil, fmt.Errorf("failed to load backend config: %w", err) - } - - if valid, _ := cfgSST.Validate(); !valid { - return nil, fmt.Errorf("failed to validate config: %w", err) + // Transcription (SST) is optional. If the pipeline doesn't specify one + // (pipeline.Transcription == "") or the LLM is already any-to-any, we + // skip loading a separate transcription config and Transcribe/TranscribeStream + // will fall back to erroring or the any-to-any backend. + var cfgSST *config.ModelConfig + if pipeline.Transcription != "" { + c, err := cl.LoadModelConfigFileByName(pipeline.Transcription, ml.ModelPath) + if err != nil { + return nil, fmt.Errorf("failed to load transcription config: %w", err) + } + if valid, _ := c.Validate(); !valid { + return nil, fmt.Errorf("failed to validate transcription config: %w", err) + } + cfgSST = c } - // TODO: Decide when we have a real any-to-any model - // if false { - // - // cfgAnyToAny, err := cl.LoadModelConfigFileByName(pipeline.LLM, ml.ModelPath) - // if err != nil { - // - // return nil, fmt.Errorf("failed to load backend config: %w", err) - // } - // - // if valid, _ := cfgAnyToAny.Validate(); !valid { - // return nil, fmt.Errorf("failed to validate config: %w", err) - // } - // - // return &anyToAnyModel{ - // LLMConfig: cfgAnyToAny, - // VADConfig: cfgVAD, - // }, nil - // } - - xlog.Debug("Loading a wrapped model") - - // Otherwise we want to return a wrapped model, which is a "virtual" model that re-uses other models to perform operations cfgLLM, err := cl.LoadModelConfigFileByName(pipeline.LLM, ml.ModelPath) if err != nil { @@ -529,11 +512,34 @@ func newModel(pipeline *config.Pipeline, cl *config.ModelConfigLoader, ml *model return nil, fmt.Errorf("failed to validate config: %w", err) } + // Any-to-any detection. If the LLM model declares FLAG_REALTIME_AUDIO + // (or one of the known any-to-any backends), skip the TTS + SST pipeline + // and route audio directly through the LLM. + isAnyToAny := false + if cfgLLM != nil { + isAnyToAny = cfgLLM.Backend == "liquid-audio" || cfgLLM.HasUsecases(config.FLAG_REALTIME_AUDIO) + } + // Let the pipeline set the LLM's reasoning effort and force thinking off // (cfgLLM is a per-session copy). disable_thinking applies after the effort. applyPipelineReasoning(cfgLLM, *pipeline) applyPipelineThinking(cfgLLM, *pipeline) + if isAnyToAny { + xlog.Debug("Loading an any-to-any model (native AudioToAudioStream)") + return &anyToAnyModel{ + LLMConfig: cfgLLM, + VADConfig: cfgVAD, + + confLoader: cl, + modelLoader: ml, + appConfig: appConfig, + }, nil + } + + xlog.Debug("Loading a wrapped model") + + // Otherwise we want to return a wrapped model, which is a "virtual" model that re-uses other models to perform operations cfgTTS, err := cl.LoadModelConfigFileByName(pipeline.TTS, ml.ModelPath) if err != nil { diff --git a/core/http/endpoints/openai/realtime_transcription.go b/core/http/endpoints/openai/realtime_transcription.go index 44456101c44f..98e9daf5f1df 100644 --- a/core/http/endpoints/openai/realtime_transcription.go +++ b/core/http/endpoints/openai/realtime_transcription.go @@ -4,6 +4,7 @@ import ( "context" "fmt" + "github.com/google/uuid" "github.com/mudler/LocalAI/core/http/endpoints/openai/types" ) @@ -19,7 +20,7 @@ func emitTranscription(ctx context.Context, t Transport, session *Session, itemI if session.ModelConfig != nil && session.ModelConfig.Pipeline.StreamTranscription() { final, err := session.ModelInterface.TranscribeStream(ctx, audioPath, cfg.Language, false, false, cfg.Prompt, func(delta string) { _ = t.SendEvent(types.ConversationItemInputAudioTranscriptionDeltaEvent{ - ServerEventBase: types.ServerEventBase{EventID: "event_TODO"}, + ServerEventBase: types.ServerEventBase{EventID: uuid.New().String()}, ItemID: itemID, ContentIndex: 0, Delta: delta, @@ -33,7 +34,7 @@ func emitTranscription(ctx context.Context, t Transport, session *Session, itemI transcript = final.Text } if err := t.SendEvent(types.ConversationItemInputAudioTranscriptionCompletedEvent{ - ServerEventBase: types.ServerEventBase{EventID: "event_TODO"}, + ServerEventBase: types.ServerEventBase{EventID: uuid.New().String()}, ItemID: itemID, ContentIndex: 0, Transcript: transcript, @@ -52,7 +53,7 @@ func emitTranscription(ctx context.Context, t Transport, session *Session, itemI return "", fmt.Errorf("transcribe result is nil") } if err := t.SendEvent(types.ConversationItemInputAudioTranscriptionCompletedEvent{ - ServerEventBase: types.ServerEventBase{EventID: "event_TODO"}, + ServerEventBase: types.ServerEventBase{EventID: uuid.New().String()}, ItemID: itemID, ContentIndex: 0, Transcript: tr.Text, diff --git a/core/http/routes/localai.go b/core/http/routes/localai.go index 96baceaf8e44..b65bfb61d2a6 100644 --- a/core/http/routes/localai.go +++ b/core/http/routes/localai.go @@ -92,6 +92,10 @@ func RegisterLocalAIRoutes(router *echo.Echo, router.POST("/models/reload", localai.ReloadModelsEndpoint(cl, appConfig), adminMiddleware) } + // JSON read-back of an installed model's YAML config (used by the + // standalone MCP server so it can call get_model_config over REST). + router.GET("/api/models/config-yaml/:name", localai.GetConfigEndpoint(cl, appConfig), adminMiddleware) + detectionHandler := localai.DetectionEndpoint(cl, ml, appConfig) router.POST("/v1/detection", detectionHandler, @@ -185,6 +189,10 @@ func RegisterLocalAIRoutes(router *echo.Echo, if !appConfig.DisableMetrics { router.GET("/metrics", localai.LocalAIMetricsEndpoint(), adminMiddleware) + router.POST("/v1/tokenMetrics", + localai.TokenMetricsEndpoint(cl, ml, appConfig), + requestExtractor.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(0)), + requestExtractor.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.TokenMetricsRequest) })) } videoHandler := localai.VideoEndpoint(cl, ml, appConfig) @@ -194,13 +202,15 @@ func RegisterLocalAIRoutes(router *echo.Echo, requestExtractor.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.VideoRequest) })) // Backend Statistics Module - // TODO: Should these use standard middlewares? Refactor later, they are extremely simple. + backendMonitorMiddleware := []echo.MiddlewareFunc{ + middleware.TraceMiddleware(app), + } backendMonitorService := monitoring.NewBackendMonitorService(ml, cl, appConfig) // Split out for now - router.GET("/backend/monitor", localai.BackendMonitorEndpoint(backendMonitorService), adminMiddleware) - router.POST("/backend/shutdown", localai.BackendShutdownEndpoint(backendMonitorService), adminMiddleware) + router.GET("/backend/monitor", localai.BackendMonitorEndpoint(backendMonitorService), append(backendMonitorMiddleware, adminMiddleware)...) + router.POST("/backend/shutdown", localai.BackendShutdownEndpoint(backendMonitorService), append(backendMonitorMiddleware, adminMiddleware)...) // The v1/* urls are exactly the same as above - makes local e2e testing easier if they are registered. - router.GET("/v1/backend/monitor", localai.BackendMonitorEndpoint(backendMonitorService), adminMiddleware) - router.POST("/v1/backend/shutdown", localai.BackendShutdownEndpoint(backendMonitorService), adminMiddleware) + router.GET("/v1/backend/monitor", localai.BackendMonitorEndpoint(backendMonitorService), append(backendMonitorMiddleware, adminMiddleware)...) + router.POST("/v1/backend/shutdown", localai.BackendShutdownEndpoint(backendMonitorService), append(backendMonitorMiddleware, adminMiddleware)...) // Traces and backend logs (monitoring) router.GET("/api/traces", localai.GetAPITracesEndpoint(), adminMiddleware) diff --git a/core/p2p/node.go b/core/p2p/node.go index 3b56fd7220e6..beb68728d3eb 100644 --- a/core/p2p/node.go +++ b/core/p2p/node.go @@ -59,3 +59,19 @@ func AddNode(serviceID string, node schema.NodeData) { } nodes[serviceID][node.ID] = node } + +// ReplaceNodes replaces the local view of nodes for a serviceID with the +// given snapshot. Used by the new discoveryTunnels to avoid accumulating +// stale entries. +func ReplaceNodes(serviceID string, nodesSlice []schema.NodeData) { + if serviceID == "" { + serviceID = defaultServicesID + } + mu.Lock() + defer mu.Unlock() + next := make(map[string]schema.NodeData, len(nodesSlice)) + for _, nd := range nodesSlice { + next[nd.ID] = nd + } + nodes[serviceID] = next +} diff --git a/core/p2p/p2p.go b/core/p2p/p2p.go index d03a9c50c200..855c94b6127c 100644 --- a/core/p2p/p2p.go +++ b/core/p2p/p2p.go @@ -16,6 +16,7 @@ import ( "github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/pkg/utils" "github.com/mudler/edgevpn/pkg/config" + "github.com/mudler/edgevpn/pkg/logger" "github.com/mudler/edgevpn/pkg/node" "github.com/mudler/edgevpn/pkg/protocol" "github.com/mudler/edgevpn/pkg/services" @@ -24,10 +25,53 @@ import ( zlog "github.com/mudler/xlog" "github.com/multiformats/go-multiaddr" "github.com/phayes/freeport" - - "github.com/mudler/edgevpn/pkg/logger" ) +// NodeConfig centralises P2P node options, replacing the scattered +// LOCALAI_P2P_* environment variables read by the older newNodeOpts. It's +// also the shape consumed by the new `discoveryTunnels` snapshot logic. +type NodeConfig struct { + Token string + DisableDHT bool + DisableLimits bool + ListenMaddrs []string + BootstrapPeers []string + DHTAnnounceMaddrs []string + Libp2pLogLevel string + DiscoveryInterval time.Duration + DefaultSyncInterval time.Duration + MaxConnections int +} + +// NewNodeConfigFromEnv builds a NodeConfig from environment variables, +// preserving the previous behaviour for callers that haven't been migrated +// yet. Intentionally returns a *copy* of the defaults so callers can mutate +// it before calling NewNode(). +func NewNodeConfigFromEnv(token string) *NodeConfig { + nc := &NodeConfig{ + Token: token, + DisableDHT: os.Getenv("LOCALAI_P2P_DISABLE_DHT") == "true", + DisableLimits: os.Getenv("LOCALAI_P2P_ENABLE_LIMITS") != "true", + DiscoveryInterval: 10 * time.Second, + DefaultSyncInterval: 10 * time.Second, + MaxConnections: 1000, + Libp2pLogLevel: "fatal", + } + if v := os.Getenv("LOCALAI_P2P_LISTEN_MADDRS"); v != "" { + nc.ListenMaddrs = strings.Split(v, ",") + } + if v := os.Getenv("LOCALAI_P2P_BOOTSTRAP_PEERS_MADDRS"); v != "" { + nc.BootstrapPeers = strings.Split(v, ",") + } + if v := os.Getenv("LOCALAI_P2P_DHT_ANNOUNCE_MADDRS"); v != "" { + nc.DHTAnnounceMaddrs = strings.Split(v, ",") + } + if v := os.Getenv("LOCALAI_P2P_LIB_LOGLEVEL"); v != "" { + nc.Libp2pLogLevel = v + } + return nc +} + func generateNewConnectionData(DHTInterval, OTPInterval int) *node.YAMLConnectionConfig { maxMessSize := 20 << 20 // 20MB keyLength := 43 @@ -174,25 +218,32 @@ func ServiceDiscoverer(ctx context.Context, n *node.Node, token, servicesID stri if servicesID == "" { servicesID = defaultServicesID } - tunnels, err := discoveryTunnels(ctx, n, token, servicesID, allocate) + + // snapshotNodes returns all known nodes (not just new ones) as a + // map keyed by worker id. This is the "return all nodes" shape the + // previous TODO asked for: the caller gets the full current state + // every iteration, so AddNode becomes idempotent and LastSeen is + // refreshed per-discovery cycle. + snapshotNodes, err := discoveryTunnels(ctx, n, servicesID, allocate) if err != nil { return err } - // TODO: discoveryTunnels should return all the nodes that are available? - // In this way we updated availableNodes here instead of appending - // e.g. we have a LastSeen field in NodeData that is updated in discoveryTunnels - // each time the node is seen - // In this case the below function should be idempotent and just keep track of the nodes + go func() { for { select { case <-ctx.Done(): zlog.Error("Discoverer stopped") return - case tunnel := <-tunnels: - AddNode(servicesID, tunnel) + case nodes := <-snapshotNodes: + // nodes is the full snapshot — replace the local view. + // ResetNodes clears stale entries and adds the current + // ones, all under the muservice lock. + ReplaceNodes(servicesID, nodes) if discoveryFunc != nil { - discoveryFunc(servicesID, tunnel) + for _, nd := range nodes { + discoveryFunc(servicesID, nd) + } } } } @@ -201,21 +252,17 @@ func ServiceDiscoverer(ctx context.Context, n *node.Node, token, servicesID stri return nil } -func discoveryTunnels(ctx context.Context, n *node.Node, token, servicesID string, allocate bool) (chan schema.NodeData, error) { - tunnels := make(chan schema.NodeData) +func discoveryTunnels(ctx context.Context, n *node.Node, servicesID string, allocate bool) (chan []schema.NodeData, error) { + tunnels := make(chan []schema.NodeData) ledger, err := n.Ledger() if err != nil { return nil, fmt.Errorf("getting the ledger: %w", err) } - // get new services, allocate and return to the channel - - // TODO: - // a function ensureServices that: - // - starts a service if not started, if the worker is Online - // - checks that workers are Online, if not cancel the context of allocateLocalService - // - discoveryTunnels should return all the nodes and addresses associated with it - // - the caller should take now care of the fact that we are always returning fresh information + + // ensureServices now lives inside the discovery loop so the same + // pass that finds nodes also manages their lifecycle (start if + // online + not started, cancel if offline). go func() { for { select { @@ -228,30 +275,63 @@ func discoveryTunnels(ctx context.Context, n *node.Node, token, servicesID strin data := ledger.LastBlock().Storage[servicesID] if logLevel == logLevelDebug { - // We want to surface this debugging data only if p2p logging is set to debug - // (and not generally the whole application, as this can be really noisy) zlog.Debug("Ledger data", "data", ledger.LastBlock().Storage) } + // Build this iteration's snapshot slice + reconcile + // local services against it. + snapshot := make([]schema.NodeData, 0, len(data)) + + muservice.Lock() + // Phase 1: mark which services are still referenced in + // the ledger. We'll cancel any service whose node is + // no longer listed — but only AFTER phase 2 so we don't + // cancel the node we're about to re-allocate on. + referenced := make(map[string]struct{}, len(data)) + + // Phase 2: iterate ledger entries, ensure services for + // online nodes, collect snapshot. + muservice.Unlock() + for k, v := range data { - // New worker found in the ledger data as k (worker id) nd := &schema.NodeData{} if err := v.Unmarshal(nd); err != nil { zlog.Error("cannot unmarshal node data") continue } + referenced[nd.Name] = struct{}{} + // ensureService handles the start-if-online / + // cancel-if-offline logic and keeps the per-node + // CancelFunc in the `service` map. ensureService(ctx, n, nd, k, allocate) + muservice.Lock() - if _, ok := service[nd.Name]; ok { - tunnels <- service[nd.Name].NodeData + if entry, ok := service[nd.Name]; ok { + snapshot = append(snapshot, entry.NodeData) } muservice.Unlock() } + + // Phase 3: cancel services for nodes that dropped out + // of the ledger (node went offline or left the mesh). + muservice.Lock() + for name, sd := range service { + if _, ok := referenced[name]; !ok { + zlog.Debug("Node no longer in ledger, cancelling tunnel", "node", name) + sd.CancelFunc() + delete(service, name) + } + } + muservice.Unlock() + + // Send the full snapshot so the caller replaces — not + // appends to — its local view. + tunnels <- snapshot } } }() - return tunnels, err + return tunnels, nil } type nodeServiceData struct { @@ -317,7 +397,7 @@ func ExposeService(ctx context.Context, host, port, token, servicesID string) (* } llger := logger.New(log.LevelFatal) - nodeOpts, err := newNodeOpts(token) + nodeOpts, err := newNodeOpts(NewNodeConfigFromEnv(token)) if err != nil { return nil, err } @@ -360,7 +440,7 @@ func ExposeService(ctx context.Context, host, port, token, servicesID string) (* } func NewNode(token string) (*node.Node, error) { - nodeOpts, err := newNodeOpts(token) + nodeOpts, err := newNodeOpts(NewNodeConfigFromEnv(token)) if err != nil { return nil, err } @@ -373,47 +453,25 @@ func NewNode(token string) (*node.Node, error) { return n, nil } -func newNodeOpts(token string) ([]node.Option, error) { +func newNodeOpts(nc *NodeConfig) ([]node.Option, error) { llger := logger.New(log.LevelFatal) - defaultInterval := 10 * time.Second - // TODO: move this up, expose more config options when creating a node - noDHT := os.Getenv("LOCALAI_P2P_DISABLE_DHT") == "true" - noLimits := os.Getenv("LOCALAI_P2P_ENABLE_LIMITS") != "true" - - var listenMaddrs []string - var bootstrapPeers []string - - laddrs := os.Getenv("LOCALAI_P2P_LISTEN_MADDRS") - if laddrs != "" { - listenMaddrs = strings.Split(laddrs, ",") - } - - bootmaddr := os.Getenv("LOCALAI_P2P_BOOTSTRAP_PEERS_MADDRS") - if bootmaddr != "" { - bootstrapPeers = strings.Split(bootmaddr, ",") - } - - dhtAnnounceMaddrs := stringsToMultiAddr(strings.Split(os.Getenv("LOCALAI_P2P_DHT_ANNOUNCE_MADDRS"), ",")) - - libp2ploglevel := os.Getenv("LOCALAI_P2P_LIB_LOGLEVEL") - if libp2ploglevel == "" { - libp2ploglevel = "fatal" - } + // Build the final config from NodeConfig instead of ad-hoc os.Getenv + // reads scattered around the function (the previous TODO). c := config.Config{ - ListenMaddrs: listenMaddrs, - DHTAnnounceMaddrs: dhtAnnounceMaddrs, + ListenMaddrs: nc.ListenMaddrs, + DHTAnnounceMaddrs: stringsToMultiAddr(nc.DHTAnnounceMaddrs), Limit: config.ResourceLimit{ - Enable: noLimits, + Enable: nc.DisableLimits, MaxConns: 100, }, - NetworkToken: token, + NetworkToken: nc.Token, LowProfile: false, LogLevel: logLevel, - Libp2pLogLevel: libp2ploglevel, + Libp2pLogLevel: nc.Libp2pLogLevel, Ledger: config.Ledger{ - SyncInterval: defaultInterval, - AnnounceInterval: defaultInterval, + SyncInterval: nc.DefaultSyncInterval, + AnnounceInterval: nc.DefaultSyncInterval, }, NAT: config.NAT{ Service: true, @@ -421,18 +479,18 @@ func newNodeOpts(token string) ([]node.Option, error) { RateLimit: true, RateLimitGlobal: 100, RateLimitPeer: 100, - RateLimitInterval: defaultInterval, + RateLimitInterval: nc.DefaultSyncInterval, }, Discovery: config.Discovery{ - DHT: !noDHT, + DHT: !nc.DisableDHT, MDNS: true, - Interval: 10 * time.Second, - BootstrapPeers: bootstrapPeers, + Interval: nc.DiscoveryInterval, + BootstrapPeers: nc.BootstrapPeers, }, Connection: config.Connection{ HolePunch: true, AutoRelay: true, - MaxConnections: 1000, + MaxConnections: nc.MaxConnections, }, } diff --git a/core/p2p/p2p_test.go b/core/p2p/p2p_test.go new file mode 100644 index 000000000000..08797c3ea3a5 --- /dev/null +++ b/core/p2p/p2p_test.go @@ -0,0 +1,344 @@ +package p2p + +import ( + "os" + "strings" + "testing" + "time" + + "github.com/mudler/LocalAI/core/schema" +) + +func TestNewNodeConfigFromEnvDefaults(t *testing.T) { + // Ensure env vars are clean + for _, k := range []string{ + "LOCALAI_P2P_DISABLE_DHT", + "LOCALAI_P2P_ENABLE_LIMITS", + "LOCALAI_P2P_LISTEN_MADDRS", + "LOCALAI_P2P_BOOTSTRAP_PEERS_MADDRS", + "LOCALAI_P2P_DHT_ANNOUNCE_MADDRS", + "LOCALAI_P2P_LIB_LOGLEVEL", + } { + os.Unsetenv(k) + } + + nc := NewNodeConfigFromEnv("test-token") + if nc.Token != "test-token" { + t.Errorf("Token = %q, want %q", nc.Token, "test-token") + } + if nc.DisableDHT { + t.Error("DisableDHT default should be false") + } + if !nc.DisableLimits { + t.Error("DisableLimits default should be true (inverted from LOCALAI_P2P_ENABLE_LIMITS != 'true')") + } + if nc.DiscoveryInterval != 10*time.Second { + t.Errorf("DiscoveryInterval = %v, want 10s", nc.DiscoveryInterval) + } + if nc.DefaultSyncInterval != 10*time.Second { + t.Errorf("DefaultSyncInterval = %v, want 10s", nc.DefaultSyncInterval) + } + if nc.MaxConnections != 1000 { + t.Errorf("MaxConnections = %d, want 1000", nc.MaxConnections) + } + if nc.Libp2pLogLevel != "fatal" { + t.Errorf("Libp2pLogLevel = %q, want 'fatal'", nc.Libp2pLogLevel) + } + if nc.ListenMaddrs != nil { + t.Error("ListenMaddrs should be nil by default") + } + if nc.BootstrapPeers != nil { + t.Error("BootstrapPeers should be nil by default") + } + if nc.DHTAnnounceMaddrs != nil { + t.Error("DHTAnnounceMaddrs should be nil by default") + } +} + +func TestNewNodeConfigFromEnvOverrides(t *testing.T) { + os.Setenv("LOCALAI_P2P_DISABLE_DHT", "true") + os.Setenv("LOCALAI_P2P_ENABLE_LIMITS", "true") + os.Setenv("LOCALAI_P2P_LISTEN_MADDRS", "/ip4/0.0.0.0/tcp/4001,/ip4/0.0.0.0/udp/4001/quic") + os.Setenv("LOCALAI_P2P_BOOTSTRAP_PEERS_MADDRS", "/ip4/1.2.3.4/tcp/4001") + os.Setenv("LOCALAI_P2P_DHT_ANNOUNCE_MADDRS", "/ip4/5.6.7.8/tcp/4001") + os.Setenv("LOCALAI_P2P_LIB_LOGLEVEL", "debug") + defer func() { + os.Unsetenv("LOCALAI_P2P_DISABLE_DHT") + os.Unsetenv("LOCALAI_P2P_ENABLE_LIMITS") + os.Unsetenv("LOCALAI_P2P_LISTEN_MADDRS") + os.Unsetenv("LOCALAI_P2P_BOOTSTRAP_PEERS_MADDRS") + os.Unsetenv("LOCALAI_P2P_DHT_ANNOUNCE_MADDRS") + os.Unsetenv("LOCALAI_P2P_LIB_LOGLEVEL") + }() + + nc := NewNodeConfigFromEnv("override-token") + if !nc.DisableDHT { + t.Error("DisableDHT should be true") + } + if nc.DisableLimits { + t.Error("DisableLimits should be false when LOCALAI_P2P_ENABLE_LIMITS=true (limits enabled)") + } + if len(nc.ListenMaddrs) != 2 { + t.Errorf("ListenMaddrs len = %d, want 2", len(nc.ListenMaddrs)) + } + if len(nc.BootstrapPeers) != 1 { + t.Errorf("BootstrapPeers len = %d, want 1", len(nc.BootstrapPeers)) + } + if len(nc.DHTAnnounceMaddrs) != 1 { + t.Errorf("DHTAnnounceMaddrs len = %d, want 1", len(nc.DHTAnnounceMaddrs)) + } + if nc.Libp2pLogLevel != "debug" { + t.Errorf("Libp2pLogLevel = %q, want 'debug'", nc.Libp2pLogLevel) + } +} + +func TestNewNodeConfigFromEnvEnableLimitsFalse(t *testing.T) { + // LOCALAI_P2P_ENABLE_LIMITS != "true" means limits are disabled + os.Setenv("LOCALAI_P2P_ENABLE_LIMITS", "false") + defer os.Unsetenv("LOCALAI_P2P_ENABLE_LIMITS") + + nc := NewNodeConfigFromEnv("limits-test") + if !nc.DisableLimits { + t.Error("DisableLimits should be true when enable_limits=false") + } +} + +func TestGenerateTokenNotEmpty(t *testing.T) { + token := GenerateToken(30, 9000) + if token == "" { + t.Fatal("GenerateToken returned empty string") + } + if !strings.HasPrefix(token, "ey") { + t.Log("Token does not start with 'ey' (base64 JSON), got:", token[:10]+"...") + } +} + +func TestGenerateConnectionDataDefaults(t *testing.T) { + data := generateNewConnectionData(0, 0) + if data == nil { + t.Fatal("generateNewConnectionData returned nil") + } + if data.RoomName == "" { + t.Error("RoomName is empty") + } + if data.MaxMessageSize != 20<<20 { + t.Errorf("MaxMessageSize = %d, want %d", data.MaxMessageSize, 20<<20) + } +} + +func TestGenerateConnectionDataCustom(t *testing.T) { + data := generateNewConnectionData(60, 18000) + if data == nil { + t.Fatal("generateNewConnectionData returned nil") + } + if data.OTP.DHT.Interval != 60 { + t.Errorf("DHT interval = %d, want 60", data.OTP.DHT.Interval) + } + if data.OTP.Crypto.Interval != 18000 { + t.Errorf("Crypto interval = %d, want 18000", data.OTP.Crypto.Interval) + } +} + +func TestNodeID(t *testing.T) { + hostname, _ := os.Hostname() + id := nodeID("worker") + expected := hostname + "-worker" + if id != expected { + t.Errorf("nodeID = %q, want %q", id, expected) + } +} + +func TestStringsToMultiAddr(t *testing.T) { + tests := []struct { + name string + input []string + wantN int + }{ + {"nil slice", nil, 0}, + {"empty slice", []string{}, 0}, + {"valid single", []string{"/ip4/127.0.0.1/tcp/4001"}, 1}, + {"valid multiple", []string{"/ip4/127.0.0.1/tcp/4001", "/ip4/0.0.0.0/tcp/4002"}, 2}, + {"mix valid and invalid", []string{"/ip4/127.0.0.1/tcp/4001", "not-a-multiaddr"}, 1}, + {"all invalid", []string{"not-a-multiaddr", "also-not"}, 0}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := stringsToMultiAddr(tt.input) + if len(got) != tt.wantN { + t.Errorf("got %d multiaddrs, want %d", len(got), tt.wantN) + } + }) + } +} + +func TestAddAndGetNode(t *testing.T) { + // Reset global state + mu.Lock() + nodes = map[string]map[string]schema.NodeData{} + mu.Unlock() + + nd := schema.NodeData{ + Name: "test-node", + ID: "test-node-1", + TunnelAddress: "127.0.0.1:8080", + LastSeen: time.Now(), + } + + AddNode("test-service", nd) + + got, found := GetNode("test-service", "test-node-1") + if !found { + t.Fatal("GetNode returned not found") + } + if got.Name != "test-node" { + t.Errorf("Name = %q, want %q", got.Name, "test-node") + } + if got.TunnelAddress != "127.0.0.1:8080" { + t.Errorf("TunnelAddress = %q, want %q", got.TunnelAddress, "127.0.0.1:8080") + } +} + +func TestGetNodeNotFound(t *testing.T) { + mu.Lock() + nodes = map[string]map[string]schema.NodeData{} + mu.Unlock() + + _, found := GetNode("nonexistent", "no-id") + if found { + t.Error("GetNode should return false for non-existent node") + } +} + +func TestGetNodeDefaultServiceID(t *testing.T) { + mu.Lock() + nodes = map[string]map[string]schema.NodeData{} + mu.Unlock() + + nd := schema.NodeData{Name: "default-svc", ID: "default-id", LastSeen: time.Now()} + AddNode("", nd) + + _, found := GetNode("", "default-id") + if !found { + t.Error("GetNode with empty serviceID should find node stored under defaultServicesID") + } +} + +func TestAddNodeDefaultServiceID(t *testing.T) { + mu.Lock() + nodes = map[string]map[string]schema.NodeData{} + mu.Unlock() + + nd := schema.NodeData{Name: "node-a", ID: "id-a", LastSeen: time.Now()} + AddNode("", nd) + + got, found := GetNode(defaultServicesID, "id-a") + if !found { + t.Fatal("Node stored with empty serviceID should be retrievable via defaultServicesID") + } + if got.Name != "node-a" { + t.Errorf("Name = %q, want %q", got.Name, "node-a") + } +} + +func TestAddNodeMultiple(t *testing.T) { + mu.Lock() + nodes = map[string]map[string]schema.NodeData{} + mu.Unlock() + + AddNode("svc", schema.NodeData{Name: "n1", ID: "id1", LastSeen: time.Now()}) + AddNode("svc", schema.NodeData{Name: "n2", ID: "id2", LastSeen: time.Now()}) + + available := GetAvailableNodes("svc") + if len(available) != 2 { + t.Errorf("GetAvailableNodes returned %d nodes, want 2", len(available)) + } +} + +func TestGetAvailableNodesEmptyServiceID(t *testing.T) { + mu.Lock() + nodes = map[string]map[string]schema.NodeData{} + mu.Unlock() + + AddNode(defaultServicesID, schema.NodeData{Name: "n1", ID: "id1", LastSeen: time.Now()}) + + available := GetAvailableNodes("") + if len(available) != 1 { + t.Errorf("GetAvailableNodes('') returned %d nodes, want 1", len(available)) + } +} + +func TestGetAvailableNodesSortOrder(t *testing.T) { + mu.Lock() + nodes = map[string]map[string]schema.NodeData{} + mu.Unlock() + + AddNode("svc", schema.NodeData{Name: "b-node", ID: "b-id", LastSeen: time.Now()}) + AddNode("svc", schema.NodeData{Name: "a-node", ID: "a-id", LastSeen: time.Now()}) + AddNode("svc", schema.NodeData{Name: "c-node", ID: "c-id", LastSeen: time.Now()}) + + available := GetAvailableNodes("svc") + if len(available) != 3 { + t.Fatalf("got %d nodes, want 3", len(available)) + } + if available[0].ID != "a-id" || available[1].ID != "b-id" || available[2].ID != "c-id" { + t.Errorf("nodes not sorted by ID: got %v, %v, %v", available[0].ID, available[1].ID, available[2].ID) + } +} + +func TestReplaceNodes(t *testing.T) { + mu.Lock() + nodes = map[string]map[string]schema.NodeData{} + mu.Unlock() + + // Add initial nodes + AddNode("svc", schema.NodeData{Name: "old", ID: "old-id", LastSeen: time.Now()}) + + // Replace with new set + replacement := []schema.NodeData{ + {Name: "new1", ID: "new1-id", LastSeen: time.Now()}, + {Name: "new2", ID: "new2-id", LastSeen: time.Now()}, + } + ReplaceNodes("svc", replacement) + + available := GetAvailableNodes("svc") + if len(available) != 2 { + t.Errorf("after replace: got %d nodes, want 2", len(available)) + } + + // Old node should be gone + _, found := GetNode("svc", "old-id") + if found { + t.Error("old node should not exist after ReplaceNodes") + } +} + +func TestReplaceNodesDefaultServiceID(t *testing.T) { + mu.Lock() + nodes = map[string]map[string]schema.NodeData{} + mu.Unlock() + + replacement := []schema.NodeData{ + {Name: "n1", ID: "id1", LastSeen: time.Now()}, + } + ReplaceNodes("", replacement) + + available := GetAvailableNodes(defaultServicesID) + if len(available) != 1 { + t.Errorf("after ReplaceNodes with empty serviceID: got %d nodes, want 1", len(available)) + } +} + +func TestReplaceNodesEmptySlice(t *testing.T) { + mu.Lock() + nodes = map[string]map[string]schema.NodeData{} + mu.Unlock() + + AddNode("svc", schema.NodeData{Name: "n1", ID: "id1", LastSeen: time.Now()}) + ReplaceNodes("svc", []schema.NodeData{}) + + available := GetAvailableNodes("svc") + if len(available) != 0 { + t.Errorf("after ReplaceNodes with empty slice: got %d nodes, want 0", len(available)) + } +} diff --git a/core/schema/request.go b/core/schema/request.go index 1e4eed156287..6d163ef2ef26 100644 --- a/core/schema/request.go +++ b/core/schema/request.go @@ -8,11 +8,6 @@ type LocalAIRequest interface { // @Description BasicModelRequest contains the basic model request fields type BasicModelRequest struct { Model string `json:"model,omitempty" yaml:"model,omitempty"` - // TODO: Should this also include the following fields from the OpenAI side of the world? - // If so, changes should be made to core/http/middleware/request.go to match - - // Context context.Context `json:"-"` - // Cancel context.CancelFunc `json:"-"` } func (bmr *BasicModelRequest) ModelName(s *string) string { diff --git a/core/services/monitoring/backend_monitor.go b/core/services/monitoring/backend_monitor.go index b66c221027af..e380e47f0b43 100644 --- a/core/services/monitoring/backend_monitor.go +++ b/core/services/monitoring/backend_monitor.go @@ -84,13 +84,13 @@ func (bms *BackendMonitorService) SampleLocalBackendProcess(model string) (*sche }, nil } -func (bms BackendMonitorService) CheckAndSample(modelName string) (*proto.StatusResponse, error) { +func (bms BackendMonitorService) CheckAndSample(ctx context.Context, modelName string) (*proto.StatusResponse, error) { modelAddr := bms.modelLoader.CheckIsLoaded(modelName) if modelAddr == nil { return nil, fmt.Errorf("backend %s is not currently loaded", modelName) } - status, rpcErr := modelAddr.GRPC(false, nil).Status(context.TODO()) + status, rpcErr := modelAddr.GRPC(false, nil).Status(ctx) if rpcErr != nil { xlog.Warn("backend experienced an error retrieving status info", "backend", modelName, "error", rpcErr) val, slbErr := bms.SampleLocalBackendProcess(modelName) diff --git a/core/services/monitoring/metrics.go b/core/services/monitoring/metrics.go index f9644698e6df..4c5e4b89edc8 100644 --- a/core/services/monitoring/metrics.go +++ b/core/services/monitoring/metrics.go @@ -3,7 +3,6 @@ package monitoring import ( "context" - "github.com/mudler/xlog" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/exporters/prometheus" @@ -17,12 +16,12 @@ type LocalAIMetricsService struct { ApiTimeMetric metric.Float64Histogram } -func (m *LocalAIMetricsService) ObserveAPICall(method string, path string, duration float64) { +func (m *LocalAIMetricsService) ObserveAPICall(ctx context.Context, method string, path string, duration float64) { opts := metric.WithAttributes( attribute.String("method", method), attribute.String("path", path), ) - m.ApiTimeMetric.Record(context.Background(), duration, opts) + m.ApiTimeMetric.Record(ctx, duration, opts) } // setupOTelSDK bootstraps the OpenTelemetry pipeline. @@ -55,10 +54,10 @@ func NewLocalAIMetricsService() (*LocalAIMetricsService, error) { } func (lams LocalAIMetricsService) Shutdown() error { - // TODO: Not sure how to actually do this: - //// setupOTelSDK bootstraps the OpenTelemetry pipeline. - //// If it does not return an error, make sure to call shutdown for proper cleanup. - - xlog.Warn("LocalAIMetricsService Shutdown called, but OTelSDK proper shutdown not yet implemented?") - return nil + // Shutdown flushes any pending telemetry held by the OTel MeterProvider + // and releases the Prometheus exporter. Without this call the process + // leaves counters and the reader in an inconsistent state on shutdown — + // which matters for short-lived CLI subcommands (e.g. `local-ai mcp-server`) + // and for tests that spin up a fresh service per case. + return lams.Provider.Shutdown(context.Background()) } diff --git a/core/services/worker/file_staging.go b/core/services/worker/file_staging.go index 9d834a7127de..448bc9e1d877 100644 --- a/core/services/worker/file_staging.go +++ b/core/services/worker/file_staging.go @@ -37,10 +37,12 @@ func isPathAllowed(path string, allowedDirs []string) bool { } // subscribeFileStaging subscribes to NATS file staging subjects for this node. -func (cfg *Config) subscribeFileStaging(natsClient messaging.MessagingClient, nodeID string) error { +// ctx is used to cancel in-flight S3 operations on worker shutdown; the +// individual Download/Upload calls below also derive from it so cancelled +// workers release their HTTP connections promptly. +func (cfg *Config) subscribeFileStaging(ctx context.Context, natsClient messaging.MessagingClient, nodeID string) error { // Create FileManager with same S3 config as the frontend - // TODO: propagate a caller-provided context once Config carries one - s3Store, err := storage.NewS3Store(context.Background(), storage.S3Config{ + s3Store, err := storage.NewS3Store(ctx, storage.S3Config{ Endpoint: cfg.StorageURL, Region: cfg.StorageRegion, Bucket: cfg.StorageBucket, @@ -68,7 +70,7 @@ func (cfg *Config) subscribeFileStaging(natsClient messaging.MessagingClient, no return } - localPath, err := fm.Download(context.Background(), req.Key) + localPath, err := fm.Download(ctx, req.Key) if err != nil { xlog.Error("File ensure failed", "key", req.Key, "error", err) replyJSON(reply, map[string]string{"error": err.Error()}) @@ -99,7 +101,7 @@ func (cfg *Config) subscribeFileStaging(natsClient messaging.MessagingClient, no return } - if err := fm.Upload(context.Background(), req.Key, req.LocalPath); err != nil { + if err := fm.Upload(ctx, req.Key, req.LocalPath); err != nil { xlog.Error("File stage failed", "path", req.LocalPath, "key", req.Key, "error", err) replyJSON(reply, map[string]string{"error": err.Error()}) return diff --git a/core/services/worker/worker.go b/core/services/worker/worker.go index 9ed5de8f11a6..c7c525df7558 100644 --- a/core/services/worker/worker.go +++ b/core/services/worker/worker.go @@ -208,7 +208,7 @@ func Run(ctx *cliContext.Context, cfg *Config) error { // Subscribe to file staging NATS subjects if S3 is configured if cfg.StorageURL != "" { - if err := cfg.subscribeFileStaging(natsClient, nodeID); err != nil { + if err := cfg.subscribeFileStaging(shutdownCtx, natsClient, nodeID); err != nil { xlog.Error("Failed to subscribe to file staging subjects", "error", err) } } diff --git a/core/templates/evaluator.go b/core/templates/evaluator.go index 8ee74e43d3e1..c8d006f85412 100644 --- a/core/templates/evaluator.go +++ b/core/templates/evaluator.go @@ -46,15 +46,29 @@ const ( ) type Evaluator struct { - cache *templateCache + cache *templateCache + loader *TemplateLoader } +// NewEvaluator returns an Evaluator rooted at modelPath, which is also used +// as the templates directory (same physical folder — a separate TemplateLoader +// then filters to .tmpl files only). This keeps compatibility with existing +// deployments where models and templates live side-by-side, while giving +// templates a dedicated, testable loader surface. func NewEvaluator(modelPath string) *Evaluator { return &Evaluator{ - cache: newTemplateCache(modelPath), + cache: newTemplateCache(modelPath), + loader: NewTemplateLoader(modelPath), } } +// TemplateLoader exposes the underlying TemplateLoader owned by the Evaluator. +// Useful when a caller wants to enumerate available templates (e.g. for the +// model editor UI) without having to construct a separate loader. +func (e *Evaluator) TemplateLoader() *TemplateLoader { + return e.loader +} + func (e *Evaluator) EvaluateTemplateForPrompt(templateType TemplateType, config config.ModelConfig, in PromptTemplateData) (string, error) { template := "" @@ -141,7 +155,7 @@ func (e *Evaluator) TemplateMessages(input schema.OpenAIRequest, messages []sche } else { if templatedChatMessage == "" { xlog.Warn("template produced blank output, skipping", "template", config.TemplateConfig.ChatMessage, "message", chatMessageData) - continue // TODO: This continue is here intentionally to skip over the line `mess = append(mess, content)` below, and to prevent the sprintf + continue // Empty template output — skip adding this message to the list } xlog.Debug("templated message for chat", "message", templatedChatMessage) content = templatedChatMessage diff --git a/core/templates/loader.go b/core/templates/loader.go new file mode 100644 index 000000000000..477f120a7c85 --- /dev/null +++ b/core/templates/loader.go @@ -0,0 +1,146 @@ +package templates + +import ( + "fmt" + "os" + "path/filepath" + "strings" + "sync" + + "github.com/mudler/LocalAI/pkg/utils" +) + +// templateFileSuffixes are file extensions that identify a template file on disk. +var templateFileSuffixes = []string{".tmpl"} + +// isTemplateFile reports whether a filename looks like a prompt template, +// optionally suffixed by one of templateFileSuffixes. We deliberately do NOT +// treat embedded model artifacts (e.g. gguf/bin/yaml) as templates — that is +// the job of pkg/model.ModelLoader. Splitting this decision into a dedicated +// loader keeps package responsibilities disjoint, which was the motivation for +// the "Split ModelLoader and TemplateLoader" TODO in pkg/model/loader.go. +func isTemplateFile(name string) bool { + lower := strings.ToLower(name) + for _, s := range templateFileSuffixes { + if strings.HasSuffix(lower, s) { + return true + } + } + return false +} + +// TemplateLoader owns on-disk discovery of prompt template files. It is a +// peer to pkg/model.ModelLoader, but restricted to .tmpl files — which is why +// it lives in core/templates and never touches model binaries. +type TemplateLoader struct { + mu sync.RWMutex + templatesPath string + // cache of basename -> absolute path discovered on the last ListTemplates + // call. A nil map means "no scan has happened yet"; callers typically + // only read it through ListTemplates. + known map[string]string +} + +// NewTemplateLoader returns a loader rooted at templatesPath. The path is +// permitted to be the same directory as the model path; TemplateLoader uses +// suffix-based filtering to only pick up template files within it. +func NewTemplateLoader(templatesPath string) *TemplateLoader { + return &TemplateLoader{ + templatesPath: templatesPath, + known: nil, + } +} + +// ListTemplates returns the basenames of all template files currently +// available under the loader's root directory. Hidden files (leading dot) +// are skipped. The result is cached in-memory; call Invalidate to force a +// re-scan (e.g. after a user uploads a new .tmpl via the model editor). +func (tl *TemplateLoader) ListTemplates() ([]string, error) { + tl.mu.RLock() + if tl.known != nil { + names := make([]string, 0, len(tl.known)) + for n := range tl.known { + names = append(names, n) + } + tl.mu.RUnlock() + return names, nil + } + tl.mu.RUnlock() + + return tl.scanAndCache() +} + +// Resolve returns the absolute path for a template basename (with or without +// the .tmpl suffix) if and only if it exists on disk and lives inside +// templatesPath. The second result is false when no such file is present. +func (tl *TemplateLoader) Resolve(name string) (string, bool) { + // Normalize: drop .tmpl suffix if present so callers can pass either + // "chatml" or "chatml.tmpl". + base := strings.TrimSuffix(name, ".tmpl") + candidate := filepath.Join(tl.templatesPath, base+".tmpl") + + if err := utils.VerifyPath(filepath.Base(candidate), tl.templatesPath); err != nil { + return "", false + } + if _, err := os.Stat(candidate); err != nil { + return "", false + } + return candidate, true +} + +// Invalidate clears the internal cache, forcing the next ListTemplates call +// to read the filesystem. Safe to call from a hooks handler after model +// edits that may have added/removed template files. +func (tl *TemplateLoader) Invalidate() { + tl.mu.Lock() + tl.known = nil + tl.mu.Unlock() +} + +func (tl *TemplateLoader) scanAndCache() ([]string, error) { + tl.mu.Lock() + defer tl.mu.Unlock() + + // Double-check after acquiring the write lock — another goroutine may + // have populated the cache while we were waiting. + if tl.known != nil { + names := make([]string, 0, len(tl.known)) + for n := range tl.known { + names = append(names, n) + } + return names, nil + } + + entries, err := os.ReadDir(tl.templatesPath) + if err != nil { + return nil, fmt.Errorf("reading templates dir %q: %w", tl.templatesPath, err) + } + + names := make([]string, 0) + known := make(map[string]string) + for _, e := range entries { + if e.IsDir() { + continue + } + name := e.Name() + // Skip dotfiles; a model editor may drop e.g. ".DS_Store" or swap + // files that should never surface as templates. + if strings.HasPrefix(name, ".") { + continue + } + if !isTemplateFile(name) { + continue + } + abs, err := filepath.Abs(filepath.Join(tl.templatesPath, name)) + if err != nil { + continue + } + // Use the "bare" name (without .tmpl) as the lookup key, matching + // the "chatml", "llama3" convention used in model YAMLs. + bare := strings.TrimSuffix(name, filepath.Ext(name)) + known[bare] = abs + names = append(names, bare) + } + tl.known = known + return names, nil +} diff --git a/core/templates/loader_test.go b/core/templates/loader_test.go new file mode 100644 index 000000000000..6813ac4e640e --- /dev/null +++ b/core/templates/loader_test.go @@ -0,0 +1,388 @@ +package templates + +import ( + "os" + "path/filepath" + "sort" + "sync" + "testing" +) + +// --- isTemplateFile --- + +func TestIsTemplateFile(t *testing.T) { + tests := []struct { + name string + input string + want bool + }{ + {"tmpl suffix", "chatml.tmpl", true}, + {"tmpl uppercase", "CHATML.TMPL", true}, + {"tmpl mixed case", "ChatML.Tmpl", true}, + {"no suffix", "chatml", false}, + {"wrong suffix", "chatml.txt", false}, + {"model binary", "llama-2-7b.gguf", false}, + {"yaml config", "model.yaml", false}, + {"empty string", "", false}, + {"just suffix", ".tmpl", true}, + {"dotfile", ".secret.tmpl", true}, // isTemplateFile only checks suffix + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := isTemplateFile(tt.input); got != tt.want { + t.Errorf("isTemplateFile(%q) = %v, want %v", tt.input, got, tt.want) + } + }) + } +} + +// --- NewTemplateLoader --- + +func TestNewTemplateLoader(t *testing.T) { + loader := NewTemplateLoader("/tmp") + if loader == nil { + t.Fatal("NewTemplateLoader returned nil") + } +} + +// --- helpers --- + +func writeFile(t *testing.T, dir, name, content string) { + t.Helper() + if err := os.WriteFile(filepath.Join(dir, name), []byte(content), 0644); err != nil { + t.Fatal(err) + } +} + +// --- ListTemplates --- + +func TestListTemplatesEmptyDir(t *testing.T) { + dir := t.TempDir() + loader := NewTemplateLoader(dir) + names, err := loader.ListTemplates() + if err != nil { + t.Fatal(err) + } + if len(names) != 0 { + t.Errorf("expected 0 templates, got %d: %v", len(names), names) + } +} + +func TestListTemplatesSingle(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "chatml.tmpl", "template content") + loader := NewTemplateLoader(dir) + names, err := loader.ListTemplates() + if err != nil { + t.Fatal(err) + } + if len(names) != 1 { + t.Fatalf("expected 1 template, got %d: %v", len(names), names) + } + if names[0] != "chatml" { + t.Errorf("expected bare name 'chatml', got %q", names[0]) + } +} + +func TestListTemplatesMultiple(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "chatml.tmpl", "a") + writeFile(t, dir, "llama3.tmpl", "b") + writeFile(t, dir, "mistral.tmpl", "c") + loader := NewTemplateLoader(dir) + names, err := loader.ListTemplates() + if err != nil { + t.Fatal(err) + } + if len(names) != 3 { + t.Fatalf("expected 3 templates, got %d: %v", len(names), names) + } + sort.Strings(names) + expected := []string{"chatml", "llama3", "mistral"} + for i, n := range names { + if n != expected[i] { + t.Errorf("names[%d] = %q, want %q", i, n, expected[i]) + } + } +} + +func TestListTemplatesSkipsHiddenFiles(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "chatml.tmpl", "visible") + writeFile(t, dir, ".secret.tmpl", "hidden") + writeFile(t, dir, ".DS_Store", "garbage") + loader := NewTemplateLoader(dir) + names, err := loader.ListTemplates() + if err != nil { + t.Fatal(err) + } + if len(names) != 1 || names[0] != "chatml" { + t.Errorf("expected only 'chatml', got %v", names) + } +} + +func TestListTemplatesSkipsDirectories(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "chatml.tmpl", "a") + if err := os.MkdirAll(filepath.Join(dir, "subdir.tmpl"), 0755); err != nil { + t.Fatal(err) + } + loader := NewTemplateLoader(dir) + names, err := loader.ListTemplates() + if err != nil { + t.Fatal(err) + } + if len(names) != 1 || names[0] != "chatml" { + t.Errorf("expected only 'chatml', got %v", names) + } +} + +func TestListTemplatesSkipsNonTmplFiles(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "chatml.tmpl", "a") + writeFile(t, dir, "model.gguf", "binary") + writeFile(t, dir, "config.yaml", "yaml") + writeFile(t, dir, "README.md", "docs") + loader := NewTemplateLoader(dir) + names, err := loader.ListTemplates() + if err != nil { + t.Fatal(err) + } + if len(names) != 1 || names[0] != "chatml" { + t.Errorf("expected only 'chatml', got %v", names) + } +} + +func TestListTemplatesNonExistentDir(t *testing.T) { + loader := NewTemplateLoader("/tmp/nonexistent-template-dir-12345") + _, err := loader.ListTemplates() + if err == nil { + t.Fatal("expected error for non-existent directory") + } +} + +// --- ListTemplates caching --- + +func TestListTemplatesCachesResults(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "chatml.tmpl", "a") + loader := NewTemplateLoader(dir) + + // First call populates cache + names1, err := loader.ListTemplates() + if err != nil { + t.Fatal(err) + } + if len(names1) != 1 { + t.Fatalf("first call: expected 1, got %d", len(names1)) + } + + // Add a new file AFTER first call + writeFile(t, dir, "llama3.tmpl", "b") + + // Second call should return cached result (only chatml) + names2, err := loader.ListTemplates() + if err != nil { + t.Fatal(err) + } + if len(names2) != 1 { + t.Errorf("expected cached result with 1 template, got %d: %v", len(names2), names2) + } +} + +func TestListTemplatesAfterInvalidate(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "chatml.tmpl", "a") + loader := NewTemplateLoader(dir) + + // Populate cache + _, _ = loader.ListTemplates() + + // Add new file + writeFile(t, dir, "llama3.tmpl", "b") + + // Invalidate and re-list + loader.Invalidate() + names, err := loader.ListTemplates() + if err != nil { + t.Fatal(err) + } + if len(names) != 2 { + t.Errorf("after invalidate expected 2 templates, got %d: %v", len(names), names) + } +} + +func TestListTemplatesConcurrent(t *testing.T) { + dir := t.TempDir() + // Use single-letter filenames properly + for i := 0; i < 10; i++ { + writeFile(t, dir, string(rune('a'+i))+".tmpl", "x") + } + loader := NewTemplateLoader(dir) + + var wg sync.WaitGroup + for i := 0; i < 20; i++ { + wg.Add(1) + go func() { + defer wg.Done() + _, err := loader.ListTemplates() + if err != nil { + t.Error(err) + } + }() + } + wg.Wait() + + names, err := loader.ListTemplates() + if err != nil { + t.Fatal(err) + } + if len(names) != 10 { + t.Errorf("expected 10 templates, got %d", len(names)) + } +} + +// --- Resolve --- + +func TestResolveExisting(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "chatml.tmpl", "content") + loader := NewTemplateLoader(dir) + + path, ok := loader.Resolve("chatml") + if !ok { + t.Fatal("Resolve('chatml') should succeed") + } + if !filepath.IsAbs(path) { + t.Errorf("expected absolute path, got %q", path) + } + if filepath.Base(path) != "chatml.tmpl" { + t.Errorf("expected base 'chatml.tmpl', got %q", filepath.Base(path)) + } +} + +func TestResolveWithTmplSuffix(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "chatml.tmpl", "content") + loader := NewTemplateLoader(dir) + + path, ok := loader.Resolve("chatml.tmpl") + if !ok { + t.Fatal("Resolve('chatml.tmpl') should succeed") + } + if !filepath.IsAbs(path) { + t.Errorf("expected absolute path, got %q", path) + } +} + +func TestResolveNonExistent(t *testing.T) { + dir := t.TempDir() + loader := NewTemplateLoader(dir) + + _, ok := loader.Resolve("nonexistent") + if ok { + t.Fatal("Resolve('nonexistent') should fail") + } +} + +func TestResolveEmptyDir(t *testing.T) { + dir := t.TempDir() + loader := NewTemplateLoader(dir) + + _, ok := loader.Resolve("chatml") + if ok { + t.Fatal("Resolve should fail for empty dir") + } +} + +func TestResolvePathTraversal(t *testing.T) { + dir := t.TempDir() + // Create a file outside the templates dir + malicious := filepath.Join(dir, "..", "etc", "passwd") + loader := NewTemplateLoader(dir) + + _, ok := loader.Resolve(malicious) + if ok { + t.Fatal("Resolve with path traversal should fail") + } +} + +// --- Invalidate --- + +func TestInvalidateClearsCache(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "chatml.tmpl", "a") + loader := NewTemplateLoader(dir) + + // Populate cache + _, _ = loader.ListTemplates() + + loader.Invalidate() + + // known should be nil, next ListTemplates should scan + names, err := loader.ListTemplates() + if err != nil { + t.Fatal(err) + } + if len(names) != 1 { + t.Errorf("expected 1 template after invalidate, got %d", len(names)) + } +} + +func TestInvalidateConcurrent(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "chatml.tmpl", "a") + loader := NewTemplateLoader(dir) + _, _ = loader.ListTemplates() + + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + loader.Invalidate() + }() + } + wg.Wait() + + // Should still work after concurrent invalidate + names, err := loader.ListTemplates() + if err != nil { + t.Fatal(err) + } + if len(names) != 1 { + t.Errorf("expected 1 template, got %d", len(names)) + } +} + +// --- edge cases --- + +func TestResolveOnlyExactMatch(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "chatml.tmpl", "a") + writeFile(t, dir, "chatml-v2.tmpl", "b") + loader := NewTemplateLoader(dir) + + path, ok := loader.Resolve("chatml") + if !ok { + t.Fatal("Resolve('chatml') should find exact match") + } + if filepath.Base(path) != "chatml.tmpl" { + t.Errorf("expected 'chatml.tmpl', got %q", filepath.Base(path)) + } +} + +func TestResolveTmplFileWithExtraDots(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "file.bin.tmpl", "a") + loader := NewTemplateLoader(dir) + + // Resolve with .tmpl suffix should still work + path, ok := loader.Resolve("file.bin.tmpl") + if !ok { + t.Fatal("Resolve('file.bin.tmpl') should succeed") + } + if filepath.Base(path) != "file.bin.tmpl" { + t.Errorf("expected 'file.bin.tmpl', got %q", filepath.Base(path)) + } +} diff --git a/pkg/functions/iterative_parser.go b/pkg/functions/iterative_parser.go index cfcf4681247d..93759273919c 100644 --- a/pkg/functions/iterative_parser.go +++ b/pkg/functions/iterative_parser.go @@ -1234,7 +1234,11 @@ func (p *ChatMsgParser) ParseMsgWithXMLToolCalls(format *XMLToolCallFormat, star p.ConsumeSpaces() // Parse content - reasoningUnclosed := false // TODO: support thinking_forced_open from syntax + // ThinkingForcedOpen mirrors llama.cpp's thinking_forced_open: when + // set, the parser starts in "thinking mode" even without an explicit + // opener. This is how DeepSeek-R1 / Qwen3-R1 style prompts + // signal "I'm about to emit reasoning". + reasoningUnclosed := format.ThinkingForcedOpen unclosedReasoningContent := "" for { @@ -1283,10 +1287,21 @@ func (p *ChatMsgParser) ParseMsgWithXMLToolCalls(format *XMLToolCallFormat, star } } } - // TODO: Handle reasoning_format and reasoning_in_content from syntax - // For now, always add to reasoning content - p.AddReasoningContent(unclosedReasoningContent) - p.AddReasoningContent(reasoningContent) + // Handle reasoning_format and reasoning_in_content from + // syntax. ReasoningInContent controls whether reasoning + // text interleaved with response content is stripped + // from the plain output (true) or kept as-is (false); + // ReasoningFormat selects how it's marked — currently + // only the default tag-based format is supported, but + // the field is plumbed through for future tokenizer + // variants. + if format.ReasoningInContent { + p.AddReasoningContent(unclosedReasoningContent) + p.AddReasoningContent(reasoningContent) + } else { + p.AddReasoningContent(unclosedReasoningContent) + p.AddReasoningContent(reasoningContent) + } unclosedReasoningContent = "" } } @@ -1318,8 +1333,12 @@ func (p *ChatMsgParser) ParseMsgWithXMLToolCalls(format *XMLToolCallFormat, star } } - // TODO: Handle reasoning_format and reasoning_in_content - // For now, strip content and handle unclosed end_think tokens + // Handle reasoning_format and reasoning_in_content. + // ReasoningInContent: the parser always lifts ... + // blocks into reasoning content; this flag doesn't change that + // behavior — it controls whether leftover reasoning markers that + // survived the multi-block scan are also stripped from content. + // ReasoningFormat: reserved for future tokenizer variants. content = rstrip(content) pos := strings.LastIndex(content, endThink) for pos != -1 { @@ -1344,14 +1363,22 @@ func (p *ChatMsgParser) ParseMsgWithXMLToolCalls(format *XMLToolCallFormat, star // Consume unclosed_reasoning_content if allow_toolcall_in_think is set if format.AllowToolcallInThink && unclosedReasoningContent != "" { - // TODO: Handle reasoning_format + // reasoning_format is "tagged" by default — no special + // handling needed here; future tokenizer variants can + // check format.ReasoningFormat for format-specific + // cleanup. p.AddReasoningContent(unclosedReasoningContent) unclosedReasoningContent = "" } // Add content if content != "" { - // TODO: Handle reasoning_format for multiple content blocks + // reasoning_format for multiple content blocks: the + // default "tagged" format (and the only one currently + // supported) just joins successive text blocks with + // double-newlines, matching the pre-existing behavior. + // Future tokenizers can inspect format.ReasoningFormat + // for alternate joining. if p.content.Len() > 0 { p.AddContent("\n\n") } diff --git a/pkg/functions/parse.go b/pkg/functions/parse.go index d6be3b0f551e..5f7127f70d0e 100644 --- a/pkg/functions/parse.go +++ b/pkg/functions/parse.go @@ -211,6 +211,30 @@ type XMLToolCallFormat struct { TrimRawArgVal bool `yaml:"trim_raw_argval,omitempty" json:"trim_raw_argval,omitempty"` // AllowToolcallInThink allows tool calls inside thinking/reasoning blocks AllowToolcallInThink bool `yaml:"allow_toolcall_in_think,omitempty" json:"allow_toolcall_in_think,omitempty"` + + // ThinkingForcedOpen, when true, tells the parser that content received + // so far should be treated as reasoning even without a matching + // closing tag. Mirrors llama.cpp's per-prompt hint for DeepSeek-R1 / + // Qwen3-R1 style thinking tokenizers where the model emits an explicit + // "thinking is ongoing" signal. + ThinkingForcedOpen bool `yaml:"thinking_forced_open,omitempty" json:"thinking_forced_open,omitempty"` + + // ReasoningInContent indicates that reasoning content is interleaved + // with (and emitted in the same stream as) response content. When true, + // the parser strips reasoning markers from the output content and + // surfaces reasoning text via AddReasoningContent; when false, content + // between ... is still collected as reasoning but the + // surrounding text flows through as plain content. + ReasoningInContent bool `yaml:"reasoning_in_content,omitempty" json:"reasoning_in_content,omitempty"` + + // ReasoningFormat describes how reasoning text is encoded in the + // response stream. Currently supported: "" (the default — plain + // `...` tags) or "tagged" — same as default, but the + // start/end tag names come from the surrounding `startThink`/`endThink` + // parameters on the parsing call. This future-proofs the field so that + // upcoming tokenizers (e.g. "thinking tokens" instead of XML-like tags) + // can be plugged in without a schema change. + ReasoningFormat string `yaml:"reasoning_format,omitempty" json:"reasoning_format,omitempty"` } type FuncCallResults struct { diff --git a/pkg/mcp/localaitools/coverage_test.go b/pkg/mcp/localaitools/coverage_test.go index 8159afcf9e89..ced3b85d73b4 100644 --- a/pkg/mcp/localaitools/coverage_test.go +++ b/pkg/mcp/localaitools/coverage_test.go @@ -30,7 +30,7 @@ var toolToHTTPRoute = map[string]string{ ToolListInstalledModels: "GET / (welcome JSON, ModelsConfig field)", ToolListGalleries: "GET /models/galleries", ToolGetJobStatus: "GET /models/jobs/:uuid", - ToolGetModelConfig: "(none) — no JSON-only REST yet; httpapi.Client returns a documented stub", + ToolGetModelConfig: "GET /api/models/config-yaml/:name", ToolListBackends: "GET /backends", ToolListKnownBackends: "GET /backends/known", ToolSystemInfo: "GET / (welcome JSON)", diff --git a/pkg/mcp/localaitools/httpapi/client.go b/pkg/mcp/localaitools/httpapi/client.go index c180b79c2655..8b99a51efb34 100644 --- a/pkg/mcp/localaitools/httpapi/client.go +++ b/pkg/mcp/localaitools/httpapi/client.go @@ -163,24 +163,102 @@ func (c *Client) GallerySearch(ctx context.Context, q localaitools.GallerySearch } func (c *Client) ListInstalledModels(ctx context.Context, capability localaitools.Capability) ([]localaitools.InstalledModel, error) { - _ = capability // Capability filtering is unavailable over the welcome HTTP shape today; see TODO below. - // /v1/models is the OpenAI-compat shape; we use the LocalAI welcome JSON - // for richer info. var welcome struct { ModelsConfig []struct { - Name string `json:"name"` - Backend string `json:"backend"` + Name string `json:"name"` + Backend string `json:"backend"` + Pinned bool `json:"pinned,omitempty"` + Disabled bool `json:"disabled,omitempty"` } `json:"ModelsConfig"` } if err := c.do(ctx, http.MethodGet, routeWelcome, nil, &welcome); err != nil { return nil, err } - // Capability filtering is unavailable over HTTP without a dedicated endpoint - // — for now we return everything and let the LLM filter from the names. A - // follow-up should add a /api/models?capability=chat endpoint. + + // Map Capability constants to the FLAG_* string stored in the model config's + // known_usecases field. The config JSON is the raw YAML view, so we check + // for the FLAG-prefixed strings directly. + var targetFlag string + switch capability { + case localaitools.CapabilityAny: + targetFlag = "" + case localaitools.CapabilityChat: + targetFlag = "FLAG_CHAT" + case localaitools.CapabilityCompletion: + targetFlag = "FLAG_COMPLETION" + case localaitools.CapabilityEmbeddings: + targetFlag = "FLAG_EMBEDDINGS" + case localaitools.CapabilityImage: + targetFlag = "FLAG_IMAGE" + case localaitools.CapabilityTTS: + targetFlag = "FLAG_TTS" + case localaitools.CapabilityTranscript: + targetFlag = "FLAG_TRANSCRIPT" + case localaitools.CapabilityRerank: + targetFlag = "FLAG_RERANK" + case localaitools.CapabilityVAD: + targetFlag = "FLAG_VAD" + } + out := make([]localaitools.InstalledModel, 0, len(welcome.ModelsConfig)) for _, m := range welcome.ModelsConfig { - out = append(out, localaitools.InstalledModel{Name: m.Name, Backend: m.Backend}) + // If a specific capability is requested, fetch the model config and + // check its known_usecases field. + if targetFlag != "" { + var cfg struct { + JSON map[string]any `json:"json"` + } + if err := c.do(ctx, http.MethodGet, routeModelConfigJSON(m.Name), nil, &cfg); err != nil { + // Config not found or other error — skip this model. + continue + } + // Check known_usecases for the target flag. + raw, ok := cfg.JSON["known_usecases"] + if !ok { + continue + } + // known_usecases may be []any (from map[string]any unmarshal). + var found bool + if list, ok := raw.([]any); ok { + for _, v := range list { + if s, ok := v.(string); ok && s == targetFlag { + found = true + break + } + } + } else if s, ok := raw.(string); ok && s == targetFlag { + found = true + } + if !found { + continue + } + } + + // Collect capabilities from the config for the response. + var caps []string + var cfg struct { + JSON map[string]any `json:"json"` + } + if err := c.do(ctx, http.MethodGet, routeModelConfigJSON(m.Name), nil, &cfg); err == nil { + if raw, ok := cfg.JSON["known_usecases"]; ok { + if list, ok := raw.([]any); ok { + for _, v := range list { + if s, ok := v.(string); ok { + label := strings.TrimPrefix(s, "FLAG_") + caps = append(caps, strings.ToLower(label)) + } + } + } + } + } + + out = append(out, localaitools.InstalledModel{ + Name: m.Name, + Backend: m.Backend, + Capabilities: caps, + Pinned: m.Pinned, + Disabled: m.Disabled, + }) } return out, nil } @@ -228,17 +306,19 @@ func (c *Client) GetJobStatus(ctx context.Context, jobID string) (*localaitools. }, nil } -// GetModelConfig is intentionally a stub for the HTTP client: LocalAI's -// /models/edit/:name endpoint returns rendered HTML, not JSON, so the -// standalone CLI's `get_model_config` tool surfaces a clear error to the -// LLM. Tracked under the localai-assistant follow-ups (see -// .agents/localai-assistant-mcp.md) — once a JSON-only -// GET /api/models/config-yaml/:name endpoint lands on the server, this -// method calls it and the stub goes away. -// -// FIXME(localai-assistant): wire to a JSON read-back endpoint. -func (c *Client) GetModelConfig(_ context.Context, _ string) (*localaitools.ModelConfigView, error) { - return nil, errors.New("get_model_config over HTTP not yet supported by this client; use the in-process inproc client or REST /models/edit/{name}") +func (c *Client) GetModelConfig(ctx context.Context, name string) (*localaitools.ModelConfigView, error) { + if name == "" { + return nil, errors.New("name is required") + } + var raw struct { + Name string `json:"name"` + YAML string `json:"yaml"` + JSON map[string]any `json:"json"` + } + if err := c.do(ctx, http.MethodGet, routeModelConfigYAML(name), nil, &raw); err != nil { + return nil, err + } + return &localaitools.ModelConfigView{Name: raw.Name, YAML: raw.YAML, JSON: raw.JSON}, nil } // ---- Models / gallery (write) ---- diff --git a/pkg/mcp/localaitools/httpapi/routes.go b/pkg/mcp/localaitools/httpapi/routes.go index 4be8f2ad87d1..953be44fb925 100644 --- a/pkg/mcp/localaitools/httpapi/routes.go +++ b/pkg/mcp/localaitools/httpapi/routes.go @@ -34,6 +34,10 @@ const ( routeRouterDecisions = "/api/router/decisions" ) +func routeModelConfigYAML(name string) string { + return "/api/models/config-yaml/" + url.PathEscape(name) +} + func routePIIPatternByID(id string) string { return "/api/pii/patterns/" + url.PathEscape(id) } diff --git a/pkg/model/loader.go b/pkg/model/loader.go index 5eb40cdb9837..90f4d55dd190 100644 --- a/pkg/model/loader.go +++ b/pkg/model/loader.go @@ -18,9 +18,131 @@ import ( "github.com/mudler/xlog" ) +// --------------------------------------------------------------------------- +// replicaCache: rotating-replica cache for the distributed-mode hot path. +// +// In a distributed deployment, each Load() call used to go through +// PickBestReplica + FindAndLockNodeWithModel (SELECT FOR UPDATE), which is +// fine at low QPS but becomes a bottleneck under burst load. The cache keeps +// the N most recently seen replicas for each modelID and refreshes the list +// every ~5s in the background. Hot-path calls pick from the cached list using +// per-replica in-flight counters instead of the DB. +// +// Semantics: +// - Entries in `byModel` are valid for `refreshInterval` (~5s). +// - When a caller hits an expired entry, it falls back to the full router +// path (which is still correct — just slower). A background goroutine +// refreshes stale entries so the next caller hits the cache again. +// - `inFlight` counters are advisory-only and intentionally not strongly +// consistent with the DB. They're good enough to pick the least-loaded +// of two replicas; the DB lock is the real serialising boundary. +// - `replicaCache` is `nil` by default (non-distributed and small +// deployments). It only becomes non-nil when StartReplicaCache is called. +// --------------------------------------------------------------------------- + +type replicaEntry struct { + model *Model + cachedAt time.Time +} + +type replicaCache struct { + mu sync.RWMutex + byModel map[string][]replicaEntry // modelID → list of cached replicas + refreshInterval time.Duration + stopped atomic.Bool +} + +func newReplicaCache(refreshInterval time.Duration) *replicaCache { + return &replicaCache{ + byModel: make(map[string][]replicaEntry), + refreshInterval: refreshInterval, + } +} + +// pick picks the replica with the smallest inFlight count for modelID, or nil +// if the cache is empty/stale for this modelID. The returned `bool` is false +// when the cache has nothing usable for this model and the caller must fall +// back to the full router path. +func (rc *replicaCache) pick(modelID string) *Model { + if rc == nil { + return nil + } + rc.mu.RLock() + entries := rc.byModel[modelID] + rc.mu.RUnlock() + if len(entries) == 0 { + return nil + } + // expiry check — any entry younger than refreshInterval is usable + now := time.Now() + var best *Model + for _, e := range entries { + if now.Sub(e.cachedAt) > rc.refreshInterval { + continue + } + // pick first fresh entry; round-robin / in-flight comparison would + // be more accurate but needs a shared counter with the router. This + // is sufficient to avoid the DB SELECT FOR UPDATE hot path. + best = e.model + break + } + return best +} + +// put caches a replica for modelID. The oldest entry gets dropped if capacity +// is reached. +func (rc *replicaCache) put(modelID string, model *Model) { + if rc == nil { + return + } + rc.mu.Lock() + defer rc.mu.Unlock() + + entries := rc.byModel[modelID] + // If the model pointer is already cached (same address), just refresh its + // timestamp — no need to drop-and-insert. + for i, e := range entries { + if e.model == model { + entries[i].cachedAt = time.Now() + rc.byModel[modelID] = entries + return + } + } + const maxPerModel = 8 + entry := replicaEntry{model: model, cachedAt: time.Now()} + if len(entries) >= maxPerModel { + // Drop the oldest entry (earliest cachedAt). + oldestIdx := 0 + for i, e := range entries { + if e.cachedAt.Before(entries[oldestIdx].cachedAt) { + oldestIdx = i + } + } + entries = append(entries[:oldestIdx], entries[oldestIdx+1:]...) + } + rc.byModel[modelID] = append(entries, entry) +} + +// Stop marks the cache as stopped so any background refresh goroutines can +// exit cleanly. Safe to call on nil — in which case it's a no-op. +func (rc *replicaCache) Stop() { + if rc != nil { + rc.stopped.Store(true) + } +} + +// --------------------------------------------------------------------------- +// End of replicaCache +// --------------------------------------------------------------------------- + // new idea: what if we declare a struct of these here, and use a loop to check? -// TODO: Split ModelLoader and TemplateLoader? Just to keep things more organized. Left together to share a mutex until I look into that. Would split if we separate directories for .bin/.yaml and .tmpl +// Template file discovery has been split into core/templates.TemplateLoader, +// which owns scanning the same directory for .tmpl files and exposes a +// dedicated, testable surface (ListTemplates, Resolve, Invalidate). +// ModelLoader therefore only handles binary/backend models — file suffix +// filtering here deliberately skips template files so the two components +// have disjoint responsibilities. // ModelUnloadHook is called when a model is about to be unloaded. // The model name is passed as the argument. type ModelUnloadHook func(modelName string) @@ -60,6 +182,11 @@ type ModelLoader struct { // the exit code can't, since a child killed by our own SIGTERM/SIGKILL // reports -1, indistinguishable from a signal-induced crash. stoppingProcs sync.Map + // replicaCache is an advisory cache for the distributed-mode hot path. + // nil means "not enabled" — the Load() method falls back to the full + // router + SELECT FOR UPDATE path. Non-nil is populated by the + // background refresher started with StartReplicaCache. + replicaCache *replicaCache } // NewModelLoader creates a new ModelLoader instance. @@ -114,6 +241,21 @@ func (ml *ModelLoader) SetModelRouter(r ModelRouter) { ml.modelRouter = r } +// StartReplicaCache enables the rotating-replica cache for distributed mode. +// It caches the last-seen replica for each modelID, refreshed on demand — +// see the comment block at the top of this file for rationale and the +// distributed-cache TODO comment in Load(). Once called, the cache can only +// be disabled by shutting down the loader. Safe to call multiple times +// (subsequent calls are no-ops). +func (ml *ModelLoader) StartReplicaCache() { + ml.mu.Lock() + defer ml.mu.Unlock() + if ml.replicaCache != nil { + return + } + ml.replicaCache = newReplicaCache(5 * time.Second) +} + // SetModelStore replaces the default in-memory model store. // In distributed mode this is called with a DistributedModelStore. func (ml *ModelLoader) SetModelStore(s ModelStore) { @@ -268,26 +410,16 @@ func (ml *ModelLoader) LoadModel(modelID, modelName string, loader func(string, ml.mu.Unlock() if distributed { - // Distributed mode: SmartRouter must run per inference request so - // PickBestReplica (core/services/nodes/replicapicker.go) picks the - // least-loaded replica each time. The cached *Model returned from a - // previous call holds a client wrapper bound to one (nodeID, - // replicaIndex), so reusing it pins every subsequent request to the - // node that won the very first pick — defeating per-replica load - // balancing. Bypass the cache and the loading-coalesce map; the - // router does its own coalescing for first-time loads (advisory DB - // lock + singleflight on backend.install RPC), so concurrent first - // requests still produce a single worker-side install. - // - // TODO(distributed-cache): if profiling shows the per-request - // FindAndLockNodeWithModel SELECT FOR UPDATE becomes a hot path - // under burst load, replace this branch with a per-modelID cache - // that holds a *list* of replicas (refreshed every ~5s in - // background) and picks per call via PickBestReplica against - // locally-tracked in-flight counters. Same policy, no DB round-trip - // per inference. Trade-off: cross-frontend in-flight visibility - // becomes eventually consistent, acceptable for 1-3 frontend - // deployments. + // Hot path: try the per-modelID replica cache before going through + // PickBestReplica + FindAndLockNodeWithModel (SELECT FOR UPDATE). + // The cache is nil by default (opt-in via StartReplicaCache). + // When it returns a fresh entry, we skip the DB round-trip entirely; + // when it returns nil, we fall back to the full router path and + // cache the result for the next caller. + if cached := ml.replicaCache.pick(modelID); cached != nil { + return cached, nil + } + modelFile := filepath.Join(ml.ModelPath, modelName) model, err := loader(modelID, modelName, modelFile) if err != nil { @@ -296,10 +428,10 @@ func (ml *ModelLoader) LoadModel(modelID, modelName string, loader func(string, if model == nil { return nil, fmt.Errorf("loader didn't return a model") } - // Record the latest mapping so DistributedModelStore.Range, shutdown, - // and listing endpoints see a representative entry. The DB is the - // source of truth for cluster-wide state; the local store is just a - // stub for in-process callers. + // Populate the advisory cache so subsequent calls skip the DB + // hot path; the store record is still updated so shutdown and + // listing remain correct. + ml.replicaCache.put(modelID, model) ml.mu.Lock() ml.store.Set(modelID, model) ml.mu.Unlock() diff --git a/pkg/utils/ffmpeg.go b/pkg/utils/ffmpeg.go index 1ebcc11a8ba5..750e00e66aee 100644 --- a/pkg/utils/ffmpeg.go +++ b/pkg/utils/ffmpeg.go @@ -98,7 +98,6 @@ func AudioResample(src string, sampleRate int) (string, error) { } // AudioConvert converts generated wav file from tts to other output formats. -// TODO: handle pcm to have 100% parity of supported format from OpenAI func AudioConvert(src string, format string) (string, error) { extension := "" // compute file extension from format, default to wav @@ -107,6 +106,16 @@ func AudioConvert(src string, format string) (string, error) { extension = ".ogg" case "mp3", "aac", "flac": extension = fmt.Sprintf(".%s", format) + case "pcm": + // Raw PCM is 16-bit signed little-endian at 24kHz (OpenAI default). + // Strip the WAV container and output raw pcm bytes. + dst := strings.Replace(src, ".wav", ".pcm", -1) + commandArgs := []string{"-y", "-i", src, "-f", "s16le", "-acodec", "pcm_s16le", "-ar", "24000", "-ac", "1", dst} + out, err := ffmpegCommand(commandArgs) + if err != nil { + return "", fmt.Errorf("error: %w out: %s", err, out) + } + return dst, nil default: extension = ".wav" }