forked from ogx-ai/ogx-client-python
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathagent.py
More file actions
228 lines (205 loc) · 9.31 KB
/
agent.py
File metadata and controls
228 lines (205 loc) · 9.31 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
from typing import Any, Callable, List, Optional, Tuple, Union
from llama_stack_client import LlamaStackClient
from llama_stack_client.types.agent_create_params import AgentConfig
from llama_stack_client.types.agents.turn_create_params import Toolgroup
from llama_stack_client.types.shared_params.agent_config import ToolConfig
from llama_stack_client.types.shared_params.response_format import ResponseFormat
from llama_stack_client.types.shared_params.sampling_params import SamplingParams
from ...._types import Headers
from ..agent import Agent, AgentUtils
from ..client_tool import ClientTool
from ..tool_parser import ToolParser
from .prompts import DEFAULT_REACT_AGENT_SYSTEM_PROMPT_TEMPLATE
from .tool_parser import ReActOutput, ReActToolParser
logger = logging.getLogger(__name__)
def get_tool_defs(
client: LlamaStackClient, builtin_toolgroups: Tuple[Toolgroup] = (), client_tools: Tuple[ClientTool] = ()
):
tool_defs = []
for x in builtin_toolgroups:
if isinstance(x, str):
toolgroup_id = x
else:
toolgroup_id = x["name"]
tool_defs.extend(
[
{
"name": tool.identifier,
"description": tool.description,
"parameters": tool.parameters,
}
for tool in client.tools.list(toolgroup_id=toolgroup_id)
]
)
tool_defs.extend(
[
{
"name": tool.get_name(),
"description": tool.get_description(),
"parameters": tool.get_params_definition(),
}
for tool in client_tools
]
)
return tool_defs
def get_default_react_instructions(
client: LlamaStackClient, builtin_toolgroups: Tuple[str] = (), client_tools: Tuple[ClientTool] = ()
):
tool_defs = get_tool_defs(client, builtin_toolgroups, client_tools)
tool_names = ", ".join([x["name"] for x in tool_defs])
tool_descriptions = "\n".join([f"- {x['name']}: {x}" for x in tool_defs])
instruction = DEFAULT_REACT_AGENT_SYSTEM_PROMPT_TEMPLATE.replace("<<tool_names>>", tool_names).replace(
"<<tool_descriptions>>", tool_descriptions
)
return instruction
def get_agent_config_DEPRECATED(
client: LlamaStackClient,
model: str,
builtin_toolgroups: Tuple[str] = (),
client_tools: Tuple[ClientTool] = (),
json_response_format: bool = False,
custom_agent_config: Optional[AgentConfig] = None,
) -> AgentConfig:
if custom_agent_config is None:
instruction = get_default_react_instructions(client, builtin_toolgroups, client_tools)
# user default toolgroups
agent_config = AgentConfig(
model=model,
instructions=instruction,
toolgroups=builtin_toolgroups,
client_tools=[client_tool.get_tool_definition() for client_tool in client_tools],
tool_config={
"tool_choice": "auto",
"system_message_behavior": "replace",
},
input_shields=[],
output_shields=[],
enable_session_persistence=False,
)
else:
agent_config = custom_agent_config
if json_response_format:
agent_config["response_format"] = {
"type": "json_schema",
"json_schema": ReActOutput.model_json_schema(),
}
return agent_config
class ReActAgent(Agent):
"""ReAct agent.
Simple wrapper around Agent to add prepare prompts for creating a ReAct agent from a list of tools.
"""
def __init__(
self,
client: LlamaStackClient,
model: str,
tool_parser: ToolParser = ReActToolParser(),
instructions: Optional[str] = None,
tools: Optional[List[Union[Toolgroup, ClientTool, Callable[..., Any]]]] = None,
tool_config: Optional[ToolConfig] = None,
sampling_params: Optional[SamplingParams] = None,
max_infer_iters: Optional[int] = None,
input_shields: Optional[List[str]] = None,
output_shields: Optional[List[str]] = None,
response_format: Optional[ResponseFormat] = None,
enable_session_persistence: Optional[bool] = None,
json_response_format: bool = False,
builtin_toolgroups: Tuple[str] = (), # DEPRECATED
client_tools: Tuple[ClientTool] = (), # DEPRECATED
custom_agent_config: Optional[AgentConfig] = None, # DEPRECATED
extra_headers: Headers | None = None,
):
"""Construct an Agent with the given parameters.
:param client: The LlamaStackClient instance.
:param custom_agent_config: The AgentConfig instance.
::deprecated: use other parameters instead
:param client_tools: A tuple of ClientTool instances.
::deprecated: use tools instead
:param builtin_toolgroups: A tuple of Toolgroup instances.
::deprecated: use tools instead
:param tool_parser: Custom logic that parses tool calls from a message.
:param model: The model to use for the agent.
:param instructions: The instructions for the agent.
:param tools: A list of tools for the agent. Values can be one of the following:
- dict representing a toolgroup/tool with arguments: e.g. {"name": "builtin::rag/knowledge_search", "args": {"vector_db_ids": [123]}}
- a python function with a docstring. See @client_tool for more details.
- str representing a tool within a toolgroup: e.g. "builtin::rag/knowledge_search"
- str representing a toolgroup_id: e.g. "builtin::rag", "builtin::code_interpreter", where all tools in the toolgroup will be added to the agent
- an instance of ClientTool: A client tool object.
:param tool_config: The tool configuration for the agent.
:param sampling_params: The sampling parameters for the agent.
:param max_infer_iters: The maximum number of inference iterations.
:param input_shields: The input shields for the agent.
:param output_shields: The output shields for the agent.
:param response_format: The response format for the agent.
:param enable_session_persistence: Whether to enable session persistence.
:param json_response_format: Whether to use the json response format with default ReAct output schema.
::deprecated: use response_format instead
:param extra_headers: Extra headers to add to all requests sent by the agent.
"""
use_deprecated_params = False
if custom_agent_config is not None:
logger.warning("`custom_agent_config` is deprecated. Use inlined parameters instead.")
use_deprecated_params = True
if client_tools != ():
logger.warning("`client_tools` is deprecated. Use `tools` instead.")
use_deprecated_params = True
if builtin_toolgroups != ():
logger.warning("`builtin_toolgroups` is deprecated. Use `tools` instead.")
use_deprecated_params = True
if use_deprecated_params:
agent_config = get_agent_config_DEPRECATED(
client=client,
model=model,
builtin_toolgroups=builtin_toolgroups,
client_tools=client_tools,
json_response_format=json_response_format,
)
super().__init__(
client=client,
agent_config=agent_config,
client_tools=client_tools,
tool_parser=tool_parser,
extra_headers=extra_headers,
)
else:
if not tool_config:
tool_config = {
"tool_choice": "auto",
"system_message_behavior": "replace",
}
if json_response_format:
if instructions is not None:
logger.warning(
"Using a custom instructions, but json_response_format is set. Please make sure instructions are"
"compatible with the default ReAct output format."
)
response_format = {
"type": "json_schema",
"json_schema": ReActOutput.model_json_schema(),
}
# build REACT instructions
client_tools = AgentUtils.get_client_tools(tools)
builtin_toolgroups = [x for x in tools if isinstance(x, str) or isinstance(x, dict)]
if not instructions:
instructions = get_default_react_instructions(client, builtin_toolgroups, client_tools)
super().__init__(
client=client,
model=model,
tool_parser=tool_parser,
instructions=instructions,
tools=tools,
tool_config=tool_config,
sampling_params=sampling_params,
max_infer_iters=max_infer_iters,
input_shields=input_shields,
output_shields=output_shields,
response_format=response_format,
enable_session_persistence=enable_session_persistence,
extra_headers=extra_headers,
)