Skip to content

Commit 923e350

Browse files
authored
Add AgentTask feature (#1045)
1 parent 0df9dc4 commit 923e350

25 files changed

Lines changed: 1838 additions & 442 deletions

.changeset/curvy-coins-boil.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
---
2+
"@livekit/agents": patch
3+
"@livekit/agents-plugin-google": patch
4+
"@livekit/agents-plugin-livekit": patch
5+
"@livekit/agents-plugin-openai": patch
6+
"livekit-agents-examples": patch
7+
---
8+
9+
Implement AgentTask feature

agents/src/cli.ts

Lines changed: 20 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -77,16 +77,16 @@ const runServer = async (args: CliArgs) => {
7777
* ```
7878
*/
7979
export const runApp = (opts: ServerOptions) => {
80+
const logLevelOption = (defaultLevel: string) =>
81+
new Option('--log-level <level>', 'Set the logging level')
82+
.choices(['trace', 'debug', 'info', 'warn', 'error', 'fatal'])
83+
.default(defaultLevel)
84+
.env('LOG_LEVEL');
85+
8086
const program = new Command()
8187
.name('agents')
8288
.description('LiveKit Agents CLI')
8389
.version(version)
84-
.addOption(
85-
new Option('--log-level <level>', 'Set the logging level')
86-
.choices(['trace', 'debug', 'info', 'warn', 'error', 'fatal'])
87-
.default('info')
88-
.env('LOG_LEVEL'),
89-
)
9090
.addOption(
9191
new Option('--url <string>', 'LiveKit server or Cloud project websocket URL').env(
9292
'LIVEKIT_URL',
@@ -120,13 +120,15 @@ export const runApp = (opts: ServerOptions) => {
120120
program
121121
.command('start')
122122
.description('Start the worker in production mode')
123-
.action(() => {
124-
const options = program.optsWithGlobals();
125-
opts.wsURL = options.url || opts.wsURL;
126-
opts.apiKey = options.apiKey || opts.apiKey;
127-
opts.apiSecret = options.apiSecret || opts.apiSecret;
128-
opts.logLevel = options.logLevel || opts.logLevel;
129-
opts.workerToken = options.workerToken || opts.workerToken;
123+
.addOption(logLevelOption('info'))
124+
.action((...[, command]) => {
125+
const globalOptions = program.optsWithGlobals();
126+
const commandOptions = command.opts();
127+
opts.wsURL = globalOptions.url || opts.wsURL;
128+
opts.apiKey = globalOptions.apiKey || opts.apiKey;
129+
opts.apiSecret = globalOptions.apiSecret || opts.apiSecret;
130+
opts.logLevel = commandOptions.logLevel;
131+
opts.workerToken = globalOptions.workerToken || opts.workerToken;
130132
runServer({
131133
opts,
132134
production: true,
@@ -137,19 +139,14 @@ export const runApp = (opts: ServerOptions) => {
137139
program
138140
.command('dev')
139141
.description('Start the worker in development mode')
140-
.addOption(
141-
new Option('--log-level <level>', 'Set the logging level')
142-
.choices(['trace', 'debug', 'info', 'warn', 'error', 'fatal'])
143-
.default('debug')
144-
.env('LOG_LEVEL'),
145-
)
142+
.addOption(logLevelOption('debug'))
146143
.action((...[, command]) => {
147144
const globalOptions = program.optsWithGlobals();
148145
const commandOptions = command.opts();
149146
opts.wsURL = globalOptions.url || opts.wsURL;
150147
opts.apiKey = globalOptions.apiKey || opts.apiKey;
151148
opts.apiSecret = globalOptions.apiSecret || opts.apiSecret;
152-
opts.logLevel = commandOptions.logLevel || globalOptions.logLevel || opts.logLevel;
149+
opts.logLevel = commandOptions.logLevel;
153150
opts.workerToken = globalOptions.workerToken || opts.workerToken;
154151
runServer({
155152
opts,
@@ -163,19 +160,14 @@ export const runApp = (opts: ServerOptions) => {
163160
.description('Connect to a specific room')
164161
.requiredOption('--room <string>', 'Room name to connect to')
165162
.option('--participant-identity <string>', 'Identity of user to listen to')
166-
.addOption(
167-
new Option('--log-level <level>', 'Set the logging level')
168-
.choices(['trace', 'debug', 'info', 'warn', 'error', 'fatal'])
169-
.default('debug')
170-
.env('LOG_LEVEL'),
171-
)
163+
.addOption(logLevelOption('info'))
172164
.action((...[, command]) => {
173165
const globalOptions = program.optsWithGlobals();
174166
const commandOptions = command.opts();
175167
opts.wsURL = globalOptions.url || opts.wsURL;
176168
opts.apiKey = globalOptions.apiKey || opts.apiKey;
177169
opts.apiSecret = globalOptions.apiSecret || opts.apiSecret;
178-
opts.logLevel = commandOptions.logLevel || globalOptions.logLevel || opts.logLevel;
170+
opts.logLevel = commandOptions.logLevel;
179171
opts.workerToken = globalOptions.workerToken || opts.workerToken;
180172
runServer({
181173
opts,
@@ -189,12 +181,7 @@ export const runApp = (opts: ServerOptions) => {
189181
program
190182
.command('download-files')
191183
.description('Download plugin dependency files')
192-
.addOption(
193-
new Option('--log-level <level>', 'Set the logging level')
194-
.choices(['trace', 'debug', 'info', 'warn', 'error', 'fatal'])
195-
.default('debug')
196-
.env('LOG_LEVEL'),
197-
)
184+
.addOption(logLevelOption('debug'))
198185
.action((...[, command]) => {
199186
const commandOptions = command.opts();
200187
initializeLogger({ pretty: true, level: commandOptions.logLevel });

agents/src/ipc/job_proc_lazy_main.ts

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,14 @@ import type { IPCMessage } from './message.js';
1515

1616
const ORPHANED_TIMEOUT = 15 * 1000;
1717

18+
const safeSend = (msg: IPCMessage): boolean => {
19+
if (process.connected && process.send) {
20+
process.send(msg);
21+
return true;
22+
}
23+
return false;
24+
};
25+
1826
type JobTask = {
1927
ctx: JobContext;
2028
task: Promise<void>;
@@ -50,7 +58,10 @@ class InfClient implements InferenceExecutor {
5058

5159
async doInference(method: string, data: unknown): Promise<unknown> {
5260
const requestId = shortuuid('inference_job_');
53-
process.send!({ case: 'inferenceRequest', value: { requestId, method, data } });
61+
if (!safeSend({ case: 'inferenceRequest', value: { requestId, method, data } })) {
62+
throw new Error('IPC channel closed');
63+
}
64+
5465
this.#requests[requestId] = new PendingInference();
5566
const resp = await this.#requests[requestId]!.promise;
5667
if (resp.error) {
@@ -117,7 +128,7 @@ const startJob = (
117128
await once(closeEvent, 'close').then((close) => {
118129
logger.debug('shutting down');
119130
shutdown = true;
120-
process.send!({ case: 'exiting', value: { reason: close[1] } });
131+
safeSend({ case: 'exiting', value: { reason: close[1] } });
121132
});
122133

123134
// Close the primary agent session if it exists
@@ -139,7 +150,7 @@ const startJob = (
139150
logger.error({ error }, 'error while shutting down the job'),
140151
);
141152

142-
process.send!({ case: 'done' });
153+
safeSend({ case: 'done', value: undefined });
143154
joinFuture.resolve();
144155
})();
145156

@@ -199,7 +210,7 @@ const startJob = (
199210
logger.debug('initializing job runner');
200211
await agent.prewarm(proc);
201212
logger.debug('job runner initialized');
202-
process.send({ case: 'initializeResponse' });
213+
safeSend({ case: 'initializeResponse', value: undefined });
203214

204215
let job: JobTask | undefined = undefined;
205216
const closeEvent = new EventEmitter();
@@ -213,7 +224,7 @@ const startJob = (
213224
switch (msg.case) {
214225
case 'pingRequest': {
215226
orphanedTimeout.refresh();
216-
process.send!({
227+
safeSend({
217228
case: 'pongResponse',
218229
value: { lastTimestamp: msg.value.timestamp, timestamp: Date.now() },
219230
});

agents/src/llm/chat_context.ts

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -510,6 +510,41 @@ export class ChatContext {
510510
return new ChatContext(items);
511511
}
512512

513+
merge(
514+
other: ChatContext,
515+
options: {
516+
excludeFunctionCall?: boolean;
517+
excludeInstructions?: boolean;
518+
} = {},
519+
): ChatContext {
520+
const { excludeFunctionCall = false, excludeInstructions = false } = options;
521+
const existingIds = new Set(this._items.map((item) => item.id));
522+
523+
for (const item of other.items) {
524+
if (excludeFunctionCall && ['function_call', 'function_call_output'].includes(item.type)) {
525+
continue;
526+
}
527+
528+
if (
529+
excludeInstructions &&
530+
item.type === 'message' &&
531+
(item.role === 'system' || item.role === 'developer')
532+
) {
533+
continue;
534+
}
535+
536+
if (existingIds.has(item.id)) {
537+
continue;
538+
}
539+
540+
const idx = this.findInsertionIndex(item.createdAt);
541+
this._items.splice(idx, 0, item);
542+
existingIds.add(item.id);
543+
}
544+
545+
return this;
546+
}
547+
513548
truncate(maxItems: number): ChatContext {
514549
if (maxItems <= 0) return this;
515550

agents/src/llm/provider_format/utils.ts

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,14 @@ class ChatItemGroup {
5656
}
5757

5858
removeInvalidToolCalls() {
59-
if (this.toolCalls.length === this.toolOutputs.length) {
60-
return;
61-
}
62-
6359
const toolCallIds = new Set(this.toolCalls.map((call) => call.callId));
6460
const toolOutputIds = new Set(this.toolOutputs.map((output) => output.callId));
61+
const sameIds =
62+
toolCallIds.size === toolOutputIds.size &&
63+
[...toolCallIds].every((id) => toolOutputIds.has(id));
64+
if (this.toolCalls.length === this.toolOutputs.length && sameIds) {
65+
return;
66+
}
6567

6668
// intersection of tool call ids and tool output ids
6769
const validCallIds = intersection(toolCallIds, toolOutputIds);

agents/src/llm/realtime.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ export interface RealtimeCapabilities {
4848
userTranscription: boolean;
4949
autoToolReplyGeneration: boolean;
5050
audioOutput: boolean;
51+
manualFunctionCalls: boolean;
5152
}
5253

5354
export interface InputTranscriptionCompleted {

agents/src/stream/deferred_stream.ts

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -59,16 +59,17 @@ export class DeferredReadableStream<T> {
5959
throw new Error('Stream source already set');
6060
}
6161

62-
this.sourceReader = source.getReader();
63-
this.pump();
62+
const sourceReader = source.getReader();
63+
this.sourceReader = sourceReader;
64+
void this.pump(sourceReader);
6465
}
6566

66-
private async pump() {
67+
private async pump(sourceReader: ReadableStreamDefaultReader<T>) {
6768
let sourceError: unknown;
6869

6970
try {
7071
while (true) {
71-
const { done, value } = await this.sourceReader!.read();
72+
const { done, value } = await sourceReader.read();
7273
if (done) break;
7374
await this.writer.write(value);
7475
}
@@ -81,7 +82,7 @@ export class DeferredReadableStream<T> {
8182
// any other error from source will be propagated to the consumer
8283
if (sourceError) {
8384
try {
84-
this.writer.abort(sourceError);
85+
await this.writer.abort(sourceError);
8586
} catch (e) {
8687
// ignore if writer is already closed
8788
}
@@ -118,10 +119,20 @@ export class DeferredReadableStream<T> {
118119
return;
119120
}
120121

122+
const sourceReader = this.sourceReader!;
123+
// Clear source first so future setSource() calls can reattach cleanly.
124+
this.sourceReader = undefined;
125+
121126
// release lock will make any pending read() throw TypeError
122127
// which are expected, and we intentionally catch those error
123128
// using isStreamReaderReleaseError
124129
// this will unblock any pending read() inside the async for loop
125-
this.sourceReader!.releaseLock();
130+
try {
131+
sourceReader.releaseLock();
132+
} catch (e) {
133+
if (!isStreamReaderReleaseError(e)) {
134+
throw e;
135+
}
136+
}
126137
}
127138
}

agents/src/utils.test.ts

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -469,6 +469,93 @@ describe('utils', () => {
469469
expect((error as Error).name).toBe('TypeError');
470470
}
471471
});
472+
473+
it('should return undefined for Task.current outside task context', () => {
474+
expect(Task.current()).toBeUndefined();
475+
});
476+
477+
it('should preserve Task.current inside a task across awaits', async () => {
478+
const task = Task.from(
479+
async () => {
480+
const currentAtStart = Task.current();
481+
await delay(5);
482+
const currentAfterAwait = Task.current();
483+
484+
expect(currentAtStart).toBeDefined();
485+
expect(currentAfterAwait).toBe(currentAtStart);
486+
487+
return currentAtStart;
488+
},
489+
undefined,
490+
'current-context-test',
491+
);
492+
493+
const currentFromResult = await task.result;
494+
expect(currentFromResult).toBe(task);
495+
});
496+
497+
it('should isolate nested Task.current context and restore parent context', async () => {
498+
const parentTask = Task.from(
499+
async (controller) => {
500+
const parentCurrent = Task.current();
501+
expect(parentCurrent).toBeDefined();
502+
503+
const childTask = Task.from(
504+
async () => {
505+
const childCurrentStart = Task.current();
506+
await delay(5);
507+
const childCurrentAfterAwait = Task.current();
508+
509+
expect(childCurrentStart).toBeDefined();
510+
expect(childCurrentAfterAwait).toBe(childCurrentStart);
511+
expect(childCurrentStart).not.toBe(parentCurrent);
512+
513+
return childCurrentStart;
514+
},
515+
controller,
516+
'child-current-context-test',
517+
);
518+
519+
const childCurrent = await childTask.result;
520+
const parentCurrentAfterChild = Task.current();
521+
522+
expect(parentCurrentAfterChild).toBe(parentCurrent);
523+
524+
return { parentCurrent, childCurrent };
525+
},
526+
undefined,
527+
'parent-current-context-test',
528+
);
529+
530+
const { parentCurrent, childCurrent } = await parentTask.result;
531+
expect(parentCurrent).toBe(parentTask);
532+
expect(childCurrent).not.toBe(parentCurrent);
533+
expect(Task.current()).toBeUndefined();
534+
});
535+
536+
it('should always expose Task.current for concurrent task callbacks', async () => {
537+
const tasks = Array.from({ length: 25 }, (_, idx) =>
538+
Task.from(
539+
async () => {
540+
const currentAtStart = Task.current();
541+
await delay(1);
542+
const currentAfterAwait = Task.current();
543+
544+
expect(currentAtStart).toBeDefined();
545+
expect(currentAfterAwait).toBe(currentAtStart);
546+
547+
return currentAtStart;
548+
},
549+
undefined,
550+
`current-context-stress-${idx}`,
551+
),
552+
);
553+
554+
const currentTasks = await Promise.all(tasks.map((task) => task.result));
555+
currentTasks.forEach((currentTask, idx) => {
556+
expect(currentTask).toBe(tasks[idx]);
557+
});
558+
});
472559
});
473560

474561
describe('Event', () => {

0 commit comments

Comments
 (0)