Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions temporalio/client/_nexus.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,7 +619,7 @@ async def start_operation(
operation: Callable[
[
NexusServiceType,
temporalio.nexus.TemporalNexusStartOperationContext,
temporalio.nexus.TemporalStartOperationContext,
temporalio.nexus.TemporalNexusClient,
InputT,
],
Expand Down Expand Up @@ -841,7 +841,7 @@ async def execute_operation(
operation: Callable[
[
NexusServiceType,
temporalio.nexus.TemporalNexusStartOperationContext,
temporalio.nexus.TemporalStartOperationContext,
temporalio.nexus.TemporalNexusClient,
InputT,
],
Expand Down
16 changes: 8 additions & 8 deletions temporalio/nexus/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,16 @@
"""

from ._decorators import (
TemporalNexusOperationStartHandlerFunc,
TemporalOperationStartHandlerFunc,
temporal_operation,
workflow_run_operation,
)
from ._operation_context import (
Info,
LoggerAdapter,
NexusCallback,
TemporalNexusCancelOperationContext,
TemporalNexusStartOperationContext,
TemporalCancelOperationContext,
TemporalStartOperationContext,
WorkflowRunOperationContext,
client,
in_operation,
Expand All @@ -26,7 +26,7 @@
)
from ._operation_handlers import (
CancelWorkflowRunOptions,
TemporalNexusOperationHandler,
TemporalOperationHandler,
)
from ._temporal_client import TemporalNexusClient, TemporalOperationResult
from ._token import WorkflowHandle
Expand All @@ -38,8 +38,8 @@
"LoggerAdapter",
"NexusCallback",
"WorkflowRunOperationContext",
"TemporalNexusCancelOperationContext",
"TemporalNexusStartOperationContext",
"TemporalCancelOperationContext",
"TemporalStartOperationContext",
"client",
"in_operation",
"info",
Expand All @@ -50,8 +50,8 @@
"wait_for_worker_shutdown_sync",
"WorkflowHandle",
"TemporalNexusClient",
"TemporalNexusOperationStartHandlerFunc",
"TemporalNexusOperationHandler",
"TemporalOperationStartHandlerFunc",
"TemporalOperationHandler",
"TemporalOperationResult",
"temporal_operation",
)
40 changes: 19 additions & 21 deletions temporalio/nexus/_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@
from temporalio.types import NexusServiceType

from ._operation_context import (
TemporalNexusStartOperationContext,
TemporalStartOperationContext,
WorkflowRunOperationContext,
)
from ._operation_handlers import (
TemporalNexusOperationHandler,
TemporalOperationHandler,
WorkflowRunOperationHandler,
)
from ._token import WorkflowHandle
Expand Down Expand Up @@ -145,10 +145,10 @@ async def _start(
return decorator(start)


TemporalNexusOperationStartHandlerFunc: TypeAlias = Callable[
TemporalOperationStartHandlerFunc: TypeAlias = Callable[
[
NexusServiceType,
TemporalNexusStartOperationContext,
TemporalStartOperationContext,
TemporalNexusClient,
InputT,
],
Expand All @@ -158,30 +158,30 @@ async def _start(

@overload
def temporal_operation(
start: TemporalNexusOperationStartHandlerFunc[NexusServiceType, InputT, OutputT],
) -> TemporalNexusOperationStartHandlerFunc[NexusServiceType, InputT, OutputT]: ...
start: TemporalOperationStartHandlerFunc[NexusServiceType, InputT, OutputT],
) -> TemporalOperationStartHandlerFunc[NexusServiceType, InputT, OutputT]: ...


@overload
def temporal_operation(
*,
name: str | None = None,
) -> Callable[
[TemporalNexusOperationStartHandlerFunc[NexusServiceType, InputT, OutputT]],
TemporalNexusOperationStartHandlerFunc[NexusServiceType, InputT, OutputT],
[TemporalOperationStartHandlerFunc[NexusServiceType, InputT, OutputT]],
TemporalOperationStartHandlerFunc[NexusServiceType, InputT, OutputT],
]: ...


def temporal_operation(
start: None
| TemporalNexusOperationStartHandlerFunc[NexusServiceType, InputT, OutputT] = None,
| TemporalOperationStartHandlerFunc[NexusServiceType, InputT, OutputT] = None,
*,
name: str | None = None,
) -> (
TemporalNexusOperationStartHandlerFunc[NexusServiceType, InputT, OutputT]
TemporalOperationStartHandlerFunc[NexusServiceType, InputT, OutputT]
| Callable[
[TemporalNexusOperationStartHandlerFunc[NexusServiceType, InputT, OutputT]],
TemporalNexusOperationStartHandlerFunc[NexusServiceType, InputT, OutputT],
[TemporalOperationStartHandlerFunc[NexusServiceType, InputT, OutputT]],
TemporalOperationStartHandlerFunc[NexusServiceType, InputT, OutputT],
]
):
"""Decorator marking a method as the start method for an operation that interacts with Temporal.
Expand All @@ -191,10 +191,8 @@ def temporal_operation(
"""

def decorator(
start: TemporalNexusOperationStartHandlerFunc[
NexusServiceType, InputT, OutputT
],
) -> TemporalNexusOperationStartHandlerFunc[NexusServiceType, InputT, OutputT]:
start: TemporalOperationStartHandlerFunc[NexusServiceType, InputT, OutputT],
) -> TemporalOperationStartHandlerFunc[NexusServiceType, InputT, OutputT]:
if not is_async_callable(start):
raise RuntimeError(
f"{start} is not an `async def` method. "
Expand All @@ -209,7 +207,7 @@ def operation_handler_factory(
self: NexusServiceType,
) -> OperationHandler[InputT, OutputT]:
async def _start(
ctx: TemporalNexusStartOperationContext,
ctx: TemporalStartOperationContext,
client: TemporalNexusClient,
input: InputT,
) -> TemporalOperationResult[OutputT]:
Expand All @@ -220,18 +218,18 @@ async def _start(
input,
)

class _TemporalNexusOperationHandler(TemporalNexusOperationHandler):
class _TemporalOperationHandler(TemporalOperationHandler):
@override
async def start_operation(
self,
ctx: TemporalNexusStartOperationContext,
ctx: TemporalStartOperationContext,
client: TemporalNexusClient,
input: InputT,
) -> TemporalOperationResult[OutputT]:
return await _start(ctx, client, input)

_TemporalNexusOperationHandler.start_operation.__doc__ = start.__doc__
return _TemporalNexusOperationHandler()
_TemporalOperationHandler.start_operation.__doc__ = start.__doc__
return _TemporalOperationHandler()

method_name = get_callable_name(start)
op = nexusrpc.Operation(
Expand Down
4 changes: 2 additions & 2 deletions temporalio/nexus/_operation_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,7 +549,7 @@ def set(self) -> None:
_temporal_cancel_operation_context.set(self)


class TemporalNexusStartOperationContext(StartOperationContext):
class TemporalStartOperationContext(StartOperationContext):
"""Context received by a Temporal Nexus operation when it is started.

.. warning::
Expand All @@ -563,7 +563,7 @@ def _from_start_operation_context(cls, ctx: StartOperationContext) -> Self:
)


class TemporalNexusCancelOperationContext(CancelOperationContext):
class TemporalCancelOperationContext(CancelOperationContext):
"""Context received by a Temporal Nexus operation when it is canceled.

.. warning::
Expand Down
22 changes: 9 additions & 13 deletions temporalio/nexus/_operation_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@

import temporalio.nexus
from temporalio.nexus._operation_context import (
TemporalNexusCancelOperationContext,
TemporalNexusStartOperationContext,
TemporalCancelOperationContext,
TemporalStartOperationContext,
_temporal_cancel_operation_context,
)
from temporalio.nexus._temporal_client import (
Expand Down Expand Up @@ -127,8 +127,8 @@ async def _cancel_workflow(
class CancelWorkflowRunOptions:
"""Options for cancelling the workflow backing a Nexus operation.

These options are built by :py:class:`TemporalNexusOperationHandler` and passed to
:py:meth:`TemporalNexusOperationHandler.cancel_workflow_run`.
These options are built by :py:class:`TemporalOperationHandler` and passed to
:py:meth:`TemporalOperationHandler.cancel_workflow_run`.

.. warning::
This API is experimental and unstable.
Expand All @@ -138,7 +138,7 @@ class CancelWorkflowRunOptions:
"""The ID of the workflow to cancel."""


class TemporalNexusOperationHandler(OperationHandler[InputT, OutputT], ABC):
class TemporalOperationHandler(OperationHandler[InputT, OutputT], ABC):
"""Operation handler for Nexus operations that interact with Temporal.
Implementations override the start_operation method.

Expand All @@ -149,7 +149,7 @@ class TemporalNexusOperationHandler(OperationHandler[InputT, OutputT], ABC):
@abstractmethod
async def start_operation(
self,
ctx: TemporalNexusStartOperationContext,
ctx: TemporalStartOperationContext,
client: TemporalNexusClient,
input: InputT,
) -> TemporalOperationResult[OutputT]:
Expand All @@ -165,9 +165,7 @@ async def start(
This API is experimental and unstable.
"""
nexus_client = _TemporalNexusClient()
start_ctx = TemporalNexusStartOperationContext._from_start_operation_context(
ctx
)
start_ctx = TemporalStartOperationContext._from_start_operation_context(ctx)
result = await self.start_operation(start_ctx, nexus_client, input)
return result._to_nexus_result()

Expand All @@ -185,9 +183,7 @@ async def cancel(self, ctx: CancelOperationContext, token: str) -> None:
type=HandlerErrorType.INTERNAL,
) from err

cancel_ctx = TemporalNexusCancelOperationContext._from_cancel_operation_context(
ctx
)
cancel_ctx = TemporalCancelOperationContext._from_cancel_operation_context(ctx)
match operation_token.type:
case OperationTokenType.WORKFLOW:
options = CancelWorkflowRunOptions(
Expand All @@ -197,7 +193,7 @@ async def cancel(self, ctx: CancelOperationContext, token: str) -> None:

async def cancel_workflow_run(
self,
ctx: TemporalNexusCancelOperationContext, # pyright: ignore[reportUnusedParameter]
ctx: TemporalCancelOperationContext, # pyright: ignore[reportUnusedParameter]
options: CancelWorkflowRunOptions,
) -> None:
"""Cancels the workflow backing the Nexus operation.
Expand Down
6 changes: 3 additions & 3 deletions temporalio/nexus/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
)

from temporalio.nexus._operation_context import (
TemporalNexusStartOperationContext,
TemporalStartOperationContext,
WorkflowRunOperationContext,
)
from temporalio.nexus._temporal_client import (
Expand Down Expand Up @@ -55,7 +55,7 @@ def get_temporal_operation_start_method_input_and_output_type_annotations(
start: Callable[
[
NexusServiceType,
TemporalNexusStartOperationContext,
TemporalStartOperationContext,
TemporalNexusClient,
InputT,
],
Expand All @@ -73,7 +73,7 @@ def get_temporal_operation_start_method_input_and_output_type_annotations(
return _get_wrapped_start_method_input_and_output_type_annotations(
start,
expected_param_types=(
TemporalNexusStartOperationContext,
TemporalStartOperationContext,
TemporalNexusClient,
),
expected_return_origin=TemporalOperationResult,
Expand Down
4 changes: 2 additions & 2 deletions temporalio/workflow/_nexus.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ async def start_operation(
operation: Callable[
[
NexusServiceType,
temporalio.nexus.TemporalNexusStartOperationContext,
temporalio.nexus.TemporalStartOperationContext,
temporalio.nexus.TemporalNexusClient,
InputT,
],
Expand Down Expand Up @@ -390,7 +390,7 @@ async def execute_operation(
operation: Callable[
[
NexusServiceType,
temporalio.nexus.TemporalNexusStartOperationContext,
temporalio.nexus.TemporalStartOperationContext,
temporalio.nexus.TemporalNexusClient,
InputT,
],
Expand Down
4 changes: 2 additions & 2 deletions tests/nexus/test_handler_operation_definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,10 @@ def test_unsafe_narrow_context_annotations_warn_and_drop_input_type():

with pytest.warns(
UserWarning,
match="Expected parameter 1 .* TemporalNexusStartOperationContext",
match="Expected parameter 1 .* TemporalStartOperationContext",
):

class MyTemporalOpCtx(nexus.TemporalNexusStartOperationContext):
class MyTemporalOpCtx(nexus.TemporalStartOperationContext):
def custom_method(self):
raise NotImplementedError

Expand Down
16 changes: 8 additions & 8 deletions tests/nexus/test_nexus_type_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import temporalio.nexus
from temporalio import workflow
from temporalio.client import Client, NexusOperationHandle
from temporalio.nexus import TemporalNexusOperationStartHandlerFunc
from temporalio.nexus import TemporalOperationStartHandlerFunc
from temporalio.service import ServiceClient


Expand Down Expand Up @@ -100,7 +100,7 @@ async def my_workflow_run_operation(
@temporalio.nexus.temporal_operation
async def my_temporal_operation(
self,
_ctx: temporalio.nexus.TemporalNexusStartOperationContext,
_ctx: temporalio.nexus.TemporalStartOperationContext,
client: temporalio.nexus.TemporalNexusClient,
input: int,
) -> temporalio.nexus.TemporalOperationResult[None]:
Expand Down Expand Up @@ -180,7 +180,7 @@ async def my_workflow_run_operation(
@temporalio.nexus.temporal_operation
async def my_temporal_operation(
self,
_ctx: temporalio.nexus.TemporalNexusStartOperationContext,
_ctx: temporalio.nexus.TemporalStartOperationContext,
_client: temporalio.nexus.TemporalNexusClient,
_input: int,
) -> temporalio.nexus.TemporalOperationResult[None]:
Expand All @@ -204,26 +204,26 @@ async def my_workflow_run_operation(
@temporalio.nexus.temporal_operation
async def my_temporal_operation(
self,
_ctx: temporalio.nexus.TemporalNexusStartOperationContext,
_ctx: temporalio.nexus.TemporalStartOperationContext,
_client: temporalio.nexus.TemporalNexusClient,
_input: int,
) -> temporalio.nexus.TemporalOperationResult[None]:
raise NotImplementedError


_handler: TemporalNexusOperationStartHandlerFunc[
_handler: TemporalOperationStartHandlerFunc[
MyServiceHandler,
int,
None,
] = MyServiceHandler.my_temporal_operation

_BadHandler: TypeAlias = temporalio.nexus.TemporalNexusOperationStartHandlerFunc[
_BadHandler: TypeAlias = temporalio.nexus.TemporalOperationStartHandlerFunc[
MyServiceHandler,
str,
None,
]

_bad_handler: TemporalNexusOperationStartHandlerFunc[
_bad_handler: TemporalOperationStartHandlerFunc[
MyServiceHandler,
str,
None,
Expand All @@ -235,7 +235,7 @@ class MyUnsafeContextAnnotationServiceHandler:
# A temporal operation receives TemporalStartOperationContext at runtime, so
# requiring an arbitrary user subclass is not safe.
class MyCustomTemporalStartOperationContext(
temporalio.nexus.TemporalNexusStartOperationContext
temporalio.nexus.TemporalStartOperationContext
):
def custom_state(self) -> str:
raise NotImplementedError
Expand Down
Loading
Loading