From 969e7fdf73a139915d73e7e113a5d4e0c4eb431b Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Tue, 17 Feb 2026 12:42:37 -0800 Subject: [PATCH 1/5] added middleware implementations --- go/ai/document.go | 27 ++ go/ai/document_test.go | 106 ++++++ go/ai/generate.go | 103 +++--- go/ai/middleware.go | 4 +- go/ai/tools.go | 14 + go/genkit/genkit.go | 25 +- go/plugins/middleware/fallback.go | 148 +++++++++ go/plugins/middleware/fallback_test.go | 270 +++++++++++++++ go/plugins/middleware/plugin.go | 43 +++ go/plugins/middleware/retry.go | 171 ++++++++++ go/plugins/middleware/retry_test.go | 256 +++++++++++++++ go/plugins/middleware/subagent_design.md | 232 +++++++++++++ go/plugins/middleware/tool_approval.go | 161 +++++++++ go/plugins/middleware/tool_approval_test.go | 345 ++++++++++++++++++++ go/plugins/middleware/tool_error_handler.go | 62 ++++ 15 files changed, 1899 insertions(+), 68 deletions(-) create mode 100644 go/plugins/middleware/fallback.go create mode 100644 go/plugins/middleware/fallback_test.go create mode 100644 go/plugins/middleware/plugin.go create mode 100644 go/plugins/middleware/retry.go create mode 100644 go/plugins/middleware/retry_test.go create mode 100644 go/plugins/middleware/subagent_design.md create mode 100644 go/plugins/middleware/tool_approval.go create mode 100644 go/plugins/middleware/tool_approval_test.go create mode 100644 go/plugins/middleware/tool_error_handler.go diff --git a/go/ai/document.go b/go/ai/document.go index 1c6407ef17..3f113bdbb7 100644 --- a/go/ai/document.go +++ b/go/ai/document.go @@ -19,6 +19,8 @@ package ai import ( "encoding/json" "fmt" + "maps" + "slices" "strings" ) @@ -44,6 +46,31 @@ type Part struct { Metadata map[string]any `json:"metadata,omitempty"` // valid for all kinds } +// Clone returns a shallow copy of the Part with its own Metadata and Custom +// maps. Callers can add or remove map keys without mutating the original. +func (p *Part) Clone() *Part { + if p == nil { + return nil + } + cp := *p + cp.Custom = maps.Clone(p.Custom) + cp.Metadata = maps.Clone(p.Metadata) + return &cp +} + +// Clone returns a shallow copy of the Message with its own Content slice +// and Metadata map. Callers can replace parts or add metadata keys without +// mutating the original. +func (m *Message) Clone() *Message { + if m == nil { + return nil + } + cp := *m + cp.Content = slices.Clone(m.Content) + cp.Metadata = maps.Clone(m.Metadata) + return &cp +} + type PartKind int8 const ( diff --git a/go/ai/document_test.go b/go/ai/document_test.go index a5d4bc9bc0..3363e7eccd 100644 --- a/go/ai/document_test.go +++ b/go/ai/document_test.go @@ -17,6 +17,7 @@ package ai import ( + "bytes" "encoding/json" "reflect" "testing" @@ -411,3 +412,108 @@ func TestNewResponseForToolRequest(t *testing.T) { } }) } + +// TestPartClone verifies that Part.Clone produces an independent copy. +// Every Part field is populated so that adding a new field without updating +// this test (and Clone) causes a failure. +func TestPartClone(t *testing.T) { + orig := &Part{ + Kind: PartToolRequest, + ContentType: "application/json", + Text: "body", + ToolRequest: &ToolRequest{Name: "tool", Input: map[string]any{"a": 1}}, + // Normally a Part wouldn't have both ToolRequest and ToolResponse, + // but we populate everything to catch missing fields. + ToolResponse: &ToolResponse{Name: "tool", Output: "ok"}, + Resource: &ResourcePart{Uri: "res://x"}, + Custom: map[string]any{"ck": "cv"}, + Metadata: map[string]any{"sig": []byte{1, 2, 3}, "key": "val"}, + } + + // Guard: every field in the fixture must be non-zero. + // If someone adds a new field to Part this will fail, forcing them to + // add it here and verify Clone handles it. + rv := reflect.ValueOf(orig).Elem() + for i := range rv.NumField() { + if rv.Field(i).IsZero() { + t.Fatalf("Part field %q is zero in test fixture — populate it and verify Clone handles it", rv.Type().Field(i).Name) + } + } + + cp := orig.Clone() + + // Values must match. + if !reflect.DeepEqual(orig, cp) { + t.Fatal("Clone() values differ from original") + } + + // Mutating clone's maps must not affect the original. + cp.Metadata["extra"] = true + if _, ok := orig.Metadata["extra"]; ok { + t.Error("mutating clone Metadata affected original") + } + + cp.Custom["extra"] = true + if _, ok := orig.Custom["extra"]; ok { + t.Error("mutating clone Custom affected original") + } + + // Go types in metadata (e.g. []byte) must be preserved, not string-ified. + sig, ok := cp.Metadata["sig"].([]byte) + if !ok { + t.Fatalf("Metadata[sig] type = %T, want []byte", cp.Metadata["sig"]) + } + if !bytes.Equal(sig, []byte{1, 2, 3}) { + t.Errorf("Metadata[sig] = %v, want [1 2 3]", sig) + } + + // nil Part.Clone() should return nil. + var nilPart *Part + if nilPart.Clone() != nil { + t.Error("nil Part.Clone() should return nil") + } +} + +// TestMessageClone verifies that Message.Clone produces an independent copy. +// Every Message field is populated so that adding a new field without updating +// this test (and Clone) causes a failure. +func TestMessageClone(t *testing.T) { + orig := &Message{ + Role: RoleModel, + Content: []*Part{NewTextPart("hello"), NewTextPart("world")}, + Metadata: map[string]any{"k": "v"}, + } + + // Guard: every field must be non-zero. + rv := reflect.ValueOf(orig).Elem() + for i := range rv.NumField() { + if rv.Field(i).IsZero() { + t.Fatalf("Message field %q is zero in test fixture — populate it and verify Clone handles it", rv.Type().Field(i).Name) + } + } + + cp := orig.Clone() + + // Values must match. + if !reflect.DeepEqual(orig, cp) { + t.Fatal("Clone() values differ from original") + } + + // Mutating clone's Content slice must not affect the original. + cp.Content[0] = NewTextPart("replaced") + if orig.Content[0].Text != "hello" { + t.Error("mutating clone Content affected original") + } + + // Mutating clone's Metadata must not affect the original. + cp.Metadata["extra"] = true + if _, ok := orig.Metadata["extra"]; ok { + t.Error("mutating clone Metadata affected original") + } + + // nil Message.Clone() should return nil. + var nilMsg *Message + if nilMsg.Clone() != nil { + t.Error("nil Message.Clone() should return nil") + } +} diff --git a/go/ai/generate.go b/go/ai/generate.go index 7f8778664f..1b28f19c26 100644 --- a/go/ai/generate.go +++ b/go/ai/generate.go @@ -203,7 +203,7 @@ func LookupModel(r api.Registry, name string) Model { } // GenerateWithRequest is the central generation implementation for ai.Generate(), prompt.Execute(), and the GenerateAction direct call. -func GenerateWithRequest(ctx context.Context, r api.Registry, opts *GenerateActionOptions, mw []ModelMiddleware, cb ModelStreamCallback) (*ModelResponse, error) { +func GenerateWithRequest(ctx context.Context, r api.Registry, opts *GenerateActionOptions, mmws []ModelMiddleware, cb ModelStreamCallback) (*ModelResponse, error) { if opts.Model == "" { if defaultModel, ok := r.LookupValue(api.DefaultModelKey).(string); ok && defaultModel != "" { opts.Model = defaultModel @@ -219,7 +219,27 @@ func GenerateWithRequest(ctx context.Context, r api.Registry, opts *GenerateActi return nil, core.NewError(core.NOT_FOUND, "ai.GenerateWithRequest: model %q not found", opts.Model) } - resumeOutput, err := handleResumeOption(ctx, r, opts) + var mws []Middleware + if len(opts.Use) > 0 { + mws = make([]Middleware, 0, len(opts.Use)) + for _, ref := range opts.Use { + desc := LookupMiddleware(r, ref.Name) + if desc == nil { + return nil, core.NewError(core.NOT_FOUND, "ai.GenerateWithRequest: middleware %q not found", ref.Name) + } + configJSON, err := json.Marshal(ref.Config) + if err != nil { + return nil, core.NewError(core.INTERNAL, "ai.GenerateWithRequest: failed to marshal config for middleware %q: %v", ref.Name, err) + } + mw, err := desc.configFromJSON(configJSON) + if err != nil { + return nil, core.NewError(core.INVALID_ARGUMENT, "ai.GenerateWithRequest: failed to create middleware %q: %v", ref.Name, err) + } + mws = append(mws, mw) + } + } + + resumeOutput, err := handleResumeOption(ctx, r, opts, mws) if err != nil { return nil, err } @@ -315,26 +335,6 @@ func GenerateWithRequest(ctx context.Context, r api.Registry, opts *GenerateActi Output: &outputCfg, } - var middlewareHandlers []Middleware - if len(opts.Use) > 0 { - middlewareHandlers = make([]Middleware, 0, len(opts.Use)) - for _, ref := range opts.Use { - desc := LookupMiddleware(r, ref.Name) - if desc == nil { - return nil, core.NewError(core.NOT_FOUND, "ai.GenerateWithRequest: middleware %q not found", ref.Name) - } - configJSON, err := json.Marshal(ref.Config) - if err != nil { - return nil, core.NewError(core.INTERNAL, "ai.GenerateWithRequest: failed to marshal config for middleware %q: %v", ref.Name, err) - } - handler, err := desc.configFromJSON(configJSON) - if err != nil { - return nil, core.NewError(core.INVALID_ARGUMENT, "ai.GenerateWithRequest: failed to create middleware %q: %v", ref.Name, err) - } - middlewareHandlers = append(middlewareHandlers, handler) - } - } - fn := m.Generate if bm != nil { if cb != nil { @@ -343,11 +343,11 @@ func GenerateWithRequest(ctx context.Context, r api.Registry, opts *GenerateActi fn = backgroundModelToModelFn(bm.Start) } - if len(middlewareHandlers) > 0 { + if len(mws) > 0 { modelHook := func(next ModelFunc) ModelFunc { wrapped := next - for i := len(middlewareHandlers) - 1; i >= 0; i-- { - h := middlewareHandlers[i] + for i := len(mws) - 1; i >= 0; i-- { + h := mws[i] inner := wrapped wrapped = func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { return h.WrapModel(ctx, &ModelParams{Request: req, Callback: cb}, func(ctx context.Context, params *ModelParams) (*ModelResponse, error) { @@ -357,9 +357,9 @@ func GenerateWithRequest(ctx context.Context, r api.Registry, opts *GenerateActi } return wrapped } - mw = append([]ModelMiddleware{modelHook}, mw...) + mmws = append([]ModelMiddleware{modelHook}, mmws...) } - fn = core.ChainMiddleware(mw...)(fn) + fn = core.ChainMiddleware(mmws...)(fn) // Inline recursive helper function that captures variables from parent scope. var generate func(context.Context, *ModelRequest, int, int) (*ModelResponse, error) @@ -427,7 +427,7 @@ func GenerateWithRequest(ctx context.Context, r api.Registry, opts *GenerateActi return nil, core.NewError(core.ABORTED, "exceeded maximum tool call iterations (%d)", maxTurns) } - newReq, interruptMsg, err := handleToolRequests(ctx, r, req, resp, wrappedCb, currentIndex, middlewareHandlers) + newReq, interruptMsg, err := handleToolRequests(ctx, r, req, resp, wrappedCb, currentIndex, mws) if err != nil { return nil, err } @@ -446,14 +446,14 @@ func GenerateWithRequest(ctx context.Context, r api.Registry, opts *GenerateActi } // Wrap generate with the Generate hook chain from middleware. - if len(middlewareHandlers) > 0 { + if len(mws) > 0 { innerGenerate := generate generate = func(ctx context.Context, req *ModelRequest, currentTurn int, messageIndex int) (*ModelResponse, error) { innerFn := func(ctx context.Context, params *GenerateParams) (*ModelResponse, error) { return innerGenerate(ctx, params.Request, currentTurn, messageIndex) } - for i := len(middlewareHandlers) - 1; i >= 0; i-- { - h := middlewareHandlers[i] + for i := len(mws) - 1; i >= 0; i-- { + h := mws[i] next := innerFn innerFn = func(ctx context.Context, params *GenerateParams) (*ModelResponse, error) { return h.WrapGenerate(ctx, params, next) @@ -841,23 +841,6 @@ func ensureToolRequestRefs(msg *Message) { } // clone creates a deep copy of the provided object using JSON marshaling and unmarshaling. -func clone[T any](obj *T) *T { - if obj == nil { - return nil - } - - bytes, err := json.Marshal(obj) - if err != nil { - panic(fmt.Sprintf("clone: failed to marshal object: %v", err)) - } - - var newObj T - if err := json.Unmarshal(bytes, &newObj); err != nil { - panic(fmt.Sprintf("clone: failed to unmarshal object: %v", err)) - } - - return &newObj -} // handleToolRequests processes any tool requests in the response, returning // either a new request to continue the conversation or nil if no tool requests @@ -870,7 +853,7 @@ func handleToolRequests(ctx context.Context, r api.Registry, req *ModelRequest, resultChan := make(chan result[*MultipartToolResponse]) toolMsg := &Message{Role: RoleTool} - revisedMsg := clone(resp.Message) + revisedMsg := resp.Message.Clone() for i, part := range revisedMsg.Content { if !part.IsToolRequest() { @@ -891,7 +874,7 @@ func handleToolRequests(ctx context.Context, r api.Registry, req *ModelRequest, if errors.As(err, &tie) { logger.FromContext(ctx).Debug("tool %q triggered an interrupt: %v", toolReq.Name, tie.Metadata) - newPart := clone(p) + newPart := p.Clone() if newPart.Metadata == nil { newPart.Metadata = make(map[string]any) } @@ -911,7 +894,7 @@ func handleToolRequests(ctx context.Context, r api.Registry, req *ModelRequest, return } - newPart := clone(p) + newPart := p.Clone() if newPart.Metadata == nil { newPart.Metadata = make(map[string]any) } @@ -1241,13 +1224,13 @@ func (m ModelRef) Config() any { // handleResumedToolRequest resolves a tool request from a previous, interrupted model turn, // when generation is being resumed. It determines the outcome of the tool request based on // pending output, or explicit 'respond' or 'restart' directives in the resume options. -func handleResumedToolRequest(ctx context.Context, r api.Registry, genOpts *GenerateActionOptions, p *Part) (*resumedToolRequestOutput, error) { +func handleResumedToolRequest(ctx context.Context, r api.Registry, genOpts *GenerateActionOptions, p *Part, middlewareHandlers []Middleware) (*resumedToolRequestOutput, error) { if p == nil || !p.IsToolRequest() { return nil, core.NewError(core.INVALID_ARGUMENT, "handleResumedToolRequest: part is not a tool request") } if pendingOutputVal, ok := p.Metadata["pendingOutput"]; ok { - newReqPart := clone(p) + newReqPart := p.Clone() delete(newReqPart.Metadata, "pendingOutput") newRespPart := NewResponseForToolRequest(p, pendingOutputVal) @@ -1266,7 +1249,7 @@ func handleResumedToolRequest(ctx context.Context, r api.Registry, genOpts *Gene if respondPart.ToolResponse != nil && respondPart.ToolResponse.Name == toolReq.Name && respondPart.ToolResponse.Ref == toolReq.Ref { - newToolReq := clone(p) + newToolReq := p.Clone() if interruptVal, ok := newToolReq.Metadata["interrupt"]; ok { delete(newToolReq.Metadata, "interrupt") newToolReq.Metadata["resolvedInterrupt"] = interruptVal @@ -1329,13 +1312,13 @@ func handleResumedToolRequest(ctx context.Context, r api.Registry, genOpts *Gene resumedCtx = origInputCtxKey.NewContext(resumedCtx, originalInputVal) } - output, err := tool.RunRaw(resumedCtx, restartPart.ToolRequest.Input) + multipartResp, err := runToolWithMiddleware(resumedCtx, tool, restartPart.ToolRequest, middlewareHandlers) if err != nil { var tie *toolInterruptError if errors.As(err, &tie) { logger.FromContext(ctx).Debug("tool %q triggered an interrupt: %v", restartPart.ToolRequest.Name, tie.Metadata) - interruptPart := clone(p) + interruptPart := p.Clone() if interruptPart.Metadata == nil { interruptPart.Metadata = make(map[string]any) } @@ -1349,7 +1332,7 @@ func handleResumedToolRequest(ctx context.Context, r api.Registry, genOpts *Gene return nil, core.NewError(core.INTERNAL, "tool %q failed: %v", restartPart.ToolRequest.Name, err) } - newToolReq := clone(p) + newToolReq := p.Clone() if interruptVal, ok := newToolReq.Metadata["interrupt"]; ok { delete(newToolReq.Metadata, "interrupt") newToolReq.Metadata["resolvedInterrupt"] = interruptVal @@ -1358,7 +1341,7 @@ func handleResumedToolRequest(ctx context.Context, r api.Registry, genOpts *Gene newToolResp := NewToolResponsePart(&ToolResponse{ Name: restartPart.ToolRequest.Name, Ref: restartPart.ToolRequest.Ref, - Output: output, + Output: multipartResp.Output, }) return &resumedToolRequestOutput{ @@ -1378,7 +1361,7 @@ func handleResumedToolRequest(ctx context.Context, r api.Registry, genOpts *Gene // handleResumeOption amends message history to handle `resume` arguments. // It returns the amended history. -func handleResumeOption(ctx context.Context, r api.Registry, genOpts *GenerateActionOptions) (*resumeOptionOutput, error) { +func handleResumeOption(ctx context.Context, r api.Registry, genOpts *GenerateActionOptions, middlewareHandlers []Middleware) (*resumeOptionOutput, error) { if genOpts.Resume == nil || (len(genOpts.Resume.Respond) == 0 && len(genOpts.Resume.Restart) == 0) { return &resumeOptionOutput{revisedRequest: genOpts}, nil } @@ -1414,7 +1397,7 @@ func handleResumeOption(ctx context.Context, r api.Registry, genOpts *GenerateAc toolReqCount++ go func(idx int, p *Part) { - output, err := handleResumedToolRequest(ctx, r, genOpts, p) + output, err := handleResumedToolRequest(ctx, r, genOpts, p, middlewareHandlers) resultChan <- result[*resumedToolRequestOutput]{ index: idx, value: output, diff --git a/go/ai/middleware.go b/go/ai/middleware.go index aff3b063fc..8b6f7d4ee8 100644 --- a/go/ai/middleware.go +++ b/go/ai/middleware.go @@ -26,7 +26,7 @@ import ( ) // middlewareConfigFunc creates a Middleware instance from JSON config. -type middlewareConfigFunc = func([]byte) (Middleware, error) +type middlewareConfigFunc = func(configBytes []byte) (Middleware, error) // Middleware provides hooks for different stages of generation. type Middleware interface { @@ -110,7 +110,7 @@ func NewMiddleware[T Middleware](description string, prototype T) *MiddlewareDes return &MiddlewareDesc{ Name: prototype.Name(), Description: description, - ConfigSchema: core.InferSchemaMap(*new(T)), + ConfigSchema: core.InferSchemaMap(prototype), configFromJSON: func(configJSON []byte) (Middleware, error) { inst := prototype.New() if len(configJSON) > 0 { diff --git a/go/ai/tools.go b/go/ai/tools.go index 453ad779a3..b812c4b903 100644 --- a/go/ai/tools.go +++ b/go/ai/tools.go @@ -106,6 +106,13 @@ func IsToolInterruptError(err error) (bool, map[string]any) { return false, nil } +// NewToolInterruptError creates a tool interrupt error with the given metadata. +// This is intended for use in middleware that needs to interrupt tool execution +// without calling the tool itself. +func NewToolInterruptError(metadata map[string]any) error { + return &toolInterruptError{Metadata: metadata} +} + // InterruptOptions provides configuration for tool interruption. type InterruptOptions struct { Metadata map[string]any @@ -237,6 +244,13 @@ func (tc *ToolContext) IsResumed() bool { return tc.Resumed != nil } +// IsToolResumed reports whether the current context is a resumed tool execution. +// This is intended for use in middleware that needs to distinguish between +// first-time and restarted tool calls. +func IsToolResumed(ctx context.Context) bool { + return resumedCtxKey.FromContext(ctx) != nil +} + // ResumedValue retrieves a typed value from the Resumed metadata. // Returns the zero value and false if the key doesn't exist or the type doesn't match. func ResumedValue[T any](tc *ToolContext, key string) (T, bool) { diff --git a/go/genkit/genkit.go b/go/genkit/genkit.go index 8fd32913c2..89b42de9fa 100644 --- a/go/genkit/genkit.go +++ b/go/genkit/genkit.go @@ -32,9 +32,21 @@ import ( "github.com/firebase/genkit/go/ai" "github.com/firebase/genkit/go/core" "github.com/firebase/genkit/go/core/api" + "github.com/firebase/genkit/go/internal/base" "github.com/firebase/genkit/go/internal/registry" ) +// genkitCtxKey is the context key for the Genkit instance. +var genkitCtxKey = base.NewContextKey[*Genkit]() + +// FromContext returns the [*Genkit] instance stored in the context. +// This is set automatically by [Generate] and related functions. +// Middleware implementations can use this to access the Genkit instance +// during generation. +func FromContext(ctx context.Context) *Genkit { + return genkitCtxKey.FromContext(ctx) +} + // Genkit encapsulates a Genkit instance, providing access to its registry, // configuration, and core functionalities. It serves as the central hub for // defining and managing Genkit resources like flows, models, tools, and prompts. @@ -962,7 +974,7 @@ func GenerateWithRequest(ctx context.Context, g *Genkit, actionOpts *ai.Generate // // fmt.Println(resp.Text()) func Generate(ctx context.Context, g *Genkit, opts ...ai.GenerateOption) (*ai.ModelResponse, error) { - return ai.Generate(ctx, g.reg, opts...) + return ai.Generate(genkitCtxKey.NewContext(ctx, g), g.reg, opts...) } // GenerateStream generates a model response and streams the output. @@ -994,7 +1006,7 @@ func Generate(ctx context.Context, g *Genkit, opts ...ai.GenerateOption) (*ai.Mo // } // } func GenerateStream(ctx context.Context, g *Genkit, opts ...ai.GenerateOption) iter.Seq2[*ai.ModelStreamValue, error] { - return ai.GenerateStream(ctx, g.reg, opts...) + return ai.GenerateStream(genkitCtxKey.NewContext(ctx, g), g.reg, opts...) } // GenerateOperation performs a model generation request using a flexible set of options @@ -1031,7 +1043,7 @@ func GenerateStream(ctx context.Context, g *Genkit, opts ...ai.GenerateOption) i // // Get the result of the operation // fmt.Println(op.Output.Text()) func GenerateOperation(ctx context.Context, g *Genkit, opts ...ai.GenerateOption) (*ai.ModelOperation, error) { - return ai.GenerateOperation(ctx, g.reg, opts...) + return ai.GenerateOperation(genkitCtxKey.NewContext(ctx, g), g.reg, opts...) } // CheckModelOperation checks the status of a background model operation by looking up the model and calling its Check method. @@ -1056,7 +1068,7 @@ func CheckModelOperation(ctx context.Context, g *Genkit, op *ai.ModelOperation) // } // fmt.Println(joke) func GenerateText(ctx context.Context, g *Genkit, opts ...ai.GenerateOption) (string, error) { - return ai.GenerateText(ctx, g.reg, opts...) + return ai.GenerateText(genkitCtxKey.NewContext(ctx, g), g.reg, opts...) } // GenerateData performs a model generation request, expecting structured output @@ -1085,7 +1097,7 @@ func GenerateText(ctx context.Context, g *Genkit, opts ...ai.GenerateOption) (st // // log.Printf("Book: %+v\n", book) // Output: Book: {Title:The Hitchhiker's Guide to the Galaxy Author:Douglas Adams Year:1979} func GenerateData[Out any](ctx context.Context, g *Genkit, opts ...ai.GenerateOption) (*Out, *ai.ModelResponse, error) { - return ai.GenerateData[Out](ctx, g.reg, opts...) + return ai.GenerateData[Out](genkitCtxKey.NewContext(ctx, g), g.reg, opts...) } // GenerateDataStream generates a model response with streaming and returns strongly-typed output. @@ -1124,7 +1136,7 @@ func GenerateData[Out any](ctx context.Context, g *Genkit, opts ...ai.GenerateOp // } // } func GenerateDataStream[Out any](ctx context.Context, g *Genkit, opts ...ai.GenerateOption) iter.Seq2[*ai.StreamValue[Out, Out], error] { - return ai.GenerateDataStream[Out](ctx, g.reg, opts...) + return ai.GenerateDataStream[Out](genkitCtxKey.NewContext(ctx, g), g.reg, opts...) } // Retrieve performs a document retrieval request using a flexible set of options @@ -1585,3 +1597,4 @@ func ListResources(g *Genkit) []ai.Resource { } return resources } + diff --git a/go/plugins/middleware/fallback.go b/go/plugins/middleware/fallback.go new file mode 100644 index 0000000000..f5f15ad3dc --- /dev/null +++ b/go/plugins/middleware/fallback.go @@ -0,0 +1,148 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package middleware + +import ( + "context" + "encoding/json" + "errors" + "slices" + + "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/core" + "github.com/firebase/genkit/go/genkit" +) + +// defaultFallbackStatuses are the status codes that trigger a fallback by default. +var defaultFallbackStatuses = []core.StatusName{ + core.UNAVAILABLE, + core.DEADLINE_EXCEEDED, + core.RESOURCE_EXHAUSTED, + core.ABORTED, + core.INTERNAL, + core.NOT_FOUND, + core.UNIMPLEMENTED, +} + +// Fallback is a middleware that tries alternative models when the primary model +// fails with a retryable error status. +// +// It only hooks the Model stage — when a model API call fails with a matching +// status, the request is forwarded to the next model in the list. +// +// Models are specified as [ai.ModelArg] values (model references, model instances, +// or strings via [ai.NewModelRef]) and resolved via the [genkit.Genkit] instance at call time. +// The Genkit instance is available via [genkit.FromContext] during generation. +// +// Usage: +// +// resp, err := genkit.Generate(ctx, g, +// ai.WithModel(primary), +// ai.WithPrompt("hello"), +// ai.WithUse(&middleware.Fallback{Models: []ai.ModelArg{backup1, backup2}}), +// ) +type Fallback struct { + ai.BaseMiddleware + // Models is the ordered list of fallback models to try. + // Each entry is an [ai.ModelArg] (e.g. an [ai.Model], [ai.ModelRef], etc.). + // These are tried in order after the primary model fails. + Models ModelList `json:"models,omitempty"` + // Statuses is the set of status codes that trigger a fallback. + // Only [core.GenkitError] errors with a matching status will trigger fallback; + // non-GenkitError errors propagate immediately. + // Defaults to [defaultFallbackStatuses]. + Statuses []core.StatusName `json:"statuses,omitempty"` +} + +// ModelList is a list of [ai.ModelArg] values that marshals to/from JSON as +// a list of model name strings. +type ModelList []ai.ModelArg + +func (l ModelList) MarshalJSON() ([]byte, error) { + names := make([]string, len(l)) + for i, m := range l { + names[i] = m.Name() + } + return json.Marshal(names) +} + +func (l *ModelList) UnmarshalJSON(data []byte) error { + var names []string + if err := json.Unmarshal(data, &names); err != nil { + return err + } + *l = make(ModelList, len(names)) + for i, name := range names { + (*l)[i] = ai.NewModelRef(name, nil) + } + return nil +} + +func (f *Fallback) Name() string { return provider + "/fallback" } + +func (f *Fallback) New() ai.Middleware { + return &Fallback{ + Models: f.Models, + Statuses: f.Statuses, + } +} + +func (f *Fallback) statuses() []core.StatusName { + if len(f.Statuses) > 0 { + return f.Statuses + } + return defaultFallbackStatuses +} + +func (f *Fallback) WrapModel(ctx context.Context, params *ai.ModelParams, next ai.ModelNext) (*ai.ModelResponse, error) { + resp, err := next(ctx, params) + if err == nil { + return resp, nil + } + + if !isFallbackRetryable(err, f.statuses()) { + return nil, err + } + + lastErr := err + for _, ref := range f.Models { + name := ref.Name() + m := genkit.LookupModel(genkit.FromContext(ctx), name) + if m == nil { + return nil, core.NewError(core.NOT_FOUND, "fallback: model %q not found", name) + } + resp, err := m.Generate(ctx, params.Request, params.Callback) + if err == nil { + return resp, nil + } + lastErr = err + if !isFallbackRetryable(err, f.statuses()) { + return nil, err + } + } + return nil, lastErr +} + +// isFallbackRetryable reports whether err should trigger trying the next model. +// Only GenkitErrors with a matching status trigger fallback. +func isFallbackRetryable(err error, statuses []core.StatusName) bool { + var ge *core.GenkitError + if !errors.As(err, &ge) { + return false + } + return slices.Contains(statuses, ge.Status) +} diff --git a/go/plugins/middleware/fallback_test.go b/go/plugins/middleware/fallback_test.go new file mode 100644 index 0000000000..62c6e0f708 --- /dev/null +++ b/go/plugins/middleware/fallback_test.go @@ -0,0 +1,270 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package middleware + +import ( + "context" + "fmt" + "strings" + "testing" + + "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/core" + "github.com/firebase/genkit/go/genkit" +) + +func newTestGenkit(t *testing.T) *genkit.Genkit { + t.Helper() + return genkit.Init(context.Background()) +} + +func defineTestModel(t *testing.T, g *genkit.Genkit, name string, fn ai.ModelFunc) ai.Model { + t.Helper() + return genkit.DefineModel(g, name, &ai.ModelOptions{ + Supports: &ai.ModelSupports{Multiturn: true, SystemRole: true}, + }, fn) +} + +func TestFallbackNotTriggeredOnSuccess(t *testing.T) { + g := newTestGenkit(t) + primaryCalls := 0 + secondaryCalls := 0 + + primary := defineTestModel(t, g, "test/primary", func(ctx context.Context, req *ai.ModelRequest, cb ai.ModelStreamCallback) (*ai.ModelResponse, error) { + primaryCalls++ + return &ai.ModelResponse{Message: ai.NewModelTextMessage("primary")}, nil + }) + secondary := defineTestModel(t, g, "test/secondary", func(ctx context.Context, req *ai.ModelRequest, cb ai.ModelStreamCallback) (*ai.ModelResponse, error) { + secondaryCalls++ + return &ai.ModelResponse{Message: ai.NewModelTextMessage("secondary")}, nil + }) + + fb := &Fallback{Models: ModelList{secondary}} + + resp, err := genkit.Generate(ctx, g, ai.WithModel(primary), ai.WithPrompt("hello"), ai.WithUse(fb)) + if err != nil { + t.Fatal(err) + } + if resp.Text() != "primary" { + t.Errorf("got %q, want %q", resp.Text(), "primary") + } + if primaryCalls != 1 { + t.Errorf("primary called %d times, want 1", primaryCalls) + } + if secondaryCalls != 0 { + t.Errorf("secondary called %d times, want 0", secondaryCalls) + } +} + +func TestFallbackTriggeredOnRetryableError(t *testing.T) { + g := newTestGenkit(t) + + primary := defineTestModel(t, g, "test/primary", func(ctx context.Context, req *ai.ModelRequest, cb ai.ModelStreamCallback) (*ai.ModelResponse, error) { + return nil, core.NewError(core.UNAVAILABLE, "primary down") + }) + secondary := defineTestModel(t, g, "test/secondary", func(ctx context.Context, req *ai.ModelRequest, cb ai.ModelStreamCallback) (*ai.ModelResponse, error) { + return &ai.ModelResponse{Message: ai.NewModelTextMessage("secondary ok")}, nil + }) + + fb := &Fallback{Models: ModelList{secondary}} + + resp, err := genkit.Generate(ctx, g, ai.WithModel(primary), ai.WithPrompt("hello"), ai.WithUse(fb)) + if err != nil { + t.Fatal(err) + } + if resp.Text() != "secondary ok" { + t.Errorf("got %q, want %q", resp.Text(), "secondary ok") + } +} + +func TestFallbackTriesMultipleModels(t *testing.T) { + g := newTestGenkit(t) + var callOrder []string + + primary := defineTestModel(t, g, "test/primary", func(ctx context.Context, req *ai.ModelRequest, cb ai.ModelStreamCallback) (*ai.ModelResponse, error) { + callOrder = append(callOrder, "primary") + return nil, core.NewError(core.UNAVAILABLE, "primary down") + }) + secondary := defineTestModel(t, g, "test/secondary", func(ctx context.Context, req *ai.ModelRequest, cb ai.ModelStreamCallback) (*ai.ModelResponse, error) { + callOrder = append(callOrder, "secondary") + return nil, core.NewError(core.RESOURCE_EXHAUSTED, "secondary exhausted") + }) + tertiary := defineTestModel(t, g, "test/tertiary", func(ctx context.Context, req *ai.ModelRequest, cb ai.ModelStreamCallback) (*ai.ModelResponse, error) { + callOrder = append(callOrder, "tertiary") + return &ai.ModelResponse{Message: ai.NewModelTextMessage("tertiary ok")}, nil + }) + + fb := &Fallback{Models: ModelList{secondary, tertiary}} + + resp, err := genkit.Generate(ctx, g, ai.WithModel(primary), ai.WithPrompt("hello"), ai.WithUse(fb)) + if err != nil { + t.Fatal(err) + } + if resp.Text() != "tertiary ok" { + t.Errorf("got %q, want %q", resp.Text(), "tertiary ok") + } + want := []string{"primary", "secondary", "tertiary"} + if len(callOrder) != len(want) { + t.Fatalf("got call order %v, want %v", callOrder, want) + } + for i := range want { + if callOrder[i] != want[i] { + t.Errorf("callOrder[%d] = %q, want %q", i, callOrder[i], want[i]) + } + } +} + +func TestFallbackAllModelsFail(t *testing.T) { + g := newTestGenkit(t) + + primary := defineTestModel(t, g, "test/primary", func(ctx context.Context, req *ai.ModelRequest, cb ai.ModelStreamCallback) (*ai.ModelResponse, error) { + return nil, core.NewError(core.UNAVAILABLE, "primary down") + }) + secondary := defineTestModel(t, g, "test/secondary", func(ctx context.Context, req *ai.ModelRequest, cb ai.ModelStreamCallback) (*ai.ModelResponse, error) { + return nil, core.NewError(core.UNAVAILABLE, "secondary down") + }) + + fb := &Fallback{Models: ModelList{secondary}} + + _, err := genkit.Generate(ctx, g, ai.WithModel(primary), ai.WithPrompt("hello"), ai.WithUse(fb)) + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "secondary down") { + t.Errorf("error %q does not contain %q", err.Error(), "secondary down") + } +} + +func TestFallbackDoesNotTriggerOnNonRetryableError(t *testing.T) { + g := newTestGenkit(t) + secondaryCalls := 0 + + primary := defineTestModel(t, g, "test/primary", func(ctx context.Context, req *ai.ModelRequest, cb ai.ModelStreamCallback) (*ai.ModelResponse, error) { + return nil, core.NewError(core.INVALID_ARGUMENT, "bad input") + }) + secondary := defineTestModel(t, g, "test/secondary", func(ctx context.Context, req *ai.ModelRequest, cb ai.ModelStreamCallback) (*ai.ModelResponse, error) { + secondaryCalls++ + return &ai.ModelResponse{Message: ai.NewModelTextMessage("secondary")}, nil + }) + + fb := &Fallback{Models: ModelList{secondary}} + + _, err := genkit.Generate(ctx, g, ai.WithModel(primary), ai.WithPrompt("hello"), ai.WithUse(fb)) + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "bad input") { + t.Errorf("error %q does not contain %q", err.Error(), "bad input") + } + if secondaryCalls != 0 { + t.Errorf("secondary called %d times, want 0 (non-retryable error)", secondaryCalls) + } +} + +func TestFallbackDoesNotTriggerOnNonGenkitError(t *testing.T) { + g := newTestGenkit(t) + secondaryCalls := 0 + + primary := defineTestModel(t, g, "test/primary", func(ctx context.Context, req *ai.ModelRequest, cb ai.ModelStreamCallback) (*ai.ModelResponse, error) { + return nil, fmt.Errorf("plain error") + }) + secondary := defineTestModel(t, g, "test/secondary", func(ctx context.Context, req *ai.ModelRequest, cb ai.ModelStreamCallback) (*ai.ModelResponse, error) { + secondaryCalls++ + return &ai.ModelResponse{Message: ai.NewModelTextMessage("secondary")}, nil + }) + + fb := &Fallback{Models: ModelList{secondary}} + + _, err := genkit.Generate(ctx, g, ai.WithModel(primary), ai.WithPrompt("hello"), ai.WithUse(fb)) + if err == nil { + t.Fatal("expected error, got nil") + } + if secondaryCalls != 0 { + t.Errorf("secondary called %d times, want 0 (non-GenkitError)", secondaryCalls) + } +} + +func TestFallbackStopsOnNonRetryableFallbackError(t *testing.T) { + g := newTestGenkit(t) + tertiaryCalls := 0 + + primary := defineTestModel(t, g, "test/primary", func(ctx context.Context, req *ai.ModelRequest, cb ai.ModelStreamCallback) (*ai.ModelResponse, error) { + return nil, core.NewError(core.UNAVAILABLE, "primary down") + }) + secondary := defineTestModel(t, g, "test/secondary", func(ctx context.Context, req *ai.ModelRequest, cb ai.ModelStreamCallback) (*ai.ModelResponse, error) { + return nil, core.NewError(core.INVALID_ARGUMENT, "bad request from secondary") + }) + tertiary := defineTestModel(t, g, "test/tertiary", func(ctx context.Context, req *ai.ModelRequest, cb ai.ModelStreamCallback) (*ai.ModelResponse, error) { + tertiaryCalls++ + return &ai.ModelResponse{Message: ai.NewModelTextMessage("tertiary")}, nil + }) + + fb := &Fallback{Models: ModelList{secondary, tertiary}} + + _, err := genkit.Generate(ctx, g, ai.WithModel(primary), ai.WithPrompt("hello"), ai.WithUse(fb)) + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "bad request from secondary") { + t.Errorf("error %q does not contain %q", err.Error(), "bad request from secondary") + } + if tertiaryCalls != 0 { + t.Errorf("tertiary called %d times, want 0", tertiaryCalls) + } +} + +func TestFallbackCustomStatuses(t *testing.T) { + g := newTestGenkit(t) + + primary := defineTestModel(t, g, "test/primary", func(ctx context.Context, req *ai.ModelRequest, cb ai.ModelStreamCallback) (*ai.ModelResponse, error) { + return nil, core.NewError(core.PERMISSION_DENIED, "forbidden") + }) + secondary := defineTestModel(t, g, "test/secondary", func(ctx context.Context, req *ai.ModelRequest, cb ai.ModelStreamCallback) (*ai.ModelResponse, error) { + return &ai.ModelResponse{Message: ai.NewModelTextMessage("secondary ok")}, nil + }) + + fb := &Fallback{ + Models: ModelList{secondary}, + Statuses: []core.StatusName{core.PERMISSION_DENIED}, + } + + resp, err := genkit.Generate(ctx, g, ai.WithModel(primary), ai.WithPrompt("hello"), ai.WithUse(fb)) + if err != nil { + t.Fatal(err) + } + if resp.Text() != "secondary ok" { + t.Errorf("got %q, want %q", resp.Text(), "secondary ok") + } +} + +func TestFallbackModelNotFound(t *testing.T) { + g := newTestGenkit(t) + + primary := defineTestModel(t, g, "test/primary", func(ctx context.Context, req *ai.ModelRequest, cb ai.ModelStreamCallback) (*ai.ModelResponse, error) { + return nil, core.NewError(core.UNAVAILABLE, "primary down") + }) + + fb := &Fallback{Models: ModelList{ai.NewModelRef("test/nonexistent", nil)}} + + _, err := genkit.Generate(ctx, g, ai.WithModel(primary), ai.WithPrompt("hello"), ai.WithUse(fb)) + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "not found") { + t.Errorf("error %q does not contain %q", err.Error(), "not found") + } +} diff --git a/go/plugins/middleware/plugin.go b/go/plugins/middleware/plugin.go new file mode 100644 index 0000000000..4ef05f44de --- /dev/null +++ b/go/plugins/middleware/plugin.go @@ -0,0 +1,43 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package middleware + +import ( + "context" + + "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/core/api" +) + +const provider = "genkit-middleware" + +// Middleware provides the built-in middleware (Retry, Fallback, ToolApproval) +// as a Genkit plugin. Register it with [genkit.WithPlugins] during [genkit.Init]. +type Middleware struct{} + +func (p *Middleware) Name() string { return provider } + +func (p *Middleware) Init(ctx context.Context) []api.Action { return nil } + +func (p *Middleware) ListMiddleware(ctx context.Context) ([]*ai.MiddlewareDesc, error) { + return []*ai.MiddlewareDesc{ + ai.NewMiddleware("Retry failed model calls with exponential backoff", &Retry{}), + ai.NewMiddleware("Try alternative models when the primary model fails", &Fallback{}), + ai.NewMiddleware("Require explicit approval before executing tools", &ToolApproval{}), + ai.NewMiddleware("Return tool errors as text responses instead of failing", &ToolErrorHandler{}), + }, nil +} diff --git a/go/plugins/middleware/retry.go b/go/plugins/middleware/retry.go new file mode 100644 index 0000000000..570fab7ef3 --- /dev/null +++ b/go/plugins/middleware/retry.go @@ -0,0 +1,171 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +// Package middleware provides reusable middleware for Genkit model generation, +// including retry with exponential backoff and model fallback. +package middleware + +import ( + "context" + "errors" + "math" + "math/rand" + "slices" + "time" + + "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/core" +) + +// defaultRetryStatuses are the status codes that trigger a retry by default. +var defaultRetryStatuses = []core.StatusName{ + core.UNAVAILABLE, + core.DEADLINE_EXCEEDED, + core.RESOURCE_EXHAUSTED, + core.ABORTED, + core.INTERNAL, +} + +// sleepFunc is the function used for delays. Overridable for testing. +var sleepFunc = time.Sleep + +// Retry is a middleware that retries failed model calls with exponential backoff. +// +// It only hooks the Model stage — individual model API calls are retried, +// not the entire generate loop. +// +// By default, retries occur for non-[core.GenkitError] errors (e.g. network failures) +// and for [core.GenkitError] errors whose status is one of UNAVAILABLE, DEADLINE_EXCEEDED, +// RESOURCE_EXHAUSTED, ABORTED, or INTERNAL_SERVER_ERROR. +// +// Usage: +// +// resp, err := ai.Generate(ctx, r, +// ai.WithModel(m), +// ai.WithPrompt("hello"), +// ai.WithUse(&middleware.Retry{MaxRetries: 3}), +// ) +type Retry struct { + ai.BaseMiddleware + // MaxRetries is the maximum number of retry attempts. Defaults to 3. + MaxRetries int `json:"maxRetries,omitempty"` + // Statuses is the set of status codes that trigger a retry for [core.GenkitError] errors. + // Non-GenkitError errors are always retried regardless of this setting. + // Defaults to [defaultRetryStatuses]. + Statuses []core.StatusName `json:"statuses,omitempty"` + // InitialDelayMs is the delay before the first retry, in milliseconds. Defaults to 1000. + InitialDelayMs int `json:"initialDelayMs,omitempty"` + // MaxDelayMs is the upper bound on retry delay, in milliseconds. Defaults to 60000. + MaxDelayMs int `json:"maxDelayMs,omitempty"` + // BackoffFactor is the multiplier applied to the delay after each retry. Defaults to 2. + BackoffFactor float64 `json:"backoffFactor,omitempty"` + // NoJitter disables random jitter on the delay. Jitter helps prevent + // thundering-herd problems when many clients retry simultaneously. + NoJitter bool `json:"noJitter,omitempty"` +} + +func (r *Retry) Name() string { return provider + "/retry" } + +func (r *Retry) New() ai.Middleware { + return &Retry{ + MaxRetries: r.MaxRetries, + Statuses: r.Statuses, + InitialDelayMs: r.InitialDelayMs, + MaxDelayMs: r.MaxDelayMs, + BackoffFactor: r.BackoffFactor, + NoJitter: r.NoJitter, + } +} + +func (r *Retry) maxRetries() int { + if r.MaxRetries > 0 { + return r.MaxRetries + } + return 3 +} + +func (r *Retry) statuses() []core.StatusName { + if len(r.Statuses) > 0 { + return r.Statuses + } + return defaultRetryStatuses +} + +func (r *Retry) initialDelay() time.Duration { + if r.InitialDelayMs > 0 { + return time.Duration(r.InitialDelayMs) * time.Millisecond + } + return time.Second +} + +func (r *Retry) maxDelay() time.Duration { + if r.MaxDelayMs > 0 { + return time.Duration(r.MaxDelayMs) * time.Millisecond + } + return 60 * time.Second +} + +func (r *Retry) backoffFactor() float64 { + if r.BackoffFactor > 0 { + return r.BackoffFactor + } + return 2 +} + +func (r *Retry) WrapModel(ctx context.Context, params *ai.ModelParams, next ai.ModelNext) (*ai.ModelResponse, error) { + maxRetries := r.maxRetries() + statuses := r.statuses() + currentDelay := r.initialDelay() + + var lastErr error + for attempt := 0; attempt <= maxRetries; attempt++ { + resp, err := next(ctx, params) + if err == nil { + return resp, nil + } + lastErr = err + + if attempt == maxRetries { + break + } + + if !isRetryable(err, statuses) { + return nil, err + } + + delay := currentDelay + if !r.NoJitter { + jitter := time.Duration(float64(time.Second) * math.Pow(2, float64(attempt)) * rand.Float64()) + delay += jitter + } + + sleepFunc(delay) + + currentDelay = min(time.Duration(float64(currentDelay)*r.backoffFactor()), r.maxDelay()) + } + return nil, lastErr +} + +// isRetryable reports whether err should trigger a retry. +// Non-GenkitError errors are always retried. GenkitErrors are retried +// only if their status is in the provided list. +func isRetryable(err error, statuses []core.StatusName) bool { + var ge *core.GenkitError + if !errors.As(err, &ge) { + return true // unknown errors are retryable + } + return slices.Contains(statuses, ge.Status) +} diff --git a/go/plugins/middleware/retry_test.go b/go/plugins/middleware/retry_test.go new file mode 100644 index 0000000000..d7925582c6 --- /dev/null +++ b/go/plugins/middleware/retry_test.go @@ -0,0 +1,256 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package middleware + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/core" + "github.com/firebase/genkit/go/internal/registry" +) + +func newTestRegistry(t *testing.T) *registry.Registry { + t.Helper() + r := registry.New() + ai.ConfigureFormats(r) + return r +} + +func defineModel(t *testing.T, r *registry.Registry, name string, fn ai.ModelFunc) ai.Model { + t.Helper() + return ai.DefineModel(r, name, &ai.ModelOptions{ + Supports: &ai.ModelSupports{Multiturn: true, SystemRole: true}, + }, fn) +} + +func init() { + // Disable real sleeping in tests. + sleepFunc = func(time.Duration) {} +} + +func TestRetrySucceedsOnFirstAttempt(t *testing.T) { + r := newTestRegistry(t) + calls := 0 + m := defineModel(t, r, "test/ok", func(ctx context.Context, req *ai.ModelRequest, cb ai.ModelStreamCallback) (*ai.ModelResponse, error) { + calls++ + return &ai.ModelResponse{Message: ai.NewModelTextMessage("ok")}, nil + }) + + retry := &Retry{} + ai.DefineMiddleware(r, "retry", retry) + + resp, err := ai.Generate(ctx, r, ai.WithModel(m), ai.WithPrompt("hello"), ai.WithUse(retry)) + if err != nil { + t.Fatal(err) + } + if resp.Text() != "ok" { + t.Errorf("got %q, want %q", resp.Text(), "ok") + } + if calls != 1 { + t.Errorf("got %d calls, want 1", calls) + } +} + +func TestRetryRecoversAfterTransientError(t *testing.T) { + r := newTestRegistry(t) + calls := 0 + m := defineModel(t, r, "test/transient", func(ctx context.Context, req *ai.ModelRequest, cb ai.ModelStreamCallback) (*ai.ModelResponse, error) { + calls++ + if calls <= 2 { + return nil, core.NewError(core.UNAVAILABLE, "service down") + } + return &ai.ModelResponse{Message: ai.NewModelTextMessage("recovered")}, nil + }) + + retry := &Retry{} + ai.DefineMiddleware(r, "retry", retry) + + resp, err := ai.Generate(ctx, r, ai.WithModel(m), ai.WithPrompt("hello"), ai.WithUse(retry)) + if err != nil { + t.Fatal(err) + } + if resp.Text() != "recovered" { + t.Errorf("got %q, want %q", resp.Text(), "recovered") + } + if calls != 3 { + t.Errorf("got %d calls, want 3", calls) + } +} + +func TestRetryExhaustsMaxRetries(t *testing.T) { + r := newTestRegistry(t) + calls := 0 + m := defineModel(t, r, "test/alwaysfail", func(ctx context.Context, req *ai.ModelRequest, cb ai.ModelStreamCallback) (*ai.ModelResponse, error) { + calls++ + return nil, core.NewError(core.UNAVAILABLE, "always failing") + }) + + retry := &Retry{MaxRetries: 2} + ai.DefineMiddleware(r, "retry", retry) + + _, err := ai.Generate(ctx, r, ai.WithModel(m), ai.WithPrompt("hello"), ai.WithUse(retry)) + if err == nil { + t.Fatal("expected error, got nil") + } + // 1 initial + 2 retries = 3 calls + if calls != 3 { + t.Errorf("got %d calls, want 3", calls) + } +} + +func TestRetryDoesNotRetryNonMatchingGenkitError(t *testing.T) { + r := newTestRegistry(t) + calls := 0 + m := defineModel(t, r, "test/badarg", func(ctx context.Context, req *ai.ModelRequest, cb ai.ModelStreamCallback) (*ai.ModelResponse, error) { + calls++ + return nil, core.NewError(core.INVALID_ARGUMENT, "bad input") + }) + + retry := &Retry{} + ai.DefineMiddleware(r, "retry", retry) + + _, err := ai.Generate(ctx, r, ai.WithModel(m), ai.WithPrompt("hello"), ai.WithUse(retry)) + if err == nil { + t.Fatal("expected error, got nil") + } + if calls != 1 { + t.Errorf("got %d calls, want 1 (no retries for INVALID_ARGUMENT)", calls) + } +} + +func TestRetryRetriesNonGenkitErrors(t *testing.T) { + r := newTestRegistry(t) + calls := 0 + m := defineModel(t, r, "test/plainerr", func(ctx context.Context, req *ai.ModelRequest, cb ai.ModelStreamCallback) (*ai.ModelResponse, error) { + calls++ + if calls == 1 { + return nil, fmt.Errorf("network timeout") + } + return &ai.ModelResponse{Message: ai.NewModelTextMessage("ok")}, nil + }) + + retry := &Retry{} + ai.DefineMiddleware(r, "retry", retry) + + resp, err := ai.Generate(ctx, r, ai.WithModel(m), ai.WithPrompt("hello"), ai.WithUse(retry)) + if err != nil { + t.Fatal(err) + } + if resp.Text() != "ok" { + t.Errorf("got %q, want %q", resp.Text(), "ok") + } + if calls != 2 { + t.Errorf("got %d calls, want 2", calls) + } +} + +func TestRetryCustomStatuses(t *testing.T) { + r := newTestRegistry(t) + calls := 0 + m := defineModel(t, r, "test/forbidden", func(ctx context.Context, req *ai.ModelRequest, cb ai.ModelStreamCallback) (*ai.ModelResponse, error) { + calls++ + return nil, core.NewError(core.PERMISSION_DENIED, "forbidden") + }) + + retry := &Retry{ + Statuses: []core.StatusName{core.PERMISSION_DENIED}, + MaxRetries: 1, + } + ai.DefineMiddleware(r, "retry", retry) + + _, err := ai.Generate(ctx, r, ai.WithModel(m), ai.WithPrompt("hello"), ai.WithUse(retry)) + if err == nil { + t.Fatal("expected error, got nil") + } + // 1 initial + 1 retry = 2 calls + if calls != 2 { + t.Errorf("got %d calls, want 2", calls) + } +} + +func TestRetryBackoffDelays(t *testing.T) { + r := newTestRegistry(t) + m := defineModel(t, r, "test/delays", func(ctx context.Context, req *ai.ModelRequest, cb ai.ModelStreamCallback) (*ai.ModelResponse, error) { + return nil, core.NewError(core.UNAVAILABLE, "down") + }) + + var delays []time.Duration + origSleep := sleepFunc + sleepFunc = func(d time.Duration) { delays = append(delays, d) } + defer func() { sleepFunc = origSleep }() + + retry := &Retry{ + MaxRetries: 3, + InitialDelayMs: 100, + BackoffFactor: 2, + NoJitter: true, + } + ai.DefineMiddleware(r, "retry", retry) + + _, _ = ai.Generate(ctx, r, ai.WithModel(m), ai.WithPrompt("hello"), ai.WithUse(retry)) + + if len(delays) != 3 { + t.Fatalf("got %d delays, want 3", len(delays)) + } + // With no jitter and factor 2: 100ms, 200ms, 400ms + want := []time.Duration{100 * time.Millisecond, 200 * time.Millisecond, 400 * time.Millisecond} + for i, got := range delays { + if got != want[i] { + t.Errorf("delay[%d] = %v, want %v", i, got, want[i]) + } + } +} + +func TestRetryMaxDelayClamp(t *testing.T) { + r := newTestRegistry(t) + m := defineModel(t, r, "test/clamp", func(ctx context.Context, req *ai.ModelRequest, cb ai.ModelStreamCallback) (*ai.ModelResponse, error) { + return nil, core.NewError(core.UNAVAILABLE, "down") + }) + + var delays []time.Duration + origSleep := sleepFunc + sleepFunc = func(d time.Duration) { delays = append(delays, d) } + defer func() { sleepFunc = origSleep }() + + retry := &Retry{ + MaxRetries: 3, + InitialDelayMs: 500, + MaxDelayMs: 600, + BackoffFactor: 2, + NoJitter: true, + } + ai.DefineMiddleware(r, "retry", retry) + + _, _ = ai.Generate(ctx, r, ai.WithModel(m), ai.WithPrompt("hello"), ai.WithUse(retry)) + + if len(delays) != 3 { + t.Fatalf("got %d delays, want 3", len(delays)) + } + // 500ms, then 1000ms clamped to 600ms, then still 600ms (clamped again) + want := []time.Duration{500 * time.Millisecond, 600 * time.Millisecond, 600 * time.Millisecond} + for i, got := range delays { + if got != want[i] { + t.Errorf("delay[%d] = %v, want %v", i, got, want[i]) + } + } +} + +var ctx = context.Background() diff --git a/go/plugins/middleware/subagent_design.md b/go/plugins/middleware/subagent_design.md new file mode 100644 index 0000000000..63d724b944 --- /dev/null +++ b/go/plugins/middleware/subagent_design.md @@ -0,0 +1,232 @@ +# Subagent Middleware Design (Conceptual, Simplified) + +Status: proposal only. This document focuses on the critical path and intentionally drops non-essential complexity. + +## Why Not `WrapGenerate` for Dispatch + +Short answer: in current `ai.Generate` flow, `WrapGenerate` is the wrong interception point for model tool calls. + +In `go/ai/generate.go`, the model call and tool execution loop are inside the `generate(...)` function. `WrapGenerate` wraps each iteration of that function, but by the time `next(...)` returns to `WrapGenerate`, tool calls have already been handled (unless `ReturnToolRequests` is on, which changes global behavior). + +So `WrapGenerate` can: + +- See/modify the request before model execution. +- See the final iteration response. + +But `WrapGenerate` cannot naturally own per-tool dispatch semantics without re-implementing the internal tool loop. + +The clean interception point for delegated subagent execution is still tool execution (`WrapTool`), even if we use a single synthetic tool like `call_subagent`. + +## Simplified Architecture + +Use one universal middleware-injected tool: + +- Tool name: `call_subagent` (configurable). +- Model asks for this tool with `{agent, messages}` payload. +- Middleware intercepts this tool in `WrapTool` and routes to the configured child agent. + +This keeps delegation explicit, simple, and model-visible. + +## Reference Session-Flow Contract + +Session flows in `/Users/alex/Developer/genkit-super/genkit-read-only/go/ai/x/session_flow.go` emit: + +- Input: `SessionFlowInput{Messages}` +- Stream chunk envelope: `{Chunk, Status, Artifact, SnapshotCreated, EndTurn}` +- Final output: `{State, SnapshotID}` + +For this design, we only fold: + +- `messages` +- `artifacts` + +No custom-state fold modes. + +## Proposed Middleware Surface + +```go +type Subagent struct { + ai.BaseMiddleware + + // Single universal delegation tool name. + ToolName string `json:"toolName,omitempty"` // default: "call_subagent" + + // Agent routing table keyed by logical agent id. + Agents map[string]SubagentTarget `json:"agents,omitempty"` + + // Forward child model chunks to parent stream. + StreamChildren bool `json:"streamChildren,omitempty"` + + // Recursion guard. + MaxDepth int `json:"maxDepth,omitempty"` +} + +type SubagentTarget struct { + Kind string `json:"kind"` // "generate" | "prompt" | "sessionFlow" + Target string `json:"target"` // model/prompt/flow name +} + +type CallSubagentInput struct { + Agent string `json:"agent"` + Messages []*ai.Message `json:"messages,omitempty"` +} + +type SubagentFold struct { + Messages []*ai.Message `json:"messages,omitempty"` + Artifacts []*aix.SessionFlowArtifact `json:"artifacts,omitempty"` +} +``` + +## Critical Path + +1. Middleware injects one tool (`call_subagent`) via `Tools()`. +2. Parent model requests that tool with `{agent, messages}`. +3. `WrapTool` intercepts the tool call and executes selected child target. +4. Child emits stream chunks (optional passthrough to parent stream). +5. Middleware builds a fold object with only `messages` and `artifacts`. +6. Middleware returns `ToolResponse` where: + - `Output` contains child result + fold payload. + - `Content` contains compact natural-language summary for parent model. +7. Parent model continues normal tool loop with this tool response. + +## Pseudocode (Dispatch) + +```go +func (m *Subagent) WrapTool(ctx context.Context, p *ai.ToolParams, next ai.ToolNext) (*ai.ToolResponse, error) { + toolName := m.ToolName + if toolName == "" { + toolName = "call_subagent" + } + if p.Request.Name != toolName { + return next(ctx, p) + } + + if overDepth(ctx, m.MaxDepth) { + return nil, core.NewError(core.ABORTED, "subagent: max depth exceeded") + } + + in := decodeCallSubagentInput(p.Request.Input) // {agent,messages} + target, ok := m.Agents[in.Agent] + if !ok { + return nil, core.NewError(core.NOT_FOUND, "subagent: unknown agent %q", in.Agent) + } + + result, fold, err := m.runChild(ctx, target, in, p) + if err != nil { + return nil, err + } + + return &ai.ToolResponse{ + Name: p.Request.Name, + Output: map[string]any{ + "agent": in.Agent, + "result": result, + "fold": fold, // messages + artifacts only + }, + Content: []*ai.Part{ + ai.NewTextPart(summaryForParent(result, fold)), + }, + }, nil +} +``` + +## Pseudocode (Session-Flow Child) + +```go +func (m *Subagent) runSessionFlowChild(...) (result any, fold *SubagentFold, err error) { + conn, err := lookupSessionFlow(target).StreamBidi(ctx) + if err != nil { + return nil, nil, err + } + defer conn.Close() + + _ = conn.Send(&aix.SessionFlowInput{Messages: in.Messages}) + + fold = &SubagentFold{} + for ch, err := range conn.Receive() { + if err != nil { + return nil, nil, err + } + if m.StreamChildren && ch.Chunk != nil { + _ = emitParentChunk(ctx, ch.Chunk) // requires runtime hook, see gaps + } + if ch.Artifact != nil { + fold.Artifacts = append(fold.Artifacts, ch.Artifact) + } + if ch.EndTurn { + break + } + } + + out, err := conn.Output() + if err != nil { + return nil, nil, err + } + fold.Messages = out.State.Messages + // artifacts in out.State.Artifacts can be used as canonical deduped final set + if len(out.State.Artifacts) > 0 { + fold.Artifacts = out.State.Artifacts + } + return out, fold, nil +} +``` + +## What Was Removed to Simplify + +Removed from prior design: + +- Per-delegate tool names. +- Fold mode matrix (`tool-output` vs `state-bridge`). +- Custom-state folding. +- Typed status/snapshot event streaming as first-class requirement. + +Kept: + +- One universal delegation tool. +- Route table (`agent` -> target). +- Optional child chunk passthrough. +- Canonical fold of messages + artifacts only. + +## Minimal Runtime Gaps to Close + +### 1) Stream passthrough from tool middleware + +Current `ToolParams` has no parent stream emitter. Add: + +```go +EmitChunk func(context.Context, *ai.ModelResponseChunk) error +``` + +Wire from `handleToolRequests` -> `runToolWithMiddleware`. + +Without this, child chunk passthrough cannot happen from `WrapTool`. + +### 2) Flow/session-flow lookup API for middleware + +Middleware currently can look up models via `genkit.FromContext(ctx)` + `LookupModel`, but no equivalent flow/session-flow lookup helper exists for middleware packages. + +Add a minimal public lookup path for flow/session-flow actions. + +### 3) Bidi/runtime availability + +Current `genkit` core branch does not include the bidi/session-flow primitives present in `genkit-read-only`. Session-flow routing should be feature-gated until those primitives land. + +## Recommended Rollout + +### Phase 1 (works on current architecture) + +- Single `call_subagent` tool. +- Route only to `generate` and `prompt` child targets. +- Fold only messages/artifacts into tool output payload. +- No child stream passthrough yet. + +### Phase 2 + +- Enable session-flow routes once bidi primitives are in `genkit`. +- Add `EmitChunk` hook and stream child model chunks to parent stream. + +## Clarifying Questions + +1. Do you want `call_subagent` to accept raw `messages`, or a smaller `{goal, context}` payload that middleware converts to messages? +2. Should folded artifacts use final canonical state only (`out.State.Artifacts`) or include per-chunk artifact events as well? +3. For Phase 1, is no child stream passthrough acceptable until `EmitChunk` support is added? diff --git a/go/plugins/middleware/tool_approval.go b/go/plugins/middleware/tool_approval.go new file mode 100644 index 0000000000..4db79ceef3 --- /dev/null +++ b/go/plugins/middleware/tool_approval.go @@ -0,0 +1,161 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package middleware + +import ( + "context" + "fmt" + "maps" + "slices" + + "github.com/firebase/genkit/go/ai" +) + +const toolApprovalSource = "toolApproval" + +// ToolApprovalInterrupt is the typed interrupt metadata emitted by the +// [ToolApproval] middleware when a tool call is blocked for approval. +// +// Use [IsToolApprovalInterrupt] to check whether an interrupt came from this +// middleware and extract the metadata in one step. +type ToolApprovalInterrupt struct { + // Source identifies this interrupt as coming from the ToolApproval middleware. + Source string `json:"source"` + // Tool is the name of the tool that requires approval. + Tool string `json:"tool"` +} + +// IsToolApprovalInterrupt reports whether an interrupt [ai.Part] was emitted +// by the [ToolApproval] middleware. If so, it returns the typed metadata. +func IsToolApprovalInterrupt(p *ai.Part) (ToolApprovalInterrupt, bool) { + meta, ok := ai.InterruptAs[ToolApprovalInterrupt](p) + if !ok || meta.Source != toolApprovalSource { + return ToolApprovalInterrupt{}, false + } + return meta, true +} + +// ApproveInterrupt creates a restart [ai.Part] for a tool call that was +// blocked by [ToolApproval]. Pass the returned Part to [ai.WithToolRestarts] +// to re-execute the tool. +// +// Returns nil if p is not a tool request. +func ApproveInterrupt(p *ai.Part) *ai.Part { + if p == nil || !p.IsToolRequest() { + return nil + } + newMeta := maps.Clone(p.Metadata) + if newMeta == nil { + newMeta = make(map[string]any) + } + newMeta["resumed"] = true + delete(newMeta, "interrupt") + + part := ai.NewToolRequestPart(&ai.ToolRequest{ + Name: p.ToolRequest.Name, + Ref: p.ToolRequest.Ref, + Input: p.ToolRequest.Input, + }) + part.Metadata = newMeta + return part +} + +// DenyInterrupt creates a response [ai.Part] for a tool call that was blocked +// by [ToolApproval] and denied by the user. Pass the returned Part to +// [ai.WithToolResponses] to provide the denial message as the tool's output. +// +// Returns nil if p is not a tool request. +func DenyInterrupt(p *ai.Part, message string) *ai.Part { + if p == nil || !p.IsToolRequest() { + return nil + } + resp := ai.NewResponseForToolRequest(p, message) + resp.Metadata = map[string]any{ + "interruptResponse": true, + } + return resp +} + +// ToolApproval is a middleware that requires explicit approval for tool execution. +// +// AllowedTools and DeniedTools are mutually exclusive and control the default behavior: +// +// - AllowedTools: deny-by-default; only listed tools run, all others interrupt. +// - DeniedTools: allow-by-default; all tools run except listed ones, which interrupt. +// - Neither set: all tools interrupt (deny-all). +// +// Usage: +// +// resp, err := ai.Generate(ctx, r, +// ai.WithModel(m), +// ai.WithPrompt("do something"), +// ai.WithTools(toolA, toolB, toolC), +// ai.WithUse(&middleware.ToolApproval{AllowedTools: []string{"toolA"}}), +// ) +// // toolA runs automatically; toolB and toolC trigger interrupts. +// // Use resp.Interrupts() + WithToolRestarts() to approve and re-execute. +type ToolApproval struct { + ai.BaseMiddleware + // AllowedTools is the list of tool names that are pre-approved to run + // without interruption. Tools not in this list will trigger an interrupt. + // Mutually exclusive with DeniedTools. + AllowedTools []string `json:"allowedTools,omitempty"` + // DeniedTools is the list of tool names that will trigger an interrupt. + // Tools not in this list run immediately. + // Mutually exclusive with AllowedTools. + DeniedTools []string `json:"deniedTools,omitempty"` +} + +func (t *ToolApproval) Name() string { return provider + "/toolApproval" } + +func (t *ToolApproval) New() ai.Middleware { + return &ToolApproval{ + AllowedTools: t.AllowedTools, + DeniedTools: t.DeniedTools, + } +} + +func (t *ToolApproval) WrapTool(ctx context.Context, params *ai.ToolParams, next ai.ToolNext) (*ai.ToolResponse, error) { + if len(t.AllowedTools) > 0 && len(t.DeniedTools) > 0 { + return nil, fmt.Errorf("toolApproval: AllowedTools and DeniedTools are mutually exclusive") + } + + // Resumed (restarted) tool calls have already been approved by the caller. + if ai.IsToolResumed(ctx) { + return next(ctx, params) + } + + name := params.Tool.Name() + + interrupt := map[string]any{ + "source": toolApprovalSource, + "tool": name, + } + + if len(t.DeniedTools) > 0 { + if slices.Contains(t.DeniedTools, name) { + return nil, ai.NewToolInterruptError(interrupt) + } + return next(ctx, params) + } + + // AllowedTools mode (or neither set — deny all). + if slices.Contains(t.AllowedTools, name) { + return next(ctx, params) + } + return nil, ai.NewToolInterruptError(interrupt) +} diff --git a/go/plugins/middleware/tool_approval_test.go b/go/plugins/middleware/tool_approval_test.go new file mode 100644 index 0000000000..9a14e052b7 --- /dev/null +++ b/go/plugins/middleware/tool_approval_test.go @@ -0,0 +1,345 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package middleware + +import ( + "context" + "testing" + + "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/core/api" + "github.com/firebase/genkit/go/internal/registry" +) + +func defineToolModel(t *testing.T, r *registry.Registry, name string, fn ai.ModelFunc) ai.Model { + t.Helper() + return ai.DefineModel(r, name, &ai.ModelOptions{ + Supports: &ai.ModelSupports{Multiturn: true, SystemRole: true, Tools: true}, + }, fn) +} + +func defineTool(t *testing.T, r api.Registry, name string) ai.Tool { + t.Helper() + return ai.DefineTool(r, name, "test tool", + func(ctx *ai.ToolContext, input struct { + V string `json:"v"` + }) (string, error) { + return "result:" + input.V, nil + }) +} + +// twoToolModelHandler returns a model handler that requests two tools on the first call, +// then returns a final text response when it sees tool responses. +func twoToolModelHandler(tool1, tool2 string) ai.ModelFunc { + return func(ctx context.Context, req *ai.ModelRequest, cb ai.ModelStreamCallback) (*ai.ModelResponse, error) { + // Check if we already have tool responses + for _, msg := range req.Messages { + for _, part := range msg.Content { + if part.IsToolResponse() { + return &ai.ModelResponse{ + Request: req, + Message: ai.NewModelTextMessage("done"), + }, nil + } + } + } + // First call — request both tools + return &ai.ModelResponse{ + Request: req, + Message: &ai.Message{ + Role: ai.RoleModel, + Content: []*ai.Part{ + ai.NewToolRequestPart(&ai.ToolRequest{Name: tool1, Input: map[string]any{"v": "1"}}), + ai.NewToolRequestPart(&ai.ToolRequest{Name: tool2, Input: map[string]any{"v": "2"}}), + }, + }, + }, nil + } +} + +func TestToolApprovalAllowsApprovedTools(t *testing.T) { + r := newTestRegistry(t) + + m := defineToolModel(t, r, "test/twotools", twoToolModelHandler("allowed", "alsoAllowed")) + allowed := defineTool(t, r, "allowed") + alsoAllowed := defineTool(t, r, "alsoAllowed") + + ta := &ToolApproval{AllowedTools: []string{"allowed", "alsoAllowed"}} + ai.DefineMiddleware(r, "toolApproval", ta) + + resp, err := ai.Generate(ctx, r, + ai.WithModel(m), + ai.WithPrompt("go"), + ai.WithTools(allowed, alsoAllowed), + ai.WithUse(ta), + ) + if err != nil { + t.Fatal(err) + } + // Both tools approved → no interrupt, model returns "done" + if resp.Text() != "done" { + t.Errorf("got %q, want %q", resp.Text(), "done") + } + if resp.FinishReason == "interrupted" { + t.Error("did not expect interrupted finish reason") + } +} + +func TestToolApprovalInterruptsUnapprovedTools(t *testing.T) { + r := newTestRegistry(t) + + m := defineToolModel(t, r, "test/twotools", twoToolModelHandler("safe", "dangerous")) + safe := defineTool(t, r, "safe") + dangerous := defineTool(t, r, "dangerous") + + ta := &ToolApproval{AllowedTools: []string{"safe"}} + ai.DefineMiddleware(r, "toolApproval", ta) + + resp, err := ai.Generate(ctx, r, + ai.WithModel(m), + ai.WithPrompt("go"), + ai.WithTools(safe, dangerous), + ai.WithUse(ta), + ) + if err != nil { + t.Fatal(err) + } + if resp.FinishReason != "interrupted" { + t.Errorf("got finish reason %q, want %q", resp.FinishReason, "interrupted") + } + + interrupts := resp.Interrupts() + if len(interrupts) == 0 { + t.Fatal("expected at least one interrupt") + } + + // Find the interrupt for "dangerous" using the typed helper. + found := false + for _, p := range interrupts { + meta, ok := IsToolApprovalInterrupt(p) + if !ok { + continue + } + if meta.Tool == "dangerous" { + found = true + } + } + if !found { + t.Error("expected interrupt for 'dangerous' tool") + } +} + +func TestToolApprovalEmptyListInterruptsAll(t *testing.T) { + r := newTestRegistry(t) + + // Model requests a single tool + singleToolHandler := func(ctx context.Context, req *ai.ModelRequest, cb ai.ModelStreamCallback) (*ai.ModelResponse, error) { + for _, msg := range req.Messages { + for _, part := range msg.Content { + if part.IsToolResponse() { + return &ai.ModelResponse{Request: req, Message: ai.NewModelTextMessage("done")}, nil + } + } + } + return &ai.ModelResponse{ + Request: req, + Message: &ai.Message{ + Role: ai.RoleModel, + Content: []*ai.Part{ + ai.NewToolRequestPart(&ai.ToolRequest{Name: "myTool", Input: map[string]any{"v": "1"}}), + }, + }, + }, nil + } + + m := defineToolModel(t, r, "test/singletool", singleToolHandler) + myTool := defineTool(t, r, "myTool") + + ta := &ToolApproval{} // empty allowed list + ai.DefineMiddleware(r, "toolApproval", ta) + + resp, err := ai.Generate(ctx, r, + ai.WithModel(m), + ai.WithPrompt("go"), + ai.WithTools(myTool), + ai.WithUse(ta), + ) + if err != nil { + t.Fatal(err) + } + if resp.FinishReason != "interrupted" { + t.Errorf("got finish reason %q, want %q", resp.FinishReason, "interrupted") + } +} + +func TestToolApprovalDeniedToolsInterruptsDenied(t *testing.T) { + r := newTestRegistry(t) + + m := defineToolModel(t, r, "test/twotools", twoToolModelHandler("safe", "dangerous")) + safe := defineTool(t, r, "safe") + dangerous := defineTool(t, r, "dangerous") + + ta := &ToolApproval{DeniedTools: []string{"dangerous"}} + ai.DefineMiddleware(r, "toolApproval", ta) + + resp, err := ai.Generate(ctx, r, + ai.WithModel(m), + ai.WithPrompt("go"), + ai.WithTools(safe, dangerous), + ai.WithUse(ta), + ) + if err != nil { + t.Fatal(err) + } + if resp.FinishReason != "interrupted" { + t.Errorf("got finish reason %q, want %q", resp.FinishReason, "interrupted") + } + + interrupts := resp.Interrupts() + if len(interrupts) == 0 { + t.Fatal("expected at least one interrupt") + } + + found := false + for _, p := range interrupts { + meta, ok := IsToolApprovalInterrupt(p) + if !ok { + continue + } + if meta.Tool == "dangerous" { + found = true + } + if meta.Tool == "safe" { + t.Error("did not expect interrupt for 'safe' tool") + } + } + if !found { + t.Error("expected interrupt for 'dangerous' tool") + } +} + +func TestToolApprovalDeniedToolsAllowsOthers(t *testing.T) { + r := newTestRegistry(t) + + m := defineToolModel(t, r, "test/twotools", twoToolModelHandler("allowed1", "allowed2")) + allowed1 := defineTool(t, r, "allowed1") + allowed2 := defineTool(t, r, "allowed2") + + ta := &ToolApproval{DeniedTools: []string{"somethingElse"}} + ai.DefineMiddleware(r, "toolApproval", ta) + + resp, err := ai.Generate(ctx, r, + ai.WithModel(m), + ai.WithPrompt("go"), + ai.WithTools(allowed1, allowed2), + ai.WithUse(ta), + ) + if err != nil { + t.Fatal(err) + } + if resp.Text() != "done" { + t.Errorf("got %q, want %q", resp.Text(), "done") + } + if resp.FinishReason == "interrupted" { + t.Error("did not expect interrupted finish reason") + } +} + +func TestToolApprovalMutualExclusionErrors(t *testing.T) { + r := newTestRegistry(t) + + m := defineToolModel(t, r, "test/twotools", twoToolModelHandler("a", "b")) + a := defineTool(t, r, "a") + b := defineTool(t, r, "b") + + ta := &ToolApproval{ + AllowedTools: []string{"a"}, + DeniedTools: []string{"b"}, + } + ai.DefineMiddleware(r, "toolApproval", ta) + + _, err := ai.Generate(ctx, r, + ai.WithModel(m), + ai.WithPrompt("go"), + ai.WithTools(a, b), + ai.WithUse(ta), + ) + if err == nil { + t.Fatal("expected error when both AllowedTools and DeniedTools are set") + } +} + +func TestToolApprovalApproveAndDenyHelpers(t *testing.T) { + r := newTestRegistry(t) + + m := defineToolModel(t, r, "test/twotools", twoToolModelHandler("approvable", "deniable")) + approvable := defineTool(t, r, "approvable") + deniable := defineTool(t, r, "deniable") + + ta := &ToolApproval{} // deny all + ai.DefineMiddleware(r, "toolApproval", ta) + + resp, err := ai.Generate(ctx, r, + ai.WithModel(m), + ai.WithPrompt("go"), + ai.WithTools(approvable, deniable), + ai.WithUse(ta), + ) + if err != nil { + t.Fatal(err) + } + if resp.FinishReason != "interrupted" { + t.Fatalf("got finish reason %q, want %q", resp.FinishReason, "interrupted") + } + + var restarts, responses []*ai.Part + for _, interrupt := range resp.Interrupts() { + meta, ok := IsToolApprovalInterrupt(interrupt) + if !ok { + t.Fatal("expected tool approval interrupt") + } + switch meta.Tool { + case "approvable": + restarts = append(restarts, ApproveInterrupt(interrupt)) + case "deniable": + responses = append(responses, DenyInterrupt(interrupt, "denied by user")) + } + } + + if len(restarts) != 1 { + t.Fatalf("expected 1 restart, got %d", len(restarts)) + } + if len(responses) != 1 { + t.Fatalf("expected 1 response, got %d", len(responses)) + } + + // Resume with the approved and denied tools. + resp, err = ai.Generate(ctx, r, + ai.WithModel(m), + ai.WithMessages(resp.History()...), + ai.WithTools(approvable, deniable), + ai.WithToolRestarts(restarts...), + ai.WithToolResponses(responses...), + ai.WithUse(ta), + ) + if err != nil { + t.Fatal(err) + } + if resp.Text() != "done" { + t.Errorf("got %q, want %q", resp.Text(), "done") + } +} diff --git a/go/plugins/middleware/tool_error_handler.go b/go/plugins/middleware/tool_error_handler.go new file mode 100644 index 0000000000..99f0397a0e --- /dev/null +++ b/go/plugins/middleware/tool_error_handler.go @@ -0,0 +1,62 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package middleware + +import ( + "context" + "fmt" + + "github.com/firebase/genkit/go/ai" +) + +// ToolErrorHandler catches non-interrupt tool execution errors and returns +// them as tool responses with a text part describing the failure. This allows +// the model to see and react to tool failures instead of aborting the entire +// generation loop. +// +// Usage: +// +// resp, err := ai.Generate(ctx, r, +// ai.WithModel(m), +// ai.WithPrompt("do something"), +// ai.WithTools(toolA, toolB), +// ai.WithUse(&middleware.ToolErrorHandler{}), +// ) +type ToolErrorHandler struct { + ai.BaseMiddleware +} + +func (t *ToolErrorHandler) Name() string { return provider + "/toolErrorHandler" } + +func (t *ToolErrorHandler) New() ai.Middleware { + return &ToolErrorHandler{} +} + +func (t *ToolErrorHandler) WrapTool(ctx context.Context, params *ai.ToolParams, next ai.ToolNext) (*ai.ToolResponse, error) { + resp, err := next(ctx, params) + if err != nil { + if isInterrupt, _ := ai.IsToolInterruptError(err); isInterrupt { + return nil, err + } + return &ai.ToolResponse{ + Content: []*ai.Part{ + ai.NewTextPart(fmt.Sprintf("tool %q failed: %v", params.Tool.Name(), err)), + }, + }, nil + } + return resp, nil +} From 95a803b30d4a7d3ba524dcab1c710249b8377f5b Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Tue, 17 Feb 2026 12:48:22 -0800 Subject: [PATCH 2/5] Update genkit.go --- go/genkit/genkit.go | 1 - 1 file changed, 1 deletion(-) diff --git a/go/genkit/genkit.go b/go/genkit/genkit.go index 89b42de9fa..98b33ae020 100644 --- a/go/genkit/genkit.go +++ b/go/genkit/genkit.go @@ -1597,4 +1597,3 @@ func ListResources(g *Genkit) []ai.Resource { } return resources } - From f614abdb02812d98253d9466f9f3fbdcd845b582 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Tue, 17 Feb 2026 13:29:40 -0800 Subject: [PATCH 3/5] Delete subagent_design.md --- go/plugins/middleware/subagent_design.md | 232 ----------------------- 1 file changed, 232 deletions(-) delete mode 100644 go/plugins/middleware/subagent_design.md diff --git a/go/plugins/middleware/subagent_design.md b/go/plugins/middleware/subagent_design.md deleted file mode 100644 index 63d724b944..0000000000 --- a/go/plugins/middleware/subagent_design.md +++ /dev/null @@ -1,232 +0,0 @@ -# Subagent Middleware Design (Conceptual, Simplified) - -Status: proposal only. This document focuses on the critical path and intentionally drops non-essential complexity. - -## Why Not `WrapGenerate` for Dispatch - -Short answer: in current `ai.Generate` flow, `WrapGenerate` is the wrong interception point for model tool calls. - -In `go/ai/generate.go`, the model call and tool execution loop are inside the `generate(...)` function. `WrapGenerate` wraps each iteration of that function, but by the time `next(...)` returns to `WrapGenerate`, tool calls have already been handled (unless `ReturnToolRequests` is on, which changes global behavior). - -So `WrapGenerate` can: - -- See/modify the request before model execution. -- See the final iteration response. - -But `WrapGenerate` cannot naturally own per-tool dispatch semantics without re-implementing the internal tool loop. - -The clean interception point for delegated subagent execution is still tool execution (`WrapTool`), even if we use a single synthetic tool like `call_subagent`. - -## Simplified Architecture - -Use one universal middleware-injected tool: - -- Tool name: `call_subagent` (configurable). -- Model asks for this tool with `{agent, messages}` payload. -- Middleware intercepts this tool in `WrapTool` and routes to the configured child agent. - -This keeps delegation explicit, simple, and model-visible. - -## Reference Session-Flow Contract - -Session flows in `/Users/alex/Developer/genkit-super/genkit-read-only/go/ai/x/session_flow.go` emit: - -- Input: `SessionFlowInput{Messages}` -- Stream chunk envelope: `{Chunk, Status, Artifact, SnapshotCreated, EndTurn}` -- Final output: `{State, SnapshotID}` - -For this design, we only fold: - -- `messages` -- `artifacts` - -No custom-state fold modes. - -## Proposed Middleware Surface - -```go -type Subagent struct { - ai.BaseMiddleware - - // Single universal delegation tool name. - ToolName string `json:"toolName,omitempty"` // default: "call_subagent" - - // Agent routing table keyed by logical agent id. - Agents map[string]SubagentTarget `json:"agents,omitempty"` - - // Forward child model chunks to parent stream. - StreamChildren bool `json:"streamChildren,omitempty"` - - // Recursion guard. - MaxDepth int `json:"maxDepth,omitempty"` -} - -type SubagentTarget struct { - Kind string `json:"kind"` // "generate" | "prompt" | "sessionFlow" - Target string `json:"target"` // model/prompt/flow name -} - -type CallSubagentInput struct { - Agent string `json:"agent"` - Messages []*ai.Message `json:"messages,omitempty"` -} - -type SubagentFold struct { - Messages []*ai.Message `json:"messages,omitempty"` - Artifacts []*aix.SessionFlowArtifact `json:"artifacts,omitempty"` -} -``` - -## Critical Path - -1. Middleware injects one tool (`call_subagent`) via `Tools()`. -2. Parent model requests that tool with `{agent, messages}`. -3. `WrapTool` intercepts the tool call and executes selected child target. -4. Child emits stream chunks (optional passthrough to parent stream). -5. Middleware builds a fold object with only `messages` and `artifacts`. -6. Middleware returns `ToolResponse` where: - - `Output` contains child result + fold payload. - - `Content` contains compact natural-language summary for parent model. -7. Parent model continues normal tool loop with this tool response. - -## Pseudocode (Dispatch) - -```go -func (m *Subagent) WrapTool(ctx context.Context, p *ai.ToolParams, next ai.ToolNext) (*ai.ToolResponse, error) { - toolName := m.ToolName - if toolName == "" { - toolName = "call_subagent" - } - if p.Request.Name != toolName { - return next(ctx, p) - } - - if overDepth(ctx, m.MaxDepth) { - return nil, core.NewError(core.ABORTED, "subagent: max depth exceeded") - } - - in := decodeCallSubagentInput(p.Request.Input) // {agent,messages} - target, ok := m.Agents[in.Agent] - if !ok { - return nil, core.NewError(core.NOT_FOUND, "subagent: unknown agent %q", in.Agent) - } - - result, fold, err := m.runChild(ctx, target, in, p) - if err != nil { - return nil, err - } - - return &ai.ToolResponse{ - Name: p.Request.Name, - Output: map[string]any{ - "agent": in.Agent, - "result": result, - "fold": fold, // messages + artifacts only - }, - Content: []*ai.Part{ - ai.NewTextPart(summaryForParent(result, fold)), - }, - }, nil -} -``` - -## Pseudocode (Session-Flow Child) - -```go -func (m *Subagent) runSessionFlowChild(...) (result any, fold *SubagentFold, err error) { - conn, err := lookupSessionFlow(target).StreamBidi(ctx) - if err != nil { - return nil, nil, err - } - defer conn.Close() - - _ = conn.Send(&aix.SessionFlowInput{Messages: in.Messages}) - - fold = &SubagentFold{} - for ch, err := range conn.Receive() { - if err != nil { - return nil, nil, err - } - if m.StreamChildren && ch.Chunk != nil { - _ = emitParentChunk(ctx, ch.Chunk) // requires runtime hook, see gaps - } - if ch.Artifact != nil { - fold.Artifacts = append(fold.Artifacts, ch.Artifact) - } - if ch.EndTurn { - break - } - } - - out, err := conn.Output() - if err != nil { - return nil, nil, err - } - fold.Messages = out.State.Messages - // artifacts in out.State.Artifacts can be used as canonical deduped final set - if len(out.State.Artifacts) > 0 { - fold.Artifacts = out.State.Artifacts - } - return out, fold, nil -} -``` - -## What Was Removed to Simplify - -Removed from prior design: - -- Per-delegate tool names. -- Fold mode matrix (`tool-output` vs `state-bridge`). -- Custom-state folding. -- Typed status/snapshot event streaming as first-class requirement. - -Kept: - -- One universal delegation tool. -- Route table (`agent` -> target). -- Optional child chunk passthrough. -- Canonical fold of messages + artifacts only. - -## Minimal Runtime Gaps to Close - -### 1) Stream passthrough from tool middleware - -Current `ToolParams` has no parent stream emitter. Add: - -```go -EmitChunk func(context.Context, *ai.ModelResponseChunk) error -``` - -Wire from `handleToolRequests` -> `runToolWithMiddleware`. - -Without this, child chunk passthrough cannot happen from `WrapTool`. - -### 2) Flow/session-flow lookup API for middleware - -Middleware currently can look up models via `genkit.FromContext(ctx)` + `LookupModel`, but no equivalent flow/session-flow lookup helper exists for middleware packages. - -Add a minimal public lookup path for flow/session-flow actions. - -### 3) Bidi/runtime availability - -Current `genkit` core branch does not include the bidi/session-flow primitives present in `genkit-read-only`. Session-flow routing should be feature-gated until those primitives land. - -## Recommended Rollout - -### Phase 1 (works on current architecture) - -- Single `call_subagent` tool. -- Route only to `generate` and `prompt` child targets. -- Fold only messages/artifacts into tool output payload. -- No child stream passthrough yet. - -### Phase 2 - -- Enable session-flow routes once bidi primitives are in `genkit`. -- Add `EmitChunk` hook and stream child model chunks to parent stream. - -## Clarifying Questions - -1. Do you want `call_subagent` to accept raw `messages`, or a smaller `{goal, context}` payload that middleware converts to messages? -2. Should folded artifacts use final canonical state only (`out.State.Artifacts`) or include per-chunk artifact events as well? -3. For Phase 1, is no child stream passthrough acceptable until `EmitChunk` support is added? From fd0dec62767b20b5557d2d5c5a26a8433a80ab4d Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Wed, 25 Feb 2026 10:37:28 -0800 Subject: [PATCH 4/5] Apply suggestion from @apascal07 --- go/ai/generate.go | 1 - 1 file changed, 1 deletion(-) diff --git a/go/ai/generate.go b/go/ai/generate.go index 1b28f19c26..6c6921be39 100644 --- a/go/ai/generate.go +++ b/go/ai/generate.go @@ -840,7 +840,6 @@ func ensureToolRequestRefs(msg *Message) { } } -// clone creates a deep copy of the provided object using JSON marshaling and unmarshaling. // handleToolRequests processes any tool requests in the response, returning // either a new request to continue the conversation or nil if no tool requests From c391df8ba767cd64cc3f275e9f15683476e3cb06 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Wed, 25 Feb 2026 10:40:38 -0800 Subject: [PATCH 5/5] Update generate.go --- go/ai/generate.go | 1 - 1 file changed, 1 deletion(-) diff --git a/go/ai/generate.go b/go/ai/generate.go index 6c6921be39..540ef56f5b 100644 --- a/go/ai/generate.go +++ b/go/ai/generate.go @@ -840,7 +840,6 @@ func ensureToolRequestRefs(msg *Message) { } } - // handleToolRequests processes any tool requests in the response, returning // either a new request to continue the conversation or nil if no tool requests // need handling.