Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 49 additions & 48 deletions WebAgent/WebSailor/src/react_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ def __init__(self,

def call_server(self, msgs, max_tries=10):
# Set OpenAI API key and base URL using vLLM API server
openai_api_key = "EMPTY"
openai_api_base = "http://127.0.0.1:6001/v1"
openai_api_key = os.getenv("OPENAI_API_KEY", "EMPTY")
openai_api_base = os.getenv("OPENAI_API_BASE", "http://127.0.0.1:6001/v1")

client = OpenAI(
api_key=openai_api_key,
Expand Down Expand Up @@ -77,15 +77,53 @@ def count_tokens(self, messages, model="gpt-4o"):

return len(tokenizer.encode(full_prompt))

def _process_tool_call(self, content, messages):
if '<tool_call>' in content and '</tool_call>' in content:
tool_call = content.split('<tool_call>')[1].split('</tool_call>')[0]
try:
tool_call = json.loads(tool_call)
tool_name = tool_call.get('name', '')
tool_args = tool_call.get('arguments', {})
result = self._call_tool(tool_name, tool_args)
except:
result = 'Error: Tool call is not a valid JSON. Tool call must contain a valid "name" and "arguments" field.'
result = "<tool_response>\n" + result + "\n</tool_response>"
messages.append({"role": "user", "content": result})
return messages

def _handle_token_limit(self, messages, question, answer, rollout_id):
print(f"Token count exceeds limit")

messages[-1]['content'] = "You have now reached the maximum context length you can handle. You should stop making tool calls and, based on all the information above, think again and provide what you consider the most likely answer in the following format:<think>your final thinking</think>\n<answer>your answer</answer>"
content = self.call_server(messages)
messages.append({"role": "assistant", "content": content.strip()})
if '<answer>' in content and '</answer>' in content:
prediction = messages[-1]['content'].split('<answer>')[1].split('</answer>')[0]
termination = 'generate an answer as token limit reached'
else:
prediction = messages[-1]['content']
termination = 'format error: generate an answer as token limit reached'
return self._generate_result(question, answer, rollout_id, messages, prediction, termination)

def _generate_result(self, question, answer, rollout_id, messages, prediction, termination):
return {
"question": question,
"answer": answer,
"rollout_id": rollout_id,
"messages": messages,
"prediction": prediction,
"termination": termination
}

def _run(self, data: str, model: str, user_prompt: str, **kwargs) -> List[List[Message]]:
self.model=model
try:
question = data['item']['question']
except:
raw_msg = data['item']['messages'][1]["content"]
question = raw_msg.split("User:")[1].strip() if "User:" in raw_msg else raw_msg
question = data.get('item', {}).get('question', '')
if not question:
raw_msg = data.get('item', {}).get('messages', [{}, {}])[1].get("content", "")
question = raw_msg.split("User:")[1].strip() if "User:" in raw_msg else raw_msg

answer = data['item']['answer']
answer = data.get('item', {}).get('answer', '')
rollout_id = data.get('rollout_id', '')
self.user_prompt = user_prompt
self.user_prompt = self.user_prompt + question
messages = [{"role": "system", "content": self.system_message}, {"role": "user", "content": self.user_prompt}]
Expand All @@ -100,17 +138,7 @@ def _run(self, data: str, model: str, user_prompt: str, **kwargs) -> List[List[M
pos = content.find('<tool_response>')
content = content[:pos]
messages.append({"role": "assistant", "content": content.strip()})
if '<tool_call>' in content and '</tool_call>' in content:
tool_call = content.split('<tool_call>')[1].split('</tool_call>')[0]
try:
tool_call = json.loads(tool_call)
tool_name = tool_call.get('name', '')
tool_args = tool_call.get('arguments', {})
result = self._call_tool(tool_name, tool_args)
except:
result = 'Error: Tool call is not a valid JSON. Tool call must contain a valid "name" and "arguments" field.'
result = "<tool_response>\n" + result + "\n</tool_response>"
messages.append({"role": "user", "content": result})
messages = self._process_tool_call(content, messages)
if '<answer>' in content and '</answer>' in content:
termination = 'answer'
break
Expand All @@ -122,26 +150,7 @@ def _run(self, data: str, model: str, user_prompt: str, **kwargs) -> List[List[M
print(f"round: {round}, token count: {token_count}")

if token_count > max_tokens:
print(f"Token count exceeds limit: {token_count} > {max_tokens}")

messages[-1]['content'] = "You have now reached the maximum context length you can handle. You should stop making tool calls and, based on all the information above, think again and provide what you consider the most likely answer in the following format:<think>your final thinking</think>\n<answer>your answer</answer>"
content = self.call_server(messages)
messages.append({"role": "assistant", "content": content.strip()})
if '<answer>' in content and '</answer>' in content:
prediction = messages[-1]['content'].split('<answer>')[1].split('</answer>')[0]
termination = 'generate an answer as token limit reached'
else:
prediction = messages[-1]['content']
termination = 'format error: generate an answer as token limit reached'
result = {
"question": question,
"answer": answer,
"rollout_id": data['rollout_id'],
"messages": messages,
"prediction": prediction,
"termination": termination
}
return result
return self._handle_token_limit(messages, question, answer, rollout_id)

if '<answer>' in messages[-1]['content']:
prediction = messages[-1]['content'].split('<answer>')[1].split('</answer>')[0]
Expand All @@ -151,12 +160,4 @@ def _run(self, data: str, model: str, user_prompt: str, **kwargs) -> List[List[M
termination = 'answer not found'
if num_llm_calls_available == 0:
termination = 'exceed available llm calls'
result = {
"question": question,
"answer": answer,
"rollout_id": data['rollout_id'],
"messages": messages,
"prediction": prediction,
"termination": termination
}
return result
return self._generate_result(question, answer, rollout_id, messages, prediction, termination)