diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index fea940f214..76aba1538c 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -119,7 +119,10 @@ jobs: --dynamic-config-value nexusoperation.enableStandalone=true \ --dynamic-config-value history.enableChasm=true \ --dynamic-config-value history.enableCHASMSignalBacklinks=true \ - --dynamic-config-value history.enableTransitionHistory=true & + --dynamic-config-value history.enableTransitionHistory=true \ + --dynamic-config-value frontend.enableCancelWorkerPollsOnShutdown=true \ + --dynamic-config-value frontend.workerCommandsEnabled=true \ + --dynamic-config-value system.enableCancelActivityWorkerCommand=true & sleep 10s # Can't actually run tests against Java 8 because Mockito 5 requires Java 11+. diff --git a/AGENTS.md b/AGENTS.md index 8d449963c0..0a13d6abf4 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -17,7 +17,7 @@ - The SDK code is written for Java 8. ## Building and Testing -1. Format the code before committing: +1. Format the code before committing (and don't bother running spotlessCheck, just run apply): ```bash ./gradlew --offline spotlessApply ``` diff --git a/temporal-sdk/src/main/java/io/temporal/activity/ActivityCancellationToken.java b/temporal-sdk/src/main/java/io/temporal/activity/ActivityCancellationToken.java new file mode 100644 index 0000000000..26211d201d --- /dev/null +++ b/temporal-sdk/src/main/java/io/temporal/activity/ActivityCancellationToken.java @@ -0,0 +1,49 @@ +package io.temporal.activity; + +import io.temporal.client.ActivityCanceledException; +import java.util.concurrent.CompletableFuture; + +/** Token that allows an Activity implementation to observe cancellation requests. */ +public interface ActivityCancellationToken { + + ActivityCancellationToken NONE = + new ActivityCancellationToken() { + @Override + public boolean isCancellationRequested() { + return false; + } + + @Override + public void throwIfCancellationRequested() throws ActivityCanceledException {} + + @Override + public CompletableFuture getCancellationRequest() { + return new CompletableFuture<>(); + } + }; + + /** + * Returns true after cancellation has been requested for this Activity Execution. + * + *

