Skip to content
This repository was archived by the owner on Mar 31, 2026. It is now read-only.

Commit 3618423

Browse files
authored
Merge pull request #42 from sicoyle/notify-when-stream-ready
fix: signal when stream reader thread is ready + logs + try/catch blocks
2 parents 92af27e + 3414d5a commit 3618423

1 file changed

Lines changed: 76 additions & 29 deletions

File tree

durabletask/worker.py

Lines changed: 76 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,11 @@
2626
TInput = TypeVar("TInput")
2727
TOutput = TypeVar("TOutput")
2828

29+
2930
class VersionNotRegisteredException(Exception):
3031
pass
3132

33+
3234
def _log_all_threads(logger: logging.Logger, context: str = ""):
3335
"""Helper function to log all currently active threads for debugging."""
3436
active_threads = threading.enumerate()
@@ -100,15 +102,23 @@ def __init__(self):
100102
self.latest_versioned_orchestrators_version_name = {}
101103
self.activities = {}
102104

103-
def add_orchestrator(self, fn: task.Orchestrator, version_name: Optional[str] = None, is_latest: bool = False) -> str:
105+
def add_orchestrator(
106+
self, fn: task.Orchestrator, version_name: Optional[str] = None, is_latest: bool = False
107+
) -> str:
104108
if fn is None:
105109
raise ValueError("An orchestrator function argument is required.")
106110

107111
name = task.get_name(fn)
108112
self.add_named_orchestrator(name, fn, version_name, is_latest)
109113
return name
110114

111-
def add_named_orchestrator(self, name: str, fn: task.Orchestrator, version_name: Optional[str] = None, is_latest: bool = False) -> None:
115+
def add_named_orchestrator(
116+
self,
117+
name: str,
118+
fn: task.Orchestrator,
119+
version_name: Optional[str] = None,
120+
is_latest: bool = False,
121+
) -> None:
112122
if not name:
113123
raise ValueError("A non-empty orchestrator name is required.")
114124

