-
Notifications
You must be signed in to change notification settings - Fork 64
Expand file tree
/
Copy pathparallel_agent.py
More file actions
66 lines (49 loc) · 2.49 KB
/
parallel_agent.py
File metadata and controls
66 lines (49 loc) · 2.49 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import Optional
from google.adk.agents import ParallelAgent as GoogleADKParallelAgent
from google.adk.agents.base_agent import BaseAgent
from pydantic import ConfigDict, Field
from typing_extensions import Any
from veadk.memory.short_term_memory import ShortTermMemory
from veadk.prompts.agent_default_prompt import DEFAULT_DESCRIPTION, DEFAULT_INSTRUCTION
from veadk.tracing.base_tracer import BaseTracer
from veadk.utils.logger import get_logger
from veadk.utils.patches import patch_asyncio
patch_asyncio()
logger = get_logger(__name__)
class ParallelAgent(GoogleADKParallelAgent):
"""LLM-based Agent with Volcengine capabilities."""
model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow")
"""The model config"""
name: str = "veParallelAgent"
"""The name of the agent."""
description: str = DEFAULT_DESCRIPTION
"""The description of the agent. This will be helpful in A2A scenario."""
instruction: str = DEFAULT_INSTRUCTION
"""The instruction for the agent, such as principles of function calling."""
sub_agents: list[BaseAgent] = Field(default_factory=list, exclude=True)
"""The sub agents provided to agent."""
tracers: list[BaseTracer] = []
"""The tracers provided to agent."""
short_term_memory: Optional[ShortTermMemory] = None
"""The short term memory provided to agent. This attribute is not used in agent directly, as it will be passed to Runner in VeADK."""
def model_post_init(self, __context: Any) -> None:
super().model_post_init(None) # for sub_agents init
if self.tracers:
logger.warning(
"Enable tracing in ParallelAgent may cause OpenTelemetry context error. Issue see https://github.com/google/adk-python/issues/1670"
)
logger.info(f"{self.__class__.__name__} `{self.name}` init done.")