Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 83 additions & 40 deletions js/ai/src/tool.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import {
stripUndefinedProps,
z,
type Action,
type ActionContext,
type ActionRunOptions,
type JSONSchema7,
} from '@genkit-ai/core';
Expand Down Expand Up @@ -280,56 +279,81 @@ export function toToolDefinition(
return out;
}

export interface ToolFnOptions extends ActionFnArg<never> {
/**
* Options passed to tool callbacks. Context is typed as C & ActionContext when using defineTool with a typed Genkit instance (genkit<AppContext>()).
*/
export interface ToolFnOptions<C extends object = object>
extends ActionFnArg<never, C> {
/**
* A function that can be called during tool execution that will result in the tool
* getting interrupted (immediately) and tool request returned to the upstream caller.
*/
interrupt: (metadata?: Record<string, any>) => never;

context: ActionContext;
}

export type ToolFn<I extends z.ZodTypeAny, O extends z.ZodTypeAny> = (
export type ToolFn<
I extends z.ZodTypeAny = z.ZodTypeAny,
O extends z.ZodTypeAny = z.ZodTypeAny,
C extends object = object,
> = (
input: z.infer<I>,
ctx: ToolFnOptions & ToolRunOptions
ctx: ToolFnOptions<C> & ToolRunOptions
) => Promise<z.infer<O>>;

export type MultipartToolFn<I extends z.ZodTypeAny, O extends z.ZodTypeAny> = (
export type MultipartToolFn<
I extends z.ZodTypeAny = z.ZodTypeAny,
O extends z.ZodTypeAny = z.ZodTypeAny,
C extends object = object,
> = (
input: z.infer<I>,
ctx: ToolFnOptions & ToolRunOptions
ctx: ToolFnOptions<C> & ToolRunOptions
) => Promise<{
output?: z.infer<O>;
content?: Part[];
}>;

export function defineTool<I extends z.ZodTypeAny, O extends z.ZodTypeAny>(
export function defineTool<
I extends z.ZodTypeAny = z.ZodTypeAny,
O extends z.ZodTypeAny = z.ZodTypeAny,
C extends object = object,
>(
registry: Registry,
config: { multipart: true } & ToolConfig<I, O>,
fn?: ToolFn<I, O>
fn?: ToolFn<I, O, C>
): MultipartToolAction<I, O>;
export function defineTool<I extends z.ZodTypeAny, O extends z.ZodTypeAny>(
export function defineTool<
I extends z.ZodTypeAny = z.ZodTypeAny,
O extends z.ZodTypeAny = z.ZodTypeAny,
C extends object = object,
>(
registry: Registry,
config: ToolConfig<I, O>,
fn?: ToolFn<I, O>
fn?: ToolFn<I, O, C>
): ToolAction<I, O>;

/**
* Defines a tool.
*
* A tool is an action that can be passed to a model to be called automatically if it so chooses.
*/
export function defineTool<I extends z.ZodTypeAny, O extends z.ZodTypeAny>(
export function defineTool<
I extends z.ZodTypeAny = z.ZodTypeAny,
O extends z.ZodTypeAny = z.ZodTypeAny,
C extends object = object,
>(
registry: Registry,
config: { multipart?: true } & ToolConfig<I, O>,
fn?: ToolFn<I, O> | MultipartToolFn<I, O>
fn?: ToolFn<I, O, C> | MultipartToolFn<I, O, C>
): ToolAction<I, O> | MultipartToolAction<I, O> {
const a = tool(config, fn);
delete a.__action.metadata.dynamic;
registry.registerAction(config.multipart ? 'tool.v2' : 'tool', a);
if (!config.multipart) {
// For non-multipart tools, we register a v2 tool action as well
registry.registerAction('tool.v2', basicToolV2(config, fn as ToolFn<I, O>));
registry.registerAction(
'tool.v2',
basicToolV2(config, fn as ToolFn<I, O, C>)
);
}
return a as ToolAction<I, O>;
}
Expand Down Expand Up @@ -469,30 +493,40 @@ function interruptTool(registry?: Registry) {
};
}

export function tool<I extends z.ZodTypeAny, O extends z.ZodTypeAny>(
export function tool<
I extends z.ZodTypeAny = z.ZodTypeAny,
O extends z.ZodTypeAny = z.ZodTypeAny,
C extends object = object,
>(
config: { multipart: true } & ToolConfig<I, O>,
fn?: ToolFn<I, O>
fn?: ToolFn<I, O, C>
): MultipartToolAction<I, O>;
export function tool<I extends z.ZodTypeAny, O extends z.ZodTypeAny>(
config: ToolConfig<I, O>,
fn?: ToolFn<I, O>
): ToolAction<I, O>;
export function tool<
I extends z.ZodTypeAny = z.ZodTypeAny,
O extends z.ZodTypeAny = z.ZodTypeAny,
C extends object = object,
>(config: ToolConfig<I, O>, fn?: ToolFn<I, O, C>): ToolAction<I, O>;

/**
* Defines a dynamic tool. Dynamic tools are just like regular tools but will not be registered in the
* Genkit registry and can be defined dynamically at runtime.
*/
export function tool<I extends z.ZodTypeAny, O extends z.ZodTypeAny>(
export function tool<
I extends z.ZodTypeAny = z.ZodTypeAny,
O extends z.ZodTypeAny = z.ZodTypeAny,
C extends object = object,
>(
config: { multipart?: true } & ToolConfig<I, O>,
fn?: ToolFn<I, O> | MultipartToolFn<I, O>
fn?: ToolFn<I, O, C> | MultipartToolFn<I, O, C>
): ToolAction<I, O> | MultipartToolAction<I, O> {
return config.multipart ? multipartTool(config, fn) : basicTool(config, fn);
}

function basicTool<I extends z.ZodTypeAny, O extends z.ZodTypeAny>(
config: ToolConfig<I, O>,
fn?: ToolFn<I, O>
): ToolAction<I, O> {
function basicTool<
I extends z.ZodTypeAny = z.ZodTypeAny,
O extends z.ZodTypeAny = z.ZodTypeAny,
C extends object = object,
>(config: ToolConfig<I, O>, fn?: ToolFn<I, O, C>): ToolAction<I, O> {
const a = action(
{
...config,
Expand All @@ -506,7 +540,7 @@ function basicTool<I extends z.ZodTypeAny, O extends z.ZodTypeAny>(
...runOptions,
context: { ...runOptions.context },
interrupt,
});
} as unknown as ToolFnOptions<C> & ToolRunOptions);
}
return interrupt();
}
Expand All @@ -515,24 +549,32 @@ function basicTool<I extends z.ZodTypeAny, O extends z.ZodTypeAny>(
return a;
}

function basicToolV2<I extends z.ZodTypeAny, O extends z.ZodTypeAny>(
config: ToolConfig<I, O>,
fn?: ToolFn<I, O>
): MultipartToolAction<I, O> {
function basicToolV2<
I extends z.ZodTypeAny = z.ZodTypeAny,
O extends z.ZodTypeAny = z.ZodTypeAny,
C extends object = object,
>(config: ToolConfig<I, O>, fn?: ToolFn<I, O, C>): MultipartToolAction<I, O> {
return multipartTool(config, async (input, ctx) => {
if (!fn) {
const interrupt = interruptTool(ctx.registry);
return interrupt();
}
return {
output: await fn(input, ctx),
output: await fn(
input,
ctx as unknown as ToolFnOptions<C> & ToolRunOptions
),
};
});
}

function multipartTool<I extends z.ZodTypeAny, O extends z.ZodTypeAny>(
function multipartTool<
I extends z.ZodTypeAny = z.ZodTypeAny,
O extends z.ZodTypeAny = z.ZodTypeAny,
C extends object = object,
>(
config: ToolConfig<I, O>,
fn?: MultipartToolFn<I, O>
fn?: MultipartToolFn<I, O, C>
): MultipartToolAction<I, O> {
const a = action(
{
Expand All @@ -553,7 +595,7 @@ function multipartTool<I extends z.ZodTypeAny, O extends z.ZodTypeAny>(
...runOptions,
context: { ...runOptions.context },
interrupt,
});
} as unknown as ToolFnOptions<C> & ToolRunOptions);
}
return interrupt() as any; // we cast to any because `interrupt` throws.
}
Expand All @@ -568,10 +610,11 @@ function multipartTool<I extends z.ZodTypeAny, O extends z.ZodTypeAny>(
*
* @deprecated renamed to {@link tool}.
*/
export function dynamicTool<I extends z.ZodTypeAny, O extends z.ZodTypeAny>(
config: ToolConfig<I, O>,
fn?: ToolFn<I, O>
): DynamicToolAction<I, O> {
export function dynamicTool<
I extends z.ZodTypeAny = z.ZodTypeAny,
O extends z.ZodTypeAny = z.ZodTypeAny,
C extends object = object,
>(config: ToolConfig<I, O>, fn?: ToolFn<I, O, C>): DynamicToolAction<I, O> {
const t = basicTool(config, fn) as DynamicToolAction<I, O>;
t.attach = (_: Registry) => t;
return t;
Expand Down
7 changes: 4 additions & 3 deletions js/core/src/action.ts
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,9 @@ export interface ActionRunOptions<S> {

/**
* Options (side channel) data to pass to the model.
* @template C - Your app's context shape; ActionContext (auth, etc.) is merged in automatically, so you only declare your own fields.
*/
export interface ActionFnArg<S> {
export interface ActionFnArg<S, C extends object = object> {
/**
* Whether the caller of the action requested streaming.
*/
Expand All @@ -133,9 +134,9 @@ export interface ActionFnArg<S> {
sendChunk: StreamingCallback<S>;

/**
* Additional runtime context data (ex. auth context data).
* Runtime context (always set when the action runs). Typed as C & ActionContext so auth and other built-in context are always available.
*/
context?: ActionContext;
context: C & ActionContext;

/**
* Trace context containing trace and span IDs.
Expand Down
2 changes: 1 addition & 1 deletion js/core/src/context.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import { UserFacingError } from './error.js';
const contextAlsKey = 'core.auth.context';

/**
* Action side channel data, like auth and other invocation context infromation provided by the invoker.
* Action side channel data, like auth and other invocation context information provided by the invoker.
*/
export interface ActionContext {
/** Information about the currently authenticated user if provided. */
Expand Down
30 changes: 19 additions & 11 deletions js/core/src/flow.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import type { z } from 'zod';
import { ActionFnArg, action, type Action } from './action.js';
import type { ActionContext } from './context.js';
import { Registry, type HasRegistry } from './registry.js';
import { SPAN_TYPE_ATTR, runInNewSpan } from './tracing.js';

Expand Down Expand Up @@ -53,22 +54,25 @@ export interface FlowConfig<
* side-channel context data. The context itself is a function, a short-cut
* for streaming callback.
*/
export interface FlowSideChannel<S> extends ActionFnArg<S> {
export interface FlowSideChannel<S, C extends object = object>
extends ActionFnArg<S, C> {
(chunk: S): void;
}

/**
* Function to be executed in the flow.
* @template C - Your app's context shape; ActionContext is merged in automatically in the callback's context.
*/
export type FlowFn<
I extends z.ZodTypeAny = z.ZodTypeAny,
O extends z.ZodTypeAny = z.ZodTypeAny,
S extends z.ZodTypeAny = z.ZodTypeAny,
C extends object = object,
> = (
/** Input to the flow. */
input: z.infer<I>,
/** Callback for streaming functions only. */
streamingCallback: FlowSideChannel<z.infer<S>>
streamingCallback: FlowSideChannel<z.infer<S>, C>
) => Promise<z.infer<O>> | z.infer<O>;

/**
Expand All @@ -78,7 +82,8 @@ export function flow<
I extends z.ZodTypeAny = z.ZodTypeAny,
O extends z.ZodTypeAny = z.ZodTypeAny,
S extends z.ZodTypeAny = z.ZodTypeAny,
>(config: FlowConfig<I, O, S> | string, fn: FlowFn<I, O, S>): Flow<I, O, S> {
C extends object = object,
>(config: FlowConfig<I, O, S> | string, fn: FlowFn<I, O, S, C>): Flow<I, O, S> {
const resolvedConfig: FlowConfig<I, O, S> =
typeof config === 'string' ? { name: config } : config;

Expand All @@ -92,10 +97,11 @@ export function defineFlow<
I extends z.ZodTypeAny = z.ZodTypeAny,
O extends z.ZodTypeAny = z.ZodTypeAny,
S extends z.ZodTypeAny = z.ZodTypeAny,
C extends object = object,
>(
registry: Registry,
config: FlowConfig<I, O, S> | string,
fn: FlowFn<I, O, S>
fn: FlowFn<I, O, S, C>
): Flow<I, O, S> {
const f = flow(config, fn);

Expand All @@ -111,7 +117,8 @@ function flowAction<
I extends z.ZodTypeAny = z.ZodTypeAny,
O extends z.ZodTypeAny = z.ZodTypeAny,
S extends z.ZodTypeAny = z.ZodTypeAny,
>(config: FlowConfig<I, O, S>, fn: FlowFn<I, O, S>): Flow<I, O, S> {
C extends object = object,
>(config: FlowConfig<I, O, S>, fn: FlowFn<I, O, S, C>): Flow<I, O, S> {
return action(
{
actionType: 'flow',
Expand All @@ -126,13 +133,14 @@ function flowAction<
{ sendChunk, context, trace, abortSignal, streamingRequested }
) => {
const ctx = sendChunk;
(ctx as FlowSideChannel<z.infer<S>>).sendChunk = sendChunk;
(ctx as FlowSideChannel<z.infer<S>>).context = context;
(ctx as FlowSideChannel<z.infer<S>>).trace = trace;
(ctx as FlowSideChannel<z.infer<S>>).abortSignal = abortSignal;
(ctx as FlowSideChannel<z.infer<S>>).streamingRequested =
(ctx as FlowSideChannel<z.infer<S>, C>).sendChunk = sendChunk;
(ctx as FlowSideChannel<z.infer<S>, C>).context = context as C &
ActionContext;
(ctx as FlowSideChannel<z.infer<S>, C>).trace = trace;
(ctx as FlowSideChannel<z.infer<S>, C>).abortSignal = abortSignal;
(ctx as FlowSideChannel<z.infer<S>, C>).streamingRequested =
streamingRequested;
return fn(input, ctx as FlowSideChannel<z.infer<S>>);
return fn(input, ctx as FlowSideChannel<z.infer<S>, C>);
}
);
}
Expand Down
Loading