If this method returns true, the Activity implementation should stop its work and usually + * call {@link #throwIfCancellationRequested()} to report successful cancellation to Temporal. + */ + boolean isCancellationRequested(); + + /** + * Throws {@link ActivityCanceledException} if cancellation has been requested for this Activity + * Execution. + * + *

Rethrowing this exception from Activity code reports successful cancellation to Temporal. + */ + void throwIfCancellationRequested() throws ActivityCanceledException; + + /** + * Future that completes when cancellation has been requested for this Activity Execution. + * + *

The future completes normally. Activity code should still call {@link + * #throwIfCancellationRequested()} or otherwise report cancellation if it wants the Activity + * Execution to complete as canceled. + */ + CompletableFuture getCancellationRequest(); +} diff --git a/temporal-sdk/src/main/java/io/temporal/activity/ActivityExecutionContext.java b/temporal-sdk/src/main/java/io/temporal/activity/ActivityExecutionContext.java index 8918effc4b..2fd7dfa474 100644 --- a/temporal-sdk/src/main/java/io/temporal/activity/ActivityExecutionContext.java +++ b/temporal-sdk/src/main/java/io/temporal/activity/ActivityExecutionContext.java @@ -89,10 +89,18 @@ public interface ActivityExecutionContext { */ byte[] getTaskToken(); + /** + * Returns a token that can be used by Activity code to observe cancellation requests without + * recording Heartbeats. + */ + default ActivityCancellationToken getCancellationToken() { + return ActivityCancellationToken.NONE; + } + /** * If this method is called during an Activity Execution then the Activity Execution is not going - * to complete when it's method returns. It is expected to be completed asynchronously using - * {@link io.temporal.client.ActivityCompletionClient}. + * to complete when its method returns. It is expected to be completed asynchronously using {@link + * io.temporal.client.ActivityCompletionClient}. * *

Async Activity Executions that have {@link #isUseLocalManualCompletion()} set to false will * not respect the limit defined by {@link WorkerOptions#getMaxConcurrentActivityExecutionSize()}. diff --git a/temporal-sdk/src/main/java/io/temporal/common/interceptors/ActivityExecutionContextBase.java b/temporal-sdk/src/main/java/io/temporal/common/interceptors/ActivityExecutionContextBase.java index 73adde3784..3ce86cec38 100644 --- a/temporal-sdk/src/main/java/io/temporal/common/interceptors/ActivityExecutionContextBase.java +++ b/temporal-sdk/src/main/java/io/temporal/common/interceptors/ActivityExecutionContextBase.java @@ -1,6 +1,7 @@ package io.temporal.common.interceptors; import com.uber.m3.tally.Scope; +import io.temporal.activity.ActivityCancellationToken; import io.temporal.activity.ActivityExecutionContext; import io.temporal.activity.ActivityInfo; import io.temporal.activity.ManualActivityCompletionClient; @@ -52,6 +53,11 @@ public byte[] getTaskToken() { return next.getTaskToken(); } + @Override + public ActivityCancellationToken getCancellationToken() { + return next.getCancellationToken(); + } + @Override public void doNotCompleteOnReturn() { next.doNotCompleteOnReturn(); diff --git a/temporal-sdk/src/main/java/io/temporal/internal/activity/ActivityCancellationTokenImpl.java b/temporal-sdk/src/main/java/io/temporal/internal/activity/ActivityCancellationTokenImpl.java new file mode 100644 index 0000000000..d0e8bb9377 --- /dev/null +++ b/temporal-sdk/src/main/java/io/temporal/internal/activity/ActivityCancellationTokenImpl.java @@ -0,0 +1,33 @@ +package io.temporal.internal.activity; + +import io.temporal.activity.ActivityCancellationToken; +import io.temporal.client.ActivityCanceledException; +import java.util.concurrent.CompletableFuture; + +final class ActivityCancellationTokenImpl implements ActivityCancellationToken { + private final CompletableFuture cancellationRequest = new CompletableFuture<>(); + private volatile ActivityCanceledException cancellationException; + + @Override + public boolean isCancellationRequested() { + return cancellationException != null; + } + + @Override + public void throwIfCancellationRequested() throws ActivityCanceledException { + ActivityCanceledException exception = cancellationException; + if (exception != null) { + throw exception; + } + } + + @Override + public CompletableFuture getCancellationRequest() { + return cancellationRequest; + } + + void requestCancel(ActivityCanceledException exception) { + cancellationException = exception; + cancellationRequest.complete(null); + } +} diff --git a/temporal-sdk/src/main/java/io/temporal/internal/activity/ActivityExecutionContextFactory.java b/temporal-sdk/src/main/java/io/temporal/internal/activity/ActivityExecutionContextFactory.java index cc0f5ee279..4e4003d963 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/activity/ActivityExecutionContextFactory.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/activity/ActivityExecutionContextFactory.java @@ -5,4 +5,14 @@ public interface ActivityExecutionContextFactory { InternalActivityExecutionContext createContext( ActivityInfoInternal info, Object activity, Scope metricsScope); + + /** + * Removes a context for a currently running activity identified by task token and optionally + * requests cancellation. + * + * @return true if the activity was found and cleaned up. + */ + default boolean cleanupContext(byte[] taskToken, boolean cancel) { + return false; + } } diff --git a/temporal-sdk/src/main/java/io/temporal/internal/activity/ActivityExecutionContextFactoryImpl.java b/temporal-sdk/src/main/java/io/temporal/internal/activity/ActivityExecutionContextFactoryImpl.java index c3df217721..4acc1d17dd 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/activity/ActivityExecutionContextFactoryImpl.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/activity/ActivityExecutionContextFactoryImpl.java @@ -4,8 +4,12 @@ import io.temporal.client.WorkflowClient; import io.temporal.common.converter.DataConverter; import io.temporal.internal.client.external.ManualActivityCompletionClientFactory; +import java.nio.ByteBuffer; import java.time.Duration; +import java.util.Arrays; import java.util.Objects; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; import java.util.concurrent.ScheduledExecutorService; public class ActivityExecutionContextFactoryImpl implements ActivityExecutionContextFactory { @@ -17,6 +21,8 @@ public class ActivityExecutionContextFactoryImpl implements ActivityExecutionCon private final DataConverter dataConverter; private final ScheduledExecutorService heartbeatExecutor; private final ManualActivityCompletionClientFactory manualCompletionClientFactory; + private final ConcurrentMap activeContexts = + new ConcurrentHashMap<>(); public ActivityExecutionContextFactoryImpl( WorkflowClient client, @@ -42,18 +48,39 @@ public ActivityExecutionContextFactoryImpl( @Override public InternalActivityExecutionContext createContext( ActivityInfoInternal info, Object activity, Scope metricsScope) { - return new ActivityExecutionContextImpl( - client, - namespace, - activity, - info, - dataConverter, - heartbeatExecutor, - manualCompletionClientFactory, - info.getCompletionHandle(), - metricsScope, - identity, - maxHeartbeatThrottleInterval, - defaultHeartbeatThrottleInterval); + ByteBuffer taskToken = taskTokenKey(info.getTaskToken()); + ActivityExecutionContextImpl context = + new ActivityExecutionContextImpl( + client, + namespace, + activity, + info, + dataConverter, + heartbeatExecutor, + manualCompletionClientFactory, + info.getCompletionHandle(), + metricsScope, + identity, + maxHeartbeatThrottleInterval, + defaultHeartbeatThrottleInterval, + () -> cleanupContext(info.getTaskToken(), false)); + activeContexts.put(taskToken, context); + return context; + } + + @Override + public boolean cleanupContext(byte[] taskToken, boolean cancel) { + ActivityExecutionContextImpl context = activeContexts.remove(taskTokenKey(taskToken)); + if (context == null) { + return false; + } + if (cancel) { + context.cancelFromWorkerCommand(); + } + return true; + } + + private static ByteBuffer taskTokenKey(byte[] taskToken) { + return ByteBuffer.wrap(Arrays.copyOf(taskToken, taskToken.length)).asReadOnlyBuffer(); } } diff --git a/temporal-sdk/src/main/java/io/temporal/internal/activity/ActivityExecutionContextImpl.java b/temporal-sdk/src/main/java/io/temporal/internal/activity/ActivityExecutionContextImpl.java index 101ca4c047..db8943138c 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/activity/ActivityExecutionContextImpl.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/activity/ActivityExecutionContextImpl.java @@ -1,6 +1,7 @@ package io.temporal.internal.activity; import com.uber.m3.tally.Scope; +import io.temporal.activity.ActivityCancellationToken; import io.temporal.activity.ActivityExecutionContext; import io.temporal.activity.ActivityInfo; import io.temporal.activity.ManualActivityCompletionClient; @@ -32,6 +33,7 @@ class ActivityExecutionContextImpl implements InternalActivityExecutionContext { private final ManualActivityCompletionClientFactory manualCompletionClientFactory; private final Functions.Proc completionHandle; private final HeartbeatContext heartbeatContext; + private final Functions.Proc closeCallback; private final Scope metricsScope; private final ActivityInfo info; @@ -51,12 +53,14 @@ class ActivityExecutionContextImpl implements InternalActivityExecutionContext { Scope metricsScope, String identity, Duration maxHeartbeatThrottleInterval, - Duration defaultHeartbeatThrottleInterval) { + Duration defaultHeartbeatThrottleInterval, + Functions.Proc closeCallback) { this.client = client; this.activity = activity; this.metricsScope = metricsScope; this.info = info; this.completionHandle = completionHandle; + this.closeCallback = closeCallback; this.manualCompletionClientFactory = manualCompletionClientFactory; this.heartbeatContext = new HeartbeatContextImpl( @@ -105,6 +109,11 @@ public byte[] getTaskToken() { return info.getTaskToken(); } + @Override + public ActivityCancellationToken getCancellationToken() { + return heartbeatContext.getCancellationToken(); + } + @Override public void doNotCompleteOnReturn() { lock.lock(); @@ -170,6 +179,16 @@ public Object getLastHeartbeatValue() { @Override public void cancelOutstandingHeartbeat() { heartbeatContext.cancelOutstandingHeartbeat(); + closeCallback.apply(); + } + + @Override + public void asyncCompletionStarted() { + heartbeatContext.asyncCompletionStarted(); + } + + void cancelFromWorkerCommand() { + heartbeatContext.cancelFromWorkerCommand(); } @Override diff --git a/temporal-sdk/src/main/java/io/temporal/internal/activity/ActivityTaskExecutors.java b/temporal-sdk/src/main/java/io/temporal/internal/activity/ActivityTaskExecutors.java index a77838193c..7789bd2716 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/activity/ActivityTaskExecutors.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/activity/ActivityTaskExecutors.java @@ -129,8 +129,13 @@ public ActivityTaskHandler.Result execute(ActivityInfoInternal info, Scope metri local, dataConverterWithActivityContext); } finally { - if (!context.isDoNotCompleteOnReturn()) { - // if the activity is not completed, we need to cancel the heartbeat + if (context.isDoNotCompleteOnReturn()) { + if (!context.isUseLocalManualCompletion()) { + context.asyncCompletionStarted(); + } + executionContextFactory.cleanupContext(info.getTaskToken(), false); + } else { + // if the activity is completed, we need to cancel the heartbeat // to avoid sending it after the activity is completed context.cancelOutstandingHeartbeat(); } diff --git a/temporal-sdk/src/main/java/io/temporal/internal/activity/ActivityTaskHandlerImpl.java b/temporal-sdk/src/main/java/io/temporal/internal/activity/ActivityTaskHandlerImpl.java index 312e7c728a..48e9dbfabf 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/activity/ActivityTaskHandlerImpl.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/activity/ActivityTaskHandlerImpl.java @@ -83,6 +83,10 @@ public boolean isTypeSupported(String type) { return activities.get(type) != null || dynamicActivity != null; } + public boolean requestCancel(byte[] taskToken) { + return executionContextFactory.cleanupContext(taskToken, true); + } + public void registerActivityImplementations(Object[] activitiesImplementation) { for (Object activity : activitiesImplementation) { registerActivityImplementation(activity); diff --git a/temporal-sdk/src/main/java/io/temporal/internal/activity/HeartbeatContext.java b/temporal-sdk/src/main/java/io/temporal/internal/activity/HeartbeatContext.java index f87f3c637f..c960f70ba9 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/activity/HeartbeatContext.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/activity/HeartbeatContext.java @@ -1,5 +1,6 @@ package io.temporal.internal.activity; +import io.temporal.activity.ActivityCancellationToken; import io.temporal.client.ActivityCompletionException; import java.lang.reflect.Type; import java.util.Optional; @@ -23,6 +24,14 @@ interface HeartbeatContext { Object getLatestHeartbeatDetails(); + ActivityCancellationToken getCancellationToken(); + + /** Mark this activity as canceled by an external worker command. */ + void cancelFromWorkerCommand(); + + /** Mark this activity as returned for async completion. */ + void asyncCompletionStarted(); + /** Cancel any pending heartbeat and discard cached heartbeat details. */ void cancelOutstandingHeartbeat(); } diff --git a/temporal-sdk/src/main/java/io/temporal/internal/activity/HeartbeatContextImpl.java b/temporal-sdk/src/main/java/io/temporal/internal/activity/HeartbeatContextImpl.java index 48993d0da1..871c44e122 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/activity/HeartbeatContextImpl.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/activity/HeartbeatContextImpl.java @@ -3,6 +3,7 @@ import com.uber.m3.tally.Scope; import io.grpc.Status; import io.grpc.StatusRuntimeException; +import io.temporal.activity.ActivityCancellationToken; import io.temporal.activity.ActivityExecutionContext; import io.temporal.activity.ActivityInfo; import io.temporal.api.common.v1.Payloads; @@ -72,8 +73,11 @@ static long getLocalHeartbeatTimeoutBufferMillis() { // 0 means no local timeout is active. private long heartbeatTimeoutDeadlineNanos; private boolean heartbeatTimedOut; + private boolean rejectNewHeartbeats; private ActivityCompletionException lastException; + private final ActivityCancellationTokenImpl cancellationToken = + new ActivityCancellationTokenImpl(); public HeartbeatContextImpl( WorkflowServiceStubs service, @@ -149,6 +153,10 @@ public void heartbeat(V details) throws ActivityCompletionException { lock.lock(); try { checkHeartbeatTimeoutDeadlineLocked(); + if (rejectNewHeartbeats) { + cancellationToken.throwIfCancellationRequested(); + throw new ActivityCanceledException(info); + } receivedAHeartbeat = true; lastDetails = details; hasOutstandingHeartbeat = true; @@ -159,6 +167,7 @@ public void heartbeat(V details) throws ActivityCompletionException { if (lastException != null) { throw lastException; } + cancellationToken.throwIfCancellationRequested(); } finally { lock.unlock(); } @@ -228,6 +237,32 @@ public void cancelOutstandingHeartbeat() { } } + @Override + public void cancelFromWorkerCommand() { + lock.lock(); + try { + requestCancelLocked(); + } finally { + lock.unlock(); + } + } + + @Override + public void asyncCompletionStarted() { + lock.lock(); + try { + requestCancelLocked(); + rejectNewHeartbeats = true; + } finally { + lock.unlock(); + } + } + + @Override + public ActivityCancellationToken getCancellationToken() { + return cancellationToken; + } + private void doHeartBeatLocked(Object details) { long nextHeartbeatDelay; try { @@ -307,7 +342,7 @@ private void sendHeartbeatRequest(Object details) { dataConverterWithActivityContext.toPayloads(details), metricsScope); if (status.getCancelRequested()) { - lastException = new ActivityCanceledException(info); + requestCancelLocked(); } else if (status.getActivityReset()) { lastException = new ActivityResetException(info); } else if (status.getActivityPaused()) { @@ -327,6 +362,12 @@ private void sendHeartbeatRequest(Object details) { } } + private void requestCancelLocked() { + ActivityCanceledException exception = new ActivityCanceledException(info); + lastException = exception; + cancellationToken.requestCancel(exception); + } + private static long getHeartbeatIntervalMs( Duration activityHeartbeatTimeout, Duration maxHeartbeatThrottleInterval, diff --git a/temporal-sdk/src/main/java/io/temporal/internal/activity/InternalActivityExecutionContext.java b/temporal-sdk/src/main/java/io/temporal/internal/activity/InternalActivityExecutionContext.java index 1b65e32cd7..8b6fbe997e 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/activity/InternalActivityExecutionContext.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/activity/InternalActivityExecutionContext.java @@ -10,6 +10,9 @@ public interface InternalActivityExecutionContext extends ActivityExecutionConte /** Get the latest value of {@link ActivityExecutionContext#heartbeat(Object)}. */ Object getLastHeartbeatValue(); + /** Mark this context as returned for async completion. */ + void asyncCompletionStarted(); + /** Cancel any pending heartbeat and discard cached heartbeat details. */ void cancelOutstandingHeartbeat(); } diff --git a/temporal-sdk/src/main/java/io/temporal/internal/activity/LocalActivityExecutionContextImpl.java b/temporal-sdk/src/main/java/io/temporal/internal/activity/LocalActivityExecutionContextImpl.java index 78b82135a4..9a910a3918 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/activity/LocalActivityExecutionContextImpl.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/activity/LocalActivityExecutionContextImpl.java @@ -1,6 +1,7 @@ package io.temporal.internal.activity; import com.uber.m3.tally.Scope; +import io.temporal.activity.ActivityCancellationToken; import io.temporal.activity.ActivityInfo; import io.temporal.activity.ManualActivityCompletionClient; import io.temporal.client.ActivityCompletionException; @@ -57,6 +58,12 @@ public byte[] getTaskToken() { throw new UnsupportedOperationException("getTaskToken is not supported for local activities"); } + @Override + public ActivityCancellationToken getCancellationToken() { + throw new UnsupportedOperationException( + "getCancellationToken is not supported for local activities"); + } + @Override public void doNotCompleteOnReturn() { throw new UnsupportedOperationException( @@ -89,6 +96,11 @@ public Object getLastHeartbeatValue() { return null; } + @Override + public void asyncCompletionStarted() { + // Ignored + } + @Override public void cancelOutstandingHeartbeat() { // Ignored diff --git a/temporal-sdk/src/main/java/io/temporal/internal/worker/ActivityPollTask.java b/temporal-sdk/src/main/java/io/temporal/internal/worker/ActivityPollTask.java index 1dceb67fb0..f0d3e649f0 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/worker/ActivityPollTask.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/worker/ActivityPollTask.java @@ -43,7 +43,8 @@ public ActivityPollTask( @Nonnull TrackingSlotSupplier slotSupplier, @Nonnull Scope metricsScope, @Nonnull Supplier serverCapabilities, - @Nonnull PollerTracker pollerTracker) { + @Nonnull PollerTracker pollerTracker, + String workerControlTaskQueue) { this.service = Objects.requireNonNull(service); this.slotSupplier = slotSupplier; this.metricsScope = Objects.requireNonNull(metricsScope); @@ -55,6 +56,9 @@ public ActivityPollTask( .setIdentity(identity) .setTaskQueue(TaskQueue.newBuilder().setName(taskQueue)); pollRequest.setWorkerInstanceKey(workerInstanceKey); + if (workerControlTaskQueue != null) { + pollRequest.setWorkerControlTaskQueue(workerControlTaskQueue); + } if (activitiesPerSecond > 0) { pollRequest.setTaskQueueMetadata( TaskQueueMetadata.newBuilder() diff --git a/temporal-sdk/src/main/java/io/temporal/internal/worker/ActivityWorker.java b/temporal-sdk/src/main/java/io/temporal/internal/worker/ActivityWorker.java index d2fddde3f3..e1888dc59e 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/worker/ActivityWorker.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/worker/ActivityWorker.java @@ -111,7 +111,8 @@ public boolean start() { this.slotSupplier, workerMetricsScope, service.getServerCapabilities(), - pollerTracker), + pollerTracker, + workerControlTaskQueue()), this.pollTaskExecutor, pollerOptions, namespaceCapabilities, @@ -132,7 +133,8 @@ public boolean start() { this.slotSupplier, workerMetricsScope, service.getServerCapabilities(), - pollerTracker), + pollerTracker, + workerControlTaskQueue()), this.pollTaskExecutor, pollerOptions, workerMetricsScope, @@ -146,6 +148,12 @@ public boolean start() { } } + private String workerControlTaskQueue() { + return namespaceCapabilities.isWorkerHeartbeats() && namespaceCapabilities.isWorkerCommands() + ? options.getWorkerControlTaskQueue() + : null; + } + @Override public CompletableFuture shutdown(ShutdownManager shutdownManager, boolean interruptTasks) { String supplierName = this + "#executorSlots"; diff --git a/temporal-sdk/src/main/java/io/temporal/internal/worker/AsyncActivityPollTask.java b/temporal-sdk/src/main/java/io/temporal/internal/worker/AsyncActivityPollTask.java index b23d161845..1e8791bd02 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/worker/AsyncActivityPollTask.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/worker/AsyncActivityPollTask.java @@ -49,7 +49,8 @@ public AsyncActivityPollTask( @Nonnull TrackingSlotSupplier slotSupplier, @Nonnull Scope metricsScope, @Nonnull Supplier serverCapabilities, - @Nonnull PollerTracker pollerTracker) { + @Nonnull PollerTracker pollerTracker, + String workerControlTaskQueue) { this.service = service; this.slotSupplier = slotSupplier; this.metricsScope = metricsScope; @@ -61,6 +62,9 @@ public AsyncActivityPollTask( .setIdentity(identity) .setTaskQueue(TaskQueue.newBuilder().setName(taskQueue)); pollRequest.setWorkerInstanceKey(workerInstanceKey); + if (workerControlTaskQueue != null) { + pollRequest.setWorkerControlTaskQueue(workerControlTaskQueue); + } if (activitiesPerSecond > 0) { pollRequest.setTaskQueueMetadata( TaskQueueMetadata.newBuilder() diff --git a/temporal-sdk/src/main/java/io/temporal/internal/worker/AsyncNexusPollTask.java b/temporal-sdk/src/main/java/io/temporal/internal/worker/AsyncNexusPollTask.java index 1ba3b84d15..d83bda0be2 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/worker/AsyncNexusPollTask.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/worker/AsyncNexusPollTask.java @@ -6,6 +6,7 @@ import com.uber.m3.tally.Scope; import io.grpc.Context; import io.temporal.api.common.v1.WorkerVersionCapabilities; +import io.temporal.api.enums.v1.TaskQueueKind; import io.temporal.api.taskqueue.v1.TaskQueue; import io.temporal.api.workflowservice.v1.GetSystemInfoResponse; import io.temporal.api.workflowservice.v1.PollNexusTaskQueueRequest; @@ -47,6 +48,33 @@ public AsyncNexusPollTask( @Nonnull Supplier serverCapabilities, TrackingSlotSupplier slotSupplier, @Nonnull PollerTracker pollerTracker) { + this( + service, + namespace, + taskQueue, + identity, + workerInstanceKey, + versioningOptions, + metricsScope, + serverCapabilities, + slotSupplier, + pollerTracker, + false); + } + + @SuppressWarnings("deprecation") + public AsyncNexusPollTask( + @Nonnull WorkflowServiceStubs service, + @Nonnull String namespace, + @Nonnull String taskQueue, + @Nonnull String identity, + @Nonnull String workerInstanceKey, + @Nonnull WorkerVersioningOptions versioningOptions, + @Nonnull Scope metricsScope, + @Nonnull Supplier serverCapabilities, + TrackingSlotSupplier slotSupplier, + @Nonnull PollerTracker pollerTracker, + boolean workerCommandsTaskQueue) { this.service = Objects.requireNonNull(service); this.metricsScope = Objects.requireNonNull(metricsScope); this.slotSupplier = slotSupplier; @@ -56,7 +84,13 @@ public AsyncNexusPollTask( PollNexusTaskQueueRequest.newBuilder() .setNamespace(namespace) .setIdentity(identity) - .setTaskQueue(TaskQueue.newBuilder().setName(taskQueue)); + .setTaskQueue( + TaskQueue.newBuilder() + .setName(taskQueue) + .setKind( + workerCommandsTaskQueue + ? TaskQueueKind.TASK_QUEUE_KIND_WORKER_COMMANDS + : TaskQueueKind.TASK_QUEUE_KIND_NORMAL)); pollRequest.setWorkerInstanceKey(workerInstanceKey); diff --git a/temporal-sdk/src/main/java/io/temporal/internal/worker/AsyncPoller.java b/temporal-sdk/src/main/java/io/temporal/internal/worker/AsyncPoller.java index c56111c02e..9a376a25c6 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/worker/AsyncPoller.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/worker/AsyncPoller.java @@ -305,7 +305,7 @@ public void run() { if (shouldTerminate()) { pollerBalancer.removePoller(asyncTaskPoller.getLabel()); abort = true; - log.info( + log.debug( "Poll loop is terminated: {} - {}", AsyncPoller.this.getClass().getSimpleName(), asyncTaskPoller.getLabel()); diff --git a/temporal-sdk/src/main/java/io/temporal/internal/worker/AsyncWorkflowPollTask.java b/temporal-sdk/src/main/java/io/temporal/internal/worker/AsyncWorkflowPollTask.java index 3bfa796a30..97ed165ed2 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/worker/AsyncWorkflowPollTask.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/worker/AsyncWorkflowPollTask.java @@ -57,7 +57,8 @@ public AsyncWorkflowPollTask( @Nonnull TrackingSlotSupplier slotSupplier, @Nonnull Scope metricsScope, @Nonnull Supplier serverCapabilities, - @Nonnull PollerTracker pollerTracker) { + @Nonnull PollerTracker pollerTracker, + String workerControlTaskQueue) { this.service = service; this.slotSupplier = slotSupplier; this.metricsScope = metricsScope; @@ -69,6 +70,9 @@ public AsyncWorkflowPollTask( .setIdentity(Objects.requireNonNull(identity)); pollRequestBuilder.setWorkerInstanceKey(workerInstanceKey); + if (workerControlTaskQueue != null) { + pollRequestBuilder.setWorkerControlTaskQueue(workerControlTaskQueue); + } if (versioningOptions.getWorkerDeploymentOptions() != null) { pollRequestBuilder.setDeploymentOptions( diff --git a/temporal-sdk/src/main/java/io/temporal/internal/worker/BasePoller.java b/temporal-sdk/src/main/java/io/temporal/internal/worker/BasePoller.java index 855145b317..a8a77d680f 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/worker/BasePoller.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/worker/BasePoller.java @@ -52,7 +52,7 @@ public boolean isTerminated() { @Override public CompletableFuture shutdown(ShutdownManager shutdownManager, boolean interruptTasks) { - log.info("shutdown: {}", this); + log.debug("shutdown: {}", this); WorkerLifecycleState lifecycleState = getLifecycleState(); switch (lifecycleState) { case NOT_STARTED: diff --git a/temporal-sdk/src/main/java/io/temporal/internal/worker/MultiThreadedPoller.java b/temporal-sdk/src/main/java/io/temporal/internal/worker/MultiThreadedPoller.java index 7fe0335b15..e82d162665 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/worker/MultiThreadedPoller.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/worker/MultiThreadedPoller.java @@ -68,7 +68,7 @@ public MultiThreadedPoller( @Override public boolean start() { - log.info("start: {}", this); + log.debug("start: {}", this); if (pollerOptions.getMaximumPollRatePerSecond() > 0.0) { pollRateThrottler = @@ -193,7 +193,7 @@ public void run() { // Resubmit itself back to pollExecutor pollExecutor.execute(this); } else { - log.info( + log.debug( "poll loop is terminated: {}", MultiThreadedPoller.this.pollTask.getClass().getSimpleName()); } diff --git a/temporal-sdk/src/main/java/io/temporal/internal/worker/NamespaceCapabilities.java b/temporal-sdk/src/main/java/io/temporal/internal/worker/NamespaceCapabilities.java index 8c9f23270f..ed4ac3935f 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/worker/NamespaceCapabilities.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/worker/NamespaceCapabilities.java @@ -12,6 +12,7 @@ public final class NamespaceCapabilities { private final AtomicBoolean pollerAutoscaling = new AtomicBoolean(false); private final AtomicBoolean gracefulPollShutdown = new AtomicBoolean(false); private final AtomicBoolean workerHeartbeats = new AtomicBoolean(false); + private final AtomicBoolean workerCommands = new AtomicBoolean(false); public void setFromCapabilities(Capabilities capabilities) { if (capabilities.getPollerAutoscaling()) { @@ -23,6 +24,9 @@ public void setFromCapabilities(Capabilities capabilities) { if (capabilities.getWorkerHeartbeats()) { workerHeartbeats.set(true); } + if (capabilities.getWorkerCommands()) { + workerCommands.set(true); + } } public boolean isPollerAutoscaling() { @@ -44,4 +48,12 @@ public boolean isWorkerHeartbeats() { public void setWorkerHeartbeats(boolean value) { workerHeartbeats.set(value); } + + public boolean isWorkerCommands() { + return workerCommands.get(); + } + + public void setWorkerCommands(boolean value) { + workerCommands.set(value); + } } diff --git a/temporal-sdk/src/main/java/io/temporal/internal/worker/NexusPollTask.java b/temporal-sdk/src/main/java/io/temporal/internal/worker/NexusPollTask.java index 0ccab59443..b53546cfd2 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/worker/NexusPollTask.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/worker/NexusPollTask.java @@ -5,6 +5,7 @@ import com.google.protobuf.Timestamp; import com.uber.m3.tally.Scope; import io.temporal.api.common.v1.WorkerVersionCapabilities; +import io.temporal.api.enums.v1.TaskQueueKind; import io.temporal.api.taskqueue.v1.TaskQueue; import io.temporal.api.workflowservice.v1.*; import io.temporal.internal.common.ProtobufTimeUtils; @@ -40,6 +41,33 @@ public NexusPollTask( @Nonnull Scope metricsScope, @Nonnull Supplier serverCapabilities, @Nonnull PollerTracker pollerTracker) { + this( + service, + namespace, + taskQueue, + identity, + workerInstanceKey, + versioningOptions, + slotSupplier, + metricsScope, + serverCapabilities, + pollerTracker, + false); + } + + @SuppressWarnings("deprecation") + public NexusPollTask( + @Nonnull WorkflowServiceStubs service, + @Nonnull String namespace, + @Nonnull String taskQueue, + @Nonnull String identity, + @Nonnull String workerInstanceKey, + @Nonnull WorkerVersioningOptions versioningOptions, + @Nonnull TrackingSlotSupplier slotSupplier, + @Nonnull Scope metricsScope, + @Nonnull Supplier serverCapabilities, + @Nonnull PollerTracker pollerTracker, + boolean workerCommandsTaskQueue) { this.service = Objects.requireNonNull(service); this.slotSupplier = slotSupplier; this.metricsScope = Objects.requireNonNull(metricsScope); @@ -49,7 +77,13 @@ public NexusPollTask( PollNexusTaskQueueRequest.newBuilder() .setNamespace(namespace) .setIdentity(identity) - .setTaskQueue(TaskQueue.newBuilder().setName(taskQueue)); + .setTaskQueue( + TaskQueue.newBuilder() + .setName(taskQueue) + .setKind( + workerCommandsTaskQueue + ? TaskQueueKind.TASK_QUEUE_KIND_WORKER_COMMANDS + : TaskQueueKind.TASK_QUEUE_KIND_NORMAL)); pollRequest.setWorkerInstanceKey(workerInstanceKey); if (versioningOptions.getWorkerDeploymentOptions() != null) { diff --git a/temporal-sdk/src/main/java/io/temporal/internal/worker/NexusWorker.java b/temporal-sdk/src/main/java/io/temporal/internal/worker/NexusWorker.java index a09993c037..1fd9cf9148 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/worker/NexusWorker.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/worker/NexusWorker.java @@ -54,6 +54,7 @@ final class NexusWorker implements SuspendableWorker { private final TrackingSlotSupplier slotSupplier; private final NamespaceCapabilities namespaceCapabilities; private final boolean forceOldFailureFormat; + private final boolean workerCommandsTaskQueue; private final TaskCounter taskCounter = new TaskCounter(); private final PollerTracker pollerTracker = new PollerTracker(); @@ -66,6 +67,28 @@ public NexusWorker( @Nonnull DataConverter dataConverter, @Nonnull SlotSupplier slotSupplier, @Nonnull NamespaceCapabilities namespaceCapabilities) { + this( + service, + namespace, + taskQueue, + options, + handler, + dataConverter, + slotSupplier, + namespaceCapabilities, + false); + } + + public NexusWorker( + @Nonnull WorkflowServiceStubs service, + @Nonnull String namespace, + @Nonnull String taskQueue, + @Nonnull SingleWorkerOptions options, + @Nonnull NexusTaskHandler handler, + @Nonnull DataConverter dataConverter, + @Nonnull SlotSupplier slotSupplier, + @Nonnull NamespaceCapabilities namespaceCapabilities, + boolean workerCommandsTaskQueue) { this.service = Objects.requireNonNull(service); this.namespace = Objects.requireNonNull(namespace); this.taskQueue = Objects.requireNonNull(taskQueue); @@ -82,6 +105,7 @@ public NexusWorker( this.slotSupplier = new TrackingSlotSupplier<>(slotSupplier, this.workerMetricsScope); this.namespaceCapabilities = namespaceCapabilities; + this.workerCommandsTaskQueue = workerCommandsTaskQueue; // Allow tests to force old format for backward compatibility testing String forceOldFormat = System.getProperty("temporal.nexus.forceOldFailureFormat"); this.forceOldFailureFormat = "true".equalsIgnoreCase(forceOldFormat); @@ -116,7 +140,8 @@ public boolean start() { workerMetricsScope, service.getServerCapabilities(), this.slotSupplier, - pollerTracker), + pollerTracker, + workerCommandsTaskQueue), this.pollTaskExecutor, pollerOptions, namespaceCapabilities, @@ -135,7 +160,8 @@ public boolean start() { this.slotSupplier, workerMetricsScope, service.getServerCapabilities(), - pollerTracker), + pollerTracker, + workerCommandsTaskQueue), this.pollTaskExecutor, pollerOptions, workerMetricsScope, diff --git a/temporal-sdk/src/main/java/io/temporal/internal/worker/SingleWorkerOptions.java b/temporal-sdk/src/main/java/io/temporal/internal/worker/SingleWorkerOptions.java index f53802f489..3e84dc750f 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/worker/SingleWorkerOptions.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/worker/SingleWorkerOptions.java @@ -42,6 +42,7 @@ public static final class Builder { private WorkerDeploymentOptions deploymentOptions; private String workerInstanceKey; private boolean allowActivityHeartbeatDuringShutdown; + private String workerControlTaskQueue; private Builder() {} @@ -68,6 +69,7 @@ private Builder(SingleWorkerOptions options) { this.deploymentOptions = options.getDeploymentOptions(); this.workerInstanceKey = options.getWorkerInstanceKey(); this.allowActivityHeartbeatDuringShutdown = options.getAllowActivityHeartbeatDuringShutdown(); + this.workerControlTaskQueue = options.getWorkerControlTaskQueue(); } public Builder setIdentity(String identity) { @@ -170,6 +172,11 @@ public Builder setAllowActivityHeartbeatDuringShutdown( return this; } + public Builder setWorkerControlTaskQueue(String workerControlTaskQueue) { + this.workerControlTaskQueue = workerControlTaskQueue; + return this; + } + public SingleWorkerOptions build() { PollerOptions pollerOptions = this.pollerOptions; if (pollerOptions == null) { @@ -210,7 +217,8 @@ public SingleWorkerOptions build() { usingVirtualThreads, this.deploymentOptions, this.workerInstanceKey, - this.allowActivityHeartbeatDuringShutdown); + this.allowActivityHeartbeatDuringShutdown, + this.workerControlTaskQueue); } } @@ -233,6 +241,7 @@ public SingleWorkerOptions build() { private final WorkerDeploymentOptions deploymentOptions; private final String workerInstanceKey; private final boolean allowActivityHeartbeatDuringShutdown; + private final String workerControlTaskQueue; private SingleWorkerOptions( String identity, @@ -253,7 +262,8 @@ private SingleWorkerOptions( boolean usingVirtualThreads, WorkerDeploymentOptions deploymentOptions, String workerInstanceKey, - boolean allowActivityHeartbeatDuringShutdown) { + boolean allowActivityHeartbeatDuringShutdown, + String workerControlTaskQueue) { this.identity = identity; this.binaryChecksum = binaryChecksum; this.buildId = buildId; @@ -273,6 +283,7 @@ private SingleWorkerOptions( this.deploymentOptions = deploymentOptions; this.workerInstanceKey = workerInstanceKey; this.allowActivityHeartbeatDuringShutdown = allowActivityHeartbeatDuringShutdown; + this.workerControlTaskQueue = workerControlTaskQueue; } public String getIdentity() { @@ -362,6 +373,10 @@ public String getWorkerInstanceKey() { return workerInstanceKey; } + public String getWorkerControlTaskQueue() { + return workerControlTaskQueue; + } + public WorkerVersioningOptions getWorkerVersioningOptions() { return new WorkerVersioningOptions( this.getBuildId(), this.isUsingBuildIdForVersioning(), this.getDeploymentOptions()); diff --git a/temporal-sdk/src/main/java/io/temporal/internal/worker/SyncActivityWorker.java b/temporal-sdk/src/main/java/io/temporal/internal/worker/SyncActivityWorker.java index 4eafdb38cf..94d2f5dee3 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/worker/SyncActivityWorker.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/worker/SyncActivityWorker.java @@ -171,6 +171,10 @@ public boolean isAnyTypeSupported() { return taskHandler.isAnyTypeSupported(); } + public boolean requestCancelActivity(byte[] taskToken) { + return taskHandler.requestCancel(taskToken); + } + public TrackingSlotSupplier getSlotSupplier() { return worker.getSlotSupplier(); } diff --git a/temporal-sdk/src/main/java/io/temporal/internal/worker/WorkerCommandTaskHandler.java b/temporal-sdk/src/main/java/io/temporal/internal/worker/WorkerCommandTaskHandler.java new file mode 100644 index 0000000000..6410f515d5 --- /dev/null +++ b/temporal-sdk/src/main/java/io/temporal/internal/worker/WorkerCommandTaskHandler.java @@ -0,0 +1,129 @@ +package io.temporal.internal.worker; + +import com.google.protobuf.InvalidProtocolBufferException; +import com.uber.m3.tally.Scope; +import io.temporal.api.common.v1.Payload; +import io.temporal.api.nexus.v1.Response; +import io.temporal.api.nexus.v1.StartOperationResponse; +import io.temporal.api.nexusservices.workerservice.v1.ExecuteCommandsRequest; +import io.temporal.api.nexusservices.workerservice.v1.ExecuteCommandsResponse; +import io.temporal.api.worker.v1.CancelActivityResult; +import io.temporal.api.worker.v1.WorkerCommand; +import io.temporal.api.worker.v1.WorkerCommandResult; +import io.temporal.common.converter.DataConverter; +import io.temporal.common.converter.GlobalDataConverter; +import io.temporal.serviceclient.Version; +import io.temporal.serviceclient.WorkflowServiceStubs; +import io.temporal.worker.tuning.FixedSizeSlotSupplier; +import io.temporal.worker.tuning.NexusSlotInfo; +import io.temporal.worker.tuning.PollerBehaviorSimpleMaximum; +import java.util.Objects; +import java.util.concurrent.TimeoutException; +import java.util.function.Function; +import javax.annotation.Nonnull; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** Handles server-to-worker commands delivered on the worker command Nexus task queue. */ +public final class WorkerCommandTaskHandler implements NexusTaskHandler { + private static final Logger log = LoggerFactory.getLogger(WorkerCommandTaskHandler.class); + private static final String TASK_QUEUE_PREFIX = "temporal-sys/worker-commands"; + + private final Function activityCancelCallback; + + public WorkerCommandTaskHandler(Function activityCancelCallback) { + this.activityCancelCallback = Objects.requireNonNull(activityCancelCallback); + } + + public static String workerControlTaskQueue(String namespace, String workerGroupingKey) { + return String.format("%s/%s/%s", TASK_QUEUE_PREFIX, namespace, workerGroupingKey); + } + + public static SuspendableWorker newWorkerCommandWorker( + @Nonnull WorkflowServiceStubs service, + @Nonnull String namespace, + @Nonnull String identity, + @Nonnull String workerGroupingKey, + @Nonnull Function activityCancelCallback, + @Nonnull Scope metricsScope, + @Nonnull NamespaceCapabilities namespaceCapabilities) { + String taskQueue = workerControlTaskQueue(namespace, workerGroupingKey); + DataConverter dataConverter = GlobalDataConverter.get(); + SingleWorkerOptions options = + SingleWorkerOptions.newBuilder() + .setIdentity(identity) + .setBuildId(Version.LIBRARY_VERSION) + .setWorkerInstanceKey(workerGroupingKey) + .setDataConverter(dataConverter) + .setMetricsScope(metricsScope) + .setPollerOptions( + PollerOptions.newBuilder() + .setPollerBehavior(new PollerBehaviorSimpleMaximum(1)) + .setPollThreadNamePrefix("WorkerCommandNexusPoller") + .build()) + .build(); + return new NexusWorker( + service, + namespace, + taskQueue, + options, + new WorkerCommandTaskHandler(activityCancelCallback), + dataConverter, + new FixedSizeSlotSupplier(5), + namespaceCapabilities, + true); + } + + @Override + public boolean start() { + return true; + } + + @Override + public Result handle(NexusTask task, Scope metricsScope) throws TimeoutException { + ExecuteCommandsRequest request = decodeRequest(task); + ExecuteCommandsResponse.Builder response = ExecuteCommandsResponse.newBuilder(); + for (WorkerCommand command : request.getCommandsList()) { + response.addResults(handleCommand(command)); + } + return new Result( + Response.newBuilder() + .setStartOperation( + StartOperationResponse.newBuilder() + .setSyncSuccess( + StartOperationResponse.Sync.newBuilder() + .setPayload( + Payload.newBuilder().setData(response.build().toByteString())))) + .build()); + } + + private ExecuteCommandsRequest decodeRequest(NexusTask task) { + if (!task.getResponse().hasRequest() + || !task.getResponse().getRequest().hasStartOperation() + || !task.getResponse().getRequest().getStartOperation().hasPayload()) { + throw new IllegalArgumentException( + "Worker command Nexus task missing ExecuteCommands payload"); + } + try { + return ExecuteCommandsRequest.parseFrom( + task.getResponse().getRequest().getStartOperation().getPayload().getData()); + } catch (InvalidProtocolBufferException e) { + throw new IllegalArgumentException("Failed to decode ExecuteCommandsRequest", e); + } + } + + private WorkerCommandResult handleCommand(WorkerCommand command) { + WorkerCommandResult.Builder result = WorkerCommandResult.newBuilder(); + if (command.hasCancelActivity()) { + byte[] taskToken = command.getCancelActivity().getTaskToken().toByteArray(); + Boolean found = activityCancelCallback.apply(taskToken); + if (!Boolean.TRUE.equals(found)) { + log.debug("Activity task token from worker command was not found"); + } + result.setCancelActivity(CancelActivityResult.newBuilder()); + } else { + log.warn("Worker command has no supported type set"); + } + return result.build(); + } +} diff --git a/temporal-sdk/src/main/java/io/temporal/internal/worker/WorkflowPollTask.java b/temporal-sdk/src/main/java/io/temporal/internal/worker/WorkflowPollTask.java index 18607b5d1e..1b6c8cf7dc 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/worker/WorkflowPollTask.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/worker/WorkflowPollTask.java @@ -54,7 +54,8 @@ public WorkflowPollTask( @Nonnull Scope workerMetricsScope, @Nonnull Supplier serverCapabilities, @Nonnull PollerTracker pollerTracker, - @Nonnull PollerTracker stickyPollerTracker) { + @Nonnull PollerTracker stickyPollerTracker, + String workerControlTaskQueue) { this.slotSupplier = Objects.requireNonNull(slotSupplier); this.stickyQueueBalancer = Objects.requireNonNull(stickyQueueBalancer); this.metricsScope = Objects.requireNonNull(workerMetricsScope); @@ -75,6 +76,9 @@ public WorkflowPollTask( .setNamespace(Objects.requireNonNull(namespace)) .setIdentity(Objects.requireNonNull(identity)); pollRequestBuilder.setWorkerInstanceKey(workerInstanceKey); + if (workerControlTaskQueue != null) { + pollRequestBuilder.setWorkerControlTaskQueue(workerControlTaskQueue); + } if (versioningOptions.getWorkerDeploymentOptions() != null) { pollRequestBuilder.setDeploymentOptions( diff --git a/temporal-sdk/src/main/java/io/temporal/internal/worker/WorkflowWorker.java b/temporal-sdk/src/main/java/io/temporal/internal/worker/WorkflowWorker.java index d6aa835a29..3d256bb2ba 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/worker/WorkflowWorker.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/worker/WorkflowWorker.java @@ -126,7 +126,8 @@ public boolean start() { slotSupplier, workerMetricsScope, service.getServerCapabilities(), - pollerTracker); + pollerTracker, + workerControlTaskQueue()); pollers = Arrays.asList( new AsyncWorkflowPollTask( @@ -140,7 +141,8 @@ public boolean start() { slotSupplier, workerMetricsScope, service.getServerCapabilities(), - stickyPollerTracker), + stickyPollerTracker, + workerControlTaskQueue()), normalPoller); this.stickyQueueBalancer = normalPoller; } else { @@ -157,7 +159,8 @@ public boolean start() { slotSupplier, workerMetricsScope, service.getServerCapabilities(), - pollerTracker)); + pollerTracker, + workerControlTaskQueue())); } poller = new AsyncPoller<>( @@ -191,7 +194,8 @@ public boolean start() { workerMetricsScope, service.getServerCapabilities(), pollerTracker, - stickyPollerTracker), + stickyPollerTracker, + workerControlTaskQueue()), pollTaskExecutor, pollerOptions, workerMetricsScope, @@ -647,6 +651,10 @@ private RespondWorkflowTaskCompletedResponse sendTaskCompleted( .setIdentity(options.getIdentity()) .setNamespace(namespace) .setTaskToken(taskToken); + String workerControlTaskQueue = workerControlTaskQueue(); + if (workerControlTaskQueue != null) { + taskCompleted.setWorkerControlTaskQueue(workerControlTaskQueue); + } if (options.getDeploymentOptions() != null) { taskCompleted.setDeploymentOptions( @@ -755,4 +763,10 @@ private Failure grpcMessageTooLargeFailure( .exceptionToFailure(applicationFailure); } } + + private String workerControlTaskQueue() { + return namespaceCapabilities.isWorkerHeartbeats() && namespaceCapabilities.isWorkerCommands() + ? options.getWorkerControlTaskQueue() + : null; + } } diff --git a/temporal-sdk/src/main/java/io/temporal/worker/Worker.java b/temporal-sdk/src/main/java/io/temporal/worker/Worker.java index 6355e5a75a..bab7bdd390 100644 --- a/temporal-sdk/src/main/java/io/temporal/worker/Worker.java +++ b/temporal-sdk/src/main/java/io/temporal/worker/Worker.java @@ -110,6 +110,7 @@ private static final class TaskSnapshot { WorkflowThreadExecutor workflowThreadExecutor, List contextPropagators, @Nonnull List plugins, + @Nonnull String workerGroupingKey, @Nonnull NamespaceCapabilities namespaceCapabilities) { Objects.requireNonNull(client, "client should not be null"); @@ -126,6 +127,10 @@ private static final class TaskSnapshot { WorkflowClientOptions clientOptions = client.getOptions(); String namespace = clientOptions.getNamespace(); this.namespace = namespace; + String workerControlTaskQueue = + clientOptions.getWorkerHeartbeatInterval().isNegative() + ? null + : WorkerCommandTaskHandler.workerControlTaskQueue(namespace, workerGroupingKey); Map tags = new ImmutableMap.Builder(1).put(MetricsTag.TASK_QUEUE, taskQueue).build(); Scope taggedScope = metricsScope.tagged(tags); @@ -136,7 +141,8 @@ private static final class TaskSnapshot { clientOptions, contextPropagators, taggedScope, - workerInstanceKey); + workerInstanceKey, + workerControlTaskQueue); if (this.options.isLocalActivityWorkerOnly()) { activityWorker = null; } else { @@ -169,7 +175,8 @@ private static final class TaskSnapshot { clientOptions, contextPropagators, taggedScope, - workerInstanceKey); + workerInstanceKey, + workerControlTaskQueue); SlotSupplier nexusSlotSupplier = this.options.getWorkerTuner() == null ? new FixedSizeSlotSupplier<>(this.options.getMaxConcurrentNexusExecutionSize()) @@ -188,7 +195,8 @@ private static final class TaskSnapshot { taskQueue, contextPropagators, taggedScope, - workerInstanceKey); + workerInstanceKey, + workerControlTaskQueue); SingleWorkerOptions localActivityOptions = toLocalActivityOptions( factoryOptions, @@ -196,7 +204,8 @@ private static final class TaskSnapshot { clientOptions, contextPropagators, taggedScope, - workerInstanceKey); + workerInstanceKey, + workerControlTaskQueue); SlotSupplier workflowSlotSupplier = this.options.getWorkerTuner() == null @@ -671,6 +680,10 @@ Supplier buildHeartbeatCallback(String workerGroupingKey) { }; } + boolean requestCancelActivity(byte[] taskToken) { + return activityWorker != null && activityWorker.requestCancelActivity(taskToken); + } + private WorkerSlotsInfo buildSlotsInfo( String key, TrackingSlotSupplier tracker, TaskCounter taskCounter) { int maxSlots = tracker.maximumSlots().orElse(-1); @@ -889,9 +902,15 @@ private static SingleWorkerOptions toActivityOptions( WorkflowClientOptions clientOptions, List contextPropagators, Scope metricsScope, - String workerInstanceKey) { + String workerInstanceKey, + String workerControlTaskQueue) { return toSingleWorkerOptions( - factoryOptions, options, clientOptions, contextPropagators, workerInstanceKey) + factoryOptions, + options, + clientOptions, + contextPropagators, + workerInstanceKey, + workerControlTaskQueue) .setUsingVirtualThreads(options.isUsingVirtualThreadsOnActivityWorker()) .setAllowActivityHeartbeatDuringShutdown(options.getAllowActivityHeartbeatDuringShutdown()) .setPollerOptions( @@ -914,9 +933,15 @@ private static SingleWorkerOptions toNexusOptions( WorkflowClientOptions clientOptions, List contextPropagators, Scope metricsScope, - String workerInstanceKey) { + String workerInstanceKey, + String workerControlTaskQueue) { return toSingleWorkerOptions( - factoryOptions, options, clientOptions, contextPropagators, workerInstanceKey) + factoryOptions, + options, + clientOptions, + contextPropagators, + workerInstanceKey, + workerControlTaskQueue) .setPollerOptions( PollerOptions.newBuilder() .setPollerBehavior( @@ -938,7 +963,8 @@ private static SingleWorkerOptions toWorkflowWorkerOptions( String taskQueue, List contextPropagators, Scope metricsScope, - String workerInstanceKey) { + String workerInstanceKey, + String workerControlTaskQueue) { Map tags = new ImmutableMap.Builder(1).put(MetricsTag.TASK_QUEUE, taskQueue).build(); @@ -968,7 +994,12 @@ private static SingleWorkerOptions toWorkflowWorkerOptions( } return toSingleWorkerOptions( - factoryOptions, options, clientOptions, contextPropagators, workerInstanceKey) + factoryOptions, + options, + clientOptions, + contextPropagators, + workerInstanceKey, + workerControlTaskQueue) .setPollerOptions( PollerOptions.newBuilder() .setPollerBehavior( @@ -991,9 +1022,15 @@ private static SingleWorkerOptions toLocalActivityOptions( WorkflowClientOptions clientOptions, List contextPropagators, Scope metricsScope, - String workerInstanceKey) { + String workerInstanceKey, + String workerControlTaskQueue) { return toSingleWorkerOptions( - factoryOptions, options, clientOptions, contextPropagators, workerInstanceKey) + factoryOptions, + options, + clientOptions, + contextPropagators, + workerInstanceKey, + workerControlTaskQueue) .setPollerOptions( PollerOptions.newBuilder() .setPollerBehavior(new PollerBehaviorSimpleMaximum(1)) @@ -1011,7 +1048,8 @@ private static SingleWorkerOptions.Builder toSingleWorkerOptions( WorkerOptions options, WorkflowClientOptions clientOptions, List contextPropagators, - String workerInstanceKey) { + String workerInstanceKey, + String workerControlTaskQueue) { String buildId = null; if (options.getBuildId() != null) { buildId = options.getBuildId(); @@ -1035,7 +1073,8 @@ private static SingleWorkerOptions.Builder toSingleWorkerOptions( .setMaxHeartbeatThrottleInterval(options.getMaxHeartbeatThrottleInterval()) .setDefaultHeartbeatThrottleInterval(options.getDefaultHeartbeatThrottleInterval()) .setDeploymentOptions(options.getDeploymentOptions()) - .setWorkerInstanceKey(workerInstanceKey); + .setWorkerInstanceKey(workerInstanceKey) + .setWorkerControlTaskQueue(workerControlTaskQueue); } /** diff --git a/temporal-sdk/src/main/java/io/temporal/worker/WorkerFactory.java b/temporal-sdk/src/main/java/io/temporal/worker/WorkerFactory.java index a87a36fb02..08522da969 100644 --- a/temporal-sdk/src/main/java/io/temporal/worker/WorkerFactory.java +++ b/temporal-sdk/src/main/java/io/temporal/worker/WorkerFactory.java @@ -17,6 +17,8 @@ import io.temporal.internal.worker.HeartbeatManager; import io.temporal.internal.worker.NamespaceCapabilities; import io.temporal.internal.worker.ShutdownManager; +import io.temporal.internal.worker.SuspendableWorker; +import io.temporal.internal.worker.WorkerCommandTaskHandler; import io.temporal.internal.worker.WorkflowExecutorCache; import io.temporal.internal.worker.WorkflowRunLockManager; import io.temporal.serviceclient.MetricsTag; @@ -64,6 +66,8 @@ public final class WorkerFactory { /** Namespace capabilities populated during start() from DescribeNamespace response. */ private final NamespaceCapabilities namespaceCapabilities = new NamespaceCapabilities(); + private SuspendableWorker workerCommandWorker; + private State state = State.Initial; private final String statusErrorMessage = @@ -201,6 +205,7 @@ public synchronized Worker newWorker(String taskQueue, WorkerOptions options) { workflowThreadExecutor, workflowClient.getOptions().getContextPropagators(), plugins, + ((WorkflowClientInternal) workflowClient.getInternal()).getWorkerGroupingKey(), namespaceCapabilities); workers.put(taskQueue, worker); @@ -285,6 +290,33 @@ public synchronized void start() { /** Internal method that actually starts the workers. Called from the plugin chain. */ private void doStart() { + // Start the internal nexus worker if enabled + WorkflowClientInternal clientInternal = (WorkflowClientInternal) workflowClient.getInternal(); + String namespace = workflowClient.getOptions().getNamespace(); + String workerGroupingKey = clientInternal.getWorkerGroupingKey(); + HeartbeatManager hbManager = clientInternal.getHeartbeatManager(); + if (hbManager != null + && namespaceCapabilities.isWorkerHeartbeats() + && namespaceCapabilities.isWorkerCommands()) { + workerCommandWorker = + WorkerCommandTaskHandler.newWorkerCommandWorker( + workflowClient.getWorkflowServiceStubs(), + namespace, + workflowClient.getOptions().getIdentity(), + workerGroupingKey, + taskToken -> { + for (Worker worker : workers.values()) { + if (worker.requestCancelActivity(taskToken)) { + return true; + } + } + return false; + }, + metricsScope, + namespaceCapabilities); + workerCommandWorker.start(); + } + // Start each worker with plugin hooks for (Map.Entry entry : workers.entrySet()) { String taskQueue = entry.getKey(); @@ -303,11 +335,7 @@ private void doStart() { } // Register heartbeat callbacks after workers are started. - WorkflowClientInternal clientInternal = (WorkflowClientInternal) workflowClient.getInternal(); - HeartbeatManager hbManager = clientInternal.getHeartbeatManager(); if (hbManager != null && namespaceCapabilities.isWorkerHeartbeats()) { - String namespace = workflowClient.getOptions().getNamespace(); - String workerGroupingKey = clientInternal.getWorkerGroupingKey(); for (Worker worker : workers.values()) { Supplier heartbeatSupplier = worker.buildHeartbeatCallback(workerGroupingKey); @@ -338,6 +366,9 @@ public synchronized boolean isTerminated() { if (state != State.Shutdown) { return false; } + if (workerCommandWorker != null && !workerCommandWorker.isTerminated()) { + return false; + } for (Worker worker : workers.values()) { if (!worker.isTerminated()) { return false; @@ -366,7 +397,7 @@ public WorkflowClient getWorkflowClient() { * Invocation has no additional effect if already shut down. */ public synchronized void shutdown() { - log.info("shutdown: {}", this); + log.debug("shutdown: {}", this); shutdownInternal(false); } @@ -432,6 +463,11 @@ private void doShutdown(boolean interruptUserTasks) { shutdownFutures.add(futureHolder[0]); } } + if (workerCommandWorker != null) { + // TODO: Should be able to pass `interruptUserTasks` here when + // https://github.com/temporalio/api/pull/784 is in + shutdownFutures.add(workerCommandWorker.shutdown(shutdownManager, true)); + } CompletableFuture.allOf(shutdownFutures.toArray(new CompletableFuture[0])) .thenApply( @@ -450,6 +486,7 @@ private void doShutdown(boolean interruptUserTasks) { } cache.invalidateAll(); workflowThreadPool.shutdownNow(); + workerCommandWorker = null; return null; }) .whenComplete( @@ -468,7 +505,7 @@ private void doShutdown(boolean interruptUserTasks) { * occurs. */ public void awaitTermination(long timeout, TimeUnit unit) { - log.info("awaitTermination begin: {}", this); + log.debug("awaitTermination begin: {}", this); long timeoutMillis = unit.toMillis(timeout); for (Worker worker : workers.values()) { long t = timeoutMillis; // closure needs immutable value @@ -476,7 +513,12 @@ public void awaitTermination(long timeout, TimeUnit unit) { ShutdownManager.runAndGetRemainingTimeoutMs( t, () -> worker.awaitTermination(t, TimeUnit.MILLISECONDS)); } - log.info("awaitTermination done: {}", this); + if (workerCommandWorker != null) { + long t = timeoutMillis; + ShutdownManager.runAndGetRemainingTimeoutMs( + t, () -> workerCommandWorker.awaitTermination(t, TimeUnit.MILLISECONDS)); + } + log.debug("awaitTermination done: {}", this); } // TODO we should hide an actual implementation of WorkerFactory under WorkerFactory interface and @@ -496,6 +538,9 @@ public synchronized void suspendPolling() { for (Worker worker : workers.values()) { worker.suspendPolling(); } + if (workerCommandWorker != null) { + workerCommandWorker.suspendPolling(); + } } public synchronized void resumePolling() { @@ -508,6 +553,9 @@ public synchronized void resumePolling() { for (Worker worker : workers.values()) { worker.resumePolling(); } + if (workerCommandWorker != null) { + workerCommandWorker.resumePolling(); + } } @Override diff --git a/temporal-sdk/src/test/java/io/temporal/internal/activity/HeartbeatContextImplTest.java b/temporal-sdk/src/test/java/io/temporal/internal/activity/HeartbeatContextImplTest.java index 686bc566f8..ac60202740 100644 --- a/temporal-sdk/src/test/java/io/temporal/internal/activity/HeartbeatContextImplTest.java +++ b/temporal-sdk/src/test/java/io/temporal/internal/activity/HeartbeatContextImplTest.java @@ -9,10 +9,12 @@ import io.grpc.StatusRuntimeException; import io.temporal.activity.ActivityInfo; import io.temporal.api.enums.v1.TimeoutType; +import io.temporal.api.workflowservice.v1.RecordActivityTaskHeartbeatRequest; import io.temporal.api.workflowservice.v1.RecordActivityTaskHeartbeatResponse; import io.temporal.api.workflowservice.v1.WorkflowServiceGrpc; import io.temporal.client.ActivityCanceledException; import io.temporal.client.ActivityCompletionException; +import io.temporal.client.WorkflowClient; import io.temporal.common.converter.GlobalDataConverter; import io.temporal.failure.TimeoutFailure; import io.temporal.serviceclient.WorkflowServiceStubs; @@ -25,6 +27,7 @@ import org.junit.After; import org.junit.Before; import org.junit.Test; +import org.mockito.ArgumentCaptor; public class HeartbeatContextImplTest { @@ -183,7 +186,141 @@ public void heartbeatTimeoutPersistsAcrossMultipleCalls() { ctx.cancelOutstandingHeartbeat(); } + @Test + public void workerCommandCancelStillSendsHeartbeatDetails() { + when(blockingStub.recordActivityTaskHeartbeat(any())) + .thenReturn(RecordActivityTaskHeartbeatResponse.getDefaultInstance()); + + ActivityInfo info = activityInfoWithHeartbeatTimeout(Duration.ofSeconds(10)); + HeartbeatContextImpl ctx = + createHeartbeatContext(info, Duration.ofMillis(100), Duration.ofMillis(100)); + + assertFalse(ctx.getCancellationToken().isCancellationRequested()); + assertFalse(ctx.getCancellationToken().getCancellationRequest().isDone()); + + ctx.heartbeat("before-cancel"); + ctx.cancelFromWorkerCommand(); + + assertTrue(ctx.getCancellationToken().isCancellationRequested()); + assertTrue(ctx.getCancellationToken().getCancellationRequest().isDone()); + assertThrows( + ActivityCanceledException.class, + () -> ctx.getCancellationToken().throwIfCancellationRequested()); + + try { + ctx.heartbeat("after-cancel"); + fail("Expected ActivityCanceledException"); + } catch (ActivityCanceledException e) { + assertNull(e.getCause()); + } + + ArgumentCaptor requestCaptor = + ArgumentCaptor.forClass(RecordActivityTaskHeartbeatRequest.class); + verify(blockingStub, timeout(1000).times(2)) + .recordActivityTaskHeartbeat(requestCaptor.capture()); + String details = + GlobalDataConverter.get() + .fromPayloads( + 0, + Optional.of(requestCaptor.getAllValues().get(1).getDetails()), + String.class, + String.class); + assertEquals("after-cancel", details); + ctx.cancelOutstandingHeartbeat(); + } + + @Test + public void asyncCompletionRejectsNewHeartbeatsAndFlushesQueuedHeartbeat() { + when(blockingStub.recordActivityTaskHeartbeat(any())) + .thenReturn(RecordActivityTaskHeartbeatResponse.getDefaultInstance()); + + ActivityInfo info = activityInfoWithHeartbeatTimeout(Duration.ofSeconds(10)); + HeartbeatContextImpl ctx = + createHeartbeatContext(info, Duration.ofMillis(100), Duration.ofMillis(100)); + + ctx.heartbeat("sent-before-return"); + ctx.heartbeat("queued-before-return"); + ctx.asyncCompletionStarted(); + + assertTrue(ctx.getCancellationToken().isCancellationRequested()); + assertTrue(ctx.getCancellationToken().getCancellationRequest().isDone()); + assertThrows(ActivityCanceledException.class, () -> ctx.heartbeat("after-return")); + + ArgumentCaptor requestCaptor = + ArgumentCaptor.forClass(RecordActivityTaskHeartbeatRequest.class); + verify(blockingStub, timeout(1000).times(2)) + .recordActivityTaskHeartbeat(requestCaptor.capture()); + String details = + GlobalDataConverter.get() + .fromPayloads( + 0, + Optional.of(requestCaptor.getAllValues().get(1).getDetails()), + String.class, + String.class); + assertEquals("queued-before-return", details); + ctx.cancelOutstandingHeartbeat(); + } + + @Test + public void heartbeatCancelCompletesCancellationToken() { + when(blockingStub.recordActivityTaskHeartbeat(any())) + .thenReturn( + RecordActivityTaskHeartbeatResponse.newBuilder().setCancelRequested(true).build()); + + ActivityInfo info = activityInfoWithHeartbeatTimeout(Duration.ofSeconds(10)); + HeartbeatContextImpl ctx = createHeartbeatContext(info); + + assertFalse(ctx.getCancellationToken().isCancellationRequested()); + assertFalse(ctx.getCancellationToken().getCancellationRequest().isDone()); + + assertThrows(ActivityCanceledException.class, () -> ctx.heartbeat("details")); + + assertTrue(ctx.getCancellationToken().isCancellationRequested()); + assertTrue(ctx.getCancellationToken().getCancellationRequest().isDone()); + assertThrows( + ActivityCanceledException.class, + () -> ctx.getCancellationToken().throwIfCancellationRequested()); + + ctx.cancelOutstandingHeartbeat(); + } + + @Test + public void factoryCancelByTaskTokenCompletesCancellationToken() { + WorkflowClient client = mock(WorkflowClient.class); + when(client.getWorkflowServiceStubs()).thenReturn(service); + + ActivityExecutionContextFactoryImpl factory = + new ActivityExecutionContextFactoryImpl( + client, + "test-identity", + "test-namespace", + Duration.ofSeconds(60), + Duration.ofSeconds(30), + GlobalDataConverter.get(), + heartbeatExecutor); + + ActivityInfoInternal info = activityInfoWithHeartbeatTimeout(Duration.ofSeconds(10)); + InternalActivityExecutionContext context = + factory.createContext(info, new Object(), new NoopScope()); + + assertFalse(context.getCancellationToken().isCancellationRequested()); + assertFalse(factory.cleanupContext(new byte[] {9, 8, 7}, true)); + assertTrue(factory.cleanupContext(new byte[] {1, 2, 3}, true)); + assertTrue(context.getCancellationToken().isCancellationRequested()); + assertTrue(context.getCancellationToken().getCancellationRequest().isDone()); + + context.cancelOutstandingHeartbeat(); + assertFalse(factory.cleanupContext(new byte[] {1, 2, 3}, true)); + } + private HeartbeatContextImpl createHeartbeatContext(ActivityInfo info) { + return createHeartbeatContext(info, Duration.ofSeconds(60), Duration.ofSeconds(30)); + } + + private HeartbeatContextImpl createHeartbeatContext( + ActivityInfo info, + Duration maxHeartbeatThrottleInterval, + Duration defaultHeartbeatThrottleInterval) { return new HeartbeatContextImpl( service, "test-namespace", @@ -192,13 +329,13 @@ private HeartbeatContextImpl createHeartbeatContext(ActivityInfo info) { heartbeatExecutor, new NoopScope(), "test-identity", - Duration.ofSeconds(60), - Duration.ofSeconds(30), + maxHeartbeatThrottleInterval, + defaultHeartbeatThrottleInterval, TEST_BUFFER_MILLIS); } - private static ActivityInfo activityInfoWithHeartbeatTimeout(Duration heartbeatTimeout) { - ActivityInfo info = mock(ActivityInfo.class); + private static ActivityInfoInternal activityInfoWithHeartbeatTimeout(Duration heartbeatTimeout) { + ActivityInfoInternal info = mock(ActivityInfoInternal.class); when(info.getHeartbeatTimeout()).thenReturn(heartbeatTimeout); when(info.getTaskToken()).thenReturn(new byte[] {1, 2, 3}); when(info.getWorkflowId()).thenReturn("test-workflow-id"); @@ -208,6 +345,7 @@ private static ActivityInfo activityInfoWithHeartbeatTimeout(Duration heartbeatT when(info.getActivityId()).thenReturn("test-activity-id"); when(info.isLocal()).thenReturn(false); when(info.getHeartbeatDetails()).thenReturn(Optional.empty()); + when(info.getCompletionHandle()).thenReturn(() -> {}); return info; } } diff --git a/temporal-sdk/src/test/java/io/temporal/internal/worker/SlotSupplierTest.java b/temporal-sdk/src/test/java/io/temporal/internal/worker/SlotSupplierTest.java index c6f11a61a1..1017a76431 100644 --- a/temporal-sdk/src/test/java/io/temporal/internal/worker/SlotSupplierTest.java +++ b/temporal-sdk/src/test/java/io/temporal/internal/worker/SlotSupplierTest.java @@ -87,7 +87,8 @@ public void supplierIsCalledAppropriately() { metricsScope, () -> GetSystemInfoResponse.Capabilities.newBuilder().build(), new PollerTracker(), - new PollerTracker()); + new PollerTracker(), + null); PollWorkflowTaskQueueResponse pollResponse = PollWorkflowTaskQueueResponse.newBuilder() @@ -178,7 +179,8 @@ public void asyncPollerSupplierIsCalledAppropriately() throws Exception { trackingSS, metricsScope, () -> GetSystemInfoResponse.Capabilities.newBuilder().build(), - new PollerTracker()); + new PollerTracker(), + null); SlotPermit permit = new SlotPermit(); diff --git a/temporal-sdk/src/test/java/io/temporal/internal/worker/StickyQueueBacklogTest.java b/temporal-sdk/src/test/java/io/temporal/internal/worker/StickyQueueBacklogTest.java index ab806c960b..5a29a74054 100644 --- a/temporal-sdk/src/test/java/io/temporal/internal/worker/StickyQueueBacklogTest.java +++ b/temporal-sdk/src/test/java/io/temporal/internal/worker/StickyQueueBacklogTest.java @@ -75,7 +75,8 @@ public void stickyQueueBacklogResetTest() { metricsScope, () -> GetSystemInfoResponse.Capabilities.newBuilder().build(), new PollerTracker(), - new PollerTracker()); + new PollerTracker(), + null); PollWorkflowTaskQueueResponse pollResponse = PollWorkflowTaskQueueResponse.newBuilder() diff --git a/temporal-sdk/src/test/java/io/temporal/worker/WorkerHeartbeatIntegrationTest.java b/temporal-sdk/src/test/java/io/temporal/worker/WorkerHeartbeatIntegrationTest.java index 2684180542..d49a8d70f6 100644 --- a/temporal-sdk/src/test/java/io/temporal/worker/WorkerHeartbeatIntegrationTest.java +++ b/temporal-sdk/src/test/java/io/temporal/worker/WorkerHeartbeatIntegrationTest.java @@ -748,6 +748,7 @@ private List listWorkersForQueue(String taskQueue) { .setNamespace( testWorkflowRule.getWorkflowClient().getOptions().getNamespace()) .setQuery("TaskQueue = \"" + taskQueue + "\"") + .setIncludeSystemWorkers(false) .setPageSize(200) .build()); return resp.getWorkersInfoList().stream() diff --git a/temporal-sdk/src/test/java/io/temporal/worker/WorkerShutdownTest.java b/temporal-sdk/src/test/java/io/temporal/worker/WorkerShutdownTest.java index e9f4c9a361..23a63cda8b 100644 --- a/temporal-sdk/src/test/java/io/temporal/worker/WorkerShutdownTest.java +++ b/temporal-sdk/src/test/java/io/temporal/worker/WorkerShutdownTest.java @@ -118,6 +118,7 @@ public void activeTaskQueueTypesEvaluatedAtShutdownTime() throws Exception { wfThreadExecutor, Collections.emptyList(), Collections.emptyList(), + "test-worker-group", new NamespaceCapabilities()); // Register types AFTER worker construction. The request built by shutdown should reflect diff --git a/temporal-sdk/src/test/java/io/temporal/workflow/activityTests/AsyncActivityCompleteWithErrorTest.java b/temporal-sdk/src/test/java/io/temporal/workflow/activityTests/AsyncActivityCompleteWithErrorTest.java index d999909c7f..dd43ded763 100644 --- a/temporal-sdk/src/test/java/io/temporal/workflow/activityTests/AsyncActivityCompleteWithErrorTest.java +++ b/temporal-sdk/src/test/java/io/temporal/workflow/activityTests/AsyncActivityCompleteWithErrorTest.java @@ -12,17 +12,21 @@ import io.temporal.workflow.WorkflowMethod; import java.time.Duration; import java.util.concurrent.ForkJoinPool; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; import org.junit.Assert; import org.junit.Rule; import org.junit.Test; public class AsyncActivityCompleteWithErrorTest { + private final AsyncActivityWithManualCompletion activities = + new AsyncActivityWithManualCompletion(); @Rule public SDKTestWorkflowRule testWorkflowRule = SDKTestWorkflowRule.newBuilder() .setWorkflowTypes(TestWorkflowImpl.class) - .setActivityImplementations(new AsyncActivityWithManualCompletion()) + .setActivityImplementations(activities) .build(); @WorkflowInterface @@ -42,6 +46,7 @@ public String execute(String taskQueue) { ActivityOptions.newBuilder() .setScheduleToStartTimeout(Duration.ofSeconds(1)) .setScheduleToCloseTimeout(Duration.ofSeconds(1)) + .setHeartbeatTimeout(Duration.ofSeconds(1)) .setRetryOptions(RetryOptions.newBuilder().setMaximumAttempts(1).build()) .build()); Promise promise = Async.function(activity::execute); @@ -64,15 +69,31 @@ public interface TestActivity { } public static class AsyncActivityWithManualCompletion implements TestActivity { + private final AtomicBoolean postReturnHeartbeatSucceeded = new AtomicBoolean(); + private final AtomicBoolean postReturnTokenCanceled = new AtomicBoolean(); + private final AtomicReference postReturnHeartbeatFailure = new AtomicReference<>(); + @Override public int execute() { ActivityExecutionContext context = Activity.getExecutionContext(); ManualActivityCompletionClient completionClient = context.useLocalManualCompletion(); - ForkJoinPool.commonPool().execute(() -> asyncActivityFn(completionClient)); + ForkJoinPool.commonPool().execute(() -> asyncActivityFn(context, completionClient)); return 0; } - private void asyncActivityFn(ManualActivityCompletionClient completionClient) { + private void asyncActivityFn( + ActivityExecutionContext context, ManualActivityCompletionClient completionClient) { + try { + Thread.sleep(100); + postReturnTokenCanceled.set(context.getCancellationToken().isCancellationRequested()); + context.heartbeat("after-local-manual-return"); + postReturnHeartbeatSucceeded.set(true); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + postReturnHeartbeatFailure.set(e); + } catch (Throwable e) { + postReturnHeartbeatFailure.set(e); + } completionClient.fail( ApplicationFailure.newFailure("simulated failure", "test", "some details")); } @@ -84,5 +105,8 @@ public void verifyActivityCompletionClientCompleteExceptionally() { TestWorkflow workflow = testWorkflowRule.newWorkflowStub(TestWorkflow.class); String result = workflow.execute(taskQueue); Assert.assertEquals("success", result); + Assert.assertNull(activities.postReturnHeartbeatFailure.get()); + Assert.assertTrue(activities.postReturnHeartbeatSucceeded.get()); + Assert.assertFalse(activities.postReturnTokenCanceled.get()); } } diff --git a/temporal-sdk/src/test/java/io/temporal/workflow/activityTests/AsyncActivityWithCompletionClientTest.java b/temporal-sdk/src/test/java/io/temporal/workflow/activityTests/AsyncActivityWithCompletionClientTest.java index ef666618d5..3e9e48362d 100644 --- a/temporal-sdk/src/test/java/io/temporal/workflow/activityTests/AsyncActivityWithCompletionClientTest.java +++ b/temporal-sdk/src/test/java/io/temporal/workflow/activityTests/AsyncActivityWithCompletionClientTest.java @@ -33,12 +33,16 @@ public void tearDown() throws Exception { @Test public void testAsyncActivity() { + completionClientActivitiesImpl.activity1AsyncCompletionTokenCanceled.set(false); + completionClientActivitiesImpl.activity1PostReturnHeartbeatRejected.set(false); completionClientActivitiesImpl.completionClient = testWorkflowRule.getWorkflowClient().newActivityCompletionClient(); TestWorkflow1 client = testWorkflowRule.newWorkflowStubTimeoutOptions(TestWorkflow1.class); String result = client.execute(testWorkflowRule.getTaskQueue()); Assert.assertEquals("workflow", result); Assert.assertEquals("activity1", completionClientActivitiesImpl.invocations.get(0)); + Assert.assertTrue(completionClientActivitiesImpl.activity1AsyncCompletionTokenCanceled.get()); + Assert.assertTrue(completionClientActivitiesImpl.activity1PostReturnHeartbeatRejected.get()); } public static class TestAsyncActivityWorkflowImpl implements TestWorkflow1 { diff --git a/temporal-sdk/src/test/java/io/temporal/workflow/activityTests/cancellation/ActivityCancellationTokenIntegrationTest.java b/temporal-sdk/src/test/java/io/temporal/workflow/activityTests/cancellation/ActivityCancellationTokenIntegrationTest.java new file mode 100644 index 0000000000..a120455337 --- /dev/null +++ b/temporal-sdk/src/test/java/io/temporal/workflow/activityTests/cancellation/ActivityCancellationTokenIntegrationTest.java @@ -0,0 +1,158 @@ +package io.temporal.workflow.activityTests.cancellation; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assume.assumeTrue; + +import io.temporal.activity.Activity; +import io.temporal.activity.ActivityCancellationType; +import io.temporal.activity.ActivityExecutionContext; +import io.temporal.activity.ActivityInterface; +import io.temporal.activity.ActivityOptions; +import io.temporal.api.workflowservice.v1.DescribeNamespaceRequest; +import io.temporal.api.workflowservice.v1.DescribeNamespaceResponse; +import io.temporal.client.ActivityCanceledException; +import io.temporal.client.WorkflowClientOptions; +import io.temporal.failure.ActivityFailure; +import io.temporal.failure.CanceledFailure; +import io.temporal.testing.internal.SDKTestWorkflowRule; +import io.temporal.workflow.Async; +import io.temporal.workflow.CancellationScope; +import io.temporal.workflow.Promise; +import io.temporal.workflow.SignalMethod; +import io.temporal.workflow.Workflow; +import io.temporal.workflow.WorkflowInterface; +import io.temporal.workflow.WorkflowMethod; +import java.time.Duration; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; + +public class ActivityCancellationTokenIntegrationTest { + + @Rule + public SDKTestWorkflowRule testWorkflowRule = + SDKTestWorkflowRule.newBuilder() + .setTestTimeoutSeconds(30) + .setWorkflowClientOptions( + WorkflowClientOptions.newBuilder() + .setWorkerHeartbeatInterval(Duration.ofSeconds(1)) + .build()) + .setWorkflowTypes(TestCancellationWorkflowImpl.class) + .setActivityImplementations(new NonHeartbeatingActivityImpl()) + .build(); + + @Before + public void checkServerSupportsWorkerCommands() { + assumeTrue( + "Requires real server with worker command support", SDKTestWorkflowRule.useExternalService); + + DescribeNamespaceResponse response = + testWorkflowRule + .getWorkflowClient() + .getWorkflowServiceStubs() + .blockingStub() + .describeNamespace( + DescribeNamespaceRequest.newBuilder() + .setNamespace(testWorkflowRule.getWorkflowClient().getOptions().getNamespace()) + .build()); + assumeTrue( + "Server does not support worker heartbeats", + response.getNamespaceInfo().getCapabilities().getWorkerHeartbeats()); + assumeTrue( + "Server does not support worker commands", + response.getNamespaceInfo().getCapabilities().getWorkerCommands()); + } + + @Test + public void activityObservesCancellationWithoutHeartbeat() { + TestCancellationWorkflow workflow = + testWorkflowRule.newWorkflowStub(TestCancellationWorkflow.class); + + assertEquals("cancelled", workflow.execute(testWorkflowRule.getTaskQueue())); + } + + @WorkflowInterface + public interface TestCancellationWorkflow { + @WorkflowMethod + String execute(String taskQueue); + + @SignalMethod + void activityStarted(); + } + + @ActivityInterface + public interface NonHeartbeatingActivity { + String waitForCancellation(); + } + + public static class TestCancellationWorkflowImpl implements TestCancellationWorkflow { + private boolean activityStarted; + + @Override + public String execute(String taskQueue) { + NonHeartbeatingActivity activity = + Workflow.newActivityStub( + NonHeartbeatingActivity.class, + ActivityOptions.newBuilder() + .setTaskQueue(taskQueue) + .setScheduleToCloseTimeout(Duration.ofSeconds(20)) + .setStartToCloseTimeout(Duration.ofSeconds(20)) + .setCancellationType(ActivityCancellationType.WAIT_CANCELLATION_COMPLETED) + .setDisableEagerExecution(true) + .build()); + + List> activityResults = new ArrayList<>(); + CancellationScope cancellationScope = + Workflow.newCancellationScope( + () -> activityResults.add(Async.function(activity::waitForCancellation))); + + cancellationScope.run(); + Workflow.await(() -> activityStarted); + cancellationScope.cancel(); + + try { + activityResults.get(0).get(); + return "completed"; + } catch (ActivityFailure e) { + if (e.getCause() instanceof CanceledFailure) { + return "cancelled"; + } + throw e; + } + } + + @Override + public void activityStarted() { + activityStarted = true; + } + } + + public static class NonHeartbeatingActivityImpl implements NonHeartbeatingActivity { + @Override + public String waitForCancellation() { + ActivityExecutionContext context = Activity.getExecutionContext(); + context + .getWorkflowClient() + .newWorkflowStub(TestCancellationWorkflow.class, context.getInfo().getWorkflowId()) + .activityStarted(); + + try { + context.getCancellationToken().getCancellationRequest().get(20, TimeUnit.SECONDS); + context.getCancellationToken().throwIfCancellationRequested(); + return "not-cancelled"; + } catch (ActivityCanceledException e) { + throw e; + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException(e); + } catch (ExecutionException | TimeoutException e) { + throw new RuntimeException(e); + } + } + } +} diff --git a/temporal-sdk/src/test/java/io/temporal/workflow/shared/TestActivities.java b/temporal-sdk/src/test/java/io/temporal/workflow/shared/TestActivities.java index d80a7ece8f..2c68c576f1 100644 --- a/temporal-sdk/src/test/java/io/temporal/workflow/shared/TestActivities.java +++ b/temporal-sdk/src/test/java/io/temporal/workflow/shared/TestActivities.java @@ -25,6 +25,7 @@ import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; public class TestActivities { @@ -407,6 +408,8 @@ public int getLastAttempt() { public static class CompletionClientActivitiesImpl implements CompletionClientActivities, Closeable { public final List invocations = Collections.synchronizedList(new ArrayList<>()); + public final AtomicBoolean activity1AsyncCompletionTokenCanceled = new AtomicBoolean(); + public final AtomicBoolean activity1PostReturnHeartbeatRejected = new AtomicBoolean(); private final ThreadPoolExecutor executor = new ThreadPoolExecutor(0, 100, 1, TimeUnit.SECONDS, new LinkedBlockingQueue<>()); public ActivityCompletionClient completionClient; @@ -422,13 +425,31 @@ public void assertInvocations(String... expected) { @Override public String activity1(String a1) { Preconditions.checkNotNull(completionClient, "completionClient"); - byte[] taskToken = Activity.getExecutionContext().getInfo().getTaskToken(); + ActivityExecutionContext ctx = Activity.getExecutionContext(); + byte[] taskToken = ctx.getInfo().getTaskToken(); executor.execute( () -> { invocations.add("activity1"); + long deadline = System.currentTimeMillis() + TimeUnit.SECONDS.toMillis(5); + while (!ctx.getCancellationToken().isCancellationRequested() + && System.currentTimeMillis() < deadline) { + try { + Thread.sleep(10); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + break; + } + } + activity1AsyncCompletionTokenCanceled.set( + ctx.getCancellationToken().isCancellationRequested()); + try { + ctx.heartbeat("after-async-return"); + } catch (ActivityCanceledException e) { + activity1PostReturnHeartbeatRejected.set(true); + } completionClient.complete(taskToken, a1); }); - Activity.getExecutionContext().doNotCompleteOnReturn(); + ctx.doNotCompleteOnReturn(); return "ignored"; } diff --git a/temporal-serviceclient/src/main/java/io/temporal/serviceclient/CloudServiceStubsImpl.java b/temporal-serviceclient/src/main/java/io/temporal/serviceclient/CloudServiceStubsImpl.java index be5874160b..0f126cad9c 100644 --- a/temporal-serviceclient/src/main/java/io/temporal/serviceclient/CloudServiceStubsImpl.java +++ b/temporal-serviceclient/src/main/java/io/temporal/serviceclient/CloudServiceStubsImpl.java @@ -39,7 +39,7 @@ final class CloudServiceStubsImpl implements CloudServiceStubs { .setInternalErrorDifferentiation(true) .build()); - log.info("Created CloudServiceStubs for channel: {}", channelManager.getRawChannel()); + log.debug("Created CloudServiceStubs for channel: {}", channelManager.getRawChannel()); this.blockingStub = CloudServiceGrpc.newBlockingStub(channelManager.getInterceptedChannel()); this.futureStub = CloudServiceGrpc.newFutureStub(channelManager.getInterceptedChannel()); diff --git a/temporal-serviceclient/src/main/java/io/temporal/serviceclient/OperatorServiceStubsImpl.java b/temporal-serviceclient/src/main/java/io/temporal/serviceclient/OperatorServiceStubsImpl.java index ad13548786..8ee6c6d01d 100644 --- a/temporal-serviceclient/src/main/java/io/temporal/serviceclient/OperatorServiceStubsImpl.java +++ b/temporal-serviceclient/src/main/java/io/temporal/serviceclient/OperatorServiceStubsImpl.java @@ -32,7 +32,7 @@ final class OperatorServiceStubsImpl implements OperatorServiceStubs { this.channelManager = new ChannelManager(options, Collections.singletonList(deadlineInterceptor)); - log.info("Created OperatorServiceStubs for channel: {}", channelManager.getRawChannel()); + log.debug("Created OperatorServiceStubs for channel: {}", channelManager.getRawChannel()); this.blockingStub = OperatorServiceGrpc.newBlockingStub(channelManager.getInterceptedChannel()); this.futureStub = OperatorServiceGrpc.newFutureStub(channelManager.getInterceptedChannel()); diff --git a/temporal-serviceclient/src/main/java/io/temporal/serviceclient/WorkflowServiceStubsImpl.java b/temporal-serviceclient/src/main/java/io/temporal/serviceclient/WorkflowServiceStubsImpl.java index bfeb3b533e..4edc11bd4e 100644 --- a/temporal-serviceclient/src/main/java/io/temporal/serviceclient/WorkflowServiceStubsImpl.java +++ b/temporal-serviceclient/src/main/java/io/temporal/serviceclient/WorkflowServiceStubsImpl.java @@ -63,7 +63,7 @@ final class WorkflowServiceStubsImpl implements WorkflowServiceStubs { this.channelManager = new ChannelManager(this.options, Collections.singletonList(deadlineInterceptor)); - log.info( + log.debug( String.format( "Created WorkflowServiceStubs for channel: %s", channelManager.getRawChannel())); diff --git a/temporal-test-server/src/main/java/io/temporal/serviceclient/TestServiceStubsImpl.java b/temporal-test-server/src/main/java/io/temporal/serviceclient/TestServiceStubsImpl.java index e847118013..9f62a6089f 100644 --- a/temporal-test-server/src/main/java/io/temporal/serviceclient/TestServiceStubsImpl.java +++ b/temporal-test-server/src/main/java/io/temporal/serviceclient/TestServiceStubsImpl.java @@ -32,7 +32,7 @@ public class TestServiceStubsImpl implements TestServiceStubs { this.channelManager = new ChannelManager(options, Collections.singletonList(deadlineInterceptor)); - log.info("Created TestServiceStubs for channel: {}", channelManager.getRawChannel()); + log.debug("Created TestServiceStubs for channel: {}", channelManager.getRawChannel()); this.blockingStub = TestServiceGrpc.newBlockingStub(channelManager.getInterceptedChannel()); this.futureStub = TestServiceGrpc.newFutureStub(channelManager.getInterceptedChannel());