Skip to content

Commit a8a0987

Browse files
authored
Merge pull request #260 from tcdent/validate-tasks-agents
Validate that YAML tasks & agents match methods in entrypoint
2 parents 1ef2fda + 0d62dd3 commit a8a0987

5 files changed

Lines changed: 111 additions & 9 deletions

File tree

agentstack/agents.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from ruamel.yaml import YAML, YAMLError
66
from ruamel.yaml.scalarstring import FoldedScalarString
77
from agentstack import conf, log
8-
from agentstack import frameworks
98
from agentstack.exceptions import ValidationError
109

1110

@@ -71,10 +70,12 @@ def __init__(self, name: str):
7170

7271
@property
7372
def provider(self) -> str:
73+
from agentstack import frameworks
7474
return frameworks.parse_llm(self.llm)[0]
7575

7676
@property
7777
def model(self) -> str:
78+
from agentstack import frameworks
7879
return frameworks.parse_llm(self.llm)[1]
7980

8081
@property

agentstack/frameworks/__init__.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,13 @@
55
from agentstack import conf
66
from agentstack.exceptions import ValidationError
77
from agentstack.utils import get_framework
8+
from agentstack.agents import AgentConfig, get_all_agent_names
9+
from agentstack.tasks import TaskConfig, get_all_task_names
810
from agentstack._tools import ToolConfig
911
from agentstack import graph
1012

1113
if TYPE_CHECKING:
1214
from agentstack.generation import InsertionPoint
13-
from agentstack.agents import AgentConfig
14-
from agentstack.tasks import TaskConfig
1515

1616

1717
CREWAI = 'crewai'
@@ -122,7 +122,28 @@ def validate_project():
122122
"""
123123
Validate that the user's project is ready to run.
124124
"""
125-
return get_framework_module(get_framework()).validate_project()
125+
framework = get_framework()
126+
entrypoint_path = get_entrypoint_path(framework)
127+
_module = get_framework_module(framework)
128+
129+
# Run framework-specific validation
130+
_module.validate_project()
131+
132+
# Verify that agents defined in agents.yaml are present in the codebase
133+
agent_method_names = _module.get_agent_method_names()
134+
for agent_name in get_all_agent_names():
135+
if agent_name not in agent_method_names:
136+
raise ValidationError(
137+
f"Agent `{agent_name}` is defined in agents.yaml but not in {entrypoint_path}"
138+
)
139+
140+
# Verify that tasks defined in tasks.yaml are present in the codebase
141+
task_method_names = _module.get_task_method_names()
142+
for task_name in get_all_task_names():
143+
if task_name not in task_method_names:
144+
raise ValidationError(
145+
f"Task `{task_name}` is defined in tasks.yaml but not in {entrypoint_path}"
146+
)
126147

127148

128149
def parse_llm(llm: str) -> tuple[str, str]:

tests/fixtures/frameworks/crewai/entrypoint_max.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,17 @@ class TestCrew:
99
def agent_name(self) -> Agent:
1010
return Agent(config=self.agents_config['agent_name'], tools=[], verbose=True)
1111

12+
@agent
13+
def second_agent_name(self) -> Agent:
14+
return Agent(config=self.agents_config['second_agent_name'], tools=[], verbose=True)
15+
1216
@task
1317
def task_name(self) -> Task:
14-
return Task(
15-
config=self.tasks_config['task_name'],
16-
)
18+
return Task(config=self.tasks_config['task_name'])
19+
20+
@task
21+
def task_name_two(self) -> Task:
22+
return Task(config=self.tasks_config['task_name_two'])
1723

1824
@crew
1925
def crew(self) -> Crew:

tests/fixtures/frameworks/langgraph/entrypoint_max.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,20 @@ def agent_name(self, state: State):
3131
)
3232
return {'messages': [response, ]}
3333

34+
@agentstack.agent
35+
def second_agent_name(self, state: State):
36+
agent_config = agentstack.get_agent('second_agent_name')
37+
messages = ChatPromptTemplate.from_messages([
38+
("user", agent_config.prompt),
39+
])
40+
messages = messages.format_messages(**state['inputs'])
41+
agent = ChatOpenAI(model=agent_config.model)
42+
agent = agent.bind_tools([])
43+
response = agent.invoke(
44+
messages + state['messages'],
45+
)
46+
return {'messages': [response, ]}
47+
3448
@agentstack.task
3549
def task_name(self, state: State):
3650
task_config = agentstack.get_task('task_name')
@@ -40,6 +54,15 @@ def task_name(self, state: State):
4054
messages = messages.format_messages(**state['inputs'])
4155
return {'messages': messages + state['messages']}
4256

57+
@agentstack.task
58+
def task_name_two(self, state: State):
59+
task_config = agentstack.get_task('task_name_two')
60+
messages = ChatPromptTemplate.from_messages([
61+
("user", task_config.prompt),
62+
])
63+
messages = messages.format_messages(**state['inputs'])
64+
return {'messages': messages + state['messages']}
65+
4366
def run(self, inputs: list[str]):
4467
self.graph = StateGraph(State)
4568
tools = ToolNode([])
@@ -49,11 +72,19 @@ def run(self, inputs: list[str]):
4972
self.graph.add_edge("agent_name", "tools")
5073
self.graph.add_conditional_edges("agent_name", tools_condition)
5174

