Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 75 additions & 8 deletions python/packages/core/agent_framework/_threads.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,10 +170,25 @@ def __init__(
if service_thread_id is not None and chat_message_store_state is not None:
raise AgentThreadException("A thread cannot have both a service_thread_id and a chat_message_store.")
self.service_thread_id = service_thread_id
self.chat_message_store_state: ChatMessageStoreState | None = None
self.chat_message_store_state: MutableMapping[str, Any] | ChatMessageStoreState | None = None
if chat_message_store_state is not None:
if isinstance(chat_message_store_state, dict):
self.chat_message_store_state = ChatMessageStoreState.from_dict(chat_message_store_state)
# Determine if this is a custom store that needs to preserve extra fields
# Standard ChatMessageStoreState has 'type' (from SerializationMixin) and 'messages'
# Create a temporary instance to get the expected fields dynamically
temp_state = ChatMessageStoreState()
standard_state_dict = temp_state.to_dict()
standard_fields = set(standard_state_dict.keys())
# Check if input has fields beyond what standard ChatMessageStoreState would have
extra_fields = set(chat_message_store_state.keys()) - standard_fields
if extra_fields:
# Custom store with additional fields (e.g., redis_url, thread_id, key_prefix)
# Preserve as dict to retain all custom configuration
self.chat_message_store_state = chat_message_store_state
else:
# Standard ChatMessageStoreState - convert for backward compatibility
# This handles both {"messages": []} and {"type": "...", "messages": [...]}
self.chat_message_store_state = ChatMessageStoreState.from_dict(chat_message_store_state)
elif isinstance(chat_message_store_state, ChatMessageStoreState):
self.chat_message_store_state = chat_message_store_state
else:
Expand Down Expand Up @@ -464,13 +479,38 @@ async def deserialize(
return cls()

if message_store is not None:
# Handle custom message stores (e.g., Redis) that need full state deserialization
try:
await message_store.add_messages(state.chat_message_store_state.messages, **kwargs)
if isinstance(state.chat_message_store_state, dict):
# Custom store: use update_from_state method
await message_store.update_from_state(state.chat_message_store_state, **kwargs)
elif isinstance(state.chat_message_store_state, ChatMessageStoreState):
# Legacy ChatMessageStoreState: extract messages
await message_store.add_messages(state.chat_message_store_state.messages, **kwargs)
except Exception as ex:
raise AgentThreadException("Failed to deserialize the provided message store.") from ex
return cls(message_store=message_store)

# No message_store provided, try to deserialize based on state type
try:
message_store = ChatMessageStore(messages=state.chat_message_store_state.messages, **kwargs)
if isinstance(state.chat_message_store_state, dict):
# Try to determine the store type from the serialized state
store_state_dict = state.chat_message_store_state
if "messages" in store_state_dict:
parsed_state = ChatMessageStoreState.from_dict(store_state_dict)
message_store = ChatMessageStore(messages=parsed_state.messages, **kwargs)
else:
raise AgentThreadException(
"Cannot deserialize custom message store without providing a message_store instance. "
"Please provide a message_store parameter to deserialize()."
)
elif isinstance(state.chat_message_store_state, ChatMessageStoreState):
# Legacy ChatMessageStoreState object
message_store = ChatMessageStore(messages=state.chat_message_store_state.messages, **kwargs)
else:
raise AgentThreadException("Invalid chat_message_store_state type.")
except AgentThreadException:
raise
except Exception as ex:
raise AgentThreadException("Failed to deserialize the message store.") from ex
return cls(message_store=message_store)
Expand Down Expand Up @@ -498,8 +538,35 @@ async def update_from_thread_state(
if state.chat_message_store_state is None:
return
if self.message_store is not None:
await self.message_store.add_messages(state.chat_message_store_state.messages, **kwargs)
# If we don't have a chat message store yet, create an in-memory one.
# Handle custom message stores (e.g., Redis) that need full state deserialization
try:
if isinstance(state.chat_message_store_state, dict):
# Custom store: use update_from_state method
await self.message_store.update_from_state(state.chat_message_store_state, **kwargs)
elif isinstance(state.chat_message_store_state, ChatMessageStoreState):
# Legacy ChatMessageStoreState: extract messages
await self.message_store.add_messages(state.chat_message_store_state.messages, **kwargs)
except Exception as ex:
raise AgentThreadException("Failed to update message store from state.") from ex
return
# Create the message store from the default.
self.message_store = ChatMessageStore(messages=state.chat_message_store_state.messages, **kwargs)

# No message_store exists, create one from the state
try:
if isinstance(state.chat_message_store_state, dict):
if "messages" in state.chat_message_store_state:
parsed_state = ChatMessageStoreState.from_dict(state.chat_message_store_state)
self.message_store = ChatMessageStore(messages=parsed_state.messages, **kwargs)
else:
raise AgentThreadException(
"Cannot create custom message store from state. "
"Please create a message store first and then call update_from_state()."
)
elif isinstance(state.chat_message_store_state, ChatMessageStoreState):
# Legacy ChatMessageStoreState object
self.message_store = ChatMessageStore(messages=state.chat_message_store_state.messages, **kwargs)
else:
raise AgentThreadException("Invalid chat_message_store_state type.")
except AgentThreadException:
raise
except Exception as ex:
raise AgentThreadException("Failed to create message store from state.") from ex