From 0cf75bb210cb6ed4d87007f30e9a44726e2b132d Mon Sep 17 00:00:00 2001 From: jmaeagle99 <44687433+jmaeagle99@users.noreply.github.com> Date: Fri, 5 Jun 2026 15:01:24 -0700 Subject: [PATCH 1/5] Message and concurrent payload visitors --- temporal-sdk/build.gradle | 29 + .../payload/visitor/GeneratedVisitor.java | 11 + .../payload/visitor/MessageRegistryEntry.java | 18 + .../payload/visitor/MessageVisitor.java | 23 + .../visitor/MessageVisitorOptions.java | 59 ++ .../payload/visitor/MessageVisitors.java | 43 + .../payload/visitor/PayloadVisitor.java | 26 + .../visitor/PayloadVisitorContext.java | 33 + .../visitor/PayloadVisitorOptions.java | 141 +++ .../payload/visitor/PayloadVisitors.java | 54 + .../internal/payload/visitor/Traversal.java | 237 +++++ .../payload/visitor/VisitorException.java | 15 + .../visitor/gen/PayloadVisitorGenerator.java | 593 +++++++++++ .../payload/visitor/gen/ProtoClosure.java | 138 +++ .../payload/visitor/MessageVisitorTest.java | 192 ++++ .../payload/visitor/PayloadVisitorTest.java | 951 ++++++++++++++++++ .../payload/visitor/TestVisitorException.java | 12 + 17 files changed, 2575 insertions(+) create mode 100644 temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/GeneratedVisitor.java create mode 100644 temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/MessageRegistryEntry.java create mode 100644 temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/MessageVisitor.java create mode 100644 temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/MessageVisitorOptions.java create mode 100644 temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/MessageVisitors.java create mode 100644 temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/PayloadVisitor.java create mode 100644 temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/PayloadVisitorContext.java create mode 100644 temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/PayloadVisitorOptions.java create mode 100644 temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/PayloadVisitors.java create mode 100644 temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/Traversal.java create mode 100644 temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/VisitorException.java create mode 100644 temporal-sdk/src/payloadVisitorGenerator/java/io/temporal/internal/payload/visitor/gen/PayloadVisitorGenerator.java create mode 100644 temporal-sdk/src/payloadVisitorGenerator/java/io/temporal/internal/payload/visitor/gen/ProtoClosure.java create mode 100644 temporal-sdk/src/test/java/io/temporal/internal/payload/visitor/MessageVisitorTest.java create mode 100644 temporal-sdk/src/test/java/io/temporal/internal/payload/visitor/PayloadVisitorTest.java create mode 100644 temporal-sdk/src/test/java/io/temporal/internal/payload/visitor/TestVisitorException.java diff --git a/temporal-sdk/build.gradle b/temporal-sdk/build.gradle index 9e914e31f4..6b54852ed7 100644 --- a/temporal-sdk/build.gradle +++ b/temporal-sdk/build.gradle @@ -65,6 +65,35 @@ dependencies { java21Implementation files(sourceSets.main.output.classesDirs) { builtBy compileJava } } +// --- Payload visitor code generation --- +// A build-time generator (compiled in its own source set against the proto classes from +// temporal-serviceclient) emits GeneratedPayloadVisitor.java, which knows how to walk every +// payload-bearing Temporal API message. The generated source is added to the main source set. +sourceSets { + payloadVisitorGenerator { + java { + srcDirs = ['src/payloadVisitorGenerator/java'] + } + } +} + +dependencies { + payloadVisitorGeneratorImplementation project(':temporal-serviceclient') +} + +def generatedPayloadVisitorDir = layout.buildDirectory.dir('generated/payloadvisitor/java') + +def generatePayloadVisitor = tasks.register('generatePayloadVisitor', JavaExec) { + dependsOn 'compilePayloadVisitorGeneratorJava' + classpath = sourceSets.payloadVisitorGenerator.runtimeClasspath + mainClass = 'io.temporal.internal.payload.visitor.gen.PayloadVisitorGenerator' + args generatedPayloadVisitorDir.get().asFile.absolutePath + inputs.files(sourceSets.payloadVisitorGenerator.runtimeClasspath) + outputs.dir(generatedPayloadVisitorDir) +} + +sourceSets.main.java.srcDir(generatePayloadVisitor) + tasks.named('compileJava17Java') { options.release = 17 } diff --git a/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/GeneratedVisitor.java b/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/GeneratedVisitor.java new file mode 100644 index 0000000000..4e8325ba54 --- /dev/null +++ b/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/GeneratedVisitor.java @@ -0,0 +1,11 @@ +package io.temporal.internal.payload.visitor; + +import com.google.protobuf.Message; + +/** + * Generated traversal for one message type: visits the message's payload fields and recurses into + * its child messages. There is one per message type that can contain a payload. + */ +interface GeneratedVisitor { + void visit(Traversal traversal, Message.Builder builder); +} diff --git a/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/MessageRegistryEntry.java b/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/MessageRegistryEntry.java new file mode 100644 index 0000000000..2510878c62 --- /dev/null +++ b/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/MessageRegistryEntry.java @@ -0,0 +1,18 @@ +package io.temporal.internal.payload.visitor; + +import com.google.protobuf.Message; +import java.util.function.Supplier; + +/** + * How to traverse one message type, and how to create an empty builder for it (used to unpack + * {@code google.protobuf.Any} values). + */ +final class MessageRegistryEntry { + final GeneratedVisitor visitor; + final Supplier newBuilder; + + MessageRegistryEntry(GeneratedVisitor visitor, Supplier newBuilder) { + this.visitor = visitor; + this.newBuilder = newBuilder; + } +} diff --git a/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/MessageVisitor.java b/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/MessageVisitor.java new file mode 100644 index 0000000000..21268e41d7 --- /dev/null +++ b/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/MessageVisitor.java @@ -0,0 +1,23 @@ +package io.temporal.internal.payload.visitor; + +import com.google.protobuf.MessageOrBuilder; + +/** + * Callback invoked when traversal enters a proto message. The returned value becomes the contextual + * value in scope for that message and everything within it, and is restored to the enclosing value + * once traversal leaves the message. The message is provided as a builder and may be inspected or + * mutated. + * + * @param type of the contextual value + */ +@FunctionalInterface +interface MessageVisitor { + /** + * Handles a message being entered and returns the contextual value for it and its contents. + * + * @param current the contextual value in scope from the enclosing message + * @param message the message being entered + * @return the contextual value to use for this message and its contents + */ + C onEnter(C current, MessageOrBuilder message); +} diff --git a/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/MessageVisitorOptions.java b/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/MessageVisitorOptions.java new file mode 100644 index 0000000000..350c75a793 --- /dev/null +++ b/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/MessageVisitorOptions.java @@ -0,0 +1,59 @@ +package io.temporal.internal.payload.visitor; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; + +/** + * Options for visiting the messages of a proto message, without visiting individual payloads. + * + * @param type of the contextual value supplied to the visitor + */ +final class MessageVisitorOptions { + private final @Nonnull MessageVisitor messageVisitor; + private final @Nullable C initialContext; + + private MessageVisitorOptions(Builder b) { + this.messageVisitor = b.messageVisitor; + this.initialContext = b.initialContext; + } + + public static Builder newBuilder() { + return new Builder<>(); + } + + @Nonnull + public MessageVisitor getMessageVisitor() { + return messageVisitor; + } + + @Nullable + public C getInitialContext() { + return initialContext; + } + + public static final class Builder { + private MessageVisitor messageVisitor; + private C initialContext; + + private Builder() {} + + /** Required. The message visitor. */ + public Builder setMessageVisitor(@Nonnull MessageVisitor messageVisitor) { + this.messageVisitor = messageVisitor; + return this; + } + + /** Optional. The contextual value in scope before any message is entered. */ + public Builder setInitialContext(@Nullable C initialContext) { + this.initialContext = initialContext; + return this; + } + + public MessageVisitorOptions build() { + if (messageVisitor == null) { + throw new IllegalArgumentException("messageVisitor is required"); + } + return new MessageVisitorOptions<>(this); + } + } +} diff --git a/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/MessageVisitors.java b/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/MessageVisitors.java new file mode 100644 index 0000000000..5a4de476b0 --- /dev/null +++ b/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/MessageVisitors.java @@ -0,0 +1,43 @@ +package io.temporal.internal.payload.visitor; + +import com.google.protobuf.Message; +import javax.annotation.Nonnull; + +/** + * Visits the messages within a proto message, invoking the message visitor on each, without + * visiting individual payloads. Only messages that can contain a payload are visited. + * + *

