Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 38 additions & 11 deletions servers/custom/src/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,21 +399,39 @@ def assign_citation_ids(
}


# Thread-local storage for citation registry to prevent cross-request contamination
import threading
_citation_registry_local = threading.local()


class CitationRegistry:
_instances: Dict[int, Dict[str, Any]] = {}
"""Per-request citation registry to prevent cross-request contamination.

Uses thread-local storage to isolate citation state between concurrent requests.
Each request gets its own isolated registry instance.
"""

@classmethod
def reset(cls):
cls._instances = {}
"""Initialize or reset the citation registry for the current thread."""
if not hasattr(_citation_registry_local, '_instances'):
_citation_registry_local._instances = {}
else:
_citation_registry_local._instances.clear()

@classmethod
def get_or_create(cls, query_index: int) -> Dict[str, Any]:
if query_index not in cls._instances:
cls._instances[query_index] = {"registry": {}, "counter": 0}
return cls._instances[query_index]
"""Get or create registry entry for the current thread."""
if not hasattr(_citation_registry_local, '_instances'):
_citation_registry_local._instances = {}

if query_index not in _citation_registry_local._instances:
_citation_registry_local._instances[query_index] = {"registry": {}, "counter": 0}
return _citation_registry_local._instances[query_index]

@classmethod
def assign_id(cls, query_index: int, doc_text: str) -> int:
"""Assign unique citation ID to a document within the current request."""
state = cls.get_or_create(query_index)
doc_hash = doc_text.strip()

Expand Down Expand Up @@ -471,21 +489,30 @@ def assign_citation_ids_stateful(

class SurveyCPMCitationRegistry:
"""Citation registry for SurveyCPM pipeline.

Uses thread-local storage to isolate citation state between concurrent requests.
Each request gets its own isolated registry instance.

Maintains unique citation IDs across multiple search rounds for each query.
"""

_instances: Dict[int, Dict[str, Any]] = {}

@classmethod
def reset(cls):
cls._instances = {}
"""Initialize or reset the citation registry for the current thread."""
if not hasattr(_citation_registry_local, 'survey_instances'):
_citation_registry_local.survey_instances = {}
else:
_citation_registry_local.survey_instances.clear()

@classmethod
def get_or_create(cls, query_index: int) -> Dict[str, Any]:
if query_index not in cls._instances:
cls._instances[query_index] = {"registry": {}, "counter": 0}
return cls._instances[query_index]
"""Get or create registry entry for the current thread."""
if not hasattr(_citation_registry_local, 'survey_instances'):
_citation_registry_local.survey_instances = {}

if query_index not in _citation_registry_local.survey_instances:
_citation_registry_local.survey_instances[query_index] = {"registry": {}, "counter": 0}
return _citation_registry_local.survey_instances[query_index]

@classmethod
def assign_id(cls, query_index: int, doc_text: str) -> str:
Expand Down
107 changes: 107 additions & 0 deletions test_citation_registry_fix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
"""
Test to verify CitationRegistry race condition fix.

This test demonstrates that the CitationRegistry now uses thread-local storage
to prevent cross-request contamination in concurrent scenarios.

Bug: https://github.com/OpenBMB/UltraRAG/issues/394
"""

import threading
import sys
sys.path.insert(0, '/tmp/ultrarag-worktree/servers/custom/src')

from custom import CitationRegistry, SurveyCPMCitationRegistry


def test_citation_registry_thread_isolation():
"""Test that CitationRegistry isolates state between threads."""
results = {}

def worker(thread_id):
"""Simulate a request that initializes and uses citation registry."""
# Each thread resets and uses the registry
CitationRegistry.reset()

# Assign some citations
id1 = CitationRegistry.assign_id(0, "document 1")
id2 = CitationRegistry.assign_id(0, "document 2")
id3 = CitationRegistry.assign_id(0, "document 1") # Should return same as id1

results[thread_id] = {
'id1': id1,
'id2': id2,
'id3': id3,
}

# Run multiple threads concurrently
threads = []
for i in range(5):
t = threading.Thread(target=worker, args=(i,))
threads.append(t)
t.start()

# Wait for all threads to complete
for t in threads:
t.join()

# Verify each thread got consistent results
print("=== Thread Results ===")
for thread_id, result in sorted(results.items()):
print(f"Thread {thread_id}: id1={result['id1']}, id2={result['id2']}, id3={result['id3']}")
assert result['id1'] == 1, f"Thread {thread_id}: First doc should be ID 1"
assert result['id2'] == 2, f"Thread {thread_id}: Second doc should be ID 2"
assert result['id3'] == 1, f"Thread {thread_id}: Duplicate doc should return ID 1"

print("\n✅ All threads got consistent, isolated results!")
print("✅ No cross-request contamination detected!")
return True


def test_surveycpm_registry_thread_isolation():
"""Test that SurveyCPMCitationRegistry isolates state between threads."""
results = {}

def worker(thread_id):
"""Simulate a request that initializes and uses SurveyCPM registry."""
SurveyCPMCitationRegistry.reset()

id1 = SurveyCPMCitationRegistry.assign_id(0, "survey doc 1")
id2 = SurveyCPMCitationRegistry.assign_id(0, "survey doc 2")

results[thread_id] = {
'id1': id1,
'id2': id2,
}

threads = []
for i in range(3):
t = threading.Thread(target=worker, args=(i,))
threads.append(t)
t.start()

for t in threads:
t.join()

print("\n=== SurveyCPM Thread Results ===")
for thread_id, result in sorted(results.items()):
print(f"Thread {thread_id}: id1={result['id1']}, id2={result['id2']}")
assert result['id1'] == 'textid1', f"Thread {thread_id}: First doc should be textid1"
assert result['id2'] == 'textid2', f"Thread {thread_id}: Second doc should be textid2"

print("\n✅ SurveyCPM registry also thread-safe!")
return True


if __name__ == "__main__":
print("=" * 60)
print("Testing CitationRegistry thread isolation fix")
print("Bug: https://github.com/OpenBMB/UltraRAG/issues/394")
print("=" * 60)

test_citation_registry_thread_isolation()
test_surveycpm_registry_thread_isolation()

print("\n" + "=" * 60)
print("✅ ALL TESTS PASSED!")
print("=" * 60)