From 86dfa9f69500a86b4467f5e776302cebcc79b2f8 Mon Sep 17 00:00:00 2001 From: Sam Whittle Date: Fri, 12 Jun 2026 12:53:54 +0200 Subject: [PATCH 1/3] [FnApi Java] Add support for separate named data streams to provide bundle isolation. This is advertised to the runner via a new NAMED_DATA_STREAMS protocol capability. The runner is then free to assign bundles to named data streams as it chooses to isolate bundle processing from each other. Instead of single data stream from the sdk, the sdk will create a data stream for each name. The benefit of doing so is that the multiplexing currently performed on data stream messages being received allows a slow bundle to fill up buffers and block the shared stream. With separate named streams, bundles on other data streams have separate grpc flow control from the blocked stream and are not affected. --- .../model/fn_execution/v1/beam_fn_api.proto | 34 +++++- .../model/pipeline/v1/beam_runner_api.proto | 4 + .../fn/data/BeamFnDataGrpcMultiplexer.java | 2 +- .../fn/data/BeamFnDataOutboundAggregator.java | 59 ++++++---- .../sdk/util/construction/Environments.java | 1 + .../BeamFnDataOutboundAggregatorTest.java | 103 ++++++++---------- .../org/apache/beam/fn/harness/FnHarness.java | 54 +++++---- .../harness/control/ProcessBundleHandler.java | 40 ++++--- .../fn/harness/data/BeamFnDataClient.java | 28 ++--- .../fn/harness/data/BeamFnDataGrpcClient.java | 91 +++++++++++----- .../fn/harness/BeamFnDataWriteRunnerTest.java | 41 ++++--- .../beam/fn/harness/FnApiDoFnRunnerTest.java | 2 +- .../PTransformRunnerFactoryTestContext.java | 12 +- .../control/ProcessBundleHandlerTest.java | 47 ++++---- .../data/BeamFnDataGrpcClientTest.java | 41 +++---- 15 files changed, 307 insertions(+), 252 deletions(-) diff --git a/model/fn-execution/src/main/proto/org/apache/beam/model/fn_execution/v1/beam_fn_api.proto b/model/fn-execution/src/main/proto/org/apache/beam/model/fn_execution/v1/beam_fn_api.proto index ecef3f2e7a94..6543f7d16c18 100644 --- a/model/fn-execution/src/main/proto/org/apache/beam/model/fn_execution/v1/beam_fn_api.proto +++ b/model/fn-execution/src/main/proto/org/apache/beam/model/fn_execution/v1/beam_fn_api.proto @@ -120,6 +120,9 @@ message RemoteGrpcPort { service BeamFnControl { // Instructions sent by the runner to the SDK requesting different types // of work. + // + // Header metadata has the specified keys pairs: + // - "worker_id": the id of the sdk rpc Control( // A stream of responses to instructions the SDK was asked to be // performed. @@ -130,6 +133,9 @@ service BeamFnControl { // Used to get the full process bundle descriptors for bundles one // is asked to process. + // + // Header metadata has the specified keys pairs: + // - "worker_id": the id of the sdk rpc GetProcessBundleDescriptor(GetProcessBundleDescriptorRequest) returns ( ProcessBundleDescriptor) {} } @@ -416,14 +422,21 @@ message ProcessBundleRequest { // at https://s.apache.org/beam-fn-api-control-data-embedding. Elements elements = 3; - // indicates that the runner has no stare for the keys in this bundle + // Indicates that the runner has no state for the keys in this bundle // so SDk can safely begin stateful processing with a locally-generated - // initial empty state + // initial empty state. bool has_no_state = 4; - // indicates that the runner will never process another bundle for the keys + // Indicates that the runner will never process another bundle for the keys // in this bundle so state need not be included in the bundle commit. bool only_bundle_for_keys = 5; + + // (Optional) If non-empty, the ID of the data stream to use for this bundle. + // See comments at BeamFnData.Data for more details. + // + // The runner should only populate this field if the sdk advertises the + // beam:protocol:named_data_streams:v1 capability. + string data_stream_id = 6; } message ProcessBundleResponse { @@ -835,6 +848,11 @@ message Elements { // Stable service BeamFnData { // Used to send data between harnesses. + // + // Header metadata has the specified keys pairs: + // - "worker_id": value is the id of the sdk + // - "data_stream_id": value is the id of the data stream, distinguishing it from other data streams from the same + // sdk. This field should only be populated if requested in a received ProcessBundleRequest from the runner. rpc Data( // A stream of data representing input. stream Elements) @@ -900,6 +918,9 @@ message StateResponse { service BeamFnState { // Used to get/append/clear state stored by the runner on behalf of the SDK. + // + // Header metadata has the specified keys pairs: + // - "worker_id": the id of the sdk rpc State( // A stream of state instructions requested of the runner. stream StateRequest) @@ -1295,6 +1316,11 @@ message LogControl {} service BeamFnLogging { // Allows for the SDK to emit log entries which the runner can // associate with the active job. + // + // Used to get/append/clear state stored by the runner on behalf of the SDK. + // + // Header metadata has the specified keys pairs: + // - "worker_id": the id of the sdk rpc Logging( // A stream of log entries batched into lists emitted by the SDK harness. stream LogEntry.List) @@ -1356,6 +1382,8 @@ message WorkerStatusResponse { // API for SDKs to report debug-related statuses to runner during pipeline execution. service BeamFnWorkerStatus { + // Header metadata has the specified keys pairs: + // - "worker_id": the id of the sdk rpc WorkerStatus (stream WorkerStatusResponse) returns (stream WorkerStatusRequest) {} } diff --git a/model/pipeline/src/main/proto/org/apache/beam/model/pipeline/v1/beam_runner_api.proto b/model/pipeline/src/main/proto/org/apache/beam/model/pipeline/v1/beam_runner_api.proto index 67df8b9e8003..5824c9bf4b73 100644 --- a/model/pipeline/src/main/proto/org/apache/beam/model/pipeline/v1/beam_runner_api.proto +++ b/model/pipeline/src/main/proto/org/apache/beam/model/pipeline/v1/beam_runner_api.proto @@ -1689,6 +1689,10 @@ message StandardProtocols { // Indicates whether the SDK supports multimap state. MULTIMAP_STATE = 12 [(beam_urn) = "beam:protocol:multimap_state:v1"]; + + // Indicates whether the SDK supports data stream ids being requested by the runner in + // ProcessBundleRequests. + NAMED_DATA_STREAMS = 13 [(beam_urn) = "beam:protocol:named_data_streams:v1"]; } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/data/BeamFnDataGrpcMultiplexer.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/data/BeamFnDataGrpcMultiplexer.java index 8fec8b455cce..0b9d6adab4f0 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/data/BeamFnDataGrpcMultiplexer.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/data/BeamFnDataGrpcMultiplexer.java @@ -63,7 +63,7 @@ public class BeamFnDataGrpcMultiplexer implements AutoCloseable { private final Cache poisonedInstructionIds; private static class PoisonedException extends RuntimeException { - public PoisonedException() { + private PoisonedException() { super("Instruction poisoned"); } }; diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/data/BeamFnDataOutboundAggregator.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/data/BeamFnDataOutboundAggregator.java index 9b9603706b48..d2c3ada870c3 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/data/BeamFnDataOutboundAggregator.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/data/BeamFnDataOutboundAggregator.java @@ -17,6 +17,8 @@ */ package org.apache.beam.sdk.fn.data; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; + import java.io.IOException; import java.util.Collections; import java.util.HashMap; @@ -28,7 +30,6 @@ import java.util.concurrent.Executors; import java.util.concurrent.ScheduledFuture; import java.util.concurrent.TimeUnit; -import java.util.function.Supplier; import javax.annotation.Nullable; import javax.annotation.concurrent.NotThreadSafe; import org.apache.beam.model.fnexecution.v1.BeamFnApi.Elements; @@ -56,7 +57,7 @@ @SuppressWarnings({ "nullness" // TODO(https://github.com/apache/beam/issues/20497) }) -// The calling thread that invokes sendBufferedDataAndFinishOutboundStreams synchronizes on +// The calling thread that invokes sendOrCollectBufferedDataAndFinishOutboundStreams synchronizes on // flushLock effectively making the periodic flushing no longer read or mutate hasFlushedForBundle // and allowing the calling thread to read and mutate hasFlushedForBundle safely without needing to // create another memory barrier. Also note that flush is always invoked when synchronizing on @@ -72,33 +73,40 @@ public class BeamFnDataOutboundAggregator { private static final Logger LOG = LoggerFactory.getLogger(BeamFnDataOutboundAggregator.class); private final int sizeLimit; private final long timeLimit; - private final Supplier processBundleRequestIdSupplier; + private String instructionId; @VisibleForTesting final Map> outputDataReceivers; @VisibleForTesting final Map> outputTimersReceivers; - private final StreamObserver outboundObserver; + @Nullable private StreamObserver outboundObserver; @Nullable @VisibleForTesting ScheduledFuture flushFuture; private long bytesWrittenSinceFlush; private final Object flushLock; private final boolean collectElementsIfNoFlushes; private boolean hasFlushedForBundle; - public BeamFnDataOutboundAggregator( - PipelineOptions options, - Supplier processBundleRequestIdSupplier, - StreamObserver outboundObserver, - boolean collectElementsIfNoFlushes) { + public BeamFnDataOutboundAggregator(PipelineOptions options, boolean collectElementsIfNoFlushes) { this.sizeLimit = getSizeLimit(options); this.timeLimit = getTimeLimit(options); this.collectElementsIfNoFlushes = collectElementsIfNoFlushes; this.outputDataReceivers = new HashMap<>(); this.outputTimersReceivers = new HashMap<>(); - this.outboundObserver = outboundObserver; - this.processBundleRequestIdSupplier = processBundleRequestIdSupplier; this.bytesWrittenSinceFlush = 0L; this.flushLock = new Object(); this.hasFlushedForBundle = false; } + public void prepareForInstruction( + String instructionId, StreamObserver outboundObserver) { + if (timeLimit > 0) { + synchronized (flushLock) { + this.instructionId = instructionId; + this.outboundObserver = outboundObserver; + } + } else { + this.instructionId = instructionId; + this.outboundObserver = outboundObserver; + } + } + /** Starts the flushing daemon thread if data_buffer_time_limit_ms is set. */ public void start() { if (timeLimit > 0 && this.flushFuture == null) { @@ -166,7 +174,7 @@ private void flushInternal() { } Elements.Builder elements = convertBufferForTransmission(); if (elements.getDataCount() > 0 || elements.getTimersCount() > 0) { - outboundObserver.onNext(elements.build()); + checkNotNull(outboundObserver).onNext(elements.build()); } hasFlushedForBundle = true; } @@ -177,6 +185,7 @@ private void flushInternal() { * collectElementsIfNoFlushes=true, and there was no previous flush in this bundle, otherwise * returns null. */ + @Nullable public Elements sendOrCollectBufferedDataAndFinishOutboundStreams() { if (outputTimersReceivers.isEmpty() && outputDataReceivers.isEmpty()) { return null; @@ -189,16 +198,17 @@ public Elements sendOrCollectBufferedDataAndFinishOutboundStreams() { } else { bufferedElements = convertBufferForTransmission(); } + checkNotNull(instructionId); LOG.debug( "Closing streams for instruction {} and outbound data {} and timers {}.", - processBundleRequestIdSupplier.get(), + instructionId, outputDataReceivers, outputTimersReceivers); for (Map.Entry> entry : outputDataReceivers.entrySet()) { String pTransformId = entry.getKey(); bufferedElements .addDataBuilder() - .setInstructionId(processBundleRequestIdSupplier.get()) + .setInstructionId(instructionId) .setTransformId(pTransformId) .setIsLast(true); entry.getValue().resetStats(); @@ -207,34 +217,37 @@ public Elements sendOrCollectBufferedDataAndFinishOutboundStreams() { TimerEndpoint timerKey = entry.getKey(); bufferedElements .addTimersBuilder() - .setInstructionId(processBundleRequestIdSupplier.get()) + .setInstructionId(instructionId) .setTransformId(timerKey.pTransformId) .setTimerFamilyId(timerKey.timerFamilyId) .setIsLast(true); entry.getValue().resetStats(); } + // This is the end of the bundle so we reset state to prepare for future bundles. + instructionId = null; if (collectElementsIfNoFlushes && !hasFlushedForBundle) { + outboundObserver = null; return bufferedElements.build(); } - outboundObserver.onNext(bufferedElements.build()); - // This is now at the end of a bundle, so we reset hasFlushedForBundle to prepare for new - // bundles. + checkNotNull(outboundObserver).onNext(bufferedElements.build()); + outboundObserver = null; hasFlushedForBundle = false; return null; } // Send the elements to the StreamObserver associated with this aggregator. public void sendElements(Elements elements) { - outboundObserver.onNext(elements); + checkNotNull(outboundObserver).onNext(elements); } public void discard() { if (flushFuture != null) { - flushFuture.cancel(true); + flushFuture.cancel(false); } } private Elements.Builder convertBufferForTransmission() { + checkNotNull(instructionId); Elements.Builder bufferedElements = Elements.newBuilder(); for (Map.Entry> entry : outputDataReceivers.entrySet()) { if (!entry.getValue().hasBufferedOutput()) { @@ -243,7 +256,7 @@ private Elements.Builder convertBufferForTransmission() { ByteString bytes = entry.getValue().toByteStringAndResetBuffer(); bufferedElements .addDataBuilder() - .setInstructionId(processBundleRequestIdSupplier.get()) + .setInstructionId(instructionId) .setTransformId(entry.getKey()) .setData(bytes); } @@ -254,7 +267,7 @@ private Elements.Builder convertBufferForTransmission() { ByteString bytes = entry.getValue().toByteStringAndResetBuffer(); bufferedElements .addTimersBuilder() - .setInstructionId(processBundleRequestIdSupplier.get()) + .setInstructionId(instructionId) .setTransformId(entry.getKey().pTransformId) .setTimerFamilyId(entry.getKey().timerFamilyId) .setTimers(bytes); @@ -353,10 +366,12 @@ public void accept(T input) throws Exception { } } + @VisibleForTesting public long getByteCount() { return perBundleByteCount; } + @VisibleForTesting public long getElementCount() { return perBundleElementCount; } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/construction/Environments.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/construction/Environments.java index 969bda88d07f..c3b1a7a5235e 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/construction/Environments.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/construction/Environments.java @@ -522,6 +522,7 @@ public static Set getJavaCapabilities() { capabilities.add(BeamUrns.getUrn(StandardProtocols.Enum.SDK_CONSUMING_RECEIVED_DATA)); capabilities.add(BeamUrns.getUrn(StandardProtocols.Enum.ORDERED_LIST_STATE)); capabilities.add(BeamUrns.getUrn(StandardProtocols.Enum.MULTIMAP_STATE)); + capabilities.add(BeamUrns.getUrn(StandardProtocols.Enum.NAMED_DATA_STREAMS)); return capabilities.build(); } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/fn/data/BeamFnDataOutboundAggregatorTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/fn/data/BeamFnDataOutboundAggregatorTest.java index 092ba200c94b..a5143eb530d2 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/fn/data/BeamFnDataOutboundAggregatorTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/fn/data/BeamFnDataOutboundAggregatorTest.java @@ -75,13 +75,12 @@ public void testWithDefaultBuffer() throws Exception { final List values = new ArrayList<>(); final AtomicBoolean onCompletedWasCalled = new AtomicBoolean(); BeamFnDataOutboundAggregator aggregator = - new BeamFnDataOutboundAggregator( - PipelineOptionsFactory.create(), - endpoint::getInstructionId, - TestStreams.withOnNext(values::add) - .withOnCompleted(() -> onCompletedWasCalled.set(true)) - .build(), - false); + new BeamFnDataOutboundAggregator(PipelineOptionsFactory.create(), false); + aggregator.prepareForInstruction( + endpoint.getInstructionId(), + TestStreams.withOnNext(values::add) + .withOnCompleted(() -> onCompletedWasCalled.set(true)) + .build()); // Test that nothing is emitted till the default buffer size is surpassed. FnDataReceiver dataReceiver = registerOutputLocation(aggregator, endpoint, CODER); @@ -124,14 +123,12 @@ public void testConfiguredBufferLimit() throws Exception { options .as(ExperimentalOptions.class) .setExperiments(Arrays.asList("data_buffer_size_limit=100")); - BeamFnDataOutboundAggregator aggregator = - new BeamFnDataOutboundAggregator( - options, - endpoint::getInstructionId, - TestStreams.withOnNext(values::add) - .withOnCompleted(() -> onCompletedWasCalled.set(true)) - .build(), - false); + BeamFnDataOutboundAggregator aggregator = new BeamFnDataOutboundAggregator(options, false); + aggregator.prepareForInstruction( + endpoint.getInstructionId(), + TestStreams.withOnNext(values::add) + .withOnCompleted(() -> onCompletedWasCalled.set(true)) + .build()); // Test that nothing is emitted till the default buffer size is surpassed. FnDataReceiver dataReceiver = registerOutputLocation(aggregator, endpoint, CODER); aggregator.start(); @@ -187,18 +184,16 @@ public void testConfiguredTimeLimit() throws Exception { .as(ExperimentalOptions.class) .setExperiments(Arrays.asList("data_buffer_time_limit_ms=1")); final CountDownLatch waitForFlush = new CountDownLatch(1); - BeamFnDataOutboundAggregator aggregator = - new BeamFnDataOutboundAggregator( - options, - endpoint::getInstructionId, - TestStreams.withOnNext( - (Consumer) - e -> { - values.add(e); - waitForFlush.countDown(); - }) - .build(), - false); + BeamFnDataOutboundAggregator aggregator = new BeamFnDataOutboundAggregator(options, false); + aggregator.prepareForInstruction( + endpoint.getInstructionId(), + TestStreams.withOnNext( + (Consumer) + e -> { + values.add(e); + waitForFlush.countDown(); + }) + .build()); // Test that it emits when time passed the time limit FnDataReceiver dataReceiver = registerOutputLocation(aggregator, endpoint, CODER); @@ -214,17 +209,15 @@ public void testConfiguredTimeLimitExceptionPropagation() throws Exception { options .as(ExperimentalOptions.class) .setExperiments(Arrays.asList("data_buffer_time_limit_ms=1")); - BeamFnDataOutboundAggregator aggregator = - new BeamFnDataOutboundAggregator( - options, - endpoint::getInstructionId, - TestStreams.withOnNext( - (Consumer) - e -> { - throw new RuntimeException(""); - }) - .build(), - false); + BeamFnDataOutboundAggregator aggregator = new BeamFnDataOutboundAggregator(options, false); + aggregator.prepareForInstruction( + endpoint.getInstructionId(), + TestStreams.withOnNext( + (Consumer) + e -> { + throw new RuntimeException(""); + }) + .build()); // Test that it emits when time passed the time limit FnDataReceiver dataReceiver = registerOutputLocation(aggregator, endpoint, CODER); @@ -243,17 +236,15 @@ public void testConfiguredTimeLimitExceptionPropagation() throws Exception { // expected } - aggregator = - new BeamFnDataOutboundAggregator( - options, - endpoint::getInstructionId, - TestStreams.withOnNext( - (Consumer) - e -> { - throw new RuntimeException(""); - }) - .build(), - false); + aggregator = new BeamFnDataOutboundAggregator(options, false); + aggregator.prepareForInstruction( + endpoint.getInstructionId(), + TestStreams.withOnNext( + (Consumer) + e -> { + throw new RuntimeException(""); + }) + .build()); dataReceiver = registerOutputLocation(aggregator, endpoint, CODER); aggregator.start(); dataReceiver.accept(new byte[1]); @@ -279,14 +270,12 @@ public void testConfiguredBufferLimitMultipleEndpoints() throws Exception { options .as(ExperimentalOptions.class) .setExperiments(Arrays.asList("data_buffer_size_limit=100")); - BeamFnDataOutboundAggregator aggregator = - new BeamFnDataOutboundAggregator( - options, - endpoint::getInstructionId, - TestStreams.withOnNext(values::add) - .withOnCompleted(() -> onCompletedWasCalled.set(true)) - .build(), - false); + BeamFnDataOutboundAggregator aggregator = new BeamFnDataOutboundAggregator(options, false); + aggregator.prepareForInstruction( + endpoint.getInstructionId(), + TestStreams.withOnNext(values::add) + .withOnCompleted(() -> onCompletedWasCalled.set(true)) + .build()); // Test that nothing is emitted till the default buffer size is surpassed. LogicalEndpoint additionalEndpoint = LogicalEndpoint.data( diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnHarness.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnHarness.java index 703e726739a0..009de1b44e60 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnHarness.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnHarness.java @@ -17,6 +17,8 @@ */ package org.apache.beam.fn.harness; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; + import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; import java.nio.charset.StandardCharsets; @@ -30,7 +32,6 @@ import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.function.Function; -import javax.annotation.Nullable; import org.apache.beam.fn.harness.control.BeamFnControlClient; import org.apache.beam.fn.harness.control.ExecutionStateSampler; import org.apache.beam.fn.harness.control.FinalizeBundleHandler; @@ -72,6 +73,7 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.MoreExecutors; +import org.checkerframework.checker.nullness.qual.Nullable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -92,9 +94,6 @@ * for further details. * */ -@SuppressWarnings({ - "nullness" // TODO(https://github.com/apache/beam/issues/20497) -}) public class FnHarness { private static final String HARNESS_ID = "HARNESS_ID"; private static final String CONTROL_API_SERVICE_DESCRIPTOR = "CONTROL_API_SERVICE_DESCRIPTOR"; @@ -138,22 +137,30 @@ private static void removeKeyRecursively(JsonNode node, String keyToRemove) { } public static void main(String[] args) throws Exception { - main(System::getenv); + Function environmentVarGetter = System::getenv; + main(environmentVarGetter); } @VisibleForTesting - public static void main(Function environmentVarGetter) throws Exception { + public static void main(Function environmentVarGetter) + throws Exception { JvmInitializers.runOnStartup(); Endpoints.ApiServiceDescriptor loggingApiServiceDescriptor = - getApiServiceDescriptor(environmentVarGetter.apply(LOGGING_API_SERVICE_DESCRIPTOR)); + getApiServiceDescriptor( + checkNotNull( + environmentVarGetter.apply(LOGGING_API_SERVICE_DESCRIPTOR), + "LOGGING_API_SERVICE_DESCRIPTOR env var be set.")); Endpoints.ApiServiceDescriptor controlApiServiceDescriptor = - getApiServiceDescriptor(environmentVarGetter.apply(CONTROL_API_SERVICE_DESCRIPTOR)); + getApiServiceDescriptor( + checkNotNull( + environmentVarGetter.apply(CONTROL_API_SERVICE_DESCRIPTOR), + "CONTROL_API_SERVICE_DESCRIPTOR env var be set.")); + + @Nullable String envVar = environmentVarGetter.apply(STATUS_API_SERVICE_DESCRIPTOR); Endpoints.ApiServiceDescriptor statusApiServiceDescriptor = - environmentVarGetter.apply(STATUS_API_SERVICE_DESCRIPTOR) == null - ? null - : getApiServiceDescriptor(environmentVarGetter.apply(STATUS_API_SERVICE_DESCRIPTOR)); - String id = environmentVarGetter.apply(HARNESS_ID); + (envVar == null) ? null : getApiServiceDescriptor(envVar); + String id = checkNotNull(environmentVarGetter.apply(HARNESS_ID), "HARNESS_ID env var be set."); System.out.format("SDK Fn Harness started%n"); System.out.format("Harness ID %s%n", id); @@ -161,11 +168,11 @@ public static void main(Function environmentVarGetter) throws Ex System.out.format("Control location %s%n", controlApiServiceDescriptor); System.out.format("Status location %s%n", statusApiServiceDescriptor); - String pipelineOptionsJson = environmentVarGetter.apply(PIPELINE_OPTIONS); + @Nullable String pipelineOptionsJson = environmentVarGetter.apply(PIPELINE_OPTIONS); // Try looking for a file first. If that exists it should override PIPELINE_OPTIONS to avoid // maxing out the kernel's environment space try { - String pipelineOptionsPath = environmentVarGetter.apply(PIPELINE_OPTIONS_FILE); + @Nullable String pipelineOptionsPath = environmentVarGetter.apply(PIPELINE_OPTIONS_FILE); System.out.format("Pipeline Options File %s%n", pipelineOptionsPath); if (pipelineOptionsPath != null) { Path filePath = Paths.get(pipelineOptionsPath); @@ -179,12 +186,13 @@ public static void main(Function environmentVarGetter) throws Ex } catch (Exception e) { System.out.format("Problem loading pipeline options from file: %s%n", e.getMessage()); } - - System.out.format("Pipeline options %s%n", pipelineOptionsJson); - // TODO: https://github.com/apache/beam/issues/30301 - pipelineOptionsJson = removeNestedKey(pipelineOptionsJson, "impersonateServiceAccount"); - - PipelineOptions options = PipelineOptionsTranslation.fromJson(pipelineOptionsJson); + if (pipelineOptionsJson != null) { + System.out.format("Pipeline options %s%n", pipelineOptionsJson); + // TODO: https://github.com/apache/beam/issues/30301 + pipelineOptionsJson = removeNestedKey(pipelineOptionsJson, "impersonateServiceAccount"); + } + PipelineOptions options = + PipelineOptionsTranslation.fromJson(pipelineOptionsJson == null ? "" : pipelineOptionsJson); String runnerCapabilitesOrNull = environmentVarGetter.apply(RUNNER_CAPABILITIES); Set runnerCapabilites = @@ -219,7 +227,7 @@ public static void main( Set runnerCapabilities, Endpoints.ApiServiceDescriptor loggingApiServiceDescriptor, Endpoints.ApiServiceDescriptor controlApiServiceDescriptor, - @Nullable Endpoints.ApiServiceDescriptor statusApiServiceDescriptor) + Endpoints.@Nullable ApiServiceDescriptor statusApiServiceDescriptor) throws Exception { ManagedChannelFactory channelFactory; if (ExperimentalOptions.hasExperiment(options, "beam_fn_api_epoll")) { @@ -263,7 +271,7 @@ public static void main( Set runnerCapabilites, Endpoints.ApiServiceDescriptor loggingApiServiceDescriptor, Endpoints.ApiServiceDescriptor controlApiServiceDescriptor, - Endpoints.ApiServiceDescriptor statusApiServiceDescriptor, + Endpoints.@Nullable ApiServiceDescriptor statusApiServiceDescriptor, ManagedChannelFactory channelFactory, OutboundObserverFactory outboundObserverFactory, Cache processWideCache) @@ -318,7 +326,7 @@ public static void main( BeamFnControlGrpc.newBlockingStub(channel); BeamFnDataGrpcClient beamFnDataMultiplexer = - new BeamFnDataGrpcClient(options, channelFactory::forDescriptor, outboundObserverFactory); + new BeamFnDataGrpcClient(channelFactory::forDescriptor, outboundObserverFactory); BeamFnStateGrpcClientCache beamFnStateGrpcClientCache = new BeamFnStateGrpcClientCache(idGenerator, channelFactory, outboundObserverFactory); diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java index 5a57b137bf6b..621e60c50452 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java @@ -378,9 +378,8 @@ public FnDataReceiver addOutgoingDataEndpoint( outboundAggregatorMap.computeIfAbsent( apiServiceDescriptor, asd -> - queueingClient.createOutboundAggregator( - asd, - processBundleInstructionId, + new BeamFnDataOutboundAggregator( + options, runnerCapabilities.contains( BeamUrns.getUrn( StandardRunnerProtocols.Enum @@ -391,21 +390,19 @@ public FnDataReceiver addOutgoingDataEndpoint( @Override public FnDataReceiver> addOutgoingTimersEndpoint( String timerFamilyId, org.apache.beam.sdk.coders.Coder> coder) { - BeamFnDataOutboundAggregator aggregator; if (!processBundleDescriptor.hasTimerApiServiceDescriptor()) { throw new IllegalStateException( String.format( - "Timers are unsupported because the " - + "ProcessBundleRequest %s does not provide a timer ApiServiceDescriptor.", + "Timers are unsupported because the ProcessBundleRequest %s does not" + + " provide a timer ApiServiceDescriptor.", processBundleInstructionId.get())); } - aggregator = + BeamFnDataOutboundAggregator aggregator = outboundAggregatorMap.computeIfAbsent( processBundleDescriptor.getTimerApiServiceDescriptor(), asd -> - queueingClient.createOutboundAggregator( - asd, - processBundleInstructionId, + new BeamFnDataOutboundAggregator( + options, runnerCapabilities.contains( BeamUrns.getUrn( StandardRunnerProtocols.Enum @@ -499,6 +496,8 @@ public BundleFinalizer getBundleFinalizer() { */ public BeamFnApi.InstructionResponse.Builder processBundle(InstructionRequest request) throws Exception { + String instructionId = request.getInstructionId(); + String dataStreamId = request.getProcessBundle().getDataStreamId(); @Nullable BundleProcessor bundleProcessor = null; try { bundleProcessor = @@ -515,13 +514,20 @@ public BeamFnApi.InstructionResponse.Builder processBundle(InstructionRequest re } })); + for (Map.Entry entry : + bundleProcessor.getOutboundAggregators().entrySet()) { + BeamFnDataOutboundAggregator aggregator = entry.getValue(); + aggregator.prepareForInstruction( + instructionId, beamFnDataClient.getOutboundObserver(entry.getKey(), dataStreamId)); + } + PTransformFunctionRegistry startFunctionRegistry = bundleProcessor.getStartFunctionRegistry(); PTransformFunctionRegistry finishFunctionRegistry = bundleProcessor.getFinishFunctionRegistry(); ExecutionStateTracker stateTracker = bundleProcessor.getStateTracker(); ProcessBundleResponse.Builder response = ProcessBundleResponse.newBuilder(); try (HandleStateCallsForBundle beamFnStateClient = bundleProcessor.getBeamFnStateClient()) { - stateTracker.start(request.getInstructionId()); + stateTracker.start(instructionId); try { // Already in reverse topological order so we don't need to do anything. for (ThrowingRunnable startFunction : startFunctionRegistry.getFunctions()) { @@ -545,12 +551,14 @@ public BeamFnApi.InstructionResponse.Builder processBundle(InstructionRequest re } else if (!bundleProcessor.getInboundEndpointApiServiceDescriptors().isEmpty()) { BeamFnDataInboundObserver observer = bundleProcessor.getInboundObserver(); beamFnDataClient.registerReceiver( - request.getInstructionId(), + instructionId, + dataStreamId, bundleProcessor.getInboundEndpointApiServiceDescriptors(), observer); observer.awaitCompletion(); beamFnDataClient.unregisterReceiver( - request.getInstructionId(), + instructionId, + dataStreamId, bundleProcessor.getInboundEndpointApiServiceDescriptors()); } @@ -581,7 +589,7 @@ public BeamFnApi.InstructionResponse.Builder processBundle(InstructionRequest re if (!bundleProcessor.getBundleFinalizationCallbackRegistrations().isEmpty()) { finalizeBundleHandler.registerCallbacks( - bundleProcessor.getInstructionId(), + instructionId, ImmutableList.copyOf(bundleProcessor.getBundleFinalizationCallbackRegistrations())); response.setRequiresFinalization(true); } @@ -599,7 +607,7 @@ public BeamFnApi.InstructionResponse.Builder processBundle(InstructionRequest re } catch (Exception e) { LOG.debug( "Error processing bundle {} with bundleProcessor for {} after exception", - request.getInstructionId(), + instructionId, request.getProcessBundle().getProcessBundleDescriptorId(), e); if (bundleProcessor != null) { @@ -607,7 +615,7 @@ public BeamFnApi.InstructionResponse.Builder processBundle(InstructionRequest re bundleProcessorCache.discard(bundleProcessor); } // Ensure that if more data arrives for the instruction it is discarded. - beamFnDataClient.poisonInstructionId(request.getInstructionId()); + beamFnDataClient.poisonInstructionId(instructionId); throw e; } } diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/BeamFnDataClient.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/BeamFnDataClient.java index 94d59d0fcb62..1a50f5b448c5 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/BeamFnDataClient.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/BeamFnDataClient.java @@ -18,13 +18,12 @@ package org.apache.beam.fn.harness.data; import java.util.List; -import java.util.function.Supplier; import org.apache.beam.model.fnexecution.v1.BeamFnApi.Elements; import org.apache.beam.model.pipeline.v1.Endpoints; import org.apache.beam.model.pipeline.v1.Endpoints.ApiServiceDescriptor; -import org.apache.beam.sdk.fn.data.BeamFnDataOutboundAggregator; import org.apache.beam.sdk.fn.data.CloseableFnDataReceiver; import org.apache.beam.sdk.fn.data.FnDataReceiver; +import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.stub.StreamObserver; /** * The {@link BeamFnDataClient} is able to forward inbound elements to a {@link FnDataReceiver} and @@ -47,6 +46,7 @@ public interface BeamFnDataClient { */ void registerReceiver( String instructionId, + String dataStreamId, List apiServiceDescriptors, CloseableFnDataReceiver receiver); @@ -58,7 +58,8 @@ void registerReceiver( * to the {@link BeamFnDataClient} during a future {@link FnDataReceiver#accept} invocation or via * a call to {@link #poisonInstructionId}. */ - void unregisterReceiver(String instructionId, List apiServiceDescriptors); + void unregisterReceiver( + String instructionId, String dataStreamId, List apiServiceDescriptors); /** * Poisons the instruction id, indicating that future data arriving for it should be discarded. @@ -68,22 +69,7 @@ void registerReceiver( */ void poisonInstructionId(String instructionId); - /** - * Creates a {@link BeamFnDataOutboundAggregator} for buffering and sending outbound data and - * timers over the data plane. It is important that {@link - * BeamFnDataOutboundAggregator#sendOrCollectBufferedDataAndFinishOutboundStreams()} is called on - * the returned BeamFnDataOutboundAggregator at the end of each bundle. If - * collectElementsIfNoFlushes is set to true, {@link - * BeamFnDataOutboundAggregator#sendOrCollectBufferedDataAndFinishOutboundStreams()} returns the - * buffered elements instead of sending it through the outbound StreamObserver if there's no - * previous flush. - * - *

Closing the returned aggregator signals the end of the streams. - * - *

The returned aggregator is not thread safe. - */ - BeamFnDataOutboundAggregator createOutboundAggregator( - Endpoints.ApiServiceDescriptor apiServiceDescriptor, - Supplier processBundleRequestIdSupplier, - boolean collectElementsIfNoFlushes); + /** Get the outbound observer for the specified apiServiceDescriptor and dataStreamId. */ + StreamObserver getOutboundObserver( + Endpoints.ApiServiceDescriptor apiServiceDescriptor, String dataStreamId); } diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/BeamFnDataGrpcClient.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/BeamFnDataGrpcClient.java index 499d816f8cc0..2f2a6b0fc660 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/BeamFnDataGrpcClient.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/BeamFnDataGrpcClient.java @@ -18,20 +18,22 @@ package org.apache.beam.fn.harness.data; import java.util.List; +import java.util.Objects; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import java.util.function.Function; -import java.util.function.Supplier; +import javax.annotation.Nullable; import org.apache.beam.model.fnexecution.v1.BeamFnApi.Elements; import org.apache.beam.model.fnexecution.v1.BeamFnDataGrpc; import org.apache.beam.model.pipeline.v1.Endpoints; import org.apache.beam.model.pipeline.v1.Endpoints.ApiServiceDescriptor; import org.apache.beam.sdk.fn.data.BeamFnDataGrpcMultiplexer; -import org.apache.beam.sdk.fn.data.BeamFnDataOutboundAggregator; import org.apache.beam.sdk.fn.data.CloseableFnDataReceiver; import org.apache.beam.sdk.fn.stream.OutboundObserverFactory; -import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.ManagedChannel; +import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.Metadata; +import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.stub.MetadataUtils; +import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.stub.StreamObserver; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -44,17 +46,42 @@ public class BeamFnDataGrpcClient implements BeamFnDataClient { private static final Logger LOG = LoggerFactory.getLogger(BeamFnDataGrpcClient.class); - private final ConcurrentMap - multiplexerCache; + private static class MultiplexerKey { + private final Endpoints.ApiServiceDescriptor apiServiceDescriptor; + private final String dataStreamId; + + private MultiplexerKey( + Endpoints.ApiServiceDescriptor apiServiceDescriptor, String dataStreamId) { + this.apiServiceDescriptor = apiServiceDescriptor; + this.dataStreamId = dataStreamId; + } + + @Override + public boolean equals(@Nullable Object o) { + if (this == o) { + return true; + } + if (!(o instanceof MultiplexerKey)) { + return false; + } + MultiplexerKey that = (MultiplexerKey) o; + return Objects.equals(dataStreamId, that.dataStreamId) + && Objects.equals(apiServiceDescriptor, that.apiServiceDescriptor); + } + + @Override + public int hashCode() { + return Objects.hash(apiServiceDescriptor, dataStreamId); + } + } + + private final ConcurrentMap multiplexerCache; private final Function channelFactory; private final OutboundObserverFactory outboundObserverFactory; - private final PipelineOptions options; public BeamFnDataGrpcClient( - PipelineOptions options, Function channelFactory, OutboundObserverFactory outboundObserverFactory) { - this.options = options; this.channelFactory = channelFactory; this.outboundObserverFactory = outboundObserverFactory; this.multiplexerCache = new ConcurrentHashMap<>(); @@ -63,21 +90,22 @@ public BeamFnDataGrpcClient( @Override public void registerReceiver( String instructionId, + String dataStreamId, List apiServiceDescriptors, CloseableFnDataReceiver receiver) { LOG.debug("Registering consumer for {}", instructionId); for (int i = 0, size = apiServiceDescriptors.size(); i < size; i++) { - BeamFnDataGrpcMultiplexer client = getClientFor(apiServiceDescriptors.get(i)); + BeamFnDataGrpcMultiplexer client = getMultiplexer(apiServiceDescriptors.get(i), dataStreamId); client.registerConsumer(instructionId, receiver); } } @Override public void unregisterReceiver( - String instructionId, List apiServiceDescriptors) { + String instructionId, String dataStreamId, List apiServiceDescriptors) { LOG.debug("Unregistering consumer for {}", instructionId); for (int i = 0, size = apiServiceDescriptors.size(); i < size; i++) { - BeamFnDataGrpcMultiplexer client = getClientFor(apiServiceDescriptors.get(i)); + BeamFnDataGrpcMultiplexer client = getMultiplexer(apiServiceDescriptors.get(i), dataStreamId); client.unregisterConsumer(instructionId); } } @@ -91,25 +119,32 @@ public void poisonInstructionId(String instructionId) { } @Override - public BeamFnDataOutboundAggregator createOutboundAggregator( - ApiServiceDescriptor apiServiceDescriptor, - Supplier processBundleRequestIdSupplier, - boolean collectElementsIfNoFlushes) { - return new BeamFnDataOutboundAggregator( - options, - processBundleRequestIdSupplier, - getClientFor(apiServiceDescriptor).getOutboundObserver(), - collectElementsIfNoFlushes); + public StreamObserver getOutboundObserver( + ApiServiceDescriptor apiServiceDescriptor, String dataStreamId) { + return getMultiplexer(apiServiceDescriptor, dataStreamId).getOutboundObserver(); } - private BeamFnDataGrpcMultiplexer getClientFor( - Endpoints.ApiServiceDescriptor apiServiceDescriptor) { + private BeamFnDataGrpcMultiplexer getMultiplexer( + Endpoints.ApiServiceDescriptor apiServiceDescriptor, String dataStreamId) { + MultiplexerKey key = new MultiplexerKey(apiServiceDescriptor, dataStreamId); return multiplexerCache.computeIfAbsent( - apiServiceDescriptor, - (Endpoints.ApiServiceDescriptor descriptor) -> - new BeamFnDataGrpcMultiplexer( - descriptor, - outboundObserverFactory, - BeamFnDataGrpc.newStub(channelFactory.apply(apiServiceDescriptor))::data)); + key, + k -> { + OutboundObserverFactory.BasicFactory baseOutboundObserverFactory = + inboundObserver -> { + BeamFnDataGrpc.BeamFnDataStub stub = + BeamFnDataGrpc.newStub(channelFactory.apply(apiServiceDescriptor)); + if (dataStreamId != null && !dataStreamId.isEmpty()) { + Metadata headers = new Metadata(); + headers.put( + Metadata.Key.of("data_stream_id", Metadata.ASCII_STRING_MARSHALLER), + dataStreamId); + stub = stub.withInterceptors(MetadataUtils.newAttachHeadersInterceptor(headers)); + } + return stub.data(inboundObserver); + }; + return new BeamFnDataGrpcMultiplexer( + apiServiceDescriptor, outboundObserverFactory, baseOutboundObserverFactory); + }); } } diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataWriteRunnerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataWriteRunnerTest.java index 70a894e7b375..6c8abdbb3adf 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataWriteRunnerTest.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataWriteRunnerTest.java @@ -32,7 +32,6 @@ import java.util.Map; import java.util.ServiceLoader; import java.util.concurrent.atomic.AtomicReference; -import java.util.function.Supplier; import org.apache.beam.fn.harness.PTransformRunnerFactory.Registrar; import org.apache.beam.fn.harness.data.BeamFnDataClient; import org.apache.beam.model.fnexecution.v1.BeamFnApi; @@ -107,19 +106,29 @@ public void setUp() { MockitoAnnotations.initMocks(this); } - private BeamFnDataOutboundAggregator createRecordingAggregator( - Map>> output, Supplier bundleId) { + @Test + public void testReuseForMultipleBundles() throws Exception { + AtomicReference bundleId = new AtomicReference<>("0"); + String localInputId = "inputPC"; + RunnerApi.PTransform pTransform = + RemoteGrpcPortWrite.writeToPort(localInputId, PORT_SPEC).toPTransform(); + + List> output0 = new ArrayList<>(); + List> output1 = new ArrayList<>(); + Map aggregators = new HashMap<>(); + PipelineOptions options = PipelineOptionsFactory.create(); options.as(ExperimentalOptions.class).setExperiments(Arrays.asList("data_buffer_size_limit=0")); - return new BeamFnDataOutboundAggregator( - options, - bundleId, + BeamFnDataOutboundAggregator aggregator = new BeamFnDataOutboundAggregator(options, false); + + Map>> outputs = ImmutableMap.of("0", output0, "1", output1); + StreamObserver observer = new StreamObserver() { @Override public void onNext(Elements elements) { for (Data data : elements.getDataList()) { try { - output.get(bundleId.get()).add(WIRE_CODER.decode(data.getData().newInput())); + outputs.get(bundleId.get()).add(WIRE_CODER.decode(data.getData().newInput())); } catch (IOException e) { throw new RuntimeException("Failed to decode output."); } @@ -131,22 +140,9 @@ public void onError(Throwable throwable) {} @Override public void onCompleted() {} - }, - false); - } + }; - @Test - public void testReuseForMultipleBundles() throws Exception { - AtomicReference bundleId = new AtomicReference<>("0"); - String localInputId = "inputPC"; - RunnerApi.PTransform pTransform = - RemoteGrpcPortWrite.writeToPort(localInputId, PORT_SPEC).toPTransform(); - - List> output0 = new ArrayList<>(); - List> output1 = new ArrayList<>(); - Map aggregators = new HashMap<>(); - BeamFnDataOutboundAggregator aggregator = - createRecordingAggregator(ImmutableMap.of("0", output0, "1", output1), bundleId::get); + aggregator.prepareForInstruction(bundleId.get(), observer); aggregators.put(PORT_SPEC.getApiServiceDescriptor(), aggregator); PTransformRunnerFactoryTestContext context = @@ -179,6 +175,7 @@ public void testReuseForMultipleBundles() throws Exception { // Process for bundle id 1 bundleId.set("1"); + aggregator.prepareForInstruction(bundleId.get(), observer); pCollectionConsumer.accept(valueInGlobalWindow("GHI")); pCollectionConsumer.accept(valueInGlobalWindow("JKL")); diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java index 50a2fec0b5a2..2aa555e83cd5 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java @@ -839,7 +839,7 @@ private class TestBeamFnDataOutboundAggregator extends BeamFnDataOutboundAggrega private Supplier processBundleRequestIdSupplier; public TestBeamFnDataOutboundAggregator(Supplier bundleIdSupplier) { - super(PipelineOptionsFactory.create(), bundleIdSupplier, null, false); + super(PipelineOptionsFactory.create(), false); this.timers = new HashMap<>(); this.dataOutput = new HashMap<>(); this.processBundleRequestIdSupplier = bundleIdSupplier; diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/PTransformRunnerFactoryTestContext.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/PTransformRunnerFactoryTestContext.java index 7b4387738a4c..51e49953b406 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/PTransformRunnerFactoryTestContext.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/PTransformRunnerFactoryTestContext.java @@ -54,6 +54,7 @@ import org.apache.beam.sdk.transforms.DoFn.BundleFinalizer; import org.apache.beam.sdk.util.construction.Timer; import org.apache.beam.sdk.values.WindowedValue; +import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.stub.StreamObserver; import org.joda.time.Instant; /** @@ -74,6 +75,7 @@ public static Builder builder(String pTransformId, RunnerApi.PTransform pTransfo @Override public void registerReceiver( String instructionId, + String dataStreamId, List apiServiceDescriptors, CloseableFnDataReceiver receiver) { throw new UnsupportedOperationException("Unexpected call during test."); @@ -81,15 +83,15 @@ public void registerReceiver( @Override public void unregisterReceiver( - String instructionId, List apiServiceDescriptors) { + String instructionId, + String dataStreamId, + List apiServiceDescriptors) { throw new UnsupportedOperationException("Unexpected call during test."); } @Override - public BeamFnDataOutboundAggregator createOutboundAggregator( - ApiServiceDescriptor apiServiceDescriptor, - Supplier processBundleRequestIdSupplier, - boolean collectElementsIfNoFlushes) { + public StreamObserver getOutboundObserver( + ApiServiceDescriptor apiServiceDescriptor, String dataStreamId) { throw new UnsupportedOperationException("Unexpected call during test."); } diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java index 47f85178b0a1..c03f82726740 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java @@ -37,7 +37,6 @@ import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyBoolean; import static org.mockito.Mockito.argThat; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.eq; @@ -1071,28 +1070,20 @@ private ProcessBundleHandler setupProcessBundleHandlerForSimpleRecordingDoFn( dataOutput.add(input.getValue()); })); - Mockito.doAnswer( - (invocation) -> - new BeamFnDataOutboundAggregator( - PipelineOptionsFactory.create(), - invocation.getArgument(1), - new StreamObserver() { - @Override - public void onNext(Elements elements) { - for (Timers timer : elements.getTimersList()) { - timerOutput.addAll(elements.getTimersList()); - } - } + Mockito.when(beamFnDataClient.getOutboundObserver(any(), any())) + .thenReturn( + new StreamObserver() { + @Override + public void onNext(Elements elements) { + timerOutput.addAll(elements.getTimersList()); + } - @Override - public void onError(Throwable throwable) {} + @Override + public void onError(Throwable throwable) {} - @Override - public void onCompleted() {} - }, - invocation.getArgument(2))) - .when(beamFnDataClient) - .createOutboundAggregator(any(), any(), anyBoolean()); + @Override + public void onCompleted() {} + }); return new ProcessBundleHandler( PipelineOptionsFactory.create(), @@ -1409,7 +1400,7 @@ public void testInstructionIsUnregisteredFromBeamFnDataClientOnSuccess() throws (invocation) -> { String instructionId = invocation.getArgument(0, String.class); CloseableFnDataReceiver data = - invocation.getArgument(2, CloseableFnDataReceiver.class); + invocation.getArgument(3, CloseableFnDataReceiver.class); data.accept( BeamFnApi.Elements.newBuilder() .addData( @@ -1421,7 +1412,7 @@ public void testInstructionIsUnregisteredFromBeamFnDataClientOnSuccess() throws return null; }) .when(beamFnDataClient) - .registerReceiver(any(), any(), any()); + .registerReceiver(any(), any(), any(), any()); ProcessBundleHandler handler = new ProcessBundleHandler( @@ -1451,8 +1442,8 @@ public void testInstructionIsUnregisteredFromBeamFnDataClientOnSuccess() throws .build()); // Ensure that we unregister during successful processing - verify(beamFnDataClient).registerReceiver(eq("instructionId"), any(), any()); - verify(beamFnDataClient).unregisterReceiver(eq("instructionId"), any()); + verify(beamFnDataClient).registerReceiver(eq("instructionId"), any(), any(), any()); + verify(beamFnDataClient).unregisterReceiver(eq("instructionId"), any(), any()); verifyNoMoreInteractions(beamFnDataClient); } @@ -1475,7 +1466,7 @@ public void testDataProcessingExceptionsArePropagated() throws Exception { StringUtf8Coder.of().encode("A", encodedData); String instructionId = invocation.getArgument(0, String.class); CloseableFnDataReceiver data = - invocation.getArgument(2, CloseableFnDataReceiver.class); + invocation.getArgument(3, CloseableFnDataReceiver.class); data.accept( BeamFnApi.Elements.newBuilder() .addData( @@ -1489,7 +1480,7 @@ public void testDataProcessingExceptionsArePropagated() throws Exception { return null; }) .when(beamFnDataClient) - .registerReceiver(any(), any(), any()); + .registerReceiver(any(), any(), any(), any()); ProcessBundleHandler handler = new ProcessBundleHandler( @@ -1526,7 +1517,7 @@ public void testDataProcessingExceptionsArePropagated() throws Exception { .build())); // Ensure that we unregister during successful processing - verify(beamFnDataClient).registerReceiver(eq("instructionId"), any(), any()); + verify(beamFnDataClient).registerReceiver(eq("instructionId"), any(), any(), any()); verify(beamFnDataClient).poisonInstructionId(eq("instructionId")); verifyNoMoreInteractions(beamFnDataClient); } diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/BeamFnDataGrpcClientTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/BeamFnDataGrpcClientTest.java index 15f83f2582c7..9d9efa0b9c49 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/BeamFnDataGrpcClientTest.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/BeamFnDataGrpcClientTest.java @@ -23,8 +23,8 @@ import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.empty; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; -import static org.junit.Assert.fail; import java.util.Arrays; import java.util.Collection; @@ -49,6 +49,7 @@ import org.apache.beam.sdk.fn.data.LogicalEndpoint; import org.apache.beam.sdk.fn.stream.OutboundObserverFactory; import org.apache.beam.sdk.fn.test.TestStreams; +import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.transforms.windowing.GlobalWindow; import org.apache.beam.sdk.values.WindowedValue; @@ -169,7 +170,6 @@ public StreamObserver data( BeamFnDataGrpcClient clientFactory = new BeamFnDataGrpcClient( - PipelineOptionsFactory.create(), (Endpoints.ApiServiceDescriptor descriptor) -> channel, OutboundObserverFactory.trivial()); @@ -183,7 +183,7 @@ public StreamObserver data( Collections.emptyList()); clientFactory.registerReceiver( - INSTRUCTION_ID_A, Arrays.asList(apiServiceDescriptor), observerA); + INSTRUCTION_ID_A, "", Arrays.asList(apiServiceDescriptor), observerA); waitForClientToConnect.await(); outboundServerObserver.get().onNext(ELEMENTS_A_1); @@ -193,7 +193,7 @@ public StreamObserver data( Thread.sleep(100); clientFactory.registerReceiver( - INSTRUCTION_ID_B, Arrays.asList(apiServiceDescriptor), observerB); + INSTRUCTION_ID_B, "", Arrays.asList(apiServiceDescriptor), observerB); // Show that out of order stream completion can occur. observerB.awaitCompletion(); @@ -245,7 +245,6 @@ public StreamObserver data( BeamFnDataGrpcClient clientFactory = new BeamFnDataGrpcClient( - PipelineOptionsFactory.create(), (Endpoints.ApiServiceDescriptor descriptor) -> channel, OutboundObserverFactory.trivial()); @@ -262,7 +261,7 @@ public StreamObserver data( Collections.emptyList()); clientFactory.registerReceiver( - INSTRUCTION_ID_A, Arrays.asList(apiServiceDescriptor), observer); + INSTRUCTION_ID_A, "", Arrays.asList(apiServiceDescriptor), observer); waitForClientToConnect.await(); @@ -270,12 +269,8 @@ public StreamObserver data( outboundServerObserver.get().onNext(ELEMENTS_A_1); outboundServerObserver.get().onNext(ELEMENTS_A_2); - try { - observer.awaitCompletion(); - fail("Expected channel to fail"); - } catch (Exception e) { - assertEquals(exceptionToThrow, e); - } + Exception e = assertThrows(Exception.class, observer::awaitCompletion); + assertEquals(exceptionToThrow, e); // The server should not have received any values assertThat(inboundServerValues, empty()); // The consumer should have only been invoked once @@ -321,7 +316,6 @@ public StreamObserver data( BeamFnDataGrpcClient clientFactory = new BeamFnDataGrpcClient( - PipelineOptionsFactory.create(), (Endpoints.ApiServiceDescriptor descriptor) -> channel, OutboundObserverFactory.trivial()); @@ -347,7 +341,7 @@ public StreamObserver data( }); clientFactory.registerReceiver( - INSTRUCTION_ID_A, Arrays.asList(apiServiceDescriptor), observerA); + INSTRUCTION_ID_A, "", Arrays.asList(apiServiceDescriptor), observerA); waitForClientToConnect.await(); outboundServerObserver.get().onNext(ELEMENTS_B_1); @@ -358,11 +352,9 @@ public StreamObserver data( assertTrue(receivedAElement.await(5, TimeUnit.SECONDS)); clientFactory.poisonInstructionId(INSTRUCTION_ID_A); - try { - future.get(); - fail(); // We expect the awaitCompletion to fail due to closing. - } catch (Exception ignored) { - } + // We expect the awaitCompletion to fail due to closing. + // Expected. + assertThrows(Exception.class, future::get); outboundServerObserver.get().onNext(ELEMENTS_A_2); @@ -404,16 +396,15 @@ public StreamObserver data( ManagedChannel channel = InProcessChannelBuilder.forName(apiServiceDescriptor.getUrl()).build(); + PipelineOptions options = + PipelineOptionsFactory.fromArgs("--experiments=data_buffer_size_limit=20").create(); BeamFnDataGrpcClient clientFactory = new BeamFnDataGrpcClient( - PipelineOptionsFactory.fromArgs( - new String[] {"--experiments=data_buffer_size_limit=20"}) - .create(), (Endpoints.ApiServiceDescriptor descriptor) -> channel, OutboundObserverFactory.trivial()); - BeamFnDataOutboundAggregator aggregator = - clientFactory.createOutboundAggregator( - apiServiceDescriptor, () -> INSTRUCTION_ID_A, false); + BeamFnDataOutboundAggregator aggregator = new BeamFnDataOutboundAggregator(options, false); + aggregator.prepareForInstruction( + INSTRUCTION_ID_A, clientFactory.getOutboundObserver(apiServiceDescriptor, "")); FnDataReceiver> fnDataReceiver = aggregator.registerOutputDataLocation(TRANSFORM_ID_A, CODER); fnDataReceiver.accept(valueInGlobalWindow("ABC")); From 9e1a395a2b9e523342c8bdffb24c7b6635b01f7f Mon Sep 17 00:00:00 2001 From: Sam Whittle Date: Fri, 12 Jun 2026 14:53:34 +0200 Subject: [PATCH 2/3] fix up fnexection usage of BeamFnDataOutboundAggregator --- .../fnexecution/control/SdkHarnessClient.java | 2 +- .../runners/fnexecution/data/FnDataService.java | 3 +-- .../runners/fnexecution/data/GrpcDataService.java | 13 ++++++------- .../fnexecution/data/GrpcDataServiceTest.java | 2 +- 4 files changed, 9 insertions(+), 11 deletions(-) diff --git a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/SdkHarnessClient.java b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/SdkHarnessClient.java index 682c45e30795..704d298a195d 100644 --- a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/SdkHarnessClient.java +++ b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/SdkHarnessClient.java @@ -298,7 +298,7 @@ public ActiveBundle newBundle( ImmutableMap.Builder> receiverBuilder = ImmutableMap.builder(); BeamFnDataOutboundAggregator beamFnDataOutboundAggregator = - fnApiDataService.createOutboundAggregator(() -> bundleId, false); + fnApiDataService.createOutboundAggregator(bundleId, false); for (RemoteInputDestination remoteInput : remoteInputs) { LogicalEndpoint endpoint = LogicalEndpoint.data(bundleId, remoteInput.getPTransformId()); receiverBuilder.put( diff --git a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/data/FnDataService.java b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/data/FnDataService.java index 7c5f110eab28..657ec74553bc 100644 --- a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/data/FnDataService.java +++ b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/data/FnDataService.java @@ -17,7 +17,6 @@ */ package org.apache.beam.runners.fnexecution.data; -import java.util.function.Supplier; import org.apache.beam.model.fnexecution.v1.BeamFnApi.Elements; import org.apache.beam.sdk.fn.data.BeamFnDataOutboundAggregator; import org.apache.beam.sdk.fn.data.CloseableFnDataReceiver; @@ -69,5 +68,5 @@ public interface FnDataService { *

The returned aggregator is not thread safe. */ BeamFnDataOutboundAggregator createOutboundAggregator( - Supplier processBundleRequestIdSupplier, boolean collectElementsIfNoFlushes); + String processBundleId, boolean collectElementsIfNoFlushes); } diff --git a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/data/GrpcDataService.java b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/data/GrpcDataService.java index d4e45c8ccf82..a3a8c3244044 100644 --- a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/data/GrpcDataService.java +++ b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/data/GrpcDataService.java @@ -23,7 +23,6 @@ import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; -import java.util.function.Supplier; import org.apache.beam.model.fnexecution.v1.BeamFnApi; import org.apache.beam.model.fnexecution.v1.BeamFnApi.Elements; import org.apache.beam.model.fnexecution.v1.BeamFnDataGrpc; @@ -175,13 +174,13 @@ public void unregisterReceiver(String instructionId) { @Override public BeamFnDataOutboundAggregator createOutboundAggregator( - Supplier processBundleRequestIdSupplier, boolean collectElementsIfNoFlushes) { + String instructionId, boolean collectElementsIfNoFlushes) { try { - return new BeamFnDataOutboundAggregator( - options, - processBundleRequestIdSupplier, - connectedClient.get(3, TimeUnit.MINUTES).getOutboundObserver(), - collectElementsIfNoFlushes); + BeamFnDataOutboundAggregator aggregator = + new BeamFnDataOutboundAggregator(options, collectElementsIfNoFlushes); + aggregator.prepareForInstruction( + instructionId, connectedClient.get(3, TimeUnit.MINUTES).getOutboundObserver()); + return aggregator; } catch (InterruptedException e) { Thread.currentThread().interrupt(); throw new RuntimeException(e); diff --git a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/data/GrpcDataServiceTest.java b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/data/GrpcDataServiceTest.java index f84467077501..363367f1087f 100644 --- a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/data/GrpcDataServiceTest.java +++ b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/data/GrpcDataServiceTest.java @@ -102,7 +102,7 @@ public void testMessageReceivedBySingleClientWhenThereAreMultipleClients() throw for (int i = 0; i < 3; ++i) { final String instructionId = Integer.toString(i); BeamFnDataOutboundAggregator aggregator = - service.createOutboundAggregator(() -> instructionId, false); + service.createOutboundAggregator(instructionId, false); aggregator.start(); FnDataReceiver> consumer = aggregator.registerOutputDataLocation(TRANSFORM_ID, CODER); From a12f159c43a022b8a64dabc4006ff2cd0a3bae27 Mon Sep 17 00:00:00 2001 From: Sam Whittle Date: Fri, 12 Jun 2026 15:37:08 +0200 Subject: [PATCH 3/3] fix tests --- .../fn/data/BeamFnDataOutboundAggregator.java | 22 ++++++++++--- .../BeamFnDataOutboundAggregatorTest.java | 32 +++++++++++++++++++ .../harness/control/ProcessBundleHandler.java | 5 +++ 3 files changed, 54 insertions(+), 5 deletions(-) diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/data/BeamFnDataOutboundAggregator.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/data/BeamFnDataOutboundAggregator.java index d2c3ada870c3..d7ce362ef58a 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/data/BeamFnDataOutboundAggregator.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/data/BeamFnDataOutboundAggregator.java @@ -18,6 +18,7 @@ package org.apache.beam.sdk.fn.data; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState; import java.io.IOException; import java.util.Collections; @@ -98,15 +99,31 @@ public void prepareForInstruction( String instructionId, StreamObserver outboundObserver) { if (timeLimit > 0) { synchronized (flushLock) { + checkState(this.instructionId == null && this.outboundObserver == null); this.instructionId = instructionId; this.outboundObserver = outboundObserver; } } else { + checkState(this.instructionId == null && this.outboundObserver == null); this.instructionId = instructionId; this.outboundObserver = outboundObserver; } } + public void finishInstruction() { + if (timeLimit > 0) { + synchronized (flushLock) { + checkState(this.instructionId != null && this.outboundObserver != null); + this.instructionId = null; + this.outboundObserver = null; + } + } else { + checkState(this.instructionId != null && this.outboundObserver != null); + this.instructionId = null; + this.outboundObserver = null; + } + } + /** Starts the flushing daemon thread if data_buffer_time_limit_ms is set. */ public void start() { if (timeLimit > 0 && this.flushFuture == null) { @@ -198,7 +215,6 @@ public Elements sendOrCollectBufferedDataAndFinishOutboundStreams() { } else { bufferedElements = convertBufferForTransmission(); } - checkNotNull(instructionId); LOG.debug( "Closing streams for instruction {} and outbound data {} and timers {}.", instructionId, @@ -224,13 +240,10 @@ public Elements sendOrCollectBufferedDataAndFinishOutboundStreams() { entry.getValue().resetStats(); } // This is the end of the bundle so we reset state to prepare for future bundles. - instructionId = null; if (collectElementsIfNoFlushes && !hasFlushedForBundle) { - outboundObserver = null; return bufferedElements.build(); } checkNotNull(outboundObserver).onNext(bufferedElements.build()); - outboundObserver = null; hasFlushedForBundle = false; return null; } @@ -247,7 +260,6 @@ public void discard() { } private Elements.Builder convertBufferForTransmission() { - checkNotNull(instructionId); Elements.Builder bufferedElements = Elements.newBuilder(); for (Map.Entry> entry : outputDataReceivers.entrySet()) { if (!entry.getValue().hasBufferedOutput()) { diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/fn/data/BeamFnDataOutboundAggregatorTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/fn/data/BeamFnDataOutboundAggregatorTest.java index a5143eb530d2..9bcf615d638b 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/fn/data/BeamFnDataOutboundAggregatorTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/fn/data/BeamFnDataOutboundAggregatorTest.java @@ -20,6 +20,7 @@ import static org.hamcrest.Matchers.empty; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.fail; import java.io.IOException; @@ -323,6 +324,37 @@ public void testConfiguredBufferLimitMultipleEndpoints() throws Exception { checkEqualInAnyOrder(builder.build(), values.get(1)); } + @Test + public void testInstructionLifecycle() { + BeamFnDataOutboundAggregator aggregator = + new BeamFnDataOutboundAggregator(PipelineOptionsFactory.create(), false); + assertThrows( + NullPointerException.class, () -> aggregator.sendElements(Elements.getDefaultInstance())); + aggregator.prepareForInstruction( + "testInstruction", + TestStreams.withOnNext( + (Consumer) + e -> { + throw new RuntimeException(""); + }) + .build()); + assertThrows( + IllegalStateException.class, + () -> + aggregator.prepareForInstruction( + "testInstruction", + TestStreams.withOnNext( + (Consumer) + e -> { + throw new RuntimeException(""); + }) + .build())); + aggregator.finishInstruction(); + assertThrows( + NullPointerException.class, () -> aggregator.sendElements(Elements.getDefaultInstance())); + assertThrows(IllegalStateException.class, aggregator::finishInstruction); + } + private void checkEqualInAnyOrder(Elements first, Elements second) { MatcherAssert.assertThat( first.getDataList(), Matchers.containsInAnyOrder(second.getDataList().toArray())); diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java index 621e60c50452..f1b34641c1a4 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java @@ -637,6 +637,10 @@ private void embedOutboundElementsIfApplicable( collectedElements.add(elements); } if (!hasFlushedAggregator) { + for (BeamFnDataOutboundAggregator aggregator : + bundleProcessor.getOutboundAggregators().values()) { + aggregator.finishInstruction(); + } Elements.Builder elementsToEmbed = Elements.newBuilder(); for (Elements collectedElement : collectedElements) { elementsToEmbed.mergeFrom(collectedElement); @@ -653,6 +657,7 @@ private void embedOutboundElementsIfApplicable( if (elements != null) { aggregator.sendElements(elements); } + aggregator.finishInstruction(); } } }