75+
self.graph.add_node("second_agent_name", self.agent_name)
76+
self.graph.add_edge("second_agent_name", "tools")
77+
self.graph.add_conditional_edges("second_agent_name", tools_condition)
78+
5279
self.graph.add_node("task_name", self.task_name)
80+
self.graph.add_node("task_name_two", self.task_name)
5381

5482
self.graph.add_edge(START, "task_name")
83+
self.graph.add_edge(START, "task_name_two")
5584
self.graph.add_edge("task_name", "agent_name")
85+
self.graph.add_edge("task_name_two", "second_agent_name")
5686
self.graph.add_edge("agent_name", END)
87+
self.graph.add_edge("second_agent_name", END)
5788

5889
app = self.graph.compile()
5990
result = app.invoke({

tests/test_frameworks.py

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,15 +45,21 @@ def _populate_max_entrypoint(self):
4545
"""This entrypoint has tools and agents."""
4646
entrypoint_path = frameworks.get_entrypoint_path(self.framework)
4747
shutil.copy(BASE_PATH / f"fixtures/frameworks/{self.framework}/entrypoint_max.py", entrypoint_path)
48+
shutil.copy(BASE_PATH / 'fixtures/agents_max.yaml', self.project_dir / AGENTS_FILENAME)
49+
shutil.copy(BASE_PATH / 'fixtures/tasks_max.yaml', self.project_dir / TASKS_FILENAME)
4850

4951
def _get_test_agent(self) -> AgentConfig:
50-
shutil.copy(BASE_PATH / 'fixtures/agents_max.yaml', self.project_dir / AGENTS_FILENAME)
5152
return AgentConfig('agent_name')
5253

54+
def _get_test_agent_alternate(self) -> AgentConfig:
55+
return AgentConfig('second_agent_name')
56+
5357
def _get_test_task(self) -> TaskConfig:
54-
shutil.copy(BASE_PATH / 'fixtures/tasks_max.yaml', self.project_dir / TASKS_FILENAME)
5558
return TaskConfig('task_name')
5659

60+
def _get_test_task_alternate(self) -> TaskConfig:
61+
return TaskConfig('task_name_two')
62+
5763
def _get_test_tool(self) -> ToolConfig:
5864
return ToolConfig(name='test_tool', category='test', tools=['test_tool'])
5965

@@ -88,6 +94,8 @@ def test_validate_project_invalid(self):
8894

8995
def test_validate_project_has_agent_no_task_invalid(self):
9096
self._populate_min_entrypoint()
97+
shutil.copy(BASE_PATH / 'fixtures/agents_max.yaml', self.project_dir / AGENTS_FILENAME)
98+
9199
frameworks.add_agent(self._get_test_agent())
92100
with self.assertRaises(ValidationError) as context:
93101
frameworks.validate_project()
@@ -98,6 +106,38 @@ def test_validate_project_has_task_no_agent_invalid(self):
98106
with self.assertRaises(ValidationError) as context:
99107
frameworks.validate_project()
100108

109+
def test_validate_project_missing_agent_method_invalid(self):
110+
"""Ensure that all agents have a method defined in the entrypoint."""
111+
self._populate_max_entrypoint()
112+
# add an extra entry to agents.yaml
113+
with open(self.project_dir / AGENTS_FILENAME, 'a') as f:
114+
f.write("""\nextra_agent:
115+
role: >-
116+
role
117+
goal: >-
118+
this is a goal
119+
backstory: >-
120+
this is a backstory
121+
llm: openai/gpt-4o""")
122+
with self.assertRaises(ValidationError) as context:
123+
frameworks.validate_project()
124+
125+
def test_validate_project_missing_task_method_invalid(self):
126+
"""Ensure that all tasks have a method defined in the entrypoint."""
127+
self._populate_max_entrypoint()
128+
# add an extra entry to tasks.yaml
129+
with open(self.project_dir / TASKS_FILENAME, 'a') as f:
130+
f.write("""\nextra_task:
131+
description: >-
132+
Add your description here
133+
expected_output: >-
134+
Add your expected output here
135+
agent: >-
136+
default_agent""")
137+
138+
with self.assertRaises(ValidationError) as context:
139+
frameworks.validate_project()
140+
101141
def test_get_agent_tool_names(self):
102142
self._populate_max_entrypoint()
103143
frameworks.add_tool(self._get_test_tool(), 'agent_name')
@@ -167,6 +207,9 @@ def test_get_tool_callables(self, tool_config):
167207

168208
def test_get_graph(self):
169209
self._populate_max_entrypoint()
210+
shutil.copy(BASE_PATH / 'fixtures/agents_max.yaml', self.project_dir / AGENTS_FILENAME)
211+
shutil.copy(BASE_PATH / 'fixtures/tasks_max.yaml', self.project_dir / TASKS_FILENAME)
212+
170213
self._get_test_agent()
171214
self._get_test_task()
172215

0 commit comments

Comments
 (0)