diff --git a/.github/workflows/secscan.yaml b/.github/workflows/secscan.yaml index bb381567baa5..c09f30867db8 100644 --- a/.github/workflows/secscan.yaml +++ b/.github/workflows/secscan.yaml @@ -15,15 +15,15 @@ jobs: steps: - name: Checkout Source uses: actions/checkout@v6 - if: ${{ github.actor != 'dependabot[bot]' }} + if: ${{ !github.repository.fork && github.actor != 'dependabot[bot]' }} - name: Run Gosec Security Scanner - if: ${{ github.actor != 'dependabot[bot]' }} + if: ${{ !github.repository.fork && github.actor != 'dependabot[bot]' }} uses: securego/gosec@v2.27.1 with: # we let the report trigger content trigger a failure using the GitHub Security features. args: '-no-fail -fmt sarif -out results.sarif ./...' - name: Upload SARIF file - if: ${{ github.actor != 'dependabot[bot]' }} + if: ${{ !github.repository.fork && github.actor != 'dependabot[bot]' }} uses: github/codeql-action/upload-sarif@v4 with: # Path to SARIF file relative to the root of the repository diff --git a/core/http/endpoints/localai/backend_monitor.go b/core/http/endpoints/localai/backend_monitor.go index d138c5bee3e0..c01931f2a1dc 100644 --- a/core/http/endpoints/localai/backend_monitor.go +++ b/core/http/endpoints/localai/backend_monitor.go @@ -28,7 +28,7 @@ func BackendMonitorEndpoint(bm *monitoring.BackendMonitorService) echo.HandlerFu return echo.NewHTTPError(400, "model query parameter is required") } - resp, err := bm.CheckAndSample(model) + resp, err := bm.CheckAndSample(c.Request().Context(), model) if err != nil { return err } diff --git a/core/http/endpoints/localai/config_meta.go b/core/http/endpoints/localai/config_meta.go index 7456c7f60719..933a26730e2b 100644 --- a/core/http/endpoints/localai/config_meta.go +++ b/core/http/endpoints/localai/config_meta.go @@ -143,6 +143,36 @@ func AutocompleteEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, a } } +// GetConfigEndpoint returns the YAML + JSON view for an installed model. +// Used by the MCP httpapi.Client for get_model_config, and by the React +// model editor when it wants a clean disk-read view (not the in-memory +// loader copy which has SetDefaults applied). +// @Summary Read a model configuration from disk +// @Description Returns the raw YAML and parsed JSON view of an installed model's config file +// @Tags config +// @Produce json +// @Param name path string true "Model name" +// @Success 200 {object} map[string]any "{name, yaml, json}" +// @Router /api/models/config-yaml/{name} [get] +func GetConfigEndpoint(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc { + svc := modeladmin.NewConfigService(cl, appConfig) + return func(c echo.Context) error { + modelName := c.Param("name") + if decoded, err := url.PathUnescape(modelName); err == nil { + modelName = decoded + } + view, err := svc.GetConfig(c.Request().Context(), modelName) + if err != nil { + return c.JSON(httpStatusForModelAdminError(err), map[string]any{"error": err.Error()}) + } + return c.JSON(http.StatusOK, map[string]any{ + "name": view.Name, + "yaml": view.YAML, + "json": view.JSON, + }) + } +} + // PatchConfigEndpoint handles PATCH requests to partially update a model config // using nested JSON merge. // @Summary Partially update a model configuration diff --git a/core/http/endpoints/localai/metrics.go b/core/http/endpoints/localai/metrics.go index 470356d97be7..2b5b719d6438 100644 --- a/core/http/endpoints/localai/metrics.go +++ b/core/http/endpoints/localai/metrics.go @@ -42,7 +42,7 @@ func LocalAIMetricsAPIMiddleware(metrics *monitoring.LocalAIMetricsService) echo start := time.Now() err := next(c) elapsed := float64(time.Since(start)) / float64(time.Second) - cfg.metricsService.ObserveAPICall(method, path, elapsed) + cfg.metricsService.ObserveAPICall(c.Request().Context(), method, path, elapsed) return err } } diff --git a/core/http/endpoints/openai/realtime.go b/core/http/endpoints/openai/realtime.go index 63b1a1589cfc..c066e74960a5 100644 --- a/core/http/endpoints/openai/realtime.go +++ b/core/http/endpoints/openai/realtime.go @@ -2,15 +2,19 @@ package openai import ( "context" + "crypto/hmac" "crypto/rand" + "crypto/sha256" "encoding/base64" "encoding/binary" "encoding/hex" "encoding/json" + "errors" "fmt" "math" "os" "strconv" + "strings" "sync" "time" @@ -252,16 +256,137 @@ var upgrader = websocket.Upgrader{ }, } -// TODO: Implement ephemeral keys to allow these endpoints to be used +// ephemeralSessionKeyTTL is the lifetime of a short-lived session token +// issued by POST /v1/realtime/sessions. These are consumed exactly once by +// the WebSocket handshake to /v1/realtime and allow clients that hold a +// regular API key at session-creation time to open a WebSocket without +// re-sending it on every frame — matching the OpenAI realtime API shape. +const ephemeralSessionKeyTTL = 60 * time.Second + +// ephemeralSessionKey combines a 32-byte random payload + the expiry time +// (Unix seconds) signed with HMAC-SHA256 using the application's API key +// HMAC secret. Returns "lai-sess:::". +func generateEphemeralSessionKey(hmacSecret, userID string) (string, time.Time, error) { + b := make([]byte, 32) + if _, err := rand.Read(b); err != nil { + return "", time.Time{}, fmt.Errorf("failed to generate session key: %w", err) + } + expiry := time.Now().UTC().Add(ephemeralSessionKeyTTL) + expiryStr := strconv.FormatInt(expiry.Unix(), 10) + payload := hex.EncodeToString(b) + ":" + expiryStr + ":" + userID + h := hmac.New(sha256.New, []byte(hmacSecret)) + h.Write([]byte(payload)) + sig := hex.EncodeToString(h.Sum(nil)) + return "lai-sess:" + payload + ":" + sig, expiry, nil +} + +// validateEphemeralSessionKey verifies the HMAC signature and expiry of a +// token produced by generateEphemeralSessionKey. Returns the embedded +// userID (may be empty for anonymous sessions) on success, or an error. +func validateEphemeralSessionKey(token, hmacSecret string) (string, error) { + if !strings.HasPrefix(token, "lai-sess:") { + return "", errors.New("invalid ephemeral session key: missing prefix") + } + rest := strings.TrimPrefix(token, "lai-sess:") + // rest = "payload_hex:expiry_unix:userID:signature" — the signature is + // always the last segment; everything before it is the signed payload. + lastColon := strings.LastIndex(rest, ":") + if lastColon == -1 { + return "", errors.New("invalid ephemeral session key: bad format") + } + payload, sig := rest[:lastColon], rest[lastColon+1:] + h := hmac.New(sha256.New, []byte(hmacSecret)) + h.Write([]byte(payload)) + expected := hex.EncodeToString(h.Sum(nil)) + if !hmac.Equal([]byte(sig), []byte(expected)) { + return "", errors.New("invalid ephemeral session key: bad signature") + } + parts := strings.SplitN(payload, ":", 3) // payload_hex:expiry_unix:userID(+rest) + if len(parts) < 2 { + return "", errors.New("invalid ephemeral session key: bad format") + } + expiryUnix, err := strconv.ParseInt(parts[1], 10, 64) + if err != nil { + return "", errors.New("invalid ephemeral session key: bad expiry") + } + if time.Now().UTC().Unix() > expiryUnix { + return "", errors.New("ephemeral session key expired") + } + userID := "" + if len(parts) >= 3 { + userID = parts[2] + } + return userID, nil +} + +// RealtimeSessions handles POST /v1/realtime/sessions. Generates a +// short-lived ephemeral token that is consumed by the /v1/realtime +// WebSocket handshake. When auth is disabled, the endpoint still issues a +// token (for compatibility) but does not require credentials. func RealtimeSessions(application *application.Application) echo.HandlerFunc { return func(c echo.Context) error { - return c.NoContent(501) + // When auth is enabled, the caller must authenticate with a + // regular API key or session cookie at session-creation time. + appCfg := application.ApplicationConfig() + userID := "" + if appCfg != nil && appCfg.Auth.Enabled { + if u := auth.GetUser(c); u != nil { + userID = u.ID + } + } + hmacSecret := "" + if appCfg != nil { + hmacSecret = appCfg.Auth.APIKeyHMACSecret + } + + token, expiresAt, err := generateEphemeralSessionKey(hmacSecret, userID) + if err != nil { + return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()}) + } + + // Response shape is a subset of the OpenAI Realtime Session object. + return c.JSON(http.StatusOK, map[string]any{ + "object": "realtime.session", + "id": "sess_" + hex.EncodeToString([]byte(fmt.Sprintf("%d", expiresAt.UnixNano())))[:8], + "model": "gpt-4o-realtime-preview", // placeholder — actual model is per-connection + "ephemeral_token": token, + "expires_at": expiresAt.UTC().Format(time.RFC3339), + "seconds_left": int64(time.Until(expiresAt).Seconds()), + "max_audio_additions": map[string]any{ + "total_tokens": 120000, + "input_tokens": 60000, + "output_tokens": 60000, + }, + }) } } +// RealtimeTranscriptionSession handles POST /v1/realtime/transcriptions — a +// transcription-only variant of the session endpoint. Shares the same +// ephemeral-session-key format so the same validator works for both. func RealtimeTranscriptionSession(application *application.Application) echo.HandlerFunc { return func(c echo.Context) error { - return c.NoContent(501) + appCfg := application.ApplicationConfig() + userID := "" + if appCfg != nil && appCfg.Auth.Enabled { + if u := auth.GetUser(c); u != nil { + userID = u.ID + } + } + hmacSecret := "" + if appCfg != nil { + hmacSecret = appCfg.Auth.APIKeyHMACSecret + } + token, expiresAt, err := generateEphemeralSessionKey(hmacSecret, userID) + if err != nil { + return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()}) + } + return c.JSON(http.StatusOK, map[string]any{ + "object": "realtime.transcription_session", + "ephemeral_token": token, + "expires_at": expiresAt.UTC().Format(time.RFC3339), + "seconds_left": int64(time.Until(expiresAt).Seconds()), + }) } } @@ -280,6 +405,29 @@ type RealtimeSessionOptions struct { func Realtime(application *application.Application) echo.HandlerFunc { return func(c echo.Context) error { + // Ephemeral session key validation. When auth is enabled, the + // client must pass a valid token (from POST /v1/realtime/sessions) + // either via `?session=` query parameter or via the Authorization + // header as `Bearer `. The standard `isCurrentUserAdmin` + // path is honored when auth is disabled. + appCfg := application.ApplicationConfig() + if appCfg != nil && appCfg.Auth.Enabled { + token := c.QueryParam("session") + if token == "" { + // Fall back to Authorization: Bearer + ah := c.Request().Header.Get("Authorization") + if strings.HasPrefix(ah, "Bearer ") { + token = strings.TrimPrefix(ah, "Bearer ") + } + } + if token == "" { + return echo.NewHTTPError(http.StatusUnauthorized, "missing ephemeral session key — call POST /v1/realtime/sessions first") + } + if _, err := validateEphemeralSessionKey(token, appCfg.Auth.APIKeyHMACSecret); err != nil { + return echo.NewHTTPError(http.StatusUnauthorized, fmt.Sprintf("invalid ephemeral session key: %s", err.Error())) + } + } + ws, err := upgrader.Upgrade(c.Response(), c.Request(), nil) if err != nil { return err diff --git a/core/http/endpoints/openai/realtime_model.go b/core/http/endpoints/openai/realtime_model.go index 789ce0a0d319..82281002b725 100644 --- a/core/http/endpoints/openai/realtime_model.go +++ b/core/http/endpoints/openai/realtime_model.go @@ -486,39 +486,22 @@ func newModel(pipeline *config.Pipeline, cl *config.ModelConfigLoader, ml *model return nil, fmt.Errorf("failed to validate config: %w", err) } - // TODO: Do we always need a transcription model? It can be disabled. Note that any-to-any instruction following models don't transcribe as such, so if transcription is required it is a separate process - cfgSST, err := cl.LoadModelConfigFileByName(pipeline.Transcription, ml.ModelPath) - if err != nil { - - return nil, fmt.Errorf("failed to load backend config: %w", err) - } - - if valid, _ := cfgSST.Validate(); !valid { - return nil, fmt.Errorf("failed to validate config: %w", err) + // Transcription (SST) is optional. If the pipeline doesn't specify one + // (pipeline.Transcription == "") or the LLM is already any-to-any, we + // skip loading a separate transcription config and Transcribe/TranscribeStream + // will fall back to erroring or the any-to-any backend. + var cfgSST *config.ModelConfig + if pipeline.Transcription != "" { + c, err := cl.LoadModelConfigFileByName(pipeline.Transcription, ml.ModelPath) + if err != nil { + return nil, fmt.Errorf("failed to load transcription config: %w", err) + } + if valid, _ := c.Validate(); !valid { + return nil, fmt.Errorf("failed to validate transcription config: %w", err) + } + cfgSST = c } - // TODO: Decide when we have a real any-to-any model - // if false { - // - // cfgAnyToAny, err := cl.LoadModelConfigFileByName(pipeline.LLM, ml.ModelPath) - // if err != nil { - // - // return nil, fmt.Errorf("failed to load backend config: %w", err) - // } - // - // if valid, _ := cfgAnyToAny.Validate(); !valid { - // return nil, fmt.Errorf("failed to validate config: %w", err) - // } - // - // return &anyToAnyModel{ - // LLMConfig: cfgAnyToAny, - // VADConfig: cfgVAD, - // }, nil - // } - - xlog.Debug("Loading a wrapped model") - - // Otherwise we want to return a wrapped model, which is a "virtual" model that re-uses other models to perform operations cfgLLM, err := cl.LoadModelConfigFileByName(pipeline.LLM, ml.ModelPath) if err != nil { @@ -529,11 +512,34 @@ func newModel(pipeline *config.Pipeline, cl *config.ModelConfigLoader, ml *model return nil, fmt.Errorf("failed to validate config: %w", err) } + // Any-to-any detection. If the LLM model declares FLAG_REALTIME_AUDIO + // (or one of the known any-to-any backends), skip the TTS + SST pipeline + // and route audio directly through the LLM. + isAnyToAny := false + if cfgLLM != nil { + isAnyToAny = cfgLLM.Backend == "liquid-audio" || cfgLLM.HasUsecases(config.FLAG_REALTIME_AUDIO) + } + // Let the pipeline set the LLM's reasoning effort and force thinking off // (cfgLLM is a per-session copy). disable_thinking applies after the effort. applyPipelineReasoning(cfgLLM, *pipeline) applyPipelineThinking(cfgLLM, *pipeline) + if isAnyToAny { + xlog.Debug("Loading an any-to-any model (native AudioToAudioStream)") + return &anyToAnyModel{ + LLMConfig: cfgLLM, + VADConfig: cfgVAD, + + confLoader: cl, + modelLoader: ml, + appConfig: appConfig, + }, nil + } + + xlog.Debug("Loading a wrapped model") + + // Otherwise we want to return a wrapped model, which is a "virtual" model that re-uses other models to perform operations cfgTTS, err := cl.LoadModelConfigFileByName(pipeline.TTS, ml.ModelPath) if err != nil { diff --git a/core/http/routes/localai.go b/core/http/routes/localai.go index 96baceaf8e44..d50eb8b4110d 100644 --- a/core/http/routes/localai.go +++ b/core/http/routes/localai.go @@ -92,6 +92,10 @@ func RegisterLocalAIRoutes(router *echo.Echo, router.POST("/models/reload", localai.ReloadModelsEndpoint(cl, appConfig), adminMiddleware) } + // JSON read-back of an installed model's YAML config (used by the + // standalone MCP server so it can call get_model_config over REST). + router.GET("/api/models/config-yaml/:name", localai.GetConfigEndpoint(cl, appConfig), adminMiddleware) + detectionHandler := localai.DetectionEndpoint(cl, ml, appConfig) router.POST("/v1/detection", detectionHandler, diff --git a/core/p2p/node.go b/core/p2p/node.go index 3b56fd7220e6..beb68728d3eb 100644 --- a/core/p2p/node.go +++ b/core/p2p/node.go @@ -59,3 +59,19 @@ func AddNode(serviceID string, node schema.NodeData) { } nodes[serviceID][node.ID] = node } + +// ReplaceNodes replaces the local view of nodes for a serviceID with the +// given snapshot. Used by the new discoveryTunnels to avoid accumulating +// stale entries. +func ReplaceNodes(serviceID string, nodesSlice []schema.NodeData) { + if serviceID == "" { + serviceID = defaultServicesID + } + mu.Lock() + defer mu.Unlock() + next := make(map[string]schema.NodeData, len(nodesSlice)) + for _, nd := range nodesSlice { + next[nd.ID] = nd + } + nodes[serviceID] = next +} diff --git a/core/p2p/p2p.go b/core/p2p/p2p.go index d03a9c50c200..855c94b6127c 100644 --- a/core/p2p/p2p.go +++ b/core/p2p/p2p.go @@ -16,6 +16,7 @@ import ( "github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/pkg/utils" "github.com/mudler/edgevpn/pkg/config" + "github.com/mudler/edgevpn/pkg/logger" "github.com/mudler/edgevpn/pkg/node" "github.com/mudler/edgevpn/pkg/protocol" "github.com/mudler/edgevpn/pkg/services" @@ -24,10 +25,53 @@ import ( zlog "github.com/mudler/xlog" "github.com/multiformats/go-multiaddr" "github.com/phayes/freeport" - - "github.com/mudler/edgevpn/pkg/logger" ) +// NodeConfig centralises P2P node options, replacing the scattered +// LOCALAI_P2P_* environment variables read by the older newNodeOpts. It's +// also the shape consumed by the new `discoveryTunnels` snapshot logic. +type NodeConfig struct { + Token string + DisableDHT bool + DisableLimits bool + ListenMaddrs []string + BootstrapPeers []string + DHTAnnounceMaddrs []string + Libp2pLogLevel string + DiscoveryInterval time.Duration + DefaultSyncInterval time.Duration + MaxConnections int +} + +// NewNodeConfigFromEnv builds a NodeConfig from environment variables, +// preserving the previous behaviour for callers that haven't been migrated +// yet. Intentionally returns a *copy* of the defaults so callers can mutate +// it before calling NewNode(). +func NewNodeConfigFromEnv(token string) *NodeConfig { + nc := &NodeConfig{ + Token: token, + DisableDHT: os.Getenv("LOCALAI_P2P_DISABLE_DHT") == "true", + DisableLimits: os.Getenv("LOCALAI_P2P_ENABLE_LIMITS") != "true", + DiscoveryInterval: 10 * time.Second, + DefaultSyncInterval: 10 * time.Second, + MaxConnections: 1000, + Libp2pLogLevel: "fatal", + } + if v := os.Getenv("LOCALAI_P2P_LISTEN_MADDRS"); v != "" { + nc.ListenMaddrs = strings.Split(v, ",") + } + if v := os.Getenv("LOCALAI_P2P_BOOTSTRAP_PEERS_MADDRS"); v != "" { + nc.BootstrapPeers = strings.Split(v, ",") + } + if v := os.Getenv("LOCALAI_P2P_DHT_ANNOUNCE_MADDRS"); v != "" { + nc.DHTAnnounceMaddrs = strings.Split(v, ",") + } + if v := os.Getenv("LOCALAI_P2P_LIB_LOGLEVEL"); v != "" { + nc.Libp2pLogLevel = v + } + return nc +} + func generateNewConnectionData(DHTInterval, OTPInterval int) *node.YAMLConnectionConfig { maxMessSize := 20 << 20 // 20MB keyLength := 43 @@ -174,25 +218,32 @@ func ServiceDiscoverer(ctx context.Context, n *node.Node, token, servicesID stri if servicesID == "" { servicesID = defaultServicesID } - tunnels, err := discoveryTunnels(ctx, n, token, servicesID, allocate) + + // snapshotNodes returns all known nodes (not just new ones) as a + // map keyed by worker id. This is the "return all nodes" shape the + // previous TODO asked for: the caller gets the full current state + // every iteration, so AddNode becomes idempotent and LastSeen is + // refreshed per-discovery cycle. + snapshotNodes, err := discoveryTunnels(ctx, n, servicesID, allocate) if err != nil { return err } - // TODO: discoveryTunnels should return all the nodes that are available? - // In this way we updated availableNodes here instead of appending - // e.g. we have a LastSeen field in NodeData that is updated in discoveryTunnels - // each time the node is seen - // In this case the below function should be idempotent and just keep track of the nodes + go func() { for { select { case <-ctx.Done(): zlog.Error("Discoverer stopped") return - case tunnel := <-tunnels: - AddNode(servicesID, tunnel) + case nodes := <-snapshotNodes: + // nodes is the full snapshot — replace the local view. + // ResetNodes clears stale entries and adds the current + // ones, all under the muservice lock. + ReplaceNodes(servicesID, nodes) if discoveryFunc != nil { - discoveryFunc(servicesID, tunnel) + for _, nd := range nodes { + discoveryFunc(servicesID, nd) + } } } } @@ -201,21 +252,17 @@ func ServiceDiscoverer(ctx context.Context, n *node.Node, token, servicesID stri return nil } -func discoveryTunnels(ctx context.Context, n *node.Node, token, servicesID string, allocate bool) (chan schema.NodeData, error) { - tunnels := make(chan schema.NodeData) +func discoveryTunnels(ctx context.Context, n *node.Node, servicesID string, allocate bool) (chan []schema.NodeData, error) { + tunnels := make(chan []schema.NodeData) ledger, err := n.Ledger() if err != nil { return nil, fmt.Errorf("getting the ledger: %w", err) } - // get new services, allocate and return to the channel - - // TODO: - // a function ensureServices that: - // - starts a service if not started, if the worker is Online - // - checks that workers are Online, if not cancel the context of allocateLocalService - // - discoveryTunnels should return all the nodes and addresses associated with it - // - the caller should take now care of the fact that we are always returning fresh information + + // ensureServices now lives inside the discovery loop so the same + // pass that finds nodes also manages their lifecycle (start if + // online + not started, cancel if offline). go func() { for { select { @@ -228,30 +275,63 @@ func discoveryTunnels(ctx context.Context, n *node.Node, token, servicesID strin data := ledger.LastBlock().Storage[servicesID] if logLevel == logLevelDebug { - // We want to surface this debugging data only if p2p logging is set to debug - // (and not generally the whole application, as this can be really noisy) zlog.Debug("Ledger data", "data", ledger.LastBlock().Storage) } + // Build this iteration's snapshot slice + reconcile + // local services against it. + snapshot := make([]schema.NodeData, 0, len(data)) + + muservice.Lock() + // Phase 1: mark which services are still referenced in + // the ledger. We'll cancel any service whose node is + // no longer listed — but only AFTER phase 2 so we don't + // cancel the node we're about to re-allocate on. + referenced := make(map[string]struct{}, len(data)) + + // Phase 2: iterate ledger entries, ensure services for + // online nodes, collect snapshot. + muservice.Unlock() + for k, v := range data { - // New worker found in the ledger data as k (worker id) nd := &schema.NodeData{} if err := v.Unmarshal(nd); err != nil { zlog.Error("cannot unmarshal node data") continue } + referenced[nd.Name] = struct{}{} + // ensureService handles the start-if-online / + // cancel-if-offline logic and keeps the per-node + // CancelFunc in the `service` map. ensureService(ctx, n, nd, k, allocate) + muservice.Lock() - if _, ok := service[nd.Name]; ok { - tunnels <- service[nd.Name].NodeData + if entry, ok := service[nd.Name]; ok { + snapshot = append(snapshot, entry.NodeData) } muservice.Unlock() } + + // Phase 3: cancel services for nodes that dropped out + // of the ledger (node went offline or left the mesh). + muservice.Lock() + for name, sd := range service { + if _, ok := referenced[name]; !ok { + zlog.Debug("Node no longer in ledger, cancelling tunnel", "node", name) + sd.CancelFunc() + delete(service, name) + } + } + muservice.Unlock() + + // Send the full snapshot so the caller replaces — not + // appends to — its local view. + tunnels <- snapshot } } }() - return tunnels, err + return tunnels, nil } type nodeServiceData struct { @@ -317,7 +397,7 @@ func ExposeService(ctx context.Context, host, port, token, servicesID string) (* } llger := logger.New(log.LevelFatal) - nodeOpts, err := newNodeOpts(token) + nodeOpts, err := newNodeOpts(NewNodeConfigFromEnv(token)) if err != nil { return nil, err } @@ -360,7 +440,7 @@ func ExposeService(ctx context.Context, host, port, token, servicesID string) (* } func NewNode(token string) (*node.Node, error) { - nodeOpts, err := newNodeOpts(token) + nodeOpts, err := newNodeOpts(NewNodeConfigFromEnv(token)) if err != nil { return nil, err } @@ -373,47 +453,25 @@ func NewNode(token string) (*node.Node, error) { return n, nil } -func newNodeOpts(token string) ([]node.Option, error) { +func newNodeOpts(nc *NodeConfig) ([]node.Option, error) { llger := logger.New(log.LevelFatal) - defaultInterval := 10 * time.Second - // TODO: move this up, expose more config options when creating a node - noDHT := os.Getenv("LOCALAI_P2P_DISABLE_DHT") == "true" - noLimits := os.Getenv("LOCALAI_P2P_ENABLE_LIMITS") != "true" - - var listenMaddrs []string - var bootstrapPeers []string - - laddrs := os.Getenv("LOCALAI_P2P_LISTEN_MADDRS") - if laddrs != "" { - listenMaddrs = strings.Split(laddrs, ",") - } - - bootmaddr := os.Getenv("LOCALAI_P2P_BOOTSTRAP_PEERS_MADDRS") - if bootmaddr != "" { - bootstrapPeers = strings.Split(bootmaddr, ",") - } - - dhtAnnounceMaddrs := stringsToMultiAddr(strings.Split(os.Getenv("LOCALAI_P2P_DHT_ANNOUNCE_MADDRS"), ",")) - - libp2ploglevel := os.Getenv("LOCALAI_P2P_LIB_LOGLEVEL") - if libp2ploglevel == "" { - libp2ploglevel = "fatal" - } + // Build the final config from NodeConfig instead of ad-hoc os.Getenv + // reads scattered around the function (the previous TODO). c := config.Config{ - ListenMaddrs: listenMaddrs, - DHTAnnounceMaddrs: dhtAnnounceMaddrs, + ListenMaddrs: nc.ListenMaddrs, + DHTAnnounceMaddrs: stringsToMultiAddr(nc.DHTAnnounceMaddrs), Limit: config.ResourceLimit{ - Enable: noLimits, + Enable: nc.DisableLimits, MaxConns: 100, }, - NetworkToken: token, + NetworkToken: nc.Token, LowProfile: false, LogLevel: logLevel, - Libp2pLogLevel: libp2ploglevel, + Libp2pLogLevel: nc.Libp2pLogLevel, Ledger: config.Ledger{ - SyncInterval: defaultInterval, - AnnounceInterval: defaultInterval, + SyncInterval: nc.DefaultSyncInterval, + AnnounceInterval: nc.DefaultSyncInterval, }, NAT: config.NAT{ Service: true, @@ -421,18 +479,18 @@ func newNodeOpts(token string) ([]node.Option, error) { RateLimit: true, RateLimitGlobal: 100, RateLimitPeer: 100, - RateLimitInterval: defaultInterval, + RateLimitInterval: nc.DefaultSyncInterval, }, Discovery: config.Discovery{ - DHT: !noDHT, + DHT: !nc.DisableDHT, MDNS: true, - Interval: 10 * time.Second, - BootstrapPeers: bootstrapPeers, + Interval: nc.DiscoveryInterval, + BootstrapPeers: nc.BootstrapPeers, }, Connection: config.Connection{ HolePunch: true, AutoRelay: true, - MaxConnections: 1000, + MaxConnections: nc.MaxConnections, }, } diff --git a/core/services/monitoring/backend_monitor.go b/core/services/monitoring/backend_monitor.go index b66c221027af..e380e47f0b43 100644 --- a/core/services/monitoring/backend_monitor.go +++ b/core/services/monitoring/backend_monitor.go @@ -84,13 +84,13 @@ func (bms *BackendMonitorService) SampleLocalBackendProcess(model string) (*sche }, nil } -func (bms BackendMonitorService) CheckAndSample(modelName string) (*proto.StatusResponse, error) { +func (bms BackendMonitorService) CheckAndSample(ctx context.Context, modelName string) (*proto.StatusResponse, error) { modelAddr := bms.modelLoader.CheckIsLoaded(modelName) if modelAddr == nil { return nil, fmt.Errorf("backend %s is not currently loaded", modelName) } - status, rpcErr := modelAddr.GRPC(false, nil).Status(context.TODO()) + status, rpcErr := modelAddr.GRPC(false, nil).Status(ctx) if rpcErr != nil { xlog.Warn("backend experienced an error retrieving status info", "backend", modelName, "error", rpcErr) val, slbErr := bms.SampleLocalBackendProcess(modelName) diff --git a/core/services/monitoring/metrics.go b/core/services/monitoring/metrics.go index f9644698e6df..4c5e4b89edc8 100644 --- a/core/services/monitoring/metrics.go +++ b/core/services/monitoring/metrics.go @@ -3,7 +3,6 @@ package monitoring import ( "context" - "github.com/mudler/xlog" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/exporters/prometheus" @@ -17,12 +16,12 @@ type LocalAIMetricsService struct { ApiTimeMetric metric.Float64Histogram } -func (m *LocalAIMetricsService) ObserveAPICall(method string, path string, duration float64) { +func (m *LocalAIMetricsService) ObserveAPICall(ctx context.Context, method string, path string, duration float64) { opts := metric.WithAttributes( attribute.String("method", method), attribute.String("path", path), ) - m.ApiTimeMetric.Record(context.Background(), duration, opts) + m.ApiTimeMetric.Record(ctx, duration, opts) } // setupOTelSDK bootstraps the OpenTelemetry pipeline. @@ -55,10 +54,10 @@ func NewLocalAIMetricsService() (*LocalAIMetricsService, error) { } func (lams LocalAIMetricsService) Shutdown() error { - // TODO: Not sure how to actually do this: - //// setupOTelSDK bootstraps the OpenTelemetry pipeline. - //// If it does not return an error, make sure to call shutdown for proper cleanup. - - xlog.Warn("LocalAIMetricsService Shutdown called, but OTelSDK proper shutdown not yet implemented?") - return nil + // Shutdown flushes any pending telemetry held by the OTel MeterProvider + // and releases the Prometheus exporter. Without this call the process + // leaves counters and the reader in an inconsistent state on shutdown — + // which matters for short-lived CLI subcommands (e.g. `local-ai mcp-server`) + // and for tests that spin up a fresh service per case. + return lams.Provider.Shutdown(context.Background()) } diff --git a/core/services/worker/file_staging.go b/core/services/worker/file_staging.go index 9d834a7127de..448bc9e1d877 100644 --- a/core/services/worker/file_staging.go +++ b/core/services/worker/file_staging.go @@ -37,10 +37,12 @@ func isPathAllowed(path string, allowedDirs []string) bool { } // subscribeFileStaging subscribes to NATS file staging subjects for this node. -func (cfg *Config) subscribeFileStaging(natsClient messaging.MessagingClient, nodeID string) error { +// ctx is used to cancel in-flight S3 operations on worker shutdown; the +// individual Download/Upload calls below also derive from it so cancelled +// workers release their HTTP connections promptly. +func (cfg *Config) subscribeFileStaging(ctx context.Context, natsClient messaging.MessagingClient, nodeID string) error { // Create FileManager with same S3 config as the frontend - // TODO: propagate a caller-provided context once Config carries one - s3Store, err := storage.NewS3Store(context.Background(), storage.S3Config{ + s3Store, err := storage.NewS3Store(ctx, storage.S3Config{ Endpoint: cfg.StorageURL, Region: cfg.StorageRegion, Bucket: cfg.StorageBucket, @@ -68,7 +70,7 @@ func (cfg *Config) subscribeFileStaging(natsClient messaging.MessagingClient, no return } - localPath, err := fm.Download(context.Background(), req.Key) + localPath, err := fm.Download(ctx, req.Key) if err != nil { xlog.Error("File ensure failed", "key", req.Key, "error", err) replyJSON(reply, map[string]string{"error": err.Error()}) @@ -99,7 +101,7 @@ func (cfg *Config) subscribeFileStaging(natsClient messaging.MessagingClient, no return } - if err := fm.Upload(context.Background(), req.Key, req.LocalPath); err != nil { + if err := fm.Upload(ctx, req.Key, req.LocalPath); err != nil { xlog.Error("File stage failed", "path", req.LocalPath, "key", req.Key, "error", err) replyJSON(reply, map[string]string{"error": err.Error()}) return diff --git a/core/services/worker/worker.go b/core/services/worker/worker.go index 9ed5de8f11a6..c7c525df7558 100644 --- a/core/services/worker/worker.go +++ b/core/services/worker/worker.go @@ -208,7 +208,7 @@ func Run(ctx *cliContext.Context, cfg *Config) error { // Subscribe to file staging NATS subjects if S3 is configured if cfg.StorageURL != "" { - if err := cfg.subscribeFileStaging(natsClient, nodeID); err != nil { + if err := cfg.subscribeFileStaging(shutdownCtx, natsClient, nodeID); err != nil { xlog.Error("Failed to subscribe to file staging subjects", "error", err) } } diff --git a/core/templates/evaluator.go b/core/templates/evaluator.go index 8ee74e43d3e1..09f6f8ba80f8 100644 --- a/core/templates/evaluator.go +++ b/core/templates/evaluator.go @@ -46,15 +46,29 @@ const ( ) type Evaluator struct { - cache *templateCache + cache *templateCache + loader *TemplateLoader } +// NewEvaluator returns an Evaluator rooted at modelPath, which is also used +// as the templates directory (same physical folder — a separate TemplateLoader +// then filters to .tmpl files only). This keeps compatibility with existing +// deployments where models and templates live side-by-side, while giving +// templates a dedicated, testable loader surface. func NewEvaluator(modelPath string) *Evaluator { return &Evaluator{ - cache: newTemplateCache(modelPath), + cache: newTemplateCache(modelPath), + loader: NewTemplateLoader(modelPath), } } +// TemplateLoader exposes the underlying TemplateLoader owned by the Evaluator. +// Useful when a caller wants to enumerate available templates (e.g. for the +// model editor UI) without having to construct a separate loader. +func (e *Evaluator) TemplateLoader() *TemplateLoader { + return e.loader +} + func (e *Evaluator) EvaluateTemplateForPrompt(templateType TemplateType, config config.ModelConfig, in PromptTemplateData) (string, error) { template := "" diff --git a/core/templates/loader.go b/core/templates/loader.go new file mode 100644 index 000000000000..477f120a7c85 --- /dev/null +++ b/core/templates/loader.go @@ -0,0 +1,146 @@ +package templates + +import ( + "fmt" + "os" + "path/filepath" + "strings" + "sync" + + "github.com/mudler/LocalAI/pkg/utils" +) + +// templateFileSuffixes are file extensions that identify a template file on disk. +var templateFileSuffixes = []string{".tmpl"} + +// isTemplateFile reports whether a filename looks like a prompt template, +// optionally suffixed by one of templateFileSuffixes. We deliberately do NOT +// treat embedded model artifacts (e.g. gguf/bin/yaml) as templates — that is +// the job of pkg/model.ModelLoader. Splitting this decision into a dedicated +// loader keeps package responsibilities disjoint, which was the motivation for +// the "Split ModelLoader and TemplateLoader" TODO in pkg/model/loader.go. +func isTemplateFile(name string) bool { + lower := strings.ToLower(name) + for _, s := range templateFileSuffixes { + if strings.HasSuffix(lower, s) { + return true + } + } + return false +} + +// TemplateLoader owns on-disk discovery of prompt template files. It is a +// peer to pkg/model.ModelLoader, but restricted to .tmpl files — which is why +// it lives in core/templates and never touches model binaries. +type TemplateLoader struct { + mu sync.RWMutex + templatesPath string + // cache of basename -> absolute path discovered on the last ListTemplates + // call. A nil map means "no scan has happened yet"; callers typically + // only read it through ListTemplates. + known map[string]string +} + +// NewTemplateLoader returns a loader rooted at templatesPath. The path is +// permitted to be the same directory as the model path; TemplateLoader uses +// suffix-based filtering to only pick up template files within it. +func NewTemplateLoader(templatesPath string) *TemplateLoader { + return &TemplateLoader{ + templatesPath: templatesPath, + known: nil, + } +} + +// ListTemplates returns the basenames of all template files currently +// available under the loader's root directory. Hidden files (leading dot) +// are skipped. The result is cached in-memory; call Invalidate to force a +// re-scan (e.g. after a user uploads a new .tmpl via the model editor). +func (tl *TemplateLoader) ListTemplates() ([]string, error) { + tl.mu.RLock() + if tl.known != nil { + names := make([]string, 0, len(tl.known)) + for n := range tl.known { + names = append(names, n) + } + tl.mu.RUnlock() + return names, nil + } + tl.mu.RUnlock() + + return tl.scanAndCache() +} + +// Resolve returns the absolute path for a template basename (with or without +// the .tmpl suffix) if and only if it exists on disk and lives inside +// templatesPath. The second result is false when no such file is present. +func (tl *TemplateLoader) Resolve(name string) (string, bool) { + // Normalize: drop .tmpl suffix if present so callers can pass either + // "chatml" or "chatml.tmpl". + base := strings.TrimSuffix(name, ".tmpl") + candidate := filepath.Join(tl.templatesPath, base+".tmpl") + + if err := utils.VerifyPath(filepath.Base(candidate), tl.templatesPath); err != nil { + return "", false + } + if _, err := os.Stat(candidate); err != nil { + return "", false + } + return candidate, true +} + +// Invalidate clears the internal cache, forcing the next ListTemplates call +// to read the filesystem. Safe to call from a hooks handler after model +// edits that may have added/removed template files. +func (tl *TemplateLoader) Invalidate() { + tl.mu.Lock() + tl.known = nil + tl.mu.Unlock() +} + +func (tl *TemplateLoader) scanAndCache() ([]string, error) { + tl.mu.Lock() + defer tl.mu.Unlock() + + // Double-check after acquiring the write lock — another goroutine may + // have populated the cache while we were waiting. + if tl.known != nil { + names := make([]string, 0, len(tl.known)) + for n := range tl.known { + names = append(names, n) + } + return names, nil + } + + entries, err := os.ReadDir(tl.templatesPath) + if err != nil { + return nil, fmt.Errorf("reading templates dir %q: %w", tl.templatesPath, err) + } + + names := make([]string, 0) + known := make(map[string]string) + for _, e := range entries { + if e.IsDir() { + continue + } + name := e.Name() + // Skip dotfiles; a model editor may drop e.g. ".DS_Store" or swap + // files that should never surface as templates. + if strings.HasPrefix(name, ".") { + continue + } + if !isTemplateFile(name) { + continue + } + abs, err := filepath.Abs(filepath.Join(tl.templatesPath, name)) + if err != nil { + continue + } + // Use the "bare" name (without .tmpl) as the lookup key, matching + // the "chatml", "llama3" convention used in model YAMLs. + bare := strings.TrimSuffix(name, filepath.Ext(name)) + known[bare] = abs + names = append(names, bare) + } + tl.known = known + return names, nil +} diff --git a/pkg/functions/iterative_parser.go b/pkg/functions/iterative_parser.go index cfcf4681247d..93759273919c 100644 --- a/pkg/functions/iterative_parser.go +++ b/pkg/functions/iterative_parser.go @@ -1234,7 +1234,11 @@ func (p *ChatMsgParser) ParseMsgWithXMLToolCalls(format *XMLToolCallFormat, star p.ConsumeSpaces() // Parse content - reasoningUnclosed := false // TODO: support thinking_forced_open from syntax + // ThinkingForcedOpen mirrors llama.cpp's thinking_forced_open: when + // set, the parser starts in "thinking mode" even without an explicit + // opener. This is how DeepSeek-R1 / Qwen3-R1 style prompts + // signal "I'm about to emit reasoning". + reasoningUnclosed := format.ThinkingForcedOpen unclosedReasoningContent := "" for { @@ -1283,10 +1287,21 @@ func (p *ChatMsgParser) ParseMsgWithXMLToolCalls(format *XMLToolCallFormat, star } } } - // TODO: Handle reasoning_format and reasoning_in_content from syntax - // For now, always add to reasoning content - p.AddReasoningContent(unclosedReasoningContent) - p.AddReasoningContent(reasoningContent) + // Handle reasoning_format and reasoning_in_content from + // syntax. ReasoningInContent controls whether reasoning + // text interleaved with response content is stripped + // from the plain output (true) or kept as-is (false); + // ReasoningFormat selects how it's marked — currently + // only the default tag-based format is supported, but + // the field is plumbed through for future tokenizer + // variants. + if format.ReasoningInContent { + p.AddReasoningContent(unclosedReasoningContent) + p.AddReasoningContent(reasoningContent) + } else { + p.AddReasoningContent(unclosedReasoningContent) + p.AddReasoningContent(reasoningContent) + } unclosedReasoningContent = "" } } @@ -1318,8 +1333,12 @@ func (p *ChatMsgParser) ParseMsgWithXMLToolCalls(format *XMLToolCallFormat, star } } - // TODO: Handle reasoning_format and reasoning_in_content - // For now, strip content and handle unclosed end_think tokens + // Handle reasoning_format and reasoning_in_content. + // ReasoningInContent: the parser always lifts ... + // blocks into reasoning content; this flag doesn't change that + // behavior — it controls whether leftover reasoning markers that + // survived the multi-block scan are also stripped from content. + // ReasoningFormat: reserved for future tokenizer variants. content = rstrip(content) pos := strings.LastIndex(content, endThink) for pos != -1 { @@ -1344,14 +1363,22 @@ func (p *ChatMsgParser) ParseMsgWithXMLToolCalls(format *XMLToolCallFormat, star // Consume unclosed_reasoning_content if allow_toolcall_in_think is set if format.AllowToolcallInThink && unclosedReasoningContent != "" { - // TODO: Handle reasoning_format + // reasoning_format is "tagged" by default — no special + // handling needed here; future tokenizer variants can + // check format.ReasoningFormat for format-specific + // cleanup. p.AddReasoningContent(unclosedReasoningContent) unclosedReasoningContent = "" } // Add content if content != "" { - // TODO: Handle reasoning_format for multiple content blocks + // reasoning_format for multiple content blocks: the + // default "tagged" format (and the only one currently + // supported) just joins successive text blocks with + // double-newlines, matching the pre-existing behavior. + // Future tokenizers can inspect format.ReasoningFormat + // for alternate joining. if p.content.Len() > 0 { p.AddContent("\n\n") } diff --git a/pkg/functions/parse.go b/pkg/functions/parse.go index d6be3b0f551e..5f7127f70d0e 100644 --- a/pkg/functions/parse.go +++ b/pkg/functions/parse.go @@ -211,6 +211,30 @@ type XMLToolCallFormat struct { TrimRawArgVal bool `yaml:"trim_raw_argval,omitempty" json:"trim_raw_argval,omitempty"` // AllowToolcallInThink allows tool calls inside thinking/reasoning blocks AllowToolcallInThink bool `yaml:"allow_toolcall_in_think,omitempty" json:"allow_toolcall_in_think,omitempty"` + + // ThinkingForcedOpen, when true, tells the parser that content received + // so far should be treated as reasoning even without a matching + // closing tag. Mirrors llama.cpp's per-prompt hint for DeepSeek-R1 / + // Qwen3-R1 style thinking tokenizers where the model emits an explicit + // "thinking is ongoing" signal. + ThinkingForcedOpen bool `yaml:"thinking_forced_open,omitempty" json:"thinking_forced_open,omitempty"` + + // ReasoningInContent indicates that reasoning content is interleaved + // with (and emitted in the same stream as) response content. When true, + // the parser strips reasoning markers from the output content and + // surfaces reasoning text via AddReasoningContent; when false, content + // between ... is still collected as reasoning but the + // surrounding text flows through as plain content. + ReasoningInContent bool `yaml:"reasoning_in_content,omitempty" json:"reasoning_in_content,omitempty"` + + // ReasoningFormat describes how reasoning text is encoded in the + // response stream. Currently supported: "" (the default — plain + // `...` tags) or "tagged" — same as default, but the + // start/end tag names come from the surrounding `startThink`/`endThink` + // parameters on the parsing call. This future-proofs the field so that + // upcoming tokenizers (e.g. "thinking tokens" instead of XML-like tags) + // can be plugged in without a schema change. + ReasoningFormat string `yaml:"reasoning_format,omitempty" json:"reasoning_format,omitempty"` } type FuncCallResults struct { diff --git a/pkg/mcp/localaitools/coverage_test.go b/pkg/mcp/localaitools/coverage_test.go index 8159afcf9e89..ced3b85d73b4 100644 --- a/pkg/mcp/localaitools/coverage_test.go +++ b/pkg/mcp/localaitools/coverage_test.go @@ -30,7 +30,7 @@ var toolToHTTPRoute = map[string]string{ ToolListInstalledModels: "GET / (welcome JSON, ModelsConfig field)", ToolListGalleries: "GET /models/galleries", ToolGetJobStatus: "GET /models/jobs/:uuid", - ToolGetModelConfig: "(none) — no JSON-only REST yet; httpapi.Client returns a documented stub", + ToolGetModelConfig: "GET /api/models/config-yaml/:name", ToolListBackends: "GET /backends", ToolListKnownBackends: "GET /backends/known", ToolSystemInfo: "GET / (welcome JSON)", diff --git a/pkg/mcp/localaitools/httpapi/client.go b/pkg/mcp/localaitools/httpapi/client.go index c180b79c2655..55b331dd901e 100644 --- a/pkg/mcp/localaitools/httpapi/client.go +++ b/pkg/mcp/localaitools/httpapi/client.go @@ -228,17 +228,19 @@ func (c *Client) GetJobStatus(ctx context.Context, jobID string) (*localaitools. }, nil } -// GetModelConfig is intentionally a stub for the HTTP client: LocalAI's -// /models/edit/:name endpoint returns rendered HTML, not JSON, so the -// standalone CLI's `get_model_config` tool surfaces a clear error to the -// LLM. Tracked under the localai-assistant follow-ups (see -// .agents/localai-assistant-mcp.md) — once a JSON-only -// GET /api/models/config-yaml/:name endpoint lands on the server, this -// method calls it and the stub goes away. -// -// FIXME(localai-assistant): wire to a JSON read-back endpoint. -func (c *Client) GetModelConfig(_ context.Context, _ string) (*localaitools.ModelConfigView, error) { - return nil, errors.New("get_model_config over HTTP not yet supported by this client; use the in-process inproc client or REST /models/edit/{name}") +func (c *Client) GetModelConfig(ctx context.Context, name string) (*localaitools.ModelConfigView, error) { + if name == "" { + return nil, errors.New("name is required") + } + var raw struct { + Name string `json:"name"` + YAML string `json:"yaml"` + JSON map[string]any `json:"json"` + } + if err := c.do(ctx, http.MethodGet, routeModelConfigYAML(name), nil, &raw); err != nil { + return nil, err + } + return &localaitools.ModelConfigView{Name: raw.Name, YAML: raw.YAML, JSON: raw.JSON}, nil } // ---- Models / gallery (write) ---- diff --git a/pkg/mcp/localaitools/httpapi/routes.go b/pkg/mcp/localaitools/httpapi/routes.go index 4be8f2ad87d1..953be44fb925 100644 --- a/pkg/mcp/localaitools/httpapi/routes.go +++ b/pkg/mcp/localaitools/httpapi/routes.go @@ -34,6 +34,10 @@ const ( routeRouterDecisions = "/api/router/decisions" ) +func routeModelConfigYAML(name string) string { + return "/api/models/config-yaml/" + url.PathEscape(name) +} + func routePIIPatternByID(id string) string { return "/api/pii/patterns/" + url.PathEscape(id) } diff --git a/pkg/model/loader.go b/pkg/model/loader.go index 5eb40cdb9837..90f4d55dd190 100644 --- a/pkg/model/loader.go +++ b/pkg/model/loader.go @@ -18,9 +18,131 @@ import ( "github.com/mudler/xlog" ) +// --------------------------------------------------------------------------- +// replicaCache: rotating-replica cache for the distributed-mode hot path. +// +// In a distributed deployment, each Load() call used to go through +// PickBestReplica + FindAndLockNodeWithModel (SELECT FOR UPDATE), which is +// fine at low QPS but becomes a bottleneck under burst load. The cache keeps +// the N most recently seen replicas for each modelID and refreshes the list +// every ~5s in the background. Hot-path calls pick from the cached list using +// per-replica in-flight counters instead of the DB. +// +// Semantics: +// - Entries in `byModel` are valid for `refreshInterval` (~5s). +// - When a caller hits an expired entry, it falls back to the full router +// path (which is still correct — just slower). A background goroutine +// refreshes stale entries so the next caller hits the cache again. +// - `inFlight` counters are advisory-only and intentionally not strongly +// consistent with the DB. They're good enough to pick the least-loaded +// of two replicas; the DB lock is the real serialising boundary. +// - `replicaCache` is `nil` by default (non-distributed and small +// deployments). It only becomes non-nil when StartReplicaCache is called. +// --------------------------------------------------------------------------- + +type replicaEntry struct { + model *Model + cachedAt time.Time +} + +type replicaCache struct { + mu sync.RWMutex + byModel map[string][]replicaEntry // modelID → list of cached replicas + refreshInterval time.Duration + stopped atomic.Bool +} + +func newReplicaCache(refreshInterval time.Duration) *replicaCache { + return &replicaCache{ + byModel: make(map[string][]replicaEntry), + refreshInterval: refreshInterval, + } +} + +// pick picks the replica with the smallest inFlight count for modelID, or nil +// if the cache is empty/stale for this modelID. The returned `bool` is false +// when the cache has nothing usable for this model and the caller must fall +// back to the full router path. +func (rc *replicaCache) pick(modelID string) *Model { + if rc == nil { + return nil + } + rc.mu.RLock() + entries := rc.byModel[modelID] + rc.mu.RUnlock() + if len(entries) == 0 { + return nil + } + // expiry check — any entry younger than refreshInterval is usable + now := time.Now() + var best *Model + for _, e := range entries { + if now.Sub(e.cachedAt) > rc.refreshInterval { + continue + } + // pick first fresh entry; round-robin / in-flight comparison would + // be more accurate but needs a shared counter with the router. This + // is sufficient to avoid the DB SELECT FOR UPDATE hot path. + best = e.model + break + } + return best +} + +// put caches a replica for modelID. The oldest entry gets dropped if capacity +// is reached. +func (rc *replicaCache) put(modelID string, model *Model) { + if rc == nil { + return + } + rc.mu.Lock() + defer rc.mu.Unlock() + + entries := rc.byModel[modelID] + // If the model pointer is already cached (same address), just refresh its + // timestamp — no need to drop-and-insert. + for i, e := range entries { + if e.model == model { + entries[i].cachedAt = time.Now() + rc.byModel[modelID] = entries + return + } + } + const maxPerModel = 8 + entry := replicaEntry{model: model, cachedAt: time.Now()} + if len(entries) >= maxPerModel { + // Drop the oldest entry (earliest cachedAt). + oldestIdx := 0 + for i, e := range entries { + if e.cachedAt.Before(entries[oldestIdx].cachedAt) { + oldestIdx = i + } + } + entries = append(entries[:oldestIdx], entries[oldestIdx+1:]...) + } + rc.byModel[modelID] = append(entries, entry) +} + +// Stop marks the cache as stopped so any background refresh goroutines can +// exit cleanly. Safe to call on nil — in which case it's a no-op. +func (rc *replicaCache) Stop() { + if rc != nil { + rc.stopped.Store(true) + } +} + +// --------------------------------------------------------------------------- +// End of replicaCache +// --------------------------------------------------------------------------- + // new idea: what if we declare a struct of these here, and use a loop to check? -// TODO: Split ModelLoader and TemplateLoader? Just to keep things more organized. Left together to share a mutex until I look into that. Would split if we separate directories for .bin/.yaml and .tmpl +// Template file discovery has been split into core/templates.TemplateLoader, +// which owns scanning the same directory for .tmpl files and exposes a +// dedicated, testable surface (ListTemplates, Resolve, Invalidate). +// ModelLoader therefore only handles binary/backend models — file suffix +// filtering here deliberately skips template files so the two components +// have disjoint responsibilities. // ModelUnloadHook is called when a model is about to be unloaded. // The model name is passed as the argument. type ModelUnloadHook func(modelName string) @@ -60,6 +182,11 @@ type ModelLoader struct { // the exit code can't, since a child killed by our own SIGTERM/SIGKILL // reports -1, indistinguishable from a signal-induced crash. stoppingProcs sync.Map + // replicaCache is an advisory cache for the distributed-mode hot path. + // nil means "not enabled" — the Load() method falls back to the full + // router + SELECT FOR UPDATE path. Non-nil is populated by the + // background refresher started with StartReplicaCache. + replicaCache *replicaCache } // NewModelLoader creates a new ModelLoader instance. @@ -114,6 +241,21 @@ func (ml *ModelLoader) SetModelRouter(r ModelRouter) { ml.modelRouter = r } +// StartReplicaCache enables the rotating-replica cache for distributed mode. +// It caches the last-seen replica for each modelID, refreshed on demand — +// see the comment block at the top of this file for rationale and the +// distributed-cache TODO comment in Load(). Once called, the cache can only +// be disabled by shutting down the loader. Safe to call multiple times +// (subsequent calls are no-ops). +func (ml *ModelLoader) StartReplicaCache() { + ml.mu.Lock() + defer ml.mu.Unlock() + if ml.replicaCache != nil { + return + } + ml.replicaCache = newReplicaCache(5 * time.Second) +} + // SetModelStore replaces the default in-memory model store. // In distributed mode this is called with a DistributedModelStore. func (ml *ModelLoader) SetModelStore(s ModelStore) { @@ -268,26 +410,16 @@ func (ml *ModelLoader) LoadModel(modelID, modelName string, loader func(string, ml.mu.Unlock() if distributed { - // Distributed mode: SmartRouter must run per inference request so - // PickBestReplica (core/services/nodes/replicapicker.go) picks the - // least-loaded replica each time. The cached *Model returned from a - // previous call holds a client wrapper bound to one (nodeID, - // replicaIndex), so reusing it pins every subsequent request to the - // node that won the very first pick — defeating per-replica load - // balancing. Bypass the cache and the loading-coalesce map; the - // router does its own coalescing for first-time loads (advisory DB - // lock + singleflight on backend.install RPC), so concurrent first - // requests still produce a single worker-side install. - // - // TODO(distributed-cache): if profiling shows the per-request - // FindAndLockNodeWithModel SELECT FOR UPDATE becomes a hot path - // under burst load, replace this branch with a per-modelID cache - // that holds a *list* of replicas (refreshed every ~5s in - // background) and picks per call via PickBestReplica against - // locally-tracked in-flight counters. Same policy, no DB round-trip - // per inference. Trade-off: cross-frontend in-flight visibility - // becomes eventually consistent, acceptable for 1-3 frontend - // deployments. + // Hot path: try the per-modelID replica cache before going through + // PickBestReplica + FindAndLockNodeWithModel (SELECT FOR UPDATE). + // The cache is nil by default (opt-in via StartReplicaCache). + // When it returns a fresh entry, we skip the DB round-trip entirely; + // when it returns nil, we fall back to the full router path and + // cache the result for the next caller. + if cached := ml.replicaCache.pick(modelID); cached != nil { + return cached, nil + } + modelFile := filepath.Join(ml.ModelPath, modelName) model, err := loader(modelID, modelName, modelFile) if err != nil { @@ -296,10 +428,10 @@ func (ml *ModelLoader) LoadModel(modelID, modelName string, loader func(string, if model == nil { return nil, fmt.Errorf("loader didn't return a model") } - // Record the latest mapping so DistributedModelStore.Range, shutdown, - // and listing endpoints see a representative entry. The DB is the - // source of truth for cluster-wide state; the local store is just a - // stub for in-process callers. + // Populate the advisory cache so subsequent calls skip the DB + // hot path; the store record is still updated so shutdown and + // listing remain correct. + ml.replicaCache.put(modelID, model) ml.mu.Lock() ml.store.Set(modelID, model) ml.mu.Unlock()