This repository was archived by the owner on Apr 26, 2026. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathassistants.py
More file actions
177 lines (144 loc) · 5.88 KB
/
assistants.py
File metadata and controls
177 lines (144 loc) · 5.88 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
"""Assistants API implementation
Note that this API is (will soon?) be deprecated in favor of the responses API.
It does not support the gpt-5 series of models.
* https://platform.openai.com/docs/assistants/tools/function-calling
* https://platform.openai.com/docs/assistants/deep-dive#runs-and-run-steps
* https://platform.openai.com/docs/api-reference/assistants-streaming/events
* https://github.com/openai/openai-python/blob/main/src/openai/resources/beta/threads/runs/runs.py
"""
from collections.abc import Sequence
import logging
from pathlib import PurePosixPath
from typing import Any, Self, TypedDict, override
import openai
from ...common import JSONObject
from ...instructions import SYSTEM_PROMPT
from ..common import ActionSummary, Bot, Goal, UserFeedback, Worktree
from .common import ToolHandler, ToolsFactory, new_client
_logger = logging.getLogger(__name__)
def new_threads_bot(
api_key: str | None = None,
base_url: str | None = None,
model: str = "gpt-4o",
) -> Bot:
"""Beta bot, uses assistant threads with function calling"""
return _ThreadsBot(new_client(api_key, base_url), model)
class _ThreadsBot(Bot):
def __init__(self, client: openai.OpenAI, model: str) -> None:
self._client = client
self._model = model
def _load_assistant_id(self) -> str:
kwargs: JSONObject = dict(
model=self._model,
instructions=SYSTEM_PROMPT,
tools=ToolsFactory(strict=True).params(),
)
path = self.state_folder_path(ensure_exists=True) / "ASSISTANT_ID"
try:
with open(path) as f:
assistant_id = f.read()
self._client.beta.assistants.update(assistant_id, **kwargs)
except (FileNotFoundError, openai.NotFoundError):
assistant = self._client.beta.assistants.create(**kwargs)
assistant_id = assistant.id
with open(path, "w") as f:
f.write(assistant_id)
return assistant_id
async def act(
self, goal: Goal, tree: Worktree, feedback: UserFeedback
) -> ActionSummary:
assistant_id = self._load_assistant_id()
thread = self._client.beta.threads.create()
self._client.beta.threads.messages.create(
thread_id=thread.id,
role="user",
content=goal.prompt,
)
# We intentionally do not count the two requests above, to focus on
# "data requests" only.
action = ActionSummary(turn_count=0, token_count=0)
with self._client.beta.threads.runs.stream(
thread_id=thread.id,
assistant_id=assistant_id,
event_handler=_EventHandler(self._client, tree, feedback, action),
) as stream:
stream.until_done()
return action
class _EventHandler(openai.AssistantEventHandler):
def __init__(
self,
client: openai.Client,
tree: Worktree,
feedback: UserFeedback,
action: ActionSummary,
) -> None:
super().__init__()
self._client = client
self._tree = tree
self._feedback = feedback
self._action = action
self._action.increment_turn_count()
def _clone(self) -> Self:
return self.__class__(
self._client, self._tree, self._feedback, self._action
)
@override
def on_event(self, event: openai.types.beta.AssistantStreamEvent) -> None:
if event.event == "thread.run.requires_action":
run_id = event.data.id # Retrieve the run ID from the event data
self._handle_action(run_id, event.data)
elif event.event == "thread.run.completed":
_logger.info("Threads run completed. [usage=%s]", event.data.usage)
else:
_logger.debug("Threads event: %s", event)
@override
def on_run_step_done(
self, run_step: openai.types.beta.threads.runs.RunStep
) -> None:
usage = run_step.usage
if usage:
_logger.debug("Threads run step usage: %s", usage)
self._action.increment_token_count(usage.total_tokens)
else:
_logger.warning("Missing usage in threads run step")
def _handle_action(self, _run_id: str, data: Any) -> None:
tool_outputs = list[Any]()
for tool in data.required_action.submit_tool_outputs.tool_calls:
handler = _ThreadToolHandler(self._tree, self._feedback, tool.id)
tool_outputs.append(handler.handle_function(tool.function))
run = self.current_run
assert run, "No ongoing run"
with self._client.beta.threads.runs.submit_tool_outputs_stream(
thread_id=run.thread_id,
run_id=run.id,
tool_outputs=tool_outputs,
event_handler=self._clone(),
) as stream:
stream.until_done()
class _ToolOutput(TypedDict):
tool_call_id: str
output: str
class _ThreadToolHandler(ToolHandler[_ToolOutput]):
def __init__(
self, tree: Worktree, feedback: UserFeedback, call_id: str
) -> None:
super().__init__(tree, feedback)
self._call_id = call_id
def _wrap(self, output: str) -> _ToolOutput:
return _ToolOutput(tool_call_id=self._call_id, output=output)
def _on_ask_user(self, response: str) -> _ToolOutput:
return self._wrap(response)
def _on_read_file(
self, _path: PurePosixPath, contents: str | None
) -> _ToolOutput:
return self._wrap(contents or "")
def _on_write_file(self, _path: PurePosixPath) -> _ToolOutput:
return self._wrap("OK")
def _on_delete_file(self, _path: PurePosixPath) -> _ToolOutput:
return self._wrap("OK")
def _on_rename_file(
self, _src_path: PurePosixPath, _dst_path: PurePosixPath
) -> _ToolOutput:
return self._wrap("OK")
def _on_list_files(self, paths: Sequence[PurePosixPath]) -> _ToolOutput:
return self._wrap("\n".join(str(p) for p in paths))