diff --git a/model/fn-execution/src/main/proto/org/apache/beam/model/fn_execution/v1/beam_fn_api.proto b/model/fn-execution/src/main/proto/org/apache/beam/model/fn_execution/v1/beam_fn_api.proto index ecef3f2e7a94..6543f7d16c18 100644 --- a/model/fn-execution/src/main/proto/org/apache/beam/model/fn_execution/v1/beam_fn_api.proto +++ b/model/fn-execution/src/main/proto/org/apache/beam/model/fn_execution/v1/beam_fn_api.proto @@ -120,6 +120,9 @@ message RemoteGrpcPort { service BeamFnControl { // Instructions sent by the runner to the SDK requesting different types // of work. + // + // Header metadata has the specified keys pairs: + // - "worker_id": the id of the sdk rpc Control( // A stream of responses to instructions the SDK was asked to be // performed. @@ -130,6 +133,9 @@ service BeamFnControl { // Used to get the full process bundle descriptors for bundles one // is asked to process. + // + // Header metadata has the specified keys pairs: + // - "worker_id": the id of the sdk rpc GetProcessBundleDescriptor(GetProcessBundleDescriptorRequest) returns ( ProcessBundleDescriptor) {} } @@ -416,14 +422,21 @@ message ProcessBundleRequest { // at https://s.apache.org/beam-fn-api-control-data-embedding. Elements elements = 3; - // indicates that the runner has no stare for the keys in this bundle + // Indicates that the runner has no state for the keys in this bundle // so SDk can safely begin stateful processing with a locally-generated - // initial empty state + // initial empty state. bool has_no_state = 4; - // indicates that the runner will never process another bundle for the keys + // Indicates that the runner will never process another bundle for the keys // in this bundle so state need not be included in the bundle commit. bool only_bundle_for_keys = 5; + + // (Optional) If non-empty, the ID of the data stream to use for this bundle. + // See comments at BeamFnData.Data for more details. + // + // The runner should only populate this field if the sdk advertises the + // beam:protocol:named_data_streams:v1 capability. + string data_stream_id = 6; } message ProcessBundleResponse { @@ -835,6 +848,11 @@ message Elements { // Stable service BeamFnData { // Used to send data between harnesses. + // + // Header metadata has the specified keys pairs: + // - "worker_id": value is the id of the sdk + // - "data_stream_id": value is the id of the data stream, distinguishing it from other data streams from the same + // sdk. This field should only be populated if requested in a received ProcessBundleRequest from the runner. rpc Data( // A stream of data representing input. stream Elements) @@ -900,6 +918,9 @@ message StateResponse { service BeamFnState { // Used to get/append/clear state stored by the runner on behalf of the SDK. + // + // Header metadata has the specified keys pairs: + // - "worker_id": the id of the sdk rpc State( // A stream of state instructions requested of the runner. stream StateRequest) @@ -1295,6 +1316,11 @@ message LogControl {} service BeamFnLogging { // Allows for the SDK to emit log entries which the runner can // associate with the active job. + // + // Used to get/append/clear state stored by the runner on behalf of the SDK. + // + // Header metadata has the specified keys pairs: + // - "worker_id": the id of the sdk rpc Logging( // A stream of log entries batched into lists emitted by the SDK harness. stream LogEntry.List) @@ -1356,6 +1382,8 @@ message WorkerStatusResponse { // API for SDKs to report debug-related statuses to runner during pipeline execution. service BeamFnWorkerStatus { + // Header metadata has the specified keys pairs: + // - "worker_id": the id of the sdk rpc WorkerStatus (stream WorkerStatusResponse) returns (stream WorkerStatusRequest) {} } diff --git a/model/pipeline/src/main/proto/org/apache/beam/model/pipeline/v1/beam_runner_api.proto b/model/pipeline/src/main/proto/org/apache/beam/model/pipeline/v1/beam_runner_api.proto index 67df8b9e8003..5824c9bf4b73 100644 --- a/model/pipeline/src/main/proto/org/apache/beam/model/pipeline/v1/beam_runner_api.proto +++ b/model/pipeline/src/main/proto/org/apache/beam/model/pipeline/v1/beam_runner_api.proto @@ -1689,6 +1689,10 @@ message StandardProtocols { // Indicates whether the SDK supports multimap state. MULTIMAP_STATE = 12 [(beam_urn) = "beam:protocol:multimap_state:v1"]; + + // Indicates whether the SDK supports data stream ids being requested by the runner in + // ProcessBundleRequests. + NAMED_DATA_STREAMS = 13 [(beam_urn) = "beam:protocol:named_data_streams:v1"]; } } diff --git a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/SdkHarnessClient.java b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/SdkHarnessClient.java index 682c45e30795..704d298a195d 100644 --- a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/SdkHarnessClient.java +++ b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/SdkHarnessClient.java @@ -298,7 +298,7 @@ public ActiveBundle newBundle( ImmutableMap.Builder> receiverBuilder = ImmutableMap.builder(); BeamFnDataOutboundAggregator beamFnDataOutboundAggregator = - fnApiDataService.createOutboundAggregator(() -> bundleId, false); + fnApiDataService.createOutboundAggregator(bundleId, false); for (RemoteInputDestination remoteInput : remoteInputs) { LogicalEndpoint endpoint = LogicalEndpoint.data(bundleId, remoteInput.getPTransformId()); receiverBuilder.put( diff --git a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/data/FnDataService.java b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/data/FnDataService.java index 7c5f110eab28..657ec74553bc 100644 --- a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/data/FnDataService.java +++ b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/data/FnDataService.java @@ -17,7 +17,6 @@ */ package org.apache.beam.runners.fnexecution.data; -import java.util.function.Supplier; import org.apache.beam.model.fnexecution.v1.BeamFnApi.Elements; import org.apache.beam.sdk.fn.data.BeamFnDataOutboundAggregator; import org.apache.beam.sdk.fn.data.CloseableFnDataReceiver; @@ -69,5 +68,5 @@ public interface FnDataService { *

The returned aggregator is not thread safe. */ BeamFnDataOutboundAggregator createOutboundAggregator( - Supplier processBundleRequestIdSupplier, boolean collectElementsIfNoFlushes); + String processBundleId, boolean collectElementsIfNoFlushes); } diff --git a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/data/GrpcDataService.java b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/data/GrpcDataService.java index d4e45c8ccf82..a3a8c3244044 100644 --- a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/data/GrpcDataService.java +++ b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/data/GrpcDataService.java @@ -23,7 +23,6 @@ import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; -import java.util.function.Supplier; import org.apache.beam.model.fnexecution.v1.BeamFnApi; import org.apache.beam.model.fnexecution.v1.BeamFnApi.Elements; import org.apache.beam.model.fnexecution.v1.BeamFnDataGrpc; @@ -175,13 +174,13 @@ public void unregisterReceiver(String instructionId) { @Override public BeamFnDataOutboundAggregator createOutboundAggregator( - Supplier processBundleRequestIdSupplier, boolean collectElementsIfNoFlushes) { + String instructionId, boolean collectElementsIfNoFlushes) { try { - return new BeamFnDataOutboundAggregator( - options, - processBundleRequestIdSupplier, - connectedClient.get(3, TimeUnit.MINUTES).getOutboundObserver(), - collectElementsIfNoFlushes); + BeamFnDataOutboundAggregator aggregator = + new BeamFnDataOutboundAggregator(options, collectElementsIfNoFlushes); + aggregator.prepareForInstruction( + instructionId, connectedClient.get(3, TimeUnit.MINUTES).getOutboundObserver()); + return aggregator; } catch (InterruptedException e) { Thread.currentThread().interrupt(); throw new RuntimeException(e); diff --git a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/data/GrpcDataServiceTest.java b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/data/GrpcDataServiceTest.java index f84467077501..363367f1087f 100644 --- a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/data/GrpcDataServiceTest.java +++ b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/data/GrpcDataServiceTest.java @@ -102,7 +102,7 @@ public void testMessageReceivedBySingleClientWhenThereAreMultipleClients() throw for (int i = 0; i < 3; ++i) { final String instructionId = Integer.toString(i); BeamFnDataOutboundAggregator aggregator = - service.createOutboundAggregator(() -> instructionId, false); + service.createOutboundAggregator(instructionId, false); aggregator.start(); FnDataReceiver> consumer = aggregator.registerOutputDataLocation(TRANSFORM_ID, CODER); diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/data/BeamFnDataGrpcMultiplexer.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/data/BeamFnDataGrpcMultiplexer.java index 8fec8b455cce..0b9d6adab4f0 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/data/BeamFnDataGrpcMultiplexer.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/data/BeamFnDataGrpcMultiplexer.java @@ -63,7 +63,7 @@ public class BeamFnDataGrpcMultiplexer implements AutoCloseable { private final Cache poisonedInstructionIds; private static class PoisonedException extends RuntimeException { - public PoisonedException() { + private PoisonedException() { super("Instruction poisoned"); } }; diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/data/BeamFnDataOutboundAggregator.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/data/BeamFnDataOutboundAggregator.java index 9b9603706b48..d7ce362ef58a 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/data/BeamFnDataOutboundAggregator.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/data/BeamFnDataOutboundAggregator.java @@ -17,6 +17,9 @@ */ package org.apache.beam.sdk.fn.data; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState; + import java.io.IOException; import java.util.Collections; import java.util.HashMap; @@ -28,7 +31,6 @@ import java.util.concurrent.Executors; import java.util.concurrent.ScheduledFuture; import java.util.concurrent.TimeUnit; -import java.util.function.Supplier; import javax.annotation.Nullable; import javax.annotation.concurrent.NotThreadSafe; import org.apache.beam.model.fnexecution.v1.BeamFnApi.Elements; @@ -56,7 +58,7 @@ @SuppressWarnings({ "nullness" // TODO(https://github.com/apache/beam/issues/20497) }) -// The calling thread that invokes sendBufferedDataAndFinishOutboundStreams synchronizes on +// The calling thread that invokes sendOrCollectBufferedDataAndFinishOutboundStreams synchronizes on // flushLock effectively making the periodic flushing no longer read or mutate hasFlushedForBundle // and allowing the calling thread to read and mutate hasFlushedForBundle safely without needing to // create another memory barrier. Also note that flush is always invoked when synchronizing on @@ -72,33 +74,56 @@ public class BeamFnDataOutboundAggregator { private static final Logger LOG = LoggerFactory.getLogger(BeamFnDataOutboundAggregator.class); private final int sizeLimit; private final long timeLimit; - private final Supplier processBundleRequestIdSupplier; + private String instructionId; @VisibleForTesting final Map> outputDataReceivers; @VisibleForTesting final Map> outputTimersReceivers; - private final StreamObserver outboundObserver; + @Nullable private StreamObserver outboundObserver; @Nullable @VisibleForTesting ScheduledFuture flushFuture; private long bytesWrittenSinceFlush; private final Object flushLock; private final boolean collectElementsIfNoFlushes; private boolean hasFlushedForBundle; - public BeamFnDataOutboundAggregator( - PipelineOptions options, - Supplier processBundleRequestIdSupplier, - StreamObserver outboundObserver, - boolean collectElementsIfNoFlushes) { + public BeamFnDataOutboundAggregator(PipelineOptions options, boolean collectElementsIfNoFlushes) { this.sizeLimit = getSizeLimit(options); this.timeLimit = getTimeLimit(options); this.collectElementsIfNoFlushes = collectElementsIfNoFlushes; this.outputDataReceivers = new HashMap<>(); this.outputTimersReceivers = new HashMap<>(); - this.outboundObserver = outboundObserver; - this.processBundleRequestIdSupplier = processBundleRequestIdSupplier; this.bytesWrittenSinceFlush = 0L; this.flushLock = new Object(); this.hasFlushedForBundle = false; } + public void prepareForInstruction( + String instructionId, StreamObserver outboundObserver) { + if (timeLimit > 0) { + synchronized (flushLock) { + checkState(this.instructionId == null && this.outboundObserver == null); + this.instructionId = instructionId; + this.outboundObserver = outboundObserver; + } + } else { + checkState(this.instructionId == null && this.outboundObserver == null); + this.instructionId = instructionId; + this.outboundObserver = outboundObserver; + } + } + + public void finishInstruction() { + if (timeLimit > 0) { + synchronized (flushLock) { + checkState(this.instructionId != null && this.outboundObserver != null); + this.instructionId = null; + this.outboundObserver = null; + } + } else { + checkState(this.instructionId != null && this.outboundObserver != null); + this.instructionId = null; + this.outboundObserver = null; + } + } + /** Starts the flushing daemon thread if data_buffer_time_limit_ms is set. */ public void start() { if (timeLimit > 0 && this.flushFuture == null) { @@ -166,7 +191,7 @@ private void flushInternal() { } Elements.Builder elements = convertBufferForTransmission(); if (elements.getDataCount() > 0 || elements.getTimersCount() > 0) { - outboundObserver.onNext(elements.build()); + checkNotNull(outboundObserver).onNext(elements.build()); } hasFlushedForBundle = true; } @@ -177,6 +202,7 @@ private void flushInternal() { * collectElementsIfNoFlushes=true, and there was no previous flush in this bundle, otherwise * returns null. */ + @Nullable public Elements sendOrCollectBufferedDataAndFinishOutboundStreams() { if (outputTimersReceivers.isEmpty() && outputDataReceivers.isEmpty()) { return null; @@ -191,14 +217,14 @@ public Elements sendOrCollectBufferedDataAndFinishOutboundStreams() { } LOG.debug( "Closing streams for instruction {} and outbound data {} and timers {}.", - processBundleRequestIdSupplier.get(), + instructionId, outputDataReceivers, outputTimersReceivers); for (Map.Entry> entry : outputDataReceivers.entrySet()) { String pTransformId = entry.getKey(); bufferedElements .addDataBuilder() - .setInstructionId(processBundleRequestIdSupplier.get()) + .setInstructionId(instructionId) .setTransformId(pTransformId) .setIsLast(true); entry.getValue().resetStats(); @@ -207,30 +233,29 @@ public Elements sendOrCollectBufferedDataAndFinishOutboundStreams() { TimerEndpoint timerKey = entry.getKey(); bufferedElements .addTimersBuilder() - .setInstructionId(processBundleRequestIdSupplier.get()) + .setInstructionId(instructionId) .setTransformId(timerKey.pTransformId) .setTimerFamilyId(timerKey.timerFamilyId) .setIsLast(true); entry.getValue().resetStats(); } + // This is the end of the bundle so we reset state to prepare for future bundles. if (collectElementsIfNoFlushes && !hasFlushedForBundle) { return bufferedElements.build(); } - outboundObserver.onNext(bufferedElements.build()); - // This is now at the end of a bundle, so we reset hasFlushedForBundle to prepare for new - // bundles. + checkNotNull(outboundObserver).onNext(bufferedElements.build()); hasFlushedForBundle = false; return null; } // Send the elements to the StreamObserver associated with this aggregator. public void sendElements(Elements elements) { - outboundObserver.onNext(elements); + checkNotNull(outboundObserver).onNext(elements); } public void discard() { if (flushFuture != null) { - flushFuture.cancel(true); + flushFuture.cancel(false); } } @@ -243,7 +268,7 @@ private Elements.Builder convertBufferForTransmission() { ByteString bytes = entry.getValue().toByteStringAndResetBuffer(); bufferedElements .addDataBuilder() - .setInstructionId(processBundleRequestIdSupplier.get()) + .setInstructionId(instructionId) .setTransformId(entry.getKey()) .setData(bytes); } @@ -254,7 +279,7 @@ private Elements.Builder convertBufferForTransmission() { ByteString bytes = entry.getValue().toByteStringAndResetBuffer(); bufferedElements .addTimersBuilder() - .setInstructionId(processBundleRequestIdSupplier.get()) + .setInstructionId(instructionId) .setTransformId(entry.getKey().pTransformId) .setTimerFamilyId(entry.getKey().timerFamilyId) .setTimers(bytes); @@ -353,10 +378,12 @@ public void accept(T input) throws Exception { } } + @VisibleForTesting public long getByteCount() { return perBundleByteCount; } + @VisibleForTesting public long getElementCount() { return perBundleElementCount; } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/construction/Environments.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/construction/Environments.java index 969bda88d07f..c3b1a7a5235e 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/construction/Environments.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/construction/Environments.java @@ -522,6 +522,7 @@ public static Set getJavaCapabilities() { capabilities.add(BeamUrns.getUrn(StandardProtocols.Enum.SDK_CONSUMING_RECEIVED_DATA)); capabilities.add(BeamUrns.getUrn(StandardProtocols.Enum.ORDERED_LIST_STATE)); capabilities.add(BeamUrns.getUrn(StandardProtocols.Enum.MULTIMAP_STATE)); + capabilities.add(BeamUrns.getUrn(StandardProtocols.Enum.NAMED_DATA_STREAMS)); return capabilities.build(); } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/fn/data/BeamFnDataOutboundAggregatorTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/fn/data/BeamFnDataOutboundAggregatorTest.java index 092ba200c94b..9bcf615d638b 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/fn/data/BeamFnDataOutboundAggregatorTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/fn/data/BeamFnDataOutboundAggregatorTest.java @@ -20,6 +20,7 @@ import static org.hamcrest.Matchers.empty; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.fail; import java.io.IOException; @@ -75,13 +76,12 @@ public void testWithDefaultBuffer() throws Exception { final List values = new ArrayList<>(); final AtomicBoolean onCompletedWasCalled = new AtomicBoolean(); BeamFnDataOutboundAggregator aggregator = - new BeamFnDataOutboundAggregator( - PipelineOptionsFactory.create(), - endpoint::getInstructionId, - TestStreams.withOnNext(values::add) - .withOnCompleted(() -> onCompletedWasCalled.set(true)) - .build(), - false); + new BeamFnDataOutboundAggregator(PipelineOptionsFactory.create(), false); + aggregator.prepareForInstruction( + endpoint.getInstructionId(), + TestStreams.withOnNext(values::add) + .withOnCompleted(() -> onCompletedWasCalled.set(true)) + .build()); // Test that nothing is emitted till the default buffer size is surpassed. FnDataReceiver dataReceiver = registerOutputLocation(aggregator, endpoint, CODER); @@ -124,14 +124,12 @@ public void testConfiguredBufferLimit() throws Exception { options .as(ExperimentalOptions.class) .setExperiments(Arrays.asList("data_buffer_size_limit=100")); - BeamFnDataOutboundAggregator aggregator = - new BeamFnDataOutboundAggregator( - options, - endpoint::getInstructionId, - TestStreams.withOnNext(values::add) - .withOnCompleted(() -> onCompletedWasCalled.set(true)) - .build(), - false); + BeamFnDataOutboundAggregator aggregator = new BeamFnDataOutboundAggregator(options, false); + aggregator.prepareForInstruction( + endpoint.getInstructionId(), + TestStreams.withOnNext(values::add) + .withOnCompleted(() -> onCompletedWasCalled.set(true)) + .build()); // Test that nothing is emitted till the default buffer size is surpassed. FnDataReceiver dataReceiver = registerOutputLocation(aggregator, endpoint, CODER); aggregator.start(); @@ -187,18 +185,16 @@ public void testConfiguredTimeLimit() throws Exception { .as(ExperimentalOptions.class) .setExperiments(Arrays.asList("data_buffer_time_limit_ms=1")); final CountDownLatch waitForFlush = new CountDownLatch(1); - BeamFnDataOutboundAggregator aggregator = - new BeamFnDataOutboundAggregator( - options, - endpoint::getInstructionId, - TestStreams.withOnNext( - (Consumer) - e -> { - values.add(e); - waitForFlush.countDown(); - }) - .build(), - false); + BeamFnDataOutboundAggregator aggregator = new BeamFnDataOutboundAggregator(options, false); + aggregator.prepareForInstruction( + endpoint.getInstructionId(), + TestStreams.withOnNext( + (Consumer) + e -> { + values.add(e); + waitForFlush.countDown(); + }) + .build()); // Test that it emits when time passed the time limit FnDataReceiver dataReceiver = registerOutputLocation(aggregator, endpoint, CODER); @@ -214,17 +210,15 @@ public void testConfiguredTimeLimitExceptionPropagation() throws Exception { options .as(ExperimentalOptions.class) .setExperiments(Arrays.asList("data_buffer_time_limit_ms=1")); - BeamFnDataOutboundAggregator aggregator = - new BeamFnDataOutboundAggregator( - options, - endpoint::getInstructionId, - TestStreams.withOnNext( - (Consumer) - e -> { - throw new RuntimeException(""); - }) - .build(), - false); + BeamFnDataOutboundAggregator aggregator = new BeamFnDataOutboundAggregator(options, false); + aggregator.prepareForInstruction( + endpoint.getInstructionId(), + TestStreams.withOnNext( + (Consumer) + e -> { + throw new RuntimeException(""); + }) + .build()); // Test that it emits when time passed the time limit FnDataReceiver dataReceiver = registerOutputLocation(aggregator, endpoint, CODER); @@ -243,17 +237,15 @@ public void testConfiguredTimeLimitExceptionPropagation() throws Exception { // expected } - aggregator = - new BeamFnDataOutboundAggregator( - options, - endpoint::getInstructionId, - TestStreams.withOnNext( - (Consumer) - e -> { - throw new RuntimeException(""); - }) - .build(), - false); + aggregator = new BeamFnDataOutboundAggregator(options, false); + aggregator.prepareForInstruction( + endpoint.getInstructionId(), + TestStreams.withOnNext( + (Consumer) + e -> { + throw new RuntimeException(""); + }) + .build()); dataReceiver = registerOutputLocation(aggregator, endpoint, CODER); aggregator.start(); dataReceiver.accept(new byte[1]); @@ -279,14 +271,12 @@ public void testConfiguredBufferLimitMultipleEndpoints() throws Exception { options .as(ExperimentalOptions.class) .setExperiments(Arrays.asList("data_buffer_size_limit=100")); - BeamFnDataOutboundAggregator aggregator = - new BeamFnDataOutboundAggregator( - options, - endpoint::getInstructionId, - TestStreams.withOnNext(values::add) - .withOnCompleted(() -> onCompletedWasCalled.set(true)) - .build(), - false); + BeamFnDataOutboundAggregator aggregator = new BeamFnDataOutboundAggregator(options, false); + aggregator.prepareForInstruction( + endpoint.getInstructionId(), + TestStreams.withOnNext(values::add) + .withOnCompleted(() -> onCompletedWasCalled.set(true)) + .build()); // Test that nothing is emitted till the default buffer size is surpassed. LogicalEndpoint additionalEndpoint = LogicalEndpoint.data( @@ -334,6 +324,37 @@ public void testConfiguredBufferLimitMultipleEndpoints() throws Exception { checkEqualInAnyOrder(builder.build(), values.get(1)); } + @Test + public void testInstructionLifecycle() { + BeamFnDataOutboundAggregator aggregator = + new BeamFnDataOutboundAggregator(PipelineOptionsFactory.create(), false); + assertThrows( + NullPointerException.class, () -> aggregator.sendElements(Elements.getDefaultInstance())); + aggregator.prepareForInstruction( + "testInstruction", + TestStreams.withOnNext( + (Consumer) + e -> { + throw new RuntimeException(""); + }) + .build()); + assertThrows( + IllegalStateException.class, + () -> + aggregator.prepareForInstruction( + "testInstruction", + TestStreams.withOnNext( + (Consumer) + e -> { + throw new RuntimeException(""); + }) + .build())); + aggregator.finishInstruction(); + assertThrows( + NullPointerException.class, () -> aggregator.sendElements(Elements.getDefaultInstance())); + assertThrows(IllegalStateException.class, aggregator::finishInstruction); + } + private void checkEqualInAnyOrder(Elements first, Elements second) { MatcherAssert.assertThat( first.getDataList(), Matchers.containsInAnyOrder(second.getDataList().toArray())); diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnHarness.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnHarness.java index 703e726739a0..009de1b44e60 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnHarness.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnHarness.java @@ -17,6 +17,8 @@ */ package org.apache.beam.fn.harness; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; + import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; import java.nio.charset.StandardCharsets; @@ -30,7 +32,6 @@ import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.function.Function; -import javax.annotation.Nullable; import org.apache.beam.fn.harness.control.BeamFnControlClient; import org.apache.beam.fn.harness.control.ExecutionStateSampler; import org.apache.beam.fn.harness.control.FinalizeBundleHandler; @@ -72,6 +73,7 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.MoreExecutors; +import org.checkerframework.checker.nullness.qual.Nullable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -92,9 +94,6 @@ * for further details. * */ -@SuppressWarnings({ - "nullness" // TODO(https://github.com/apache/beam/issues/20497) -}) public class FnHarness { private static final String HARNESS_ID = "HARNESS_ID"; private static final String CONTROL_API_SERVICE_DESCRIPTOR = "CONTROL_API_SERVICE_DESCRIPTOR"; @@ -138,22 +137,30 @@ private static void removeKeyRecursively(JsonNode node, String keyToRemove) { } public static void main(String[] args) throws Exception { - main(System::getenv); + Function environmentVarGetter = System::getenv; + main(environmentVarGetter); } @VisibleForTesting - public static void main(Function environmentVarGetter) throws Exception { + public static void main(Function environmentVarGetter) + throws Exception { JvmInitializers.runOnStartup(); Endpoints.ApiServiceDescriptor loggingApiServiceDescriptor = - getApiServiceDescriptor(environmentVarGetter.apply(LOGGING_API_SERVICE_DESCRIPTOR)); + getApiServiceDescriptor( + checkNotNull( + environmentVarGetter.apply(LOGGING_API_SERVICE_DESCRIPTOR), + "LOGGING_API_SERVICE_DESCRIPTOR env var be set.")); Endpoints.ApiServiceDescriptor controlApiServiceDescriptor = - getApiServiceDescriptor(environmentVarGetter.apply(CONTROL_API_SERVICE_DESCRIPTOR)); + getApiServiceDescriptor( + checkNotNull( + environmentVarGetter.apply(CONTROL_API_SERVICE_DESCRIPTOR), + "CONTROL_API_SERVICE_DESCRIPTOR env var be set.")); + + @Nullable String envVar = environmentVarGetter.apply(STATUS_API_SERVICE_DESCRIPTOR); Endpoints.ApiServiceDescriptor statusApiServiceDescriptor = - environmentVarGetter.apply(STATUS_API_SERVICE_DESCRIPTOR) == null - ? null - : getApiServiceDescriptor(environmentVarGetter.apply(STATUS_API_SERVICE_DESCRIPTOR)); - String id = environmentVarGetter.apply(HARNESS_ID); + (envVar == null) ? null : getApiServiceDescriptor(envVar); + String id = checkNotNull(environmentVarGetter.apply(HARNESS_ID), "HARNESS_ID env var be set."); System.out.format("SDK Fn Harness started%n"); System.out.format("Harness ID %s%n", id); @@ -161,11 +168,11 @@ public static void main(Function environmentVarGetter) throws Ex System.out.format("Control location %s%n", controlApiServiceDescriptor); System.out.format("Status location %s%n", statusApiServiceDescriptor); - String pipelineOptionsJson = environmentVarGetter.apply(PIPELINE_OPTIONS); + @Nullable String pipelineOptionsJson = environmentVarGetter.apply(PIPELINE_OPTIONS); // Try looking for a file first. If that exists it should override PIPELINE_OPTIONS to avoid // maxing out the kernel's environment space try { - String pipelineOptionsPath = environmentVarGetter.apply(PIPELINE_OPTIONS_FILE); + @Nullable String pipelineOptionsPath = environmentVarGetter.apply(PIPELINE_OPTIONS_FILE); System.out.format("Pipeline Options File %s%n", pipelineOptionsPath); if (pipelineOptionsPath != null) { Path filePath = Paths.get(pipelineOptionsPath); @@ -179,12 +186,13 @@ public static void main(Function environmentVarGetter) throws Ex } catch (Exception e) { System.out.format("Problem loading pipeline options from file: %s%n", e.getMessage()); } - - System.out.format("Pipeline options %s%n", pipelineOptionsJson); - // TODO: https://github.com/apache/beam/issues/30301 - pipelineOptionsJson = removeNestedKey(pipelineOptionsJson, "impersonateServiceAccount"); - - PipelineOptions options = PipelineOptionsTranslation.fromJson(pipelineOptionsJson); + if (pipelineOptionsJson != null) { + System.out.format("Pipeline options %s%n", pipelineOptionsJson); + // TODO: https://github.com/apache/beam/issues/30301 + pipelineOptionsJson = removeNestedKey(pipelineOptionsJson, "impersonateServiceAccount"); + } + PipelineOptions options = + PipelineOptionsTranslation.fromJson(pipelineOptionsJson == null ? "" : pipelineOptionsJson); String runnerCapabilitesOrNull = environmentVarGetter.apply(RUNNER_CAPABILITIES); Set runnerCapabilites = @@ -219,7 +227,7 @@ public static void main( Set runnerCapabilities, Endpoints.ApiServiceDescriptor loggingApiServiceDescriptor, Endpoints.ApiServiceDescriptor controlApiServiceDescriptor, - @Nullable Endpoints.ApiServiceDescriptor statusApiServiceDescriptor) + Endpoints.@Nullable ApiServiceDescriptor statusApiServiceDescriptor) throws Exception { ManagedChannelFactory channelFactory; if (ExperimentalOptions.hasExperiment(options, "beam_fn_api_epoll")) { @@ -263,7 +271,7 @@ public static void main( Set runnerCapabilites, Endpoints.ApiServiceDescriptor loggingApiServiceDescriptor, Endpoints.ApiServiceDescriptor controlApiServiceDescriptor, - Endpoints.ApiServiceDescriptor statusApiServiceDescriptor, + Endpoints.@Nullable ApiServiceDescriptor statusApiServiceDescriptor, ManagedChannelFactory channelFactory, OutboundObserverFactory outboundObserverFactory, Cache processWideCache) @@ -318,7 +326,7 @@ public static void main( BeamFnControlGrpc.newBlockingStub(channel); BeamFnDataGrpcClient beamFnDataMultiplexer = - new BeamFnDataGrpcClient(options, channelFactory::forDescriptor, outboundObserverFactory); + new BeamFnDataGrpcClient(channelFactory::forDescriptor, outboundObserverFactory); BeamFnStateGrpcClientCache beamFnStateGrpcClientCache = new BeamFnStateGrpcClientCache(idGenerator, channelFactory, outboundObserverFactory); diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java index 5a57b137bf6b..f1b34641c1a4 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java @@ -378,9 +378,8 @@ public FnDataReceiver addOutgoingDataEndpoint( outboundAggregatorMap.computeIfAbsent( apiServiceDescriptor, asd -> - queueingClient.createOutboundAggregator( - asd, - processBundleInstructionId, + new BeamFnDataOutboundAggregator( + options, runnerCapabilities.contains( BeamUrns.getUrn( StandardRunnerProtocols.Enum @@ -391,21 +390,19 @@ public FnDataReceiver addOutgoingDataEndpoint( @Override public FnDataReceiver> addOutgoingTimersEndpoint( String timerFamilyId, org.apache.beam.sdk.coders.Coder> coder) { - BeamFnDataOutboundAggregator aggregator; if (!processBundleDescriptor.hasTimerApiServiceDescriptor()) { throw new IllegalStateException( String.format( - "Timers are unsupported because the " - + "ProcessBundleRequest %s does not provide a timer ApiServiceDescriptor.", + "Timers are unsupported because the ProcessBundleRequest %s does not" + + " provide a timer ApiServiceDescriptor.", processBundleInstructionId.get())); } - aggregator = + BeamFnDataOutboundAggregator aggregator = outboundAggregatorMap.computeIfAbsent( processBundleDescriptor.getTimerApiServiceDescriptor(), asd -> - queueingClient.createOutboundAggregator( - asd, - processBundleInstructionId, + new BeamFnDataOutboundAggregator( + options, runnerCapabilities.contains( BeamUrns.getUrn( StandardRunnerProtocols.Enum @@ -499,6 +496,8 @@ public BundleFinalizer getBundleFinalizer() { */ public BeamFnApi.InstructionResponse.Builder processBundle(InstructionRequest request) throws Exception { + String instructionId = request.getInstructionId(); + String dataStreamId = request.getProcessBundle().getDataStreamId(); @Nullable BundleProcessor bundleProcessor = null; try { bundleProcessor = @@ -515,13 +514,20 @@ public BeamFnApi.InstructionResponse.Builder processBundle(InstructionRequest re } })); + for (Map.Entry entry : + bundleProcessor.getOutboundAggregators().entrySet()) { + BeamFnDataOutboundAggregator aggregator = entry.getValue(); + aggregator.prepareForInstruction( + instructionId, beamFnDataClient.getOutboundObserver(entry.getKey(), dataStreamId)); + } + PTransformFunctionRegistry startFunctionRegistry = bundleProcessor.getStartFunctionRegistry(); PTransformFunctionRegistry finishFunctionRegistry = bundleProcessor.getFinishFunctionRegistry(); ExecutionStateTracker stateTracker = bundleProcessor.getStateTracker(); ProcessBundleResponse.Builder response = ProcessBundleResponse.newBuilder(); try (HandleStateCallsForBundle beamFnStateClient = bundleProcessor.getBeamFnStateClient()) { - stateTracker.start(request.getInstructionId()); + stateTracker.start(instructionId); try { // Already in reverse topological order so we don't need to do anything. for (ThrowingRunnable startFunction : startFunctionRegistry.getFunctions()) { @@ -545,12 +551,14 @@ public BeamFnApi.InstructionResponse.Builder processBundle(InstructionRequest re } else if (!bundleProcessor.getInboundEndpointApiServiceDescriptors().isEmpty()) { BeamFnDataInboundObserver observer = bundleProcessor.getInboundObserver(); beamFnDataClient.registerReceiver( - request.getInstructionId(), + instructionId, + dataStreamId, bundleProcessor.getInboundEndpointApiServiceDescriptors(), observer); observer.awaitCompletion(); beamFnDataClient.unregisterReceiver( - request.getInstructionId(), + instructionId, + dataStreamId, bundleProcessor.getInboundEndpointApiServiceDescriptors()); } @@ -581,7 +589,7 @@ public BeamFnApi.InstructionResponse.Builder processBundle(InstructionRequest re if (!bundleProcessor.getBundleFinalizationCallbackRegistrations().isEmpty()) { finalizeBundleHandler.registerCallbacks( - bundleProcessor.getInstructionId(), + instructionId, ImmutableList.copyOf(bundleProcessor.getBundleFinalizationCallbackRegistrations())); response.setRequiresFinalization(true); } @@ -599,7 +607,7 @@ public BeamFnApi.InstructionResponse.Builder processBundle(InstructionRequest re } catch (Exception e) { LOG.debug( "Error processing bundle {} with bundleProcessor for {} after exception", - request.getInstructionId(), + instructionId, request.getProcessBundle().getProcessBundleDescriptorId(), e); if (bundleProcessor != null) { @@ -607,7 +615,7 @@ public BeamFnApi.InstructionResponse.Builder processBundle(InstructionRequest re bundleProcessorCache.discard(bundleProcessor); } // Ensure that if more data arrives for the instruction it is discarded. - beamFnDataClient.poisonInstructionId(request.getInstructionId()); + beamFnDataClient.poisonInstructionId(instructionId); throw e; } } @@ -629,6 +637,10 @@ private void embedOutboundElementsIfApplicable( collectedElements.add(elements); } if (!hasFlushedAggregator) { + for (BeamFnDataOutboundAggregator aggregator : + bundleProcessor.getOutboundAggregators().values()) { + aggregator.finishInstruction(); + } Elements.Builder elementsToEmbed = Elements.newBuilder(); for (Elements collectedElement : collectedElements) { elementsToEmbed.mergeFrom(collectedElement); @@ -645,6 +657,7 @@ private void embedOutboundElementsIfApplicable( if (elements != null) { aggregator.sendElements(elements); } + aggregator.finishInstruction(); } } } diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/BeamFnDataClient.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/BeamFnDataClient.java index 94d59d0fcb62..1a50f5b448c5 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/BeamFnDataClient.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/BeamFnDataClient.java @@ -18,13 +18,12 @@ package org.apache.beam.fn.harness.data; import java.util.List; -import java.util.function.Supplier; import org.apache.beam.model.fnexecution.v1.BeamFnApi.Elements; import org.apache.beam.model.pipeline.v1.Endpoints; import org.apache.beam.model.pipeline.v1.Endpoints.ApiServiceDescriptor; -import org.apache.beam.sdk.fn.data.BeamFnDataOutboundAggregator; import org.apache.beam.sdk.fn.data.CloseableFnDataReceiver; import org.apache.beam.sdk.fn.data.FnDataReceiver; +import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.stub.StreamObserver; /** * The {@link BeamFnDataClient} is able to forward inbound elements to a {@link FnDataReceiver} and @@ -47,6 +46,7 @@ public interface BeamFnDataClient { */ void registerReceiver( String instructionId, + String dataStreamId, List apiServiceDescriptors, CloseableFnDataReceiver receiver); @@ -58,7 +58,8 @@ void registerReceiver( * to the {@link BeamFnDataClient} during a future {@link FnDataReceiver#accept} invocation or via * a call to {@link #poisonInstructionId}. */ - void unregisterReceiver(String instructionId, List apiServiceDescriptors); + void unregisterReceiver( + String instructionId, String dataStreamId, List apiServiceDescriptors); /** * Poisons the instruction id, indicating that future data arriving for it should be discarded. @@ -68,22 +69,7 @@ void registerReceiver( */ void poisonInstructionId(String instructionId); - /** - * Creates a {@link BeamFnDataOutboundAggregator} for buffering and sending outbound data and - * timers over the data plane. It is important that {@link - * BeamFnDataOutboundAggregator#sendOrCollectBufferedDataAndFinishOutboundStreams()} is called on - * the returned BeamFnDataOutboundAggregator at the end of each bundle. If - * collectElementsIfNoFlushes is set to true, {@link - * BeamFnDataOutboundAggregator#sendOrCollectBufferedDataAndFinishOutboundStreams()} returns the - * buffered elements instead of sending it through the outbound StreamObserver if there's no - * previous flush. - * - *

Closing the returned aggregator signals the end of the streams. - * - *

The returned aggregator is not thread safe. - */ - BeamFnDataOutboundAggregator createOutboundAggregator( - Endpoints.ApiServiceDescriptor apiServiceDescriptor, - Supplier processBundleRequestIdSupplier, - boolean collectElementsIfNoFlushes); + /** Get the outbound observer for the specified apiServiceDescriptor and dataStreamId. */ + StreamObserver getOutboundObserver( + Endpoints.ApiServiceDescriptor apiServiceDescriptor, String dataStreamId); } diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/BeamFnDataGrpcClient.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/BeamFnDataGrpcClient.java index 499d816f8cc0..2f2a6b0fc660 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/BeamFnDataGrpcClient.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/BeamFnDataGrpcClient.java @@ -18,20 +18,22 @@ package org.apache.beam.fn.harness.data; import java.util.List; +import java.util.Objects; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import java.util.function.Function; -import java.util.function.Supplier; +import javax.annotation.Nullable; import org.apache.beam.model.fnexecution.v1.BeamFnApi.Elements; import org.apache.beam.model.fnexecution.v1.BeamFnDataGrpc; import org.apache.beam.model.pipeline.v1.Endpoints; import org.apache.beam.model.pipeline.v1.Endpoints.ApiServiceDescriptor; import org.apache.beam.sdk.fn.data.BeamFnDataGrpcMultiplexer; -import org.apache.beam.sdk.fn.data.BeamFnDataOutboundAggregator; import org.apache.beam.sdk.fn.data.CloseableFnDataReceiver; import org.apache.beam.sdk.fn.stream.OutboundObserverFactory; -import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.ManagedChannel; +import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.Metadata; +import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.stub.MetadataUtils; +import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.stub.StreamObserver; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -44,17 +46,42 @@ public class BeamFnDataGrpcClient implements BeamFnDataClient { private static final Logger LOG = LoggerFactory.getLogger(BeamFnDataGrpcClient.class); - private final ConcurrentMap - multiplexerCache; + private static class MultiplexerKey { + private final Endpoints.ApiServiceDescriptor apiServiceDescriptor; + private final String dataStreamId; + + private MultiplexerKey( + Endpoints.ApiServiceDescriptor apiServiceDescriptor, String dataStreamId) { + this.apiServiceDescriptor = apiServiceDescriptor; + this.dataStreamId = dataStreamId; + } + + @Override + public boolean equals(@Nullable Object o) { + if (this == o) { + return true; + } + if (!(o instanceof MultiplexerKey)) { + return false; + } + MultiplexerKey that = (MultiplexerKey) o; + return Objects.equals(dataStreamId, that.dataStreamId) + && Objects.equals(apiServiceDescriptor, that.apiServiceDescriptor); + } + + @Override + public int hashCode() { + return Objects.hash(apiServiceDescriptor, dataStreamId); + } + } + + private final ConcurrentMap multiplexerCache; private final Function channelFactory; private final OutboundObserverFactory outboundObserverFactory; - private final PipelineOptions options; public BeamFnDataGrpcClient( - PipelineOptions options, Function channelFactory, OutboundObserverFactory outboundObserverFactory) { - this.options = options; this.channelFactory = channelFactory; this.outboundObserverFactory = outboundObserverFactory; this.multiplexerCache = new ConcurrentHashMap<>(); @@ -63,21 +90,22 @@ public BeamFnDataGrpcClient( @Override public void registerReceiver( String instructionId, + String dataStreamId, List apiServiceDescriptors, CloseableFnDataReceiver receiver) { LOG.debug("Registering consumer for {}", instructionId); for (int i = 0, size = apiServiceDescriptors.size(); i < size; i++) { - BeamFnDataGrpcMultiplexer client = getClientFor(apiServiceDescriptors.get(i)); + BeamFnDataGrpcMultiplexer client = getMultiplexer(apiServiceDescriptors.get(i), dataStreamId); client.registerConsumer(instructionId, receiver); } } @Override public void unregisterReceiver( - String instructionId, List apiServiceDescriptors) { + String instructionId, String dataStreamId, List apiServiceDescriptors) { LOG.debug("Unregistering consumer for {}", instructionId); for (int i = 0, size = apiServiceDescriptors.size(); i < size; i++) { - BeamFnDataGrpcMultiplexer client = getClientFor(apiServiceDescriptors.get(i)); + BeamFnDataGrpcMultiplexer client = getMultiplexer(apiServiceDescriptors.get(i), dataStreamId); client.unregisterConsumer(instructionId); } } @@ -91,25 +119,32 @@ public void poisonInstructionId(String instructionId) { } @Override - public BeamFnDataOutboundAggregator createOutboundAggregator( - ApiServiceDescriptor apiServiceDescriptor, - Supplier processBundleRequestIdSupplier, - boolean collectElementsIfNoFlushes) { - return new BeamFnDataOutboundAggregator( - options, - processBundleRequestIdSupplier, - getClientFor(apiServiceDescriptor).getOutboundObserver(), - collectElementsIfNoFlushes); + public StreamObserver getOutboundObserver( + ApiServiceDescriptor apiServiceDescriptor, String dataStreamId) { + return getMultiplexer(apiServiceDescriptor, dataStreamId).getOutboundObserver(); } - private BeamFnDataGrpcMultiplexer getClientFor( - Endpoints.ApiServiceDescriptor apiServiceDescriptor) { + private BeamFnDataGrpcMultiplexer getMultiplexer( + Endpoints.ApiServiceDescriptor apiServiceDescriptor, String dataStreamId) { + MultiplexerKey key = new MultiplexerKey(apiServiceDescriptor, dataStreamId); return multiplexerCache.computeIfAbsent( - apiServiceDescriptor, - (Endpoints.ApiServiceDescriptor descriptor) -> - new BeamFnDataGrpcMultiplexer( - descriptor, - outboundObserverFactory, - BeamFnDataGrpc.newStub(channelFactory.apply(apiServiceDescriptor))::data)); + key, + k -> { + OutboundObserverFactory.BasicFactory baseOutboundObserverFactory = + inboundObserver -> { + BeamFnDataGrpc.BeamFnDataStub stub = + BeamFnDataGrpc.newStub(channelFactory.apply(apiServiceDescriptor)); + if (dataStreamId != null && !dataStreamId.isEmpty()) { + Metadata headers = new Metadata(); + headers.put( + Metadata.Key.of("data_stream_id", Metadata.ASCII_STRING_MARSHALLER), + dataStreamId); + stub = stub.withInterceptors(MetadataUtils.newAttachHeadersInterceptor(headers)); + } + return stub.data(inboundObserver); + }; + return new BeamFnDataGrpcMultiplexer( + apiServiceDescriptor, outboundObserverFactory, baseOutboundObserverFactory); + }); } } diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataWriteRunnerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataWriteRunnerTest.java index 70a894e7b375..6c8abdbb3adf 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataWriteRunnerTest.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataWriteRunnerTest.java @@ -32,7 +32,6 @@ import java.util.Map; import java.util.ServiceLoader; import java.util.concurrent.atomic.AtomicReference; -import java.util.function.Supplier; import org.apache.beam.fn.harness.PTransformRunnerFactory.Registrar; import org.apache.beam.fn.harness.data.BeamFnDataClient; import org.apache.beam.model.fnexecution.v1.BeamFnApi; @@ -107,19 +106,29 @@ public void setUp() { MockitoAnnotations.initMocks(this); } - private BeamFnDataOutboundAggregator createRecordingAggregator( - Map>> output, Supplier bundleId) { + @Test + public void testReuseForMultipleBundles() throws Exception { + AtomicReference bundleId = new AtomicReference<>("0"); + String localInputId = "inputPC"; + RunnerApi.PTransform pTransform = + RemoteGrpcPortWrite.writeToPort(localInputId, PORT_SPEC).toPTransform(); + + List> output0 = new ArrayList<>(); + List> output1 = new ArrayList<>(); + Map aggregators = new HashMap<>(); + PipelineOptions options = PipelineOptionsFactory.create(); options.as(ExperimentalOptions.class).setExperiments(Arrays.asList("data_buffer_size_limit=0")); - return new BeamFnDataOutboundAggregator( - options, - bundleId, + BeamFnDataOutboundAggregator aggregator = new BeamFnDataOutboundAggregator(options, false); + + Map>> outputs = ImmutableMap.of("0", output0, "1", output1); + StreamObserver observer = new StreamObserver() { @Override public void onNext(Elements elements) { for (Data data : elements.getDataList()) { try { - output.get(bundleId.get()).add(WIRE_CODER.decode(data.getData().newInput())); + outputs.get(bundleId.get()).add(WIRE_CODER.decode(data.getData().newInput())); } catch (IOException e) { throw new RuntimeException("Failed to decode output."); } @@ -131,22 +140,9 @@ public void onError(Throwable throwable) {} @Override public void onCompleted() {} - }, - false); - } + }; - @Test - public void testReuseForMultipleBundles() throws Exception { - AtomicReference bundleId = new AtomicReference<>("0"); - String localInputId = "inputPC"; - RunnerApi.PTransform pTransform = - RemoteGrpcPortWrite.writeToPort(localInputId, PORT_SPEC).toPTransform(); - - List> output0 = new ArrayList<>(); - List> output1 = new ArrayList<>(); - Map aggregators = new HashMap<>(); - BeamFnDataOutboundAggregator aggregator = - createRecordingAggregator(ImmutableMap.of("0", output0, "1", output1), bundleId::get); + aggregator.prepareForInstruction(bundleId.get(), observer); aggregators.put(PORT_SPEC.getApiServiceDescriptor(), aggregator); PTransformRunnerFactoryTestContext context = @@ -179,6 +175,7 @@ public void testReuseForMultipleBundles() throws Exception { // Process for bundle id 1 bundleId.set("1"); + aggregator.prepareForInstruction(bundleId.get(), observer); pCollectionConsumer.accept(valueInGlobalWindow("GHI")); pCollectionConsumer.accept(valueInGlobalWindow("JKL")); diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java index 50a2fec0b5a2..2aa555e83cd5 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java @@ -839,7 +839,7 @@ private class TestBeamFnDataOutboundAggregator extends BeamFnDataOutboundAggrega private Supplier processBundleRequestIdSupplier; public TestBeamFnDataOutboundAggregator(Supplier bundleIdSupplier) { - super(PipelineOptionsFactory.create(), bundleIdSupplier, null, false); + super(PipelineOptionsFactory.create(), false); this.timers = new HashMap<>(); this.dataOutput = new HashMap<>(); this.processBundleRequestIdSupplier = bundleIdSupplier; diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/PTransformRunnerFactoryTestContext.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/PTransformRunnerFactoryTestContext.java index 7b4387738a4c..51e49953b406 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/PTransformRunnerFactoryTestContext.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/PTransformRunnerFactoryTestContext.java @@ -54,6 +54,7 @@ import org.apache.beam.sdk.transforms.DoFn.BundleFinalizer; import org.apache.beam.sdk.util.construction.Timer; import org.apache.beam.sdk.values.WindowedValue; +import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.stub.StreamObserver; import org.joda.time.Instant; /** @@ -74,6 +75,7 @@ public static Builder builder(String pTransformId, RunnerApi.PTransform pTransfo @Override public void registerReceiver( String instructionId, + String dataStreamId, List apiServiceDescriptors, CloseableFnDataReceiver receiver) { throw new UnsupportedOperationException("Unexpected call during test."); @@ -81,15 +83,15 @@ public void registerReceiver( @Override public void unregisterReceiver( - String instructionId, List apiServiceDescriptors) { + String instructionId, + String dataStreamId, + List apiServiceDescriptors) { throw new UnsupportedOperationException("Unexpected call during test."); } @Override - public BeamFnDataOutboundAggregator createOutboundAggregator( - ApiServiceDescriptor apiServiceDescriptor, - Supplier processBundleRequestIdSupplier, - boolean collectElementsIfNoFlushes) { + public StreamObserver getOutboundObserver( + ApiServiceDescriptor apiServiceDescriptor, String dataStreamId) { throw new UnsupportedOperationException("Unexpected call during test."); } diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java index 47f85178b0a1..c03f82726740 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java @@ -37,7 +37,6 @@ import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyBoolean; import static org.mockito.Mockito.argThat; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.eq; @@ -1071,28 +1070,20 @@ private ProcessBundleHandler setupProcessBundleHandlerForSimpleRecordingDoFn( dataOutput.add(input.getValue()); })); - Mockito.doAnswer( - (invocation) -> - new BeamFnDataOutboundAggregator( - PipelineOptionsFactory.create(), - invocation.getArgument(1), - new StreamObserver() { - @Override - public void onNext(Elements elements) { - for (Timers timer : elements.getTimersList()) { - timerOutput.addAll(elements.getTimersList()); - } - } + Mockito.when(beamFnDataClient.getOutboundObserver(any(), any())) + .thenReturn( + new StreamObserver() { + @Override + public void onNext(Elements elements) { + timerOutput.addAll(elements.getTimersList()); + } - @Override - public void onError(Throwable throwable) {} + @Override + public void onError(Throwable throwable) {} - @Override - public void onCompleted() {} - }, - invocation.getArgument(2))) - .when(beamFnDataClient) - .createOutboundAggregator(any(), any(), anyBoolean()); + @Override + public void onCompleted() {} + }); return new ProcessBundleHandler( PipelineOptionsFactory.create(), @@ -1409,7 +1400,7 @@ public void testInstructionIsUnregisteredFromBeamFnDataClientOnSuccess() throws (invocation) -> { String instructionId = invocation.getArgument(0, String.class); CloseableFnDataReceiver data = - invocation.getArgument(2, CloseableFnDataReceiver.class); + invocation.getArgument(3, CloseableFnDataReceiver.class); data.accept( BeamFnApi.Elements.newBuilder() .addData( @@ -1421,7 +1412,7 @@ public void testInstructionIsUnregisteredFromBeamFnDataClientOnSuccess() throws return null; }) .when(beamFnDataClient) - .registerReceiver(any(), any(), any()); + .registerReceiver(any(), any(), any(), any()); ProcessBundleHandler handler = new ProcessBundleHandler( @@ -1451,8 +1442,8 @@ public void testInstructionIsUnregisteredFromBeamFnDataClientOnSuccess() throws .build()); // Ensure that we unregister during successful processing - verify(beamFnDataClient).registerReceiver(eq("instructionId"), any(), any()); - verify(beamFnDataClient).unregisterReceiver(eq("instructionId"), any()); + verify(beamFnDataClient).registerReceiver(eq("instructionId"), any(), any(), any()); + verify(beamFnDataClient).unregisterReceiver(eq("instructionId"), any(), any()); verifyNoMoreInteractions(beamFnDataClient); } @@ -1475,7 +1466,7 @@ public void testDataProcessingExceptionsArePropagated() throws Exception { StringUtf8Coder.of().encode("A", encodedData); String instructionId = invocation.getArgument(0, String.class); CloseableFnDataReceiver data = - invocation.getArgument(2, CloseableFnDataReceiver.class); + invocation.getArgument(3, CloseableFnDataReceiver.class); data.accept( BeamFnApi.Elements.newBuilder() .addData( @@ -1489,7 +1480,7 @@ public void testDataProcessingExceptionsArePropagated() throws Exception { return null; }) .when(beamFnDataClient) - .registerReceiver(any(), any(), any()); + .registerReceiver(any(), any(), any(), any()); ProcessBundleHandler handler = new ProcessBundleHandler( @@ -1526,7 +1517,7 @@ public void testDataProcessingExceptionsArePropagated() throws Exception { .build())); // Ensure that we unregister during successful processing - verify(beamFnDataClient).registerReceiver(eq("instructionId"), any(), any()); + verify(beamFnDataClient).registerReceiver(eq("instructionId"), any(), any(), any()); verify(beamFnDataClient).poisonInstructionId(eq("instructionId")); verifyNoMoreInteractions(beamFnDataClient); } diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/BeamFnDataGrpcClientTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/BeamFnDataGrpcClientTest.java index 15f83f2582c7..9d9efa0b9c49 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/BeamFnDataGrpcClientTest.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/BeamFnDataGrpcClientTest.java @@ -23,8 +23,8 @@ import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.empty; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; -import static org.junit.Assert.fail; import java.util.Arrays; import java.util.Collection; @@ -49,6 +49,7 @@ import org.apache.beam.sdk.fn.data.LogicalEndpoint; import org.apache.beam.sdk.fn.stream.OutboundObserverFactory; import org.apache.beam.sdk.fn.test.TestStreams; +import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.transforms.windowing.GlobalWindow; import org.apache.beam.sdk.values.WindowedValue; @@ -169,7 +170,6 @@ public StreamObserver data( BeamFnDataGrpcClient clientFactory = new BeamFnDataGrpcClient( - PipelineOptionsFactory.create(), (Endpoints.ApiServiceDescriptor descriptor) -> channel, OutboundObserverFactory.trivial()); @@ -183,7 +183,7 @@ public StreamObserver data( Collections.emptyList()); clientFactory.registerReceiver( - INSTRUCTION_ID_A, Arrays.asList(apiServiceDescriptor), observerA); + INSTRUCTION_ID_A, "", Arrays.asList(apiServiceDescriptor), observerA); waitForClientToConnect.await(); outboundServerObserver.get().onNext(ELEMENTS_A_1); @@ -193,7 +193,7 @@ public StreamObserver data( Thread.sleep(100); clientFactory.registerReceiver( - INSTRUCTION_ID_B, Arrays.asList(apiServiceDescriptor), observerB); + INSTRUCTION_ID_B, "", Arrays.asList(apiServiceDescriptor), observerB); // Show that out of order stream completion can occur. observerB.awaitCompletion(); @@ -245,7 +245,6 @@ public StreamObserver data( BeamFnDataGrpcClient clientFactory = new BeamFnDataGrpcClient( - PipelineOptionsFactory.create(), (Endpoints.ApiServiceDescriptor descriptor) -> channel, OutboundObserverFactory.trivial()); @@ -262,7 +261,7 @@ public StreamObserver data( Collections.emptyList()); clientFactory.registerReceiver( - INSTRUCTION_ID_A, Arrays.asList(apiServiceDescriptor), observer); + INSTRUCTION_ID_A, "", Arrays.asList(apiServiceDescriptor), observer); waitForClientToConnect.await(); @@ -270,12 +269,8 @@ public StreamObserver data( outboundServerObserver.get().onNext(ELEMENTS_A_1); outboundServerObserver.get().onNext(ELEMENTS_A_2); - try { - observer.awaitCompletion(); - fail("Expected channel to fail"); - } catch (Exception e) { - assertEquals(exceptionToThrow, e); - } + Exception e = assertThrows(Exception.class, observer::awaitCompletion); + assertEquals(exceptionToThrow, e); // The server should not have received any values assertThat(inboundServerValues, empty()); // The consumer should have only been invoked once @@ -321,7 +316,6 @@ public StreamObserver data( BeamFnDataGrpcClient clientFactory = new BeamFnDataGrpcClient( - PipelineOptionsFactory.create(), (Endpoints.ApiServiceDescriptor descriptor) -> channel, OutboundObserverFactory.trivial()); @@ -347,7 +341,7 @@ public StreamObserver data( }); clientFactory.registerReceiver( - INSTRUCTION_ID_A, Arrays.asList(apiServiceDescriptor), observerA); + INSTRUCTION_ID_A, "", Arrays.asList(apiServiceDescriptor), observerA); waitForClientToConnect.await(); outboundServerObserver.get().onNext(ELEMENTS_B_1); @@ -358,11 +352,9 @@ public StreamObserver data( assertTrue(receivedAElement.await(5, TimeUnit.SECONDS)); clientFactory.poisonInstructionId(INSTRUCTION_ID_A); - try { - future.get(); - fail(); // We expect the awaitCompletion to fail due to closing. - } catch (Exception ignored) { - } + // We expect the awaitCompletion to fail due to closing. + // Expected. + assertThrows(Exception.class, future::get); outboundServerObserver.get().onNext(ELEMENTS_A_2); @@ -404,16 +396,15 @@ public StreamObserver data( ManagedChannel channel = InProcessChannelBuilder.forName(apiServiceDescriptor.getUrl()).build(); + PipelineOptions options = + PipelineOptionsFactory.fromArgs("--experiments=data_buffer_size_limit=20").create(); BeamFnDataGrpcClient clientFactory = new BeamFnDataGrpcClient( - PipelineOptionsFactory.fromArgs( - new String[] {"--experiments=data_buffer_size_limit=20"}) - .create(), (Endpoints.ApiServiceDescriptor descriptor) -> channel, OutboundObserverFactory.trivial()); - BeamFnDataOutboundAggregator aggregator = - clientFactory.createOutboundAggregator( - apiServiceDescriptor, () -> INSTRUCTION_ID_A, false); + BeamFnDataOutboundAggregator aggregator = new BeamFnDataOutboundAggregator(options, false); + aggregator.prepareForInstruction( + INSTRUCTION_ID_A, clientFactory.getOutboundObserver(apiServiceDescriptor, "")); FnDataReceiver> fnDataReceiver = aggregator.registerOutputDataLocation(TRANSFORM_ID_A, CODER); fnDataReceiver.accept(valueInGlobalWindow("ABC"));