diff --git a/pyproject.toml b/pyproject.toml index b44459fa3..3775eb137 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -99,14 +99,17 @@ format = [ { cmd = "cargo fmt", cwd = "temporalio/bridge" }, ] gen-docs = "uv run scripts/gen_docs.py" +gen-nexus-system-api = "uv run scripts/gen_nexus_system_api.py" gen-protos = [ { cmd = "uv run scripts/gen_protos.py" }, + { ref = "gen-nexus-system-api" }, { cmd = "uv run scripts/gen_payload_visitor.py" }, { cmd = "uv run scripts/gen_bridge_client.py" }, { ref = "format" }, ] gen-protos-docker = [ { cmd = "uv run scripts/gen_protos_docker.py" }, + { ref = "gen-nexus-system-api" }, { cmd = "uv run scripts/gen_payload_visitor.py" }, { cmd = "uv run scripts/gen_bridge_client.py" }, { ref = "format" }, @@ -170,7 +173,7 @@ exclude = [ [tool.pydocstyle] convention = "google" # https://github.com/PyCQA/pydocstyle/issues/363#issuecomment-625563088 -match_dir = "^(?!(docs|scripts|tests|api|proto|\\.)).*" +match_dir = "^(?!(docs|scripts|tests|api|proto|system|\\.)).*" add_ignore = [ # We like to wrap at a certain number of chars, even long summary sentences. # https://github.com/PyCQA/pydocstyle/issues/184 diff --git a/scripts/_nexus/deps/nexus-temporal-types/model.wit b/scripts/_nexus/deps/nexus-temporal-types/model.wit new file mode 100644 index 000000000..a951a1827 --- /dev/null +++ b/scripts/_nexus/deps/nexus-temporal-types/model.wit @@ -0,0 +1,152 @@ +/// @nexus.support +/// python="python/temporal_model_converters.py" +/// typescript="typescript/temporal_model_converters.ts" +package nexus:temporal-types@1.0.0; + +interface model { + /// String-shaped placeholder for semantic types that generators reinterpret. + type placeholder = string; + + /// @nexus.proto "temporal.api.common.v1.Payload" typescript-package="@temporalio/proto" + /// @nexus.type python="typing.Any" typescript="common.Payload" typescript-package="@temporalio/common" + type payload = placeholder; + + /// @nexus.proto "temporal.api.common.v1.Payloads" + /// typescript-package="@temporalio/proto" + type payloads = list; + + /// Callable result annotation for workflow functions. + /// @nexus.type + /// python="collections.abc.Awaitable[WorkflowResult]" + /// typescript="Promise" + type workflow-result = placeholder; + + /// Receiver/context argument for workflow callable method forms. + /// @nexus.type python="typing.Any" typescript="any" + type callable-prefix = placeholder; + + /// @nexus.function-args + /// varargs=true + /// param="args" + /// typescript-drop-prefix=true + workflow-call: async func(callable-prefix: callable-prefix, args: payloads) -> workflow-result; + + /// Callable result annotation for signal functions. + /// @nexus.type python="None | collections.abc.Awaitable[None]" typescript="void" + type signal-result = placeholder; + + /// @nexus.function-args + /// varargs=true + /// param="signal-args" + /// typescript-drop-prefix=true + signal-call: func(callable-prefix: callable-prefix, signal-args: payloads) -> signal-result; + + /// @nexus.proto "temporal.api.common.v1.WorkflowType" typescript-package="@temporalio/proto" + /// @nexus.type python="str" typescript="string" + type workflow-type = placeholder; + + /// @nexus.function + /// primary=true + /// signature="workflow-call" + /// args-field="input" + /// result-type-parameter="WorkflowResult" + /// alternate-type="workflow-type" + /// @nexus.add-rpc-compatible-with "workflow-type" + type workflow-function = placeholder; + + /// @nexus.function + /// signature="signal-call" + /// args-field="signal-input" + /// alternate-type="string" + /// python-converter="signal_function_to_proto" + /// typescript-converter="signalFunctionToProto" + /// @nexus.add-rpc-compatible-with "string" + /// @nexus.typescript-with-arguments + /// signature="signal-call" + /// args-field="signal-input" + /// alternate-type="string" + /// value-type="workflow.SignalDefinition" + /// args-type="Value extends workflow.SignalDefinition ? Args : never" + /// name-expr="value.name" + /// typescript-package="@temporalio/workflow" + type signal-function = placeholder; + + /// @nexus.proto "temporal.api.common.v1.RetryPolicy" typescript-package="@temporalio/proto" + /// @nexus.type + /// python="temporalio.common.RetryPolicy" + /// typescript="common.RetryPolicy" + /// typescript-package="@temporalio/common" + type retry-policy = placeholder; + + /// @nexus.proto "temporal.api.taskqueue.v1.TaskQueue" typescript-package="@temporalio/proto" + /// @nexus.type python="str" typescript="string" + type task-queue = placeholder; + + /// @nexus.proto "temporal.api.common.v1.Memo" typescript-package="@temporalio/proto" + /// @nexus.type python="collections.abc.Mapping[str, typing.Any]" typescript="Record" + type memo = placeholder; + + /// @nexus.proto "temporal.api.common.v1.SearchAttributes" typescript-package="@temporalio/proto" + /// @nexus.type + /// python="temporalio.common.TypedSearchAttributes" + /// typescript="common.TypedSearchAttributes" + /// typescript-package="@temporalio/common" + type search-attributes = placeholder; + + /// @nexus.proto "temporal.api.common.v1.Priority" typescript-package="@temporalio/proto" + /// @nexus.type + /// python="temporalio.common.Priority" + /// typescript="common.Priority" + /// typescript-package="@temporalio/common" + type priority = placeholder; + + /// @nexus.proto "temporal.api.workflow.v1.VersioningOverride" typescript-package="@temporalio/proto" + /// @nexus.type + /// python="temporalio.common.VersioningOverride" + /// typescript="common.VersioningOverride" + /// typescript-package="@temporalio/common" + type versioning-override = placeholder; + + /// @nexus.proto "google.protobuf.Duration" typescript-package="@temporalio/proto" + /// @nexus.type + /// python="datetime.timedelta" + /// typescript="common.Duration" + /// typescript-package="@temporalio/common" + type duration = placeholder; + + /// @nexus.proto "temporal.api.enums.v1.WorkflowIdReusePolicy" typescript-package="@temporalio/proto" + /// @nexus.type + /// python="temporalio.common.WorkflowIDReusePolicy" + /// typescript="common.WorkflowIdReusePolicy" + /// typescript-package="@temporalio/common" + enum workflow-id-reuse-policy { + allow-duplicate, + allow-duplicate-failed-only, + reject-duplicate, + terminate-if-running, + } + + /// @nexus.proto "temporal.api.enums.v1.WorkflowIdConflictPolicy" typescript-package="@temporalio/proto" + /// @nexus.type + /// python="temporalio.common.WorkflowIDConflictPolicy" + /// typescript="common.WorkflowIdConflictPolicy" + /// typescript-package="@temporalio/common" + enum workflow-id-conflict-policy { + fail, + use-existing, + terminate-existing, + } + + /// @nexus.proto "temporal.api.sdk.v1.UserMetadata" typescript-package="@temporalio/proto" + /// @nexus.flatten-in-api + record user-metadata { + /// @nexus.doc "Single-line fixed summary for the workflow execution that may appear in UI and CLI. This can be in single-line Temporal Markdown format." + /// @nexus.proto-field "summary" + /// @nexus.flattened-type python="str" typescript="string" + static-summary: option, + /// @nexus.doc "General fixed details for the workflow execution that may appear in UI and CLI. This can be in Temporal Markdown format and can span multiple lines. This value is fixed on the workflow execution and cannot be updated." + /// @nexus.proto-field "details" + /// @nexus.flattened-type python="str" typescript="string" + static-details: option, + } +} diff --git a/scripts/_nexus/deps/nexus-temporal-types/python/temporal_model_converters.py b/scripts/_nexus/deps/nexus-temporal-types/python/temporal_model_converters.py new file mode 100644 index 000000000..d98be0494 --- /dev/null +++ b/scripts/_nexus/deps/nexus-temporal-types/python/temporal_model_converters.py @@ -0,0 +1,197 @@ +# pyright: reportAny=false, reportExplicitAny=false + +import collections.abc +import typing +from datetime import timedelta + +import google.protobuf.duration_pb2 + +import temporalio.api.common.v1.message_pb2 as common_pb2 +import temporalio.api.enums.v1.workflow_pb2 as workflow_enums_pb2 +import temporalio.api.taskqueue.v1.message_pb2 as taskqueue_pb2 +import temporalio.api.workflow.v1 +import temporalio.common +import temporalio.converter + + +def retry_policy_from_proto( + proto: common_pb2.RetryPolicy, +) -> temporalio.common.RetryPolicy: + return temporalio.common.RetryPolicy.from_proto(proto) + + +def retry_policy_to_proto( + retry_policy: temporalio.common.RetryPolicy, +) -> common_pb2.RetryPolicy: + proto = common_pb2.RetryPolicy() + retry_policy.apply_to_proto(proto) + return proto + + +def workflow_function_name( + value: str | collections.abc.Callable[..., collections.abc.Awaitable[object]], +) -> str: + from temporalio.workflow import _Definition # pyright: ignore[reportPrivateUsage] + + name, _result_type = _Definition.get_name_and_result_type(value) + return name + + +def signal_function_to_proto( + value: str | collections.abc.Callable[..., typing.Any], +) -> str: + from temporalio.workflow import ( + _SignalDefinition, # pyright: ignore[reportPrivateUsage] + ) + + return _SignalDefinition.must_name_from_fn_or_str(value) # pyright: ignore[reportUnknownMemberType] + + +def workflow_type_to_proto( + workflow_type: str + | collections.abc.Callable[..., collections.abc.Awaitable[object]], +) -> common_pb2.WorkflowType: + return common_pb2.WorkflowType(name=workflow_function_name(workflow_type)) + + +def task_queue_from_proto( + proto: taskqueue_pb2.TaskQueue, +) -> str: + return proto.name + + +def task_queue_to_proto( + task_queue: str, +) -> taskqueue_pb2.TaskQueue: + return taskqueue_pb2.TaskQueue(name=task_queue) + + +def workflow_namespace() -> str: + from temporalio.workflow import info + + return info().namespace + + +def payloads_to_proto( + values: collections.abc.Sequence[typing.Any], +) -> common_pb2.Payloads: + from temporalio.workflow import payload_converter + + return payload_converter().to_payloads_wrapper(values) + + +def _clone_payload(payload: common_pb2.Payload) -> common_pb2.Payload: + clone = common_pb2.Payload() + clone.CopyFrom(payload) + return clone + + +def _value_to_payload(value: object | common_pb2.Payload) -> common_pb2.Payload: + if isinstance(value, common_pb2.Payload): + return _clone_payload(value) + from temporalio.workflow import payload_converter + + payloads = payload_converter().to_payloads_wrapper([value]) + return _clone_payload(payloads.payloads[0]) + + +def _payload_to_value(payload: common_pb2.Payload) -> object: + wrapper = common_pb2.Payloads() + wrapper.payloads.add().CopyFrom(payload) + from temporalio.workflow import payload_converter + + return typing.cast( + object, + payload_converter().from_payloads_wrapper(wrapper)[0], + ) + + +def payload_from_proto( + proto: common_pb2.Payload, +) -> object: + return _payload_to_value(proto) + + +def payload_to_proto( + payload: object, +) -> common_pb2.Payload: + return _value_to_payload(payload) + + +def memo_from_proto( + proto: common_pb2.Memo, +) -> collections.abc.Mapping[str, object]: + return {key: _payload_to_value(value) for key, value in proto.fields.items()} + + +def memo_to_proto( + memo: collections.abc.Mapping[str, object], +) -> common_pb2.Memo: + message = common_pb2.Memo() + for key, value in memo.items(): + message.fields[key].CopyFrom(_value_to_payload(value)) + return message + + +def duration_from_proto(proto: google.protobuf.duration_pb2.Duration) -> timedelta: + return proto.ToTimedelta() + + +def duration_to_proto( + duration: timedelta, +) -> google.protobuf.duration_pb2.Duration: + proto = google.protobuf.duration_pb2.Duration() + proto.FromTimedelta(duration) + return proto + + +def workflow_id_reuse_policy_from_proto( + policy: workflow_enums_pb2.WorkflowIdReusePolicy.ValueType, +) -> temporalio.common.WorkflowIDReusePolicy: + return temporalio.common.WorkflowIDReusePolicy(int(policy)) + + +def workflow_id_reuse_policy_to_proto( + policy: temporalio.common.WorkflowIDReusePolicy, +) -> workflow_enums_pb2.WorkflowIdReusePolicy.ValueType: + return typing.cast(workflow_enums_pb2.WorkflowIdReusePolicy.ValueType, int(policy)) + + +def workflow_id_conflict_policy_from_proto( + policy: workflow_enums_pb2.WorkflowIdConflictPolicy.ValueType, +) -> temporalio.common.WorkflowIDConflictPolicy: + return temporalio.common.WorkflowIDConflictPolicy(int(policy)) + + +def workflow_id_conflict_policy_to_proto( + policy: temporalio.common.WorkflowIDConflictPolicy, +) -> workflow_enums_pb2.WorkflowIdConflictPolicy.ValueType: + return typing.cast( + workflow_enums_pb2.WorkflowIdConflictPolicy.ValueType, int(policy) + ) + + +def search_attributes_to_proto( + search_attributes: temporalio.common.TypedSearchAttributes, +) -> common_pb2.SearchAttributes: + proto = common_pb2.SearchAttributes() + temporalio.converter.encode_search_attributes(search_attributes, proto) + return proto + + +def priority_from_proto( + proto: common_pb2.Priority, +) -> temporalio.common.Priority: + return temporalio.common.Priority._from_proto(proto) # pyright: ignore[reportPrivateUsage] + + +def priority_to_proto( + priority: temporalio.common.Priority, +) -> common_pb2.Priority: + return priority._to_proto() # pyright: ignore[reportPrivateUsage] + + +def versioning_override_to_proto( + versioning_override: temporalio.common.VersioningOverride, +) -> temporalio.api.workflow.v1.VersioningOverride: + return versioning_override._to_proto() # pyright: ignore[reportPrivateUsage] diff --git a/scripts/_nexus/deps/nexus-temporal-types/typescript/temporal_model_converters.ts b/scripts/_nexus/deps/nexus-temporal-types/typescript/temporal_model_converters.ts new file mode 100644 index 000000000..2180bdb92 --- /dev/null +++ b/scripts/_nexus/deps/nexus-temporal-types/typescript/temporal_model_converters.ts @@ -0,0 +1,347 @@ +import * as common from "@temporalio/common"; +import type { google, temporal } from "@temporalio/proto"; +import * as workflow from "@temporalio/workflow"; +import type Long from "long"; + +function int64ToNumber(value: Long | number | string | object | null | undefined): number { + if (value == null) { + return 0; + } + if (typeof value === "number") { + return value; + } + if (typeof value === "string") { + return Number(value); + } + if ("toNumber" in value && typeof value.toNumber === "function") { + return value.toNumber(); + } + if ("low" in value && "high" in value) { + const longValue = value as { low: number; high: number; unsigned?: boolean }; + const low = longValue.low >>> 0; + return longValue.high * 4_294_967_296 + low; + } + throw new TypeError("unsupported int64 value"); +} + +function durationToMillis( + proto: google.protobuf.IDuration | null | undefined, +): number | undefined { + if (proto == null) { + return undefined; + } + return int64ToNumber(proto.seconds) * 1000 + Math.floor((proto.nanos ?? 0) / 1_000_000); +} + +export function retryPolicyFromProto( + proto: temporal.api.common.v1.IRetryPolicy, +): common.RetryPolicy { + return { + backoffCoefficient: proto.backoffCoefficient ?? undefined, + maximumAttempts: proto.maximumAttempts ?? undefined, + maximumInterval: durationToMillis(proto.maximumInterval), + initialInterval: durationToMillis(proto.initialInterval), + nonRetryableErrorTypes: proto.nonRetryableErrorTypes ?? undefined, + }; +} + +export function retryPolicyToProto( + retryPolicy: common.RetryPolicy, +): temporal.api.common.v1.IRetryPolicy { + return common.compileRetryPolicy(retryPolicy); +} + +export function workflowTypeFromProto( + proto: temporal.api.common.v1.IWorkflowType, +): string | common.Workflow { + return proto.name ?? ""; +} + +export function workflowTypeToProto( + workflowType: string | common.Workflow, +): temporal.api.common.v1.IWorkflowType { + return { name: workflowFunctionName(workflowType) }; +} + +export function workflowFunctionName( + value: string | common.Workflow, +): string { + return typeof value === "string" ? value : common.extractWorkflowType(value); +} + +export function signalFunctionToProto( + value: string | workflow.SignalDefinition, +): string { + return typeof value === "string" ? value : value.name; +} + +export function taskQueueFromProto( + proto: temporal.api.taskqueue.v1.ITaskQueue, +): string { + return proto.name ?? ""; +} + +export function taskQueueToProto( + taskQueue: string, +): temporal.api.taskqueue.v1.ITaskQueue { + return { name: taskQueue }; +} + +export function workflowNamespace(): string { + return workflow.workflowInfo().namespace; +} + +export function payloadFromProto( + payload: temporal.api.common.v1.IPayload, +): common.Payload { + return payload; +} + +export function payloadToProto( + payload: common.Payload, +): temporal.api.common.v1.IPayload { + return payload; +} + +function configuredPayloadConverter(): common.PayloadConverter { + const activator = ( + globalThis as typeof globalThis & { + __TEMPORAL_ACTIVATOR__?: { + payloadConverter?: common.PayloadConverter; + }; + } + ).__TEMPORAL_ACTIVATOR__; + if (activator?.payloadConverter == null) { + throw new Error("payload converter is unavailable outside workflow context"); + } + return activator.payloadConverter; +} + +export function memoFromProto( + proto: temporal.api.common.v1.IMemo, +): Record { + return ( + common.mapFromPayloads( + configuredPayloadConverter(), + proto.fields ?? undefined, + ) ?? {} + ); +} + +export function memoToProto( + memo: Record, +): temporal.api.common.v1.IMemo { + return { + fields: common.mapToPayloads(configuredPayloadConverter(), memo), + }; +} + +export function durationFromProto( + proto: google.protobuf.IDuration, +): common.Duration { + return durationToMillis(proto)!; +} + +export function durationToProto( + duration: common.Duration, +): google.protobuf.IDuration { + return common.msToTs(duration); +} + +function typedSearchAttributePayload( + value: unknown, + type: common.SearchAttributeType, +): common.Payload { + const payload = configuredPayloadConverter().toPayload(value); + payload.metadata ??= {}; + payload.metadata.type = common.u8( + common.TypedSearchAttributes.toMetadataType(type), + ); + return payload; +} + +function isValidSearchAttributeValue( + type: common.SearchAttributeType, + value: unknown, +): boolean { + switch (type) { + case common.SearchAttributeType.TEXT: + case common.SearchAttributeType.KEYWORD: + return typeof value === "string"; + case common.SearchAttributeType.INT: + return Number.isInteger(value); + case common.SearchAttributeType.DOUBLE: + return typeof value === "number"; + case common.SearchAttributeType.BOOL: + return typeof value === "boolean"; + case common.SearchAttributeType.DATETIME: + return value instanceof Date; + case common.SearchAttributeType.KEYWORD_LIST: + return ( + Array.isArray(value) && + value.every((item) => typeof item === "string") + ); + default: + return false; + } +} + +function typedSearchAttributePairFromPayload( + name: string, + payload: common.Payload, +): common.SearchAttributePair | undefined { + const metadataType = payload.metadata?.type; + if (metadataType == null) { + return undefined; + } + const type = common.TypedSearchAttributes.toSearchAttributeType( + common.str(metadataType), + ); + if (type == null) { + return undefined; + } + let value: unknown = configuredPayloadConverter().fromPayload(payload); + if ( + type !== common.SearchAttributeType.KEYWORD_LIST && + Array.isArray(value) + ) { + if (value.length !== 1) { + return undefined; + } + value = value[0]; + } + if (type === common.SearchAttributeType.DATETIME && value != null) { + value = new Date(value as string); + } + if (!isValidSearchAttributeValue(type, value)) { + return undefined; + } + return { + key: { name, type }, + value, + } as common.SearchAttributePair; +} + +export function searchAttributesFromProto( + proto: temporal.api.common.v1.ISearchAttributes, +): common.TypedSearchAttributes { + const indexedFields = proto.indexedFields ?? {}; + const typedPairs: common.SearchAttributePair[] = []; + for (const [name, payload] of Object.entries(indexedFields)) { + const pair = typedSearchAttributePairFromPayload(name, payload); + if (pair == null) { + throw new TypeError( + `search attribute ${name} cannot be decoded as a typed search attribute`, + ); + } + typedPairs.push(pair); + } + return new common.TypedSearchAttributes(typedPairs); +} + +export function searchAttributesToProto( + searchAttributes: common.TypedSearchAttributes, +): temporal.api.common.v1.ISearchAttributes { + return { + indexedFields: Object.fromEntries( + searchAttributes.getAll().map((pair): [string, common.Payload] => [ + pair.key.name, + typedSearchAttributePayload(pair.value, pair.key.type), + ]), + ), + }; +} + +export function priorityFromProto( + proto: temporal.api.common.v1.IPriority, +): common.Priority { + return common.decodePriority(proto); +} + +export function priorityToProto( + priority: common.Priority, +): temporal.api.common.v1.IPriority { + return common.compilePriority(priority); +} + +const VERSIONING_BEHAVIOR_PINNED = 1; +const VERSIONING_BEHAVIOR_AUTO_UPGRADE = 2; +const PINNED_OVERRIDE_BEHAVIOR_PINNED = 1; + +export function versioningOverrideFromProto( + proto: temporal.api.workflow.v1.IVersioningOverride, +): common.VersioningOverride | undefined { + if ( + proto.autoUpgrade || + proto.behavior === VERSIONING_BEHAVIOR_AUTO_UPGRADE + ) { + return "AUTO_UPGRADE"; + } + const pinnedVersion = proto.pinned?.version; + if (pinnedVersion?.deploymentName != null && pinnedVersion.buildId != null) { + return { + pinnedTo: { + deploymentName: pinnedVersion.deploymentName, + buildId: pinnedVersion.buildId, + }, + }; + } + if ( + proto.deployment?.seriesName != null && + proto.deployment.buildId != null + ) { + return { + pinnedTo: { + deploymentName: proto.deployment.seriesName, + buildId: proto.deployment.buildId, + }, + }; + } + return undefined; +} + +export function versioningOverrideToProto( + versioningOverride: common.VersioningOverride, +): temporal.api.workflow.v1.IVersioningOverride { + if (versioningOverride === "AUTO_UPGRADE") { + return { + behavior: VERSIONING_BEHAVIOR_AUTO_UPGRADE, + autoUpgrade: true, + }; + } + return { + behavior: VERSIONING_BEHAVIOR_PINNED, + pinnedVersion: common.toCanonicalString(versioningOverride.pinnedTo), + pinned: { + behavior: PINNED_OVERRIDE_BEHAVIOR_PINNED, + version: { + deploymentName: versioningOverride.pinnedTo.deploymentName, + buildId: versioningOverride.pinnedTo.buildId, + }, + }, + }; +} + +export function workflowIdReusePolicyFromProto( + policy: temporal.api.enums.v1.WorkflowIdReusePolicy, +): common.WorkflowIdReusePolicy | undefined { + return common.decodeWorkflowIdReusePolicy(policy); +} + +export function workflowIdReusePolicyToProto( + policy: common.WorkflowIdReusePolicy, +): temporal.api.enums.v1.WorkflowIdReusePolicy | undefined { + return common.encodeWorkflowIdReusePolicy(policy); +} + +export function workflowIdConflictPolicyFromProto( + policy: temporal.api.enums.v1.WorkflowIdConflictPolicy, +): common.WorkflowIdConflictPolicy | undefined { + return common.decodeWorkflowIdConflictPolicy(policy); +} + +export function workflowIdConflictPolicyToProto( + policy: common.WorkflowIdConflictPolicy, +): temporal.api.enums.v1.WorkflowIdConflictPolicy | undefined { + return common.encodeWorkflowIdConflictPolicy(policy); +} diff --git a/scripts/_nexus/temporal-system.wit b/scripts/_nexus/temporal-system.wit new file mode 100644 index 000000000..1b7f10656 --- /dev/null +++ b/scripts/_nexus/temporal-system.wit @@ -0,0 +1,119 @@ +package temporal:nexus@1.0.0; + +world system { + export workflow-service; +} + +/// @nexus.endpoint "__temporal_system" +/// @nexus.service-name "temporal.api.workflowservice.v1.WorkflowService" +/// @nexus.delay-load-temporalio-workflow +/// @nexus.experimental +interface workflow-service { + use nexus:temporal-types/model@1.0.0.{ + duration, + memo, + payloads, + placeholder, + priority, + retry-policy, + search-attributes, + signal-function, + task-queue, + user-metadata, + versioning-override, + workflow-function, + workflow-id-conflict-policy, + workflow-id-reuse-policy, + }; + + /// @nexus.doc "Request fields for signaling a workflow, starting it first if needed." + /// @nexus.experimental + /// @nexus.proto "temporal.api.workflowservice.v1.SignalWithStartWorkflowExecutionRequest" typescript-package="@temporalio/proto" + record signal-with-start-workflow-request { + /// @nexus.doc + /// python="Workflow type name or callable identifying the workflow to start." + /// typescript="Workflow type name or workflow function identifying the workflow to start." + /// @nexus.proto-field "workflow_type" + workflow: workflow-function, + /// @nexus.doc "Unique identifier for the workflow execution." + /// @nexus.proto-field "workflow_id" + id: string, + /// @nexus.doc "Task queue to run the workflow on." + task-queue: task-queue, + /// @nexus.doc + /// python="Signal name or callable to send with the start request." + /// typescript="Signal name or signal definition to send with the start request." + /// @nexus.proto-field "signal_name" + signal: signal-function, + /// @nexus.doc "Total workflow execution timeout, including retries and continue-as-new." + /// @nexus.proto-field "workflow_execution_timeout" + execution-timeout: option, + /// @nexus.doc "Timeout of a single workflow run." + /// @nexus.proto-field "workflow_run_timeout" + run-timeout: option, + /// @nexus.doc "Timeout of a single workflow task." + /// @nexus.proto-field "workflow_task_timeout" + task-timeout: option, + /// @nexus.omit + identity: placeholder, + /// @nexus.doc "Request ID used to deduplicate workflow start requests." + request-id: option, + /// @nexus.doc "Behavior when a closed workflow with the same ID exists. Default is allow-duplicate." + /// @nexus.proto-field "workflow_id_reuse_policy" + /// @nexus.default "allow-duplicate" + id-reuse-policy: workflow-id-reuse-policy, + /// @nexus.doc "Behavior when a workflow is currently running with the same ID. Set to use-existing for idempotent deduplication on workflow ID. Cannot be set if id-reuse-policy is terminate-if-running." + /// @nexus.proto-field "workflow_id_conflict_policy" + id-conflict-policy: option, + /// @nexus.doc "Retry policy for the workflow." + retry-policy: option, + /// @nexus.doc "Cron schedule for recurring workflow executions. See https://docs.temporal.io/cron-job." + cron-schedule: option, + /// @nexus.doc "Memo for the workflow." + memo: option, + /// @nexus.doc "Typed search attributes for the workflow." + search-attributes: option, + /// @nexus.doc "Priority of the workflow execution." + priority: option, + /// @nexus.doc "Override for workflow versioning behavior." + versioning-override: option, + /// @nexus.doc "Amount of time to wait before starting the workflow. This does not work with cron-schedule." + /// @nexus.proto-field "workflow_start_delay" + start-delay: option, + user-metadata: option, + /// @nexus.source python="workflow_namespace" typescript="workflowNamespace" + namespace: string, + /// @nexus.omit + control: placeholder, + /// @nexus.omit + header: placeholder, + /// @nexus.omit + links: placeholder, + /// @nexus.omit + time-skipping-config: placeholder, + } + + /// @nexus.experimental + /// @nexus.proto "temporal.api.workflowservice.v1.SignalWithStartWorkflowExecutionResponse" typescript-package="@temporalio/proto" + record signal-with-start-workflow-response { + run-id: option, + started: option, + /// @nexus.omit + signal-link: placeholder, + } + + /// @nexus.doc + /// "Signal a workflow, starting it first if needed." + /// returns="A workflow handle to the started workflow." + /// @nexus.output-transform + /// python-type="temporalio.workflow.ExternalWorkflowHandle[WorkflowResult]" + /// python="temporalio.workflow.get_external_workflow_handle(request.id, run_id=result.run_id)" + /// typescript-type="workflow.ExternalWorkflowHandle" + /// typescript="workflow.getExternalWorkflowHandle(request.id, result.runId ?? undefined)" + /// typescript-package="@temporalio/workflow" + /// @nexus.operation name="SignalWithStartWorkflowExecution" + /// @nexus.experimental + signal-with-start-workflow: func( + request: signal-with-start-workflow-request, + ) -> signal-with-start-workflow-response; +} diff --git a/scripts/gen_nexus_system_api.py b/scripts/gen_nexus_system_api.py new file mode 100644 index 000000000..cad236014 --- /dev/null +++ b/scripts/gen_nexus_system_api.py @@ -0,0 +1,159 @@ +import os +import shutil +import subprocess +import sys +import tempfile +from importlib.util import module_from_spec, spec_from_file_location +from pathlib import Path +from typing import cast + +import gen_protos + +base_dir = Path(__file__).parent.parent +sys.path.insert(0, str(base_dir)) +wit_input_dir = base_dir / "scripts" / "_nexus" +wit_path = wit_input_dir / "temporal-system.wit" +wit_deps_dir = wit_input_dir / "deps" +output_dir = base_dir / "temporalio" / "nexus" / "system" / "workflow_service" +workflow_init_path = base_dir / "temporalio" / "workflow" / "__init__.py" +workflowservice_request_response_proto = ( + gen_protos.api_proto_dir + / "temporal" + / "api" + / "workflowservice" + / "v1" + / "request_response.proto" +) + + +def nex_gen_command() -> list[str]: + if bin_path := os.environ.get("NEX_GEN_BIN"): + return [bin_path] + + if shutil.which("nex-gen") is None: + subprocess.check_call(["cargo", "install", "--locked", "nex-gen"]) + return ["nex-gen"] + + +def build_descriptor_set(descriptor_path: Path) -> None: + subprocess.check_call( + [ + sys.executable, + "-mgrpc_tools.protoc", + f"--proto_path={gen_protos.api_proto_dir}", + f"--proto_path={gen_protos.proto_dir}", + "--include_imports", + f"--descriptor_set_out={descriptor_path}", + str(workflowservice_request_response_proto), + ] + ) + + +def strip_unsupported_pyright_comments() -> None: + for path in output_dir.rglob("*.py"): + content = path.read_text() + content = content.replace("# pyright: reportAny=false\n", "") + content = content.replace( + "# pyright: reportAny=false, reportExplicitAny=false\n", "" + ) + path.write_text(content) + + +def generate_workflow_exports() -> None: + spec = spec_from_file_location( + "temporalio_nexus_system_workflow_service_exports", + output_dir / "__init__.py", + submodule_search_locations=[str(output_dir)], + ) + if spec is None or spec.loader is None: + raise RuntimeError(f"Cannot load generated workflow service from {output_dir}") + module = module_from_spec(spec) + sys.modules[spec.name] = module + spec.loader.exec_module(module) + exports = cast(list[str], module.__all__) + + import_block = [ + "# BEGIN GENERATED NEXUS SYSTEM EXPORTS\n", + "from temporalio.nexus.system.workflow_service import (\n", + *[f" {export},\n" for export in exports], + ")\n", + "# END GENERATED NEXUS SYSTEM EXPORTS\n", + ] + all_block = [ + " # BEGIN GENERATED NEXUS SYSTEM __ALL__\n", + *[f' "{export}",\n' for export in exports], + " # END GENERATED NEXUS SYSTEM __ALL__\n", + ] + content = workflow_init_path.read_text() + start = content.index("# BEGIN GENERATED NEXUS SYSTEM EXPORTS") + end = content.index("# END GENERATED NEXUS SYSTEM EXPORTS", start) + end = content.index("\n", end) + 1 + content = content[:start] + "".join(import_block) + content[end:] + start = content.index(" # BEGIN GENERATED NEXUS SYSTEM __ALL__") + end = content.index(" # END GENERATED NEXUS SYSTEM __ALL__", start) + end = content.index("\n", end) + 1 + workflow_init_path.write_text(content[:start] + "".join(all_block) + content[end:]) + + +def generate_nexus_system_api() -> None: + if not wit_path.exists(): + raise RuntimeError(f"missing WIT source: {wit_path}") + if not wit_deps_dir.exists(): + raise RuntimeError(f"missing WIT dependency directory: {wit_deps_dir}") + + with tempfile.TemporaryDirectory(dir=base_dir) as temp_dir: + descriptor_path = Path(temp_dir) / "temporal_api.bin" + build_descriptor_set(descriptor_path) + command = nex_gen_command() + + shutil.rmtree(output_dir, ignore_errors=True) + output_dir.parent.mkdir(parents=True, exist_ok=True) + subprocess.check_call( + [ + *command, + "generate", + "--lang", + "python", + "--input", + str(wit_path), + "--input", + str(wit_deps_dir), + "--descriptors", + str(descriptor_path), + "--output", + str(output_dir), + ] + ) + + (output_dir.parent / "__init__.py").touch() + strip_unsupported_pyright_comments() + generate_workflow_exports() + subprocess.check_call( + [ + sys.executable, + "-m", + "ruff", + "check", + "--select", + "I", + "--fix", + str(output_dir), + str(workflow_init_path), + ] + ) + subprocess.check_call( + [ + sys.executable, + "-m", + "ruff", + "format", + str(output_dir), + str(workflow_init_path), + ] + ) + + +if __name__ == "__main__": + print("Generating Nexus system API...", file=sys.stderr) + generate_nexus_system_api() + print("Done", file=sys.stderr) diff --git a/scripts/gen_payload_visitor.py b/scripts/gen_payload_visitor.py index 928be03e5..4e0d780aa 100644 --- a/scripts/gen_payload_visitor.py +++ b/scripts/gen_payload_visitor.py @@ -1,9 +1,16 @@ import subprocess import sys +from importlib.util import module_from_spec, spec_from_file_location from pathlib import Path +from typing import cast +import google.protobuf.message +import nexusrpc from google.protobuf.descriptor import Descriptor, FieldDescriptor +base_dir = Path(__file__).parent.parent +sys.path.insert(0, str(base_dir)) + from temporalio.api.common.v1.message_pb2 import Payload, Payloads, SearchAttributes from temporalio.bridge.proto.workflow_activation.workflow_activation_pb2 import ( WorkflowActivation, @@ -12,7 +19,36 @@ WorkflowActivationCompletion, ) -base_dir = Path(__file__).parent.parent + +def discover_system_nexus_roots() -> list[Descriptor]: + module_path = ( + base_dir / "temporalio" / "nexus" / "system" / "workflow_service" / "service.py" + ) + spec = spec_from_file_location( + "temporalio_nexus_system_workflow_service", module_path + ) + if spec is None or spec.loader is None: + raise RuntimeError(f"Cannot load generated system service from {module_path}") + module = module_from_spec(spec) + sys.modules[spec.name] = module + spec.loader.exec_module(module) + + roots: list[Descriptor] = [] + for operation in vars(module.WorkflowService).values(): + if not isinstance(operation, nexusrpc.Operation): + continue + for proto_type in (operation.input_type, operation.output_type): + if isinstance(proto_type, type) and issubclass( + proto_type, google.protobuf.message.Message + ): + roots.append(cast(Descriptor, proto_type.DESCRIPTOR)) + deduped: list[Descriptor] = [] + seen: set[str] = set() + for root in roots: + if root.full_name not in seen: + seen.add(root.full_name) + deduped.append(root) + return deduped def name_for(desc: Descriptor) -> str: @@ -80,28 +116,15 @@ def generate(self, roots: list[Descriptor]) -> str: self.walk(r) header = """ +from __future__ import annotations + # This file is generated by gen_payload_visitor.py. Changes should be made there. -import abc import asyncio -from typing import Any, MutableSequence +from typing import Any +import temporalio.nexus.system from temporalio.api.common.v1.message_pb2 import Payload - - -class VisitorFunctions(abc.ABC): - \"\"\"Set of functions which can be called by the visitor. - Allows handling payloads as a sequence. - \"\"\" - - @abc.abstractmethod - async def visit_payload(self, payload: Payload) -> None: - \"\"\"Called when encountering a single payload.\"\"\" - raise NotImplementedError() - - @abc.abstractmethod - async def visit_payloads(self, payloads: MutableSequence[Payload]) -> None: - \"\"\"Called when encountering multiple payloads together.\"\"\" - raise NotImplementedError() +from temporalio.bridge._visitor_functions import PayloadSequence, VisitorFunctions class _BoundedVisitorFunctions(VisitorFunctions): @@ -126,7 +149,7 @@ async def _run() -> None: self._tasks.append(asyncio.create_task(_run())) - async def visit_payloads(self, payloads: MutableSequence[Payload]) -> None: + async def visit_payloads(self, payloads: PayloadSequence) -> None: await self._sem.acquire() async def _run() -> None: @@ -137,6 +160,9 @@ async def _run() -> None: self._tasks.append(asyncio.create_task(_run())) + async def visit_system_nexus_envelope(self, payload: Payload) -> None: + await self._inner.visit_system_nexus_envelope(payload) + async def drain(self) -> None: \"\"\"Wait for all in-flight background tasks to complete. @@ -199,6 +225,28 @@ async def visit( finally: await bounded.drain() + async def _visit_system_nexus_payload( + self, + fs: VisitorFunctions, + service: str, + operation: str, + payload: Payload, + ) -> None: + new_payload = await temporalio.nexus.system.visit_payload( + service, + operation, + payload, + fs, + self.skip_search_attributes, + ) + if new_payload is None: + await self._visit_temporal_api_common_v1_Payload(fs, payload) + return + + if new_payload is not payload: + payload.CopyFrom(new_payload) + await fs.visit_system_nexus_envelope(payload) + """ return header + "\n".join(self.methods) @@ -212,15 +260,15 @@ def __init__(self): self.in_progress: set[str] = set() self.methods: list[str] = [ """\ - async def _visit_temporal_api_common_v1_Payload(self, fs, o): + async def _visit_temporal_api_common_v1_Payload(self, fs: VisitorFunctions, o: Payload): await fs.visit_payload(o) """, """\ - async def _visit_temporal_api_common_v1_Payloads(self, fs, o): + async def _visit_temporal_api_common_v1_Payloads(self, fs: VisitorFunctions, o: Any): await fs.visit_payloads(o.payloads) """, """\ - async def _visit_payload_container(self, fs, o): + async def _visit_payload_container(self, fs: VisitorFunctions, o: PayloadSequence): await fs.visit_payloads(o) """, ] @@ -275,6 +323,22 @@ def walk(self, desc: Descriptor) -> bool: # Process regular fields first for field in regular_fields: + if ( + desc.full_name == "coresdk.workflow_commands.ScheduleNexusOperation" + and field.name == "input" + ): + has_payload = True + emit_items.append( + ( + "system_nexus", + field.name, + "o.service", + "o.operation", + "o.input", + ) + ) + continue + # Repeated fields (including maps which are represented as repeated messages) if field.label == FieldDescriptor.LABEL_REPEATED: if ( @@ -359,7 +423,10 @@ def walk(self, desc: Descriptor) -> bool: self.in_progress.discard(key) if has_payload: - lines: list[str] = [f" async def _visit_{name_for(desc)}(self, fs, o):"] + lines: list[str] = [ + f" async def _visit_{name_for(desc)}" + "(self, fs: VisitorFunctions, o: Any):" + ] if is_search_attrs: lines.append(" if self.skip_search_attributes:") lines.append(" return") @@ -375,6 +442,14 @@ def walk(self, desc: Descriptor) -> bool: field_name, access_expr, child_method, presence_word ) ) + elif item[0] == "system_nexus": + _, field_name, service_expr, operation_expr, payload_expr = item + lines.append( + f' if o.HasField("{field_name}"):\n' + " await self._visit_system_nexus_payload(\n" + f" fs, {service_expr}, {operation_expr}, {payload_expr}\n" + " )" + ) else: # oneof_group for field_name, access_expr, child_method, presence_word in item[1]: lines.append( @@ -387,8 +462,7 @@ def walk(self, desc: Descriptor) -> bool: return has_payload -def write_generated_visitors_into_visitor_generated_py() -> None: - """Write the generated visitor code into _visitor.py.""" +def write_bridge_visitors() -> None: out_path = base_dir / "temporalio" / "bridge" / "_visitor.py" # Build root descriptors: WorkflowActivation, WorkflowActivationCompletion, @@ -402,7 +476,41 @@ def write_generated_visitors_into_visitor_generated_py() -> None: out_path.write_text(code) +def write_system_nexus_payload_visitors() -> None: + out_path = base_dir / "temporalio" / "nexus" / "system" / "_payload_visitor.py" + code = VisitorGenerator().generate(discover_system_nexus_roots()) + out_path.write_text(code) + + if __name__ == "__main__": print("Generating temporalio/bridge/_visitor.py...", file=sys.stderr) - write_generated_visitors_into_visitor_generated_py() - subprocess.run(["uv", "run", "ruff", "format", "temporalio/bridge/_visitor.py"]) + write_bridge_visitors() + print("Generating temporalio/nexus/system/_payload_visitor.py...", file=sys.stderr) + write_system_nexus_payload_visitors() + subprocess.run( + [ + "uv", + "run", + "ruff", + "check", + "--select", + "I", + "--fix", + "temporalio/bridge/_visitor.py", + "temporalio/nexus/system/_payload_visitor.py", + ], + cwd=base_dir, + check=True, + ) + subprocess.run( + [ + "uv", + "run", + "ruff", + "format", + "temporalio/bridge/_visitor.py", + "temporalio/nexus/system/_payload_visitor.py", + ], + cwd=base_dir, + check=True, + ) diff --git a/temporalio/bridge/_visitor.py b/temporalio/bridge/_visitor.py index 0f030ac01..40bbceb8f 100644 --- a/temporalio/bridge/_visitor.py +++ b/temporalio/bridge/_visitor.py @@ -1,25 +1,12 @@ +from __future__ import annotations + # This file is generated by gen_payload_visitor.py. Changes should be made there. -import abc import asyncio -from typing import Any, MutableSequence +from typing import Any +import temporalio.nexus.system from temporalio.api.common.v1.message_pb2 import Payload - - -class VisitorFunctions(abc.ABC): - """Set of functions which can be called by the visitor. - Allows handling payloads as a sequence. - """ - - @abc.abstractmethod - async def visit_payload(self, payload: Payload) -> None: - """Called when encountering a single payload.""" - raise NotImplementedError() - - @abc.abstractmethod - async def visit_payloads(self, payloads: MutableSequence[Payload]) -> None: - """Called when encountering multiple payloads together.""" - raise NotImplementedError() +from temporalio.bridge._visitor_functions import PayloadSequence, VisitorFunctions class _BoundedVisitorFunctions(VisitorFunctions): @@ -44,7 +31,7 @@ async def _run() -> None: self._tasks.append(asyncio.create_task(_run())) - async def visit_payloads(self, payloads: MutableSequence[Payload]) -> None: + async def visit_payloads(self, payloads: PayloadSequence) -> None: await self._sem.acquire() async def _run() -> None: @@ -55,6 +42,9 @@ async def _run() -> None: self._tasks.append(asyncio.create_task(_run())) + async def visit_system_nexus_envelope(self, payload: Payload) -> None: + await self._inner.visit_system_nexus_envelope(payload) + async def drain(self) -> None: """Wait for all in-flight background tasks to complete. @@ -117,36 +107,72 @@ async def visit(self, fs: VisitorFunctions, root: Any) -> None: finally: await bounded.drain() - async def _visit_temporal_api_common_v1_Payload(self, fs, o): + async def _visit_system_nexus_payload( + self, + fs: VisitorFunctions, + service: str, + operation: str, + payload: Payload, + ) -> None: + new_payload = await temporalio.nexus.system.visit_payload( + service, + operation, + payload, + fs, + self.skip_search_attributes, + ) + if new_payload is None: + await self._visit_temporal_api_common_v1_Payload(fs, payload) + return + + if new_payload is not payload: + payload.CopyFrom(new_payload) + await fs.visit_system_nexus_envelope(payload) + + async def _visit_temporal_api_common_v1_Payload( + self, fs: VisitorFunctions, o: Payload + ): await fs.visit_payload(o) - async def _visit_temporal_api_common_v1_Payloads(self, fs, o): + async def _visit_temporal_api_common_v1_Payloads( + self, fs: VisitorFunctions, o: Any + ): await fs.visit_payloads(o.payloads) - async def _visit_payload_container(self, fs, o): + async def _visit_payload_container(self, fs: VisitorFunctions, o: PayloadSequence): await fs.visit_payloads(o) - async def _visit_temporal_api_failure_v1_ApplicationFailureInfo(self, fs, o): + async def _visit_temporal_api_failure_v1_ApplicationFailureInfo( + self, fs: VisitorFunctions, o: Any + ): if o.HasField("details"): await self._visit_temporal_api_common_v1_Payloads(fs, o.details) - async def _visit_temporal_api_failure_v1_TimeoutFailureInfo(self, fs, o): + async def _visit_temporal_api_failure_v1_TimeoutFailureInfo( + self, fs: VisitorFunctions, o: Any + ): if o.HasField("last_heartbeat_details"): await self._visit_temporal_api_common_v1_Payloads( fs, o.last_heartbeat_details ) - async def _visit_temporal_api_failure_v1_CanceledFailureInfo(self, fs, o): + async def _visit_temporal_api_failure_v1_CanceledFailureInfo( + self, fs: VisitorFunctions, o: Any + ): if o.HasField("details"): await self._visit_temporal_api_common_v1_Payloads(fs, o.details) - async def _visit_temporal_api_failure_v1_ResetWorkflowFailureInfo(self, fs, o): + async def _visit_temporal_api_failure_v1_ResetWorkflowFailureInfo( + self, fs: VisitorFunctions, o: Any + ): if o.HasField("last_heartbeat_details"): await self._visit_temporal_api_common_v1_Payloads( fs, o.last_heartbeat_details ) - async def _visit_temporal_api_failure_v1_Failure(self, fs, o): + async def _visit_temporal_api_failure_v1_Failure( + self, fs: VisitorFunctions, o: Any + ): if o.HasField("encoded_attributes"): await self._visit_temporal_api_common_v1_Payload(fs, o.encoded_attributes) if o.HasField("cause"): @@ -168,17 +194,21 @@ async def _visit_temporal_api_failure_v1_Failure(self, fs, o): fs, o.reset_workflow_failure_info ) - async def _visit_temporal_api_common_v1_Memo(self, fs, o): + async def _visit_temporal_api_common_v1_Memo(self, fs: VisitorFunctions, o: Any): for v in o.fields.values(): await self._visit_temporal_api_common_v1_Payload(fs, v) - async def _visit_temporal_api_common_v1_SearchAttributes(self, fs, o): + async def _visit_temporal_api_common_v1_SearchAttributes( + self, fs: VisitorFunctions, o: Any + ): if self.skip_search_attributes: return for v in o.indexed_fields.values(): await self._visit_temporal_api_common_v1_Payload(fs, v) - async def _visit_coresdk_workflow_activation_InitializeWorkflow(self, fs, o): + async def _visit_coresdk_workflow_activation_InitializeWorkflow( + self, fs: VisitorFunctions, o: Any + ): await self._visit_payload_container(fs, o.arguments) if not self.skip_headers: for v in o.headers.values(): @@ -196,31 +226,43 @@ async def _visit_coresdk_workflow_activation_InitializeWorkflow(self, fs, o): fs, o.search_attributes ) - async def _visit_coresdk_workflow_activation_QueryWorkflow(self, fs, o): + async def _visit_coresdk_workflow_activation_QueryWorkflow( + self, fs: VisitorFunctions, o: Any + ): await self._visit_payload_container(fs, o.arguments) if not self.skip_headers: for v in o.headers.values(): await self._visit_temporal_api_common_v1_Payload(fs, v) - async def _visit_coresdk_workflow_activation_SignalWorkflow(self, fs, o): + async def _visit_coresdk_workflow_activation_SignalWorkflow( + self, fs: VisitorFunctions, o: Any + ): await self._visit_payload_container(fs, o.input) if not self.skip_headers: for v in o.headers.values(): await self._visit_temporal_api_common_v1_Payload(fs, v) - async def _visit_coresdk_activity_result_Success(self, fs, o): + async def _visit_coresdk_activity_result_Success( + self, fs: VisitorFunctions, o: Any + ): if o.HasField("result"): await self._visit_temporal_api_common_v1_Payload(fs, o.result) - async def _visit_coresdk_activity_result_Failure(self, fs, o): + async def _visit_coresdk_activity_result_Failure( + self, fs: VisitorFunctions, o: Any + ): if o.HasField("failure"): await self._visit_temporal_api_failure_v1_Failure(fs, o.failure) - async def _visit_coresdk_activity_result_Cancellation(self, fs, o): + async def _visit_coresdk_activity_result_Cancellation( + self, fs: VisitorFunctions, o: Any + ): if o.HasField("failure"): await self._visit_temporal_api_failure_v1_Failure(fs, o.failure) - async def _visit_coresdk_activity_result_ActivityResolution(self, fs, o): + async def _visit_coresdk_activity_result_ActivityResolution( + self, fs: VisitorFunctions, o: Any + ): if o.HasField("completed"): await self._visit_coresdk_activity_result_Success(fs, o.completed) elif o.HasField("failed"): @@ -228,37 +270,43 @@ async def _visit_coresdk_activity_result_ActivityResolution(self, fs, o): elif o.HasField("cancelled"): await self._visit_coresdk_activity_result_Cancellation(fs, o.cancelled) - async def _visit_coresdk_workflow_activation_ResolveActivity(self, fs, o): + async def _visit_coresdk_workflow_activation_ResolveActivity( + self, fs: VisitorFunctions, o: Any + ): if o.HasField("result"): await self._visit_coresdk_activity_result_ActivityResolution(fs, o.result) async def _visit_coresdk_workflow_activation_ResolveChildWorkflowExecutionStartCancelled( - self, fs, o + self, fs: VisitorFunctions, o: Any ): if o.HasField("failure"): await self._visit_temporal_api_failure_v1_Failure(fs, o.failure) async def _visit_coresdk_workflow_activation_ResolveChildWorkflowExecutionStart( - self, fs, o + self, fs: VisitorFunctions, o: Any ): if o.HasField("cancelled"): await self._visit_coresdk_workflow_activation_ResolveChildWorkflowExecutionStartCancelled( fs, o.cancelled ) - async def _visit_coresdk_child_workflow_Success(self, fs, o): + async def _visit_coresdk_child_workflow_Success(self, fs: VisitorFunctions, o: Any): if o.HasField("result"): await self._visit_temporal_api_common_v1_Payload(fs, o.result) - async def _visit_coresdk_child_workflow_Failure(self, fs, o): + async def _visit_coresdk_child_workflow_Failure(self, fs: VisitorFunctions, o: Any): if o.HasField("failure"): await self._visit_temporal_api_failure_v1_Failure(fs, o.failure) - async def _visit_coresdk_child_workflow_Cancellation(self, fs, o): + async def _visit_coresdk_child_workflow_Cancellation( + self, fs: VisitorFunctions, o: Any + ): if o.HasField("failure"): await self._visit_temporal_api_failure_v1_Failure(fs, o.failure) - async def _visit_coresdk_child_workflow_ChildWorkflowResult(self, fs, o): + async def _visit_coresdk_child_workflow_ChildWorkflowResult( + self, fs: VisitorFunctions, o: Any + ): if o.HasField("completed"): await self._visit_coresdk_child_workflow_Success(fs, o.completed) elif o.HasField("failed"): @@ -267,36 +315,40 @@ async def _visit_coresdk_child_workflow_ChildWorkflowResult(self, fs, o): await self._visit_coresdk_child_workflow_Cancellation(fs, o.cancelled) async def _visit_coresdk_workflow_activation_ResolveChildWorkflowExecution( - self, fs, o + self, fs: VisitorFunctions, o: Any ): if o.HasField("result"): await self._visit_coresdk_child_workflow_ChildWorkflowResult(fs, o.result) async def _visit_coresdk_workflow_activation_ResolveSignalExternalWorkflow( - self, fs, o + self, fs: VisitorFunctions, o: Any ): if o.HasField("failure"): await self._visit_temporal_api_failure_v1_Failure(fs, o.failure) async def _visit_coresdk_workflow_activation_ResolveRequestCancelExternalWorkflow( - self, fs, o + self, fs: VisitorFunctions, o: Any ): if o.HasField("failure"): await self._visit_temporal_api_failure_v1_Failure(fs, o.failure) - async def _visit_coresdk_workflow_activation_DoUpdate(self, fs, o): + async def _visit_coresdk_workflow_activation_DoUpdate( + self, fs: VisitorFunctions, o: Any + ): await self._visit_payload_container(fs, o.input) if not self.skip_headers: for v in o.headers.values(): await self._visit_temporal_api_common_v1_Payload(fs, v) async def _visit_coresdk_workflow_activation_ResolveNexusOperationStart( - self, fs, o + self, fs: VisitorFunctions, o: Any ): if o.HasField("failed"): await self._visit_temporal_api_failure_v1_Failure(fs, o.failed) - async def _visit_coresdk_nexus_NexusOperationResult(self, fs, o): + async def _visit_coresdk_nexus_NexusOperationResult( + self, fs: VisitorFunctions, o: Any + ): if o.HasField("completed"): await self._visit_temporal_api_common_v1_Payload(fs, o.completed) elif o.HasField("failed"): @@ -306,11 +358,15 @@ async def _visit_coresdk_nexus_NexusOperationResult(self, fs, o): elif o.HasField("timed_out"): await self._visit_temporal_api_failure_v1_Failure(fs, o.timed_out) - async def _visit_coresdk_workflow_activation_ResolveNexusOperation(self, fs, o): + async def _visit_coresdk_workflow_activation_ResolveNexusOperation( + self, fs: VisitorFunctions, o: Any + ): if o.HasField("result"): await self._visit_coresdk_nexus_NexusOperationResult(fs, o.result) - async def _visit_coresdk_workflow_activation_WorkflowActivationJob(self, fs, o): + async def _visit_coresdk_workflow_activation_WorkflowActivationJob( + self, fs: VisitorFunctions, o: Any + ): if o.HasField("initialize_workflow"): await self._visit_coresdk_workflow_activation_InitializeWorkflow( fs, o.initialize_workflow @@ -354,42 +410,56 @@ async def _visit_coresdk_workflow_activation_WorkflowActivationJob(self, fs, o): fs, o.resolve_nexus_operation ) - async def _visit_coresdk_workflow_activation_WorkflowActivation(self, fs, o): + async def _visit_coresdk_workflow_activation_WorkflowActivation( + self, fs: VisitorFunctions, o: Any + ): for v in o.jobs: await self._visit_coresdk_workflow_activation_WorkflowActivationJob(fs, v) - async def _visit_temporal_api_sdk_v1_UserMetadata(self, fs, o): + async def _visit_temporal_api_sdk_v1_UserMetadata( + self, fs: VisitorFunctions, o: Any + ): if o.HasField("summary"): await self._visit_temporal_api_common_v1_Payload(fs, o.summary) if o.HasField("details"): await self._visit_temporal_api_common_v1_Payload(fs, o.details) - async def _visit_coresdk_workflow_commands_ScheduleActivity(self, fs, o): + async def _visit_coresdk_workflow_commands_ScheduleActivity( + self, fs: VisitorFunctions, o: Any + ): if not self.skip_headers: for v in o.headers.values(): await self._visit_temporal_api_common_v1_Payload(fs, v) await self._visit_payload_container(fs, o.arguments) - async def _visit_coresdk_workflow_commands_QuerySuccess(self, fs, o): + async def _visit_coresdk_workflow_commands_QuerySuccess( + self, fs: VisitorFunctions, o: Any + ): if o.HasField("response"): await self._visit_temporal_api_common_v1_Payload(fs, o.response) - async def _visit_coresdk_workflow_commands_QueryResult(self, fs, o): + async def _visit_coresdk_workflow_commands_QueryResult( + self, fs: VisitorFunctions, o: Any + ): if o.HasField("succeeded"): await self._visit_coresdk_workflow_commands_QuerySuccess(fs, o.succeeded) elif o.HasField("failed"): await self._visit_temporal_api_failure_v1_Failure(fs, o.failed) - async def _visit_coresdk_workflow_commands_CompleteWorkflowExecution(self, fs, o): + async def _visit_coresdk_workflow_commands_CompleteWorkflowExecution( + self, fs: VisitorFunctions, o: Any + ): if o.HasField("result"): await self._visit_temporal_api_common_v1_Payload(fs, o.result) - async def _visit_coresdk_workflow_commands_FailWorkflowExecution(self, fs, o): + async def _visit_coresdk_workflow_commands_FailWorkflowExecution( + self, fs: VisitorFunctions, o: Any + ): if o.HasField("failure"): await self._visit_temporal_api_failure_v1_Failure(fs, o.failure) async def _visit_coresdk_workflow_commands_ContinueAsNewWorkflowExecution( - self, fs, o + self, fs: VisitorFunctions, o: Any ): await self._visit_payload_container(fs, o.arguments) for v in o.memo.values(): @@ -402,7 +472,9 @@ async def _visit_coresdk_workflow_commands_ContinueAsNewWorkflowExecution( fs, o.search_attributes ) - async def _visit_coresdk_workflow_commands_StartChildWorkflowExecution(self, fs, o): + async def _visit_coresdk_workflow_commands_StartChildWorkflowExecution( + self, fs: VisitorFunctions, o: Any + ): await self._visit_payload_container(fs, o.input) if not self.skip_headers: for v in o.headers.values(): @@ -415,42 +487,52 @@ async def _visit_coresdk_workflow_commands_StartChildWorkflowExecution(self, fs, ) async def _visit_coresdk_workflow_commands_SignalExternalWorkflowExecution( - self, fs, o + self, fs: VisitorFunctions, o: Any ): await self._visit_payload_container(fs, o.args) if not self.skip_headers: for v in o.headers.values(): await self._visit_temporal_api_common_v1_Payload(fs, v) - async def _visit_coresdk_workflow_commands_ScheduleLocalActivity(self, fs, o): + async def _visit_coresdk_workflow_commands_ScheduleLocalActivity( + self, fs: VisitorFunctions, o: Any + ): if not self.skip_headers: for v in o.headers.values(): await self._visit_temporal_api_common_v1_Payload(fs, v) await self._visit_payload_container(fs, o.arguments) async def _visit_coresdk_workflow_commands_UpsertWorkflowSearchAttributes( - self, fs, o + self, fs: VisitorFunctions, o: Any ): if o.HasField("search_attributes"): await self._visit_temporal_api_common_v1_SearchAttributes( fs, o.search_attributes ) - async def _visit_coresdk_workflow_commands_ModifyWorkflowProperties(self, fs, o): + async def _visit_coresdk_workflow_commands_ModifyWorkflowProperties( + self, fs: VisitorFunctions, o: Any + ): if o.HasField("upserted_memo"): await self._visit_temporal_api_common_v1_Memo(fs, o.upserted_memo) - async def _visit_coresdk_workflow_commands_UpdateResponse(self, fs, o): + async def _visit_coresdk_workflow_commands_UpdateResponse( + self, fs: VisitorFunctions, o: Any + ): if o.HasField("rejected"): await self._visit_temporal_api_failure_v1_Failure(fs, o.rejected) elif o.HasField("completed"): await self._visit_temporal_api_common_v1_Payload(fs, o.completed) - async def _visit_coresdk_workflow_commands_ScheduleNexusOperation(self, fs, o): + async def _visit_coresdk_workflow_commands_ScheduleNexusOperation( + self, fs: VisitorFunctions, o: Any + ): if o.HasField("input"): - await self._visit_temporal_api_common_v1_Payload(fs, o.input) + await self._visit_system_nexus_payload(fs, o.service, o.operation, o.input) - async def _visit_coresdk_workflow_commands_WorkflowCommand(self, fs, o): + async def _visit_coresdk_workflow_commands_WorkflowCommand( + self, fs: VisitorFunctions, o: Any + ): if o.HasField("user_metadata"): await self._visit_temporal_api_sdk_v1_UserMetadata(fs, o.user_metadata) if o.HasField("schedule_activity"): @@ -502,16 +584,20 @@ async def _visit_coresdk_workflow_commands_WorkflowCommand(self, fs, o): fs, o.schedule_nexus_operation ) - async def _visit_coresdk_workflow_completion_Success(self, fs, o): + async def _visit_coresdk_workflow_completion_Success( + self, fs: VisitorFunctions, o: Any + ): for v in o.commands: await self._visit_coresdk_workflow_commands_WorkflowCommand(fs, v) - async def _visit_coresdk_workflow_completion_Failure(self, fs, o): + async def _visit_coresdk_workflow_completion_Failure( + self, fs: VisitorFunctions, o: Any + ): if o.HasField("failure"): await self._visit_temporal_api_failure_v1_Failure(fs, o.failure) async def _visit_coresdk_workflow_completion_WorkflowActivationCompletion( - self, fs, o + self, fs: VisitorFunctions, o: Any ): if o.HasField("successful"): await self._visit_coresdk_workflow_completion_Success(fs, o.successful) diff --git a/temporalio/bridge/_visitor_functions.py b/temporalio/bridge/_visitor_functions.py new file mode 100644 index 000000000..6014f8e75 --- /dev/null +++ b/temporalio/bridge/_visitor_functions.py @@ -0,0 +1,25 @@ +from __future__ import annotations + +from typing import Protocol + +from google.protobuf.internal.containers import RepeatedCompositeFieldContainer + +from temporalio.api.common.v1.message_pb2 import Payload + +PayloadSequence = list[Payload] | RepeatedCompositeFieldContainer[Payload] + + +class VisitorFunctions(Protocol): + """Functions invoked by generated payload visitors.""" + + async def visit_payload(self, payload: Payload) -> None: + """Visit a single payload.""" + ... + + async def visit_payloads(self, payloads: PayloadSequence) -> None: + """Visit a sequence of payloads together.""" + ... + + async def visit_system_nexus_envelope(self, payload: Payload) -> None: + """Visit a recognized system Nexus envelope payload.""" + return None diff --git a/temporalio/bridge/worker.py b/temporalio/bridge/worker.py index a9c857373..e1e23dd89 100644 --- a/temporalio/bridge/worker.py +++ b/temporalio/bridge/worker.py @@ -5,7 +5,7 @@ from __future__ import annotations -from collections.abc import Awaitable, Callable, MutableSequence, Sequence +from collections.abc import Awaitable, Callable, Sequence from dataclasses import dataclass from typing import ( TypeAlias, @@ -22,7 +22,7 @@ import temporalio.converter import temporalio.converter._extstore from temporalio.api.common.v1.message_pb2 import Payload -from temporalio.bridge._visitor import VisitorFunctions +from temporalio.bridge._visitor_functions import PayloadSequence, VisitorFunctions from temporalio.bridge.temporal_sdk_bridge import ( CustomSlotSupplier as BridgeCustomSlotSupplier, ) @@ -281,15 +281,20 @@ async def finalize_shutdown(self) -> None: class _Visitor(VisitorFunctions): - def __init__(self, f: Callable[[Sequence[Payload]], Awaitable[list[Payload]]]): + def __init__( + self, + f: Callable[[Sequence[Payload]], Awaitable[list[Payload]]], + visit_system_nexus_envelope: Callable[[Payload], Awaitable[None]] | None = None, + ): self._f = f + self._visit_system_nexus_envelope = visit_system_nexus_envelope async def visit_payload(self, payload: Payload) -> None: new_payload = (await self._f([payload]))[0] if new_payload is not payload: payload.CopyFrom(new_payload) - async def visit_payloads(self, payloads: MutableSequence[Payload]) -> None: + async def visit_payloads(self, payloads: PayloadSequence) -> None: if len(payloads) == 0: return new_payloads = await self._f(payloads) @@ -298,6 +303,10 @@ async def visit_payloads(self, payloads: MutableSequence[Payload]) -> None: del payloads[:] payloads.extend(new_payloads) + async def visit_system_nexus_envelope(self, payload: Payload) -> None: + if self._visit_system_nexus_envelope is not None: + await self._visit_system_nexus_envelope(payload) + async def decode_activation( activation: temporalio.bridge.proto.workflow_activation.WorkflowActivation, @@ -339,10 +348,20 @@ async def encode_completion( Returns: Metrics from any external storage store operations that occurred. """ + + async def _validate_system_nexus_envelope(payload: Payload) -> None: + data_converter._validate_payload_limits([payload]) + await CommandAwarePayloadVisitor( skip_search_attributes=True, skip_headers=not encode_headers, - ).visit(_Visitor(data_converter._encode_payload_sequence), completion) + ).visit( + _Visitor( + data_converter._encode_payload_sequence, + visit_system_nexus_envelope=_validate_system_nexus_envelope, + ), + completion, + ) async def _store_and_validate( payloads: Sequence[Payload], @@ -357,6 +376,12 @@ async def _store_and_validate( skip_search_attributes=True, skip_headers=not encode_headers, concurrency_limit=storage_concurrency_limit, - ).visit(_Visitor(_store_and_validate), completion) + ).visit( + _Visitor( + _store_and_validate, + visit_system_nexus_envelope=_validate_system_nexus_envelope, + ), + completion, + ) return metrics diff --git a/temporalio/nexus/system/__init__.py b/temporalio/nexus/system/__init__.py new file mode 100644 index 000000000..0187e0291 --- /dev/null +++ b/temporalio/nexus/system/__init__.py @@ -0,0 +1,74 @@ +"""System Nexus operation helpers.""" + +from __future__ import annotations + +import typing + +import google.protobuf.message +import nexusrpc + +import temporalio.api.common.v1 +import temporalio.converter +from temporalio.bridge._visitor_functions import VisitorFunctions +from temporalio.converter import BinaryProtoPayloadConverter, CompositePayloadConverter +from temporalio.nexus.system import workflow_service + + +class SystemNexusPayloadConverter(CompositePayloadConverter): + """Payload converter for system Nexus outer envelopes.""" + + def __init__(self) -> None: + """Create a payload converter for system Nexus outer envelopes.""" + super().__init__(BinaryProtoPayloadConverter()) + + +def _operation( + service: str, operation: str +) -> nexusrpc.Operation[typing.Any, typing.Any] | None: + return workflow_service.__nexus_operation_registry__.get((service, operation)) + + +async def visit_payload( + service: str, + operation: str, + payload: temporalio.api.common.v1.Payload, + visitor_functions: VisitorFunctions, + skip_search_attributes: bool, +) -> temporalio.api.common.v1.Payload | None: + """Visit nested payloads inside a recognized system Nexus envelope.""" + operation_def = _operation(service, operation) + if operation_def is None: + return None + input_type = operation_def.input_type + if not ( + isinstance(input_type, type) + and issubclass(input_type, google.protobuf.message.Message) + ): + return None + + payload_converter = get_payload_converter() + value = payload_converter.from_payload(payload, input_type) + from ._payload_visitor import PayloadVisitor + + await PayloadVisitor(skip_search_attributes=skip_search_attributes).visit( + visitor_functions, value + ) + return payload_converter.to_payload(value) + + +def is_system_operation(service: str, operation: str) -> bool: + """Return whether a Nexus operation uses a generated system envelope.""" + return _operation(service, operation) is not None + + +def get_payload_converter() -> temporalio.converter.PayloadConverter: + """Return the fixed payload converter for system Nexus outer envelopes.""" + return SystemNexusPayloadConverter() + + +__all__ = [ + "get_payload_converter", + "is_system_operation", + "SystemNexusPayloadConverter", + "visit_payload", +] diff --git a/temporalio/nexus/system/_payload_visitor.py b/temporalio/nexus/system/_payload_visitor.py new file mode 100644 index 000000000..ecc51e2c4 --- /dev/null +++ b/temporalio/nexus/system/_payload_visitor.py @@ -0,0 +1,185 @@ +from __future__ import annotations + +# This file is generated by gen_payload_visitor.py. Changes should be made there. +import asyncio +from typing import Any + +import temporalio.nexus.system +from temporalio.api.common.v1.message_pb2 import Payload +from temporalio.bridge._visitor_functions import PayloadSequence, VisitorFunctions + + +class _BoundedVisitorFunctions(VisitorFunctions): + """Wraps VisitorFunctions to cap concurrent payload visits via a semaphore. + + After the full traversal, call drain() to await all in-flight tasks. + """ + + def __init__(self, inner: VisitorFunctions, sem: asyncio.Semaphore) -> None: + self._inner = inner + self._sem = sem + self._tasks: list[asyncio.Task[None]] = [] + + async def visit_payload(self, payload: Payload) -> None: + await self._sem.acquire() + + async def _run() -> None: + try: + await self._inner.visit_payload(payload) + finally: + self._sem.release() + + self._tasks.append(asyncio.create_task(_run())) + + async def visit_payloads(self, payloads: PayloadSequence) -> None: + await self._sem.acquire() + + async def _run() -> None: + try: + await self._inner.visit_payloads(payloads) + finally: + self._sem.release() + + self._tasks.append(asyncio.create_task(_run())) + + async def visit_system_nexus_envelope(self, payload: Payload) -> None: + await self._inner.visit_system_nexus_envelope(payload) + + async def drain(self) -> None: + """Wait for all in-flight background tasks to complete. + + On cancellation or error, cancels all remaining tasks and awaits + them so their finally blocks run before this coroutine returns. + """ + if not self._tasks: + return + try: + await asyncio.gather(*self._tasks) + except BaseException: + for task in self._tasks: + task.cancel() + await asyncio.gather(*self._tasks, return_exceptions=True) + raise + + +class PayloadVisitor: + """A visitor for payloads. + Applies a function to every payload in a tree of messages. + """ + + def __init__( + self, + *, + skip_search_attributes: bool = False, + skip_headers: bool = False, + concurrency_limit: int = 1, + ): + """Creates a new payload visitor. + + Args: + skip_search_attributes: If True, search attributes are not visited. + skip_headers: If True, headers are not visited. + concurrency_limit: Maximum number of payload visits that may run + concurrently during a single call to visit(). Defaults to 1 + (sequential). + """ + if concurrency_limit < 1: + raise ValueError("concurrency_limit must be positive") + self.skip_search_attributes = skip_search_attributes + self.skip_headers = skip_headers + self._concurrency_limit = concurrency_limit + + async def visit(self, fs: VisitorFunctions, root: Any) -> None: + """Visits the given root message with the given function.""" + method_name = "_visit_" + root.DESCRIPTOR.full_name.replace(".", "_") + method = getattr(self, method_name, None) + if method is None: + raise ValueError(f"Unknown root message type: {root.DESCRIPTOR.full_name}") + if self._concurrency_limit == 1: + await method(fs, root) + return + + bounded = _BoundedVisitorFunctions( + fs, asyncio.Semaphore(self._concurrency_limit) + ) + try: + await method(bounded, root) + finally: + await bounded.drain() + + async def _visit_system_nexus_payload( + self, + fs: VisitorFunctions, + service: str, + operation: str, + payload: Payload, + ) -> None: + new_payload = await temporalio.nexus.system.visit_payload( + service, + operation, + payload, + fs, + self.skip_search_attributes, + ) + if new_payload is None: + await self._visit_temporal_api_common_v1_Payload(fs, payload) + return + + if new_payload is not payload: + payload.CopyFrom(new_payload) + await fs.visit_system_nexus_envelope(payload) + + async def _visit_temporal_api_common_v1_Payload( + self, fs: VisitorFunctions, o: Payload + ): + await fs.visit_payload(o) + + async def _visit_temporal_api_common_v1_Payloads( + self, fs: VisitorFunctions, o: Any + ): + await fs.visit_payloads(o.payloads) + + async def _visit_payload_container(self, fs: VisitorFunctions, o: PayloadSequence): + await fs.visit_payloads(o) + + async def _visit_temporal_api_common_v1_Memo(self, fs: VisitorFunctions, o: Any): + for v in o.fields.values(): + await self._visit_temporal_api_common_v1_Payload(fs, v) + + async def _visit_temporal_api_common_v1_SearchAttributes( + self, fs: VisitorFunctions, o: Any + ): + if self.skip_search_attributes: + return + for v in o.indexed_fields.values(): + await self._visit_temporal_api_common_v1_Payload(fs, v) + + async def _visit_temporal_api_common_v1_Header(self, fs: VisitorFunctions, o: Any): + for v in o.fields.values(): + await self._visit_temporal_api_common_v1_Payload(fs, v) + + async def _visit_temporal_api_sdk_v1_UserMetadata( + self, fs: VisitorFunctions, o: Any + ): + if o.HasField("summary"): + await self._visit_temporal_api_common_v1_Payload(fs, o.summary) + if o.HasField("details"): + await self._visit_temporal_api_common_v1_Payload(fs, o.details) + + async def _visit_temporal_api_workflowservice_v1_SignalWithStartWorkflowExecutionRequest( + self, fs: VisitorFunctions, o: Any + ): + if o.HasField("input"): + await self._visit_temporal_api_common_v1_Payloads(fs, o.input) + if o.HasField("signal_input"): + await self._visit_temporal_api_common_v1_Payloads(fs, o.signal_input) + if o.HasField("memo"): + await self._visit_temporal_api_common_v1_Memo(fs, o.memo) + if o.HasField("search_attributes"): + await self._visit_temporal_api_common_v1_SearchAttributes( + fs, o.search_attributes + ) + if o.HasField("header"): + await self._visit_temporal_api_common_v1_Header(fs, o.header) + if o.HasField("user_metadata"): + await self._visit_temporal_api_sdk_v1_UserMetadata(fs, o.user_metadata) diff --git a/temporalio/nexus/system/workflow_service/__init__.py b/temporalio/nexus/system/workflow_service/__init__.py new file mode 100644 index 000000000..7c24fa125 --- /dev/null +++ b/temporalio/nexus/system/workflow_service/__init__.py @@ -0,0 +1,18 @@ +# Generated by nex-gen. DO NOT EDIT! + +from __future__ import annotations + +from . import service as _service +from .operations.signal_with_start_workflow import signal_with_start_workflow + +__all__ = [ + "signal_with_start_workflow", +] + + +__nexus_operation_registry__ = { + ( + "temporal.api.workflowservice.v1.WorkflowService", + "SignalWithStartWorkflowExecution", + ): _service.WorkflowService.signal_with_start_workflow, +} diff --git a/temporalio/nexus/system/workflow_service/_resources/__init__.py b/temporalio/nexus/system/workflow_service/_resources/__init__.py new file mode 100644 index 000000000..373efbd33 --- /dev/null +++ b/temporalio/nexus/system/workflow_service/_resources/__init__.py @@ -0,0 +1,5 @@ +# Generated by nex-gen. DO NOT EDIT! + +from __future__ import annotations + +__all__ = [] diff --git a/temporalio/nexus/system/workflow_service/_support/__init__.py b/temporalio/nexus/system/workflow_service/_support/__init__.py new file mode 100644 index 000000000..166261d16 --- /dev/null +++ b/temporalio/nexus/system/workflow_service/_support/__init__.py @@ -0,0 +1,5 @@ +# Generated by nex-gen. DO NOT EDIT! + +from __future__ import annotations + +from .temporal_model_converters import * # noqa: F401,F403 diff --git a/temporalio/nexus/system/workflow_service/_support/temporal_model_converters.py b/temporalio/nexus/system/workflow_service/_support/temporal_model_converters.py new file mode 100644 index 000000000..58b51e263 --- /dev/null +++ b/temporalio/nexus/system/workflow_service/_support/temporal_model_converters.py @@ -0,0 +1,195 @@ +import collections.abc +import typing +from datetime import timedelta + +import google.protobuf.duration_pb2 + +import temporalio.api.common.v1.message_pb2 as common_pb2 +import temporalio.api.enums.v1.workflow_pb2 as workflow_enums_pb2 +import temporalio.api.taskqueue.v1.message_pb2 as taskqueue_pb2 +import temporalio.api.workflow.v1 +import temporalio.common +import temporalio.converter + + +def retry_policy_from_proto( + proto: common_pb2.RetryPolicy, +) -> temporalio.common.RetryPolicy: + return temporalio.common.RetryPolicy.from_proto(proto) + + +def retry_policy_to_proto( + retry_policy: temporalio.common.RetryPolicy, +) -> common_pb2.RetryPolicy: + proto = common_pb2.RetryPolicy() + retry_policy.apply_to_proto(proto) + return proto + + +def workflow_function_name( + value: str | collections.abc.Callable[..., collections.abc.Awaitable[object]], +) -> str: + from temporalio.workflow import _Definition # pyright: ignore[reportPrivateUsage] + + name, _result_type = _Definition.get_name_and_result_type(value) + return name + + +def signal_function_to_proto( + value: str | collections.abc.Callable[..., typing.Any], +) -> str: + from temporalio.workflow import ( + _SignalDefinition, # pyright: ignore[reportPrivateUsage] + ) + + return _SignalDefinition.must_name_from_fn_or_str(value) # pyright: ignore[reportUnknownMemberType] + + +def workflow_type_to_proto( + workflow_type: str + | collections.abc.Callable[..., collections.abc.Awaitable[object]], +) -> common_pb2.WorkflowType: + return common_pb2.WorkflowType(name=workflow_function_name(workflow_type)) + + +def task_queue_from_proto( + proto: taskqueue_pb2.TaskQueue, +) -> str: + return proto.name + + +def task_queue_to_proto( + task_queue: str, +) -> taskqueue_pb2.TaskQueue: + return taskqueue_pb2.TaskQueue(name=task_queue) + + +def workflow_namespace() -> str: + from temporalio.workflow import info + + return info().namespace + + +def payloads_to_proto( + values: collections.abc.Sequence[typing.Any], +) -> common_pb2.Payloads: + from temporalio.workflow import payload_converter + + return payload_converter().to_payloads_wrapper(values) + + +def _clone_payload(payload: common_pb2.Payload) -> common_pb2.Payload: + clone = common_pb2.Payload() + clone.CopyFrom(payload) + return clone + + +def _value_to_payload(value: object | common_pb2.Payload) -> common_pb2.Payload: + if isinstance(value, common_pb2.Payload): + return _clone_payload(value) + from temporalio.workflow import payload_converter + + payloads = payload_converter().to_payloads_wrapper([value]) + return _clone_payload(payloads.payloads[0]) + + +def _payload_to_value(payload: common_pb2.Payload) -> object: + wrapper = common_pb2.Payloads() + wrapper.payloads.add().CopyFrom(payload) + from temporalio.workflow import payload_converter + + return typing.cast( + object, + payload_converter().from_payloads_wrapper(wrapper)[0], + ) + + +def payload_from_proto( + proto: common_pb2.Payload, +) -> object: + return _payload_to_value(proto) + + +def payload_to_proto( + payload: object, +) -> common_pb2.Payload: + return _value_to_payload(payload) + + +def memo_from_proto( + proto: common_pb2.Memo, +) -> collections.abc.Mapping[str, object]: + return {key: _payload_to_value(value) for key, value in proto.fields.items()} + + +def memo_to_proto( + memo: collections.abc.Mapping[str, object], +) -> common_pb2.Memo: + message = common_pb2.Memo() + for key, value in memo.items(): + message.fields[key].CopyFrom(_value_to_payload(value)) + return message + + +def duration_from_proto(proto: google.protobuf.duration_pb2.Duration) -> timedelta: + return proto.ToTimedelta() + + +def duration_to_proto( + duration: timedelta, +) -> google.protobuf.duration_pb2.Duration: + proto = google.protobuf.duration_pb2.Duration() + proto.FromTimedelta(duration) + return proto + + +def workflow_id_reuse_policy_from_proto( + policy: workflow_enums_pb2.WorkflowIdReusePolicy.ValueType, +) -> temporalio.common.WorkflowIDReusePolicy: + return temporalio.common.WorkflowIDReusePolicy(int(policy)) + + +def workflow_id_reuse_policy_to_proto( + policy: temporalio.common.WorkflowIDReusePolicy, +) -> workflow_enums_pb2.WorkflowIdReusePolicy.ValueType: + return typing.cast(workflow_enums_pb2.WorkflowIdReusePolicy.ValueType, int(policy)) + + +def workflow_id_conflict_policy_from_proto( + policy: workflow_enums_pb2.WorkflowIdConflictPolicy.ValueType, +) -> temporalio.common.WorkflowIDConflictPolicy: + return temporalio.common.WorkflowIDConflictPolicy(int(policy)) + + +def workflow_id_conflict_policy_to_proto( + policy: temporalio.common.WorkflowIDConflictPolicy, +) -> workflow_enums_pb2.WorkflowIdConflictPolicy.ValueType: + return typing.cast( + workflow_enums_pb2.WorkflowIdConflictPolicy.ValueType, int(policy) + ) + + +def search_attributes_to_proto( + search_attributes: temporalio.common.TypedSearchAttributes, +) -> common_pb2.SearchAttributes: + proto = common_pb2.SearchAttributes() + temporalio.converter.encode_search_attributes(search_attributes, proto) + return proto + + +def priority_from_proto( + proto: common_pb2.Priority, +) -> temporalio.common.Priority: + return temporalio.common.Priority._from_proto(proto) # pyright: ignore[reportPrivateUsage] + + +def priority_to_proto( + priority: temporalio.common.Priority, +) -> common_pb2.Priority: + return priority._to_proto() # pyright: ignore[reportPrivateUsage] + + +def versioning_override_to_proto( + versioning_override: temporalio.common.VersioningOverride, +) -> temporalio.api.workflow.v1.VersioningOverride: + return versioning_override._to_proto() # pyright: ignore[reportPrivateUsage] diff --git a/temporalio/nexus/system/workflow_service/models.py b/temporalio/nexus/system/workflow_service/models.py new file mode 100644 index 000000000..88d77c35b --- /dev/null +++ b/temporalio/nexus/system/workflow_service/models.py @@ -0,0 +1,141 @@ +# Generated by nex-gen. DO NOT EDIT! + +from __future__ import annotations + +import collections.abc +import dataclasses +import datetime +import typing + +import temporalio.api.sdk.v1.user_metadata_pb2 +import temporalio.api.workflowservice.v1.request_response_pb2 +import temporalio.common + +from ._support import ( + duration_to_proto, + memo_to_proto, + payload_from_proto, + payload_to_proto, + payloads_to_proto, + priority_to_proto, + retry_policy_to_proto, + search_attributes_to_proto, + signal_function_to_proto, + task_queue_to_proto, + versioning_override_to_proto, + workflow_id_conflict_policy_to_proto, + workflow_id_reuse_policy_to_proto, + workflow_namespace, + workflow_type_to_proto, +) + + +@dataclasses.dataclass(slots=True, kw_only=True) +class SignalWithStartWorkflowRequest: + """ + .. warning:: + This API is experimental and subject to change. + """ + + workflow: str | collections.abc.Callable[..., collections.abc.Awaitable[typing.Any]] + args: list[typing.Any] | None = None + id: str + task_queue: str + signal: str | collections.abc.Callable[..., None | collections.abc.Awaitable[None]] + signal_args: list[typing.Any] | None = None + execution_timeout: datetime.timedelta | None = None + run_timeout: datetime.timedelta | None = None + task_timeout: datetime.timedelta | None = None + request_id: str | None = None + id_reuse_policy: temporalio.common.WorkflowIDReusePolicy = ( + temporalio.common.WorkflowIDReusePolicy.ALLOW_DUPLICATE + ) + id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy | None = None + retry_policy: temporalio.common.RetryPolicy | None = None + cron_schedule: str | None = None + memo: collections.abc.Mapping[str, typing.Any] | None = None + search_attributes: temporalio.common.TypedSearchAttributes | None = None + priority: temporalio.common.Priority | None = None + versioning_override: temporalio.common.VersioningOverride | None = None + start_delay: datetime.timedelta | None = None + user_metadata: UserMetadata | None = None + + def to_proto( + self, + ) -> temporalio.api.workflowservice.v1.request_response_pb2.SignalWithStartWorkflowExecutionRequest: + message = temporalio.api.workflowservice.v1.request_response_pb2.SignalWithStartWorkflowExecutionRequest() + message.workflow_type.CopyFrom(workflow_type_to_proto(self.workflow)) + if self.args is not None: + message.input.CopyFrom(payloads_to_proto(self.args)) + message.workflow_id = self.id + message.task_queue.CopyFrom(task_queue_to_proto(self.task_queue)) + message.signal_name = signal_function_to_proto(self.signal) + if self.signal_args is not None: + message.signal_input.CopyFrom(payloads_to_proto(self.signal_args)) + if self.execution_timeout is not None: + message.workflow_execution_timeout.CopyFrom( + duration_to_proto(self.execution_timeout) + ) + if self.run_timeout is not None: + message.workflow_run_timeout.CopyFrom(duration_to_proto(self.run_timeout)) + if self.task_timeout is not None: + message.workflow_task_timeout.CopyFrom(duration_to_proto(self.task_timeout)) + if self.request_id is not None: + message.request_id = self.request_id + message.workflow_id_reuse_policy = workflow_id_reuse_policy_to_proto( + self.id_reuse_policy + ) + if self.id_conflict_policy is not None: + message.workflow_id_conflict_policy = workflow_id_conflict_policy_to_proto( + self.id_conflict_policy + ) + if self.retry_policy is not None: + message.retry_policy.CopyFrom(retry_policy_to_proto(self.retry_policy)) + if self.cron_schedule is not None: + message.cron_schedule = self.cron_schedule + if self.memo is not None: + message.memo.CopyFrom(memo_to_proto(self.memo)) + if self.search_attributes is not None: + message.search_attributes.CopyFrom( + search_attributes_to_proto(self.search_attributes) + ) + if self.priority is not None: + message.priority.CopyFrom(priority_to_proto(self.priority)) + if self.versioning_override is not None: + message.versioning_override.CopyFrom( + versioning_override_to_proto(self.versioning_override) + ) + if self.start_delay is not None: + message.workflow_start_delay.CopyFrom(duration_to_proto(self.start_delay)) + if self.user_metadata is not None: + message.user_metadata.CopyFrom(self.user_metadata.to_proto()) + message.namespace = workflow_namespace() + return message + + +@dataclasses.dataclass(slots=True) +class UserMetadata: + static_summary: typing.Any | None = None + static_details: typing.Any | None = None + + @classmethod + def from_proto( + cls, + proto: temporalio.api.sdk.v1.user_metadata_pb2.UserMetadata, + ) -> UserMetadata: + return cls( + static_summary=payload_from_proto(proto.summary) + if proto.HasField("summary") + else None, + static_details=payload_from_proto(proto.details) + if proto.HasField("details") + else None, + ) + + def to_proto(self) -> temporalio.api.sdk.v1.user_metadata_pb2.UserMetadata: + message = temporalio.api.sdk.v1.user_metadata_pb2.UserMetadata() + if self.static_summary is not None: + message.summary.CopyFrom(payload_to_proto(self.static_summary)) + if self.static_details is not None: + message.details.CopyFrom(payload_to_proto(self.static_details)) + return message diff --git a/temporalio/nexus/system/workflow_service/operations/__init__.py b/temporalio/nexus/system/workflow_service/operations/__init__.py new file mode 100644 index 000000000..67c9cc56b --- /dev/null +++ b/temporalio/nexus/system/workflow_service/operations/__init__.py @@ -0,0 +1,3 @@ +# Generated by nex-gen. DO NOT EDIT! + +from __future__ import annotations diff --git a/temporalio/nexus/system/workflow_service/operations/signal_with_start_workflow.py b/temporalio/nexus/system/workflow_service/operations/signal_with_start_workflow.py new file mode 100644 index 000000000..ded496898 --- /dev/null +++ b/temporalio/nexus/system/workflow_service/operations/signal_with_start_workflow.py @@ -0,0 +1,507 @@ +# Generated by nex-gen. DO NOT EDIT! + +from __future__ import annotations + +import collections.abc +import datetime +import typing + +import typing_extensions + +import temporalio.api.workflowservice.v1.request_response_pb2 +import temporalio.common + +if typing.TYPE_CHECKING: + from temporalio.workflow import ExternalWorkflowHandle + +from ..models import ( + SignalWithStartWorkflowRequest, + UserMetadata, +) + +SignalArg = typing.TypeVar("SignalArg") +WorkflowResult = typing.TypeVar("WorkflowResult") +WorkflowArgs = typing_extensions.TypeVarTuple("WorkflowArgs") + + +async def _signal_with_start_workflow( + request: SignalWithStartWorkflowRequest, +) -> ExternalWorkflowHandle[typing.Any]: + from temporalio.workflow import ( + create_nexus_client, + get_external_workflow_handle, + ) + + request_proto = request.to_proto() + nexus_client = create_nexus_client( + service="temporal.api.workflowservice.v1.WorkflowService", + endpoint="__temporal_system", + ) + handle = await nexus_client.start_operation( + operation="SignalWithStartWorkflowExecution", + input=request_proto, + output_type=temporalio.api.workflowservice.v1.request_response_pb2.SignalWithStartWorkflowExecutionResponse, + ) + result = await handle + return get_external_workflow_handle(request.id, run_id=result.run_id) + + +@typing.overload +async def signal_with_start_workflow( + workflow: str, + *positional_args: object, + args: list[typing.Any] | None = ..., + id: str, + task_queue: str, + signal: str, + signal_args: list[typing.Any] | None = ..., + execution_timeout: datetime.timedelta | None = ..., + run_timeout: datetime.timedelta | None = ..., + task_timeout: datetime.timedelta | None = ..., + request_id: str | None = ..., + id_reuse_policy: temporalio.common.WorkflowIDReusePolicy = ..., + id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy | None = ..., + retry_policy: temporalio.common.RetryPolicy | None = ..., + cron_schedule: str | None = ..., + memo: collections.abc.Mapping[str, typing.Any] | None = ..., + search_attributes: temporalio.common.TypedSearchAttributes | None = ..., + priority: temporalio.common.Priority | None = ..., + versioning_override: temporalio.common.VersioningOverride | None = ..., + start_delay: datetime.timedelta | None = ..., + static_summary: str | None = ..., + static_details: str | None = ..., +) -> ExternalWorkflowHandle[typing.Any]: ... + + +@typing.overload +async def signal_with_start_workflow( + workflow: collections.abc.Callable[ + [typing.Any, typing_extensions.Unpack[WorkflowArgs]], + collections.abc.Awaitable[WorkflowResult], + ], + *positional_args: typing_extensions.Unpack[WorkflowArgs], + id: str, + task_queue: str, + signal: str, + signal_args: list[typing.Any] | None = ..., + execution_timeout: datetime.timedelta | None = ..., + run_timeout: datetime.timedelta | None = ..., + task_timeout: datetime.timedelta | None = ..., + request_id: str | None = ..., + id_reuse_policy: temporalio.common.WorkflowIDReusePolicy = ..., + id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy | None = ..., + retry_policy: temporalio.common.RetryPolicy | None = ..., + cron_schedule: str | None = ..., + memo: collections.abc.Mapping[str, typing.Any] | None = ..., + search_attributes: temporalio.common.TypedSearchAttributes | None = ..., + priority: temporalio.common.Priority | None = ..., + versioning_override: temporalio.common.VersioningOverride | None = ..., + start_delay: datetime.timedelta | None = ..., + static_summary: str | None = ..., + static_details: str | None = ..., +) -> ExternalWorkflowHandle[WorkflowResult]: ... + + +@typing.overload +async def signal_with_start_workflow( + workflow: collections.abc.Callable[..., collections.abc.Awaitable[WorkflowResult]], + *, + args: list[typing.Any], + id: str, + task_queue: str, + signal: str, + signal_args: list[typing.Any] | None = ..., + execution_timeout: datetime.timedelta | None = ..., + run_timeout: datetime.timedelta | None = ..., + task_timeout: datetime.timedelta | None = ..., + request_id: str | None = ..., + id_reuse_policy: temporalio.common.WorkflowIDReusePolicy = ..., + id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy | None = ..., + retry_policy: temporalio.common.RetryPolicy | None = ..., + cron_schedule: str | None = ..., + memo: collections.abc.Mapping[str, typing.Any] | None = ..., + search_attributes: temporalio.common.TypedSearchAttributes | None = ..., + priority: temporalio.common.Priority | None = ..., + versioning_override: temporalio.common.VersioningOverride | None = ..., + start_delay: datetime.timedelta | None = ..., + static_summary: str | None = ..., + static_details: str | None = ..., +) -> ExternalWorkflowHandle[WorkflowResult]: ... + + +@typing.overload +async def signal_with_start_workflow( + workflow: str, + *positional_args: object, + args: list[typing.Any] | None = ..., + id: str, + task_queue: str, + signal: collections.abc.Callable[ + [typing.Any], None | collections.abc.Awaitable[None] + ], + execution_timeout: datetime.timedelta | None = ..., + run_timeout: datetime.timedelta | None = ..., + task_timeout: datetime.timedelta | None = ..., + request_id: str | None = ..., + id_reuse_policy: temporalio.common.WorkflowIDReusePolicy = ..., + id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy | None = ..., + retry_policy: temporalio.common.RetryPolicy | None = ..., + cron_schedule: str | None = ..., + memo: collections.abc.Mapping[str, typing.Any] | None = ..., + search_attributes: temporalio.common.TypedSearchAttributes | None = ..., + priority: temporalio.common.Priority | None = ..., + versioning_override: temporalio.common.VersioningOverride | None = ..., + start_delay: datetime.timedelta | None = ..., + static_summary: str | None = ..., + static_details: str | None = ..., +) -> ExternalWorkflowHandle[typing.Any]: ... + + +@typing.overload +async def signal_with_start_workflow( + workflow: collections.abc.Callable[ + [typing.Any, typing_extensions.Unpack[WorkflowArgs]], + collections.abc.Awaitable[WorkflowResult], + ], + *positional_args: typing_extensions.Unpack[WorkflowArgs], + id: str, + task_queue: str, + signal: collections.abc.Callable[ + [typing.Any], None | collections.abc.Awaitable[None] + ], + execution_timeout: datetime.timedelta | None = ..., + run_timeout: datetime.timedelta | None = ..., + task_timeout: datetime.timedelta | None = ..., + request_id: str | None = ..., + id_reuse_policy: temporalio.common.WorkflowIDReusePolicy = ..., + id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy | None = ..., + retry_policy: temporalio.common.RetryPolicy | None = ..., + cron_schedule: str | None = ..., + memo: collections.abc.Mapping[str, typing.Any] | None = ..., + search_attributes: temporalio.common.TypedSearchAttributes | None = ..., + priority: temporalio.common.Priority | None = ..., + versioning_override: temporalio.common.VersioningOverride | None = ..., + start_delay: datetime.timedelta | None = ..., + static_summary: str | None = ..., + static_details: str | None = ..., +) -> ExternalWorkflowHandle[WorkflowResult]: ... + + +@typing.overload +async def signal_with_start_workflow( + workflow: collections.abc.Callable[..., collections.abc.Awaitable[WorkflowResult]], + *, + args: list[typing.Any], + id: str, + task_queue: str, + signal: collections.abc.Callable[ + [typing.Any], None | collections.abc.Awaitable[None] + ], + execution_timeout: datetime.timedelta | None = ..., + run_timeout: datetime.timedelta | None = ..., + task_timeout: datetime.timedelta | None = ..., + request_id: str | None = ..., + id_reuse_policy: temporalio.common.WorkflowIDReusePolicy = ..., + id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy | None = ..., + retry_policy: temporalio.common.RetryPolicy | None = ..., + cron_schedule: str | None = ..., + memo: collections.abc.Mapping[str, typing.Any] | None = ..., + search_attributes: temporalio.common.TypedSearchAttributes | None = ..., + priority: temporalio.common.Priority | None = ..., + versioning_override: temporalio.common.VersioningOverride | None = ..., + start_delay: datetime.timedelta | None = ..., + static_summary: str | None = ..., + static_details: str | None = ..., +) -> ExternalWorkflowHandle[WorkflowResult]: ... + + +@typing.overload +async def signal_with_start_workflow( + workflow: str, + *positional_args: object, + args: list[typing.Any] | None = ..., + id: str, + task_queue: str, + signal: collections.abc.Callable[ + [typing.Any, SignalArg], None | collections.abc.Awaitable[None] + ], + signal_args: SignalArg, + execution_timeout: datetime.timedelta | None = ..., + run_timeout: datetime.timedelta | None = ..., + task_timeout: datetime.timedelta | None = ..., + request_id: str | None = ..., + id_reuse_policy: temporalio.common.WorkflowIDReusePolicy = ..., + id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy | None = ..., + retry_policy: temporalio.common.RetryPolicy | None = ..., + cron_schedule: str | None = ..., + memo: collections.abc.Mapping[str, typing.Any] | None = ..., + search_attributes: temporalio.common.TypedSearchAttributes | None = ..., + priority: temporalio.common.Priority | None = ..., + versioning_override: temporalio.common.VersioningOverride | None = ..., + start_delay: datetime.timedelta | None = ..., + static_summary: str | None = ..., + static_details: str | None = ..., +) -> ExternalWorkflowHandle[typing.Any]: ... + + +@typing.overload +async def signal_with_start_workflow( + workflow: collections.abc.Callable[ + [typing.Any, typing_extensions.Unpack[WorkflowArgs]], + collections.abc.Awaitable[WorkflowResult], + ], + *positional_args: typing_extensions.Unpack[WorkflowArgs], + id: str, + task_queue: str, + signal: collections.abc.Callable[ + [typing.Any, SignalArg], None | collections.abc.Awaitable[None] + ], + signal_args: SignalArg, + execution_timeout: datetime.timedelta | None = ..., + run_timeout: datetime.timedelta | None = ..., + task_timeout: datetime.timedelta | None = ..., + request_id: str | None = ..., + id_reuse_policy: temporalio.common.WorkflowIDReusePolicy = ..., + id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy | None = ..., + retry_policy: temporalio.common.RetryPolicy | None = ..., + cron_schedule: str | None = ..., + memo: collections.abc.Mapping[str, typing.Any] | None = ..., + search_attributes: temporalio.common.TypedSearchAttributes | None = ..., + priority: temporalio.common.Priority | None = ..., + versioning_override: temporalio.common.VersioningOverride | None = ..., + start_delay: datetime.timedelta | None = ..., + static_summary: str | None = ..., + static_details: str | None = ..., +) -> ExternalWorkflowHandle[WorkflowResult]: ... + + +@typing.overload +async def signal_with_start_workflow( + workflow: collections.abc.Callable[..., collections.abc.Awaitable[WorkflowResult]], + *, + args: list[typing.Any], + id: str, + task_queue: str, + signal: collections.abc.Callable[ + [typing.Any, SignalArg], None | collections.abc.Awaitable[None] + ], + signal_args: SignalArg, + execution_timeout: datetime.timedelta | None = ..., + run_timeout: datetime.timedelta | None = ..., + task_timeout: datetime.timedelta | None = ..., + request_id: str | None = ..., + id_reuse_policy: temporalio.common.WorkflowIDReusePolicy = ..., + id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy | None = ..., + retry_policy: temporalio.common.RetryPolicy | None = ..., + cron_schedule: str | None = ..., + memo: collections.abc.Mapping[str, typing.Any] | None = ..., + search_attributes: temporalio.common.TypedSearchAttributes | None = ..., + priority: temporalio.common.Priority | None = ..., + versioning_override: temporalio.common.VersioningOverride | None = ..., + start_delay: datetime.timedelta | None = ..., + static_summary: str | None = ..., + static_details: str | None = ..., +) -> ExternalWorkflowHandle[WorkflowResult]: ... + + +@typing.overload +async def signal_with_start_workflow( + workflow: str, + *positional_args: object, + args: list[typing.Any] | None = ..., + id: str, + task_queue: str, + signal: collections.abc.Callable[..., None | collections.abc.Awaitable[None]], + signal_args: list[typing.Any], + execution_timeout: datetime.timedelta | None = ..., + run_timeout: datetime.timedelta | None = ..., + task_timeout: datetime.timedelta | None = ..., + request_id: str | None = ..., + id_reuse_policy: temporalio.common.WorkflowIDReusePolicy = ..., + id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy | None = ..., + retry_policy: temporalio.common.RetryPolicy | None = ..., + cron_schedule: str | None = ..., + memo: collections.abc.Mapping[str, typing.Any] | None = ..., + search_attributes: temporalio.common.TypedSearchAttributes | None = ..., + priority: temporalio.common.Priority | None = ..., + versioning_override: temporalio.common.VersioningOverride | None = ..., + start_delay: datetime.timedelta | None = ..., + static_summary: str | None = ..., + static_details: str | None = ..., +) -> ExternalWorkflowHandle[typing.Any]: ... + + +@typing.overload +async def signal_with_start_workflow( + workflow: collections.abc.Callable[ + [typing.Any, typing_extensions.Unpack[WorkflowArgs]], + collections.abc.Awaitable[WorkflowResult], + ], + *positional_args: typing_extensions.Unpack[WorkflowArgs], + id: str, + task_queue: str, + signal: collections.abc.Callable[..., None | collections.abc.Awaitable[None]], + signal_args: list[typing.Any], + execution_timeout: datetime.timedelta | None = ..., + run_timeout: datetime.timedelta | None = ..., + task_timeout: datetime.timedelta | None = ..., + request_id: str | None = ..., + id_reuse_policy: temporalio.common.WorkflowIDReusePolicy = ..., + id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy | None = ..., + retry_policy: temporalio.common.RetryPolicy | None = ..., + cron_schedule: str | None = ..., + memo: collections.abc.Mapping[str, typing.Any] | None = ..., + search_attributes: temporalio.common.TypedSearchAttributes | None = ..., + priority: temporalio.common.Priority | None = ..., + versioning_override: temporalio.common.VersioningOverride | None = ..., + start_delay: datetime.timedelta | None = ..., + static_summary: str | None = ..., + static_details: str | None = ..., +) -> ExternalWorkflowHandle[WorkflowResult]: ... + + +@typing.overload +async def signal_with_start_workflow( + workflow: collections.abc.Callable[..., collections.abc.Awaitable[WorkflowResult]], + *, + args: list[typing.Any], + id: str, + task_queue: str, + signal: collections.abc.Callable[..., None | collections.abc.Awaitable[None]], + signal_args: list[typing.Any], + execution_timeout: datetime.timedelta | None = ..., + run_timeout: datetime.timedelta | None = ..., + task_timeout: datetime.timedelta | None = ..., + request_id: str | None = ..., + id_reuse_policy: temporalio.common.WorkflowIDReusePolicy = ..., + id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy | None = ..., + retry_policy: temporalio.common.RetryPolicy | None = ..., + cron_schedule: str | None = ..., + memo: collections.abc.Mapping[str, typing.Any] | None = ..., + search_attributes: temporalio.common.TypedSearchAttributes | None = ..., + priority: temporalio.common.Priority | None = ..., + versioning_override: temporalio.common.VersioningOverride | None = ..., + start_delay: datetime.timedelta | None = ..., + static_summary: str | None = ..., + static_details: str | None = ..., +) -> ExternalWorkflowHandle[WorkflowResult]: ... + + +async def signal_with_start_workflow( + workflow: str + | collections.abc.Callable[..., collections.abc.Awaitable[typing.Any]], + *positional_args: object, + args: list[typing.Any] | None = None, + id: str, + task_queue: str, + signal: str | collections.abc.Callable[..., None | collections.abc.Awaitable[None]], + signal_args: object | list[typing.Any] | None = None, + execution_timeout: datetime.timedelta | None = None, + run_timeout: datetime.timedelta | None = None, + task_timeout: datetime.timedelta | None = None, + request_id: str | None = None, + id_reuse_policy: temporalio.common.WorkflowIDReusePolicy = ( + temporalio.common.WorkflowIDReusePolicy.ALLOW_DUPLICATE + ), + id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy | None = None, + retry_policy: temporalio.common.RetryPolicy | None = None, + cron_schedule: str | None = None, + memo: collections.abc.Mapping[str, typing.Any] | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + priority: temporalio.common.Priority | None = None, + versioning_override: temporalio.common.VersioningOverride | None = None, + start_delay: datetime.timedelta | None = None, + static_summary: str | None = None, + static_details: str | None = None, +) -> ExternalWorkflowHandle[typing.Any]: + """Signal a workflow, starting it first if needed. + + .. warning:: + This API is experimental and subject to change. + + Args: + workflow: Workflow type name or callable identifying the workflow to start. + positional_args: Positional arguments for workflow. Cannot be set if args is + set. + args: List-form arguments for workflow. Cannot be set if positional_args are + set. For typed workflow callables, list contents are not statically + typechecked; pass workflow arguments positionally for precise typechecking. + id: Unique identifier for the workflow execution. + task_queue: Task queue to run the workflow on. + signal: Signal name or callable to send with the start request. + signal_args: Argument value, or list of argument values, for signal. For typed + single-argument signals, scalar signal_args values are statically + typechecked. List-form signal_args values are not precisely typechecked. To + pass a single signal argument that is itself a list, wrap it in another + list; otherwise the list is interpreted as multiple signal arguments. + execution_timeout: Total workflow execution timeout, including retries and + continue-as-new. + run_timeout: Timeout of a single workflow run. + task_timeout: Timeout of a single workflow task. + request_id: Request ID used to deduplicate workflow start requests. + id_reuse_policy: Behavior when a closed workflow with the same ID exists. + Default is allow-duplicate. + id_conflict_policy: Behavior when a workflow is currently running with the same + ID. Set to use-existing for idempotent deduplication on workflow ID. Cannot + be set if id-reuse-policy is terminate-if-running. + retry_policy: Retry policy for the workflow. + cron_schedule: Cron schedule for recurring workflow executions. See + https://docs.temporal.io/cron-job. + memo: Memo for the workflow. + search_attributes: Typed search attributes for the workflow. + priority: Priority of the workflow execution. + versioning_override: Override for workflow versioning behavior. + start_delay: Amount of time to wait before starting the workflow. This does not + work with cron-schedule. + static_summary: Single-line fixed summary for the workflow execution that may + appear in UI and CLI. This can be in single-line Temporal Markdown format. + static_details: General fixed details for the workflow execution that may appear + in UI and CLI. This can be in Temporal Markdown format and can span multiple + lines. This value is fixed on the workflow execution and cannot be updated. + + Returns: + A workflow handle to the started workflow. + """ + normalized_signal_args: list[typing.Any] | None + if signal_args is None: + normalized_signal_args = None + elif isinstance(signal_args, list): + normalized_signal_args = typing.cast(list[typing.Any], signal_args) + else: + normalized_signal_args = [signal_args] + if positional_args and args is not None: + raise TypeError("cannot specify both positional arguments and args") + normalized_args: list[typing.Any] | None = ( + list(positional_args) if positional_args else args + ) + user_metadata = ( + None + if static_summary is None and static_details is None + else UserMetadata( + static_summary=static_summary, + static_details=static_details, + ) + ) + request = SignalWithStartWorkflowRequest( + workflow=workflow, + args=normalized_args, + id=id, + task_queue=task_queue, + signal=signal, + signal_args=normalized_signal_args, + execution_timeout=execution_timeout, + run_timeout=run_timeout, + task_timeout=task_timeout, + request_id=request_id, + id_reuse_policy=id_reuse_policy, + id_conflict_policy=id_conflict_policy, + retry_policy=retry_policy, + cron_schedule=cron_schedule, + memo=memo, + search_attributes=search_attributes, + priority=priority, + versioning_override=versioning_override, + start_delay=start_delay, + user_metadata=user_metadata, + ) + return await _signal_with_start_workflow(request) diff --git a/temporalio/nexus/system/workflow_service/service.py b/temporalio/nexus/system/workflow_service/service.py new file mode 100644 index 000000000..7ce5849ca --- /dev/null +++ b/temporalio/nexus/system/workflow_service/service.py @@ -0,0 +1,21 @@ +# Generated by nex-gen. DO NOT EDIT! + +from __future__ import annotations + +from nexusrpc import Operation, service + +import temporalio.api.workflowservice.v1.request_response_pb2 + + +@service(name="temporal.api.workflowservice.v1.WorkflowService") +class WorkflowService: + """ + .. warning:: + This API is experimental and subject to change. + """ + + # .. warning:: This API is experimental and subject to change. + signal_with_start_workflow: Operation[ + temporalio.api.workflowservice.v1.request_response_pb2.SignalWithStartWorkflowExecutionRequest, + temporalio.api.workflowservice.v1.request_response_pb2.SignalWithStartWorkflowExecutionResponse, + ] = Operation(name="SignalWithStartWorkflowExecution") diff --git a/temporalio/worker/_command_aware_visitor.py b/temporalio/worker/_command_aware_visitor.py index f77bea042..500fc4db5 100644 --- a/temporalio/worker/_command_aware_visitor.py +++ b/temporalio/worker/_command_aware_visitor.py @@ -6,7 +6,8 @@ from dataclasses import dataclass from temporalio.api.enums.v1.command_type_pb2 import CommandType -from temporalio.bridge._visitor import PayloadVisitor, VisitorFunctions +from temporalio.bridge._visitor import PayloadVisitor +from temporalio.bridge._visitor_functions import VisitorFunctions from temporalio.bridge.proto.workflow_activation.workflow_activation_pb2 import ( ResolveActivity, ResolveChildWorkflowExecution, diff --git a/temporalio/worker/_workflow_instance.py b/temporalio/worker/_workflow_instance.py index 76ccdb2e3..deefb5ad3 100644 --- a/temporalio/worker/_workflow_instance.py +++ b/temporalio/worker/_workflow_instance.py @@ -57,6 +57,7 @@ import temporalio.common import temporalio.converter import temporalio.exceptions +import temporalio.nexus.system import temporalio.workflow from temporalio.converter import StorageDriverStoreContext, StorageDriverWorkflowInfo from temporalio.service import __version__ @@ -2085,8 +2086,19 @@ async def operation_handle_fn() -> OutputT: ): t.uncancel() # type: ignore[union-attr] + payload_converter = ( + temporalio.nexus.system.get_payload_converter() + if temporalio.nexus.system.is_system_operation( + input.service, input.operation_name + ) + else self._context_free_payload_converter + ) handle = _NexusOperationHandle( - self, self._next_seq("nexus_operation"), input, operation_handle_fn() + self, + self._next_seq("nexus_operation"), + input, + operation_handle_fn(), + payload_converter, ) handle._apply_schedule_command() self._pending_nexus_operations[handle._seq] = handle @@ -3453,6 +3465,7 @@ def __init__( seq: int, input: StartNexusOperationInput[Any, OutputT], fn: Coroutine[Any, Any, OutputT], + payload_converter: temporalio.converter.PayloadConverter, ): self._instance = instance self._seq = seq @@ -3460,7 +3473,7 @@ def __init__( self._task = asyncio.Task(fn) self._start_fut: asyncio.Future[str | None] = instance.create_future() self._result_fut: asyncio.Future[OutputT | None] = instance.create_future() - self._payload_converter = self._instance._context_free_payload_converter + self._payload_converter = payload_converter self._failure_converter = self._instance._context_free_failure_converter @property diff --git a/temporalio/workflow/__init__.py b/temporalio/workflow/__init__.py index 8b8b0fb6f..fedc31b10 100644 --- a/temporalio/workflow/__init__.py +++ b/temporalio/workflow/__init__.py @@ -1,5 +1,7 @@ """Utilities that can decorate or be called inside workflows.""" +# ruff: noqa: I001 + from __future__ import annotations from ..types import ( @@ -167,6 +169,12 @@ start_child_workflow, ) +# BEGIN GENERATED NEXUS SYSTEM EXPORTS +from temporalio.nexus.system.workflow_service import ( + signal_with_start_workflow, +) +# END GENERATED NEXUS SYSTEM EXPORTS + __all__ = [ "ActivityCancellationType", "ActivityConfig", @@ -314,4 +322,7 @@ "ProtocolReturnType", "ReturnType", "SelfType", + # BEGIN GENERATED NEXUS SYSTEM __ALL__ + "signal_with_start_workflow", + # END GENERATED NEXUS SYSTEM __ALL__ ] diff --git a/tests/__init__.py b/tests/__init__.py index d62129b39..4725d3a7e 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1 +1 @@ -DEV_SERVER_DOWNLOAD_VERSION = "v1.7.1-standalone-nexus-operations" +DEV_SERVER_DOWNLOAD_VERSION = "v1.7.1-system-nexus-operations" diff --git a/tests/conftest.py b/tests/conftest.py index 1e1db3730..9eaa1ff47 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -136,6 +136,8 @@ async def env(env_type: str) -> AsyncGenerator[WorkflowEnvironment, None]: "nexusoperation.enableStandalone=true", "--dynamic-config-value", 'system.system.refreshNexusEndpointsMinWait="0s"', + "--dynamic-config-value", + "history.enableSignalWithStartFromWorkflow=true", ], dev_server_download_version=DEV_SERVER_DOWNLOAD_VERSION, ) diff --git a/tests/nexus/test_temporal_system_nexus.py b/tests/nexus/test_temporal_system_nexus.py new file mode 100644 index 000000000..c7d9319ca --- /dev/null +++ b/tests/nexus/test_temporal_system_nexus.py @@ -0,0 +1,302 @@ +from __future__ import annotations + +import dataclasses +import uuid +from collections.abc import Sequence +from datetime import timedelta +from typing import Any, cast + +import pytest +from google.protobuf.descriptor import FieldDescriptor +from google.protobuf.message import Message + +import temporalio.api.common.v1 +import temporalio.api.workflowservice.v1.request_response_pb2 as workflowservice_pb2 +import temporalio.converter +import temporalio.nexus.system as nexus_system +from temporalio import workflow +from temporalio.client import Client +from temporalio.converter import ExternalStorage, PayloadCodec +from temporalio.testing import WorkflowEnvironment +from temporalio.worker import ( + Interceptor, + StartNexusOperationInput, + Worker, + WorkflowInboundInterceptor, + WorkflowInterceptorClassInput, + WorkflowOutboundInterceptor, +) +from temporalio.worker._workflow_instance import UnsandboxedWorkflowRunner +from tests.test_extstore import InMemoryTestDriver + +interceptor_traces: list[tuple[str, object]] = [] + + +@workflow.defn +class ExternalHandleSignalWithStartWorkflowCaller: + @workflow.run + async def run(self, task_queue: str) -> str: + started_handle = await workflow.signal_with_start_workflow( + "test-workflow", + "workflow-input", + id="system-nexus-workflow-id", + task_queue=task_queue, + signal="test-signal", + signal_args=["signal-input"], + memo={"memo-key": "memo-value"}, + static_summary="summary-value", + static_details="details-value", + ) + return started_handle.id + + +class RejectOuterSystemNexusCodec(PayloadCodec): + def __init__(self) -> None: + self.encode_count = 0 + + async def encode( + self, payloads: Sequence[temporalio.api.common.v1.Payload] + ) -> list[temporalio.api.common.v1.Payload]: + encoded: list[temporalio.api.common.v1.Payload] = [] + for payload in payloads: + if ( + payload.metadata.get("encoding") == b"binary/protobuf" + and payload.metadata.get("messageType") + == b"temporal.api.workflowservice.v1.SignalWithStartWorkflowExecutionRequest" + ): + raise RuntimeError( + "outer system nexus envelope should not be codec encoded" + ) + self.encode_count += 1 + encoded.append( + temporalio.api.common.v1.Payload( + metadata={**payload.metadata, "test-codec": b"true"}, + data=payload.data, + ) + ) + return encoded + + async def decode( + self, payloads: Sequence[temporalio.api.common.v1.Payload] + ) -> list[temporalio.api.common.v1.Payload]: + decoded: list[temporalio.api.common.v1.Payload] = [] + for payload in payloads: + if ( + payload.metadata.get("encoding") == b"binary/protobuf" + and payload.metadata.get("messageType") + == b"temporal.api.workflowservice.v1.SignalWithStartWorkflowExecutionRequest" + ): + raise RuntimeError( + "outer system nexus envelope should not be codec decoded" + ) + decoded.append(payload) + return decoded + + +class TracingWorkflowInterceptor(Interceptor): + def workflow_interceptor_class( + self, input: WorkflowInterceptorClassInput + ) -> type[WorkflowInboundInterceptor] | None: + return _TracingWorkflowInboundInterceptor + + +class _TracingWorkflowInboundInterceptor(WorkflowInboundInterceptor): + def init(self, outbound: WorkflowOutboundInterceptor) -> None: + super().init(_TracingWorkflowOutboundInterceptor(outbound)) + + +class _TracingWorkflowOutboundInterceptor(WorkflowOutboundInterceptor): + async def start_nexus_operation( + self, input: StartNexusOperationInput[Any, Any] + ) -> workflow.NexusOperationHandle[Any]: + interceptor_traces.append(("workflow.start_nexus_operation", input)) + return await super().start_nexus_operation(input) + + +def _assert_stored_payloads_include( + driver: InMemoryTestDriver, expected_payload_data: set[bytes] +) -> None: + stored_payload_data: set[bytes] = set() + for stored_payload_bytes in driver._storage.values(): + stored_payload = temporalio.api.common.v1.Payload() + stored_payload.ParseFromString(stored_payload_bytes) + assert stored_payload.metadata["test-codec"] == b"true" + stored_payload_data.add(stored_payload.data) + assert expected_payload_data.issubset(stored_payload_data) + + +def _assert_start_nexus_operation_interceptor_trace() -> None: + assert len(interceptor_traces) == 1 + trace_name, trace_value = interceptor_traces.pop() + assert trace_name == "workflow.start_nexus_operation" + trace_input = cast(StartNexusOperationInput[Any, Any], trace_value) + request = cast( + workflowservice_pb2.SignalWithStartWorkflowExecutionRequest, + trace_input.input, + ) + assert request.workflow_id == "system-nexus-workflow-id" + assert request.signal_name == "test-signal" + assert request.workflow_type.name == "test-workflow" + + +def _build_proto_sample(message_type: type[Message]) -> Message: + message = message_type() + _populate_proto_sample(message) + return message + + +def _populate_proto_sample(message: Message, *, path: str = "value") -> None: + seen_oneofs: set[str] = set() + for field in message.DESCRIPTOR.fields: + if field.containing_oneof is not None: + if field.containing_oneof.name in seen_oneofs: + continue + seen_oneofs.add(field.containing_oneof.name) + if field.label == FieldDescriptor.LABEL_REPEATED: + if ( + field.message_type is not None + and field.message_type.GetOptions().map_entry + ): + _populate_proto_map_entry(message, field, path=path) + elif field.cpp_type == FieldDescriptor.CPPTYPE_MESSAGE: + _populate_proto_sample( + getattr(message, field.name).add(), + path=f"{path}.{field.name}[0]", + ) + else: + getattr(message, field.name).append( + _proto_scalar_sample(field, path=f"{path}.{field.name}[0]") + ) + elif field.cpp_type == FieldDescriptor.CPPTYPE_MESSAGE: + _populate_proto_sample( + getattr(message, field.name), + path=f"{path}.{field.name}", + ) + else: + setattr( + message, + field.name, + _proto_scalar_sample(field, path=f"{path}.{field.name}"), + ) + + +def _populate_proto_map_entry( + message: Message, + field: FieldDescriptor, + *, + path: str, +) -> None: + key_field = field.message_type.fields_by_name["key"] + value_field = field.message_type.fields_by_name["value"] + key = _proto_scalar_sample(key_field, path=f"{path}.{field.name}.key") + container = getattr(message, field.name) + if value_field.cpp_type == FieldDescriptor.CPPTYPE_MESSAGE: + _populate_proto_sample( + container[key], + path=f"{path}.{field.name}[{key!r}]", + ) + else: + container[key] = _proto_scalar_sample( + value_field, + path=f"{path}.{field.name}[{key!r}]", + ) + + +def _proto_scalar_sample(field: FieldDescriptor, *, path: str) -> Any: + if field.type == FieldDescriptor.TYPE_BYTES: + return b"test" + if field.cpp_type == FieldDescriptor.CPPTYPE_STRING: + return f"{path}-value" + if field.cpp_type == FieldDescriptor.CPPTYPE_BOOL: + return True + if field.cpp_type in ( + FieldDescriptor.CPPTYPE_INT32, + FieldDescriptor.CPPTYPE_INT64, + FieldDescriptor.CPPTYPE_UINT32, + FieldDescriptor.CPPTYPE_UINT64, + ): + return 1 + if field.cpp_type in ( + FieldDescriptor.CPPTYPE_FLOAT, + FieldDescriptor.CPPTYPE_DOUBLE, + ): + return 1.5 + if field.cpp_type == FieldDescriptor.CPPTYPE_ENUM: + for enum_value in field.enum_type.values: + if enum_value.number != 0: + return enum_value.number + return field.enum_type.values[0].number + raise TypeError(f"Unhandled proto scalar sample at {path}: {field!r}") + + +@pytest.mark.parametrize( + "message_type", + [ + workflowservice_pb2.SignalWithStartWorkflowExecutionRequest, + workflowservice_pb2.SignalWithStartWorkflowExecutionResponse, + ], +) +def test_system_nexus_proto_roundtrip(message_type: type[Message]) -> None: + payload_converter = nexus_system.get_payload_converter() + proto_value = _build_proto_sample(message_type) + payload = payload_converter.to_payload(proto_value) + assert payload is not None + assert payload.metadata["encoding"] == b"binary/protobuf" + assert payload.metadata["messageType"] == message_type.DESCRIPTOR.full_name.encode() + roundtripped = payload_converter.from_payload(payload, message_type) + assert isinstance(roundtripped, message_type) + assert roundtripped == proto_value + + +async def test_external_workflow_handle_signal_with_start_workflow_uses_system_nexus( + env: WorkflowEnvironment, +): + if env.supports_time_skipping: + pytest.skip("Nexus tests don't work with the Java test server") + + codec = RejectOuterSystemNexusCodec() + interceptor_traces.clear() + driver = InMemoryTestDriver() + caller_config = env.client.config() + caller_config["data_converter"] = dataclasses.replace( + temporalio.converter.default(), + payload_codec=codec, + external_storage=ExternalStorage( + drivers=[driver], + payload_size_threshold=1, + ), + ) + caller_client = Client(**caller_config) + caller_task_queue = str(uuid.uuid4()) + handler_task_queue = str(uuid.uuid4()) + + caller_worker = Worker( + caller_client, + task_queue=caller_task_queue, + workflows=[ExternalHandleSignalWithStartWorkflowCaller], + workflow_runner=UnsandboxedWorkflowRunner(), + interceptors=[TracingWorkflowInterceptor()], + ) + + async with caller_worker: + result = await caller_client.execute_workflow( + ExternalHandleSignalWithStartWorkflowCaller.run, + args=[handler_task_queue], + id=str(uuid.uuid4()), + task_queue=caller_task_queue, + execution_timeout=timedelta(seconds=5), + ) + + assert result == "system-nexus-workflow-id" + assert codec.encode_count >= 5 + _assert_stored_payloads_include( + driver, + { + b'"workflow-input"', + b'"signal-input"', + b'"memo-value"', + b'"summary-value"', + b'"details-value"', + }, + ) + _assert_start_nexus_operation_interceptor_trace() diff --git a/tests/worker/test_visitor.py b/tests/worker/test_visitor.py index 876387393..a815f3135 100644 --- a/tests/worker/test_visitor.py +++ b/tests/worker/test_visitor.py @@ -15,7 +15,8 @@ SearchAttributes, ) from temporalio.api.sdk.v1.user_metadata_pb2 import UserMetadata -from temporalio.bridge._visitor import PayloadVisitor, VisitorFunctions +from temporalio.bridge._visitor import PayloadVisitor +from temporalio.bridge._visitor_functions import VisitorFunctions from temporalio.bridge.proto.workflow_activation.workflow_activation_pb2 import ( InitializeWorkflow, WorkflowActivation,