diff --git a/sdks/python/apache_beam/runners/worker/statesampler_fast.pyx b/sdks/python/apache_beam/runners/worker/statesampler_fast.pyx index 45700a0b0f81..7075ef47017d 100644 --- a/sdks/python/apache_beam/runners/worker/statesampler_fast.pyx +++ b/sdks/python/apache_beam/runners/worker/statesampler_fast.pyx @@ -217,7 +217,13 @@ cdef class ScopedState(object): @property def nsecs(self): - return self._nsecs + cdef pythread.PyThread_type_lock lock = self.sampler.lock + cdef int64_t val + with nogil: + pythread.PyThread_acquire_lock(lock, pythread.WAIT_LOCK) + val = self._nsecs + pythread.PyThread_release_lock(lock) + return val def sampled_seconds(self): return 1e-9 * self.nsecs diff --git a/sdks/python/apache_beam/runners/worker/statesampler_test.py b/sdks/python/apache_beam/runners/worker/statesampler_test.py index 0d0ce1d2c8dc..0495dc507da0 100644 --- a/sdks/python/apache_beam/runners/worker/statesampler_test.py +++ b/sdks/python/apache_beam/runners/worker/statesampler_test.py @@ -19,6 +19,7 @@ # pytype: skip-file import logging +import threading import time import unittest from unittest import mock @@ -312,6 +313,48 @@ def test_do_operation_process_timer_with_exception(self, mock_get_dofn_specs): actual_value, state_duration_ms * (1.0 - margin_of_error)) _LOGGER.info("Exception test finished successfully.") + def test_concurrent_nsecs_reads(self): + """Verify that concurrent reads of nsecs behave correctly under thread contention. + + This test runs state transitions on the main thread and reads `nsecs` properties + from a secondary Python thread, while the background sampler thread is concurrently + updating counter states. + """ + if not statesampler.FAST_SAMPLER: + self.skipTest('test_concurrent_nsecs_reads requires FAST_SAMPLER') + + counter_factory = CounterFactory() + sampler = statesampler.StateSampler( + 'concurrent', counter_factory, sampling_period_ms=1) + + sampler.start() + reader_thread = None + try: + state_a = sampler.scoped_state('step1', 'statea') + state_b = sampler.scoped_state('step1', 'stateb') + + stop_signal = False + + def read_nsecs_loop(): + while not stop_signal: + _ = state_a.nsecs + _ = state_b.nsecs + time.sleep(0.001) + + reader_thread = threading.Thread(target=read_nsecs_loop) + reader_thread.start() + + for _ in range(100): + with state_a: + time.sleep(0.001) + with state_b: + time.sleep(0.001) + finally: + if reader_thread is not None: + stop_signal = True + reader_thread.join() + sampler.stop() + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO)