From 893ef6ce1b972fe6bfbf0e849165a568d7f4c9a7 Mon Sep 17 00:00:00 2001 From: MestreY0d4-Uninter <241404605+MestreY0d4-Uninter@users.noreply.github.com> Date: Wed, 22 Apr 2026 14:32:46 +0000 Subject: [PATCH] fix: prevent cross-request contamination in CitationRegistry with thread-local storage - Replace class-level _instances dict with thread-local storage - CitationRegistry and SurveyCPMCitationRegistry now isolate state per thread - Fixes race condition where concurrent requests could corrupt citation IDs - Add test case demonstrating thread isolation Fixes: #394 --- servers/custom/src/custom.py | 49 ++++++++++++---- test_citation_registry_fix.py | 107 ++++++++++++++++++++++++++++++++++ 2 files changed, 145 insertions(+), 11 deletions(-) create mode 100644 test_citation_registry_fix.py diff --git a/servers/custom/src/custom.py b/servers/custom/src/custom.py index f7e61a86..07a9bea4 100644 --- a/servers/custom/src/custom.py +++ b/servers/custom/src/custom.py @@ -399,21 +399,39 @@ def assign_citation_ids( } +# Thread-local storage for citation registry to prevent cross-request contamination +import threading +_citation_registry_local = threading.local() + + class CitationRegistry: - _instances: Dict[int, Dict[str, Any]] = {} + """Per-request citation registry to prevent cross-request contamination. + + Uses thread-local storage to isolate citation state between concurrent requests. + Each request gets its own isolated registry instance. + """ @classmethod def reset(cls): - cls._instances = {} + """Initialize or reset the citation registry for the current thread.""" + if not hasattr(_citation_registry_local, '_instances'): + _citation_registry_local._instances = {} + else: + _citation_registry_local._instances.clear() @classmethod def get_or_create(cls, query_index: int) -> Dict[str, Any]: - if query_index not in cls._instances: - cls._instances[query_index] = {"registry": {}, "counter": 0} - return cls._instances[query_index] + """Get or create registry entry for the current thread.""" + if not hasattr(_citation_registry_local, '_instances'): + _citation_registry_local._instances = {} + + if query_index not in _citation_registry_local._instances: + _citation_registry_local._instances[query_index] = {"registry": {}, "counter": 0} + return _citation_registry_local._instances[query_index] @classmethod def assign_id(cls, query_index: int, doc_text: str) -> int: + """Assign unique citation ID to a document within the current request.""" state = cls.get_or_create(query_index) doc_hash = doc_text.strip() @@ -471,21 +489,30 @@ def assign_citation_ids_stateful( class SurveyCPMCitationRegistry: """Citation registry for SurveyCPM pipeline. + + Uses thread-local storage to isolate citation state between concurrent requests. + Each request gets its own isolated registry instance. Maintains unique citation IDs across multiple search rounds for each query. """ - _instances: Dict[int, Dict[str, Any]] = {} - @classmethod def reset(cls): - cls._instances = {} + """Initialize or reset the citation registry for the current thread.""" + if not hasattr(_citation_registry_local, 'survey_instances'): + _citation_registry_local.survey_instances = {} + else: + _citation_registry_local.survey_instances.clear() @classmethod def get_or_create(cls, query_index: int) -> Dict[str, Any]: - if query_index not in cls._instances: - cls._instances[query_index] = {"registry": {}, "counter": 0} - return cls._instances[query_index] + """Get or create registry entry for the current thread.""" + if not hasattr(_citation_registry_local, 'survey_instances'): + _citation_registry_local.survey_instances = {} + + if query_index not in _citation_registry_local.survey_instances: + _citation_registry_local.survey_instances[query_index] = {"registry": {}, "counter": 0} + return _citation_registry_local.survey_instances[query_index] @classmethod def assign_id(cls, query_index: int, doc_text: str) -> str: diff --git a/test_citation_registry_fix.py b/test_citation_registry_fix.py new file mode 100644 index 00000000..adf27838 --- /dev/null +++ b/test_citation_registry_fix.py @@ -0,0 +1,107 @@ +""" +Test to verify CitationRegistry race condition fix. + +This test demonstrates that the CitationRegistry now uses thread-local storage +to prevent cross-request contamination in concurrent scenarios. + +Bug: https://github.com/OpenBMB/UltraRAG/issues/394 +""" + +import threading +import sys +sys.path.insert(0, '/tmp/ultrarag-worktree/servers/custom/src') + +from custom import CitationRegistry, SurveyCPMCitationRegistry + + +def test_citation_registry_thread_isolation(): + """Test that CitationRegistry isolates state between threads.""" + results = {} + + def worker(thread_id): + """Simulate a request that initializes and uses citation registry.""" + # Each thread resets and uses the registry + CitationRegistry.reset() + + # Assign some citations + id1 = CitationRegistry.assign_id(0, "document 1") + id2 = CitationRegistry.assign_id(0, "document 2") + id3 = CitationRegistry.assign_id(0, "document 1") # Should return same as id1 + + results[thread_id] = { + 'id1': id1, + 'id2': id2, + 'id3': id3, + } + + # Run multiple threads concurrently + threads = [] + for i in range(5): + t = threading.Thread(target=worker, args=(i,)) + threads.append(t) + t.start() + + # Wait for all threads to complete + for t in threads: + t.join() + + # Verify each thread got consistent results + print("=== Thread Results ===") + for thread_id, result in sorted(results.items()): + print(f"Thread {thread_id}: id1={result['id1']}, id2={result['id2']}, id3={result['id3']}") + assert result['id1'] == 1, f"Thread {thread_id}: First doc should be ID 1" + assert result['id2'] == 2, f"Thread {thread_id}: Second doc should be ID 2" + assert result['id3'] == 1, f"Thread {thread_id}: Duplicate doc should return ID 1" + + print("\nāœ… All threads got consistent, isolated results!") + print("āœ… No cross-request contamination detected!") + return True + + +def test_surveycpm_registry_thread_isolation(): + """Test that SurveyCPMCitationRegistry isolates state between threads.""" + results = {} + + def worker(thread_id): + """Simulate a request that initializes and uses SurveyCPM registry.""" + SurveyCPMCitationRegistry.reset() + + id1 = SurveyCPMCitationRegistry.assign_id(0, "survey doc 1") + id2 = SurveyCPMCitationRegistry.assign_id(0, "survey doc 2") + + results[thread_id] = { + 'id1': id1, + 'id2': id2, + } + + threads = [] + for i in range(3): + t = threading.Thread(target=worker, args=(i,)) + threads.append(t) + t.start() + + for t in threads: + t.join() + + print("\n=== SurveyCPM Thread Results ===") + for thread_id, result in sorted(results.items()): + print(f"Thread {thread_id}: id1={result['id1']}, id2={result['id2']}") + assert result['id1'] == 'textid1', f"Thread {thread_id}: First doc should be textid1" + assert result['id2'] == 'textid2', f"Thread {thread_id}: Second doc should be textid2" + + print("\nāœ… SurveyCPM registry also thread-safe!") + return True + + +if __name__ == "__main__": + print("=" * 60) + print("Testing CitationRegistry thread isolation fix") + print("Bug: https://github.com/OpenBMB/UltraRAG/issues/394") + print("=" * 60) + + test_citation_registry_thread_isolation() + test_surveycpm_registry_thread_isolation() + + print("\n" + "=" * 60) + print("āœ… ALL TESTS PASSED!") + print("=" * 60)