Skip to content

Commit f74c758

Browse files
committed
Add task updates to adk
1 parent 017c5bb commit f74c758

File tree

8 files changed

+1243
-0
lines changed

8 files changed

+1243
-0
lines changed

src/agentex/lib/adk/_modules/tasks.py

Lines changed: 301 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@
1212
from agentex.lib.core.temporal.activities.adk.tasks_activities import (
1313
DeleteTaskParams,
1414
GetTaskParams,
15+
QueryWorkflowParams,
1516
TasksActivityName,
17+
TaskStatusTransitionParams,
18+
UpdateTaskParams,
1619
)
1720
from agentex.lib.core.tracing.tracer import AsyncTracer
1821
from agentex.types.task import Task
@@ -128,3 +131,301 @@ async def delete(
128131
trace_id=trace_id,
129132
parent_span_id=parent_span_id,
130133
)
134+
135+
async def cancel(
136+
self,
137+
*,
138+
task_id: str,
139+
reason: str | None = None,
140+
trace_id: str | None = None,
141+
parent_span_id: str | None = None,
142+
start_to_close_timeout: timedelta = timedelta(seconds=5),
143+
heartbeat_timeout: timedelta = timedelta(seconds=5),
144+
retry_policy: RetryPolicy = DEFAULT_RETRY_POLICY,
145+
) -> Task:
146+
"""
147+
Mark a running task as canceled.
148+
Args:
149+
task_id: The ID of the task to cancel.
150+
reason: Optional reason for cancellation.
151+
Returns:
152+
The updated task entry.
153+
"""
154+
params = TaskStatusTransitionParams(
155+
task_id=task_id,
156+
reason=reason,
157+
trace_id=trace_id,
158+
parent_span_id=parent_span_id,
159+
)
160+
if in_temporal_workflow():
161+
return await ActivityHelpers.execute_activity(
162+
activity_name=TasksActivityName.CANCEL_TASK,
163+
request=params,
164+
response_type=Task,
165+
start_to_close_timeout=start_to_close_timeout,
166+
retry_policy=retry_policy,
167+
heartbeat_timeout=heartbeat_timeout,
168+
)
169+
else:
170+
return await self._tasks_service.cancel_task(
171+
task_id=task_id,
172+
reason=reason,
173+
trace_id=trace_id,
174+
parent_span_id=parent_span_id,
175+
)
176+
177+
async def complete(
178+
self,
179+
*,
180+
task_id: str,
181+
reason: str | None = None,
182+
trace_id: str | None = None,
183+
parent_span_id: str | None = None,
184+
start_to_close_timeout: timedelta = timedelta(seconds=5),
185+
heartbeat_timeout: timedelta = timedelta(seconds=5),
186+
retry_policy: RetryPolicy = DEFAULT_RETRY_POLICY,
187+
) -> Task:
188+
"""
189+
Mark a running task as completed.
190+
Args:
191+
task_id: The ID of the task to complete.
192+
reason: Optional reason for completion.
193+
Returns:
194+
The updated task entry.
195+
"""
196+
params = TaskStatusTransitionParams(
197+
task_id=task_id,
198+
reason=reason,
199+
trace_id=trace_id,
200+
parent_span_id=parent_span_id,
201+
)
202+
if in_temporal_workflow():
203+
return await ActivityHelpers.execute_activity(
204+
activity_name=TasksActivityName.COMPLETE_TASK,
205+
request=params,
206+
response_type=Task,
207+
start_to_close_timeout=start_to_close_timeout,
208+
retry_policy=retry_policy,
209+
heartbeat_timeout=heartbeat_timeout,
210+
)
211+
else:
212+
return await self._tasks_service.complete_task(
213+
task_id=task_id,
214+
reason=reason,
215+
trace_id=trace_id,
216+
parent_span_id=parent_span_id,
217+
)
218+
219+
async def fail(
220+
self,
221+
*,
222+
task_id: str,
223+
reason: str | None = None,
224+
trace_id: str | None = None,
225+
parent_span_id: str | None = None,
226+
start_to_close_timeout: timedelta = timedelta(seconds=5),
227+
heartbeat_timeout: timedelta = timedelta(seconds=5),
228+
retry_policy: RetryPolicy = DEFAULT_RETRY_POLICY,
229+
) -> Task:
230+
"""
231+
Mark a running task as failed.
232+
Args:
233+
task_id: The ID of the task to fail.
234+
reason: Optional reason for failure.
235+
Returns:
236+
The updated task entry.
237+
"""
238+
params = TaskStatusTransitionParams(
239+
task_id=task_id,
240+
reason=reason,
241+
trace_id=trace_id,
242+
parent_span_id=parent_span_id,
243+
)
244+
if in_temporal_workflow():
245+
return await ActivityHelpers.execute_activity(
246+
activity_name=TasksActivityName.FAIL_TASK,
247+
request=params,
248+
response_type=Task,
249+
start_to_close_timeout=start_to_close_timeout,
250+
retry_policy=retry_policy,
251+
heartbeat_timeout=heartbeat_timeout,
252+
)
253+
else:
254+
return await self._tasks_service.fail_task(
255+
task_id=task_id,
256+
reason=reason,
257+
trace_id=trace_id,
258+
parent_span_id=parent_span_id,
259+
)
260+
261+
async def terminate(
262+
self,
263+
*,
264+
task_id: str,
265+
reason: str | None = None,
266+
trace_id: str | None = None,
267+
parent_span_id: str | None = None,
268+
start_to_close_timeout: timedelta = timedelta(seconds=5),
269+
heartbeat_timeout: timedelta = timedelta(seconds=5),
270+
retry_policy: RetryPolicy = DEFAULT_RETRY_POLICY,
271+
) -> Task:
272+
"""
273+
Mark a running task as terminated.
274+
Args:
275+
task_id: The ID of the task to terminate.
276+
reason: Optional reason for termination.
277+
Returns:
278+
The updated task entry.
279+
"""
280+
params = TaskStatusTransitionParams(
281+
task_id=task_id,
282+
reason=reason,
283+
trace_id=trace_id,
284+
parent_span_id=parent_span_id,
285+
)
286+
if in_temporal_workflow():
287+
return await ActivityHelpers.execute_activity(
288+
activity_name=TasksActivityName.TERMINATE_TASK,
289+
request=params,
290+
response_type=Task,
291+
start_to_close_timeout=start_to_close_timeout,
292+
retry_policy=retry_policy,
293+
heartbeat_timeout=heartbeat_timeout,
294+
)
295+
else:
296+
return await self._tasks_service.terminate_task(
297+
task_id=task_id,
298+
reason=reason,
299+
trace_id=trace_id,
300+
parent_span_id=parent_span_id,
301+
)
302+
303+
async def timeout(
304+
self,
305+
*,
306+
task_id: str,
307+
reason: str | None = None,
308+
trace_id: str | None = None,
309+
parent_span_id: str | None = None,
310+
start_to_close_timeout: timedelta = timedelta(seconds=5),
311+
heartbeat_timeout: timedelta = timedelta(seconds=5),
312+
retry_policy: RetryPolicy = DEFAULT_RETRY_POLICY,
313+
) -> Task:
314+
"""
315+
Mark a running task as timed out.
316+
Args:
317+
task_id: The ID of the task to time out.
318+
reason: Optional reason for timeout.
319+
Returns:
320+
The updated task entry.
321+
"""
322+
params = TaskStatusTransitionParams(
323+
task_id=task_id,
324+
reason=reason,
325+
trace_id=trace_id,
326+
parent_span_id=parent_span_id,
327+
)
328+
if in_temporal_workflow():
329+
return await ActivityHelpers.execute_activity(
330+
activity_name=TasksActivityName.TIMEOUT_TASK,
331+
request=params,
332+
response_type=Task,
333+
start_to_close_timeout=start_to_close_timeout,
334+
retry_policy=retry_policy,
335+
heartbeat_timeout=heartbeat_timeout,
336+
)
337+
else:
338+
return await self._tasks_service.timeout_task(
339+
task_id=task_id,
340+
reason=reason,
341+
trace_id=trace_id,
342+
parent_span_id=parent_span_id,
343+
)
344+
345+
async def update(
346+
self,
347+
*,
348+
task_id: str | None = None,
349+
task_name: str | None = None,
350+
task_metadata: dict[str, object] | None = None,
351+
trace_id: str | None = None,
352+
parent_span_id: str | None = None,
353+
start_to_close_timeout: timedelta = timedelta(seconds=5),
354+
heartbeat_timeout: timedelta = timedelta(seconds=5),
355+
retry_policy: RetryPolicy = DEFAULT_RETRY_POLICY,
356+
) -> Task:
357+
"""
358+
Update mutable fields for a task by ID or name.
359+
Args:
360+
task_id: The ID of the task to update.
361+
task_name: The name of the task to update.
362+
task_metadata: Metadata to update on the task.
363+
Returns:
364+
The updated task entry.
365+
"""
366+
params = UpdateTaskParams(
367+
task_id=task_id,
368+
task_name=task_name,
369+
task_metadata=task_metadata,
370+
trace_id=trace_id,
371+
parent_span_id=parent_span_id,
372+
)
373+
if in_temporal_workflow():
374+
return await ActivityHelpers.execute_activity(
375+
activity_name=TasksActivityName.UPDATE_TASK,
376+
request=params,
377+
response_type=Task,
378+
start_to_close_timeout=start_to_close_timeout,
379+
retry_policy=retry_policy,
380+
heartbeat_timeout=heartbeat_timeout,
381+
)
382+
else:
383+
return await self._tasks_service.update_task(
384+
task_id=task_id,
385+
task_name=task_name,
386+
task_metadata=task_metadata,
387+
trace_id=trace_id,
388+
parent_span_id=parent_span_id,
389+
)
390+
391+
async def query_workflow(
392+
self,
393+
*,
394+
task_id: str,
395+
query_name: str,
396+
trace_id: str | None = None,
397+
parent_span_id: str | None = None,
398+
start_to_close_timeout: timedelta = timedelta(seconds=5),
399+
heartbeat_timeout: timedelta = timedelta(seconds=5),
400+
retry_policy: RetryPolicy = DEFAULT_RETRY_POLICY,
401+
) -> dict[str, object]:
402+
"""
403+
Query a Temporal workflow associated with a task for its current state.
404+
Args:
405+
task_id: The ID of the task whose workflow to query.
406+
query_name: The name of the query to execute.
407+
Returns:
408+
The query result.
409+
"""
410+
params = QueryWorkflowParams(
411+
task_id=task_id,
412+
query_name=query_name,
413+
trace_id=trace_id,
414+
parent_span_id=parent_span_id,
415+
)
416+
if in_temporal_workflow():
417+
return await ActivityHelpers.execute_activity(
418+
activity_name=TasksActivityName.QUERY_WORKFLOW,
419+
request=params,
420+
response_type=dict,
421+
start_to_close_timeout=start_to_close_timeout,
422+
retry_policy=retry_policy,
423+
heartbeat_timeout=heartbeat_timeout,
424+
)
425+
else:
426+
return await self._tasks_service.query_workflow(
427+
task_id=task_id,
428+
query_name=query_name,
429+
trace_id=trace_id,
430+
parent_span_id=parent_span_id,
431+
)

0 commit comments

Comments
 (0)