-
Notifications
You must be signed in to change notification settings - Fork 215
Expand file tree
/
Copy pathproj_templates.py
More file actions
155 lines (132 loc) · 4.77 KB
/
proj_templates.py
File metadata and controls
155 lines (132 loc) · 4.77 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
from typing import Literal
import os
from pathlib import Path
import pydantic
import requests
import json
from agentstack.exceptions import ValidationError
from agentstack.utils import get_package_path
class TemplateConfig_v1(pydantic.BaseModel):
name: str
description: str
template_version: Literal[1]
framework: str
method: str
agents: list[dict]
tasks: list[dict]
tools: list[dict]
inputs: list[str]
def to_v2(self) -> 'TemplateConfig':
return TemplateConfig(
name=self.name,
description=self.description,
template_version=2,
framework=self.framework,
method=self.method,
agents=[TemplateConfig.Agent(**agent) for agent in self.agents],
tasks=[TemplateConfig.Task(**task) for task in self.tasks],
tools=[TemplateConfig.Tool(**tool) for tool in self.tools],
inputs={key: "" for key in self.inputs},
)
class TemplateConfig(pydantic.BaseModel):
"""
Interface for interacting with template configuration files.
Templates are read-only.
Template Schema
-------------
name: str
The name of the project.
description: str
A description of the template.
template_version: int
The version of the template.
framework: str
The framework the template is for.
method: str
The method used by the project. ie. "sequential"
agents: list[TemplateConfig.Agent]
A list of agents used by the project.
tasks: list[TemplateConfig.Task]
A list of tasks used by the project.
tools: list[TemplateConfig.Tool]
A list of tools used by the project.
inputs: list[str]
A list of inputs used by the project.
"""
class Agent(pydantic.BaseModel):
name: str
role: str
goal: str
backstory: str
model: str
class Task(pydantic.BaseModel):
name: str
description: str
expected_output: str
agent: str
class Tool(pydantic.BaseModel):
name: str
agents: list[str]
name: str
description: str
template_version: Literal[2]
framework: str
method: str
agents: list[Agent]
tasks: list[Task]
tools: list[Tool]
inputs: dict[str, str]
def write_to_file(self, filename: Path):
if not filename.suffix == '.json':
filename = filename.with_suffix('.json')
with open(filename, 'w') as f:
model_dump = self.model_dump()
f.write(json.dumps(model_dump, indent=4))
@classmethod
def from_template_name(cls, name: str) -> 'TemplateConfig':
path = get_package_path() / f'templates/proj_templates/{name}.json'
if name not in get_all_template_names():
raise ValidationError(f"Template {name} not bundled with agentstack.")
return cls.from_file(path)
@classmethod
def from_file(cls, path: Path) -> 'TemplateConfig':
if not os.path.exists(path):
raise ValidationError(f"Template {path} not found.")
with open(path, 'r') as f:
return cls.from_json(json.load(f))
@classmethod
def from_url(cls, url: str) -> 'TemplateConfig':
if not url.startswith("https://"):
raise ValidationError(f"Invalid URL: {url}")
response = requests.get(url)
if response.status_code != 200:
raise ValidationError(f"Failed to fetch template from {url}")
return cls.from_json(response.json())
@classmethod
def from_json(cls, data: dict) -> 'TemplateConfig':
try:
match data.get('template_version'):
case 1:
return TemplateConfig_v1(**data).to_v2()
case 2:
return cls(**data) # current version
case _:
raise ValidationError(f"Unsupported template version: {data.get('template_version')}")
except pydantic.ValidationError as e:
err_msg = "Error validating template config JSON:\n"
for error in e.errors():
err_msg += f"{' '.join([str(loc) for loc in error['loc']])}: {error['msg']}\n"
raise ValidationError(err_msg)
except json.JSONDecodeError as e:
raise ValidationError(f"Error decoding template JSON.\n{e}")
def get_all_template_paths() -> list[Path]:
paths = []
templates_dir = get_package_path() / 'templates/proj_templates'
for file in templates_dir.iterdir():
if file.suffix == '.json':
paths.append(file)
return paths
def get_all_template_names() -> list[str]:
return [path.stem for path in get_all_template_paths()]
def get_all_templates() -> list[TemplateConfig]:
return [TemplateConfig.from_file(path) for path in get_all_template_paths()]