Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions flocks/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,20 @@ async def lifespan(app: FastAPI):
await _run_startup_phase(log, "storage.init", Storage.init)
log.info("storage.initialized")

async def _recover_orphan_tool_parts() -> None:
from flocks.session.orphan_tools import abort_all_orphan_running_parts

repaired = await abort_all_orphan_running_parts()
if repaired:
log.info("session.orphan_tools.recovered", {"count": repaired})

_schedule_startup_phase(
app,
log,
"session.recover_orphan_tools",
_recover_orphan_tool_parts,
)

# Ensure default device room exists, then migrate legacy device API
# configs from flocks.json → device_integrations table.
try:
Expand Down
7 changes: 5 additions & 2 deletions flocks/server/routes/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -1187,7 +1187,12 @@ async def get_session_messages(
_require_session_read_access(session, current_user)

try:
from flocks.session.orphan_tools import abort_orphan_running_parts_in_messages
from flocks.session.core.status import SessionStatus

messages_with_parts = await Message.list_with_parts(sessionID, include_archived=True)
if sessionID not in SessionStatus.get_busy_session_ids():
await abort_orphan_running_parts_in_messages(sessionID, messages_with_parts)
if limit:
messages_with_parts = messages_with_parts[-limit:]

Expand Down Expand Up @@ -3847,5 +3852,3 @@ async def clear_session(sessionID: str, http_request: Request):
except Exception as e:
log.error("session.clear.error", {"sessionID": sessionID, "error": str(e)})
raise HTTPException(status_code=500, detail=f"Failed to clear session: {str(e)}")


100 changes: 100 additions & 0 deletions flocks/session/orphan_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
"""Recovery helpers for tool calls left running by interrupted processes."""

import time
from typing import Iterable, Optional

from flocks.session.message import Message, MessageWithParts, ToolPart, ToolStateError
from flocks.session.session import SessionInfo
from flocks.storage.storage import Storage
from flocks.utils.log import Log


log = Log.create(service="session.orphan_tools")


INTERRUPTED_TOOL_ERROR = "Interrupted by server restart"


def _build_interrupted_error_state(state: object, now_ms: int) -> ToolStateError:
"""Create the terminal error state used for recovered orphaned tools."""
time_info = getattr(state, "time", {}) or {}
start_ms = time_info.get("start", now_ms)

return ToolStateError(
status="error",
input=getattr(state, "input", {}),
error=INTERRUPTED_TOOL_ERROR,
metadata=getattr(state, "metadata", None),
time={"start": start_ms, "end": now_ms},
)


async def abort_orphan_running_parts_in_messages(
session_id: str,
messages_with_parts: Iterable[MessageWithParts],
) -> int:
"""Mark running tool parts as interrupted using preloaded message parts."""
now_ms = int(time.time() * 1000)
repaired = 0

for msg_with_parts in messages_with_parts:
message_id = msg_with_parts.info.id
for part in msg_with_parts.parts:
if not isinstance(part, ToolPart):
continue
state = part.state
if getattr(state, "status", None) != "running":
continue

part.state = _build_interrupted_error_state(state, now_ms)
await Message.store_part(session_id, message_id, part)
repaired += 1

if repaired:
log.info("session.orphan_tools.aborted", {
"session_id": session_id,
"count": repaired,
})
return repaired


async def abort_orphan_running_parts(session_id: str) -> int:
"""Mark persisted running tool parts as interrupted errors."""
messages_with_parts = await Message.list_with_parts(session_id)
return await abort_orphan_running_parts_in_messages(session_id, messages_with_parts)


async def abort_orphan_running_parts_for_sessions(
session_ids: Iterable[str],
*,
skip_busy: bool = False,
) -> int:
"""Best-effort recovery for a known set of sessions."""
total = 0
for session_id in dict.fromkeys(session_ids):
try:
if skip_busy:
from flocks.session.core.status import SessionStatus

if session_id in SessionStatus.get_busy_session_ids():
continue
total += await abort_orphan_running_parts(session_id)
except Exception as exc:
log.warn("session.orphan_tools.session_failed", {
"session_id": session_id,
"error": str(exc),
})
return total


async def abort_all_orphan_running_parts(*, limit: Optional[int] = None) -> int:
"""Best-effort startup recovery for all persisted sessions."""
entries = await Storage.list_entries(prefix="session:", model=SessionInfo)
session_ids = [
session.id
for _, session in entries
if getattr(session, "status", None) != "deleted"
]
if limit is not None:
session_ids = session_ids[:limit]
return await abort_orphan_running_parts_for_sessions(session_ids, skip_busy=True)
55 changes: 3 additions & 52 deletions flocks/session/session_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,9 @@ async def run(
# Mark orphaned running tool parts as error (e.g. from server restart).
# Wrapped in try/except so cleanup failures never block the session loop.
try:
await cls._abort_orphan_running_parts(session_id)
from flocks.session.orphan_tools import abort_orphan_running_parts

await abort_orphan_running_parts(session_id)
except Exception as exc:
log.warn("loop.orphan_cleanup_failed", {
"session_id": session_id,
Expand Down Expand Up @@ -394,57 +396,6 @@ async def run(
except Exception as exc:
log.warn("loop.idle.event_error", {"error": str(exc)})

@classmethod
async def _abort_orphan_running_parts(cls, session_id: str) -> None:
"""Mark any tool parts stuck in 'running' status as error.

When the server restarts while a synchronous tool (e.g. delegate_task)
is executing, the tool part stays 'running' in storage forever. On the
next session loop start we know nothing is actually executing yet, so
any 'running' parts are orphans.
"""
import time as _time
from flocks.session.message import (
ToolPart, ToolStateError,
)

messages = await Message.list(session_id)
now_ms = int(_time.time() * 1000)
repaired = 0

for msg in messages:
parts = await Message.parts(msg.id, session_id)
for part in parts:
if not isinstance(part, ToolPart):
continue
state = part.state
if getattr(state, "status", None) != "running":
continue

time_info = getattr(state, "time", {}) or {}
start_ms = time_info.get("start", now_ms)

error_state = ToolStateError(
status="error",
input=getattr(state, "input", {}),
error="Interrupted by server restart",
time={"start": start_ms, "end": now_ms},
)
# Preserve metadata (e.g. sessionId) so the card still works
meta = getattr(state, "metadata", None)
if meta:
error_state.metadata = meta

part.state = error_state
await Message.store_part(session_id, msg.id, part)
repaired += 1

if repaired:
log.info("loop.orphan_parts_aborted", {
"session_id": session_id,
"count": repaired,
})

@staticmethod
async def _resolve_model(
session: Any,
Expand Down
93 changes: 93 additions & 0 deletions tests/server/routes/test_session_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,15 @@
from fastapi import HTTPException, status
from httpx import AsyncClient
from flocks.auth.context import AuthUser
from flocks.session.core.status import SessionStatus, SessionStatusBusy
from flocks.session.message import (
Message,
MessageRole,
ToolPart,
ToolStateError,
ToolStateRunning,
)
from flocks.session.orphan_tools import INTERRUPTED_TOOL_ERROR
from flocks.session.session import Session

# ===========================================================================
Expand Down Expand Up @@ -178,6 +187,90 @@ async def test_send_message_noReply(self, client: AsyncClient, session_id: str):
for m in messages
)

@pytest.mark.asyncio
async def test_list_messages_keeps_running_tool_when_session_busy(
self,
client: AsyncClient,
session_id: str,
):
msg = await Message.create(session_id, MessageRole.ASSISTANT, "")
part = ToolPart(
id="part_busy_running",
sessionID=session_id,
messageID=msg.id,
callID="call_busy_running",
tool="bash",
state=ToolStateRunning(
input={"cmd": "sleep 60"},
time={"start": 1000},
),
)
await Message.store_part(session_id, msg.id, part)

SessionStatus.set(session_id, SessionStatusBusy())
try:
resp = await client.get(f"/api/session/{session_id}/message")
finally:
SessionStatus.clear(session_id)

assert resp.status_code == status.HTTP_200_OK
parts = await Message.parts(msg.id, session_id)
running_part = next(p for p in parts if p.id == "part_busy_running")
assert running_part.state.status == "running"

@pytest.mark.asyncio
async def test_list_messages_recovers_orphan_running_tool_when_session_idle(
self,
client: AsyncClient,
session_id: str,
):
msg = await Message.create(session_id, MessageRole.ASSISTANT, "")
part = ToolPart(
id="part_idle_running",
sessionID=session_id,
messageID=msg.id,
callID="call_idle_running",
tool="bash",
state=ToolStateRunning(
input={"cmd": "sleep 60"},
metadata={"sessionId": "ses_child"},
time={"start": 1000},
),
)
await Message.store_part(session_id, msg.id, part)

resp = await client.get(f"/api/session/{session_id}/message")

assert resp.status_code == status.HTTP_200_OK
parts = await Message.parts(msg.id, session_id)
repaired_part = next(p for p in parts if p.id == "part_idle_running")
assert isinstance(repaired_part.state, ToolStateError)
assert repaired_part.state.status == "error"
assert repaired_part.state.error == INTERRUPTED_TOOL_ERROR
assert repaired_part.state.metadata == {"sessionId": "ses_child"}
assert repaired_part.state.time["start"] == 1000
assert repaired_part.state.time["end"] >= 1000

@pytest.mark.asyncio
async def test_list_messages_uses_preloaded_orphan_recovery_path(
self,
client: AsyncClient,
session_id: str,
monkeypatch: pytest.MonkeyPatch,
):
from flocks.session import orphan_tools

preloaded_recovery = AsyncMock(return_value=0)
legacy_recovery = AsyncMock(side_effect=AssertionError("legacy recovery should not be called"))
monkeypatch.setattr(orphan_tools, "abort_orphan_running_parts_in_messages", preloaded_recovery)
monkeypatch.setattr(orphan_tools, "abort_orphan_running_parts", legacy_recovery)

resp = await client.get(f"/api/session/{session_id}/message")

assert resp.status_code == status.HTTP_200_OK
preloaded_recovery.assert_awaited_once()
legacy_recovery.assert_not_called()


# ===========================================================================
# Delete permissions (single-admin model)
Expand Down
Loading