Skip to content

Commit 49da5ec

Browse files
authored
feat: enable wf to start via just name like go sdk (#957)
Signed-off-by: Samantha Coyle <sam@diagrid.io>
1 parent 2047d6b commit 49da5ec

4 files changed

Lines changed: 45 additions & 20 deletions

File tree

ext/dapr-ext-workflow/dapr/ext/workflow/aio/dapr_workflow_client.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from __future__ import annotations
1717

1818
from datetime import datetime
19-
from typing import Any, Optional, TypeVar
19+
from typing import Any, Optional, TypeVar, Union
2020

2121
import durabletask.internal.orchestrator_service_pb2 as pb
2222
from dapr.ext.workflow.logger import Logger, LoggerOptions
@@ -72,7 +72,7 @@ def __init__(
7272

7373
async def schedule_new_workflow(
7474
self,
75-
workflow: Workflow,
75+
workflow: Union[Workflow, str],
7676
*,
7777
input: Optional[TInput] = None,
7878
instance_id: Optional[str] = None,
@@ -82,7 +82,7 @@ async def schedule_new_workflow(
8282
"""Schedules a new workflow instance for execution.
8383
8484
Args:
85-
workflow: The workflow to schedule.
85+
workflow: The workflow to schedule. Can be a workflow callable or a workflow name string.
8686
input: The optional input to pass to the scheduled workflow instance. This must be a
8787
serializable value.
8888
instance_id: The unique ID of the workflow instance to schedule. If not specified, a
@@ -96,11 +96,12 @@ async def schedule_new_workflow(
9696
Returns:
9797
The ID of the scheduled workflow instance.
9898
"""
99-
workflow_name = (
100-
workflow.__dict__['_dapr_alternate_name']
101-
if hasattr(workflow, '_dapr_alternate_name')
102-
else workflow.__name__
103-
)
99+
if isinstance(workflow, str):
100+
workflow_name = workflow
101+
elif hasattr(workflow, '_dapr_alternate_name'):
102+
workflow_name = workflow.__dict__['_dapr_alternate_name']
103+
else:
104+
workflow_name = workflow.__name__
104105
return await self.__obj.schedule_new_orchestration(
105106
workflow_name,
106107
input=input,

ext/dapr-ext-workflow/dapr/ext/workflow/dapr_workflow_client.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from __future__ import annotations
1717

1818
from datetime import datetime
19-
from typing import Any, Optional, TypeVar
19+
from typing import Any, Optional, TypeVar, Union
2020

2121
import durabletask.internal.orchestrator_service_pb2 as pb
2222
from dapr.ext.workflow.logger import Logger, LoggerOptions
@@ -75,7 +75,7 @@ def __init__(
7575

7676
def schedule_new_workflow(
7777
self,
78-
workflow: Workflow,
78+
workflow: Union[Workflow, str],
7979
*,
8080
input: Optional[TInput] = None,
8181
instance_id: Optional[str] = None,
@@ -85,7 +85,7 @@ def schedule_new_workflow(
8585
"""Schedules a new workflow instance for execution.
8686
8787
Args:
88-
workflow: The workflow to schedule.
88+
workflow: The workflow to schedule. Can be a workflow callable or a workflow name string.
8989
input: The optional input to pass to the scheduled workflow instance. This must be a
9090
serializable value.
9191
instance_id: The unique ID of the workflow instance to schedule. If not specified, a
@@ -99,16 +99,14 @@ def schedule_new_workflow(
9999
Returns:
100100
The ID of the scheduled workflow instance.
101101
"""
102-
if hasattr(workflow, '_dapr_alternate_name'):
103-
return self.__obj.schedule_new_orchestration(
104-
workflow.__dict__['_dapr_alternate_name'],
105-
input=input,
106-
instance_id=instance_id,
107-
start_at=start_at,
108-
reuse_id_policy=reuse_id_policy,
109-
)
102+
if isinstance(workflow, str):
103+
workflow_name = workflow
104+
elif hasattr(workflow, '_dapr_alternate_name'):
105+
workflow_name = workflow.__dict__['_dapr_alternate_name']
106+
else:
107+
workflow_name = workflow.__name__
110108
return self.__obj.schedule_new_orchestration(
111-
workflow.__name__,
109+
workflow_name,
112110
input=input,
113111
instance_id=instance_id,
114112
start_at=start_at,

ext/dapr-ext-workflow/tests/test_workflow_client.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ def details(self):
4747

4848

4949
class FakeTaskHubGrpcClient:
50+
def __init__(self):
51+
self.last_scheduled_workflow_name = None
52+
5053
def schedule_new_orchestration(
5154
self,
5255
workflow,
@@ -55,6 +58,7 @@ def schedule_new_orchestration(
5558
start_at,
5659
reuse_id_policy: Union[pb.OrchestrationIdReusePolicy, None] = None,
5760
):
61+
self.last_scheduled_workflow_name = workflow
5862
return mock_schedule_result
5963

6064
def get_orchestration_state(self, instance_id, fetch_payloads):
@@ -112,6 +116,14 @@ class WorkflowClientTest(unittest.TestCase):
112116
def mock_client_wf(ctx: DaprWorkflowContext, input):
113117
print(f'{input}')
114118

119+
def test_schedule_workflow_by_name_string(self):
120+
fake_client = FakeTaskHubGrpcClient()
121+
with mock.patch('durabletask.client.TaskHubGrpcClient', return_value=fake_client):
122+
wfClient = DaprWorkflowClient()
123+
result = wfClient.schedule_new_workflow(workflow='my_registered_workflow', input='data')
124+
assert result == mock_schedule_result
125+
assert fake_client.last_scheduled_workflow_name == 'my_registered_workflow'
126+
115127
def test_client_functions(self):
116128
with mock.patch(
117129
'durabletask.client.TaskHubGrpcClient', return_value=FakeTaskHubGrpcClient()

ext/dapr-ext-workflow/tests/test_workflow_client_aio.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ def details(self):
4747

4848

4949
class FakeAsyncTaskHubGrpcClient:
50+
def __init__(self):
51+
self.last_scheduled_workflow_name = None
52+
5053
async def schedule_new_orchestration(
5154
self,
5255
workflow,
@@ -56,6 +59,7 @@ async def schedule_new_orchestration(
5659
start_at,
5760
reuse_id_policy: Union[pb.OrchestrationIdReusePolicy, None] = None,
5861
):
62+
self.last_scheduled_workflow_name = workflow
5963
return mock_schedule_result
6064

6165
async def get_orchestration_state(self, instance_id, *, fetch_payloads):
@@ -113,6 +117,16 @@ class WorkflowClientAioTest(unittest.IsolatedAsyncioTestCase):
113117
def mock_client_wf(ctx: DaprWorkflowContext, input):
114118
print(f'{input}')
115119

120+
async def test_schedule_workflow_by_name_string(self):
121+
fake_client = FakeAsyncTaskHubGrpcClient()
122+
with mock.patch('durabletask.aio.client.AsyncTaskHubGrpcClient', return_value=fake_client):
123+
wfClient = DaprWorkflowClient()
124+
result = await wfClient.schedule_new_workflow(
125+
workflow='my_registered_workflow', input='data'
126+
)
127+
assert result == mock_schedule_result
128+
assert fake_client.last_scheduled_workflow_name == 'my_registered_workflow'
129+
116130
async def test_client_functions(self):
117131
with mock.patch(
118132
'durabletask.aio.client.AsyncTaskHubGrpcClient',

0 commit comments

Comments
 (0)