Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions lib/ruby_llm/active_record/chat_methods.rb
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,14 @@ def complete(...)
raise e
end

def step(...)
to_llm.step(...)
rescue RubyLLM::Error => e
cleanup_failed_messages if @message&.persisted? && @message.content.blank?
cleanup_orphaned_tool_results
raise e
end

private

def cleanup_failed_messages
Expand Down
61 changes: 38 additions & 23 deletions lib/ruby_llm/chat.rb
Original file line number Diff line number Diff line change
Expand Up @@ -136,35 +136,25 @@ def each(&)
messages.each(&)
end

def complete(&) # rubocop:disable Metrics/PerceivedComplexity
response = @provider.complete(
messages,
tools: @tools,
tool_prefs: @tool_prefs,
temperature: @temperature,
model: @model,
params: @params,
headers: @headers,
schema: @schema,
thinking: @thinking,
&wrap_streaming_block(&)
)
def complete(&)
response = step(&)
return response if response.is_a?(Tool::Halt)

response.tool_call? ? complete(&) : response
end

def step(&)
response = provider_complete(&)

@on[:new_message]&.call unless block_given?

if @schema && response.content.is_a?(String) && !response.tool_call?
begin
response.content = JSON.parse(response.content)
rescue JSON::ParserError
# If parsing fails, keep content as string
end
end
normalize_schema_response(response)

add_message response
@on[:end_message]&.call(response)

if response.tool_call?
handle_tool_calls(response, &)
handle_tool_calls(response, continue_loop: false, &) || response
else
response
end
Expand All @@ -186,6 +176,31 @@ def instance_variables

private

def provider_complete(&)
@provider.complete(
messages,
tools: @tools,
tool_prefs: @tool_prefs,
temperature: @temperature,
model: @model,
params: @params,
headers: @headers,
schema: @schema,
thinking: @thinking,
&wrap_streaming_block(&)
)
end

def normalize_schema_response(response)
return unless @schema && response.content.is_a?(String) && !response.tool_call?

begin
response.content = JSON.parse(response.content)
rescue JSON::ParserError
# If parsing fails, keep content as string
end
end

def normalize_schema_payload(raw_schema)
return nil if raw_schema.nil?
return raw_schema unless raw_schema.is_a?(Hash)
Expand Down Expand Up @@ -231,7 +246,7 @@ def wrap_streaming_block(&block)
end
end

def handle_tool_calls(response, &) # rubocop:disable Metrics/PerceivedComplexity
def handle_tool_calls(response, continue_loop: true, &) # rubocop:disable Metrics/PerceivedComplexity
halt_result = nil

response.tool_calls.each_value do |tool_call|
Expand All @@ -248,7 +263,7 @@ def handle_tool_calls(response, &) # rubocop:disable Metrics/PerceivedComplexity
end

reset_tool_choice if forced_tool_choice?
halt_result || complete(&)
halt_result || (continue_loop ? complete(&) : nil)
end

def execute_tool(tool_call)
Expand Down
66 changes: 66 additions & 0 deletions spec/ruby_llm/active_record/acts_as_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -738,6 +738,72 @@ def uploaded_file(path, type)
end
end

describe 'step' do
it 'executes a single tool-calling iteration without recursing' do
chat = Chat.create!(model: model).with_tool(Calculator)
provider = chat.to_llm.instance_variable_get(:@provider)
tool_call = RubyLLM::ToolCall.new(
id: 'call_1',
name: 'calculator',
arguments: { 'expression' => '2 + 2' }
)

allow(provider).to receive(:complete).and_return(
RubyLLM::Message.new(
role: :assistant,
content: '',
tool_calls: { tool_call.id => tool_call }
)
)

chat.add_message(role: :user, content: 'What is 2 + 2?')

response = chat.step

expect(response).to be_a(RubyLLM::Message)
expect(response.tool_call?).to be(true)
expect(provider).to have_received(:complete).once
expect(chat.messages.order(:id).pluck(:role)).to eq(%w[user assistant tool])
expect(chat.messages.order(:id).last.content).to eq('4')
end

it 'returns Halt when a tool halts' do
stub_const('HaltingTool', Class.new(RubyLLM::Tool) do
description 'A tool that halts'

def execute
halt('Task completed successfully')
end
end)

chat = Chat.create!(model: model).with_tool(HaltingTool)
provider = chat.to_llm.instance_variable_get(:@provider)
tool_call = RubyLLM::ToolCall.new(
id: 'call_1',
name: 'halting',
arguments: {}
)

allow(provider).to receive(:complete).and_return(
RubyLLM::Message.new(
role: :assistant,
content: '',
tool_calls: { tool_call.id => tool_call }
)
)

chat.add_message(role: :user, content: 'Execute the halting tool')

response = chat.step

expect(response).to be_a(RubyLLM::Tool::Halt)
expect(response.content).to eq('Task completed successfully')
expect(provider).to have_received(:complete).once
expect(chat.messages.order(:id).pluck(:role)).to eq(%w[user assistant tool])
expect(chat.messages.order(:id).last.content).to eq('Task completed successfully')
end
end

describe 'error recovery' do
it 'does not clean up complete tool interactions when error occurs after tool execution' do
chat = Chat.create!(model: model)
Expand Down
54 changes: 54 additions & 0 deletions spec/ruby_llm/chat_tools_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -603,6 +603,60 @@ def tool_result_message_for(chat, tool_call)
expect(response.content).to eq('Task completed successfully')
end

it 'step executes a single tool-calling iteration without recursing' do
chat = RubyLLM.chat.with_tool(Weather)
provider = chat.instance_variable_get(:@provider)
tool_call = RubyLLM::ToolCall.new(
id: 'call_1',
name: 'weather',
arguments: { 'latitude' => 52.52, 'longitude' => 13.405 }
)

allow(provider).to receive(:complete).and_return(
RubyLLM::Message.new(
role: :assistant,
content: '',
tool_calls: { tool_call.id => tool_call }
)
)

chat.add_message(role: :user, content: "What's the weather in Berlin?")

response = chat.step

expect(response).to be_a(RubyLLM::Message)
expect(response.tool_call?).to be(true)
expect(provider).to have_received(:complete).once
expect(chat.messages.map(&:role)).to eq(%i[user assistant tool])
expect(chat.messages.last.content).to include('15')
end

it 'step returns Halt when a tool halts' do
chat = RubyLLM.chat.with_tool(HaltingTool)
provider = chat.instance_variable_get(:@provider)
tool_call = RubyLLM::ToolCall.new(
id: 'call_1',
name: 'halting',
arguments: {}
)

allow(provider).to receive(:complete).and_return(
RubyLLM::Message.new(
role: :assistant,
content: '',
tool_calls: { tool_call.id => tool_call }
)
)

chat.add_message(role: :user, content: 'Execute the halting tool')

response = chat.step

expect(response).to be_a(RubyLLM::Tool::Halt)
expect(response.content).to eq('Task completed successfully')
expect(provider).to have_received(:complete).once
end

it 'does not continue conversation after halt' do
call_count = 0
original_complete = described_class.instance_method(:complete)
Expand Down