Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from durabletask.aio import client as aioclient
from grpc.aio import AioRpcError

from dapr.aio.clients.grpc.interceptors import DaprClientTimeoutInterceptorAsync
from dapr.clients import DaprInternalError
from dapr.clients.http.client import DAPR_API_TOKEN_HEADER
from dapr.conf import settings
Expand Down Expand Up @@ -68,6 +69,7 @@ def __init__(
secure_channel=uri.tls,
log_handler=options.log_handler,
log_formatter=options.log_formatter,
interceptors=[DaprClientTimeoutInterceptorAsync()],
)

async def schedule_new_workflow(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from grpc import RpcError

from dapr.clients import DaprInternalError
from dapr.clients.grpc.interceptors import DaprClientTimeoutInterceptor
from dapr.clients.http.client import DAPR_API_TOKEN_HEADER
from dapr.conf import settings
from dapr.conf.helpers import GrpcEndpoint
Expand Down Expand Up @@ -71,6 +72,7 @@ def __init__(
secure_channel=uri.tls,
log_handler=options.log_handler,
log_formatter=options.log_formatter,
interceptors=[DaprClientTimeoutInterceptor()],
)

def schedule_new_workflow(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from durabletask import task, worker

from dapr.clients import DaprInternalError
from dapr.clients.grpc.interceptors import DaprClientTimeoutInterceptor
from dapr.clients.http.client import DAPR_API_TOKEN_HEADER
from dapr.conf import settings
from dapr.conf.helpers import GrpcEndpoint
Expand Down Expand Up @@ -71,13 +72,17 @@ def __init__(
raise DaprInternalError(f'{error}') from error

options = self._logger.get_options()
all_interceptors = []
if interceptors:
all_interceptors.extend(interceptors)
all_interceptors.append(DaprClientTimeoutInterceptor())
self.__worker = worker.TaskHubGrpcWorker(
host_address=uri.endpoint,
metadata=metadata,
secure_channel=uri.tls,
log_handler=options.log_handler,
log_formatter=options.log_formatter,
interceptors=interceptors,
interceptors=all_interceptors,
concurrency_options=worker.ConcurrencyOptions(
maximum_concurrent_activity_work_items=maximum_concurrent_activity_work_items,
maximum_concurrent_orchestration_work_items=maximum_concurrent_orchestration_work_items,
Expand Down
21 changes: 19 additions & 2 deletions ext/dapr-ext-workflow/tests/test_workflow_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@
from unittest import mock

import durabletask.internal.orchestrator_service_pb2 as pb
from dapr.ext.workflow.dapr_workflow_client import DaprWorkflowClient
from dapr.ext.workflow.dapr_workflow_context import DaprWorkflowContext
from durabletask import client
from grpc import RpcError

from dapr.ext.workflow.dapr_workflow_client import DaprWorkflowClient
from dapr.ext.workflow.dapr_workflow_context import DaprWorkflowContext

mock_schedule_result = 'workflow001'
mock_raise_event_result = 'event001'
mock_terminate_result = 'terminate001'
Expand Down Expand Up @@ -112,6 +113,20 @@ def _inner_get_orchestration_state(self, instance_id, state: client.Orchestratio
)


class WorkflowClientTimeoutInterceptorTest(unittest.TestCase):
def test_timeout_interceptor_is_passed_to_client(self):
with mock.patch('durabletask.client.TaskHubGrpcClient') as mock_client_cls:
DaprWorkflowClient()
mock_client_cls.assert_called_once()
call_kwargs = mock_client_cls.call_args[1]
interceptors = call_kwargs['interceptors']
self.assertEqual(len(interceptors), 1)
from dapr.clients.grpc.interceptors import \
DaprClientTimeoutInterceptor

self.assertIsInstance(interceptors[0], DaprClientTimeoutInterceptor)


class WorkflowClientTest(unittest.TestCase):
def mock_client_wf(ctx: DaprWorkflowContext, input):
print(f'{input}')
Expand Down Expand Up @@ -184,3 +199,5 @@ def test_client_functions(self):

actual_purge_result = wfClient.purge_workflow(instance_id=mock_instance_id)
assert actual_purge_result == mock_purge_result
actual_purge_result = wfClient.purge_workflow(instance_id=mock_instance_id)
assert actual_purge_result == mock_purge_result
21 changes: 19 additions & 2 deletions ext/dapr-ext-workflow/tests/test_workflow_client_aio.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@
from unittest import mock

import durabletask.internal.orchestrator_service_pb2 as pb
from dapr.ext.workflow.aio import DaprWorkflowClient
from dapr.ext.workflow.dapr_workflow_context import DaprWorkflowContext
from durabletask import client
from grpc.aio import AioRpcError

from dapr.ext.workflow.aio import DaprWorkflowClient
from dapr.ext.workflow.dapr_workflow_context import DaprWorkflowContext

mock_schedule_result = 'workflow001'
mock_raise_event_result = 'event001'
mock_terminate_result = 'terminate001'
Expand Down Expand Up @@ -113,6 +114,20 @@ def _inner_get_orchestration_state(self, instance_id, state: client.Orchestratio
)


class WorkflowClientAioTimeoutInterceptorTest(unittest.IsolatedAsyncioTestCase):
async def test_timeout_interceptor_is_passed_to_client(self):
with mock.patch('durabletask.aio.client.AsyncTaskHubGrpcClient') as mock_client_cls:
DaprWorkflowClient()
mock_client_cls.assert_called_once()
call_kwargs = mock_client_cls.call_args[1]
interceptors = call_kwargs['interceptors']
self.assertEqual(len(interceptors), 1)
from dapr.aio.clients.grpc.interceptors import \
DaprClientTimeoutInterceptorAsync

self.assertIsInstance(interceptors[0], DaprClientTimeoutInterceptorAsync)


class WorkflowClientAioTest(unittest.IsolatedAsyncioTestCase):
def mock_client_wf(ctx: DaprWorkflowContext, input):
print(f'{input}')
Expand Down Expand Up @@ -188,3 +203,5 @@ async def test_client_functions(self):

actual_purge_result = await wfClient.purge_workflow(instance_id=mock_instance_id)
assert actual_purge_result == mock_purge_result
actual_purge_result = await wfClient.purge_workflow(instance_id=mock_instance_id)
assert actual_purge_result == mock_purge_result
59 changes: 59 additions & 0 deletions ext/dapr-ext-workflow/tests/test_workflow_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from typing import List
from unittest import mock

import grpc

from dapr.ext.workflow.dapr_workflow_context import DaprWorkflowContext
from dapr.ext.workflow.workflow_activity_context import WorkflowActivityContext
from dapr.ext.workflow.workflow_runtime import WorkflowRuntime, alternate_name
Expand All @@ -39,6 +41,59 @@ def add_named_activity(self, name: str, fn):
self._activity_fns[name] = fn


class WorkflowRuntimeTimeoutInterceptorTest(unittest.TestCase):
def setUp(self):
listActivities.clear()
listOrchestrators.clear()
self._registry_patch = mock.patch(
'durabletask.worker._Registry', return_value=FakeTaskHubGrpcWorker()
)
self._registry_patch.start()

def tearDown(self):
mock.patch.stopall()

def test_timeout_interceptor_is_prepended(self):
with mock.patch('durabletask.worker.TaskHubGrpcWorker') as mock_worker_cls:
WorkflowRuntime()
mock_worker_cls.assert_called_once()
call_kwargs = mock_worker_cls.call_args[1]
interceptors = call_kwargs['interceptors']
self.assertEqual(len(interceptors), 1)
from dapr.clients.grpc.interceptors import \
DaprClientTimeoutInterceptor

self.assertIsInstance(interceptors[0], DaprClientTimeoutInterceptor)

def test_timeout_interceptor_with_custom_interceptors(self):
custom_interceptor = mock.MagicMock(spec=grpc.UnaryUnaryClientInterceptor)
with mock.patch('durabletask.worker.TaskHubGrpcWorker') as mock_worker_cls:
WorkflowRuntime(interceptors=[custom_interceptor])
call_kwargs = mock_worker_cls.call_args[1]
interceptors = call_kwargs['interceptors']
self.assertEqual(len(interceptors), 2)
from dapr.clients.grpc.interceptors import \
DaprClientTimeoutInterceptor

self.assertIsInstance(interceptors[0], DaprClientTimeoutInterceptor)
self.assertIs(interceptors[1], custom_interceptor)

def test_timeout_interceptor_preserves_custom_interceptor_order(self):
custom1 = mock.MagicMock(spec=grpc.UnaryUnaryClientInterceptor)
custom2 = mock.MagicMock(spec=grpc.UnaryStreamClientInterceptor)
with mock.patch('durabletask.worker.TaskHubGrpcWorker') as mock_worker_cls:
WorkflowRuntime(interceptors=[custom1, custom2])
call_kwargs = mock_worker_cls.call_args[1]
interceptors = call_kwargs['interceptors']
self.assertEqual(len(interceptors), 3)
from dapr.clients.grpc.interceptors import \
DaprClientTimeoutInterceptor

self.assertIsInstance(interceptors[0], DaprClientTimeoutInterceptor)
self.assertIs(interceptors[1], custom1)
self.assertIs(interceptors[2], custom2)


class WorkflowRuntimeTest(unittest.TestCase):
def setUp(self):
listActivities.clear()
Expand Down Expand Up @@ -618,3 +673,7 @@ def my_fn(ctx):
with self.assertRaises(ValueError) as ctx:
alternate_name(name='second')(my_fn)
self.assertIn('already has an alternate name', str(ctx.exception))

with self.assertRaises(ValueError) as ctx:
alternate_name(name='second')(my_fn)
self.assertIn('already has an alternate name', str(ctx.exception))
Loading