This is an SDK-internal utility; it is not part of the public API. + */ +final class MessageVisitors { + private MessageVisitors() {} + + /** Visits the messages in {@code builder} in place. */ + public static void visit( + @Nonnull Message.Builder builder, @Nonnull MessageVisitorOptions options) { + Traversal traversal = + new Traversal( + null, + options.getMessageVisitor(), + options.getInitialContext(), + /* skipSearchAttributes= */ false, + /* skipHeaders= */ false, + 1, + null, + GeneratedPayloadVisitor.REGISTRY); + traversal.dispatch(builder); + traversal.execute(); + } + + /** + * Visits the messages in {@code message}, returning a copy with any changes applied; the input is + * unchanged. + */ + @SuppressWarnings("unchecked") + public static T visit( + @Nonnull T message, @Nonnull MessageVisitorOptions options) { + Message.Builder builder = message.toBuilder(); + visit(builder, options); + return (T) builder.build(); + } +} diff --git a/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/PayloadVisitor.java b/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/PayloadVisitor.java new file mode 100644 index 0000000000..2f791aeb36 --- /dev/null +++ b/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/PayloadVisitor.java @@ -0,0 +1,26 @@ +package io.temporal.internal.payload.visitor; + +import io.temporal.api.common.v1.Payload; +import java.util.List; + +/** + * Callback for a sequence of payloads found in a proto message. The returned list replaces those + * payloads; return the same list to leave them unchanged. + * + *

When the visited field holds a single payload the list has one element and the visitor must + * return exactly one payload. With a concurrency limit greater than one, visits may run on multiple + * threads, so implementations must be thread-safe. + * + * @param type of the contextual value supplied to each visit + */ +@FunctionalInterface +interface PayloadVisitor { + /** + * Visits a sequence of payloads and returns their replacements. + * + * @param context the location of these payloads and the contextual value in scope + * @param payloads the payloads found at this location + * @return the replacement payloads + */ + List visit(PayloadVisitorContext context, List payloads); +} diff --git a/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/PayloadVisitorContext.java b/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/PayloadVisitorContext.java new file mode 100644 index 0000000000..756b1404d7 --- /dev/null +++ b/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/PayloadVisitorContext.java @@ -0,0 +1,33 @@ +package io.temporal.internal.payload.visitor; + +import com.google.protobuf.MessageOrBuilder; +import javax.annotation.Nonnull; +import javax.annotation.Nullable; + +/** + * The context for one payload visitor call: the contextual value in scope and the message that + * contains the payloads being visited. + * + * @param type of the contextual value + */ +final class PayloadVisitorContext { + private final @Nullable C context; + private final @Nonnull MessageOrBuilder parent; + + PayloadVisitorContext(@Nullable C context, @Nonnull MessageOrBuilder parent) { + this.context = context; + this.parent = parent; + } + + /** The contextual value in scope at this location, or {@code null} if none. */ + @Nullable + public C getContext() { + return context; + } + + /** The message that directly contains the payloads being visited. */ + @Nonnull + public MessageOrBuilder getParent() { + return parent; + } +} diff --git a/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/PayloadVisitorOptions.java b/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/PayloadVisitorOptions.java new file mode 100644 index 0000000000..fc9f6b9ecd --- /dev/null +++ b/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/PayloadVisitorOptions.java @@ -0,0 +1,141 @@ +package io.temporal.internal.payload.visitor; + +import java.util.concurrent.Executor; +import javax.annotation.Nonnull; +import javax.annotation.Nullable; + +/** + * Options for visiting the payloads of a proto message. + * + * @param type of the contextual value supplied to the visitor + */ +final class PayloadVisitorOptions { + private final @Nonnull PayloadVisitor payloadVisitor; + private final @Nullable MessageVisitor messageVisitor; + private final @Nullable C initialContext; + private final boolean skipSearchAttributes; + private final boolean skipHeaders; + private final int concurrency; + private final @Nullable Executor executor; + + private PayloadVisitorOptions(Builder b) { + this.payloadVisitor = b.payloadVisitor; + this.messageVisitor = b.messageVisitor; + this.initialContext = b.initialContext; + this.skipSearchAttributes = b.skipSearchAttributes; + this.skipHeaders = b.skipHeaders; + this.concurrency = b.concurrency; + this.executor = b.executor; + } + + public static Builder newBuilder() { + return new Builder<>(); + } + + @Nonnull + public PayloadVisitor getPayloadVisitor() { + return payloadVisitor; + } + + @Nullable + public MessageVisitor getMessageVisitor() { + return messageVisitor; + } + + @Nullable + public C getInitialContext() { + return initialContext; + } + + /** Whether search attribute payloads are skipped. */ + public boolean isSkipSearchAttributes() { + return skipSearchAttributes; + } + + /** Whether header payloads are skipped. */ + public boolean isSkipHeaders() { + return skipHeaders; + } + + /** Maximum number of visits that may run concurrently; {@code 1} is sequential. */ + public int getConcurrency() { + return concurrency; + } + + /** Executor for concurrent visits; {@code null} when concurrency is {@code 1}. */ + @Nullable + public Executor getExecutor() { + return executor; + } + + public static final class Builder { + private PayloadVisitor payloadVisitor; + private MessageVisitor messageVisitor; + private C initialContext; + private boolean skipSearchAttributes; + private boolean skipHeaders; + private int concurrency = 1; + private Executor executor; + + private Builder() {} + + /** Required. The payload visitor. */ + public Builder setPayloadVisitor(@Nonnull PayloadVisitor payloadVisitor) { + this.payloadVisitor = payloadVisitor; + return this; + } + + /** Optional. A callback invoked when entering each message. */ + public Builder setMessageVisitor(@Nullable MessageVisitor messageVisitor) { + this.messageVisitor = messageVisitor; + return this; + } + + /** Optional. The contextual value in scope before any message is entered. */ + public Builder setInitialContext(@Nullable C initialContext) { + this.initialContext = initialContext; + return this; + } + + /** Whether to skip search attribute payloads. */ + public Builder setSkipSearchAttributes(boolean skipSearchAttributes) { + this.skipSearchAttributes = skipSearchAttributes; + return this; + } + + /** Whether to skip header payloads. */ + public Builder setSkipHeaders(boolean skipHeaders) { + this.skipHeaders = skipHeaders; + return this; + } + + /** + * Maximum number of concurrent visits; must be at least {@code 1} (the default, sequential). A + * value greater than {@code 1} requires an executor (see {@link #setExecutor}). + */ + public Builder setConcurrency(int concurrency) { + this.concurrency = concurrency; + return this; + } + + /** Executor for concurrent visits. Required when concurrency is greater than {@code 1}. */ + public Builder setExecutor(@Nullable Executor executor) { + this.executor = executor; + return this; + } + + public PayloadVisitorOptions build() { + if (payloadVisitor == null) { + throw new IllegalArgumentException("payloadVisitor is required"); + } + if (concurrency < 1) { + throw new IllegalArgumentException("concurrency must be at least 1, got " + concurrency); + } + if (concurrency > 1 && executor == null) { + throw new IllegalArgumentException( + "executor is required when concurrency is greater than 1"); + } + return new PayloadVisitorOptions<>(this); + } + } +} diff --git a/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/PayloadVisitors.java b/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/PayloadVisitors.java new file mode 100644 index 0000000000..b810976ef0 --- /dev/null +++ b/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/PayloadVisitors.java @@ -0,0 +1,54 @@ +package io.temporal.internal.payload.visitor; + +import com.google.protobuf.Message; +import javax.annotation.Nonnull; + +/** + * Visits every payload within a proto message. A message with no payloads is returned unchanged. + * + *

This is an SDK-internal utility; it is not part of the public API. + * + *

{@code
+ * RespondWorkflowTaskCompletedRequest result =
+ *     PayloadVisitors.visit(
+ *         request,
+ *         PayloadVisitorOptions.newBuilder()
+ *             .setPayloadVisitor((ctx, payloads) -> encode(ctx.getContext(), payloads))
+ *             .setMessageVisitor((current, msg) -> msg instanceof Command.Builder
+ *                 ? CommandInfo.of((Command.Builder) msg) : current)
+ *             .setConcurrency(4)
+ *             .build());
+ * }
+ */ +final class PayloadVisitors { + private PayloadVisitors() {} + + /** Visits the payloads in {@code builder} in place. */ + public static void visit( + @Nonnull Message.Builder builder, @Nonnull PayloadVisitorOptions options) { + Traversal traversal = + new Traversal( + options.getPayloadVisitor(), + options.getMessageVisitor(), + options.getInitialContext(), + options.isSkipSearchAttributes(), + options.isSkipHeaders(), + options.getConcurrency(), + options.getExecutor(), + GeneratedPayloadVisitor.REGISTRY); + traversal.dispatch(builder); + traversal.execute(); + } + + /** + * Visits the payloads in {@code message}, returning a copy with replacements applied; the input + * is unchanged. + */ + @SuppressWarnings("unchecked") + public static T visit( + @Nonnull T message, @Nonnull PayloadVisitorOptions options) { + Message.Builder builder = message.toBuilder(); + visit(builder, options); + return (T) builder.build(); + } +} diff --git a/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/Traversal.java b/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/Traversal.java new file mode 100644 index 0000000000..a9922f64b1 --- /dev/null +++ b/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/Traversal.java @@ -0,0 +1,237 @@ +package io.temporal.internal.payload.visitor; + +import com.google.protobuf.Any; +import com.google.protobuf.InvalidProtocolBufferException; +import com.google.protobuf.Message; +import com.google.protobuf.MessageOrBuilder; +import io.temporal.api.common.v1.Payload; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Executor; +import java.util.concurrent.Semaphore; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; + +/** + * Mutable state for one traversal, called into by the generated per-message visitors. + * + *

A single-threaded walk records a visit job and a write-back for each payload sequence it + * finds; {@link #execute()} then runs the visitor calls (optionally with bounded concurrency) and + * finally applies the write-backs in walk order, so the non-thread-safe builders are never mutated + * concurrently. + */ +final class Traversal { + // The payload visitor is null for a message-only traversal (see MessageVisitors); in that case + // the payload seams are skipped and only the per-message MessageVisitor fires. + private final PayloadVisitor payloadVisitor; + private final MessageVisitor messageVisitor; + private final Map registry; + final boolean skipSearchAttributes; + final boolean skipHeaders; + private final int concurrency; + private final Executor executor; + + private final List jobs = new ArrayList<>(); + private final List writeBacks = new ArrayList<>(); + private Object currentContext; + + @SuppressWarnings("unchecked") + Traversal( + PayloadVisitor payloadVisitor, + MessageVisitor messageVisitor, + Object initialContext, + boolean skipSearchAttributes, + boolean skipHeaders, + int concurrency, + Executor executor, + Map registry) { + if (concurrency < 1) { + throw new IllegalArgumentException("concurrency must be at least 1, got " + concurrency); + } + if (concurrency > 1 && executor == null) { + throw new IllegalArgumentException("executor is required when concurrency is greater than 1"); + } + this.payloadVisitor = (PayloadVisitor) payloadVisitor; + this.messageVisitor = (MessageVisitor) messageVisitor; + this.currentContext = initialContext; + this.skipSearchAttributes = skipSearchAttributes; + this.skipHeaders = skipHeaders; + this.concurrency = concurrency; + this.executor = executor; + this.registry = registry; + } + + // --- Structural walk: called by generated code --- + + /** Dispatch to the generated visitor for {@code builder}'s type; no-op if it has no payloads. */ + void dispatch(Message.Builder builder) { + MessageRegistryEntry entry = registry.get(builder.getDescriptorForType().getFullName()); + if (entry != null) { + entry.visitor.visit(this, builder); + } + } + + /** + * Run the message visitor for {@code message}, narrowing the scoped context; returns the value to + * restore. + */ + Object enter(MessageOrBuilder message) { + Object previous = currentContext; + if (messageVisitor != null) { + currentContext = messageVisitor.onEnter(previous, message); + } + return previous; + } + + /** Restore the scoped context to {@code previous} when leaving a message's subtree. */ + void exit(Object previous) { + currentContext = previous; + } + + /** Record a visit of a payload sequence ({@code Payloads} or {@code repeated Payload}). */ + void payloads(MessageOrBuilder parent, List batch, Consumer> writeBack) { + if (payloadVisitor == null) { + return; // message-only traversal: payload seams are inert + } + LeafJob job = new LeafJob(batch, currentContext, parent, false); + jobs.add(job); + writeBacks.add(() -> writeBack.accept(job.result)); + } + + /** + * Record a visit of a singular payload field. The visitor must return exactly one payload for + * such a field (enforced in {@link #runJob}), which the consumer writes back. + */ + void singlePayload(MessageOrBuilder parent, Payload value, Consumer writeBack) { + if (payloadVisitor == null) { + return; // message-only traversal: payload seams are inert + } + LeafJob job = new LeafJob(Collections.singletonList(value), currentContext, parent, true); + jobs.add(job); + writeBacks.add(() -> writeBack.accept(job.result.get(0))); + } + + /** Append a deferred write-back, applied (single-threaded) after all visits and in walk order. */ + void deferWriteBack(Runnable writeBack) { + writeBacks.add(writeBack); + } + + /** Unpack a {@code google.protobuf.Any}, traverse its contents, and re-pack it after visits. */ + void any(Any.Builder anyBuilder) { + String typeUrl = anyBuilder.getTypeUrl(); + int slash = typeUrl.lastIndexOf('/'); + String fullName = slash >= 0 ? typeUrl.substring(slash + 1) : typeUrl; + MessageRegistryEntry entry = registry.get(fullName); + if (entry == null) { + // Unknown type, or a type with no payloads; leave the Any untouched. + return; + } + Message.Builder inner = entry.newBuilder.get(); + try { + inner.mergeFrom(anyBuilder.getValue()); + } catch (InvalidProtocolBufferException e) { + throw new VisitorException("failed to unpack Any of type " + fullName, e); + } + entry.visitor.visit(this, inner); + deferWriteBack(() -> anyBuilder.setValue(inner.build().toByteString())); + } + + // --- Execution: visitor calls (phase 2) then write-backs (phase 3) --- + + void execute() { + if (jobs.isEmpty()) { + return; + } + if (concurrency <= 1 || jobs.size() == 1) { + for (LeafJob job : jobs) { + runJob(job); + } + } else { + executeConcurrently(); + } + for (Runnable writeBack : writeBacks) { + writeBack.run(); + } + } + + private void executeConcurrently() { + // concurrency > 1 and a non-null executor are guaranteed by the constructor's validation. + Executor pool = executor; + Semaphore semaphore = new Semaphore(concurrency); + AtomicReference firstError = new AtomicReference<>(); + List> futures = new ArrayList<>(jobs.size()); + for (LeafJob job : jobs) { + if (firstError.get() != null) { + break; + } + try { + semaphore.acquire(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + firstError.compareAndSet(null, e); + break; + } + if (firstError.get() != null) { + semaphore.release(); + break; + } + futures.add( + CompletableFuture.runAsync( + () -> { + try { + runJob(job); + } catch (Throwable t) { + firstError.compareAndSet(null, t); + } finally { + semaphore.release(); + } + }, + pool)); + } + CompletableFuture.allOf(futures.toArray(new CompletableFuture[0])).join(); + Throwable error = firstError.get(); + if (error instanceof RuntimeException) { + throw (RuntimeException) error; + } + if (error instanceof Error) { + throw (Error) error; + } + if (error != null) { + // The only checked exception that can reach here is an InterruptedException from acquiring + // the semaphore. + throw new VisitorException("payload visit interrupted", error); + } + } + + private void runJob(LeafJob job) { + List result = + payloadVisitor.visit(new PayloadVisitorContext<>(job.context, job.parent), job.input); + if (result == null) { + throw new IllegalStateException("payload visitor returned null"); + } + if (job.single && result.size() != 1) { + throw new IllegalStateException( + "single-payload field requires exactly 1 returned payload, got " + result.size()); + } + job.result = result; + } + + /** A single recorded visitor call and the slot its result is written into. */ + private static final class LeafJob { + final List input; + final Object context; + final MessageOrBuilder parent; + final boolean single; + volatile List result; + + LeafJob(List input, Object context, MessageOrBuilder parent, boolean single) { + this.input = input; + this.context = context; + this.parent = parent; + this.single = single; + } + } +} diff --git a/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/VisitorException.java b/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/VisitorException.java new file mode 100644 index 0000000000..fd37d4ce05 --- /dev/null +++ b/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/VisitorException.java @@ -0,0 +1,15 @@ +package io.temporal.internal.payload.visitor; + +/** + * Thrown when visiting the payloads or messages of a proto message fails. The original failure, if + * any, is available via {@link #getCause()}. + */ +final class VisitorException extends RuntimeException { + VisitorException(String message, Throwable cause) { + super(message, cause); + } + + VisitorException(String message) { + super(message); + } +} diff --git a/temporal-sdk/src/payloadVisitorGenerator/java/io/temporal/internal/payload/visitor/gen/PayloadVisitorGenerator.java b/temporal-sdk/src/payloadVisitorGenerator/java/io/temporal/internal/payload/visitor/gen/PayloadVisitorGenerator.java new file mode 100644 index 0000000000..de435aa81e --- /dev/null +++ b/temporal-sdk/src/payloadVisitorGenerator/java/io/temporal/internal/payload/visitor/gen/PayloadVisitorGenerator.java @@ -0,0 +1,593 @@ +package io.temporal.internal.payload.visitor.gen; + +import com.google.protobuf.Descriptors.Descriptor; +import com.google.protobuf.Descriptors.FieldDescriptor; +import com.google.protobuf.Descriptors.FileDescriptor; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Deque; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.TreeMap; + +/** + * Build-time generator that emits {@code GeneratedPayloadVisitor}. + * + *

Starting from the WorkflowService and OperatorService file descriptors, it walks the proto + * closure, determines which message types can transitively contain a {@code Payload} (or a {@code + * google.protobuf.Any}, treated conservatively as payload-bearing), and emits one {@code visit_*} + * method per such type plus a registry keyed by descriptor full name. + * + *

Usage: {@code PayloadVisitorGenerator }. + */ +public final class PayloadVisitorGenerator { + + static final String PAYLOAD = "temporal.api.common.v1.Payload"; + static final String PAYLOADS = "temporal.api.common.v1.Payloads"; + static final String ANY = "google.protobuf.Any"; + static final String SEARCH_ATTRIBUTES = "temporal.api.common.v1.SearchAttributes"; + static final String HEADER = "temporal.api.common.v1.Header"; + + static final String OUTPUT_PACKAGE = "io.temporal.internal.payload.visitor"; + static final String OUTPUT_CLASS = "GeneratedPayloadVisitor"; + static final String PAYLOADS_FQN = "io.temporal.api.common.v1.Payloads"; + static final int REGISTER_CHUNK = 40; + + enum Kind { + SINGLE_PAYLOAD, + REPEATED_PAYLOAD, + PAYLOADS_SINGLE, + PAYLOADS_REPEATED, + MAP_PAYLOAD, + MAP_PAYLOADS, + ANY_SINGLE, + ANY_REPEATED, + MAP_ANY, + MESSAGE_SINGLE, + MESSAGE_REPEATED, + MAP_MESSAGE, + IGNORE + } + + /** Classification of a field: how it should be traversed, and its child message type if any. */ + static final class FieldPlan { + final Kind kind; + final Descriptor child; // child message descriptor for MESSAGE_* / MAP_MESSAGE, else null + + FieldPlan(Kind kind, Descriptor child) { + this.kind = kind; + this.child = child; + } + } + + public static void main(String[] args) throws Exception { + if (args.length < 1) { + throw new IllegalArgumentException("usage: PayloadVisitorGenerator "); + } + new PayloadVisitorGenerator().run(Paths.get(args[0])); + } + + private ProtoClosure closure; + + void run(Path outputRoot) throws IOException { + List seeds = + Arrays.asList( + io.temporal.api.workflowservice.v1.ServiceProto.getDescriptor(), + io.temporal.api.operatorservice.v1.ServiceProto.getDescriptor()); + + this.closure = ProtoClosure.of(seeds); + + // Deterministic output order, keyed by descriptor full name. + Map emitted = new TreeMap<>(); + for (Descriptor d : closure.allMessages) { + if (reaches(d)) { + emitted.put(d.getFullName(), d); + } + } + + verifyAccessors(emitted.values()); + + String source = emit(emitted); + + Path dir = outputRoot; + for (String part : OUTPUT_PACKAGE.split("\\.", -1)) { + dir = dir.resolve(part); + } + Files.createDirectories(dir); + Path out = dir.resolve(OUTPUT_CLASS + ".java"); + Files.write(out, source.getBytes(StandardCharsets.UTF_8)); + System.out.println("PayloadVisitorGenerator: wrote " + emitted.size() + " visitors to " + out); + } + + // --- Reachability + classification --- + + /** Whether {@code d} can transitively contain a payload; delegates to the shared closure. */ + private boolean reaches(Descriptor d) { + return closure.reaches(d); + } + + static FieldPlan classify(FieldDescriptor f) { + if (f.isMapField()) { + FieldDescriptor value = f.getMessageType().findFieldByNumber(2); + if (value.getJavaType() == FieldDescriptor.JavaType.MESSAGE) { + String name = value.getMessageType().getFullName(); + if (PAYLOAD.equals(name)) { + return new FieldPlan(Kind.MAP_PAYLOAD, null); + } + if (PAYLOADS.equals(name)) { + return new FieldPlan(Kind.MAP_PAYLOADS, null); + } + if (ANY.equals(name)) { + return new FieldPlan(Kind.MAP_ANY, null); + } + if (isTemporal(value.getMessageType())) { + return new FieldPlan(Kind.MAP_MESSAGE, value.getMessageType()); + } + return new FieldPlan(Kind.IGNORE, null); + } + return new FieldPlan(Kind.IGNORE, null); + } + if (f.getJavaType() != FieldDescriptor.JavaType.MESSAGE) { + return new FieldPlan(Kind.IGNORE, null); + } + String name = f.getMessageType().getFullName(); + boolean repeated = f.isRepeated(); + if (PAYLOAD.equals(name)) { + return new FieldPlan(repeated ? Kind.REPEATED_PAYLOAD : Kind.SINGLE_PAYLOAD, null); + } + if (PAYLOADS.equals(name)) { + return new FieldPlan(repeated ? Kind.PAYLOADS_REPEATED : Kind.PAYLOADS_SINGLE, null); + } + if (ANY.equals(name)) { + return new FieldPlan(repeated ? Kind.ANY_REPEATED : Kind.ANY_SINGLE, null); + } + if (!isTemporal(f.getMessageType())) { + // Non-Temporal messages (google well-known types, etc.) never carry Temporal payloads + // except inside an Any, which is handled separately. + return new FieldPlan(Kind.IGNORE, null); + } + return new FieldPlan( + repeated ? Kind.MESSAGE_REPEATED : Kind.MESSAGE_SINGLE, f.getMessageType()); + } + + static boolean isTemporal(Descriptor d) { + return d.getFullName().startsWith("temporal."); + } + + // --- Java naming --- + + /** Mirrors protoc's UnderscoresToCamelCase used to derive Java accessor names. */ + static String camel(String input, boolean capNext) { + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < input.length(); i++) { + char c = input.charAt(i); + if (c >= 'a' && c <= 'z') { + sb.append(capNext ? Character.toUpperCase(c) : c); + capNext = false; + } else if (c >= 'A' && c <= 'Z') { + if (i == 0 && !capNext) { + sb.append(Character.toLowerCase(c)); + } else { + sb.append(c); + } + capNext = false; + } else if (c >= '0' && c <= '9') { + sb.append(c); + capNext = true; + } else { + capNext = true; + } + } + return sb.toString(); + } + + /** Capitalized accessor base, e.g. {@code schedule_activity} -> {@code ScheduleActivity}. */ + static String base(FieldDescriptor f) { + return camel(f.getName(), true); + } + + static String javaPackage(Descriptor d) { + String pkg = d.getFile().getOptions().getJavaPackage(); + if (pkg == null || pkg.isEmpty()) { + throw new IllegalStateException("message " + d.getFullName() + " has no java_package option"); + } + return pkg; + } + + /** + * Source-form class name, e.g. {@code io.temporal.api.common.v1.Payload.ExternalPayloadDetails}. + */ + static String sourceClassName(Descriptor d) { + Deque names = new ArrayDeque<>(); + for (Descriptor c = d; c != null; c = c.getContainingType()) { + names.addFirst(c.getName()); + } + return javaPackage(d) + "." + String.join(".", names); + } + + /** Binary class name (nested types joined with {@code $}) for reflective verification. */ + static String binaryClassName(Descriptor d) { + Deque names = new ArrayDeque<>(); + for (Descriptor c = d; c != null; c = c.getContainingType()) { + names.addFirst(c.getName()); + } + return javaPackage(d) + "." + String.join("$", names); + } + + static String methodName(String full) { + return "visit_" + full.replace('.', '_'); + } + + // --- Accessor verification (build-time safety net for the naming rules) --- + + private void verifyAccessors(Iterable descriptors) { + for (Descriptor d : descriptors) { + Class builder; + try { + builder = Class.forName(binaryClassName(d) + "$Builder"); + } catch (ClassNotFoundException e) { + throw new IllegalStateException("no builder class for " + d.getFullName(), e); + } + Set methods = new HashSet<>(); + for (java.lang.reflect.Method m : builder.getMethods()) { + methods.add(m.getName()); + } + for (FieldDescriptor f : d.getFields()) { + for (String required : requiredMethods(classify(f).kind, base(f))) { + if (!methods.contains(required)) { + throw new IllegalStateException( + "expected builder method " + + builder.getName() + + "#" + + required + + " for field " + + d.getFullName() + + "." + + f.getName() + + " (" + + classify(f).kind + + ")"); + } + } + } + } + } + + static List requiredMethods(Kind kind, String base) { + switch (kind) { + case SINGLE_PAYLOAD: + return Arrays.asList("has" + base, "get" + base, "set" + base); + case REPEATED_PAYLOAD: + return Arrays.asList("get" + base + "List", "clear" + base, "addAll" + base); + case PAYLOADS_SINGLE: + return Arrays.asList("has" + base, "get" + base, "set" + base); + case PAYLOADS_REPEATED: + return Arrays.asList("get" + base, "get" + base + "Count", "set" + base); + case MAP_PAYLOAD: + return Arrays.asList("get" + base + "Map", "put" + base); + case MAP_PAYLOADS: + case MAP_ANY: + case MAP_MESSAGE: + return Arrays.asList("get" + base + "Map", "put" + base); + case ANY_SINGLE: + case MESSAGE_SINGLE: + return Arrays.asList("has" + base, "get" + base + "Builder"); + case ANY_REPEATED: + case MESSAGE_REPEATED: + return Arrays.asList("get" + base + "BuilderList"); + default: + return Arrays.asList(); + } + } + + // --- Emission --- + + private String emit(Map emitted) { + StringBuilder sb = new StringBuilder(); + sb.append("// Code generated by PayloadVisitorGenerator; DO NOT EDIT.\n"); + sb.append("package ").append(OUTPUT_PACKAGE).append(";\n\n"); + sb.append("import java.util.ArrayList;\n"); + sb.append("import java.util.HashMap;\n"); + sb.append("import java.util.Map;\n\n"); + sb.append("@SuppressWarnings(\"deprecation\")\n"); + sb.append("final class ").append(OUTPUT_CLASS).append(" {\n"); + sb.append(" private ").append(OUTPUT_CLASS).append("() {}\n\n"); + + List list = new ArrayList<>(emitted.values()); + + sb.append(" static final Map REGISTRY = buildRegistry();\n\n"); + sb.append(" private static Map buildRegistry() {\n"); + sb.append(" Map m = new HashMap<>(") + .append(Math.max(16, list.size() * 2)) + .append(");\n"); + int chunks = (list.size() + REGISTER_CHUNK - 1) / REGISTER_CHUNK; + for (int i = 0; i < chunks; i++) { + sb.append(" register").append(i).append("(m);\n"); + } + sb.append(" return m;\n"); + sb.append(" }\n\n"); + + for (int i = 0; i < chunks; i++) { + sb.append(" private static void register") + .append(i) + .append("(Map m) {\n"); + int start = i * REGISTER_CHUNK; + int end = Math.min(start + REGISTER_CHUNK, list.size()); + for (int j = start; j < end; j++) { + Descriptor d = list.get(j); + String src = sourceClassName(d); + String mn = methodName(d.getFullName()); + sb.append(" m.put(\"") + .append(d.getFullName()) + .append("\", new MessageRegistryEntry((t, b) -> ") + .append(mn) + .append("(t, (") + .append(src) + .append(".Builder) b), ") + .append(src) + .append("::newBuilder));\n"); + } + sb.append(" }\n\n"); + } + + for (Descriptor d : list) { + emitVisitMethod(sb, d); + } + + sb.append("}\n"); + return sb.toString(); + } + + private void emitVisitMethod(StringBuilder sb, Descriptor d) { + String src = sourceClassName(d); + sb.append(" static void ") + .append(methodName(d.getFullName())) + .append("(Traversal t, ") + .append(src) + .append(".Builder b) {\n"); + sb.append(" Object __c = t.enter(b);\n"); + int fi = 0; + for (FieldDescriptor f : d.getFields()) { + FieldPlan plan = classify(f); + if (plan.kind == Kind.IGNORE) { + continue; + } + if ((plan.kind == Kind.MESSAGE_SINGLE + || plan.kind == Kind.MESSAGE_REPEATED + || plan.kind == Kind.MAP_MESSAGE) + && !reaches(plan.child)) { + continue; + } + emitField(sb, f, plan, fi++); + } + sb.append(" t.exit(__c);\n"); + sb.append(" }\n\n"); + } + + private void emitField(StringBuilder sb, FieldDescriptor f, FieldPlan plan, int fi) { + String B = base(f); + String k = "__key" + fi; + String v = "__v" + fi; + switch (plan.kind) { + case SINGLE_PAYLOAD: + sb.append(" if (b.has").append(B).append("()) {\n"); + sb.append(" t.singlePayload(b, b.get") + .append(B) + .append("(), p -> b.set") + .append(B) + .append("(p));\n"); + sb.append(" }\n"); + break; + case REPEATED_PAYLOAD: + sb.append(" t.payloads(b, b.get").append(B).append("List(), pl -> {\n"); + sb.append(" b.clear").append(B).append("();\n"); + sb.append(" b.addAll").append(B).append("(pl);\n"); + sb.append(" });\n"); + break; + case PAYLOADS_SINGLE: + sb.append(" if (b.has").append(B).append("()) {\n"); + sb.append(" t.payloads(b, b.get").append(B).append("().getPayloadsList(),\n"); + sb.append(" pl -> b.set") + .append(B) + .append("(") + .append(PAYLOADS_FQN) + .append(".newBuilder().addAllPayloads(pl).build()));\n"); + sb.append(" }\n"); + break; + case PAYLOADS_REPEATED: + sb.append(" for (int ") + .append(v) + .append(" = 0; ") + .append(v) + .append(" < b.get") + .append(B) + .append("Count(); ") + .append(v) + .append("++) {\n"); + sb.append(" final int ").append(k).append(" = ").append(v).append(";\n"); + sb.append(" t.payloads(b, b.get") + .append(B) + .append("(") + .append(k) + .append(").getPayloadsList(),\n"); + sb.append(" pl -> b.set") + .append(B) + .append("(") + .append(k) + .append(", ") + .append(PAYLOADS_FQN) + .append(".newBuilder().addAllPayloads(pl).build()));\n"); + sb.append(" }\n"); + break; + case MAP_PAYLOAD: + sb.append(" for (String ") + .append(k) + .append(" : new ArrayList<>(b.get") + .append(B) + .append("Map().keySet())) {\n"); + sb.append(" final String ").append(v).append(" = ").append(k).append(";\n"); + sb.append(" t.singlePayload(b, b.get") + .append(B) + .append("Map().get(") + .append(v) + .append("), p -> b.put") + .append(B) + .append("(") + .append(v) + .append(", p));\n"); + sb.append(" }\n"); + break; + case MAP_PAYLOADS: + sb.append(" for (String ") + .append(k) + .append(" : new ArrayList<>(b.get") + .append(B) + .append("Map().keySet())) {\n"); + sb.append(" final String ").append(v).append(" = ").append(k).append(";\n"); + sb.append(" t.payloads(b, b.get") + .append(B) + .append("Map().get(") + .append(v) + .append(").getPayloadsList(),\n"); + sb.append(" pl -> b.put") + .append(B) + .append("(") + .append(v) + .append(", ") + .append(PAYLOADS_FQN) + .append(".newBuilder().addAllPayloads(pl).build()));\n"); + sb.append(" }\n"); + break; + case ANY_SINGLE: + sb.append(" if (b.has").append(B).append("()) {\n"); + sb.append(" t.any(b.get").append(B).append("Builder());\n"); + sb.append(" }\n"); + break; + case ANY_REPEATED: + sb.append(" for (com.google.protobuf.Any.Builder ") + .append(v) + .append(" : b.get") + .append(B) + .append("BuilderList()) {\n"); + sb.append(" t.any(").append(v).append(");\n"); + sb.append(" }\n"); + break; + case MAP_ANY: + sb.append(" for (String ") + .append(k) + .append(" : new ArrayList<>(b.get") + .append(B) + .append("Map().keySet())) {\n"); + sb.append(" final String ").append(v).append(" = ").append(k).append(";\n"); + sb.append(" com.google.protobuf.Any.Builder ab") + .append(fi) + .append(" = b.get") + .append(B) + .append("Map().get(") + .append(v) + .append(").toBuilder();\n"); + sb.append(" t.any(ab").append(fi).append(");\n"); + sb.append(" t.deferWriteBack(() -> b.put") + .append(B) + .append("(") + .append(v) + .append(", ab") + .append(fi) + .append(".build()));\n"); + sb.append(" }\n"); + break; + case MESSAGE_SINGLE: + { + String guard = childGuard(plan.child); + sb.append(" if (").append(guard).append("b.has").append(B).append("()) {\n"); + sb.append(" ") + .append(methodName(plan.child.getFullName())) + .append("(t, b.get") + .append(B) + .append("Builder());\n"); + sb.append(" }\n"); + } + break; + case MESSAGE_REPEATED: + { + String childSrc = sourceClassName(plan.child); + String guard = childGuard(plan.child); + if (!guard.isEmpty()) { + sb.append(" if (").append(guard.substring(0, guard.length() - 4)).append(") {\n "); + } + sb.append(" for (") + .append(childSrc) + .append(".Builder ") + .append(v) + .append(" : b.get") + .append(B) + .append("BuilderList()) {\n"); + sb.append(" ") + .append(methodName(plan.child.getFullName())) + .append("(t, ") + .append(v) + .append(");\n"); + sb.append(" }\n"); + if (!guard.isEmpty()) { + sb.append(" }\n"); + } + } + break; + case MAP_MESSAGE: + { + String childSrc = sourceClassName(plan.child); + sb.append(" for (String ") + .append(k) + .append(" : new ArrayList<>(b.get") + .append(B) + .append("Map().keySet())) {\n"); + sb.append(" final String ").append(v).append(" = ").append(k).append(";\n"); + sb.append(" ") + .append(childSrc) + .append(".Builder vb") + .append(fi) + .append(" = b.get") + .append(B) + .append("Map().get(") + .append(v) + .append(").toBuilder();\n"); + sb.append(" ") + .append(methodName(plan.child.getFullName())) + .append("(t, vb") + .append(fi) + .append(");\n"); + sb.append(" t.deferWriteBack(() -> b.put") + .append(B) + .append("(") + .append(v) + .append(", vb") + .append(fi) + .append(".build()));\n"); + sb.append(" }\n"); + } + break; + default: + break; + } + } + + /** Optional {@code &&}-terminated guard expression for SearchAttributes/Header skipping. */ + private String childGuard(Descriptor child) { + String name = child.getFullName(); + if (SEARCH_ATTRIBUTES.equals(name)) { + return "!t.skipSearchAttributes && "; + } + if (HEADER.equals(name)) { + return "!t.skipHeaders && "; + } + return ""; + } +} diff --git a/temporal-sdk/src/payloadVisitorGenerator/java/io/temporal/internal/payload/visitor/gen/ProtoClosure.java b/temporal-sdk/src/payloadVisitorGenerator/java/io/temporal/internal/payload/visitor/gen/ProtoClosure.java new file mode 100644 index 0000000000..cea64996d1 --- /dev/null +++ b/temporal-sdk/src/payloadVisitorGenerator/java/io/temporal/internal/payload/visitor/gen/ProtoClosure.java @@ -0,0 +1,138 @@ +package io.temporal.internal.payload.visitor.gen; + +import com.google.protobuf.Descriptors.Descriptor; +import com.google.protobuf.Descriptors.FieldDescriptor; +import com.google.protobuf.Descriptors.FileDescriptor; +import io.temporal.internal.payload.visitor.gen.PayloadVisitorGenerator.FieldPlan; +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Deque; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +/** + * Shared proto-descriptor model for the build-time generators: the message closure reachable from a + * set of seed services, and which of those messages can transitively contain a {@code Payload}. + */ +final class ProtoClosure { + + /** All non-map-entry messages in the closure, in discovery order. */ + final List allMessages; + + private final Set reaches; + + private ProtoClosure(List allMessages, Set reaches) { + this.allMessages = allMessages; + this.reaches = reaches; + } + + /** Whether {@code d} can transitively contain a payload. */ + boolean reaches(Descriptor d) { + return reaches.contains(d.getFullName()); + } + + /** Builds the closure and payload-reachability set from the given seed file descriptors. */ + static ProtoClosure of(List seeds) { + List all = collectMessages(fileClosure(seeds)); + return new ProtoClosure(all, computeReachability(all)); + } + + // --- Descriptor discovery --- + + private static Set fileClosure(List seeds) { + Set seen = new LinkedHashSet<>(); + Deque queue = new ArrayDeque<>(seeds); + while (!queue.isEmpty()) { + FileDescriptor f = queue.poll(); + if (seen.add(f)) { + queue.addAll(f.getDependencies()); + } + } + return seen; + } + + private static List collectMessages(Set files) { + List result = new ArrayList<>(); + for (FileDescriptor f : files) { + for (Descriptor d : f.getMessageTypes()) { + collectMessages(d, result); + } + } + return result; + } + + private static void collectMessages(Descriptor d, List out) { + if (d.getOptions().getMapEntry()) { + return; // synthetic map entry type; handled via the owning map field + } + out.add(d); + for (Descriptor nested : d.getNestedTypes()) { + collectMessages(nested, out); + } + } + + // --- Reachability --- + + /** + * Least-fixpoint reachability over the message-reference graph. A message reaches a payload if it + * has a direct payload/Any field, or it references (via a message or map-message field) another + * message that does. Iterating to a fixpoint handles cycles (e.g. {@code Failure.cause}) + * correctly without over-approximating payload-free cycles. + */ + private static Set computeReachability(List all) { + Set reaches = new HashSet<>(); + Map> children = new HashMap<>(); + for (Descriptor d : all) { + boolean direct = false; + List refs = new ArrayList<>(); + for (FieldDescriptor f : d.getFields()) { + FieldPlan plan = PayloadVisitorGenerator.classify(f); + switch (plan.kind) { + case SINGLE_PAYLOAD: + case REPEATED_PAYLOAD: + case PAYLOADS_SINGLE: + case PAYLOADS_REPEATED: + case MAP_PAYLOAD: + case MAP_PAYLOADS: + case ANY_SINGLE: + case ANY_REPEATED: + case MAP_ANY: + direct = true; + break; + case MESSAGE_SINGLE: + case MESSAGE_REPEATED: + case MAP_MESSAGE: + refs.add(plan.child); + break; + default: + break; + } + } + if (direct) { + reaches.add(d.getFullName()); + } + children.put(d.getFullName(), refs); + } + boolean changed = true; + while (changed) { + changed = false; + for (Descriptor d : all) { + if (reaches.contains(d.getFullName())) { + continue; + } + for (Descriptor c : children.get(d.getFullName())) { + if (reaches.contains(c.getFullName())) { + reaches.add(d.getFullName()); + changed = true; + break; + } + } + } + } + return reaches; + } +} diff --git a/temporal-sdk/src/test/java/io/temporal/internal/payload/visitor/MessageVisitorTest.java b/temporal-sdk/src/test/java/io/temporal/internal/payload/visitor/MessageVisitorTest.java new file mode 100644 index 0000000000..d6239ac744 --- /dev/null +++ b/temporal-sdk/src/test/java/io/temporal/internal/payload/visitor/MessageVisitorTest.java @@ -0,0 +1,192 @@ +package io.temporal.internal.payload.visitor; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; + +import com.google.protobuf.ByteString; +import io.temporal.api.command.v1.Command; +import io.temporal.api.command.v1.CompleteWorkflowExecutionCommandAttributes; +import io.temporal.api.command.v1.RecordMarkerCommandAttributes; +import io.temporal.api.command.v1.ScheduleActivityTaskCommandAttributes; +import io.temporal.api.common.v1.Memo; +import io.temporal.api.common.v1.Payload; +import io.temporal.api.common.v1.Payloads; +import io.temporal.api.enums.v1.CommandType; +import io.temporal.api.workflowservice.v1.RespondWorkflowTaskCompletedRequest; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import org.junit.Test; + +/** Tests for {@link MessageVisitors}: message traversal with scoped context, and validation. */ +public class MessageVisitorTest { + + static Payload p(String s) { + return Payload.newBuilder().setData(ByteString.copyFromUtf8(s)).build(); + } + + static Command activity(String id, String... inputs) { + Payloads.Builder in = Payloads.newBuilder(); + for (String s : inputs) { + in.addPayloads(p(s)); + } + return Command.newBuilder() + .setCommandType(CommandType.COMMAND_TYPE_SCHEDULE_ACTIVITY_TASK) + .setScheduleActivityTaskCommandAttributes( + ScheduleActivityTaskCommandAttributes.newBuilder().setActivityId(id).setInput(in)) + .build(); + } + + @Test + public void visitsBuilderInPlace() { + Memo.Builder builder = Memo.newBuilder().putFields("k", p("v")); + List entered = new ArrayList<>(); + MessageVisitors.visit( + builder, + MessageVisitorOptions.newBuilder() + .setMessageVisitor( + (current, msg) -> { + entered.add(msg.getDescriptorForType().getFullName()); + return current; + }) + .build()); + assertEquals(Arrays.asList("temporal.api.common.v1.Memo"), entered); + } + + @Test + public void messageVisitorMutatesInPlace() { + RespondWorkflowTaskCompletedRequest request = + RespondWorkflowTaskCompletedRequest.newBuilder().addCommands(activity("orig", "x")).build(); + + RespondWorkflowTaskCompletedRequest result = + MessageVisitors.visit( + request, + MessageVisitorOptions.newBuilder() + .setMessageVisitor( + (current, msg) -> { + if (msg instanceof ScheduleActivityTaskCommandAttributes.Builder) { + ((ScheduleActivityTaskCommandAttributes.Builder) msg) + .setActivityId("rewritten"); + } + return current; + }) + .build()); + + assertEquals( + "rewritten", + result.getCommands(0).getScheduleActivityTaskCommandAttributes().getActivityId()); + } + + @Test + public void visitsEachMessageWithScopedContext() { + // Three commands with distinct types exercise per-command scoping and scope restoration + // between siblings. + RespondWorkflowTaskCompletedRequest request = + RespondWorkflowTaskCompletedRequest.newBuilder() + .addCommands(activity("a", "x")) + .addCommands( + Command.newBuilder() + .setCommandType(CommandType.COMMAND_TYPE_COMPLETE_WORKFLOW_EXECUTION) + .setCompleteWorkflowExecutionCommandAttributes( + CompleteWorkflowExecutionCommandAttributes.newBuilder()) + .build()) + .addCommands( + Command.newBuilder() + .setCommandType(CommandType.COMMAND_TYPE_RECORD_MARKER) + .setRecordMarkerCommandAttributes( + RecordMarkerCommandAttributes.newBuilder().setMarkerName("m")) + .build()) + .build(); + + // MessageVisitors traversal is single-threaded, so the entered messages have a stable order. + List entered = new ArrayList<>(); + List contextOnEnter = new ArrayList<>(); + + MessageVisitorOptions opts = + MessageVisitorOptions.newBuilder() + .setMessageVisitor( + (current, msg) -> { + entered.add(msg.getDescriptorForType().getFullName()); + contextOnEnter.add(current); + return msg instanceof Command.Builder + ? ((Command.Builder) msg).getCommandType() + : current; + }) + .build(); + + MessageVisitors.visit(request, opts); + + // Exact order: the root, then each (repeated) command followed by its oneof attributes message. + assertEquals( + Arrays.asList( + "temporal.api.workflowservice.v1.RespondWorkflowTaskCompletedRequest", + "temporal.api.command.v1.Command", + "temporal.api.command.v1.ScheduleActivityTaskCommandAttributes", + "temporal.api.command.v1.Command", + "temporal.api.command.v1.CompleteWorkflowExecutionCommandAttributes", + "temporal.api.command.v1.Command", + "temporal.api.command.v1.RecordMarkerCommandAttributes"), + entered); + // Each command is entered with scope reset to null (restored between siblings), then its own + // type flows down into its attributes message. + assertEquals( + Arrays.asList( + null, + null, + CommandType.COMMAND_TYPE_SCHEDULE_ACTIVITY_TASK, + null, + CommandType.COMMAND_TYPE_COMPLETE_WORKFLOW_EXECUTION, + null, + CommandType.COMMAND_TYPE_RECORD_MARKER), + contextOnEnter); + } + + @Test + public void messageOnlyVisitorValidatesPerMessageType() { + int maxMemoFields = 2; + MessageVisitorOptions opts = + MessageVisitorOptions.newBuilder() + .setMessageVisitor( + (current, msg) -> { + if (msg instanceof Memo.Builder + && ((Memo.Builder) msg).getFieldsCount() > maxMemoFields) { + throw new TestVisitorException("too many memo fields"); + } + return current; + }) + .build(); + + Memo ok = Memo.newBuilder().putFields("a", p("1")).putFields("b", p("2")).build(); + MessageVisitors.visit(ok, opts); // no throw + + Memo tooMany = + Memo.newBuilder() + .putFields("a", p("1")) + .putFields("b", p("2")) + .putFields("c", p("3")) + .build(); + assertThrows(TestVisitorException.class, () -> MessageVisitors.visit(tooMany, opts)); + } + + @Test + public void initialContextObservedAtRoot() { + Memo memo = Memo.newBuilder().putFields("k", p("v")).build(); + List observed = new ArrayList<>(); + MessageVisitors.visit( + memo, + MessageVisitorOptions.newBuilder() + .setInitialContext("root") + .setMessageVisitor( + (current, msg) -> { + observed.add(current); + return current; + }) + .build()); + assertEquals(Arrays.asList("root"), observed); + } + + @Test + public void rejectsMissingMessageVisitor() { + assertThrows(IllegalArgumentException.class, () -> MessageVisitorOptions.newBuilder().build()); + } +} diff --git a/temporal-sdk/src/test/java/io/temporal/internal/payload/visitor/PayloadVisitorTest.java b/temporal-sdk/src/test/java/io/temporal/internal/payload/visitor/PayloadVisitorTest.java new file mode 100644 index 0000000000..7870a518c3 --- /dev/null +++ b/temporal-sdk/src/test/java/io/temporal/internal/payload/visitor/PayloadVisitorTest.java @@ -0,0 +1,951 @@ +package io.temporal.internal.payload.visitor; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + +import com.google.protobuf.Any; +import com.google.protobuf.ByteString; +import io.temporal.api.command.v1.Command; +import io.temporal.api.command.v1.RecordMarkerCommandAttributes; +import io.temporal.api.command.v1.ScheduleActivityTaskCommandAttributes; +import io.temporal.api.command.v1.ScheduleNexusOperationCommandAttributes; +import io.temporal.api.command.v1.StartChildWorkflowExecutionCommandAttributes; +import io.temporal.api.command.v1.UpsertWorkflowSearchAttributesCommandAttributes; +import io.temporal.api.common.v1.Header; +import io.temporal.api.common.v1.Memo; +import io.temporal.api.common.v1.Payload; +import io.temporal.api.common.v1.Payloads; +import io.temporal.api.common.v1.SearchAttributes; +import io.temporal.api.enums.v1.CommandType; +import io.temporal.api.failure.v1.ApplicationFailureInfo; +import io.temporal.api.failure.v1.Failure; +import io.temporal.api.protocol.v1.Message; +import io.temporal.api.query.v1.WorkflowQueryResult; +import io.temporal.api.workflowservice.v1.CountWorkflowExecutionsResponse; +import io.temporal.api.workflowservice.v1.RespondWorkflowTaskCompletedRequest; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.BrokenBarrierException; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CyclicBarrier; +import java.util.concurrent.Executor; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Collectors; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +public class PayloadVisitorTest { + + static Payload p(String s) { + return Payload.newBuilder().setData(ByteString.copyFromUtf8(s)).build(); + } + + static String data(Payload p) { + return p.getData().toStringUtf8(); + } + + static Payloads payloads(String... values) { + Payloads.Builder b = Payloads.newBuilder(); + for (String v : values) { + b.addPayloads(p(v)); + } + return b.build(); + } + + /** + * Records every payload seen (in order) and the number of visit calls, leaving payloads + * unchanged. + */ + static final class CollectingVisitor implements PayloadVisitor { + final List seen = Collections.synchronizedList(new ArrayList<>()); + final AtomicInteger visits = new AtomicInteger(); + + @Override + public List visit(PayloadVisitorContext ctx, List payloads) { + visits.incrementAndGet(); + for (Payload p : payloads) { + seen.add(data(p)); + } + return payloads; + } + } + + static PayloadVisitorOptions options(PayloadVisitor visitor) { + return PayloadVisitorOptions.newBuilder().setPayloadVisitor(visitor).build(); + } + + static Command activity(String activityId, Payloads input) { + return Command.newBuilder() + .setCommandType(CommandType.COMMAND_TYPE_SCHEDULE_ACTIVITY_TASK) + .setScheduleActivityTaskCommandAttributes( + ScheduleActivityTaskCommandAttributes.newBuilder() + .setActivityId(activityId) + .setInput(input)) + .build(); + } + + /** Executor supplied to the concurrent visits (unused by the single-threaded tests). */ + private ExecutorService executor; + + @Before + public void setUpExecutor() { + executor = Executors.newCachedThreadPool(); + } + + @After + public void tearDownExecutor() { + executor.shutdownNow(); + } + + @Test + public void visitsAndMutatesAllPayloads() { + RespondWorkflowTaskCompletedRequest request = + RespondWorkflowTaskCompletedRequest.newBuilder() + .addCommands(activity("a", payloads("one", "two"))) + .addCommands(activity("b", payloads("three"))) + .build(); + + CollectingVisitor counter = new CollectingVisitor(); + RespondWorkflowTaskCompletedRequest unchanged = + PayloadVisitors.visit(request, options(counter)); + assertEquals(java.util.Arrays.asList("one", "two", "three"), counter.seen); + // Two Payloads sequences (one per command's input): two visits, three payloads. + assertEquals(2, counter.visits.get()); + assertEquals(request, unchanged); + + // Mutating: uppercase every payload's data. + RespondWorkflowTaskCompletedRequest mutated = + PayloadVisitors.visit( + request, + options( + (ctx, pls) -> + pls.stream() + .map( + p -> + p.toBuilder() + .setData(ByteString.copyFromUtf8(data(p).toUpperCase())) + .build()) + .collect(Collectors.toList()))); + assertEquals( + payloads("ONE", "TWO"), + mutated.getCommands(0).getScheduleActivityTaskCommandAttributes().getInput()); + assertEquals( + payloads("THREE"), + mutated.getCommands(1).getScheduleActivityTaskCommandAttributes().getInput()); + } + + @Test + public void visitsSinglePayloadField() { + Command command = + Command.newBuilder() + .setScheduleNexusOperationCommandAttributes( + ScheduleNexusOperationCommandAttributes.newBuilder().setInput(p("nexus"))) + .build(); + + CollectingVisitor counter = new CollectingVisitor(); + Command result = PayloadVisitors.visit(command, options(counter)); + assertEquals(Collections.singletonList("nexus"), counter.seen); + + // A single-payload field can be replaced with one payload. + Command observed = + PayloadVisitors.visit( + command, options((ctx, pls) -> Collections.singletonList(p("replaced")))); + assertEquals( + "replaced", + observed.getScheduleNexusOperationCommandAttributes().getInput().getData().toStringUtf8()); + assertEquals(Collections.singletonList("nexus"), counter.seen); + assertEquals(command, result); + } + + @Test + public void singlePayloadFieldRequiresExactlyOnePayload() { + Command command = + Command.newBuilder() + .setScheduleNexusOperationCommandAttributes( + ScheduleNexusOperationCommandAttributes.newBuilder().setInput(p("nexus"))) + .build(); + + // Returning zero payloads for a single-payload field is rejected. + assertThrows( + IllegalStateException.class, + () -> PayloadVisitors.visit(command, options((ctx, pls) -> Collections.emptyList()))); + + // Returning more than one payload for a single-payload field is rejected. + assertThrows( + IllegalStateException.class, + () -> + PayloadVisitors.visit( + command, options((ctx, pls) -> java.util.Arrays.asList(p("a"), p("b"))))); + } + + @Test + public void visitsMapOfPayloads() { + Command command = + Command.newBuilder() + .setUpsertWorkflowSearchAttributesCommandAttributes( + UpsertWorkflowSearchAttributesCommandAttributes.newBuilder() + .setSearchAttributes( + SearchAttributes.newBuilder() + .putIndexedFields("k1", p("v1")) + .putIndexedFields("k2", p("v2")))) + .build(); + + CollectingVisitor counter = new CollectingVisitor(); + PayloadVisitors.visit(command, options(counter)); + // A map is visited once per entry; map iteration order is unspecified, so + // assert the exact visit count and the value set rather than positional offsets. + assertEquals(2, counter.visits.get()); + assertEquals(new HashSet<>(java.util.Arrays.asList("v1", "v2")), new HashSet<>(counter.seen)); + + Command mutated = + PayloadVisitors.visit( + command, options((ctx, pls) -> Collections.singletonList(p(data(pls.get(0)) + "!")))); + Map fields = + mutated + .getUpsertWorkflowSearchAttributesCommandAttributes() + .getSearchAttributes() + .getIndexedFieldsMap(); + assertEquals("v1!", data(fields.get("k1"))); + assertEquals("v2!", data(fields.get("k2"))); + } + + @Test + public void visitsMapOfPayloadsSequences() { + Command command = + Command.newBuilder() + .setRecordMarkerCommandAttributes( + RecordMarkerCommandAttributes.newBuilder() + .setMarkerName("m") + .putDetails("d1", payloads("x", "y"))) + .build(); + + CollectingVisitor counter = new CollectingVisitor(); + PayloadVisitors.visit(command, options(counter)); + // A single map entry is one sequence: one visit, two payloads. + assertEquals(1, counter.visits.get()); + assertEquals(java.util.Arrays.asList("x", "y"), counter.seen); + } + + @Test + public void visitsMapOfMessages() { + // RespondWorkflowTaskCompletedRequest.query_results is map, whose + // values carry payloads: exercises the map-of-messages path (rebuild value + write back). + RespondWorkflowTaskCompletedRequest request = + RespondWorkflowTaskCompletedRequest.newBuilder() + .putQueryResults( + "q1", WorkflowQueryResult.newBuilder().setAnswer(payloads("a")).build()) + .build(); + + CollectingVisitor counter = new CollectingVisitor(); + PayloadVisitors.visit(request, options(counter)); + assertEquals(Collections.singletonList("a"), counter.seen); + + RespondWorkflowTaskCompletedRequest mutated = + PayloadVisitors.visit( + request, options((ctx, pls) -> Collections.singletonList(p(data(pls.get(0)) + "!")))); + assertEquals(payloads("a!"), mutated.getQueryResultsMap().get("q1").getAnswer()); + } + + @Test + public void visitsRepeatedPayloadField() { + // CountWorkflowExecutionsResponse.AggregationGroup.group_values is a bare repeated Payload. + CountWorkflowExecutionsResponse response = + CountWorkflowExecutionsResponse.newBuilder() + .addGroups( + CountWorkflowExecutionsResponse.AggregationGroup.newBuilder() + .addGroupValues(p("g1")) + .addGroupValues(p("g2"))) + .build(); + + CollectingVisitor counter = new CollectingVisitor(); + PayloadVisitors.visit(response, options(counter)); + // A repeated Payload is one sequence: one visit, two payloads. + assertEquals(1, counter.visits.get()); + assertEquals(java.util.Arrays.asList("g1", "g2"), counter.seen); + + CountWorkflowExecutionsResponse mutated = + PayloadVisitors.visit( + response, + options( + (ctx, pls) -> + pls.stream().map(pl -> p(data(pl) + "!")).collect(Collectors.toList()))); + assertEquals("g1!", data(mutated.getGroups(0).getGroupValues(0))); + assertEquals("g2!", data(mutated.getGroups(0).getGroupValues(1))); + } + + @Test + public void visitsPayloadsAsRoot() { + Payloads root = payloads("a", "b"); + + CollectingVisitor counter = new CollectingVisitor(); + Payloads unchanged = PayloadVisitors.visit(root, options(counter)); + // The repeated Payload inside Payloads is one sequence: one visit, two payloads. + assertEquals(1, counter.visits.get()); + assertEquals(java.util.Arrays.asList("a", "b"), counter.seen); + assertEquals(root, unchanged); + + Payloads mutated = + PayloadVisitors.visit(root, options((ctx, pls) -> Collections.singletonList(p("x")))); + assertEquals(payloads("x"), mutated); + } + + @Test + public void visitsBuilderInPlace() { + RespondWorkflowTaskCompletedRequest.Builder builder = + RespondWorkflowTaskCompletedRequest.newBuilder().addCommands(activity("a", payloads("x"))); + + PayloadVisitors.visit(builder, options((ctx, pls) -> Collections.singletonList(p("y")))); + + assertEquals( + payloads("y"), + builder.getCommands(0).getScheduleActivityTaskCommandAttributes().getInput()); + } + + @Test + public void visitCountDistinguishesSequencesFromMapEntries() { + // A Memo with two fields is visited once per entry: two visits, two payloads. + Memo memo = Memo.newBuilder().putFields("a", p("1")).putFields("b", p("2")).build(); + CollectingVisitor memoVisitor = new CollectingVisitor(); + PayloadVisitors.visit(memo, options(memoVisitor)); + // Memo fields are a map (unspecified order): assert visit count and the value set. + assertEquals(2, memoVisitor.visits.get()); + assertEquals(new HashSet<>(java.util.Arrays.asList("1", "2")), new HashSet<>(memoVisitor.seen)); + + // An activity command with two inputs is one Payloads sequence: one visit, two payloads. + Command command = + Command.newBuilder() + .setScheduleActivityTaskCommandAttributes( + ScheduleActivityTaskCommandAttributes.newBuilder().setInput(payloads("1", "2"))) + .build(); + CollectingVisitor inputVisitor = new CollectingVisitor(); + PayloadVisitors.visit(command, options(inputVisitor)); + // A Payloads sequence preserves order, so assert the exact ordered values. + assertEquals(1, inputVisitor.visits.get()); + assertEquals(java.util.Arrays.asList("1", "2"), inputVisitor.seen); + } + + @Test + public void visitsHeaders() { + Command command = + Command.newBuilder() + .setScheduleActivityTaskCommandAttributes( + ScheduleActivityTaskCommandAttributes.newBuilder() + .setInput(payloads("in")) + .setHeader(Header.newBuilder().putFields("h", p("hv")))) + .build(); + + CollectingVisitor counter = new CollectingVisitor(); + PayloadVisitors.visit(command, options(counter)); + // With headers not skipped (the default), the header payload is visited too. + assertEquals(new HashSet<>(java.util.Arrays.asList("in", "hv")), new HashSet<>(counter.seen)); + } + + @Test + public void visitsSearchAttributes() { + Command command = + Command.newBuilder() + .setStartChildWorkflowExecutionCommandAttributes( + StartChildWorkflowExecutionCommandAttributes.newBuilder() + .setInput(payloads("in")) + .setSearchAttributes( + SearchAttributes.newBuilder().putIndexedFields("k", p("v")))) + .build(); + + CollectingVisitor counter = new CollectingVisitor(); + PayloadVisitors.visit(command, options(counter)); + // With search attributes not skipped (the default), the search attribute payload is visited. + assertEquals(new HashSet<>(java.util.Arrays.asList("in", "v")), new HashSet<>(counter.seen)); + } + + @Test + public void skipsHeaders() { + Command command = + Command.newBuilder() + .setScheduleActivityTaskCommandAttributes( + ScheduleActivityTaskCommandAttributes.newBuilder() + .setInput(payloads("in")) + .setHeader(Header.newBuilder().putFields("h", p("hv")))) + .build(); + + CollectingVisitor counter = new CollectingVisitor(); + PayloadVisitors.visit( + command, + PayloadVisitorOptions.newBuilder().setPayloadVisitor(counter).setSkipHeaders(true).build()); + // The header payload is skipped; other payloads are still visited. + assertEquals(Collections.singletonList("in"), counter.seen); + } + + @Test + public void skipsSearchAttributes() { + Command command = + Command.newBuilder() + .setStartChildWorkflowExecutionCommandAttributes( + StartChildWorkflowExecutionCommandAttributes.newBuilder() + .setInput(payloads("in")) + .setSearchAttributes( + SearchAttributes.newBuilder().putIndexedFields("k", p("v")))) + .build(); + + CollectingVisitor counter = new CollectingVisitor(); + PayloadVisitors.visit( + command, + PayloadVisitorOptions.newBuilder() + .setPayloadVisitor(counter) + .setSkipSearchAttributes(true) + .build()); + // The search attribute payload is skipped; other payloads are still visited. + assertEquals(Collections.singletonList("in"), counter.seen); + } + + @Test + public void contextScopesPerCommand() { + RespondWorkflowTaskCompletedRequest request = + RespondWorkflowTaskCompletedRequest.newBuilder() + .addCommands(activity("a", payloads("act"))) + .addCommands( + Command.newBuilder() + .setCommandType(CommandType.COMMAND_TYPE_START_CHILD_WORKFLOW_EXECUTION) + .setStartChildWorkflowExecutionCommandAttributes( + StartChildWorkflowExecutionCommandAttributes.newBuilder() + .setInput(payloads("child"))) + .build()) + .build(); + + // Default concurrency (1) visits the two repeated commands in declaration order. + List dataOrder = new ArrayList<>(); + List contextOrder = new ArrayList<>(); + PayloadVisitorOptions opts = + PayloadVisitorOptions.newBuilder() + .setPayloadVisitor( + (ctx, pls) -> { + for (Payload p : pls) { + dataOrder.add(data(p)); + contextOrder.add(ctx.getContext()); + } + return pls; + }) + .setMessageVisitor( + (current, msg) -> + msg instanceof Command.Builder + ? ((Command.Builder) msg).getCommandType() + : current) + .build(); + + PayloadVisitors.visit(request, opts); + assertEquals(java.util.Arrays.asList("act", "child"), dataOrder); + assertEquals( + java.util.Arrays.asList( + CommandType.COMMAND_TYPE_SCHEDULE_ACTIVITY_TASK, + CommandType.COMMAND_TYPE_START_CHILD_WORKFLOW_EXECUTION), + contextOrder); + } + + @Test + public void initialContextUsedWhenNoMessageVisitor() { + Command command = + Command.newBuilder() + .setScheduleActivityTaskCommandAttributes( + ScheduleActivityTaskCommandAttributes.newBuilder().setInput(payloads("x"))) + .build(); + List observed = new ArrayList<>(); + PayloadVisitorOptions opts = + PayloadVisitorOptions.newBuilder() + .setInitialContext("root") + .setPayloadVisitor( + (ctx, pls) -> { + observed.add(ctx.getContext()); + return pls; + }) + .build(); + PayloadVisitors.visit(command, opts); + assertEquals(Collections.singletonList("root"), observed); + } + + @Test + public void limitsStyleValidatorComposesBothSeams() { + // The payload-limits feature is a read-only validator using both seams of PayloadVisitors: + // - per-payload (blob size) on the payload seam + // - per-message (e.g. memo field count) on the message seam + int blobLimit = 8; + int maxMemoFields = 2; + + PayloadVisitorOptions validator = + PayloadVisitorOptions.newBuilder() + .setPayloadVisitor( + (ctx, pls) -> { + for (Payload pl : pls) { + if (pl.getData().size() > blobLimit) { + throw new TestVisitorException("blob too large"); + } + } + return pls; // read-only + }) + .setMessageVisitor( + (current, msg) -> { + if (msg instanceof Memo.Builder + && ((Memo.Builder) msg).getFieldsCount() > maxMemoFields) { + throw new TestVisitorException("too many memo fields"); + } + return current; + }) + .build(); + + // Within both limits (small input, small memo): both seams run, neither trips. + RespondWorkflowTaskCompletedRequest ok = + RespondWorkflowTaskCompletedRequest.newBuilder() + .addCommands( + Command.newBuilder() + .setStartChildWorkflowExecutionCommandAttributes( + StartChildWorkflowExecutionCommandAttributes.newBuilder() + .setInput(payloads("small")) + .setMemo( + Memo.newBuilder().putFields("a", p("1")).putFields("b", p("2"))))) + .build(); + PayloadVisitors.visit(ok, validator); + + // Oversized blob trips the payload seam. + RespondWorkflowTaskCompletedRequest bigBlob = + RespondWorkflowTaskCompletedRequest.newBuilder() + .addCommands(activity("a", payloads("way-too-large-payload"))) + .build(); + TestVisitorException blobError = + assertThrows(TestVisitorException.class, () -> PayloadVisitors.visit(bigBlob, validator)); + assertEquals("blob too large", blobError.getMessage()); + + // Too many memo fields trips the message seam (its payloads are individually small). + RespondWorkflowTaskCompletedRequest bigMemo = + RespondWorkflowTaskCompletedRequest.newBuilder() + .addCommands( + Command.newBuilder() + .setStartChildWorkflowExecutionCommandAttributes( + StartChildWorkflowExecutionCommandAttributes.newBuilder() + .setMemo( + Memo.newBuilder() + .putFields("a", p("1")) + .putFields("b", p("2")) + .putFields("c", p("3"))))) + .build(); + TestVisitorException memoError = + assertThrows(TestVisitorException.class, () -> PayloadVisitors.visit(bigMemo, validator)); + assertEquals("too many memo fields", memoError.getMessage()); + } + + @Test + public void visitsNestedFailureCauses() { + // Failure.cause is itself a Failure, so the visitor recurses into its own type; payloads at + // each level of the cause chain must be visited. + Failure failure = + Failure.newBuilder() + .setMessage("outer") + .setApplicationFailureInfo( + ApplicationFailureInfo.newBuilder().setDetails(payloads("d1"))) + .setCause( + Failure.newBuilder() + .setMessage("inner") + .setApplicationFailureInfo( + ApplicationFailureInfo.newBuilder().setDetails(payloads("d2")))) + .build(); + + CollectingVisitor counter = new CollectingVisitor(); + PayloadVisitors.visit(failure, options(counter)); + assertEquals(2, counter.seen.size()); + assertTrue(counter.seen.contains("d1")); + assertTrue(counter.seen.contains("d2")); + } + + @Test + public void roundTripsPayloadInsideAny() throws Exception { + Memo memo = Memo.newBuilder().putFields("k", p("inside-any")).build(); + Message message = Message.newBuilder().setBody(Any.pack(memo)).build(); + + CollectingVisitor counter = new CollectingVisitor(); + Message result = PayloadVisitors.visit(message, options(counter)); + assertEquals(Collections.singletonList("inside-any"), counter.seen); + + // Mutating through the Any re-packs correctly. + Message mutated = + PayloadVisitors.visit( + message, options((ctx, pls) -> Collections.singletonList(p("changed")))); + Memo unpacked = mutated.getBody().unpack(Memo.class); + assertEquals("changed", data(unpacked.getFieldsMap().get("k"))); + // Unrelated content unchanged. + assertEquals(result.getBody().getTypeUrl(), mutated.getBody().getTypeUrl()); + } + + @Test + public void leavesUnknownAnyUntouched() throws Exception { + // An Any whose type is not in the registry is left as-is. + Message message = + Message.newBuilder() + .setBody( + Any.newBuilder() + .setTypeUrl("type.googleapis.com/some.unknown.Type") + .setValue(ByteString.copyFromUtf8("opaque"))) + .build(); + CollectingVisitor counter = new CollectingVisitor(); + Message result = PayloadVisitors.visit(message, options(counter)); + assertTrue(counter.seen.isEmpty()); + assertEquals(message, result); + } + + @Test + public void messageWithoutPayloadsReturnedUnchanged() { + Command command = + Command.newBuilder() + .setCancelWorkflowExecutionCommandAttributes( + io.temporal.api.command.v1.CancelWorkflowExecutionCommandAttributes.newBuilder()) + .build(); + CollectingVisitor counter = new CollectingVisitor(); + Command result = PayloadVisitors.visit(command, options(counter)); + assertTrue(counter.seen.isEmpty()); + assertEquals(command, result); + } + + @Test + public void propagatesVisitorError() { + RespondWorkflowTaskCompletedRequest request = + RespondWorkflowTaskCompletedRequest.newBuilder() + .addCommands(activity("a", payloads("x"))) + .build(); + TestVisitorException boom = new TestVisitorException("boom"); + TestVisitorException thrown = + assertThrows( + TestVisitorException.class, + () -> + PayloadVisitors.visit( + request, + options( + (ctx, pls) -> { + throw boom; + }))); + assertSame(boom, thrown); + } + + @Test + public void messageVisitorErrorPropagates() { + Command command = + Command.newBuilder() + .setScheduleActivityTaskCommandAttributes( + ScheduleActivityTaskCommandAttributes.newBuilder().setInput(payloads("x"))) + .build(); + TestVisitorException boom = new TestVisitorException("message visitor boom"); + TestVisitorException thrown = + assertThrows( + TestVisitorException.class, + () -> + PayloadVisitors.visit( + command, + PayloadVisitorOptions.newBuilder() + .setPayloadVisitor((ctx, pls) -> pls) + .setMessageVisitor( + (current, msg) -> { + throw boom; + }) + .build())); + assertSame(boom, thrown); + } + + @Test + public void registryCoversPayloadBearingTypesAndExcludesOthers() { + Map registry = GeneratedPayloadVisitor.REGISTRY; + // Representative payload-bearing types must be present. + for (String fullName : + new String[] { + "temporal.api.command.v1.Command", + "temporal.api.command.v1.ScheduleActivityTaskCommandAttributes", + "temporal.api.command.v1.RecordMarkerCommandAttributes", + "temporal.api.failure.v1.Failure", + "temporal.api.common.v1.Memo", + "temporal.api.common.v1.Header", + "temporal.api.common.v1.SearchAttributes", + "temporal.api.common.v1.Payloads", + "temporal.api.protocol.v1.Message", + "temporal.api.workflowservice.v1.RespondWorkflowTaskCompletedRequest" + }) { + assertTrue("missing visitor for " + fullName, registry.containsKey(fullName)); + } + // Types without reachable payloads must be excluded. + assertFalse(registry.containsKey("temporal.api.common.v1.Payload")); + assertFalse(registry.containsKey("temporal.api.common.v1.WorkflowExecution")); + assertFalse(registry.containsKey("google.protobuf.DescriptorProto")); + } + + @Test + public void rejectsMissingVisitor() { + assertThrows(IllegalArgumentException.class, () -> PayloadVisitorOptions.newBuilder().build()); + } + + @Test + public void nullReturnFromVisitorFails() { + Command command = + Command.newBuilder() + .setScheduleActivityTaskCommandAttributes( + ScheduleActivityTaskCommandAttributes.newBuilder().setInput(payloads("x"))) + .build(); + assertThrows( + IllegalStateException.class, + () -> PayloadVisitors.visit(command, options((ctx, pls) -> null))); + } + + // --- Concurrency and executor --- + + @Test + public void rejectsConcurrencyBelowOne() { + assertThrows( + IllegalArgumentException.class, + () -> + PayloadVisitorOptions.newBuilder() + .setPayloadVisitor((ctx, pls) -> pls) + .setConcurrency(0) + .build()); + } + + @Test + public void rejectsConcurrencyAboveOneWithoutExecutor() { + assertThrows( + IllegalArgumentException.class, + () -> + PayloadVisitorOptions.newBuilder() + .setPayloadVisitor((ctx, pls) -> pls) + .setConcurrency(2) + .build()); + } + + /** + * A request with {@code n} activity commands, each carrying one distinct single-payload input. + */ + static RespondWorkflowTaskCompletedRequest requestWithInputs(int n) { + RespondWorkflowTaskCompletedRequest.Builder b = + RespondWorkflowTaskCompletedRequest.newBuilder(); + for (int i = 0; i < n; i++) { + b.addCommands(activity("a" + i, payloads("p" + i))); + } + return b.build(); + } + + @Test + public void concurrencyEqualToWorkAllowsFullOverlap() throws Exception { + int n = 4; + RespondWorkflowTaskCompletedRequest request = requestWithInputs(n); + CyclicBarrier barrier = new CyclicBarrier(n); + + // Each of the n visits must reach the barrier simultaneously, proving n concurrent visits. + PayloadVisitors.visit( + request, + PayloadVisitorOptions.newBuilder() + .setConcurrency(n) + .setExecutor(executor) + .setPayloadVisitor( + (ctx, pls) -> { + try { + barrier.await(5, TimeUnit.SECONDS); + } catch (InterruptedException | BrokenBarrierException | TimeoutException e) { + throw new RuntimeException(e); + } + return pls; + }) + .build()); + } + + @Test + public void boundedConcurrencyNeverExceedsLimit() { + int n = 8; + int limit = 3; + RespondWorkflowTaskCompletedRequest request = requestWithInputs(n); + + AtomicInteger inFlight = new AtomicInteger(); + AtomicInteger maxInFlight = new AtomicInteger(); + + PayloadVisitors.visit( + request, + PayloadVisitorOptions.newBuilder() + .setConcurrency(limit) + .setExecutor(executor) + .setPayloadVisitor( + (ctx, pls) -> { + int now = inFlight.incrementAndGet(); + maxInFlight.accumulateAndGet(now, Math::max); + try { + Thread.sleep(20); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + inFlight.decrementAndGet(); + return pls; + }) + .build()); + + assertTrue( + "max in-flight " + maxInFlight.get() + " > limit " + limit, maxInFlight.get() <= limit); + assertTrue("expected some overlap, got " + maxInFlight.get(), maxInFlight.get() > 1); + } + + @Test + public void sequentialConcurrencyVisitsOneAtATimeInOrder() { + int n = 5; + RespondWorkflowTaskCompletedRequest request = requestWithInputs(n); + + AtomicInteger inFlight = new AtomicInteger(); + AtomicInteger maxInFlight = new AtomicInteger(); + List order = new ArrayList<>(); + + PayloadVisitors.visit( + request, + PayloadVisitorOptions.newBuilder() + .setConcurrency(1) + .setPayloadVisitor( + (ctx, pls) -> { + int now = inFlight.incrementAndGet(); + maxInFlight.accumulateAndGet(now, Math::max); + order.add(pls.get(0).getData().toStringUtf8()); + inFlight.decrementAndGet(); + return pls; + }) + .build()); + + assertEquals(1, maxInFlight.get()); + List expected = new ArrayList<>(); + for (int i = 0; i < n; i++) { + expected.add("p" + i); + } + assertEquals(expected, order); + } + + @Test + public void concurrentVisitorErrorPropagates() { + RespondWorkflowTaskCompletedRequest request = requestWithInputs(8); + TestVisitorException boom = new TestVisitorException("boom"); + TestVisitorException thrown = + assertThrows( + TestVisitorException.class, + () -> + PayloadVisitors.visit( + request, + PayloadVisitorOptions.newBuilder() + .setConcurrency(4) + .setExecutor(executor) + .setPayloadVisitor( + (ctx, pls) -> { + if (pls.get(0).getData().toStringUtf8().equals("p5")) { + throw boom; + } + return pls; + }) + .build())); + assertSame(boom, thrown); + } + + @Test + public void mutationsAppliedCorrectlyUnderConcurrency() { + int n = 16; + RespondWorkflowTaskCompletedRequest request = requestWithInputs(n); + RespondWorkflowTaskCompletedRequest mutated = + PayloadVisitors.visit( + request, + PayloadVisitorOptions.newBuilder() + .setConcurrency(8) + .setExecutor(executor) + .setPayloadVisitor( + (ctx, pls) -> { + Payload p = pls.get(0); + return Collections.singletonList( + p.toBuilder() + .setData(ByteString.copyFromUtf8(p.getData().toStringUtf8() + "!")) + .build()); + }) + .build()); + for (int i = 0; i < n; i++) { + assertEquals( + "p" + i + "!", + mutated + .getCommands(i) + .getScheduleActivityTaskCommandAttributes() + .getInput() + .getPayloads(0) + .getData() + .toStringUtf8()); + } + } + + @Test + public void concurrentVisitsRunOnProvidedExecutor() throws InterruptedException { + AtomicInteger threadSeq = new AtomicInteger(); + ThreadFactory factory = + r -> { + Thread t = new Thread(r); + t.setName("pv-exec-" + threadSeq.incrementAndGet()); + return t; + }; + ExecutorService pool = Executors.newFixedThreadPool(4, factory); + try { + RespondWorkflowTaskCompletedRequest request = requestWithInputs(12); + Set visitThreads = ConcurrentHashMap.newKeySet(); + + PayloadVisitors.visit( + request, + PayloadVisitorOptions.newBuilder() + .setConcurrency(4) + .setExecutor(pool) + .setPayloadVisitor( + (ctx, pls) -> { + visitThreads.add(Thread.currentThread().getName()); + return pls; + }) + .build()); + + assertTrue("no visits recorded", !visitThreads.isEmpty()); + for (String name : visitThreads) { + assertTrue("visit ran off the provided executor: " + name, name.startsWith("pv-exec-")); + } + } finally { + pool.shutdownNow(); + pool.awaitTermination(5, TimeUnit.SECONDS); + } + } + + @Test + public void sequentialConcurrencyIgnoresExecutorAndRunsOnCallingThread() { + // With concurrency == 1 the executor must never be touched and visits run inline. + AtomicInteger submittedToExecutor = new AtomicInteger(); + Executor tripwire = + command -> { + submittedToExecutor.incrementAndGet(); + command.run(); + }; + + RespondWorkflowTaskCompletedRequest request = requestWithInputs(5); + String callingThread = Thread.currentThread().getName(); + AtomicReference sawDifferentThread = new AtomicReference<>(); + + PayloadVisitors.visit( + request, + PayloadVisitorOptions.newBuilder() + .setConcurrency(1) + .setExecutor(tripwire) + .setPayloadVisitor( + (ctx, pls) -> { + if (!Thread.currentThread().getName().equals(callingThread)) { + sawDifferentThread.set(Thread.currentThread().getName()); + } + return pls; + }) + .build()); + + assertEquals("executor was used for sequential traversal", 0, submittedToExecutor.get()); + assertEquals("visit ran off the calling thread", null, sawDifferentThread.get()); + } +} diff --git a/temporal-sdk/src/test/java/io/temporal/internal/payload/visitor/TestVisitorException.java b/temporal-sdk/src/test/java/io/temporal/internal/payload/visitor/TestVisitorException.java new file mode 100644 index 0000000000..ecbfab7829 --- /dev/null +++ b/temporal-sdk/src/test/java/io/temporal/internal/payload/visitor/TestVisitorException.java @@ -0,0 +1,12 @@ +package io.temporal.internal.payload.visitor; + +/** + * Exception thrown only by test visitor/message callbacks. Using a dedicated type keeps "the + * visitor threw" assertions from being satisfied by an unrelated {@link IllegalStateException} that + * production code might raise. + */ +class TestVisitorException extends RuntimeException { + TestVisitorException(String message) { + super(message); + } +} From 4955ebff78061d475a057efdb27452934ab73d6e Mon Sep 17 00:00:00 2001 From: jmaeagle99 <44687433+jmaeagle99@users.noreply.github.com> Date: Mon, 15 Jun 2026 12:32:44 -0700 Subject: [PATCH 2/5] pass context directly instead of PayloadVisitorContext --- .../payload/visitor/PayloadVisitor.java | 4 +-- .../visitor/PayloadVisitorContext.java | 33 ------------------- .../payload/visitor/PayloadVisitors.java | 2 +- .../internal/payload/visitor/Traversal.java | 15 ++++----- .../visitor/gen/PayloadVisitorGenerator.java | 12 +++---- .../payload/visitor/PayloadVisitorTest.java | 6 ++-- 6 files changed, 18 insertions(+), 54 deletions(-) delete mode 100644 temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/PayloadVisitorContext.java diff --git a/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/PayloadVisitor.java b/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/PayloadVisitor.java index 2f791aeb36..38d231432d 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/PayloadVisitor.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/PayloadVisitor.java @@ -18,9 +18,9 @@ interface PayloadVisitor { /** * Visits a sequence of payloads and returns their replacements. * - * @param context the location of these payloads and the contextual value in scope + * @param context the contextual value in scope * @param payloads the payloads found at this location * @return the replacement payloads */ - List visit(PayloadVisitorContext context, List payloads); + List visit(C context, List payloads); } diff --git a/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/PayloadVisitorContext.java b/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/PayloadVisitorContext.java deleted file mode 100644 index 756b1404d7..0000000000 --- a/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/PayloadVisitorContext.java +++ /dev/null @@ -1,33 +0,0 @@ -package io.temporal.internal.payload.visitor; - -import com.google.protobuf.MessageOrBuilder; -import javax.annotation.Nonnull; -import javax.annotation.Nullable; - -/** - * The context for one payload visitor call: the contextual value in scope and the message that - * contains the payloads being visited. - * - * @param type of the contextual value - */ -final class PayloadVisitorContext { - private final @Nullable C context; - private final @Nonnull MessageOrBuilder parent; - - PayloadVisitorContext(@Nullable C context, @Nonnull MessageOrBuilder parent) { - this.context = context; - this.parent = parent; - } - - /** The contextual value in scope at this location, or {@code null} if none. */ - @Nullable - public C getContext() { - return context; - } - - /** The message that directly contains the payloads being visited. */ - @Nonnull - public MessageOrBuilder getParent() { - return parent; - } -} diff --git a/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/PayloadVisitors.java b/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/PayloadVisitors.java index b810976ef0..9cc828688b 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/PayloadVisitors.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/PayloadVisitors.java @@ -13,7 +13,7 @@ * PayloadVisitors.visit( * request, * PayloadVisitorOptions.newBuilder() - * .setPayloadVisitor((ctx, payloads) -> encode(ctx.getContext(), payloads)) + * .setPayloadVisitor((ctx, payloads) -> encode(ctx, payloads)) * .setMessageVisitor((current, msg) -> msg instanceof Command.Builder * ? CommandInfo.of((Command.Builder) msg) : current) * .setConcurrency(4) diff --git a/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/Traversal.java b/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/Traversal.java index a9922f64b1..985f02b02e 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/Traversal.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/Traversal.java @@ -92,11 +92,11 @@ void exit(Object previous) { } /** Record a visit of a payload sequence ({@code Payloads} or {@code repeated Payload}). */ - void payloads(MessageOrBuilder parent, List batch, Consumer> writeBack) { + void payloads(List batch, Consumer> writeBack) { if (payloadVisitor == null) { return; // message-only traversal: payload seams are inert } - LeafJob job = new LeafJob(batch, currentContext, parent, false); + LeafJob job = new LeafJob(batch, currentContext, false); jobs.add(job); writeBacks.add(() -> writeBack.accept(job.result)); } @@ -105,11 +105,11 @@ void payloads(MessageOrBuilder parent, List batch, Consumer writeBack) { + void singlePayload(Payload value, Consumer writeBack) { if (payloadVisitor == null) { return; // message-only traversal: payload seams are inert } - LeafJob job = new LeafJob(Collections.singletonList(value), currentContext, parent, true); + LeafJob job = new LeafJob(Collections.singletonList(value), currentContext, true); jobs.add(job); writeBacks.add(() -> writeBack.accept(job.result.get(0))); } @@ -207,8 +207,7 @@ private void executeConcurrently() { } private void runJob(LeafJob job) { - List result = - payloadVisitor.visit(new PayloadVisitorContext<>(job.context, job.parent), job.input); + List result = payloadVisitor.visit(job.context, job.input); if (result == null) { throw new IllegalStateException("payload visitor returned null"); } @@ -223,14 +222,12 @@ private void runJob(LeafJob job) { private static final class LeafJob { final List input; final Object context; - final MessageOrBuilder parent; final boolean single; volatile List result; - LeafJob(List input, Object context, MessageOrBuilder parent, boolean single) { + LeafJob(List input, Object context, boolean single) { this.input = input; this.context = context; - this.parent = parent; this.single = single; } } diff --git a/temporal-sdk/src/payloadVisitorGenerator/java/io/temporal/internal/payload/visitor/gen/PayloadVisitorGenerator.java b/temporal-sdk/src/payloadVisitorGenerator/java/io/temporal/internal/payload/visitor/gen/PayloadVisitorGenerator.java index de435aa81e..efe522aea8 100644 --- a/temporal-sdk/src/payloadVisitorGenerator/java/io/temporal/internal/payload/visitor/gen/PayloadVisitorGenerator.java +++ b/temporal-sdk/src/payloadVisitorGenerator/java/io/temporal/internal/payload/visitor/gen/PayloadVisitorGenerator.java @@ -379,7 +379,7 @@ private void emitField(StringBuilder sb, FieldDescriptor f, FieldPlan plan, int switch (plan.kind) { case SINGLE_PAYLOAD: sb.append(" if (b.has").append(B).append("()) {\n"); - sb.append(" t.singlePayload(b, b.get") + sb.append(" t.singlePayload(b.get") .append(B) .append("(), p -> b.set") .append(B) @@ -387,14 +387,14 @@ private void emitField(StringBuilder sb, FieldDescriptor f, FieldPlan plan, int sb.append(" }\n"); break; case REPEATED_PAYLOAD: - sb.append(" t.payloads(b, b.get").append(B).append("List(), pl -> {\n"); + sb.append(" t.payloads(b.get").append(B).append("List(), pl -> {\n"); sb.append(" b.clear").append(B).append("();\n"); sb.append(" b.addAll").append(B).append("(pl);\n"); sb.append(" });\n"); break; case PAYLOADS_SINGLE: sb.append(" if (b.has").append(B).append("()) {\n"); - sb.append(" t.payloads(b, b.get").append(B).append("().getPayloadsList(),\n"); + sb.append(" t.payloads(b.get").append(B).append("().getPayloadsList(),\n"); sb.append(" pl -> b.set") .append(B) .append("(") @@ -413,7 +413,7 @@ private void emitField(StringBuilder sb, FieldDescriptor f, FieldPlan plan, int .append(v) .append("++) {\n"); sb.append(" final int ").append(k).append(" = ").append(v).append(";\n"); - sb.append(" t.payloads(b, b.get") + sb.append(" t.payloads(b.get") .append(B) .append("(") .append(k) @@ -434,7 +434,7 @@ private void emitField(StringBuilder sb, FieldDescriptor f, FieldPlan plan, int .append(B) .append("Map().keySet())) {\n"); sb.append(" final String ").append(v).append(" = ").append(k).append(";\n"); - sb.append(" t.singlePayload(b, b.get") + sb.append(" t.singlePayload(b.get") .append(B) .append("Map().get(") .append(v) @@ -452,7 +452,7 @@ private void emitField(StringBuilder sb, FieldDescriptor f, FieldPlan plan, int .append(B) .append("Map().keySet())) {\n"); sb.append(" final String ").append(v).append(" = ").append(k).append(";\n"); - sb.append(" t.payloads(b, b.get") + sb.append(" t.payloads(b.get") .append(B) .append("Map().get(") .append(v) diff --git a/temporal-sdk/src/test/java/io/temporal/internal/payload/visitor/PayloadVisitorTest.java b/temporal-sdk/src/test/java/io/temporal/internal/payload/visitor/PayloadVisitorTest.java index 7870a518c3..7e011a3577 100644 --- a/temporal-sdk/src/test/java/io/temporal/internal/payload/visitor/PayloadVisitorTest.java +++ b/temporal-sdk/src/test/java/io/temporal/internal/payload/visitor/PayloadVisitorTest.java @@ -75,7 +75,7 @@ static final class CollectingVisitor implements PayloadVisitor { final AtomicInteger visits = new AtomicInteger(); @Override - public List visit(PayloadVisitorContext ctx, List payloads) { + public List visit(Object ctx, List payloads) { visits.incrementAndGet(); for (Payload p : payloads) { seen.add(data(p)); @@ -434,7 +434,7 @@ public void contextScopesPerCommand() { (ctx, pls) -> { for (Payload p : pls) { dataOrder.add(data(p)); - contextOrder.add(ctx.getContext()); + contextOrder.add(ctx); } return pls; }) @@ -467,7 +467,7 @@ public void initialContextUsedWhenNoMessageVisitor() { .setInitialContext("root") .setPayloadVisitor( (ctx, pls) -> { - observed.add(ctx.getContext()); + observed.add(ctx); return pls; }) .build(); From 4e084d4e880a9a58dab75c6447de7e72acb9c9f4 Mon Sep 17 00:00:00 2001 From: jmaeagle99 <44687433+jmaeagle99@users.noreply.github.com> Date: Mon, 15 Jun 2026 13:26:15 -0700 Subject: [PATCH 3/5] make visitor required parameter of builders --- .../visitor/MessageVisitorOptions.java | 18 ++-- .../visitor/PayloadVisitorOptions.java | 18 ++-- .../payload/visitor/PayloadVisitors.java | 3 +- .../payload/visitor/MessageVisitorTest.java | 21 ++--- .../payload/visitor/PayloadVisitorTest.java | 88 +++++++------------ 5 files changed, 53 insertions(+), 95 deletions(-) diff --git a/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/MessageVisitorOptions.java b/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/MessageVisitorOptions.java index 350c75a793..681a9cd120 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/MessageVisitorOptions.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/MessageVisitorOptions.java @@ -1,5 +1,6 @@ package io.temporal.internal.payload.visitor; +import java.util.Objects; import javax.annotation.Nonnull; import javax.annotation.Nullable; @@ -17,8 +18,8 @@ private MessageVisitorOptions(Builder b) { this.initialContext = b.initialContext; } - public static Builder newBuilder() { - return new Builder<>(); + public static Builder newBuilder(@Nonnull MessageVisitor messageVisitor) { + return new Builder<>(messageVisitor); } @Nonnull @@ -32,15 +33,11 @@ public C getInitialContext() { } public static final class Builder { - private MessageVisitor messageVisitor; + private final @Nonnull MessageVisitor messageVisitor; private C initialContext; - private Builder() {} - - /** Required. The message visitor. */ - public Builder setMessageVisitor(@Nonnull MessageVisitor messageVisitor) { - this.messageVisitor = messageVisitor; - return this; + private Builder(@Nonnull MessageVisitor messageVisitor) { + this.messageVisitor = Objects.requireNonNull(messageVisitor, "messageVisitor"); } /** Optional. The contextual value in scope before any message is entered. */ @@ -50,9 +47,6 @@ public Builder setInitialContext(@Nullable C initialContext) { } public MessageVisitorOptions build() { - if (messageVisitor == null) { - throw new IllegalArgumentException("messageVisitor is required"); - } return new MessageVisitorOptions<>(this); } } diff --git a/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/PayloadVisitorOptions.java b/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/PayloadVisitorOptions.java index fc9f6b9ecd..5212b3a91b 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/PayloadVisitorOptions.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/PayloadVisitorOptions.java @@ -1,5 +1,6 @@ package io.temporal.internal.payload.visitor; +import java.util.Objects; import java.util.concurrent.Executor; import javax.annotation.Nonnull; import javax.annotation.Nullable; @@ -28,8 +29,8 @@ private PayloadVisitorOptions(Builder b) { this.executor = b.executor; } - public static Builder newBuilder() { - return new Builder<>(); + public static Builder newBuilder(@Nonnull PayloadVisitor payloadVisitor) { + return new Builder<>(payloadVisitor); } @Nonnull @@ -69,7 +70,7 @@ public Executor getExecutor() { } public static final class Builder { - private PayloadVisitor payloadVisitor; + private final @Nonnull PayloadVisitor payloadVisitor; private MessageVisitor messageVisitor; private C initialContext; private boolean skipSearchAttributes; @@ -77,12 +78,8 @@ public static final class Builder { private int concurrency = 1; private Executor executor; - private Builder() {} - - /** Required. The payload visitor. */ - public Builder setPayloadVisitor(@Nonnull PayloadVisitor payloadVisitor) { - this.payloadVisitor = payloadVisitor; - return this; + private Builder(@Nonnull PayloadVisitor payloadVisitor) { + this.payloadVisitor = Objects.requireNonNull(payloadVisitor, "payloadVisitor"); } /** Optional. A callback invoked when entering each message. */ @@ -125,9 +122,6 @@ public Builder setExecutor(@Nullable Executor executor) { } public PayloadVisitorOptions build() { - if (payloadVisitor == null) { - throw new IllegalArgumentException("payloadVisitor is required"); - } if (concurrency < 1) { throw new IllegalArgumentException("concurrency must be at least 1, got " + concurrency); } diff --git a/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/PayloadVisitors.java b/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/PayloadVisitors.java index 9cc828688b..731bf8fb1b 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/PayloadVisitors.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/PayloadVisitors.java @@ -12,8 +12,7 @@ * RespondWorkflowTaskCompletedRequest result = * PayloadVisitors.visit( * request, - * PayloadVisitorOptions.newBuilder() - * .setPayloadVisitor((ctx, payloads) -> encode(ctx, payloads)) + * PayloadVisitorOptions.newBuilder((ctx, payloads) -> encode(ctx, payloads)) * .setMessageVisitor((current, msg) -> msg instanceof Command.Builder * ? CommandInfo.of((Command.Builder) msg) : current) * .setConcurrency(4) diff --git a/temporal-sdk/src/test/java/io/temporal/internal/payload/visitor/MessageVisitorTest.java b/temporal-sdk/src/test/java/io/temporal/internal/payload/visitor/MessageVisitorTest.java index d6239ac744..8cc7cc7bd6 100644 --- a/temporal-sdk/src/test/java/io/temporal/internal/payload/visitor/MessageVisitorTest.java +++ b/temporal-sdk/src/test/java/io/temporal/internal/payload/visitor/MessageVisitorTest.java @@ -43,8 +43,7 @@ public void visitsBuilderInPlace() { List entered = new ArrayList<>(); MessageVisitors.visit( builder, - MessageVisitorOptions.newBuilder() - .setMessageVisitor( + MessageVisitorOptions.newBuilder( (current, msg) -> { entered.add(msg.getDescriptorForType().getFullName()); return current; @@ -61,8 +60,7 @@ public void messageVisitorMutatesInPlace() { RespondWorkflowTaskCompletedRequest result = MessageVisitors.visit( request, - MessageVisitorOptions.newBuilder() - .setMessageVisitor( + MessageVisitorOptions.newBuilder( (current, msg) -> { if (msg instanceof ScheduleActivityTaskCommandAttributes.Builder) { ((ScheduleActivityTaskCommandAttributes.Builder) msg) @@ -103,8 +101,7 @@ public void visitsEachMessageWithScopedContext() { List contextOnEnter = new ArrayList<>(); MessageVisitorOptions opts = - MessageVisitorOptions.newBuilder() - .setMessageVisitor( + MessageVisitorOptions.newBuilder( (current, msg) -> { entered.add(msg.getDescriptorForType().getFullName()); contextOnEnter.add(current); @@ -145,8 +142,7 @@ public void visitsEachMessageWithScopedContext() { public void messageOnlyVisitorValidatesPerMessageType() { int maxMemoFields = 2; MessageVisitorOptions opts = - MessageVisitorOptions.newBuilder() - .setMessageVisitor( + MessageVisitorOptions.newBuilder( (current, msg) -> { if (msg instanceof Memo.Builder && ((Memo.Builder) msg).getFieldsCount() > maxMemoFields) { @@ -174,19 +170,18 @@ public void initialContextObservedAtRoot() { List observed = new ArrayList<>(); MessageVisitors.visit( memo, - MessageVisitorOptions.newBuilder() - .setInitialContext("root") - .setMessageVisitor( + MessageVisitorOptions.newBuilder( (current, msg) -> { observed.add(current); return current; }) + .setInitialContext("root") .build()); assertEquals(Arrays.asList("root"), observed); } @Test - public void rejectsMissingMessageVisitor() { - assertThrows(IllegalArgumentException.class, () -> MessageVisitorOptions.newBuilder().build()); + public void rejectsNullMessageVisitor() { + assertThrows(NullPointerException.class, () -> MessageVisitorOptions.newBuilder(null)); } } diff --git a/temporal-sdk/src/test/java/io/temporal/internal/payload/visitor/PayloadVisitorTest.java b/temporal-sdk/src/test/java/io/temporal/internal/payload/visitor/PayloadVisitorTest.java index 7e011a3577..4ca2204663 100644 --- a/temporal-sdk/src/test/java/io/temporal/internal/payload/visitor/PayloadVisitorTest.java +++ b/temporal-sdk/src/test/java/io/temporal/internal/payload/visitor/PayloadVisitorTest.java @@ -85,7 +85,7 @@ public List visit(Object ctx, List payloads) { } static PayloadVisitorOptions options(PayloadVisitor visitor) { - return PayloadVisitorOptions.newBuilder().setPayloadVisitor(visitor).build(); + return PayloadVisitorOptions.newBuilder(visitor).build(); } static Command activity(String activityId, Payloads input) { @@ -383,8 +383,7 @@ public void skipsHeaders() { CollectingVisitor counter = new CollectingVisitor(); PayloadVisitors.visit( - command, - PayloadVisitorOptions.newBuilder().setPayloadVisitor(counter).setSkipHeaders(true).build()); + command, PayloadVisitorOptions.newBuilder(counter).setSkipHeaders(true).build()); // The header payload is skipped; other payloads are still visited. assertEquals(Collections.singletonList("in"), counter.seen); } @@ -402,11 +401,7 @@ public void skipsSearchAttributes() { CollectingVisitor counter = new CollectingVisitor(); PayloadVisitors.visit( - command, - PayloadVisitorOptions.newBuilder() - .setPayloadVisitor(counter) - .setSkipSearchAttributes(true) - .build()); + command, PayloadVisitorOptions.newBuilder(counter).setSkipSearchAttributes(true).build()); // The search attribute payload is skipped; other payloads are still visited. assertEquals(Collections.singletonList("in"), counter.seen); } @@ -429,8 +424,7 @@ public void contextScopesPerCommand() { List dataOrder = new ArrayList<>(); List contextOrder = new ArrayList<>(); PayloadVisitorOptions opts = - PayloadVisitorOptions.newBuilder() - .setPayloadVisitor( + PayloadVisitorOptions.newBuilder( (ctx, pls) -> { for (Payload p : pls) { dataOrder.add(data(p)); @@ -463,13 +457,12 @@ public void initialContextUsedWhenNoMessageVisitor() { .build(); List observed = new ArrayList<>(); PayloadVisitorOptions opts = - PayloadVisitorOptions.newBuilder() - .setInitialContext("root") - .setPayloadVisitor( + PayloadVisitorOptions.newBuilder( (ctx, pls) -> { observed.add(ctx); return pls; }) + .setInitialContext("root") .build(); PayloadVisitors.visit(command, opts); assertEquals(Collections.singletonList("root"), observed); @@ -484,8 +477,7 @@ public void limitsStyleValidatorComposesBothSeams() { int maxMemoFields = 2; PayloadVisitorOptions validator = - PayloadVisitorOptions.newBuilder() - .setPayloadVisitor( + PayloadVisitorOptions.newBuilder( (ctx, pls) -> { for (Payload pl : pls) { if (pl.getData().size() > blobLimit) { @@ -649,8 +641,7 @@ public void messageVisitorErrorPropagates() { () -> PayloadVisitors.visit( command, - PayloadVisitorOptions.newBuilder() - .setPayloadVisitor((ctx, pls) -> pls) + PayloadVisitorOptions.newBuilder((ctx, pls) -> pls) .setMessageVisitor( (current, msg) -> { throw boom; @@ -685,8 +676,8 @@ public void registryCoversPayloadBearingTypesAndExcludesOthers() { } @Test - public void rejectsMissingVisitor() { - assertThrows(IllegalArgumentException.class, () -> PayloadVisitorOptions.newBuilder().build()); + public void rejectsNullPayloadVisitor() { + assertThrows(NullPointerException.class, () -> PayloadVisitorOptions.newBuilder(null)); } @Test @@ -707,22 +698,14 @@ public void nullReturnFromVisitorFails() { public void rejectsConcurrencyBelowOne() { assertThrows( IllegalArgumentException.class, - () -> - PayloadVisitorOptions.newBuilder() - .setPayloadVisitor((ctx, pls) -> pls) - .setConcurrency(0) - .build()); + () -> PayloadVisitorOptions.newBuilder((ctx, pls) -> pls).setConcurrency(0).build()); } @Test public void rejectsConcurrencyAboveOneWithoutExecutor() { assertThrows( IllegalArgumentException.class, - () -> - PayloadVisitorOptions.newBuilder() - .setPayloadVisitor((ctx, pls) -> pls) - .setConcurrency(2) - .build()); + () -> PayloadVisitorOptions.newBuilder((ctx, pls) -> pls).setConcurrency(2).build()); } /** @@ -746,10 +729,7 @@ public void concurrencyEqualToWorkAllowsFullOverlap() throws Exception { // Each of the n visits must reach the barrier simultaneously, proving n concurrent visits. PayloadVisitors.visit( request, - PayloadVisitorOptions.newBuilder() - .setConcurrency(n) - .setExecutor(executor) - .setPayloadVisitor( + PayloadVisitorOptions.newBuilder( (ctx, pls) -> { try { barrier.await(5, TimeUnit.SECONDS); @@ -758,6 +738,8 @@ public void concurrencyEqualToWorkAllowsFullOverlap() throws Exception { } return pls; }) + .setConcurrency(n) + .setExecutor(executor) .build()); } @@ -772,10 +754,7 @@ public void boundedConcurrencyNeverExceedsLimit() { PayloadVisitors.visit( request, - PayloadVisitorOptions.newBuilder() - .setConcurrency(limit) - .setExecutor(executor) - .setPayloadVisitor( + PayloadVisitorOptions.newBuilder( (ctx, pls) -> { int now = inFlight.incrementAndGet(); maxInFlight.accumulateAndGet(now, Math::max); @@ -787,6 +766,8 @@ public void boundedConcurrencyNeverExceedsLimit() { inFlight.decrementAndGet(); return pls; }) + .setConcurrency(limit) + .setExecutor(executor) .build()); assertTrue( @@ -805,9 +786,7 @@ public void sequentialConcurrencyVisitsOneAtATimeInOrder() { PayloadVisitors.visit( request, - PayloadVisitorOptions.newBuilder() - .setConcurrency(1) - .setPayloadVisitor( + PayloadVisitorOptions.newBuilder( (ctx, pls) -> { int now = inFlight.incrementAndGet(); maxInFlight.accumulateAndGet(now, Math::max); @@ -815,6 +794,7 @@ public void sequentialConcurrencyVisitsOneAtATimeInOrder() { inFlight.decrementAndGet(); return pls; }) + .setConcurrency(1) .build()); assertEquals(1, maxInFlight.get()); @@ -835,16 +815,15 @@ public void concurrentVisitorErrorPropagates() { () -> PayloadVisitors.visit( request, - PayloadVisitorOptions.newBuilder() - .setConcurrency(4) - .setExecutor(executor) - .setPayloadVisitor( + PayloadVisitorOptions.newBuilder( (ctx, pls) -> { if (pls.get(0).getData().toStringUtf8().equals("p5")) { throw boom; } return pls; }) + .setConcurrency(4) + .setExecutor(executor) .build())); assertSame(boom, thrown); } @@ -856,10 +835,7 @@ public void mutationsAppliedCorrectlyUnderConcurrency() { RespondWorkflowTaskCompletedRequest mutated = PayloadVisitors.visit( request, - PayloadVisitorOptions.newBuilder() - .setConcurrency(8) - .setExecutor(executor) - .setPayloadVisitor( + PayloadVisitorOptions.newBuilder( (ctx, pls) -> { Payload p = pls.get(0); return Collections.singletonList( @@ -867,6 +843,8 @@ public void mutationsAppliedCorrectlyUnderConcurrency() { .setData(ByteString.copyFromUtf8(p.getData().toStringUtf8() + "!")) .build()); }) + .setConcurrency(8) + .setExecutor(executor) .build()); for (int i = 0; i < n; i++) { assertEquals( @@ -897,14 +875,13 @@ public void concurrentVisitsRunOnProvidedExecutor() throws InterruptedException PayloadVisitors.visit( request, - PayloadVisitorOptions.newBuilder() - .setConcurrency(4) - .setExecutor(pool) - .setPayloadVisitor( + PayloadVisitorOptions.newBuilder( (ctx, pls) -> { visitThreads.add(Thread.currentThread().getName()); return pls; }) + .setConcurrency(4) + .setExecutor(pool) .build()); assertTrue("no visits recorded", !visitThreads.isEmpty()); @@ -933,16 +910,15 @@ public void sequentialConcurrencyIgnoresExecutorAndRunsOnCallingThread() { PayloadVisitors.visit( request, - PayloadVisitorOptions.newBuilder() - .setConcurrency(1) - .setExecutor(tripwire) - .setPayloadVisitor( + PayloadVisitorOptions.newBuilder( (ctx, pls) -> { if (!Thread.currentThread().getName().equals(callingThread)) { sawDifferentThread.set(Thread.currentThread().getName()); } return pls; }) + .setConcurrency(1) + .setExecutor(tripwire) .build()); assertEquals("executor was used for sequential traversal", 0, submittedToExecutor.get()); From aff81757d97304e43bc57c5630700decf12ef84e Mon Sep 17 00:00:00 2001 From: jmaeagle99 <44687433+jmaeagle99@users.noreply.github.com> Date: Tue, 16 Jun 2026 14:13:21 -0700 Subject: [PATCH 4/5] async using CompletableFuture and AsyncSemaphore --- .../internal/common/AsyncSemaphore.java | 76 ++++ .../visitor/MessageVisitorOptions.java | 2 +- .../payload/visitor/MessageVisitors.java | 10 +- .../payload/visitor/PayloadVisitor.java | 22 +- .../visitor/PayloadVisitorOptions.java | 33 +- .../payload/visitor/PayloadVisitors.java | 40 +- .../internal/payload/visitor/Traversal.java | 207 +++++----- .../worker/tuning/FixedSizeSlotSupplier.java | 74 +--- .../payload/visitor/PayloadVisitorTest.java | 368 +++++++++--------- 9 files changed, 399 insertions(+), 433 deletions(-) create mode 100644 temporal-sdk/src/main/java/io/temporal/internal/common/AsyncSemaphore.java diff --git a/temporal-sdk/src/main/java/io/temporal/internal/common/AsyncSemaphore.java b/temporal-sdk/src/main/java/io/temporal/internal/common/AsyncSemaphore.java new file mode 100644 index 0000000000..c44ab3d4e4 --- /dev/null +++ b/temporal-sdk/src/main/java/io/temporal/internal/common/AsyncSemaphore.java @@ -0,0 +1,76 @@ +package io.temporal.internal.common; + +import java.util.ArrayDeque; +import java.util.Queue; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.locks.ReentrantLock; + +/** + * A simple async semaphore. Unfortunately there's not any readily available properly licensed + * library I could find for this which is a bit shocking, but this implementation should be suitable + * for our needs. + */ +public final class AsyncSemaphore { + private final ReentrantLock lock = new ReentrantLock(); + private final Queue> waiters = new ArrayDeque<>(); + private int permits; + + public AsyncSemaphore(int initialPermits) { + this.permits = initialPermits; + } + + /** + * Acquire a permit asynchronously. If a permit is available, returns a completed future, + * otherwise returns a future that will be completed when a permit is released. + */ + public CompletableFuture acquire() { + lock.lock(); + try { + if (permits > 0) { + permits--; + return CompletableFuture.completedFuture(null); + } else { + CompletableFuture waiter = new CompletableFuture<>(); + waiters.add(waiter); + return waiter; + } + } finally { + lock.unlock(); + } + } + + public boolean tryAcquire() { + lock.lock(); + try { + if (permits > 0) { + permits--; + return true; + } + return false; + } finally { + lock.unlock(); + } + } + + /** + * Release a permit. If there are waiting futures, completes the next one instead of incrementing + * the permit count. + */ + public void release() { + lock.lock(); + try { + CompletableFuture waiter = waiters.poll(); + if (waiter != null) { + if (!waiter.complete(null) && waiter.isCancelled()) { + // If this waiter was cancelled, we need to release another permit, since this waiter + // is now useless + release(); + } + } else { + permits++; + } + } finally { + lock.unlock(); + } + } +} diff --git a/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/MessageVisitorOptions.java b/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/MessageVisitorOptions.java index 681a9cd120..2389a5dcbc 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/MessageVisitorOptions.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/MessageVisitorOptions.java @@ -40,7 +40,7 @@ private Builder(@Nonnull MessageVisitor messageVisitor) { this.messageVisitor = Objects.requireNonNull(messageVisitor, "messageVisitor"); } - /** Optional. The contextual value in scope before any message is entered. */ + /** The contextual value in scope before any message is entered. */ public Builder setInitialContext(@Nullable C initialContext) { this.initialContext = initialContext; return this; diff --git a/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/MessageVisitors.java b/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/MessageVisitors.java index 5a4de476b0..c749415382 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/MessageVisitors.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/MessageVisitors.java @@ -23,16 +23,14 @@ public static void visit( /* skipSearchAttributes= */ false, /* skipHeaders= */ false, 1, - null, GeneratedPayloadVisitor.REGISTRY); traversal.dispatch(builder); - traversal.execute(); + // No payload visits, so execute() completes inline; join() returns at once. Message-visitor + // errors throw from dispatch above. + traversal.execute().join(); } - /** - * Visits the messages in {@code message}, returning a copy with any changes applied; the input is - * unchanged. - */ + /** Returns a copy with any changes applied; {@code message} is unchanged. */ @SuppressWarnings("unchecked") public static T visit( @Nonnull T message, @Nonnull MessageVisitorOptions options) { diff --git a/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/PayloadVisitor.java b/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/PayloadVisitor.java index 38d231432d..8872a44cde 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/PayloadVisitor.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/PayloadVisitor.java @@ -2,25 +2,21 @@ import io.temporal.api.common.v1.Payload; import java.util.List; +import java.util.concurrent.CompletableFuture; /** - * Callback for a sequence of payloads found in a proto message. The returned list replaces those - * payloads; return the same list to leave them unchanged. + * Callback completing with the list that replaces {@code payloads}; complete with the same list to + * leave them unchanged. Asynchronous so I/O-backed implementations (e.g. external storage) compose + * without blocking a thread per call; a synchronous one returns {@link + * CompletableFuture#completedFuture}. * - *

When the visited field holds a single payload the list has one element and the visitor must - * return exactly one payload. With a concurrency limit greater than one, visits may run on multiple - * threads, so implementations must be thread-safe. + *

For a single-payload field the visitor must complete with exactly one payload. With + * concurrency greater than one, several visits may be in flight at once, so implementations must be + * thread-safe. * * @param type of the contextual value supplied to each visit */ @FunctionalInterface interface PayloadVisitor { - /** - * Visits a sequence of payloads and returns their replacements. - * - * @param context the contextual value in scope - * @param payloads the payloads found at this location - * @return the replacement payloads - */ - List visit(C context, List payloads); + CompletableFuture> visit(C context, List payloads); } diff --git a/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/PayloadVisitorOptions.java b/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/PayloadVisitorOptions.java index 5212b3a91b..c834c2ec8e 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/PayloadVisitorOptions.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/PayloadVisitorOptions.java @@ -1,7 +1,6 @@ package io.temporal.internal.payload.visitor; import java.util.Objects; -import java.util.concurrent.Executor; import javax.annotation.Nonnull; import javax.annotation.Nullable; @@ -17,7 +16,6 @@ final class PayloadVisitorOptions { private final boolean skipSearchAttributes; private final boolean skipHeaders; private final int concurrency; - private final @Nullable Executor executor; private PayloadVisitorOptions(Builder b) { this.payloadVisitor = b.payloadVisitor; @@ -26,7 +24,6 @@ private PayloadVisitorOptions(Builder b) { this.skipSearchAttributes = b.skipSearchAttributes; this.skipHeaders = b.skipHeaders; this.concurrency = b.concurrency; - this.executor = b.executor; } public static Builder newBuilder(@Nonnull PayloadVisitor payloadVisitor) { @@ -48,27 +45,18 @@ public C getInitialContext() { return initialContext; } - /** Whether search attribute payloads are skipped. */ public boolean isSkipSearchAttributes() { return skipSearchAttributes; } - /** Whether header payloads are skipped. */ public boolean isSkipHeaders() { return skipHeaders; } - /** Maximum number of visits that may run concurrently; {@code 1} is sequential. */ public int getConcurrency() { return concurrency; } - /** Executor for concurrent visits; {@code null} when concurrency is {@code 1}. */ - @Nullable - public Executor getExecutor() { - return executor; - } - public static final class Builder { private final @Nonnull PayloadVisitor payloadVisitor; private MessageVisitor messageVisitor; @@ -76,59 +64,42 @@ public static final class Builder { private boolean skipSearchAttributes; private boolean skipHeaders; private int concurrency = 1; - private Executor executor; private Builder(@Nonnull PayloadVisitor payloadVisitor) { this.payloadVisitor = Objects.requireNonNull(payloadVisitor, "payloadVisitor"); } - /** Optional. A callback invoked when entering each message. */ public Builder setMessageVisitor(@Nullable MessageVisitor messageVisitor) { this.messageVisitor = messageVisitor; return this; } - /** Optional. The contextual value in scope before any message is entered. */ + /** The contextual value in scope before any message is entered. */ public Builder setInitialContext(@Nullable C initialContext) { this.initialContext = initialContext; return this; } - /** Whether to skip search attribute payloads. */ public Builder setSkipSearchAttributes(boolean skipSearchAttributes) { this.skipSearchAttributes = skipSearchAttributes; return this; } - /** Whether to skip header payloads. */ public Builder setSkipHeaders(boolean skipHeaders) { this.skipHeaders = skipHeaders; return this; } - /** - * Maximum number of concurrent visits; must be at least {@code 1} (the default, sequential). A - * value greater than {@code 1} requires an executor (see {@link #setExecutor}). - */ + /** At least {@code 1} (sequential). Bounds outstanding visit futures; no executor needed. */ public Builder setConcurrency(int concurrency) { this.concurrency = concurrency; return this; } - /** Executor for concurrent visits. Required when concurrency is greater than {@code 1}. */ - public Builder setExecutor(@Nullable Executor executor) { - this.executor = executor; - return this; - } - public PayloadVisitorOptions build() { if (concurrency < 1) { throw new IllegalArgumentException("concurrency must be at least 1, got " + concurrency); } - if (concurrency > 1 && executor == null) { - throw new IllegalArgumentException( - "executor is required when concurrency is greater than 1"); - } return new PayloadVisitorOptions<>(this); } } diff --git a/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/PayloadVisitors.java b/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/PayloadVisitors.java index 731bf8fb1b..67c3fb5950 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/PayloadVisitors.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/PayloadVisitors.java @@ -1,29 +1,17 @@ package io.temporal.internal.payload.visitor; import com.google.protobuf.Message; +import java.util.concurrent.CompletableFuture; import javax.annotation.Nonnull; /** - * Visits every payload within a proto message. A message with no payloads is returned unchanged. - * - *

This is an SDK-internal utility; it is not part of the public API. - * - *

{@code
- * RespondWorkflowTaskCompletedRequest result =
- *     PayloadVisitors.visit(
- *         request,
- *         PayloadVisitorOptions.newBuilder((ctx, payloads) -> encode(ctx, payloads))
- *             .setMessageVisitor((current, msg) -> msg instanceof Command.Builder
- *                 ? CommandInfo.of((Command.Builder) msg) : current)
- *             .setConcurrency(4)
- *             .build());
- * }
+ * Visits every payload within a proto message. */ final class PayloadVisitors { private PayloadVisitors() {} /** Visits the payloads in {@code builder} in place. */ - public static void visit( + public static CompletableFuture visit( @Nonnull Message.Builder builder, @Nonnull PayloadVisitorOptions options) { Traversal traversal = new Traversal( @@ -33,21 +21,23 @@ public static void visit( options.isSkipSearchAttributes(), options.isSkipHeaders(), options.getConcurrency(), - options.getExecutor(), GeneratedPayloadVisitor.REGISTRY); - traversal.dispatch(builder); - traversal.execute(); + try { + traversal.dispatch(builder); + } catch (Throwable t) { + // Surface a walk failure through the future, so all failures reach the caller the same way. + CompletableFuture failed = new CompletableFuture<>(); + failed.completeExceptionally(t); + return failed; + } + return traversal.execute(); } - /** - * Visits the payloads in {@code message}, returning a copy with replacements applied; the input - * is unchanged. - */ + /** Completes with a copy that has the replacements applied; {@code message} is unchanged. */ @SuppressWarnings("unchecked") - public static T visit( + public static CompletableFuture visit( @Nonnull T message, @Nonnull PayloadVisitorOptions options) { Message.Builder builder = message.toBuilder(); - visit(builder, options); - return (T) builder.build(); + return visit(builder, options).thenApply(v -> (T) builder.build()); } } diff --git a/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/Traversal.java b/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/Traversal.java index 985f02b02e..ffbdd712ce 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/Traversal.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/Traversal.java @@ -5,34 +5,32 @@ import com.google.protobuf.Message; import com.google.protobuf.MessageOrBuilder; import io.temporal.api.common.v1.Payload; +import io.temporal.internal.common.AsyncSemaphore; import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.Map; import java.util.concurrent.CompletableFuture; -import java.util.concurrent.Executor; -import java.util.concurrent.Semaphore; +import java.util.concurrent.CompletionException; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; /** * Mutable state for one traversal, called into by the generated per-message visitors. * - *

A single-threaded walk records a visit job and a write-back for each payload sequence it - * finds; {@link #execute()} then runs the visitor calls (optionally with bounded concurrency) and - * finally applies the write-backs in walk order, so the non-thread-safe builders are never mutated - * concurrently. + *

The single-threaded walk only records a job and a write-back per payload sequence; {@link + * #execute()} runs the visits and applies the write-backs afterward in walk order, so the + * non-thread-safe builders are never mutated concurrently. Visits are asynchronous, so the engine + * needs no executor — it only bounds how many of their futures are outstanding. */ final class Traversal { - // The payload visitor is null for a message-only traversal (see MessageVisitors); in that case - // the payload seams are skipped and only the per-message MessageVisitor fires. + // Null for a message-only traversal: payload seams are skipped, only the MessageVisitor fires. private final PayloadVisitor payloadVisitor; private final MessageVisitor messageVisitor; private final Map registry; final boolean skipSearchAttributes; final boolean skipHeaders; private final int concurrency; - private final Executor executor; private final List jobs = new ArrayList<>(); private final List writeBacks = new ArrayList<>(); @@ -46,27 +44,22 @@ final class Traversal { boolean skipSearchAttributes, boolean skipHeaders, int concurrency, - Executor executor, Map registry) { if (concurrency < 1) { throw new IllegalArgumentException("concurrency must be at least 1, got " + concurrency); } - if (concurrency > 1 && executor == null) { - throw new IllegalArgumentException("executor is required when concurrency is greater than 1"); - } this.payloadVisitor = (PayloadVisitor) payloadVisitor; this.messageVisitor = (MessageVisitor) messageVisitor; this.currentContext = initialContext; this.skipSearchAttributes = skipSearchAttributes; this.skipHeaders = skipHeaders; this.concurrency = concurrency; - this.executor = executor; this.registry = registry; } // --- Structural walk: called by generated code --- - /** Dispatch to the generated visitor for {@code builder}'s type; no-op if it has no payloads. */ + /** No-op for a type with no payloads. */ void dispatch(Message.Builder builder) { MessageRegistryEntry entry = registry.get(builder.getDescriptorForType().getFullName()); if (entry != null) { @@ -74,10 +67,7 @@ void dispatch(Message.Builder builder) { } } - /** - * Run the message visitor for {@code message}, narrowing the scoped context; returns the value to - * restore. - */ + /** Narrows the scoped context; returns the value {@link #exit} restores. */ Object enter(MessageOrBuilder message) { Object previous = currentContext; if (messageVisitor != null) { @@ -86,47 +76,43 @@ Object enter(MessageOrBuilder message) { return previous; } - /** Restore the scoped context to {@code previous} when leaving a message's subtree. */ void exit(Object previous) { currentContext = previous; } - /** Record a visit of a payload sequence ({@code Payloads} or {@code repeated Payload}). */ + /** Record a visit of a payload sequence (a {@code Payloads} or {@code repeated Payload}). */ void payloads(List batch, Consumer> writeBack) { if (payloadVisitor == null) { - return; // message-only traversal: payload seams are inert + return; } LeafJob job = new LeafJob(batch, currentContext, false); jobs.add(job); writeBacks.add(() -> writeBack.accept(job.result)); } - /** - * Record a visit of a singular payload field. The visitor must return exactly one payload for - * such a field (enforced in {@link #runJob}), which the consumer writes back. - */ + /** The visitor must return exactly one payload (checked in {@link #record}). */ void singlePayload(Payload value, Consumer writeBack) { if (payloadVisitor == null) { - return; // message-only traversal: payload seams are inert + return; } LeafJob job = new LeafJob(Collections.singletonList(value), currentContext, true); jobs.add(job); writeBacks.add(() -> writeBack.accept(job.result.get(0))); } - /** Append a deferred write-back, applied (single-threaded) after all visits and in walk order. */ + /** Applied after all visits, single-threaded, in walk order. */ void deferWriteBack(Runnable writeBack) { writeBacks.add(writeBack); } - /** Unpack a {@code google.protobuf.Any}, traverse its contents, and re-pack it after visits. */ + /** Unpack a {@code google.protobuf.Any}, traverse its contents, and re-pack after visits. */ void any(Any.Builder anyBuilder) { String typeUrl = anyBuilder.getTypeUrl(); int slash = typeUrl.lastIndexOf('/'); String fullName = slash >= 0 ? typeUrl.substring(slash + 1) : typeUrl; MessageRegistryEntry entry = registry.get(fullName); if (entry == null) { - // Unknown type, or a type with no payloads; leave the Any untouched. + // Unknown or payload-free type: leave the Any untouched. return; } Message.Builder inner = entry.newBuilder.get(); @@ -139,86 +125,109 @@ void any(Any.Builder anyBuilder) { deferWriteBack(() -> anyBuilder.setValue(inner.build().toByteString())); } - // --- Execution: visitor calls (phase 2) then write-backs (phase 3) --- + // --- Execution: visits, then write-backs --- - void execute() { - if (jobs.isEmpty()) { - return; - } - if (concurrency <= 1 || jobs.size() == 1) { - for (LeafJob job : jobs) { - runJob(job); - } - } else { - executeConcurrently(); - } - for (Runnable writeBack : writeBacks) { - writeBack.run(); - } + /** + * Completes the returned future once the visits and write-backs are done. Blocks no thread of its + * own: the caller decides how to wait and which executor to chain on. Write-backs run on whatever + * thread completes the last visit (inline if the visits are synchronous). A visit failure aborts + * the traversal — remaining visits unstarted, write-backs skipped — and completes the future + * exceptionally with the original throwable. + */ + CompletableFuture execute() { + CompletableFuture visitsDone = + jobs.isEmpty() ? CompletableFuture.completedFuture(null) : runVisits(); + CompletableFuture result = new CompletableFuture<>(); + visitsDone.whenComplete( + (v, err) -> { + if (err != null) { + result.completeExceptionally(unwrap(err)); + return; + } + try { + for (Runnable writeBack : writeBacks) { + writeBack.run(); + } + result.complete(null); + } catch (Throwable t) { + result.completeExceptionally(t); + } + }); + return result; } - private void executeConcurrently() { - // concurrency > 1 and a non-null executor are guaranteed by the constructor's validation. - Executor pool = executor; - Semaphore semaphore = new Semaphore(concurrency); + /** Run the visits with at most {@code concurrency} outstanding; fails with the first error. */ + private CompletableFuture runVisits() { + AsyncSemaphore permits = new AsyncSemaphore(concurrency); AtomicReference firstError = new AtomicReference<>(); - List> futures = new ArrayList<>(jobs.size()); + List> all = new ArrayList<>(jobs.size()); for (LeafJob job : jobs) { - if (firstError.get() != null) { - break; - } - try { - semaphore.acquire(); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - firstError.compareAndSet(null, e); - break; - } - if (firstError.get() != null) { - semaphore.release(); - break; - } - futures.add( - CompletableFuture.runAsync( - () -> { - try { - runJob(job); - } catch (Throwable t) { - firstError.compareAndSet(null, t); - } finally { - semaphore.release(); - } - }, - pool)); - } - CompletableFuture.allOf(futures.toArray(new CompletableFuture[0])).join(); - Throwable error = firstError.get(); - if (error instanceof RuntimeException) { - throw (RuntimeException) error; - } - if (error instanceof Error) { - throw (Error) error; - } - if (error != null) { - // The only checked exception that can reach here is an InterruptedException from acquiring - // the semaphore. - throw new VisitorException("payload visit interrupted", error); - } + all.add(permits.acquire().thenCompose(p -> runJob(job, permits, firstError))); + } + return CompletableFuture.allOf(all.toArray(new CompletableFuture[0])) + .thenCompose( + v -> { + Throwable error = firstError.get(); + if (error == null) { + return CompletableFuture.completedFuture(null); + } + CompletableFuture failed = new CompletableFuture<>(); + failed.completeExceptionally(error); + return failed; + }); } - private void runJob(LeafJob job) { - List result = payloadVisitor.visit(job.context, job.input); - if (result == null) { - throw new IllegalStateException("payload visitor returned null"); + /** Runs one job under a held permit, releasing it exactly once when the visit settles. */ + private CompletableFuture runJob( + LeafJob job, AsyncSemaphore permits, AtomicReference firstError) { + if (firstError.get() != null) { + permits.release(); + return CompletableFuture.completedFuture(null); } - if (job.single && result.size() != 1) { - throw new IllegalStateException( - "single-payload field requires exactly 1 returned payload, got " + result.size()); + CompletableFuture> visit; + try { + visit = payloadVisitor.visit(job.context, job.input); + } catch (RuntimeException | Error t) { + firstError.compareAndSet(null, t); + permits.release(); + return CompletableFuture.completedFuture(null); + } + if (visit == null) { + firstError.compareAndSet(null, new IllegalStateException("payload visitor returned null")); + permits.release(); + return CompletableFuture.completedFuture(null); + } + return visit + .handle( + (result, err) -> { + record(job, result, err, firstError); + return (Void) null; + }) + .whenComplete((v, e) -> permits.release()); + } + + private void record( + LeafJob job, List result, Throwable err, AtomicReference firstError) { + if (err != null) { + firstError.compareAndSet(null, unwrap(err)); + } else if (result == null) { + firstError.compareAndSet(null, new IllegalStateException("payload visitor returned null")); + } else if (job.single && result.size() != 1) { + firstError.compareAndSet( + null, + new IllegalStateException( + "single-payload field requires exactly 1 returned payload, got " + result.size())); + } else { + job.result = result; } - job.result = result; } - /** A single recorded visitor call and the slot its result is written into. */ + /** Strip the {@link CompletionException} a dependent stage wraps around its cause. */ + private static Throwable unwrap(Throwable t) { + return (t instanceof CompletionException && t.getCause() != null) ? t.getCause() : t; + } + + /** A recorded visit and the slot its result lands in. */ private static final class LeafJob { final List input; final Object context; diff --git a/temporal-sdk/src/main/java/io/temporal/worker/tuning/FixedSizeSlotSupplier.java b/temporal-sdk/src/main/java/io/temporal/worker/tuning/FixedSizeSlotSupplier.java index b62b8ec8d2..d376242788 100644 --- a/temporal-sdk/src/main/java/io/temporal/worker/tuning/FixedSizeSlotSupplier.java +++ b/temporal-sdk/src/main/java/io/temporal/worker/tuning/FixedSizeSlotSupplier.java @@ -1,11 +1,9 @@ package io.temporal.worker.tuning; import com.google.common.base.Preconditions; -import java.util.ArrayDeque; +import io.temporal.internal.common.AsyncSemaphore; import java.util.Optional; -import java.util.Queue; import java.util.concurrent.CompletableFuture; -import java.util.concurrent.locks.ReentrantLock; /** * This implementation of {@link SlotSupplier} provides a fixed number of slots backed by a @@ -17,76 +15,6 @@ public class FixedSizeSlotSupplier implements SlotSupplier< private final int numSlots; private final AsyncSemaphore executorSlotsSemaphore; - /** - * A simple version of an async semaphore. Unfortunately there's not any readily available - * properly licensed library I could find for this which is a bit shocking, but this - * implementation should be suitable for our needs - */ - static class AsyncSemaphore { - private final ReentrantLock lock = new ReentrantLock(); - private final Queue> waiters = new ArrayDeque<>(); - private int permits; - - AsyncSemaphore(int initialPermits) { - this.permits = initialPermits; - } - - /** - * Acquire a permit asynchronously. If a permit is available, returns a completed future, - * otherwise returns a future that will be completed when a permit is released. - */ - public CompletableFuture acquire() { - lock.lock(); - try { - if (permits > 0) { - permits--; - return CompletableFuture.completedFuture(null); - } else { - CompletableFuture waiter = new CompletableFuture<>(); - waiters.add(waiter); - return waiter; - } - } finally { - lock.unlock(); - } - } - - public boolean tryAcquire() { - lock.lock(); - try { - if (permits > 0) { - permits--; - return true; - } - return false; - } finally { - lock.unlock(); - } - } - - /** - * Release a permit. If there are waiting futures, completes the next one instead of - * incrementing the permit count. - */ - public void release() { - lock.lock(); - try { - CompletableFuture waiter = waiters.poll(); - if (waiter != null) { - if (!waiter.complete(null) && waiter.isCancelled()) { - // If this waiter was cancelled, we need to release another permit, since this waiter - // is now useless - release(); - } - } else { - permits++; - } - } finally { - lock.unlock(); - } - } - } - public FixedSizeSlotSupplier(int numSlots) { Preconditions.checkArgument(numSlots > 0, "FixedSizeSlotSupplier must have at least one slot"); this.numSlots = numSlots; diff --git a/temporal-sdk/src/test/java/io/temporal/internal/payload/visitor/PayloadVisitorTest.java b/temporal-sdk/src/test/java/io/temporal/internal/payload/visitor/PayloadVisitorTest.java index 4ca2204663..d0e36bfc6e 100644 --- a/temporal-sdk/src/test/java/io/temporal/internal/payload/visitor/PayloadVisitorTest.java +++ b/temporal-sdk/src/test/java/io/temporal/internal/payload/visitor/PayloadVisitorTest.java @@ -31,18 +31,15 @@ import java.util.HashSet; import java.util.List; import java.util.Map; -import java.util.Set; import java.util.concurrent.BrokenBarrierException; -import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; import java.util.concurrent.CyclicBarrier; -import java.util.concurrent.Executor; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; -import java.util.concurrent.ThreadFactory; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.atomic.AtomicReference; import java.util.stream.Collectors; import org.junit.After; import org.junit.Before; @@ -66,11 +63,21 @@ static Payloads payloads(String... values) { return b.build(); } + /** A synchronous visit; {@link #toAsync} adapts it to the asynchronous {@link PayloadVisitor}. */ + @FunctionalInterface + interface SyncPayloadVisitor { + List visit(Object ctx, List payloads); + } + + static PayloadVisitor toAsync(SyncPayloadVisitor visitor) { + return (ctx, pls) -> CompletableFuture.completedFuture(visitor.visit(ctx, pls)); + } + /** * Records every payload seen (in order) and the number of visit calls, leaving payloads * unchanged. */ - static final class CollectingVisitor implements PayloadVisitor { + static final class CollectingVisitor implements SyncPayloadVisitor { final List seen = Collections.synchronizedList(new ArrayList<>()); final AtomicInteger visits = new AtomicInteger(); @@ -84,8 +91,35 @@ public List visit(Object ctx, List payloads) { } } - static PayloadVisitorOptions options(PayloadVisitor visitor) { - return PayloadVisitorOptions.newBuilder(visitor).build(); + static PayloadVisitorOptions options(SyncPayloadVisitor visitor) { + return PayloadVisitorOptions.newBuilder(toAsync(visitor)).build(); + } + + /** + * Blocks and unwraps the {@link CompletionException} {@code join} adds, exposing the original. + */ + static T visit( + T message, PayloadVisitorOptions options) { + return join(PayloadVisitors.visit(message, options)); + } + + static void visit(com.google.protobuf.Message.Builder builder, PayloadVisitorOptions options) { + join(PayloadVisitors.visit(builder, options)); + } + + private static V join(CompletableFuture future) { + try { + return future.join(); + } catch (CompletionException e) { + Throwable cause = e.getCause(); + if (cause instanceof RuntimeException) { + throw (RuntimeException) cause; + } + if (cause instanceof Error) { + throw (Error) cause; + } + throw e; + } } static Command activity(String activityId, Payloads input) { @@ -98,7 +132,7 @@ static Command activity(String activityId, Payloads input) { .build(); } - /** Executor supplied to the concurrent visits (unused by the single-threaded tests). */ + /** Backs the simulated async visitors in the concurrency tests; the engine needs no executor. */ private ExecutorService executor; @Before @@ -120,8 +154,7 @@ public void visitsAndMutatesAllPayloads() { .build(); CollectingVisitor counter = new CollectingVisitor(); - RespondWorkflowTaskCompletedRequest unchanged = - PayloadVisitors.visit(request, options(counter)); + RespondWorkflowTaskCompletedRequest unchanged = visit(request, options(counter)); assertEquals(java.util.Arrays.asList("one", "two", "three"), counter.seen); // Two Payloads sequences (one per command's input): two visits, three payloads. assertEquals(2, counter.visits.get()); @@ -129,7 +162,7 @@ public void visitsAndMutatesAllPayloads() { // Mutating: uppercase every payload's data. RespondWorkflowTaskCompletedRequest mutated = - PayloadVisitors.visit( + visit( request, options( (ctx, pls) -> @@ -157,13 +190,12 @@ public void visitsSinglePayloadField() { .build(); CollectingVisitor counter = new CollectingVisitor(); - Command result = PayloadVisitors.visit(command, options(counter)); + Command result = visit(command, options(counter)); assertEquals(Collections.singletonList("nexus"), counter.seen); // A single-payload field can be replaced with one payload. Command observed = - PayloadVisitors.visit( - command, options((ctx, pls) -> Collections.singletonList(p("replaced")))); + visit(command, options((ctx, pls) -> Collections.singletonList(p("replaced")))); assertEquals( "replaced", observed.getScheduleNexusOperationCommandAttributes().getInput().getData().toStringUtf8()); @@ -182,14 +214,12 @@ public void singlePayloadFieldRequiresExactlyOnePayload() { // Returning zero payloads for a single-payload field is rejected. assertThrows( IllegalStateException.class, - () -> PayloadVisitors.visit(command, options((ctx, pls) -> Collections.emptyList()))); + () -> visit(command, options((ctx, pls) -> Collections.emptyList()))); // Returning more than one payload for a single-payload field is rejected. assertThrows( IllegalStateException.class, - () -> - PayloadVisitors.visit( - command, options((ctx, pls) -> java.util.Arrays.asList(p("a"), p("b"))))); + () -> visit(command, options((ctx, pls) -> java.util.Arrays.asList(p("a"), p("b"))))); } @Test @@ -205,15 +235,14 @@ public void visitsMapOfPayloads() { .build(); CollectingVisitor counter = new CollectingVisitor(); - PayloadVisitors.visit(command, options(counter)); + visit(command, options(counter)); // A map is visited once per entry; map iteration order is unspecified, so // assert the exact visit count and the value set rather than positional offsets. assertEquals(2, counter.visits.get()); assertEquals(new HashSet<>(java.util.Arrays.asList("v1", "v2")), new HashSet<>(counter.seen)); Command mutated = - PayloadVisitors.visit( - command, options((ctx, pls) -> Collections.singletonList(p(data(pls.get(0)) + "!")))); + visit(command, options((ctx, pls) -> Collections.singletonList(p(data(pls.get(0)) + "!")))); Map fields = mutated .getUpsertWorkflowSearchAttributesCommandAttributes() @@ -234,7 +263,7 @@ public void visitsMapOfPayloadsSequences() { .build(); CollectingVisitor counter = new CollectingVisitor(); - PayloadVisitors.visit(command, options(counter)); + visit(command, options(counter)); // A single map entry is one sequence: one visit, two payloads. assertEquals(1, counter.visits.get()); assertEquals(java.util.Arrays.asList("x", "y"), counter.seen); @@ -251,12 +280,11 @@ public void visitsMapOfMessages() { .build(); CollectingVisitor counter = new CollectingVisitor(); - PayloadVisitors.visit(request, options(counter)); + visit(request, options(counter)); assertEquals(Collections.singletonList("a"), counter.seen); RespondWorkflowTaskCompletedRequest mutated = - PayloadVisitors.visit( - request, options((ctx, pls) -> Collections.singletonList(p(data(pls.get(0)) + "!")))); + visit(request, options((ctx, pls) -> Collections.singletonList(p(data(pls.get(0)) + "!")))); assertEquals(payloads("a!"), mutated.getQueryResultsMap().get("q1").getAnswer()); } @@ -272,13 +300,13 @@ public void visitsRepeatedPayloadField() { .build(); CollectingVisitor counter = new CollectingVisitor(); - PayloadVisitors.visit(response, options(counter)); + visit(response, options(counter)); // A repeated Payload is one sequence: one visit, two payloads. assertEquals(1, counter.visits.get()); assertEquals(java.util.Arrays.asList("g1", "g2"), counter.seen); CountWorkflowExecutionsResponse mutated = - PayloadVisitors.visit( + visit( response, options( (ctx, pls) -> @@ -292,14 +320,13 @@ public void visitsPayloadsAsRoot() { Payloads root = payloads("a", "b"); CollectingVisitor counter = new CollectingVisitor(); - Payloads unchanged = PayloadVisitors.visit(root, options(counter)); + Payloads unchanged = visit(root, options(counter)); // The repeated Payload inside Payloads is one sequence: one visit, two payloads. assertEquals(1, counter.visits.get()); assertEquals(java.util.Arrays.asList("a", "b"), counter.seen); assertEquals(root, unchanged); - Payloads mutated = - PayloadVisitors.visit(root, options((ctx, pls) -> Collections.singletonList(p("x")))); + Payloads mutated = visit(root, options((ctx, pls) -> Collections.singletonList(p("x")))); assertEquals(payloads("x"), mutated); } @@ -308,7 +335,7 @@ public void visitsBuilderInPlace() { RespondWorkflowTaskCompletedRequest.Builder builder = RespondWorkflowTaskCompletedRequest.newBuilder().addCommands(activity("a", payloads("x"))); - PayloadVisitors.visit(builder, options((ctx, pls) -> Collections.singletonList(p("y")))); + visit(builder, options((ctx, pls) -> Collections.singletonList(p("y")))); assertEquals( payloads("y"), @@ -320,7 +347,7 @@ public void visitCountDistinguishesSequencesFromMapEntries() { // A Memo with two fields is visited once per entry: two visits, two payloads. Memo memo = Memo.newBuilder().putFields("a", p("1")).putFields("b", p("2")).build(); CollectingVisitor memoVisitor = new CollectingVisitor(); - PayloadVisitors.visit(memo, options(memoVisitor)); + visit(memo, options(memoVisitor)); // Memo fields are a map (unspecified order): assert visit count and the value set. assertEquals(2, memoVisitor.visits.get()); assertEquals(new HashSet<>(java.util.Arrays.asList("1", "2")), new HashSet<>(memoVisitor.seen)); @@ -332,7 +359,7 @@ public void visitCountDistinguishesSequencesFromMapEntries() { ScheduleActivityTaskCommandAttributes.newBuilder().setInput(payloads("1", "2"))) .build(); CollectingVisitor inputVisitor = new CollectingVisitor(); - PayloadVisitors.visit(command, options(inputVisitor)); + visit(command, options(inputVisitor)); // A Payloads sequence preserves order, so assert the exact ordered values. assertEquals(1, inputVisitor.visits.get()); assertEquals(java.util.Arrays.asList("1", "2"), inputVisitor.seen); @@ -349,7 +376,7 @@ public void visitsHeaders() { .build(); CollectingVisitor counter = new CollectingVisitor(); - PayloadVisitors.visit(command, options(counter)); + visit(command, options(counter)); // With headers not skipped (the default), the header payload is visited too. assertEquals(new HashSet<>(java.util.Arrays.asList("in", "hv")), new HashSet<>(counter.seen)); } @@ -366,7 +393,7 @@ public void visitsSearchAttributes() { .build(); CollectingVisitor counter = new CollectingVisitor(); - PayloadVisitors.visit(command, options(counter)); + visit(command, options(counter)); // With search attributes not skipped (the default), the search attribute payload is visited. assertEquals(new HashSet<>(java.util.Arrays.asList("in", "v")), new HashSet<>(counter.seen)); } @@ -382,8 +409,7 @@ public void skipsHeaders() { .build(); CollectingVisitor counter = new CollectingVisitor(); - PayloadVisitors.visit( - command, PayloadVisitorOptions.newBuilder(counter).setSkipHeaders(true).build()); + visit(command, PayloadVisitorOptions.newBuilder(toAsync(counter)).setSkipHeaders(true).build()); // The header payload is skipped; other payloads are still visited. assertEquals(Collections.singletonList("in"), counter.seen); } @@ -400,8 +426,9 @@ public void skipsSearchAttributes() { .build(); CollectingVisitor counter = new CollectingVisitor(); - PayloadVisitors.visit( - command, PayloadVisitorOptions.newBuilder(counter).setSkipSearchAttributes(true).build()); + visit( + command, + PayloadVisitorOptions.newBuilder(toAsync(counter)).setSkipSearchAttributes(true).build()); // The search attribute payload is skipped; other payloads are still visited. assertEquals(Collections.singletonList("in"), counter.seen); } @@ -430,7 +457,7 @@ public void contextScopesPerCommand() { dataOrder.add(data(p)); contextOrder.add(ctx); } - return pls; + return CompletableFuture.completedFuture(pls); }) .setMessageVisitor( (current, msg) -> @@ -439,7 +466,7 @@ public void contextScopesPerCommand() { : current) .build(); - PayloadVisitors.visit(request, opts); + visit(request, opts); assertEquals(java.util.Arrays.asList("act", "child"), dataOrder); assertEquals( java.util.Arrays.asList( @@ -460,11 +487,11 @@ public void initialContextUsedWhenNoMessageVisitor() { PayloadVisitorOptions.newBuilder( (ctx, pls) -> { observed.add(ctx); - return pls; + return CompletableFuture.completedFuture(pls); }) .setInitialContext("root") .build(); - PayloadVisitors.visit(command, opts); + visit(command, opts); assertEquals(Collections.singletonList("root"), observed); } @@ -484,7 +511,7 @@ public void limitsStyleValidatorComposesBothSeams() { throw new TestVisitorException("blob too large"); } } - return pls; // read-only + return CompletableFuture.completedFuture(pls); // read-only }) .setMessageVisitor( (current, msg) -> { @@ -507,7 +534,7 @@ public void limitsStyleValidatorComposesBothSeams() { .setMemo( Memo.newBuilder().putFields("a", p("1")).putFields("b", p("2"))))) .build(); - PayloadVisitors.visit(ok, validator); + visit(ok, validator); // Oversized blob trips the payload seam. RespondWorkflowTaskCompletedRequest bigBlob = @@ -515,7 +542,7 @@ public void limitsStyleValidatorComposesBothSeams() { .addCommands(activity("a", payloads("way-too-large-payload"))) .build(); TestVisitorException blobError = - assertThrows(TestVisitorException.class, () -> PayloadVisitors.visit(bigBlob, validator)); + assertThrows(TestVisitorException.class, () -> visit(bigBlob, validator)); assertEquals("blob too large", blobError.getMessage()); // Too many memo fields trips the message seam (its payloads are individually small). @@ -532,7 +559,7 @@ public void limitsStyleValidatorComposesBothSeams() { .putFields("c", p("3"))))) .build(); TestVisitorException memoError = - assertThrows(TestVisitorException.class, () -> PayloadVisitors.visit(bigMemo, validator)); + assertThrows(TestVisitorException.class, () -> visit(bigMemo, validator)); assertEquals("too many memo fields", memoError.getMessage()); } @@ -553,7 +580,7 @@ public void visitsNestedFailureCauses() { .build(); CollectingVisitor counter = new CollectingVisitor(); - PayloadVisitors.visit(failure, options(counter)); + visit(failure, options(counter)); assertEquals(2, counter.seen.size()); assertTrue(counter.seen.contains("d1")); assertTrue(counter.seen.contains("d2")); @@ -565,13 +592,12 @@ public void roundTripsPayloadInsideAny() throws Exception { Message message = Message.newBuilder().setBody(Any.pack(memo)).build(); CollectingVisitor counter = new CollectingVisitor(); - Message result = PayloadVisitors.visit(message, options(counter)); + Message result = visit(message, options(counter)); assertEquals(Collections.singletonList("inside-any"), counter.seen); // Mutating through the Any re-packs correctly. Message mutated = - PayloadVisitors.visit( - message, options((ctx, pls) -> Collections.singletonList(p("changed")))); + visit(message, options((ctx, pls) -> Collections.singletonList(p("changed")))); Memo unpacked = mutated.getBody().unpack(Memo.class); assertEquals("changed", data(unpacked.getFieldsMap().get("k"))); // Unrelated content unchanged. @@ -589,7 +615,7 @@ public void leavesUnknownAnyUntouched() throws Exception { .setValue(ByteString.copyFromUtf8("opaque"))) .build(); CollectingVisitor counter = new CollectingVisitor(); - Message result = PayloadVisitors.visit(message, options(counter)); + Message result = visit(message, options(counter)); assertTrue(counter.seen.isEmpty()); assertEquals(message, result); } @@ -602,7 +628,7 @@ public void messageWithoutPayloadsReturnedUnchanged() { io.temporal.api.command.v1.CancelWorkflowExecutionCommandAttributes.newBuilder()) .build(); CollectingVisitor counter = new CollectingVisitor(); - Command result = PayloadVisitors.visit(command, options(counter)); + Command result = visit(command, options(counter)); assertTrue(counter.seen.isEmpty()); assertEquals(command, result); } @@ -618,7 +644,7 @@ public void propagatesVisitorError() { assertThrows( TestVisitorException.class, () -> - PayloadVisitors.visit( + visit( request, options( (ctx, pls) -> { @@ -639,9 +665,9 @@ public void messageVisitorErrorPropagates() { assertThrows( TestVisitorException.class, () -> - PayloadVisitors.visit( + visit( command, - PayloadVisitorOptions.newBuilder((ctx, pls) -> pls) + PayloadVisitorOptions.newBuilder(toAsync((ctx, pls) -> pls)) .setMessageVisitor( (current, msg) -> { throw boom; @@ -687,25 +713,20 @@ public void nullReturnFromVisitorFails() { .setScheduleActivityTaskCommandAttributes( ScheduleActivityTaskCommandAttributes.newBuilder().setInput(payloads("x"))) .build(); - assertThrows( - IllegalStateException.class, - () -> PayloadVisitors.visit(command, options((ctx, pls) -> null))); + assertThrows(IllegalStateException.class, () -> visit(command, options((ctx, pls) -> null))); } - // --- Concurrency and executor --- + // --- Concurrency --- + // + // These tests simulate an I/O-backed visitor with futures completed on a test-local thread pool, + // so concurrency produces real overlap. @Test public void rejectsConcurrencyBelowOne() { assertThrows( IllegalArgumentException.class, - () -> PayloadVisitorOptions.newBuilder((ctx, pls) -> pls).setConcurrency(0).build()); - } - - @Test - public void rejectsConcurrencyAboveOneWithoutExecutor() { - assertThrows( - IllegalArgumentException.class, - () -> PayloadVisitorOptions.newBuilder((ctx, pls) -> pls).setConcurrency(2).build()); + () -> + PayloadVisitorOptions.newBuilder(toAsync((ctx, pls) -> pls)).setConcurrency(0).build()); } /** @@ -721,25 +742,29 @@ static RespondWorkflowTaskCompletedRequest requestWithInputs(int n) { } @Test - public void concurrencyEqualToWorkAllowsFullOverlap() throws Exception { + public void concurrencyEqualToWorkAllowsFullOverlap() { int n = 4; RespondWorkflowTaskCompletedRequest request = requestWithInputs(n); CyclicBarrier barrier = new CyclicBarrier(n); - // Each of the n visits must reach the barrier simultaneously, proving n concurrent visits. - PayloadVisitors.visit( + // All n visits must reach the barrier at once; with fewer than n in flight it would time out. + visit( request, - PayloadVisitorOptions.newBuilder( - (ctx, pls) -> { - try { - barrier.await(5, TimeUnit.SECONDS); - } catch (InterruptedException | BrokenBarrierException | TimeoutException e) { - throw new RuntimeException(e); - } - return pls; - }) + PayloadVisitorOptions.newBuilder( + (ctx, pls) -> + CompletableFuture.supplyAsync( + () -> { + try { + barrier.await(5, TimeUnit.SECONDS); + } catch (InterruptedException + | BrokenBarrierException + | TimeoutException e) { + throw new RuntimeException(e); + } + return pls; + }, + executor)) .setConcurrency(n) - .setExecutor(executor) .build()); } @@ -752,22 +777,24 @@ public void boundedConcurrencyNeverExceedsLimit() { AtomicInteger inFlight = new AtomicInteger(); AtomicInteger maxInFlight = new AtomicInteger(); - PayloadVisitors.visit( + visit( request, - PayloadVisitorOptions.newBuilder( - (ctx, pls) -> { - int now = inFlight.incrementAndGet(); - maxInFlight.accumulateAndGet(now, Math::max); - try { - Thread.sleep(20); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - } - inFlight.decrementAndGet(); - return pls; - }) + PayloadVisitorOptions.newBuilder( + (ctx, pls) -> + CompletableFuture.supplyAsync( + () -> { + int now = inFlight.incrementAndGet(); + maxInFlight.accumulateAndGet(now, Math::max); + try { + Thread.sleep(20); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + inFlight.decrementAndGet(); + return pls; + }, + executor)) .setConcurrency(limit) - .setExecutor(executor) .build()); assertTrue( @@ -782,18 +809,22 @@ public void sequentialConcurrencyVisitsOneAtATimeInOrder() { AtomicInteger inFlight = new AtomicInteger(); AtomicInteger maxInFlight = new AtomicInteger(); - List order = new ArrayList<>(); + List order = Collections.synchronizedList(new ArrayList<>()); - PayloadVisitors.visit( + // Concurrency 1 awaits each visit before the next, so even async visits run one at a time. + visit( request, - PayloadVisitorOptions.newBuilder( - (ctx, pls) -> { - int now = inFlight.incrementAndGet(); - maxInFlight.accumulateAndGet(now, Math::max); - order.add(pls.get(0).getData().toStringUtf8()); - inFlight.decrementAndGet(); - return pls; - }) + PayloadVisitorOptions.newBuilder( + (ctx, pls) -> + CompletableFuture.supplyAsync( + () -> { + int now = inFlight.incrementAndGet(); + maxInFlight.accumulateAndGet(now, Math::max); + order.add(pls.get(0).getData().toStringUtf8()); + inFlight.decrementAndGet(); + return pls; + }, + executor)) .setConcurrency(1) .build()); @@ -809,21 +840,24 @@ public void sequentialConcurrencyVisitsOneAtATimeInOrder() { public void concurrentVisitorErrorPropagates() { RespondWorkflowTaskCompletedRequest request = requestWithInputs(8); TestVisitorException boom = new TestVisitorException("boom"); + // The failure arrives as an exceptionally-completed future; it must surface unchanged. TestVisitorException thrown = assertThrows( TestVisitorException.class, () -> - PayloadVisitors.visit( + visit( request, - PayloadVisitorOptions.newBuilder( - (ctx, pls) -> { - if (pls.get(0).getData().toStringUtf8().equals("p5")) { - throw boom; - } - return pls; - }) + PayloadVisitorOptions.newBuilder( + (ctx, pls) -> + CompletableFuture.supplyAsync( + () -> { + if (pls.get(0).getData().toStringUtf8().equals("p5")) { + throw boom; + } + return pls; + }, + executor)) .setConcurrency(4) - .setExecutor(executor) .build())); assertSame(boom, thrown); } @@ -832,19 +866,23 @@ public void concurrentVisitorErrorPropagates() { public void mutationsAppliedCorrectlyUnderConcurrency() { int n = 16; RespondWorkflowTaskCompletedRequest request = requestWithInputs(n); + // Visits complete out of order, but write-backs apply in walk order, so each input is correct. RespondWorkflowTaskCompletedRequest mutated = - PayloadVisitors.visit( + visit( request, - PayloadVisitorOptions.newBuilder( - (ctx, pls) -> { - Payload p = pls.get(0); - return Collections.singletonList( - p.toBuilder() - .setData(ByteString.copyFromUtf8(p.getData().toStringUtf8() + "!")) - .build()); - }) + PayloadVisitorOptions.newBuilder( + (ctx, pls) -> + CompletableFuture.supplyAsync( + () -> { + Payload p = pls.get(0); + return Collections.singletonList( + p.toBuilder() + .setData( + ByteString.copyFromUtf8(p.getData().toStringUtf8() + "!")) + .build()); + }, + executor)) .setConcurrency(8) - .setExecutor(executor) .build()); for (int i = 0; i < n; i++) { assertEquals( @@ -860,68 +898,28 @@ public void mutationsAppliedCorrectlyUnderConcurrency() { } @Test - public void concurrentVisitsRunOnProvidedExecutor() throws InterruptedException { - AtomicInteger threadSeq = new AtomicInteger(); - ThreadFactory factory = - r -> { - Thread t = new Thread(r); - t.setName("pv-exec-" + threadSeq.incrementAndGet()); - return t; - }; - ExecutorService pool = Executors.newFixedThreadPool(4, factory); - try { - RespondWorkflowTaskCompletedRequest request = requestWithInputs(12); - Set visitThreads = ConcurrentHashMap.newKeySet(); - - PayloadVisitors.visit( - request, - PayloadVisitorOptions.newBuilder( - (ctx, pls) -> { - visitThreads.add(Thread.currentThread().getName()); - return pls; - }) - .setConcurrency(4) - .setExecutor(pool) - .build()); - - assertTrue("no visits recorded", !visitThreads.isEmpty()); - for (String name : visitThreads) { - assertTrue("visit ran off the provided executor: " + name, name.startsWith("pv-exec-")); - } - } finally { - pool.shutdownNow(); - pool.awaitTermination(5, TimeUnit.SECONDS); - } - } + public void entryPointReturnsPendingFutureWhileVisitInFlight() throws Exception { + RespondWorkflowTaskCompletedRequest request = requestWithInputs(1); + // A visit future we complete by hand, to observe the traversal future's state meanwhile. + CompletableFuture> gate = new CompletableFuture<>(); - @Test - public void sequentialConcurrencyIgnoresExecutorAndRunsOnCallingThread() { - // With concurrency == 1 the executor must never be touched and visits run inline. - AtomicInteger submittedToExecutor = new AtomicInteger(); - Executor tripwire = - command -> { - submittedToExecutor.incrementAndGet(); - command.run(); - }; - - RespondWorkflowTaskCompletedRequest request = requestWithInputs(5); - String callingThread = Thread.currentThread().getName(); - AtomicReference sawDifferentThread = new AtomicReference<>(); + CompletableFuture result = + PayloadVisitors.visit( + request, PayloadVisitorOptions.newBuilder((ctx, pls) -> gate).build()); - PayloadVisitors.visit( - request, - PayloadVisitorOptions.newBuilder( - (ctx, pls) -> { - if (!Thread.currentThread().getName().equals(callingThread)) { - sawDifferentThread.set(Thread.currentThread().getName()); - } - return pls; - }) - .setConcurrency(1) - .setExecutor(tripwire) - .build()); + // The caller is not blocked: the traversal future is pending while the visit is outstanding. + assertFalse(result.isDone()); - assertEquals("executor was used for sequential traversal", 0, submittedToExecutor.get()); - assertEquals("visit ran off the calling thread", null, sawDifferentThread.get()); + gate.complete(Collections.singletonList(p("done"))); + RespondWorkflowTaskCompletedRequest mutated = result.get(5, TimeUnit.SECONDS); + assertEquals( + "done", + mutated + .getCommands(0) + .getScheduleActivityTaskCommandAttributes() + .getInput() + .getPayloads(0) + .getData() + .toStringUtf8()); } } From 34c3c318b668cf0ccce6a30f7f50684f158bda7e Mon Sep 17 00:00:00 2001 From: jmaeagle99 <44687433+jmaeagle99@users.noreply.github.com> Date: Tue, 16 Jun 2026 20:31:29 -0700 Subject: [PATCH 5/5] fix comment format --- .../io/temporal/internal/payload/visitor/PayloadVisitors.java | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/PayloadVisitors.java b/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/PayloadVisitors.java index 67c3fb5950..69e9924123 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/PayloadVisitors.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/PayloadVisitors.java @@ -4,9 +4,7 @@ import java.util.concurrent.CompletableFuture; import javax.annotation.Nonnull; -/** - * Visits every payload within a proto message. - */ +/** Visits every payload within a proto message. */ final class PayloadVisitors { private PayloadVisitors() {}