Skip to content

Commit b8ef782

Browse files
authored
Merge pull request #147 from tkrevh/bugfix-142
BUGFIX-142
2 parents 7568829 + 1b18fb9 commit b8ef782

8 files changed

Lines changed: 208 additions & 125 deletions

File tree

agentstack/cli/cli.py

Lines changed: 77 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from typing import Optional
2-
import os, sys
2+
import os
3+
import sys
34
import time
45
from datetime import datetime
5-
from pathlib import Path
66

77
import json
88
import shutil
@@ -26,8 +26,9 @@
2626
from agentstack import inputs
2727
from agentstack.agents import get_all_agents
2828
from agentstack.tasks import get_all_tasks
29-
from agentstack.utils import open_json_file, term_color, is_snake_case, get_framework
29+
from agentstack.utils import open_json_file, term_color, is_snake_case, get_framework, validator_not_empty
3030
from agentstack.proj_templates import TemplateConfig
31+
from agentstack.exceptions import ValidationError
3132

3233

3334
PREFERRED_MODELS = [
@@ -184,6 +185,75 @@ def ask_framework() -> str:
184185
return framework
185186

186187

188+
def get_validated_input(
189+
message: str,
190+
validate_func=None,
191+
min_length: int = 0,
192+
snake_case: bool = False,
193+
) -> str:
194+
"""Helper function to get validated input from user.
195+
196+
Args:
197+
message: The prompt message to display
198+
validate_func: Optional custom validation function
199+
min_length: Minimum length requirement (0 for no requirement)
200+
snake_case: Whether to enforce snake_case naming
201+
"""
202+
while True:
203+
try:
204+
value = inquirer.text(
205+
message=message,
206+
validate=validate_func or validator_not_empty(min_length) if min_length else None,
207+
)
208+
if snake_case and not is_snake_case(value):
209+
raise ValidationError("Input must be in snake_case")
210+
return value
211+
except ValidationError as e:
212+
print(term_color(f"Error: {str(e)}", 'red'))
213+
214+
215+
def ask_agent_details():
216+
agent = {}
217+
218+
agent['name'] = get_validated_input(
219+
"What's the name of this agent? (snake_case)", min_length=3, snake_case=True
220+
)
221+
222+
agent['role'] = get_validated_input("What role does this agent have?", min_length=3)
223+
224+
agent['goal'] = get_validated_input("What is the goal of the agent?", min_length=10)
225+
226+
agent['backstory'] = get_validated_input("Give your agent a backstory", min_length=10)
227+
228+
agent['model'] = inquirer.list_input(
229+
message="What LLM should this agent use?", choices=PREFERRED_MODELS, default=PREFERRED_MODELS[0]
230+
)
231+
232+
return agent
233+
234+
235+
def ask_task_details(agents: list[dict]) -> dict:
236+
task = {}
237+
238+
task['name'] = get_validated_input(
239+
"What's the name of this task? (snake_case)", min_length=3, snake_case=True
240+
)
241+
242+
task['description'] = get_validated_input("Describe the task in more detail", min_length=10)
243+
244+
task['expected_output'] = get_validated_input(
245+
"What do you expect the result to look like? (ex: A 5 bullet point summary of the email)",
246+
min_length=10,
247+
)
248+
249+
task['agent'] = inquirer.list_input(
250+
message="Which agent should be assigned this task?",
251+
choices=[a['name'] for a in agents],
252+
)
253+
254+
return task
255+
256+
187257
def ask_design() -> dict:
188258
use_wizard = inquirer.confirm(
189259
message="Would you like to use the CLI wizard to set up agents and tasks?",
@@ -208,39 +278,10 @@ def ask_design() -> dict:
208278
while make_agent:
209279
print('---')
210280
print(f"Agent #{len(agents)+1}")
211-
212-
agent_incomplete = True
213281
agent = None
214-
while agent_incomplete:
215-
agent = inquirer.prompt(
216-
[
217-
inquirer.Text("name", message="What's the name of this agent? (snake_case)"),
218-
inquirer.Text("role", message="What role does this agent have?"),
219-
inquirer.Text("goal", message="What is the goal of the agent?"),
220-
inquirer.Text("backstory", message="Give your agent a backstory"),
221-
# TODO: make a list - #2
222-
inquirer.Text(
223-
'model',
224-
message="What LLM should this agent use? (any LiteLLM provider)",
225-
default="openai/gpt-4",
226-
),
227-
# inquirer.List("model", message="What LLM should this agent use? (any LiteLLM provider)", choices=[
228-
# 'mixtral_llm',
229-
# 'mixtral_llm',
230-
# ]),
231-
]
232-
)
233-
234-
if not agent['name'] or agent['name'] == '':
235-
print(term_color("Error: Agent name is required - Try again", 'red'))
236-
agent_incomplete = True
237-
elif not is_snake_case(agent['name']):
238-
print(term_color("Error: Agent name must be snake case - Try again", 'red'))
239-
else:
240-
agent_incomplete = False
241-
242-
make_agent = inquirer.confirm(message="Create another agent?")
282+
agent = ask_agent_details()
243283
agents.append(agent)
284+
make_agent = inquirer.confirm(message="Create another agent?")
244285

245286
print('')
246287
for x in range(3):
@@ -257,35 +298,9 @@ def ask_design() -> dict:
257298
while make_task:
258299
print('---')
259300
print(f"Task #{len(tasks) + 1}")
260-
261-
task_incomplete = True
262-
task = None
263-
while task_incomplete:
264-
task = inquirer.prompt(
265-
[
266-
inquirer.Text("name", message="What's the name of this task? (snake_case)"),
267-
inquirer.Text("description", message="Describe the task in more detail"),
268-
inquirer.Text(
269-
"expected_output",
270-
message="What do you expect the result to look like? (ex: A 5 bullet point summary of the email)",
271-
),
272-
inquirer.List(
273-
"agent",
274-
message="Which agent should be assigned this task?",
275-
choices=[a['name'] for a in agents], # type: ignore
276-
),
277-
]
278-
)
279-
280-
if not task['name'] or task['name'] == '':
281-
print(term_color("Error: Task name is required - Try again", 'red'))
282-
elif not is_snake_case(task['name']):
283-
print(term_color("Error: Task name must be snake case - Try again", 'red'))
284-
else:
285-
task_incomplete = False
286-
287-
make_task = inquirer.confirm(message="Create another task?")
301+
task = ask_task_details(agents)
288302
tasks.append(task)
303+
make_task = inquirer.confirm(message="Create another task?")
289304

290305
print('')
291306
for x in range(3):

agentstack/utils.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
1-
from typing import Optional, Union
2-
import os, sys
1+
import os
2+
import sys
33
import json
44
from ruamel.yaml import YAML
55
import re
66
from importlib.metadata import version
77
from pathlib import Path
88
import importlib.resources
99
from agentstack import conf
10+
from inquirer import errors as inquirer_errors
1011

1112

1213
def get_version(package: str = 'agentstack'):
@@ -108,3 +109,14 @@ def term_color(text: str, color: str) -> str:
108109

109110
def is_snake_case(string: str):
110111
return bool(re.match('^[a-z0-9_]+$', string))
112+
113+
114+
def validator_not_empty(min_length=1):
115+
def validator(_, answer):
116+
if len(answer) < min_length:
117+
raise inquirer_errors.ValidationError(
118+
'', reason=f"This field must be at least {min_length} characters long."
119+
)
120+
return True
121+
122+
return validator

tests/cli_test_utils.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import os, sys
2+
import subprocess
3+
4+
CLI_ENTRY = [
5+
sys.executable,
6+
"-m",
7+
"agentstack.main",
8+
]
9+
10+
def run_cli(*args):
11+
"""Helper method to run the CLI with arguments. Cross-platform."""
12+
try:
13+
# Use shell=True on Windows to handle path issues
14+
if sys.platform == 'win32':
15+
# Add PYTHONIOENCODING to the environment
16+
env = os.environ.copy()
17+
env['PYTHONIOENCODING'] = 'utf-8'
18+
result = subprocess.run(
19+
" ".join(str(arg) for arg in CLI_ENTRY + list(args)),
20+
capture_output=True,
21+
text=True,
22+
shell=True,
23+
env=env,
24+
encoding='utf-8'
25+
)
26+
else:
27+
result = subprocess.run(
28+
[*CLI_ENTRY, *args],
29+
capture_output=True,
30+
text=True,
31+
encoding='utf-8'
32+
)
33+
34+
if result.returncode != 0:
35+
print(f"Command failed with code {result.returncode}")
36+
print(f"STDOUT: {result.stdout}")
37+
print(f"STDERR: {result.stderr}")
38+
39+
return result
40+
except Exception as e:
41+
print(f"Exception running command: {e}")
42+
raise

tests/test_cli_init.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,26 @@
1-
import subprocess
21
import os, sys
32
import unittest
43
from parameterized import parameterized
54
from pathlib import Path
65
import shutil
6+
from cli_test_utils import run_cli
77

88
BASE_PATH = Path(__file__).parent
9-
CLI_ENTRY = [
10-
sys.executable,
11-
"-m",
12-
"agentstack.main",
13-
]
14-
159

1610
class CLIInitTest(unittest.TestCase):
1711
def setUp(self):
1812
self.project_dir = Path(BASE_PATH / 'tmp/cli_init')
19-
os.makedirs(self.project_dir)
13+
os.chdir(BASE_PATH) # Change to parent directory first
14+
os.makedirs(self.project_dir, exist_ok=True)
2015
os.chdir(self.project_dir)
16+
# Force UTF-8 encoding for the test environment
17+
os.environ['PYTHONIOENCODING'] = 'utf-8'
2118

2219
def tearDown(self):
23-
shutil.rmtree(self.project_dir)
24-
25-
def _run_cli(self, *args):
26-
"""Helper method to run the CLI with arguments."""
27-
return subprocess.run([*CLI_ENTRY, *args], capture_output=True, text=True)
20+
shutil.rmtree(self.project_dir, ignore_errors=True)
2821

2922
def test_init_command(self):
3023
"""Test the 'init' command to create a project directory."""
31-
result = self._run_cli('init', 'test_project')
24+
result = run_cli('init', 'test_project')
3225
self.assertEqual(result.returncode, 0)
3326
self.assertTrue((self.project_dir / 'test_project').exists())

tests/test_cli_loads.py

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,25 +3,16 @@
33
import unittest
44
from pathlib import Path
55
import shutil
6+
from cli_test_utils import run_cli
67

78
BASE_PATH = Path(__file__).parent
89

910

1011
class TestAgentStackCLI(unittest.TestCase):
11-
CLI_ENTRY = [
12-
sys.executable,
13-
"-m",
14-
"agentstack.main",
15-
]
16-
17-
def run_cli(self, *args):
18-
"""Helper method to run the CLI with arguments."""
19-
result = subprocess.run([*self.CLI_ENTRY, *args], capture_output=True, text=True)
20-
return result
2112

2213
def test_version(self):
2314
"""Test the --version command."""
24-
result = self.run_cli("--version")
15+
result = run_cli("--version")
2516
print(result.stdout)
2617
print(result.stderr)
2718
print(result.returncode)
@@ -30,27 +21,27 @@ def test_version(self):
3021

3122
def test_invalid_command(self):
3223
"""Test an invalid command gracefully exits."""
33-
result = self.run_cli("invalid_command")
24+
result = run_cli("invalid_command")
3425
self.assertNotEqual(result.returncode, 0)
3526
self.assertIn("usage:", result.stderr)
3627

3728
def test_run_command_invalid_project(self):
3829
"""Test the 'run' command on an invalid project."""
3930
test_dir = Path(BASE_PATH / 'tmp/test_project')
4031
if test_dir.exists():
41-
shutil.rmtree(test_dir)
32+
shutil.rmtree(test_dir, ignore_errors=True)
4233
os.makedirs(test_dir)
4334

4435
# Write a basic agentstack.json file
4536
with (test_dir / 'agentstack.json').open('w') as f:
4637
f.write(open(BASE_PATH / 'fixtures/agentstack.json', 'r').read())
4738

4839
os.chdir(test_dir)
49-
result = self.run_cli('run')
40+
result = run_cli('run')
5041
self.assertNotEqual(result.returncode, 0)
5142
self.assertIn("Project validation failed", result.stdout)
5243

53-
shutil.rmtree(test_dir)
44+
shutil.rmtree(test_dir, ignore_errors=True)
5445

5546

5647
if __name__ == "__main__":

0 commit comments

Comments
 (0)