diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/aio/dapr_workflow_client.py b/ext/dapr-ext-workflow/dapr/ext/workflow/aio/dapr_workflow_client.py index 781a14f9..e156ce9b 100644 --- a/ext/dapr-ext-workflow/dapr/ext/workflow/aio/dapr_workflow_client.py +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/aio/dapr_workflow_client.py @@ -26,6 +26,7 @@ from durabletask.aio import client as aioclient from grpc.aio import AioRpcError +from dapr.aio.clients.grpc.interceptors import DaprClientTimeoutInterceptorAsync from dapr.clients import DaprInternalError from dapr.clients.http.client import DAPR_API_TOKEN_HEADER from dapr.conf import settings @@ -68,6 +69,7 @@ def __init__( secure_channel=uri.tls, log_handler=options.log_handler, log_formatter=options.log_formatter, + interceptors=[DaprClientTimeoutInterceptorAsync()], ) async def schedule_new_workflow( diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/dapr_workflow_client.py b/ext/dapr-ext-workflow/dapr/ext/workflow/dapr_workflow_client.py index 527ca4a6..a1a056d6 100644 --- a/ext/dapr-ext-workflow/dapr/ext/workflow/dapr_workflow_client.py +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/dapr_workflow_client.py @@ -27,6 +27,7 @@ from grpc import RpcError from dapr.clients import DaprInternalError +from dapr.clients.grpc.interceptors import DaprClientTimeoutInterceptor from dapr.clients.http.client import DAPR_API_TOKEN_HEADER from dapr.conf import settings from dapr.conf.helpers import GrpcEndpoint @@ -71,6 +72,7 @@ def __init__( secure_channel=uri.tls, log_handler=options.log_handler, log_formatter=options.log_formatter, + interceptors=[DaprClientTimeoutInterceptor()], ) def schedule_new_workflow( diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_runtime.py b/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_runtime.py index e2bf50d4..4c47c566 100644 --- a/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_runtime.py +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_runtime.py @@ -27,6 +27,7 @@ from durabletask import task, worker from dapr.clients import DaprInternalError +from dapr.clients.grpc.interceptors import DaprClientTimeoutInterceptor from dapr.clients.http.client import DAPR_API_TOKEN_HEADER from dapr.conf import settings from dapr.conf.helpers import GrpcEndpoint @@ -71,13 +72,17 @@ def __init__( raise DaprInternalError(f'{error}') from error options = self._logger.get_options() + all_interceptors = [] + if interceptors: + all_interceptors.extend(interceptors) + all_interceptors.append(DaprClientTimeoutInterceptor()) self.__worker = worker.TaskHubGrpcWorker( host_address=uri.endpoint, metadata=metadata, secure_channel=uri.tls, log_handler=options.log_handler, log_formatter=options.log_formatter, - interceptors=interceptors, + interceptors=all_interceptors, concurrency_options=worker.ConcurrencyOptions( maximum_concurrent_activity_work_items=maximum_concurrent_activity_work_items, maximum_concurrent_orchestration_work_items=maximum_concurrent_orchestration_work_items, diff --git a/ext/dapr-ext-workflow/tests/test_workflow_client.py b/ext/dapr-ext-workflow/tests/test_workflow_client.py index 7d66d68b..bc7f422f 100644 --- a/ext/dapr-ext-workflow/tests/test_workflow_client.py +++ b/ext/dapr-ext-workflow/tests/test_workflow_client.py @@ -19,11 +19,12 @@ from unittest import mock import durabletask.internal.orchestrator_service_pb2 as pb -from dapr.ext.workflow.dapr_workflow_client import DaprWorkflowClient -from dapr.ext.workflow.dapr_workflow_context import DaprWorkflowContext from durabletask import client from grpc import RpcError +from dapr.ext.workflow.dapr_workflow_client import DaprWorkflowClient +from dapr.ext.workflow.dapr_workflow_context import DaprWorkflowContext + mock_schedule_result = 'workflow001' mock_raise_event_result = 'event001' mock_terminate_result = 'terminate001' @@ -112,6 +113,20 @@ def _inner_get_orchestration_state(self, instance_id, state: client.Orchestratio ) +class WorkflowClientTimeoutInterceptorTest(unittest.TestCase): + def test_timeout_interceptor_is_passed_to_client(self): + with mock.patch('durabletask.client.TaskHubGrpcClient') as mock_client_cls: + DaprWorkflowClient() + mock_client_cls.assert_called_once() + call_kwargs = mock_client_cls.call_args[1] + interceptors = call_kwargs['interceptors'] + self.assertEqual(len(interceptors), 1) + from dapr.clients.grpc.interceptors import \ + DaprClientTimeoutInterceptor + + self.assertIsInstance(interceptors[0], DaprClientTimeoutInterceptor) + + class WorkflowClientTest(unittest.TestCase): def mock_client_wf(ctx: DaprWorkflowContext, input): print(f'{input}') @@ -184,3 +199,5 @@ def test_client_functions(self): actual_purge_result = wfClient.purge_workflow(instance_id=mock_instance_id) assert actual_purge_result == mock_purge_result + actual_purge_result = wfClient.purge_workflow(instance_id=mock_instance_id) + assert actual_purge_result == mock_purge_result diff --git a/ext/dapr-ext-workflow/tests/test_workflow_client_aio.py b/ext/dapr-ext-workflow/tests/test_workflow_client_aio.py index d27047ce..136533d7 100644 --- a/ext/dapr-ext-workflow/tests/test_workflow_client_aio.py +++ b/ext/dapr-ext-workflow/tests/test_workflow_client_aio.py @@ -19,11 +19,12 @@ from unittest import mock import durabletask.internal.orchestrator_service_pb2 as pb -from dapr.ext.workflow.aio import DaprWorkflowClient -from dapr.ext.workflow.dapr_workflow_context import DaprWorkflowContext from durabletask import client from grpc.aio import AioRpcError +from dapr.ext.workflow.aio import DaprWorkflowClient +from dapr.ext.workflow.dapr_workflow_context import DaprWorkflowContext + mock_schedule_result = 'workflow001' mock_raise_event_result = 'event001' mock_terminate_result = 'terminate001' @@ -113,6 +114,20 @@ def _inner_get_orchestration_state(self, instance_id, state: client.Orchestratio ) +class WorkflowClientAioTimeoutInterceptorTest(unittest.IsolatedAsyncioTestCase): + async def test_timeout_interceptor_is_passed_to_client(self): + with mock.patch('durabletask.aio.client.AsyncTaskHubGrpcClient') as mock_client_cls: + DaprWorkflowClient() + mock_client_cls.assert_called_once() + call_kwargs = mock_client_cls.call_args[1] + interceptors = call_kwargs['interceptors'] + self.assertEqual(len(interceptors), 1) + from dapr.aio.clients.grpc.interceptors import \ + DaprClientTimeoutInterceptorAsync + + self.assertIsInstance(interceptors[0], DaprClientTimeoutInterceptorAsync) + + class WorkflowClientAioTest(unittest.IsolatedAsyncioTestCase): def mock_client_wf(ctx: DaprWorkflowContext, input): print(f'{input}') @@ -188,3 +203,5 @@ async def test_client_functions(self): actual_purge_result = await wfClient.purge_workflow(instance_id=mock_instance_id) assert actual_purge_result == mock_purge_result + actual_purge_result = await wfClient.purge_workflow(instance_id=mock_instance_id) + assert actual_purge_result == mock_purge_result diff --git a/ext/dapr-ext-workflow/tests/test_workflow_runtime.py b/ext/dapr-ext-workflow/tests/test_workflow_runtime.py index b3cadd4a..4f28a23a 100644 --- a/ext/dapr-ext-workflow/tests/test_workflow_runtime.py +++ b/ext/dapr-ext-workflow/tests/test_workflow_runtime.py @@ -17,6 +17,8 @@ from typing import List from unittest import mock +import grpc + from dapr.ext.workflow.dapr_workflow_context import DaprWorkflowContext from dapr.ext.workflow.workflow_activity_context import WorkflowActivityContext from dapr.ext.workflow.workflow_runtime import WorkflowRuntime, alternate_name @@ -39,6 +41,59 @@ def add_named_activity(self, name: str, fn): self._activity_fns[name] = fn +class WorkflowRuntimeTimeoutInterceptorTest(unittest.TestCase): + def setUp(self): + listActivities.clear() + listOrchestrators.clear() + self._registry_patch = mock.patch( + 'durabletask.worker._Registry', return_value=FakeTaskHubGrpcWorker() + ) + self._registry_patch.start() + + def tearDown(self): + mock.patch.stopall() + + def test_timeout_interceptor_is_prepended(self): + with mock.patch('durabletask.worker.TaskHubGrpcWorker') as mock_worker_cls: + WorkflowRuntime() + mock_worker_cls.assert_called_once() + call_kwargs = mock_worker_cls.call_args[1] + interceptors = call_kwargs['interceptors'] + self.assertEqual(len(interceptors), 1) + from dapr.clients.grpc.interceptors import \ + DaprClientTimeoutInterceptor + + self.assertIsInstance(interceptors[0], DaprClientTimeoutInterceptor) + + def test_timeout_interceptor_with_custom_interceptors(self): + custom_interceptor = mock.MagicMock(spec=grpc.UnaryUnaryClientInterceptor) + with mock.patch('durabletask.worker.TaskHubGrpcWorker') as mock_worker_cls: + WorkflowRuntime(interceptors=[custom_interceptor]) + call_kwargs = mock_worker_cls.call_args[1] + interceptors = call_kwargs['interceptors'] + self.assertEqual(len(interceptors), 2) + from dapr.clients.grpc.interceptors import \ + DaprClientTimeoutInterceptor + + self.assertIsInstance(interceptors[0], DaprClientTimeoutInterceptor) + self.assertIs(interceptors[1], custom_interceptor) + + def test_timeout_interceptor_preserves_custom_interceptor_order(self): + custom1 = mock.MagicMock(spec=grpc.UnaryUnaryClientInterceptor) + custom2 = mock.MagicMock(spec=grpc.UnaryStreamClientInterceptor) + with mock.patch('durabletask.worker.TaskHubGrpcWorker') as mock_worker_cls: + WorkflowRuntime(interceptors=[custom1, custom2]) + call_kwargs = mock_worker_cls.call_args[1] + interceptors = call_kwargs['interceptors'] + self.assertEqual(len(interceptors), 3) + from dapr.clients.grpc.interceptors import \ + DaprClientTimeoutInterceptor + + self.assertIsInstance(interceptors[0], DaprClientTimeoutInterceptor) + self.assertIs(interceptors[1], custom1) + self.assertIs(interceptors[2], custom2) + + class WorkflowRuntimeTest(unittest.TestCase): def setUp(self): listActivities.clear() @@ -618,3 +673,7 @@ def my_fn(ctx): with self.assertRaises(ValueError) as ctx: alternate_name(name='second')(my_fn) self.assertIn('already has an alternate name', str(ctx.exception)) + + with self.assertRaises(ValueError) as ctx: + alternate_name(name='second')(my_fn) + self.assertIn('already has an alternate name', str(ctx.exception))