@@ -120,12 +130,16 @@ def add_named_orchestrator(self, name: str, fn: task.Orchestrator, version_name:
120130
if name not in self.versioned_orchestrators:
121131
self.versioned_orchestrators[name] = {}
122132
if version_name in self.versioned_orchestrators[name]:
123-
raise ValueError(f"The version '{version_name}' of '{name}' orchestrator already exists.")
133+
raise ValueError(
134+
f"The version '{version_name}' of '{name}' orchestrator already exists."
135+
)
124136
self.versioned_orchestrators[name][version_name] = fn
125137
if is_latest:
126138
self.latest_versioned_orchestrators_version_name[name] = version_name
127139

128-
def get_orchestrator(self, name: str, version_name: Optional[str] = None) -> Optional[tuple[task.Orchestrator, str]]:
140+
def get_orchestrator(
141+
self, name: str, version_name: Optional[str] = None
142+
) -> Optional[tuple[task.Orchestrator, str]]:
129143
if name in self.orchestrators:
130144
return self.orchestrators.get(name), None
131145

@@ -282,7 +296,7 @@ def __init__(
282296
self._channel_options = channel_options
283297
self._stop_timeout = stop_timeout
284298
self._current_channel: Optional[grpc.Channel] = None # Store channel reference for cleanup
285-
299+
self._stream_ready = threading.Event()
286300
# Use provided concurrency options or create default ones
287301
self._concurrency_options = (
288302
concurrency_options if concurrency_options is not None else ConcurrencyOptions()
@@ -298,7 +312,7 @@ def __init__(
298312
else:
299313
self._interceptors = None
300314

301-
self._async_worker_manager = _AsyncWorkerManager(self._concurrency_options)
315+
self._async_worker_manager = _AsyncWorkerManager(self._concurrency_options, self._logger)
302316

303317
@property
304318
def concurrency_options(self) -> ConcurrencyOptions:
@@ -323,6 +337,9 @@ def add_activity(self, fn: task.Activity) -> str:
323337
raise RuntimeError("Activities cannot be added while the worker is running.")
324338
return self._registry.add_activity(fn)
325339

340+
def is_worker_ready(self) -> bool:
341+
return self._stream_ready.is_set() and self._is_running
342+
326343
def start(self):
327344
"""Starts the worker on a background thread and begins listening for work items."""
328345
if self._is_running:
@@ -336,6 +353,8 @@ def run_loop():
336353
self._logger.info(f"Starting gRPC worker that connects to {self._host_address}")
337354
self._runLoop = Thread(target=run_loop, name="WorkerRunLoop")
338355
self._runLoop.start()
356+
if not self._stream_ready.wait(timeout=10):
357+
raise RuntimeError("Failed to establish work item stream connection within 10 seconds")
339358
self._is_running = True
340359

341360
# TODO: refactor this to be more readable and maintainable.
@@ -446,10 +465,13 @@ def should_invalidate_connection(rpc_error):
446465
maxConcurrentOrchestrationWorkItems=self._concurrency_options.maximum_concurrent_orchestration_work_items,
447466
maxConcurrentActivityWorkItems=self._concurrency_options.maximum_concurrent_activity_work_items,
448467
)
449-
self._response_stream = stub.GetWorkItems(get_work_items_request)
450-
self._logger.info(
451-
f"Successfully connected to {self._host_address}. Waiting for work items..."
452-
)
468+
try:
469+
self._response_stream = stub.GetWorkItems(get_work_items_request)
470+
self._logger.info(
471+
f"Successfully connected to {self._host_address}. Waiting for work items..."
472+
)
473+
except Exception:
474+
raise
453475

454476
# Use a thread to read from the blocking gRPC stream and forward to asyncio
455477
import queue
@@ -460,12 +482,15 @@ def should_invalidate_connection(rpc_error):
460482
# NOTE: This is equivalent to the Durabletask Go goroutine calling stream.Recv() in worker_grpc.go StartWorkItemListener()
461483
def stream_reader():
462484
try:
485+
if self._response_stream is None:
486+
return
463487
stream = self._response_stream
464488

465489
# Use next() to allow shutdown check between items
466490
# This matches Go's pattern: check ctx.Err() after each stream.Recv()
467491
while True:
468492
if self._shutdown.is_set():
493+
self._logger.debug("Stream reader: shutdown detected, exiting loop")
469494
break
470495

471496
try:
@@ -502,15 +527,26 @@ def stream_reader():
502527
self._logger.debug(
503528
f"Stream reader: exception during shutdown: {type(stream_error).__name__}: {stream_error}"
504529
)
530+
break
505531
# Other stream errors - put in queue for async loop to handle
506-
self._logger.warning(
507-
f"Stream reader: unexpected error: {stream_error}"
532+
self._logger.error(
533+
f"Stream reader: unexpected error: {type(stream_error).__name__}: {stream_error}",
534+
exc_info=True,
508535
)
509536
raise
510537

511538
except Exception as e:
539+
self._logger.exception(
540+
f"Stream reader: fatal exception in stream_reader: {type(e).__name__}: {e}",
541+
exc_info=True,
542+
)
512543
if not self._shutdown.is_set():
513-
work_item_queue.put(e)
544+
try:
545+
work_item_queue.put(e)
546+
except Exception as queue_error:
547+
self._logger.error(
548+
f"Stream reader: failed to put exception in queue: {queue_error}"
549+
)
514550
finally:
515551
# signal that the stream reader is done (ie matching Go's context cancellation)
516552
try:
@@ -519,16 +555,20 @@ def stream_reader():
519555
# queue might be closed so ignore this
520556
pass
521557

522-
import threading
523-
524558
# Use non-daemon thread (daemon=False) to ensure proper resource cleanup.
525559
# Daemon threads exit immediately when the main program exits, which prevents
526560
# cleanup of gRPC channel resources and OTel interceptors. Non-daemon threads
527561
# block shutdown until they complete, ensuring all resources are properly closed.
528562
current_reader_thread = threading.Thread(
529563
target=stream_reader, daemon=False, name="StreamReader"
530564
)
531-
current_reader_thread.start()
565+
566+
try:
567+
current_reader_thread.start()
568+
self._stream_ready.set()
569+
except Exception:
570+
raise
571+
532572
loop = asyncio.get_running_loop()
533573

534574
# NOTE: This is a blocking call that will wait for a work item to become available or the shutdown sentinel
@@ -760,7 +800,6 @@ def _execute_orchestrator(
760800
version = version or pb.OrchestrationVersion()
761801
version.patches.extend(result.patches)
762802

763-
764803
res = pb.OrchestratorResponse(
765804
instanceId=req.instanceId,
766805
actions=result.actions,
@@ -932,14 +971,12 @@ def set_failed(self, ex: Exception):
932971
)
933972
self._pending_actions[action.id] = action
934973

935-
936974
def set_version_not_registered(self):
937975
self._pending_actions.clear()
938976
self._completion_status = pb.ORCHESTRATION_STATUS_STALLED
939977
action = ph.new_orchestrator_version_not_available_action(self.next_sequence_number())
940978
self._pending_actions[action.id] = action
941979

942-
943980
def set_continued_as_new(self, new_input: Any, save_events: bool):
944981
if self._is_complete:
945982
return
@@ -1150,7 +1187,6 @@ def continue_as_new(self, new_input, *, save_events: bool = False) -> None:
11501187

11511188
self.set_continued_as_new(new_input, save_events)
11521189

1153-
11541190
def is_patched(self, patch_name: str) -> bool:
11551191
is_patched = self._is_patched(patch_name)
11561192
if is_patched:
@@ -1178,7 +1214,13 @@ class ExecutionResults:
11781214
version_name: Optional[str]
11791215
patches: Optional[list[str]]
11801216

1181-
def __init__(self, actions: list[pb.OrchestratorAction], encoded_custom_status: Optional[str], version_name: Optional[str] = None, patches: Optional[list[str]] = None):
1217+
def __init__(
1218+
self,
1219+
actions: list[pb.OrchestratorAction],
1220+
encoded_custom_status: Optional[str],
1221+
version_name: Optional[str] = None,
1222+
patches: Optional[list[str]] = None,
1223+
):
11821224
self.actions = actions
11831225
self.encoded_custom_status = encoded_custom_status
11841226
self.version_name = version_name
@@ -1254,8 +1296,8 @@ def execute(
12541296
return ExecutionResults(
12551297
actions=actions,
12561298
encoded_custom_status=ctx._encoded_custom_status,
1257-
version_name=getattr(ctx, '_version_name', None),
1258-
patches=ctx._encountered_patches
1299+
version_name=getattr(ctx, "_version_name", None),
1300+
patches=ctx._encountered_patches,
12591301
)
12601302

12611303
def process_event(self, ctx: _RuntimeOrchestrationContext, event: pb.HistoryEvent) -> None:
@@ -1283,9 +1325,10 @@ def process_event(self, ctx: _RuntimeOrchestrationContext, event: pb.HistoryEven
12831325
if ctx._orchestrator_version_name:
12841326
version_name = ctx._orchestrator_version_name
12851327

1286-
12871328
# TODO: Check if we already started the orchestration
1288-
fn, version_used = self._registry.get_orchestrator(event.executionStarted.name, version_name=version_name)
1329+
fn, version_used = self._registry.get_orchestrator(
1330+
event.executionStarted.name, version_name=version_name
1331+
)
12891332

12901333
if fn is None:
12911334
raise OrchestratorNotRegisteredError(
@@ -1693,7 +1736,7 @@ def _is_suspendable(event: pb.HistoryEvent) -> bool:
16931736

16941737

16951738
class _AsyncWorkerManager:
1696-
def __init__(self, concurrency_options: ConcurrencyOptions):
1739+
def __init__(self, concurrency_options: ConcurrencyOptions, logger: logging.Logger):
16971740
self.concurrency_options = concurrency_options
16981741
self.activity_semaphore = None
16991742
self.orchestration_semaphore = None
@@ -1709,14 +1752,16 @@ def __init__(self, concurrency_options: ConcurrencyOptions):
17091752
thread_name_prefix="DurableTask",
17101753
)
17111754
self._shutdown = False
1755+
self._logger = logger
17121756

17131757
def _ensure_queues_for_current_loop(self):
17141758
"""Ensure queues are bound to the current event loop."""
17151759
try:
17161760
current_loop = asyncio.get_running_loop()
17171761
if current_loop.is_closed():
17181762
return
1719-
except RuntimeError:
1763+
except RuntimeError as e:
1764+
self._logger.exception(f"Failed to get event loop {e}")
17201765
# No event loop running, can't create queues
17211766
return
17221767

@@ -1735,14 +1780,16 @@ def _ensure_queues_for_current_loop(self):
17351780
try:
17361781
while not self.activity_queue.empty():
17371782
existing_activity_items.append(self.activity_queue.get_nowait())
1738-
except Exception:
1783+
except Exception as e:
1784+
self._logger.debug(f"Failed to append to the activity queue {e}")
17391785
pass
17401786

17411787
if self.orchestration_queue is not None:
17421788
try:
17431789
while not self.orchestration_queue.empty():
17441790
existing_orchestration_items.append(self.orchestration_queue.get_nowait())
1745-
except Exception:
1791+
except Exception as e:
1792+
self._logger.debug(f"Failed to append to the orchestration queue {e}")
17461793
pass
17471794

17481795
# Create fresh queues for the current event loop

0 commit comments

Comments
 (0)