From b05a72732f1701c42134ca77cf93cc22afca70ef Mon Sep 17 00:00:00 2001 From: Mika Naylor Date: Mon, 1 Jun 2026 21:42:55 +0200 Subject: [PATCH] [FLINK-39378][table] Add support for Context, timers and on_time to ProcessTableFunctionTestHarness --- .../docs/dev/table/functions/ptfs.md | 112 ++- docs/content/docs/dev/table/functions/ptfs.md | 112 ++- .../runtime/functions/InvocationContext.java | 58 ++ .../ProcessTableFunctionTestHarness.java | 937 ++++++++++++++---- .../runtime/functions/ResolvedMethod.java | 57 ++ .../runtime/functions/RowStateConverter.java | 2 +- .../StructuredTypeStateConverter.java | 2 +- .../functions/TestHarnessCallContext.java | 9 +- .../functions/TestHarnessTableSemantics.java | 15 +- .../functions/TestHarnessTimerManager.java | 179 ++++ .../flink/table/runtime/functions/Timer.java | 118 +++ .../ProcessTableFunctionTestHarnessTest.java | 857 +++++++++++++++- 12 files changed, 2266 insertions(+), 192 deletions(-) create mode 100644 flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/InvocationContext.java create mode 100644 flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/ResolvedMethod.java create mode 100644 flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/TestHarnessTimerManager.java create mode 100644 flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/Timer.java diff --git a/docs/content.zh/docs/dev/table/functions/ptfs.md b/docs/content.zh/docs/dev/table/functions/ptfs.md index bf24d7ced0681..6b5c1b697822a 100644 --- a/docs/content.zh/docs/dev/table/functions/ptfs.md +++ b/docs/content.zh/docs/dev/table/functions/ptfs.md @@ -2465,6 +2465,115 @@ void testStateMutation() throws Exception { {{< /tab >}} {{< /tabs >}} +#### Testing with Timers and Context + +The harness supports the `Context` parameter, timer registration via `TimeContext`, and `onTimer` +callbacks. Use `.withOnTimeColumn()` to configure the event time column and `.setWatermark()` to +advance watermarks and fire eligible timers. + +{{< tabs "timer-testing" >}} +{{< tab "Java" >}} +```java +// A PTF that registers a named timer 5 seconds after each event, and emits when it fires. +@DataTypeHint("ROW") +public class TimerPTF extends ProcessTableFunction { + public void eval( + Context ctx, + @ArgumentHint({ArgumentTrait.SET_SEMANTIC_TABLE, ArgumentTrait.REQUIRE_ON_TIME}) + Row input) { + String name = input.getFieldAs("name"); + TimeContext timeCtx = ctx.timeContext(LocalDateTime.class); + timeCtx.registerOnTime("timeout-" + name, timeCtx.time().plus(Duration.ofSeconds(5))); + collect(Row.of("registered-" + name)); + } + + public void onTimer(OnTimerContext ctx) { + collect(Row.of("timer-fired-" + ctx.currentTimer())); + } +} + +@Test +void testTimerRegistrationAndFiring() throws Exception { + try (ProcessTableFunctionTestHarness harness = + ProcessTableFunctionTestHarness.ofClass(TimerPTF.class) + .withTableArgument("input", + DataTypes.of("ROW")) + .withPartitionBy("input", "partition") + .withOnTimeColumn("ts") + .build()) { + + harness.processElement(Row.of("P1", "Alice", LocalDateTime.of(2025, 1, 1, 0, 0, 1))); + + // Verify the timer was registered + assertThat(harness.getPendingTimers()).hasSize(1); + assertThat(harness.getPendingTimers().get(0).getName()).isEqualTo("timeout-Alice"); + + // Advance watermark past the timer's timestamp to fire it + harness.clearOutput(); + harness.setWatermark(LocalDateTime.of(2025, 1, 1, 0, 0, 7)); + + assertThat(harness.getOutput()) + .containsExactly( + Row.of("P1", "timer-fired-timeout-Alice", LocalDateTime.of(2025, 1, 1, 0, 0, 6))); + + assertThat(harness.getPendingTimers()).isEmpty(); + assertThat(harness.getFiredTimers()).hasSize(1); + } +} +``` +{{< /tab >}} +{{< /tabs >}} + +**Timers with State**: State persisted during `eval()` is accessible in `onTimer()`: + +{{< tabs "timer-state-testing" >}} +{{< tab "Java" >}} +```java +@DataTypeHint("ROW") +public class TimerWithStatePTF extends ProcessTableFunction { + public static class CounterState { + public long count = 0L; + } + + public void eval( + Context ctx, + @StateHint CounterState state, + @ArgumentHint({ArgumentTrait.SET_SEMANTIC_TABLE, ArgumentTrait.REQUIRE_ON_TIME}) + Row input) { + state.count++; + TimeContext timeCtx = ctx.timeContext(LocalDateTime.class); + timeCtx.registerOnTime("check", timeCtx.time().plus(Duration.ofSeconds(5))); + } + + public void onTimer(OnTimerContext ctx, @StateHint CounterState state) { + collect(Row.of("count=" + state.count)); + } +} + +@Test +void testTimerWithState() throws Exception { + try (ProcessTableFunctionTestHarness harness = + ProcessTableFunctionTestHarness.ofClass(TimerWithStatePTF.class) + .withTableArgument("input", + DataTypes.of("ROW")) + .withPartitionBy("input", "partition") + .withOnTimeColumn("ts") + .build()) { + + harness.processElement(Row.of("P1", LocalDateTime.of(2025, 1, 1, 0, 0, 1))); + harness.processElement(Row.of("P1", LocalDateTime.of(2025, 1, 1, 0, 0, 1))); + harness.processElement(Row.of("P1", LocalDateTime.of(2025, 1, 1, 0, 0, 1))); + + harness.setWatermark(LocalDateTime.of(2025, 1, 1, 0, 0, 7)); + assertThat(harness.getOutput()) + .containsExactly( + Row.of("P1", "count=3", LocalDateTime.of(2025, 1, 1, 0, 0, 6))); + } +} +``` +{{< /tab >}} +{{< /tabs >}} + #### Optional Partitioning For PTFs with `OPTIONAL_PARTITION_BY`, you can omit `withPartitionBy()` during harness setup. The @@ -2582,8 +2691,5 @@ void testPOJO() throws Exception { ### PTF Features Unsupported by the TestHarness -- `Context` parameter -- Timers (`onTimer`) -- `on_time` / `rowtime` - Update traits (`SUPPORTS_UPDATES`, `REQUIRE_UPDATE_BEFORE`) - State TTL (state is supported but TTL expiration is not yet implemented) diff --git a/docs/content/docs/dev/table/functions/ptfs.md b/docs/content/docs/dev/table/functions/ptfs.md index 2abf34fd06630..a20d2347296ed 100644 --- a/docs/content/docs/dev/table/functions/ptfs.md +++ b/docs/content/docs/dev/table/functions/ptfs.md @@ -2468,6 +2468,115 @@ void testStateMutation() throws Exception { {{< /tab >}} {{< /tabs >}} +#### Testing with Timers and Context + +The harness supports the `Context` parameter, timer registration via `TimeContext`, and `onTimer` +callbacks. Use `.withOnTimeColumn()` to configure the event time column and `.setWatermark()` to +advance watermarks and fire eligible timers. + +{{< tabs "timer-testing" >}} +{{< tab "Java" >}} +```java +// A PTF that registers a named timer 5 seconds after each event, and emits when it fires. +@DataTypeHint("ROW") +public class TimerPTF extends ProcessTableFunction { + public void eval( + Context ctx, + @ArgumentHint({ArgumentTrait.SET_SEMANTIC_TABLE, ArgumentTrait.REQUIRE_ON_TIME}) + Row input) { + String name = input.getFieldAs("name"); + TimeContext timeCtx = ctx.timeContext(LocalDateTime.class); + timeCtx.registerOnTime("timeout-" + name, timeCtx.time().plus(Duration.ofSeconds(5))); + collect(Row.of("registered-" + name)); + } + + public void onTimer(OnTimerContext ctx) { + collect(Row.of("timer-fired-" + ctx.currentTimer())); + } +} + +@Test +void testTimerRegistrationAndFiring() throws Exception { + try (ProcessTableFunctionTestHarness harness = + ProcessTableFunctionTestHarness.ofClass(TimerPTF.class) + .withTableArgument("input", + DataTypes.of("ROW")) + .withPartitionBy("input", "partition") + .withOnTimeColumn("ts") + .build()) { + + harness.processElement(Row.of("P1", "Alice", LocalDateTime.of(2025, 1, 1, 0, 0, 1))); + + // Verify the timer was registered + assertThat(harness.getPendingTimers()).hasSize(1); + assertThat(harness.getPendingTimers().get(0).getName()).isEqualTo("timeout-Alice"); + + // Advance watermark past the timer's timestamp to fire it + harness.clearOutput(); + harness.setWatermark(LocalDateTime.of(2025, 1, 1, 0, 0, 7)); + + assertThat(harness.getOutput()) + .containsExactly( + Row.of("P1", "timer-fired-timeout-Alice", LocalDateTime.of(2025, 1, 1, 0, 0, 6))); + + assertThat(harness.getPendingTimers()).isEmpty(); + assertThat(harness.getFiredTimers()).hasSize(1); + } +} +``` +{{< /tab >}} +{{< /tabs >}} + +**Timers with State**: State persisted during `eval()` is accessible in `onTimer()`: + +{{< tabs "timer-state-testing" >}} +{{< tab "Java" >}} +```java +@DataTypeHint("ROW") +public class TimerWithStatePTF extends ProcessTableFunction { + public static class CounterState { + public long count = 0L; + } + + public void eval( + Context ctx, + @StateHint CounterState state, + @ArgumentHint({ArgumentTrait.SET_SEMANTIC_TABLE, ArgumentTrait.REQUIRE_ON_TIME}) + Row input) { + state.count++; + TimeContext timeCtx = ctx.timeContext(LocalDateTime.class); + timeCtx.registerOnTime("check", timeCtx.time().plus(Duration.ofSeconds(5))); + } + + public void onTimer(OnTimerContext ctx, @StateHint CounterState state) { + collect(Row.of("count=" + state.count)); + } +} + +@Test +void testTimerWithState() throws Exception { + try (ProcessTableFunctionTestHarness harness = + ProcessTableFunctionTestHarness.ofClass(TimerWithStatePTF.class) + .withTableArgument("input", + DataTypes.of("ROW")) + .withPartitionBy("input", "partition") + .withOnTimeColumn("ts") + .build()) { + + harness.processElement(Row.of("P1", LocalDateTime.of(2025, 1, 1, 0, 0, 1))); + harness.processElement(Row.of("P1", LocalDateTime.of(2025, 1, 1, 0, 0, 1))); + harness.processElement(Row.of("P1", LocalDateTime.of(2025, 1, 1, 0, 0, 1))); + + harness.setWatermark(LocalDateTime.of(2025, 1, 1, 0, 0, 7)); + assertThat(harness.getOutput()) + .containsExactly( + Row.of("P1", "count=3", LocalDateTime.of(2025, 1, 1, 0, 0, 6))); + } +} +``` +{{< /tab >}} +{{< /tabs >}} + #### Optional Partitioning For PTFs with `OPTIONAL_PARTITION_BY`, you can omit `withPartitionBy()` during harness setup. The @@ -2585,8 +2694,5 @@ void testPOJO() throws Exception { ### PTF Features Unsupported by the TestHarness -- `Context` parameter -- Timers (`onTimer`) -- `on_time` / `rowtime` - Update traits (`SUPPORTS_UPDATES`, `REQUIRE_UPDATE_BEFORE`) - State TTL (state is supported but TTL expiration is not yet implemented) diff --git a/flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/InvocationContext.java b/flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/InvocationContext.java new file mode 100644 index 0000000000000..af54cc6856507 --- /dev/null +++ b/flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/InvocationContext.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.functions; + +import org.apache.flink.types.Row; + +import javax.annotation.Nullable; + +/** Captures the per-invocation state for an eval() or onTimer() call in the test harness. */ +class InvocationContext { + final Row partitionKey; + @Nullable final Row row; + @Nullable final String tableArgumentName; + @Nullable final Timer firingTimer; + + private InvocationContext( + Row partitionKey, + @Nullable Row row, + @Nullable String tableArgumentName, + @Nullable Timer firingTimer) { + this.partitionKey = partitionKey; + this.row = row; + this.tableArgumentName = tableArgumentName; + this.firingTimer = firingTimer; + } + + static InvocationContext forEval(Row partitionKey, Row row, String tableArgumentName) { + return new InvocationContext(partitionKey, row, tableArgumentName, null); + } + + static InvocationContext forTimer(Timer timer) { + return new InvocationContext(timer.partitionKey, null, null, timer); + } + + boolean isTimerInvocation() { + return firingTimer != null; + } + + boolean isEvalInvocation() { + return tableArgumentName != null; + } +} diff --git a/flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/ProcessTableFunctionTestHarness.java b/flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/ProcessTableFunctionTestHarness.java index 5a6eecb50f880..cc16317ecfe0d 100644 --- a/flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/ProcessTableFunctionTestHarness.java +++ b/flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/ProcessTableFunctionTestHarness.java @@ -20,7 +20,12 @@ import org.apache.flink.annotation.PublicEvolving; import org.apache.flink.table.annotation.ArgumentTrait; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.api.TableRuntimeException; +import org.apache.flink.table.api.dataview.ListView; +import org.apache.flink.table.api.dataview.MapView; import org.apache.flink.table.catalog.DataTypeFactory; +import org.apache.flink.table.connector.ChangelogMode; import org.apache.flink.table.data.RowData; import org.apache.flink.table.data.conversion.DataStructureConverter; import org.apache.flink.table.data.conversion.DataStructureConverters; @@ -31,6 +36,7 @@ import org.apache.flink.table.types.AbstractDataType; import org.apache.flink.table.types.DataType; import org.apache.flink.table.types.FieldsDataType; +import org.apache.flink.table.types.extraction.ExtractionUtils; import org.apache.flink.table.types.inference.StateTypeStrategy; import org.apache.flink.table.types.inference.StaticArgument; import org.apache.flink.table.types.inference.StaticArgumentTrait; @@ -43,16 +49,22 @@ import org.apache.flink.table.types.logical.RowType; import org.apache.flink.table.types.logical.StructuredType; import org.apache.flink.table.types.utils.TypeConversions; +import org.apache.flink.table.utils.DateTimeUtils; import org.apache.flink.types.Row; import org.apache.flink.types.RowKind; import org.apache.flink.util.Collector; +import javax.annotation.Nullable; + import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; import java.lang.reflect.Parameter; import java.time.Duration; +import java.time.Instant; +import java.time.LocalDateTime; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.EnumSet; import java.util.HashMap; import java.util.HashSet; @@ -61,6 +73,7 @@ import java.util.Optional; import java.util.Set; import java.util.stream.Collectors; +import java.util.stream.Stream; import static org.apache.flink.util.Preconditions.checkArgument; import static org.apache.flink.util.Preconditions.checkNotNull; @@ -111,7 +124,8 @@ private static class TableArgumentConverters { private final TestHarnessStateManager stateManager; private final String defaultTableArgument; - private final Method evalMethod; + private final ResolvedMethod eval; + @Nullable private final ResolvedMethod onTimer; private final List arguments; private final Map argumentsByName; @@ -121,22 +135,35 @@ private static class TableArgumentConverters { private final Map argumentConverters; private final DataStructureConverter harnessOutputConverter; + private final TestHarnessTimerManager timerManager; + @Nullable private final String onTimeColumnName; + + @Nullable private final Class rowtimeConversionClass; + + @Nullable private InvocationContext currentInvocation; + private ProcessTableFunctionTestHarness( ProcessTableFunction function, FunctionContext functionContext, - Method evalMethod, + ResolvedMethod eval, + @Nullable ResolvedMethod onTimer, List arguments, Map argumentConverters, DataStructureConverter harnessOutputConverter, - TestHarnessStateManager stateManager) + TestHarnessStateManager stateManager, + TestHarnessTimerManager timerManager, + @Nullable String onTimeColumnName) throws Exception { this.function = function; this.functionContext = functionContext; - this.evalMethod = evalMethod; + this.eval = eval; + this.onTimer = onTimer; this.arguments = arguments; this.argumentConverters = argumentConverters; this.harnessOutputConverter = harnessOutputConverter; this.stateManager = stateManager; + this.timerManager = timerManager; + this.onTimeColumnName = onTimeColumnName; this.output = new ArrayList<>(); this.collector = new HarnessCollector(); this.isOpen = false; @@ -159,6 +186,8 @@ private ProcessTableFunctionTestHarness( this.isSingleTableFunction = false; } + this.rowtimeConversionClass = resolveRowtimeConversionClass(tableArguments); + openFunction(); } @@ -256,7 +285,7 @@ public void process() throws Exception { Object[] args = arguments.stream().map(arg -> ((ScalarArgumentInfo) arg).value).toArray(); try { - evalMethod.invoke(function, args); + eval.method.invoke(function, args); } catch (InvocationTargetException e) { handleEvalInvocationException( "Exception occurred during scalar-only PTF eval() invocation.\n", args, e); @@ -303,6 +332,388 @@ public void clearStateForKey(String stateName, Row partitionKey) { stateManager.clearStateForKey(stateName, partitionKey); } + // ------------------------------------------------------------------------- + // Watermark & Timer API + // ------------------------------------------------------------------------- + + /** + * Sets the watermark for all tables to the given {@link LocalDateTime} and fires eligible + * timers. + */ + public void setWatermark(LocalDateTime watermark) throws Exception { + checkNotNull(watermark, "watermark must not be null"); + setWatermarkMillis(DateTimeUtils.toTimestampMillis(watermark)); + } + + /** Sets the watermark for all tables to the given {@link Instant} and fires eligible timers. */ + public void setWatermark(Instant watermark) throws Exception { + checkNotNull(watermark, "watermark must not be null"); + setWatermarkMillis(watermark.toEpochMilli()); + } + + /** + * Sets the watermark for a specific table to the given {@link LocalDateTime} and fires eligible + * timers. + */ + public void setWatermarkForTable(String tableArgument, LocalDateTime watermark) + throws Exception { + checkNotNull(watermark, "watermark must not be null"); + setWatermarkForTableMillis(tableArgument, DateTimeUtils.toTimestampMillis(watermark)); + } + + /** + * Sets the watermark for a specific table to the given {@link Instant} and fires eligible + * timers. + */ + public void setWatermarkForTable(String tableArgument, Instant watermark) throws Exception { + checkNotNull(watermark, "watermark must not be null"); + setWatermarkForTableMillis(tableArgument, watermark.toEpochMilli()); + } + + /** Returns all timers (both pending and fired), sorted by timestamp then name. */ + public List getTimers() { + return Stream.concat( + timerManager.getPendingTimers().stream(), + timerManager.getFiredTimers().stream()) + .sorted() + .collect(Collectors.toList()); + } + + /** Returns all pending (not yet fired) timers, sorted by timestamp then name. */ + public List getPendingTimers() { + return timerManager.getPendingTimers(); + } + + /** Returns all pending timers with the given name. */ + public List getPendingTimers(String timerName) { + return timerManager.getPendingTimers().stream() + .filter(t -> timerName.equals(t.getName())) + .collect(Collectors.toList()); + } + + /** Returns all timers that have fired, in the order they fired. */ + public List getFiredTimers() { + return timerManager.getFiredTimers(); + } + + /** Returns all fired timers with the given name. */ + public List getFiredTimers(String timerName) { + return timerManager.getFiredTimers().stream() + .filter(t -> timerName.equals(t.getName())) + .collect(Collectors.toList()); + } + + /** Clears the fired timer history. */ + public void clearFiredTimers() { + timerManager.clearFiredTimers(); + } + + private void setWatermarkMillis(long millis) throws Exception { + checkState(isOpen, "Harness is not open"); + for (TableArgumentInfo tableArg : ArgumentInfo.filterTableArguments(arguments)) { + timerManager.setTableWatermark(tableArg.name, millis); + } + timerManager.updateGlobalWatermarkAndFireTimers(this::fireTimer); + } + + private void setWatermarkForTableMillis(String tableArgument, long millis) throws Exception { + checkState(isOpen, "Harness is not open"); + checkNotNull(tableArgument, "tableArgument must not be null"); + checkArgument( + argumentsByName.get(tableArgument) instanceof TableArgumentInfo, + "Unknown or non-table argument: %s", + tableArgument); + timerManager.setTableWatermark(tableArgument, millis); + timerManager.updateGlobalWatermarkAndFireTimers(this::fireTimer); + } + + private void fireTimer(Timer timer) throws Exception { + if (onTimer == null) { + throw new IllegalStateException( + "Timer fired but no onTimer() method is defined in " + + function.getClass().getSimpleName()); + } + + currentInvocation = InvocationContext.forTimer(timer); + + try { + Map stateMap = stateManager.loadStateForKey(timer.partitionKey); + + List stateArgs = ArgumentInfo.filterStateArguments(arguments); + Object[] methodArgs = new Object[stateArgs.size()]; + for (int i = 0; i < stateArgs.size(); i++) { + methodArgs[i] = stateMap.get(stateArgs.get(i).name); + } + + onTimer.invoke(function, new TestOnTimerContext(stateMap), methodArgs); + stateManager.updateStateForKey(timer.partitionKey, stateMap); + } catch (InvocationTargetException e) { + Throwable cause = e.getCause(); + if (cause instanceof Exception) { + throw (Exception) cause; + } + throw new RuntimeException("onTimer() invocation failed", e); + } finally { + currentInvocation = null; + } + } + + // ------------------------------------------------------------------------- + // Context implementations + // ------------------------------------------------------------------------- + + private class TestContext implements ProcessTableFunction.Context { + final Map stateMap; + + TestContext(Map stateMap) { + this.stateMap = stateMap; + } + + @Override + public ProcessTableFunction.TimeContext timeContext( + Class conversionClass) { + return new TestTimeContext<>(conversionClass); + } + + @Override + public TableSemantics tableSemanticsFor(String argName) { + ArgumentInfo argInfo = argumentsByName.get(argName); + if (argInfo == null) { + throw new IllegalArgumentException( + String.format( + "Argument '%s' not found. Available arguments: %s", + argName, argumentsByName.keySet())); + } + if (!(argInfo instanceof TableArgumentInfo)) { + throw new IllegalArgumentException( + String.format( + "Argument '%s' is not a table argument (type: %s)", + argName, argInfo.getClass().getSimpleName())); + } + TableArgumentInfo tableArg = (TableArgumentInfo) argInfo; + int[] partitionIndices = getPartitionColumnIndices(tableArg); + int timeColumn = + onTimeColumnName != null + ? getFieldNames(tableArg.dataType).indexOf(onTimeColumnName) + : -1; + return new TestHarnessTableSemantics(tableArg.dataType, partitionIndices, timeColumn); + } + + @Override + public void clearState(String stateName) { + stateMap.remove(stateName); + } + + @Override + public void clearAllState() { + stateMap.clear(); + } + + @Override + public void clearAllTimers() { + timerManager.clearAll(currentInvocation.partitionKey); + } + + @Override + public void clearAll() { + stateMap.clear(); + timerManager.clearAll(currentInvocation.partitionKey); + } + + @Override + public ChangelogMode getChangelogMode() { + return ChangelogMode.insertOnly(); + } + } + + private class TestTimeContext implements ProcessTableFunction.TimeContext { + private final Class conversionClass; + + TestTimeContext(Class conversionClass) { + this.conversionClass = conversionClass; + } + + @Override + public TimeType time() { + InvocationContext ctx = currentInvocation; + if (ctx.isTimerInvocation()) { + return fromMillis(ctx.firingTimer.timestamp); + } + if (ctx.isEvalInvocation() && onTimeColumnName != null) { + ArgumentInfo argInfo = argumentsByName.get(ctx.tableArgumentName); + if (argInfo instanceof TableArgumentInfo) { + TableArgumentInfo tableArg = (TableArgumentInfo) argInfo; + if (!getFieldNames(tableArg.dataType).contains(onTimeColumnName)) { + return null; + } + } + Object timeValue = ctx.row.getField(onTimeColumnName); + if (timeValue == null) { + return null; + } + return fromMillis(toMillis(timeValue)); + } + return null; + } + + @Override + public TimeType tableWatermark() { + InvocationContext ctx = currentInvocation; + if (!ctx.isEvalInvocation()) { + return null; + } + Long wm = timerManager.getWatermarkForTable(ctx.tableArgumentName); + return wm != null ? fromMillis(wm) : null; + } + + @Override + public TimeType currentWatermark() { + Long wm = timerManager.getGlobalWatermark(); + return wm != null ? fromMillis(wm) : null; + } + + @Override + public void registerOnTime(String name, TimeType time) { + checkTimersEnabled(); + checkNotNull(name, "Timer name must not be null"); + checkNotNull(time, "Timer timestamp must not be null"); + timerManager.register(currentInvocation.partitionKey, toMillis(time), name); + } + + @Override + public void registerOnTime(TimeType time) { + checkTimersEnabled(); + checkNotNull(time, "Timer timestamp must not be null"); + timerManager.register(currentInvocation.partitionKey, toMillis(time), null); + } + + @Override + public void clearTimer(String name) { + checkNotNull(name, "Timer name must not be null"); + timerManager.clearByName(currentInvocation.partitionKey, name); + } + + @Override + public void clearTimer(TimeType time) { + checkNotNull(time, "Timer timestamp must not be null"); + timerManager.clearByTimestamp(currentInvocation.partitionKey, toMillis(time)); + } + + @Override + public void clearAllTimers() { + timerManager.clearAll(currentInvocation.partitionKey); + } + + private void checkTimersEnabled() { + boolean enabled = + ArgumentInfo.filterTableArguments(arguments).stream() + .anyMatch( + t -> + t.isSetSemantic + && t.prependStrategy + != OutputPrependStrategy.ALL_COLUMNS); + if (!enabled) { + throw new TableRuntimeException( + "Timers are not supported in the current PTF declaration. " + + "Note that only PTFs that take set semantic tables support timers. " + + "Also timers are not available for advanced traits such as " + + "supporting pass-through columns or updates."); + } + } + + private TimeType fromMillis(long millis) { + return convertFromMillis(millis, conversionClass); + } + + private long toMillis(Object time) { + if (time instanceof Long) { + return (Long) time; + } else if (time instanceof Instant) { + return ((Instant) time).toEpochMilli(); + } else if (time instanceof LocalDateTime) { + return DateTimeUtils.toTimestampMillis((LocalDateTime) time); + } else if (time instanceof java.sql.Timestamp) { + return ((java.sql.Timestamp) time).getTime(); + } + throw new IllegalArgumentException( + "Unsupported time type: " + time.getClass().getSimpleName()); + } + } + + private class TestOnTimerContext extends TestContext + implements ProcessTableFunction.OnTimerContext { + TestOnTimerContext(Map stateMap) { + super(stateMap); + } + + @Override + public String currentTimer() { + if (currentInvocation.isTimerInvocation()) { + return currentInvocation.firingTimer.getName(); + } + return null; + } + } + + private static int[] getPartitionColumnIndices(TableArgumentInfo arg) { + if (arg.partitionColumnNames == null || arg.partitionColumnNames.length == 0) { + return new int[0]; + } + List fieldNames = getFieldNames(arg.dataType); + int[] indices = new int[arg.partitionColumnNames.length]; + for (int i = 0; i < arg.partitionColumnNames.length; i++) { + String colName = arg.partitionColumnNames[i]; + int index = fieldNames.indexOf(colName); + if (index < 0) { + throw new IllegalStateException( + "Partition column '" + + colName + + "' not found in table argument. " + + "Available fields: " + + fieldNames); + } + indices[i] = index; + } + return indices; + } + + @Nullable + private Class resolveRowtimeConversionClass(List tableArguments) { + if (onTimeColumnName == null) { + return null; + } + for (TableArgumentInfo tableArg : tableArguments) { + List fieldNames = getFieldNames(tableArg.dataType); + int idx = fieldNames.indexOf(onTimeColumnName); + if (idx >= 0) { + return DataType.getFields(tableArg.dataType) + .get(idx) + .getDataType() + .getConversionClass(); + } + } + throw new IllegalStateException( + "Could not resolve rowtime conversion class for column: " + onTimeColumnName); + } + + private Object rowtimeFromMillis(long millis) { + return convertFromMillis(millis, rowtimeConversionClass); + } + + @SuppressWarnings("unchecked") + private static T convertFromMillis(long millis, Class targetClass) { + if (targetClass == Long.class || targetClass == long.class) { + return (T) Long.valueOf(millis); + } else if (targetClass == Instant.class) { + return (T) Instant.ofEpochMilli(millis); + } else if (targetClass == LocalDateTime.class) { + return (T) DateTimeUtils.toLocalDateTime(millis); + } else if (targetClass == java.sql.Timestamp.class) { + return (T) DateTimeUtils.toSQLTimestamp(millis); + } + throw new IllegalArgumentException("Unsupported time type: " + targetClass); + } + private void invokeEval(TableArgumentInfo activeTableArg, Row activeRow) throws Exception { TableArgumentConverters converters = argumentConverters.get(activeTableArg.name); @@ -310,31 +721,30 @@ private void invokeEval(TableArgumentInfo activeTableArg, Row activeRow) throws Row namedRow = (Row) converters.toNamedRow.toExternal(rowData); Object evalArgument = converters.toEvalArgument.toExternal(rowData); - collector.setContext(activeTableArg, namedRow); - Row partitionKey = extractPartitionKey(activeTableArg, namedRow); + currentInvocation = InvocationContext.forEval(partitionKey, namedRow, activeTableArg.name); + Map stateMap = stateManager.loadStateForKey(partitionKey); - Object[] args = new Object[arguments.size()]; + Object[] methodArgs = new Object[arguments.size()]; int i = 0; - for (ArgumentInfo arg : arguments) { if (arg instanceof StateArgumentInfo) { - args[i++] = stateMap.get(arg.name); + methodArgs[i++] = stateMap.get(arg.name); } else if (arg instanceof TableArgumentInfo) { TableArgumentInfo tableArg = (TableArgumentInfo) arg; if (tableArg.name.equals(activeTableArg.name)) { - args[i++] = evalArgument; + methodArgs[i++] = evalArgument; } else { - args[i++] = null; + methodArgs[i++] = null; } } else if (arg instanceof ScalarArgumentInfo) { - args[i++] = ((ScalarArgumentInfo) arg).value; + methodArgs[i++] = ((ScalarArgumentInfo) arg).value; } } try { - evalMethod.invoke(function, args); + eval.invoke(function, new TestContext(stateMap), methodArgs); stateManager.updateStateForKey(partitionKey, stateMap); } catch (InvocationTargetException e) { String partitionInfo = @@ -348,7 +758,9 @@ private void invokeEval(TableArgumentInfo activeTableArg, Row activeRow) throws String.format( "Exception occurred during PTF eval() while processing table argument '%s'%s.\n", activeTableArg.name, partitionInfo); - handleEvalInvocationException(contextMessage, args, e); + handleEvalInvocationException(contextMessage, methodArgs, e); + } finally { + currentInvocation = null; } } @@ -364,36 +776,28 @@ private Row extractPartitionKey(TableArgumentInfo tableArg, Row row) { /** Collector implementation that stores output in the harness. */ private class HarnessCollector implements Collector { - private ArgumentInfo activeTableArg; - private Row activeRow; - - void setContext(ArgumentInfo tableArg, Row row) { - this.activeTableArg = tableArg; - this.activeRow = row; - } @Override public void collect(OUT record) { OUT finalRecord; - if (activeTableArg == null || !(activeTableArg instanceof TableArgumentInfo)) { - finalRecord = record; - } else { - TableArgumentInfo tableArg = (TableArgumentInfo) activeTableArg; - switch (tableArg.prependStrategy) { - case ALL_COLUMNS: - finalRecord = prependAllColumns(record); - break; - case PARTITION_KEYS: - finalRecord = prependPartitionKeys(record); - break; - case NONE: - finalRecord = record; - break; - default: - throw new IllegalStateException( - "Unknown prepend strategy: " + tableArg.prependStrategy); - } + OutputPrependStrategy strategy = resolvePrependStrategy(); + switch (strategy) { + case ALL_COLUMNS: + finalRecord = prependAllColumns(record); + break; + case PARTITION_KEYS: + finalRecord = prependPartitionKeys(record); + break; + case NONE: + finalRecord = record; + break; + default: + throw new IllegalStateException("Unknown prepend strategy: " + strategy); + } + + if (onTimeColumnName != null) { + finalRecord = appendRowtime(finalRecord); } // After prepending, round-trip through converter to ensure output has proper @@ -402,6 +806,23 @@ public void collect(OUT record) { output.add(structuredRecord); } + private OutputPrependStrategy resolvePrependStrategy() { + InvocationContext ctx = currentInvocation; + if (ctx == null) { + return OutputPrependStrategy.NONE; + } + if (ctx.isTimerInvocation()) { + return OutputPrependStrategy.PARTITION_KEYS; + } + if (ctx.isEvalInvocation()) { + ArgumentInfo argInfo = argumentsByName.get(ctx.tableArgumentName); + if (argInfo instanceof TableArgumentInfo) { + return ((TableArgumentInfo) argInfo).prependStrategy; + } + } + return OutputPrependStrategy.NONE; + } + @SuppressWarnings("unchecked") private OUT applyOutputConverter(OUT record) { if (record instanceof Row) { @@ -412,6 +833,32 @@ private OUT applyOutputConverter(OUT record) { return record; } + @SuppressWarnings("unchecked") + private OUT appendRowtime(OUT ptfOutput) { + if (!(ptfOutput instanceof Row)) { + throw new IllegalStateException( + "Cannot append rowtime to non-Row output type: " + ptfOutput.getClass()); + } + return (OUT) appendField((Row) ptfOutput, resolveRowtimeValue()); + } + + private Object resolveRowtimeValue() { + InvocationContext ctx = currentInvocation; + if (ctx.isTimerInvocation()) { + return rowtimeFromMillis(ctx.firingTimer.timestamp); + } + if (ctx.isEvalInvocation()) { + ArgumentInfo argInfo = argumentsByName.get(ctx.tableArgumentName); + if (argInfo instanceof TableArgumentInfo) { + TableArgumentInfo tableArg = (TableArgumentInfo) argInfo; + if (getFieldNames(tableArg.dataType).contains(onTimeColumnName)) { + return ctx.row.getField(onTimeColumnName); + } + } + } + return null; + } + @SuppressWarnings("unchecked") private OUT prependPartitionKeys(OUT ptfOutput) { if (!(ptfOutput instanceof Row)) { @@ -421,6 +868,7 @@ private OUT prependPartitionKeys(OUT ptfOutput) { } Row ptfRow = (Row) ptfOutput; + Row partitionKey = currentInvocation.partitionKey; int totalPartitionKeyCount = 0; for (ArgumentInfo arg : arguments) { @@ -437,21 +885,13 @@ private OUT prependPartitionKeys(OUT ptfOutput) { Row result = new Row(ptfRow.getKind(), totalArity); - TableArgumentInfo activeTableInfo = (TableArgumentInfo) activeTableArg; - Object[] partitionKeyValues = new Object[activeTableInfo.partitionColumnNames.length]; - for (int i = 0; i < activeTableInfo.partitionColumnNames.length; i++) { - String columnName = activeTableInfo.partitionColumnNames[i]; - int columnIndex = getFieldIndex(activeTableInfo.dataType, columnName); - partitionKeyValues[i] = activeRow.getField(columnIndex); - } - int resultIndex = 0; for (ArgumentInfo arg : arguments) { if (arg instanceof TableArgumentInfo) { TableArgumentInfo tableArg = (TableArgumentInfo) arg; if (tableArg.isSetSemantic && tableArg.partitionColumnNames != null) { for (int i = 0; i < tableArg.partitionColumnNames.length; i++) { - result.setField(resultIndex++, partitionKeyValues[i]); + result.setField(resultIndex++, partitionKey.getField(i)); } } } @@ -464,17 +904,6 @@ private OUT prependPartitionKeys(OUT ptfOutput) { return (OUT) result; } - /** Helper to get field index by name from a DataType. */ - private int getFieldIndex(DataType dataType, String fieldName) { - List fieldNames = getFieldNames(dataType); - int index = fieldNames.indexOf(fieldName); - if (index < 0) { - throw new IllegalStateException( - String.format("Field '%s' not found in type %s", fieldName, dataType)); - } - return index; - } - @SuppressWarnings("unchecked") private OUT prependAllColumns(OUT ptfOutput) { if (!(ptfOutput instanceof Row)) { @@ -483,14 +912,15 @@ private OUT prependAllColumns(OUT ptfOutput) { } Row ptfRow = (Row) ptfOutput; - int inputArity = activeRow.getArity(); + Row inputRow = currentInvocation.row; + int inputArity = inputRow.getArity(); int ptfOutputArity = ptfRow.getArity(); int totalArity = inputArity + ptfOutputArity; Row result = new Row(ptfRow.getKind(), totalArity); for (int i = 0; i < inputArity; i++) { - result.setField(i, activeRow.getField(i)); + result.setField(i, inputRow.getField(i)); } for (int i = 0; i < ptfOutputArity; i++) { @@ -504,30 +934,34 @@ private OUT prependAllColumns(OUT ptfOutput) { public void close() {} } + private static Row appendField(Row row, Object value) { + Row result = new Row(row.getArity() + 1); + for (int i = 0; i < row.getArity(); i++) { + result.setField(i, row.getField(i)); + } + result.setField(row.getArity(), value); + return result; + } + /** Extracts field names from RowType or StructuredType. */ private static List getFieldNames(DataType dataType) { LogicalType logicalType = dataType.getLogicalType(); - List fieldNames = new ArrayList<>(); - if (logicalType instanceof RowType) { - RowType rowType = (RowType) logicalType; - for (RowType.RowField field : rowType.getFields()) { - fieldNames.add(field.getName()); - } + return ((RowType) logicalType) + .getFields().stream() + .map(RowType.RowField::getName) + .collect(Collectors.toList()); } else if (logicalType instanceof StructuredType) { - StructuredType structuredType = (StructuredType) logicalType; - for (StructuredType.StructuredAttribute attr : structuredType.getAttributes()) { - fieldNames.add(attr.getName()); - } - } else { - throw new IllegalStateException( - String.format( - "Unsupported data type: %s. " - + "Only Row and structured types are supported.", - dataType)); + return ((StructuredType) logicalType) + .getAttributes().stream() + .map(StructuredType.StructuredAttribute::getName) + .collect(Collectors.toList()); } - - return fieldNames; + throw new IllegalStateException( + String.format( + "Unsupported data type: %s. " + + "Only Row and structured types are supported.", + dataType)); } /** @@ -543,6 +977,9 @@ public static class Builder { private final Map tableArgs = new HashMap<>(); private final Map partitionConfigs = new HashMap<>(); private final Map stateArgs = new HashMap<>(); + @Nullable private String onTimeColumnName = null; + private final Map initialWatermarks = new HashMap<>(); + @Nullable private Long initialWatermarkForAll = null; private Builder(Class> functionClass) { this.functionClass = checkNotNull(functionClass, "functionClass must not be null"); @@ -655,6 +1092,56 @@ public Builder withPartitionBy(String argumentName, String... columnNames) return this; } + // --------------------------------------------------------------------- + // Timer & Watermark Configuration + // --------------------------------------------------------------------- + + /** + * Configures the on-time column name for the function. + * + * @param columnName The column that carries event time + */ + public Builder withOnTimeColumn(String columnName) { + checkNotNull(columnName, "columnName must not be null"); + this.onTimeColumnName = columnName; + return this; + } + + // --------------------------------------------------------------------- + // Watermark + // --------------------------------------------------------------------- + + /** Sets the initial watermark for all table arguments. */ + public Builder withInitialWatermark(LocalDateTime watermark) { + checkNotNull(watermark, "watermark must not be null"); + this.initialWatermarkForAll = DateTimeUtils.toTimestampMillis(watermark); + return this; + } + + /** Sets the initial watermark for all table arguments. */ + public Builder withInitialWatermark(Instant watermark) { + checkNotNull(watermark, "watermark must not be null"); + this.initialWatermarkForAll = watermark.toEpochMilli(); + return this; + } + + /** Sets the initial watermark for a specific table argument. */ + public Builder withInitialWatermarkForTable( + String tableArgument, LocalDateTime watermark) { + checkNotNull(tableArgument, "tableArgument must not be null"); + checkNotNull(watermark, "watermark must not be null"); + initialWatermarks.put(tableArgument, DateTimeUtils.toTimestampMillis(watermark)); + return this; + } + + /** Sets the initial watermark for a specific table argument. */ + public Builder withInitialWatermarkForTable(String tableArgument, Instant watermark) { + checkNotNull(tableArgument, "tableArgument must not be null"); + checkNotNull(watermark, "watermark must not be null"); + initialWatermarks.put(tableArgument, watermark.toEpochMilli()); + return this; + } + // --------------------------------------------------------------------- // Build // --------------------------------------------------------------------- @@ -683,9 +1170,17 @@ public ProcessTableFunctionTestHarness build() throws Exception { FunctionContext functionContext = new FunctionContext(null, classLoader, null); - Method evalMethod = findEvalMethod(); - - validateEvalMethodSupported(evalMethod, arguments); + ResolvedMethod eval = + ResolvedMethod.of( + findEvalMethod(functionClass), ProcessTableFunction.Context.class); + Method onTimerMethod = findOnTimerMethod(functionClass, arguments); + ResolvedMethod onTimer = + onTimerMethod != null + ? ResolvedMethod.of( + onTimerMethod, ProcessTableFunction.OnTimerContext.class) + : null; + + validateEvalMethodSupported(eval, arguments); validatePartitionConsistency(arguments); validateInitialStateKeys(arguments); @@ -711,25 +1206,67 @@ public ProcessTableFunctionTestHarness build() throws Exception { // Extract table arguments for output type derivation // SystemTypeInference needs table semantics for pass-through column deduplication - List tableArgs = ArgumentInfo.filterTableArguments(arguments); + List tableArgInfos = ArgumentInfo.filterTableArguments(arguments); // Derive output schema using SystemTypeInference DataType derivedOutputType = deriveOutputTypeFromSystemInference( - function, dataTypeFactory, systemTypeInference, tableArgs); + function, + dataTypeFactory, + systemTypeInference, + arguments, + tableArgInfos); // Create output converter for PTF emissions DataStructureConverter harnessOutputConverter = createPTFOutputConverter(derivedOutputType); + // Validate onTimeColumn configuration + if (onTimeColumnName != null) { + boolean foundInAnyTable = + tableArgInfos.stream() + .anyMatch( + t -> getFieldNames(t.dataType).contains(onTimeColumnName)); + checkArgument( + foundInAnyTable, + "withOnTimeColumn references column '%s' which does not exist in any " + + "table argument. Available table arguments and their columns: %s", + onTimeColumnName, + tableArgInfos.stream() + .collect( + Collectors.toMap( + t -> t.name, t -> getFieldNames(t.dataType)))); + } + + TestHarnessTimerManager timerManager = new TestHarnessTimerManager(); + if (initialWatermarkForAll != null) { + for (TableArgumentInfo tableArg : tableArgInfos) { + timerManager.setTableWatermark(tableArg.name, initialWatermarkForAll); + } + } + Set tableArgNames = + tableArgInfos.stream().map(t -> t.name).collect(Collectors.toSet()); + for (Map.Entry entry : initialWatermarks.entrySet()) { + checkArgument( + tableArgNames.contains(entry.getKey()), + "withInitialWatermarkForTable references unknown table argument '%s'. " + + "Known table arguments: %s", + entry.getKey(), + tableArgNames); + timerManager.setTableWatermark(entry.getKey(), entry.getValue()); + } + return new ProcessTableFunctionTestHarness<>( function, functionContext, - evalMethod, + eval, + onTimer, arguments, argumentConverters, harnessOutputConverter, - stateManager); + stateManager, + timerManager, + onTimeColumnName); } /** @@ -803,7 +1340,7 @@ private void createConverters( } } - private Method findEvalMethod() throws NoSuchMethodException { + private static Method findEvalMethod(Class functionClass) throws NoSuchMethodException { Method[] methods = functionClass.getMethods(); Method evalMethod = null; int evalMethodCount = 0; @@ -823,49 +1360,78 @@ private Method findEvalMethod() throws NoSuchMethodException { "Multiple eval() methods found in " + functionClass.getSimpleName() + ". ProcessTableFunction must have exactly one eval() method."); - } else { - return evalMethod; } - } - /** - * Validates that the eval() method doesn't use unsupported features. Temporary, until - * context is supported. - */ - private void validateEvalMethodSupported(Method evalMethod, List arguments) { - Parameter[] parameters = evalMethod.getParameters(); + return evalMethod; + } - for (int i = 0; i < parameters.length; i++) { - Parameter param = parameters[i]; - Class paramType = param.getType(); + @Nullable + private static Method findOnTimerMethod( + Class functionClass, List arguments) { + List candidates = ExtractionUtils.collectMethods(functionClass, "onTimer"); + if (candidates.isEmpty()) { + return null; + } - if (ProcessTableFunction.Context.class.isAssignableFrom(paramType)) { - throw new IllegalStateException( - String.format( - "ProcessTableFunctionTestHarness does not yet support Context parameters. " - + "Found Context parameter at position %d in eval() method. ", - i)); + Class[] stateClasses = + ArgumentInfo.filterStateArguments(arguments).stream() + .map(s -> s.dataType.getConversionClass()) + .toArray(Class[]::new); + + // Try with OnTimerContext first, then without — mirrors the code generator + Class[] withContext = new Class[stateClasses.length + 1]; + withContext[0] = ProcessTableFunction.OnTimerContext.class; + System.arraycopy(stateClasses, 0, withContext, 1, stateClasses.length); + + for (Class[] signature : new Class[][] {withContext, stateClasses}) { + Optional match = + candidates.stream() + .filter( + m -> + ExtractionUtils.isInvokable( + ExtractionUtils.Autoboxing.JVM, + m, + signature)) + .findFirst(); + if (match.isPresent()) { + return match.get(); } } - if (parameters.length != arguments.size()) { + throw new IllegalStateException( + String.format( + "Found %d onTimer() method(s) in %s but none match the expected " + + "signature: optional OnTimerContext followed by state entries %s.", + candidates.size(), + functionClass.getSimpleName(), + Arrays.toString(stateClasses))); + } + + private void validateEvalMethodSupported( + ResolvedMethod eval, List arguments) { + Parameter[] parameters = eval.method.getParameters(); + + int expectedParamCount = arguments.size() + (eval.takesContext ? 1 : 0); + if (parameters.length != expectedParamCount) { long stateCount = ArgumentInfo.filterStateArguments(arguments).size(); long nonStateCount = arguments.size() - stateCount; throw new IllegalStateException( String.format( "Parameter count mismatch: eval() has %d parameters but expected %d " - + "(%d state + %d table/scalar arguments). " + + "(%d state + %d table/scalar arguments%s). " + "eval() signature: %s. " + "This may indicate missing @ArgumentHint or @StateHint annotations.", parameters.length, - arguments.size(), + expectedParamCount, stateCount, nonStateCount, - evalMethod)); + eval.takesContext ? " + Context" : "", + eval.method)); } + int argOffset = eval.takesContext ? 1 : 0; for (int i = 0; i < arguments.size(); i++) { - Parameter param = parameters[i]; + Parameter param = parameters[i + argOffset]; Class paramType = param.getType(); ArgumentInfo arg = arguments.get(i); @@ -877,7 +1443,7 @@ private void validateEvalMethodSupported(Method evalMethod, List a "Type mismatch for scalar argument '%s' at position %d: " + "eval() parameter expects %s but provided value is %s", arg.name, - i, + i + argOffset, paramType.getName(), value.getClass().getName())); } @@ -1164,40 +1730,45 @@ private List extractAndValidateTypeInference( } /** Creates appropriate StateConverter for the given state data type. */ - private StateConverter createStateConverter(DataType stateDataType, ClassLoader classLoader) - throws Exception { - LogicalType logicalType = stateDataType.getLogicalType(); + private StateConverter createStateConverter( + DataType stateDataType, ClassLoader classLoader) { + DataType resolvedType = + ListView.class.isAssignableFrom(stateDataType.getConversionClass()) + || MapView.class.isAssignableFrom( + stateDataType.getConversionClass()) + ? stateDataType.getChildren().get(0) + : stateDataType; + + LogicalType logicalType = resolvedType.getLogicalType(); if (logicalType instanceof ArrayType) { - ArrayType arrayType = (ArrayType) logicalType; - DataType elementType = stateDataType.getChildren().get(0); + DataType elementType = resolvedType.getChildren().get(0); DataStructureConverter elementConverter = DataStructureConverters.getConverter(elementType); elementConverter.open(classLoader); - return new ListViewStateConverter(arrayType, elementConverter); + return new ListViewStateConverter((ArrayType) logicalType, elementConverter); } else if (logicalType instanceof MapType) { - MapType mapType = (MapType) logicalType; - DataType keyType = stateDataType.getChildren().get(0); - DataType valueType = stateDataType.getChildren().get(1); + DataType keyType = resolvedType.getChildren().get(0); + DataType valueType = resolvedType.getChildren().get(1); DataStructureConverter keyConverter = DataStructureConverters.getConverter(keyType); DataStructureConverter valueConverter = DataStructureConverters.getConverter(valueType); keyConverter.open(classLoader); valueConverter.open(classLoader); - return new MapViewStateConverter(mapType, keyConverter, valueConverter); + return new MapViewStateConverter( + (MapType) logicalType, keyConverter, valueConverter); } else if (logicalType instanceof RowType) { - RowType rowType = (RowType) logicalType; DataStructureConverter converter = - DataStructureConverters.getConverter(stateDataType); + DataStructureConverters.getConverter(resolvedType); converter.open(classLoader); - return new RowStateConverter(converter, rowType); + return new RowStateConverter((RowType) logicalType, converter); } else { DataStructureConverter converter = - DataStructureConverters.getConverter(stateDataType); + DataStructureConverters.getConverter(resolvedType); converter.open(classLoader); - Class stateClass = stateDataType.getConversionClass(); - return new StructuredTypeStateConverter(converter, stateClass); + return new StructuredTypeStateConverter( + resolvedType.getConversionClass(), converter); } } @@ -1386,25 +1957,76 @@ private ProcessTableFunction instantiateFunction() throws IllegalArgumentEx /** * Derives the output schema using SystemTypeInference, including field name deduplication. + * + *

SystemTypeInference's staticArgs list includes both user-declared arguments and + * system-injected arguments (on_time, uid). The CallContext we build must mirror this full + * list positionally — each index maps to a staticArg. User args get their resolved + * DataType; system args get placeholder types (DESCRIPTOR for on_time, STRING for uid). + * Table semantics are attached at the positions of table arguments so SystemTypeInference + * can perform pass-through column deduplication. */ private DataType deriveOutputTypeFromSystemInference( ProcessTableFunction function, DataTypeFactory dataTypeFactory, TypeInference systemTypeInference, - List arguments) { + List allArguments, + List tableArguments) { - List argumentDataTypes = new ArrayList<>(); - for (TableArgumentInfo arg : arguments) { - argumentDataTypes.add(arg.dataType); + Optional> staticArgsOpt = systemTypeInference.getStaticArguments(); + List staticArgs = + staticArgsOpt.orElseThrow( + () -> + new IllegalStateException( + "SystemTypeInference has no static arguments")); + + Map tableArgsByName = new HashMap<>(); + for (TableArgumentInfo tableArg : tableArguments) { + tableArgsByName.put(tableArg.name, tableArg); + } + Map allArgsByName = new HashMap<>(); + for (ArgumentInfo arg : allArguments) { + if (arg.name != null) { + allArgsByName.put(arg.name, arg); + } } + List argumentDataTypes = new ArrayList<>(); Map tableSemanticsMap = new HashMap<>(); - for (int i = 0; i < arguments.size(); i++) { - TableArgumentInfo arg = arguments.get(i); - int[] partitionIndices = getPartitionColumnIndices(arg); - TableSemantics semantics = - new TestHarnessTableSemantics(arg.dataType, partitionIndices); - tableSemanticsMap.put(i, semantics); + int onTimePos = -1; + + for (int i = 0; i < staticArgs.size(); i++) { + StaticArgument staticArg = staticArgs.get(i); + String argName = staticArg.getName(); + + if (SystemTypeInference.PROCESS_TABLE_FUNCTION_ARG_ON_TIME.equals(argName)) { + argumentDataTypes.add(DataTypes.DESCRIPTOR()); + onTimePos = i; + } else if (SystemTypeInference.PROCESS_TABLE_FUNCTION_ARG_UID.equals(argName)) { + argumentDataTypes.add(DataTypes.STRING()); + } else { + ArgumentInfo argInfo = allArgsByName.get(argName); + if (argInfo != null) { + argumentDataTypes.add(argInfo.dataType); + } else { + argumentDataTypes.add(DataTypes.NULL()); + } + + TableArgumentInfo tableArg = tableArgsByName.get(argName); + if (tableArg != null) { + int[] partitionIndices = getPartitionColumnIndices(tableArg); + int timeColumnIndex = -1; + if (onTimeColumnName != null) { + int idx = getFieldNames(tableArg.dataType).indexOf(onTimeColumnName); + if (idx >= 0) { + timeColumnIndex = idx; + } + } + tableSemanticsMap.put( + i, + new TestHarnessTableSemantics( + tableArg.dataType, partitionIndices, timeColumnIndex)); + } + } } TestHarnessCallContext callContext = new TestHarnessCallContext(); @@ -1414,6 +2036,13 @@ private DataType deriveOutputTypeFromSystemInference( callContext.tableSemantics = tableSemanticsMap; callContext.name = function.getClass().getSimpleName(); + if (onTimePos >= 0 && onTimeColumnName != null) { + callContext.argumentValues.put( + onTimePos, + org.apache.flink.types.ColumnList.of( + Collections.singletonList(onTimeColumnName))); + } + TypeStrategy outputStrategy = systemTypeInference.getOutputTypeStrategy(); Optional outputTypeOpt = outputStrategy.inferType(callContext); @@ -1425,31 +2054,6 @@ private DataType deriveOutputTypeFromSystemInference( return outputTypeOpt.get(); } - - private int[] getPartitionColumnIndices(TableArgumentInfo arg) { - if (arg.partitionColumnNames == null || arg.partitionColumnNames.length == 0) { - return new int[0]; - } - - List fieldNames = getFieldNames(arg.dataType); - - int[] indices = new int[arg.partitionColumnNames.length]; - for (int i = 0; i < arg.partitionColumnNames.length; i++) { - String colName = arg.partitionColumnNames[i]; - int index = fieldNames.indexOf(colName); - if (index < 0) { - throw new IllegalStateException( - "Partition column '" - + colName - + "' not found in table argument. " - + "Available fields: " - + fieldNames); - } - indices[i] = index; - } - - return indices; - } } private enum OutputPrependStrategy { @@ -1459,16 +2063,17 @@ private enum OutputPrependStrategy { } private void handleEvalInvocationException( - String contextMessage, Object[] args, InvocationTargetException e) throws Exception { + String contextMessage, Object[] methodArgs, InvocationTargetException e) + throws Exception { Throwable cause = e.getCause(); StringBuilder details = new StringBuilder(); details.append(contextMessage); details.append("Expected parameter types: "); - details.append(Arrays.toString(evalMethod.getParameterTypes())); + details.append(Arrays.toString(eval.method.getParameterTypes())); details.append("\nActual arguments:\n"); for (int i = 0; i < arguments.size(); i++) { ArgumentInfo arg = arguments.get(i); - Object value = args[i]; + Object value = methodArgs[i]; details.append( String.format( " [%d] %s: %s (type: %s)\n", diff --git a/flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/ResolvedMethod.java b/flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/ResolvedMethod.java new file mode 100644 index 0000000000000..0a1954462c69b --- /dev/null +++ b/flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/ResolvedMethod.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.functions; + +import javax.annotation.Nullable; + +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; + +/** + * A resolved PTF method (eval or onTimer) paired with whether its first parameter accepts a + * Context. + */ +class ResolvedMethod { + final Method method; + final boolean takesContext; + + static ResolvedMethod of(Method method, Class contextClass) { + boolean takesContext = + method.getParameterTypes().length > 0 + && contextClass.isAssignableFrom(method.getParameterTypes()[0]); + return new ResolvedMethod(method, takesContext); + } + + private ResolvedMethod(Method method, boolean takesContext) { + this.method = method; + this.takesContext = takesContext; + } + + void invoke(Object target, @Nullable Object context, Object[] methodArgs) + throws InvocationTargetException, IllegalAccessException { + if (takesContext) { + Object[] args = new Object[1 + methodArgs.length]; + args[0] = context; + System.arraycopy(methodArgs, 0, args, 1, methodArgs.length); + method.invoke(target, args); + } else { + method.invoke(target, methodArgs); + } + } +} diff --git a/flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/RowStateConverter.java b/flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/RowStateConverter.java index f293edb6aae91..7f0de53de0f40 100644 --- a/flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/RowStateConverter.java +++ b/flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/RowStateConverter.java @@ -30,7 +30,7 @@ class RowStateConverter implements StateConverter { private final DataStructureConverter converter; private final RowType rowType; - RowStateConverter(DataStructureConverter converter, RowType rowType) { + RowStateConverter(RowType rowType, DataStructureConverter converter) { this.converter = converter; this.rowType = rowType; } diff --git a/flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/StructuredTypeStateConverter.java b/flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/StructuredTypeStateConverter.java index 5599aef14a7ac..5ea6cdc5334fb 100644 --- a/flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/StructuredTypeStateConverter.java +++ b/flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/StructuredTypeStateConverter.java @@ -33,7 +33,7 @@ class StructuredTypeStateConverter implements StateConverter { private final Class pojoClass; StructuredTypeStateConverter( - DataStructureConverter converter, Class pojoClass) { + Class pojoClass, DataStructureConverter converter) { this.converter = converter; this.pojoClass = pojoClass; } diff --git a/flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/TestHarnessCallContext.java b/flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/TestHarnessCallContext.java index 62c99bc41a124..b5d4583362d69 100644 --- a/flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/TestHarnessCallContext.java +++ b/flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/TestHarnessCallContext.java @@ -26,6 +26,7 @@ import org.apache.flink.table.types.DataType; import org.apache.flink.table.types.inference.CallContext; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; @@ -40,6 +41,7 @@ class TestHarnessCallContext implements CallContext { List argumentDataTypes; FunctionDefinition functionDefinition; Map tableSemantics; + Map argumentValues = new HashMap<>(); String name; @Override @@ -54,7 +56,7 @@ public FunctionDefinition getFunctionDefinition() { @Override public boolean isArgumentLiteral(int pos) { - return false; + return argumentValues.containsKey(pos); } @Override @@ -63,7 +65,12 @@ public boolean isArgumentNull(int pos) { } @Override + @SuppressWarnings("unchecked") public Optional getArgumentValue(int pos, Class clazz) { + Object value = argumentValues.get(pos); + if (value != null && clazz.isInstance(value)) { + return Optional.of((T) value); + } return Optional.empty(); } diff --git a/flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/TestHarnessTableSemantics.java b/flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/TestHarnessTableSemantics.java index 91edd9a059d69..a271070093a09 100644 --- a/flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/TestHarnessTableSemantics.java +++ b/flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/TestHarnessTableSemantics.java @@ -33,16 +33,25 @@ class TestHarnessTableSemantics implements TableSemantics { private final DataType dataType; private final int[] partitionByColumns; private final List upsertKeyColumns; + private final int timeColumnIndex; TestHarnessTableSemantics(DataType dataType, int[] partitionByColumns) { - this(dataType, partitionByColumns, Collections.emptyList()); + this(dataType, partitionByColumns, Collections.emptyList(), -1); + } + + TestHarnessTableSemantics(DataType dataType, int[] partitionByColumns, int timeColumnIndex) { + this(dataType, partitionByColumns, Collections.emptyList(), timeColumnIndex); } TestHarnessTableSemantics( - DataType dataType, int[] partitionByColumns, List upsertKeyColumns) { + DataType dataType, + int[] partitionByColumns, + List upsertKeyColumns, + int timeColumnIndex) { this.dataType = dataType; this.partitionByColumns = partitionByColumns; this.upsertKeyColumns = upsertKeyColumns; + this.timeColumnIndex = timeColumnIndex; } @Override @@ -67,7 +76,7 @@ public TableSemantics.SortDirection[] orderByDirections() { @Override public int timeColumn() { - return -1; + return timeColumnIndex; } @Override diff --git a/flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/TestHarnessTimerManager.java b/flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/TestHarnessTimerManager.java new file mode 100644 index 0000000000000..25d1a7e70abf1 --- /dev/null +++ b/flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/TestHarnessTimerManager.java @@ -0,0 +1,179 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.functions; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.types.Row; +import org.apache.flink.util.function.ThrowingConsumer; + +import javax.annotation.Nullable; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +/** + * Timer and watermark manager for {@link ProcessTableFunctionTestHarness}. + * + *

Handles timer registration, clearing, watermark tracking, and determining which timers are + * eligible to fire. + */ +@Internal +class TestHarnessTimerManager { + + private final Map> pendingTimersByPartition = new HashMap<>(); + private final List firedTimers = new ArrayList<>(); + private final Map watermarkByTable = new HashMap<>(); + @Nullable private Long globalWatermark; + + TestHarnessTimerManager() {} + + // ------------------------------------------------------------------------- + // Watermark + // ------------------------------------------------------------------------- + + void setTableWatermark(String tableName, long absoluteMillis) { + Long current = watermarkByTable.get(tableName); + if (current != null && absoluteMillis < current) { + throw new IllegalArgumentException( + String.format( + "Cannot move watermark backward for table '%s': current=%d, new=%d", + tableName, current, absoluteMillis)); + } + watermarkByTable.put(tableName, absoluteMillis); + } + + void updateGlobalWatermarkAndFireTimers(ThrowingConsumer firer) + throws Exception { + if (watermarkByTable.isEmpty()) { + return; + } + long newGlobalWatermark = Collections.min(watermarkByTable.values()); + if (globalWatermark != null && newGlobalWatermark < globalWatermark) { + throw new IllegalArgumentException( + String.format( + "Cannot move global watermark backward: current=%d, new=%d", + globalWatermark, newGlobalWatermark)); + } + globalWatermark = newGlobalWatermark; + fireEligibleTimers(newGlobalWatermark, firer); + } + + @Nullable + Long getGlobalWatermark() { + return globalWatermark; + } + + @Nullable + Long getWatermarkForTable(String tableName) { + return watermarkByTable.get(tableName); + } + + // ------------------------------------------------------------------------- + // Timer Registration + // ------------------------------------------------------------------------- + + void register(Row partitionKey, long timestampMillis, @Nullable String name) { + Set timerSet = + pendingTimersByPartition.computeIfAbsent(partitionKey, k -> new HashSet<>()); + + if (name != null) { + timerSet.removeIf(t -> name.equals(t.name)); + } + timerSet.add(new Timer(timestampMillis, name, partitionKey)); + } + + void clearByName(Row partitionKey, String name) { + Set timerSet = pendingTimersByPartition.get(partitionKey); + if (timerSet != null) { + timerSet.removeIf(t -> name.equals(t.name)); + } + } + + /** Clears unnamed timers matching the given timestamp. Named timers are not affected. */ + void clearByTimestamp(Row partitionKey, long timestamp) { + Set timerSet = pendingTimersByPartition.get(partitionKey); + if (timerSet != null) { + timerSet.removeIf(t -> t.name == null && t.timestamp == timestamp); + } + } + + void clearAll(Row partitionKey) { + pendingTimersByPartition.remove(partitionKey); + } + + // ------------------------------------------------------------------------- + // Timer Introspection + // ------------------------------------------------------------------------- + + List getPendingTimers() { + return pendingTimersByPartition.values().stream() + .flatMap(Collection::stream) + .sorted() + .collect(Collectors.toList()); + } + + List getFiredTimers() { + return new ArrayList<>(firedTimers); + } + + void clearFiredTimers() { + firedTimers.clear(); + } + + // ------------------------------------------------------------------------- + // Internal + // ------------------------------------------------------------------------- + + // Loop until no more eligible timers — handles cascading registrations + private void fireEligibleTimers(long watermark, ThrowingConsumer firer) + throws Exception { + while (true) { + List timersToFire = + pendingTimersByPartition.values().stream() + .flatMap(Collection::stream) + .filter(t -> t.timestamp <= watermark) + .sorted() + .collect(Collectors.toList()); + + if (timersToFire.isEmpty()) { + break; + } + + for (Timer timer : timersToFire) { + Set timerSet = + pendingTimersByPartition.getOrDefault( + timer.partitionKey, Collections.emptySet()); + + if (timerSet.contains(timer)) { + timerSet.remove(timer); + timer.markFired(); + firedTimers.add(timer); + firer.accept(timer); + } + } + } + } +} diff --git a/flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/Timer.java b/flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/Timer.java new file mode 100644 index 0000000000000..6e9c78a720c66 --- /dev/null +++ b/flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/Timer.java @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.functions; + +import org.apache.flink.annotation.PublicEvolving; +import org.apache.flink.types.Row; + +import javax.annotation.Nullable; + +import java.util.Objects; + +/** + * A timer registered by a {@link org.apache.flink.table.functions.ProcessTableFunction} during + * testing. + */ +@PublicEvolving +public class Timer implements Comparable { + + final long timestamp; + @Nullable final String name; + final Row partitionKey; + private boolean fired; + + Timer(long timestamp, @Nullable String name, Row partitionKey) { + this.timestamp = timestamp; + this.name = name; + this.partitionKey = partitionKey; + } + + public long getTimestamp() { + return timestamp; + } + + @Nullable + public String getName() { + return name; + } + + public Row getKey() { + return partitionKey; + } + + public boolean hasFired() { + return fired; + } + + void markFired() { + this.fired = true; + } + + /** + * Comparison of timers is done by timestamp first, then by name (unnamed timers sort after + * named ones), then by partition key, for a deterministic firing order. + */ + @Override + public int compareTo(Timer other) { + int cmp = Long.compare(this.timestamp, other.timestamp); + if (cmp != 0) { + return cmp; + } + if (this.name == null && other.name == null) { + return this.partitionKey.toString().compareTo(other.partitionKey.toString()); + } + if (this.name == null) { + return 1; + } + if (other.name == null) { + return -1; + } + cmp = this.name.compareTo(other.name); + if (cmp != 0) { + return cmp; + } + return this.partitionKey.toString().compareTo(other.partitionKey.toString()); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof Timer)) { + return false; + } + Timer that = (Timer) o; + return this.timestamp == that.timestamp + && Objects.equals(this.name, that.name) + && Objects.equals(this.partitionKey, that.partitionKey); + } + + @Override + public int hashCode() { + return Objects.hash(timestamp, name, partitionKey); + } + + @Override + public String toString() { + return String.format( + "Timer{timestamp=%d, name=%s, key=%s, fired=%s}", + timestamp, name != null ? "'" + name + "'" : "null", partitionKey, fired); + } +} diff --git a/flink-table/flink-table-test-utils/src/test/java/org/apache/flink/table/runtime/functions/ProcessTableFunctionTestHarnessTest.java b/flink-table/flink-table-test-utils/src/test/java/org/apache/flink/table/runtime/functions/ProcessTableFunctionTestHarnessTest.java index 88e1c8b32d84c..bc9bdaf4bbc0c 100644 --- a/flink-table/flink-table-test-utils/src/test/java/org/apache/flink/table/runtime/functions/ProcessTableFunctionTestHarnessTest.java +++ b/flink-table/flink-table-test-utils/src/test/java/org/apache/flink/table/runtime/functions/ProcessTableFunctionTestHarnessTest.java @@ -23,15 +23,20 @@ import org.apache.flink.table.annotation.DataTypeHint; import org.apache.flink.table.annotation.StateHint; import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.api.TableRuntimeException; import org.apache.flink.table.api.ValidationException; import org.apache.flink.table.api.dataview.ListView; import org.apache.flink.table.api.dataview.MapView; +import org.apache.flink.table.connector.ChangelogMode; import org.apache.flink.table.functions.ProcessTableFunction; +import org.apache.flink.table.functions.TableSemantics; import org.apache.flink.types.Row; import org.apache.flink.types.RowKind; import org.junit.jupiter.api.Test; +import java.time.Duration; +import java.time.LocalDateTime; import java.util.ArrayList; import java.util.List; import java.util.Map; @@ -196,6 +201,20 @@ public int hashCode() { } } + public static class TimedEvent { + public String key; + + @DataTypeHint("TIMESTAMP(3)") + public LocalDateTime ts; + + public TimedEvent() {} + + public TimedEvent(String key, LocalDateTime ts) { + this.key = key; + this.ts = ts; + } + } + /** PTF for testing structured type inputs. */ @DataTypeHint("ROW") public static class UserPTF extends ProcessTableFunction { @@ -299,7 +318,6 @@ public void eval( } } - /** PTF with Context parameter - should be rejected by test harness. */ @DataTypeHint("ROW") public static class PTFWithContext extends ProcessTableFunction { public void eval(Context ctx, @ArgumentHint(ArgumentTrait.ROW_SEMANTIC_TABLE) Row input) { @@ -1171,19 +1189,14 @@ void testProcessElementForTableWithInvalidName() throws Exception { } @Test - void testContextParameterRejected() { - Exception exception = - assertThrows( - IllegalStateException.class, - () -> - ProcessTableFunctionTestHarness.ofClass(PTFWithContext.class) - .withTableArgument("input", DataTypes.of("ROW")) - .build()); - - assertThat(exception.getMessage()) - .contains("does not yet support Context parameters") - .contains("Context parameter") - .contains("position 0"); + void testContextParameterSupported() throws Exception { + try (ProcessTableFunctionTestHarness harness = + ProcessTableFunctionTestHarness.ofClass(PTFWithContext.class) + .withTableArgument("input", DataTypes.of("ROW")) + .build()) { + harness.processElement(Row.of(42)); + assertThat(harness.getOutput()).containsExactly(Row.of(42)); + } } @Test @@ -1728,4 +1741,820 @@ void testPartitionKeyValidationOnClearState() throws Exception { harness.close(); } + + // ------------------------------------------------------------------------- + // Timer PTFs + // ------------------------------------------------------------------------- + + @DataTypeHint("ROW") + public static class TimerPTF extends ProcessTableFunction { + public void eval( + Context ctx, + @ArgumentHint({ArgumentTrait.SET_SEMANTIC_TABLE, ArgumentTrait.REQUIRE_ON_TIME}) + Row input) { + String name = input.getFieldAs("name"); + TimeContext timeCtx = ctx.timeContext(LocalDateTime.class); + timeCtx.registerOnTime("timeout-" + name, timeCtx.time().plus(Duration.ofSeconds(5))); + collect(Row.of("registered-" + name)); + } + + public void onTimer(OnTimerContext ctx) { + collect(Row.of("timer-fired-" + ctx.currentTimer())); + } + } + + @DataTypeHint("ROW") + public static class UnnamedTimerPTF extends ProcessTableFunction { + public void eval( + Context ctx, + @ArgumentHint({ArgumentTrait.SET_SEMANTIC_TABLE, ArgumentTrait.REQUIRE_ON_TIME}) + Row input) { + TimeContext timeCtx = ctx.timeContext(LocalDateTime.class); + timeCtx.registerOnTime(timeCtx.time().plus(Duration.ofSeconds(5))); + collect(Row.of("registered")); + } + + public void onTimer(OnTimerContext ctx) { + String timerName = ctx.currentTimer(); + collect(Row.of("fired-unnamed-" + (timerName == null ? "null" : timerName))); + } + } + + @DataTypeHint("ROW") + public static class TimerWithStatePTF extends ProcessTableFunction { + public static class CounterState { + public long count = 0L; + } + + public void eval( + Context ctx, + @StateHint CounterState state, + @ArgumentHint({ArgumentTrait.SET_SEMANTIC_TABLE, ArgumentTrait.REQUIRE_ON_TIME}) + Row input) { + state.count++; + TimeContext timeCtx = ctx.timeContext(LocalDateTime.class); + timeCtx.registerOnTime("check", timeCtx.time().plus(Duration.ofSeconds(5))); + } + + public void onTimer(OnTimerContext ctx, @StateHint CounterState state) { + collect(Row.of("count=" + state.count)); + } + } + + @DataTypeHint("ROW") + public static class PassThroughTimerPTF extends ProcessTableFunction { + public void eval( + Context ctx, + @ArgumentHint({ + ArgumentTrait.SET_SEMANTIC_TABLE, + ArgumentTrait.PASS_COLUMNS_THROUGH, + ArgumentTrait.REQUIRE_ON_TIME + }) + Row input) { + TimeContext timeCtx = ctx.timeContext(LocalDateTime.class); + timeCtx.registerOnTime("timer", timeCtx.time().plus(Duration.ofSeconds(5))); + collect(Row.of("registered")); + } + + public void onTimer(OnTimerContext ctx) { + collect(Row.of("fired")); + } + } + + @DataTypeHint("ROW") + public static class NoOnTimerPTF extends ProcessTableFunction { + public void eval( + Context ctx, + @ArgumentHint({ArgumentTrait.SET_SEMANTIC_TABLE, ArgumentTrait.REQUIRE_ON_TIME}) + Row input) { + TimeContext timeCtx = ctx.timeContext(LocalDateTime.class); + timeCtx.registerOnTime("timer", timeCtx.time().plus(Duration.ofSeconds(10))); + collect(Row.of("registered")); + } + } + + @DataTypeHint("ROW") + public static class MultipleOnTimerPTF extends ProcessTableFunction { + public void eval( + Context ctx, + @ArgumentHint({ArgumentTrait.SET_SEMANTIC_TABLE, ArgumentTrait.REQUIRE_ON_TIME}) + Row input) { + TimeContext timeCtx = ctx.timeContext(LocalDateTime.class); + timeCtx.registerOnTime("timer", timeCtx.time().plus(Duration.ofSeconds(5))); + collect(Row.of("registered")); + } + + public void onTimer(OnTimerContext ctx) { + collect(Row.of("fired-no-state")); + } + + public void onTimer(OnTimerContext ctx, @StateHint String state) { + collect(Row.of("fired-with-state")); + } + } + + @DataTypeHint("ROW") + public static class CascadingTimerPTF extends ProcessTableFunction { + public void eval( + Context ctx, + @ArgumentHint({ArgumentTrait.SET_SEMANTIC_TABLE, ArgumentTrait.REQUIRE_ON_TIME}) + Row input) { + TimeContext timeCtx = ctx.timeContext(LocalDateTime.class); + timeCtx.registerOnTime("first", timeCtx.time().plus(Duration.ofSeconds(5))); + collect(Row.of("eval")); + } + + public void onTimer(OnTimerContext ctx) { + String name = ctx.currentTimer(); + collect(Row.of("fired-" + name)); + if ("first".equals(name)) { + TimeContext timeCtx = ctx.timeContext(LocalDateTime.class); + timeCtx.registerOnTime("second", timeCtx.time().minus(Duration.ofSeconds(2))); + } + } + } + + @DataTypeHint("ROW") + public static class MultiTableTimerPTF extends ProcessTableFunction { + public void eval( + Context ctx, + @ArgumentHint({ArgumentTrait.SET_SEMANTIC_TABLE, ArgumentTrait.REQUIRE_ON_TIME}) + Row leftTable, + @ArgumentHint({ArgumentTrait.SET_SEMANTIC_TABLE, ArgumentTrait.REQUIRE_ON_TIME}) + Row rightTable) { + if (leftTable != null) { + TimeContext timeCtx = ctx.timeContext(LocalDateTime.class); + timeCtx.registerOnTime("check", timeCtx.time().plus(Duration.ofSeconds(5))); + collect(Row.of("left")); + } + if (rightTable != null) { + collect(Row.of("right")); + } + } + + public void onTimer(OnTimerContext ctx) { + collect(Row.of("timer-fired")); + } + } + + @DataTypeHint("ROW") + public static class ContextClearStatePTF extends ProcessTableFunction { + public static class CounterState { + public long count = 0L; + } + + public void eval( + Context ctx, + @StateHint CounterState state, + @ArgumentHint({ArgumentTrait.SET_SEMANTIC_TABLE, ArgumentTrait.REQUIRE_ON_TIME}) + Row input) { + String action = input.getFieldAs("action"); + if ("increment".equals(action)) { + if (state == null) { + state = new CounterState(); + } + state.count++; + collect(Row.of("count=" + state.count)); + } else if ("clear-state".equals(action)) { + ctx.clearState("state"); + collect(Row.of("cleared")); + } else if ("clear-all-state".equals(action)) { + ctx.clearAllState(); + collect(Row.of("cleared-all")); + } else if ("clear-all".equals(action)) { + ctx.clearAll(); + collect(Row.of("cleared-everything")); + } else if ("register-timer".equals(action)) { + TimeContext timeCtx = ctx.timeContext(LocalDateTime.class); + timeCtx.registerOnTime("t", timeCtx.time().plus(Duration.ofHours(1))); + collect(Row.of("timer-registered")); + } + } + + public void onTimer(OnTimerContext ctx, @StateHint CounterState state) { + collect(Row.of("timer-fired")); + } + } + + @DataTypeHint("ROW") + public static class PojoTimerPTF extends ProcessTableFunction { + public void eval( + Context ctx, + @ArgumentHint({ArgumentTrait.SET_SEMANTIC_TABLE, ArgumentTrait.REQUIRE_ON_TIME}) + TimedEvent input) { + TimeContext timeCtx = ctx.timeContext(LocalDateTime.class); + timeCtx.registerOnTime("check", timeCtx.time().plus(Duration.ofSeconds(5))); + collect(Row.of("registered-" + input.key)); + } + + public void onTimer(OnTimerContext ctx) { + collect(Row.of("fired-" + ctx.currentTimer())); + } + } + + @DataTypeHint("ROW") + public static class ContextSemanticsIntrospectionPTF extends ProcessTableFunction { + public void eval( + Context ctx, + @ArgumentHint({ArgumentTrait.SET_SEMANTIC_TABLE, ArgumentTrait.REQUIRE_ON_TIME}) + Row input) { + TableSemantics semantics = ctx.tableSemanticsFor("input"); + collect( + Row.of( + java.util.Arrays.toString(semantics.partitionByColumns()), + semantics.timeColumn(), + ctx.getChangelogMode().toString())); + } + } + + @DataTypeHint("ROW") + public static class ClearTimerPTF extends ProcessTableFunction { + public void eval( + Context ctx, + @ArgumentHint({ArgumentTrait.SET_SEMANTIC_TABLE, ArgumentTrait.REQUIRE_ON_TIME}) + Row input) { + TimeContext timeCtx = ctx.timeContext(LocalDateTime.class); + String action = input.getFieldAs("action"); + if ("register-named".equals(action)) { + timeCtx.registerOnTime("myTimer", timeCtx.time().plus(Duration.ofSeconds(5))); + } else if ("register-unnamed".equals(action)) { + timeCtx.registerOnTime(timeCtx.time().plus(Duration.ofSeconds(5))); + } else if ("clear-by-name".equals(action)) { + timeCtx.clearTimer("myTimer"); + } else if ("clear-by-timestamp".equals(action)) { + timeCtx.clearTimer(timeCtx.time().plus(Duration.ofSeconds(4))); + } else if ("clear-all".equals(action)) { + ctx.clearAllTimers(); + } + } + + public void onTimer(OnTimerContext ctx) { + collect( + Row.of( + "fired-" + + (ctx.currentTimer() != null + ? ctx.currentTimer() + : "unnamed"))); + } + } + + // ------------------------------------------------------------------------- + // Context Tests + // ------------------------------------------------------------------------- + + @Test + void testPerTableWatermarkTimerFiring() throws Exception { + try (ProcessTableFunctionTestHarness harness = + ProcessTableFunctionTestHarness.ofClass(MultiTableTimerPTF.class) + .withTableArgument( + "leftTable", DataTypes.of("ROW")) + .withTableArgument( + "rightTable", + DataTypes.of("ROW")) + .withPartitionBy("leftTable", "partition") + .withPartitionBy("rightTable", "partition") + .withOnTimeColumn("ts") + .build()) { + + // Register a timer at t=5000 via the left table + harness.processElementForTable( + "leftTable", Row.of("P1", LocalDateTime.of(2025, 1, 1, 0, 0, 1))); + harness.clearOutput(); + + // Advance only left table past the timer — global watermark is still min(left, right) + // Since right has no watermark yet, global won't advance enough to fire + harness.setWatermarkForTable("rightTable", LocalDateTime.of(2025, 1, 1, 0, 0, 3)); + harness.setWatermarkForTable("leftTable", LocalDateTime.of(2025, 1, 1, 0, 0, 10)); + assertThat(harness.getFiredTimers()).isEmpty(); + + // Now advance right table past the timer — global watermark catches up, timer fires + harness.setWatermarkForTable("rightTable", LocalDateTime.of(2025, 1, 1, 0, 0, 10)); + assertThat(harness.getFiredTimers()).hasSize(1); + // Both tables' partition keys are prepended (one "P1" per table) + assertThat(harness.getOutput()) + .containsExactly( + Row.of( + "P1", + "P1", + "timer-fired", + LocalDateTime.of(2025, 1, 1, 0, 0, 6))); + } + } + + @Test + void testContextClearState() throws Exception { + try (ProcessTableFunctionTestHarness harness = + ProcessTableFunctionTestHarness.ofClass(ContextClearStatePTF.class) + .withTableArgument( + "input", + DataTypes.of( + "ROW")) + .withPartitionBy("input", "partition") + .withOnTimeColumn("ts") + .build()) { + + harness.processElement( + Row.of("P1", "increment", LocalDateTime.of(2025, 1, 1, 0, 0, 1))); + harness.processElement( + Row.of("P1", "increment", LocalDateTime.of(2025, 1, 1, 0, 0, 2))); + assertThat(harness.getOutput()) + .containsExactly( + Row.of("P1", "count=1", LocalDateTime.of(2025, 1, 1, 0, 0, 1)), + Row.of("P1", "count=2", LocalDateTime.of(2025, 1, 1, 0, 0, 2))); + harness.clearOutput(); + + harness.processElement( + Row.of("P1", "clear-state", LocalDateTime.of(2025, 1, 1, 0, 0, 3))); + harness.processElement( + Row.of("P1", "increment", LocalDateTime.of(2025, 1, 1, 0, 0, 4))); + assertThat(harness.getOutput()) + .containsExactly( + Row.of("P1", "cleared", LocalDateTime.of(2025, 1, 1, 0, 0, 3)), + Row.of("P1", "count=1", LocalDateTime.of(2025, 1, 1, 0, 0, 4))); + } + } + + @Test + void testContextClearAllState() throws Exception { + try (ProcessTableFunctionTestHarness harness = + ProcessTableFunctionTestHarness.ofClass(ContextClearStatePTF.class) + .withTableArgument( + "input", + DataTypes.of( + "ROW")) + .withPartitionBy("input", "partition") + .withOnTimeColumn("ts") + .build()) { + + harness.processElement( + Row.of("P1", "increment", LocalDateTime.of(2025, 1, 1, 0, 0, 1))); + harness.clearOutput(); + + harness.processElement( + Row.of("P1", "clear-all-state", LocalDateTime.of(2025, 1, 1, 0, 0, 2))); + harness.processElement( + Row.of("P1", "increment", LocalDateTime.of(2025, 1, 1, 0, 0, 3))); + assertThat(harness.getOutput()) + .containsExactly( + Row.of("P1", "cleared-all", LocalDateTime.of(2025, 1, 1, 0, 0, 2)), + Row.of("P1", "count=1", LocalDateTime.of(2025, 1, 1, 0, 0, 3))); + } + } + + @Test + void testContextClearAll() throws Exception { + try (ProcessTableFunctionTestHarness harness = + ProcessTableFunctionTestHarness.ofClass(ContextClearStatePTF.class) + .withTableArgument( + "input", + DataTypes.of( + "ROW")) + .withPartitionBy("input", "partition") + .withOnTimeColumn("ts") + .build()) { + + harness.processElement( + Row.of("P1", "increment", LocalDateTime.of(2025, 1, 1, 0, 0, 1))); + harness.processElement( + Row.of("P1", "register-timer", LocalDateTime.of(2025, 1, 1, 0, 0, 2))); + assertThat(harness.getPendingTimers()).hasSize(1); + harness.clearOutput(); + + harness.processElement( + Row.of("P1", "clear-all", LocalDateTime.of(2025, 1, 1, 0, 0, 3))); + assertThat(harness.getPendingTimers()).isEmpty(); + + harness.processElement( + Row.of("P1", "increment", LocalDateTime.of(2025, 1, 1, 0, 0, 4))); + assertThat(harness.getOutput()) + .containsExactly( + Row.of( + "P1", + "cleared-everything", + LocalDateTime.of(2025, 1, 1, 0, 0, 3)), + Row.of("P1", "count=1", LocalDateTime.of(2025, 1, 1, 0, 0, 4))); + } + } + + @Test + void testContextTableSemanticsAndChangelogMode() throws Exception { + try (ProcessTableFunctionTestHarness harness = + ProcessTableFunctionTestHarness.ofClass(ContextSemanticsIntrospectionPTF.class) + .withTableArgument( + "input", DataTypes.of("ROW")) + .withPartitionBy("input", "partition") + .withOnTimeColumn("ts") + .build()) { + + harness.processElement(Row.of("P1", LocalDateTime.of(2025, 1, 1, 0, 0, 1))); + + assertThat(harness.getOutput()).hasSize(1); + Row result = harness.getOutput().get(0); + // Field 0 is prepended partition key + assertThat(result.getFieldAs(0).toString()).isEqualTo("P1"); + // partitionByColumns: [0] (index of "partition" column) + assertThat(result.getFieldAs(1).toString()).isEqualTo("[0]"); + // timeColumn: 1 (index of "ts" column) + assertThat((int) result.getFieldAs(2)).isEqualTo(1); + // changelogMode: insert-only + assertThat(result.getFieldAs(3).toString()) + .isEqualTo(ChangelogMode.insertOnly().toString()); + } + } + + @Test + void testPojoInputWithOnTimeColumn() throws Exception { + try (ProcessTableFunctionTestHarness harness = + ProcessTableFunctionTestHarness.ofClass(PojoTimerPTF.class) + .withTableArgument("input") + .withPartitionBy("input", "key") + .withOnTimeColumn("ts") + .build()) { + + harness.processElement(Row.of("P1", LocalDateTime.of(2025, 1, 1, 0, 0, 1))); + assertThat(harness.getOutput()) + .containsExactly( + Row.of("P1", "registered-P1", LocalDateTime.of(2025, 1, 1, 0, 0, 1))); + assertThat(harness.getPendingTimers()).hasSize(1); + harness.clearOutput(); + + harness.setWatermark(LocalDateTime.of(2025, 1, 1, 0, 0, 10)); + assertThat(harness.getFiredTimers()).hasSize(1); + assertThat(harness.getOutput()) + .containsExactly( + Row.of("P1", "fired-check", LocalDateTime.of(2025, 1, 1, 0, 0, 6))); + } + } + + // ------------------------------------------------------------------------- + // Timer Tests + // ------------------------------------------------------------------------- + + @Test + void testNamedTimerRegistrationAndFiring() throws Exception { + try (ProcessTableFunctionTestHarness harness = + ProcessTableFunctionTestHarness.ofClass(TimerPTF.class) + .withTableArgument( + "input", + DataTypes.of("ROW")) + .withPartitionBy("input", "partition") + .withOnTimeColumn("ts") + .build()) { + + harness.processElement(Row.of("P1", "Alice", LocalDateTime.of(2025, 1, 1, 0, 0, 1))); + assertThat(harness.getOutput()) + .containsExactly( + Row.of( + "P1", + "registered-Alice", + LocalDateTime.of(2025, 1, 1, 0, 0, 1))); + harness.clearOutput(); + + assertThat(harness.getPendingTimers()).hasSize(1); + assertThat(harness.getPendingTimers().get(0).getName()).isEqualTo("timeout-Alice"); + + harness.setWatermark(LocalDateTime.of(2025, 1, 1, 0, 0, 7)); + assertThat(harness.getOutput()) + .containsExactly( + Row.of( + "P1", + "timer-fired-timeout-Alice", + LocalDateTime.of(2025, 1, 1, 0, 0, 6))); + + assertThat(harness.getPendingTimers()).isEmpty(); + assertThat(harness.getFiredTimers()).hasSize(1); + } + } + + @Test + void testUnnamedTimerFiring() throws Exception { + try (ProcessTableFunctionTestHarness harness = + ProcessTableFunctionTestHarness.ofClass(UnnamedTimerPTF.class) + .withTableArgument( + "input", DataTypes.of("ROW")) + .withPartitionBy("input", "partition") + .withOnTimeColumn("ts") + .build()) { + + harness.processElement(Row.of("P1", LocalDateTime.of(2025, 1, 1, 0, 0, 1))); + harness.clearOutput(); + + harness.setWatermark(LocalDateTime.of(2025, 1, 1, 0, 0, 7)); + assertThat(harness.getOutput()) + .containsExactly( + Row.of( + "P1", + "fired-unnamed-null", + LocalDateTime.of(2025, 1, 1, 0, 0, 6))); + } + } + + @Test + void testTimerReplacementSemantics() throws Exception { + try (ProcessTableFunctionTestHarness harness = + ProcessTableFunctionTestHarness.ofClass(TimerPTF.class) + .withTableArgument( + "input", + DataTypes.of("ROW")) + .withPartitionBy("input", "partition") + .withOnTimeColumn("ts") + .build()) { + + // First eval: event at T+1s → timer at T+6s + // Second eval: event at T+6s → replaces timer to T+11s + harness.processElement(Row.of("P1", "Alice", LocalDateTime.of(2025, 1, 1, 0, 0, 1))); + harness.processElement(Row.of("P1", "Alice", LocalDateTime.of(2025, 1, 1, 0, 0, 6))); + harness.clearOutput(); + + assertThat(harness.getPendingTimers()).hasSize(1); + + // Watermark at T+7s — timer at T+11s should NOT fire + harness.setWatermark(LocalDateTime.of(2025, 1, 1, 0, 0, 7)); + assertThat(harness.getOutput()).isEmpty(); + + // Watermark at T+12s — timer at T+11s should fire + harness.setWatermark(LocalDateTime.of(2025, 1, 1, 0, 0, 12)); + assertThat(harness.getOutput()) + .containsExactly( + Row.of( + "P1", + "timer-fired-timeout-Alice", + LocalDateTime.of(2025, 1, 1, 0, 0, 11))); + } + } + + @Test + void testMultipleTimersFiringOrder() throws Exception { + try (ProcessTableFunctionTestHarness harness = + ProcessTableFunctionTestHarness.ofClass(TimerPTF.class) + .withTableArgument( + "input", + DataTypes.of("ROW")) + .withPartitionBy("input", "partition") + .withOnTimeColumn("ts") + .build()) { + + harness.processElement(Row.of("P1", "Charlie", LocalDateTime.of(2025, 1, 1, 0, 0, 5))); + harness.processElement(Row.of("P1", "Alice", LocalDateTime.of(2025, 1, 1, 0, 0, 1))); + harness.processElement(Row.of("P1", "Bob", LocalDateTime.of(2025, 1, 1, 0, 0, 3))); + harness.clearOutput(); + + harness.setWatermark(LocalDateTime.of(2025, 1, 1, 0, 0, 11)); + assertThat(harness.getOutput()) + .containsExactly( + Row.of( + "P1", + "timer-fired-timeout-Alice", + LocalDateTime.of(2025, 1, 1, 0, 0, 6)), + Row.of( + "P1", + "timer-fired-timeout-Bob", + LocalDateTime.of(2025, 1, 1, 0, 0, 8)), + Row.of( + "P1", + "timer-fired-timeout-Charlie", + LocalDateTime.of(2025, 1, 1, 0, 0, 10))); + } + } + + @Test + void testStateAccessInOnTimer() throws Exception { + try (ProcessTableFunctionTestHarness harness = + ProcessTableFunctionTestHarness.ofClass(TimerWithStatePTF.class) + .withTableArgument( + "input", DataTypes.of("ROW")) + .withPartitionBy("input", "partition") + .withOnTimeColumn("ts") + .build()) { + + harness.processElement(Row.of("P1", LocalDateTime.of(2025, 1, 1, 0, 0, 1))); + harness.processElement(Row.of("P1", LocalDateTime.of(2025, 1, 1, 0, 0, 1))); + harness.processElement(Row.of("P1", LocalDateTime.of(2025, 1, 1, 0, 0, 1))); + + harness.setWatermark(LocalDateTime.of(2025, 1, 1, 0, 0, 7)); + assertThat(harness.getOutput()) + .containsExactly( + Row.of("P1", "count=3", LocalDateTime.of(2025, 1, 1, 0, 0, 6))); + } + } + + @Test + void testWatermarkCannotMoveBackward() throws Exception { + try (ProcessTableFunctionTestHarness harness = + ProcessTableFunctionTestHarness.ofClass(TimerPTF.class) + .withTableArgument( + "input", + DataTypes.of("ROW")) + .withPartitionBy("input", "partition") + .withOnTimeColumn("ts") + .withInitialWatermark(LocalDateTime.of(2025, 1, 1, 0, 0, 5)) + .build()) { + + assertThrows( + IllegalArgumentException.class, + () -> harness.setWatermark(LocalDateTime.of(2025, 1, 1, 0, 0, 2))); + } + } + + @Test + void testPassThroughWithTimersIsRejected() throws Exception { + try (ProcessTableFunctionTestHarness harness = + ProcessTableFunctionTestHarness.ofClass(PassThroughTimerPTF.class) + .withTableArgument( + "input", DataTypes.of("ROW")) + .withPartitionBy("input", "partition") + .withOnTimeColumn("ts") + .build()) { + + assertThrows( + TableRuntimeException.class, + () -> + harness.processElement( + Row.of("P1", LocalDateTime.of(2025, 1, 1, 0, 0, 1)))); + } + } + + @Test + void testNoOnTimerMethodThrows() throws Exception { + try (ProcessTableFunctionTestHarness harness = + ProcessTableFunctionTestHarness.ofClass(NoOnTimerPTF.class) + .withTableArgument( + "input", DataTypes.of("ROW")) + .withPartitionBy("input", "partition") + .withOnTimeColumn("ts") + .build()) { + + harness.processElement(Row.of("P1", LocalDateTime.of(2025, 1, 1, 0, 0, 1))); + + assertThrows( + IllegalStateException.class, + () -> harness.setWatermark(LocalDateTime.of(2025, 1, 1, 0, 1))); + } + } + + @Test + void testMultipleOnTimerMethods() throws Exception { + try (ProcessTableFunctionTestHarness harness = + ProcessTableFunctionTestHarness.ofClass(MultipleOnTimerPTF.class) + .withTableArgument( + "input", DataTypes.of("ROW")) + .withPartitionBy("input", "partition") + .withOnTimeColumn("ts") + .build()) { + + harness.processElement(Row.of("P1", LocalDateTime.of(2025, 1, 1, 0, 0, 1))); + harness.clearOutput(); + + harness.setWatermark(LocalDateTime.of(2025, 1, 1, 0, 0, 7)); + assertThat(harness.getOutput()).hasSize(1); + } + } + + @Test + void testCascadingTimerFromOnTimer() throws Exception { + try (ProcessTableFunctionTestHarness harness = + ProcessTableFunctionTestHarness.ofClass(CascadingTimerPTF.class) + .withTableArgument( + "input", DataTypes.of("ROW")) + .withPartitionBy("input", "partition") + .withOnTimeColumn("ts") + .build()) { + + harness.processElement(Row.of("P1", LocalDateTime.of(2025, 1, 1, 0, 0, 1))); + assertThat(harness.getOutput()) + .containsExactly(Row.of("P1", "eval", LocalDateTime.of(2025, 1, 1, 0, 0, 1))); + harness.clearOutput(); + + // "first" fires, registers "second" at 3000ms (past time), + // "second" should cascade and fire in the same advance + harness.setWatermark(LocalDateTime.of(2025, 1, 1, 0, 0, 7)); + assertThat(harness.getOutput()) + .containsExactly( + Row.of("P1", "fired-first", LocalDateTime.of(2025, 1, 1, 0, 0, 6)), + Row.of("P1", "fired-second", LocalDateTime.of(2025, 1, 1, 0, 0, 4))); + assertThat(harness.getPendingTimers()).isEmpty(); + assertThat(harness.getFiredTimers()).hasSize(2); + } + } + + @Test + void testClearTimerByName() throws Exception { + try (ProcessTableFunctionTestHarness harness = + ProcessTableFunctionTestHarness.ofClass(ClearTimerPTF.class) + .withTableArgument( + "input", + DataTypes.of( + "ROW")) + .withPartitionBy("input", "partition") + .withOnTimeColumn("ts") + .build()) { + + harness.processElement( + Row.of("P1", "register-named", LocalDateTime.of(2025, 1, 1, 0, 0, 1))); + assertThat(harness.getPendingTimers()).hasSize(1); + + harness.processElement( + Row.of("P1", "clear-by-name", LocalDateTime.of(2025, 1, 1, 0, 0, 2))); + assertThat(harness.getPendingTimers()).isEmpty(); + + harness.setWatermark(LocalDateTime.of(2025, 1, 1, 0, 1)); + assertThat(harness.getOutput()).isEmpty(); + } + } + + @Test + void testClearTimerByTimestamp() throws Exception { + try (ProcessTableFunctionTestHarness harness = + ProcessTableFunctionTestHarness.ofClass(ClearTimerPTF.class) + .withTableArgument( + "input", + DataTypes.of( + "ROW")) + .withPartitionBy("input", "partition") + .withOnTimeColumn("ts") + .build()) { + + harness.processElement( + Row.of("P1", "register-unnamed", LocalDateTime.of(2025, 1, 1, 0, 0, 1))); + assertThat(harness.getPendingTimers()).hasSize(1); + + harness.processElement( + Row.of("P1", "clear-by-timestamp", LocalDateTime.of(2025, 1, 1, 0, 0, 2))); + assertThat(harness.getPendingTimers()).isEmpty(); + + harness.setWatermark(LocalDateTime.of(2025, 1, 1, 0, 1)); + assertThat(harness.getOutput()).isEmpty(); + } + } + + @Test + void testClearTimerByTimestampDoesNotClearNamedTimers() throws Exception { + try (ProcessTableFunctionTestHarness harness = + ProcessTableFunctionTestHarness.ofClass(ClearTimerPTF.class) + .withTableArgument( + "input", + DataTypes.of( + "ROW")) + .withPartitionBy("input", "partition") + .withOnTimeColumn("ts") + .build()) { + + harness.processElement( + Row.of("P1", "register-named", LocalDateTime.of(2025, 1, 1, 0, 0, 1))); + assertThat(harness.getPendingTimers()).hasSize(1); + + harness.processElement( + Row.of("P1", "clear-by-timestamp", LocalDateTime.of(2025, 1, 1, 0, 0, 2))); + assertThat(harness.getPendingTimers()).hasSize(1); + + harness.setWatermark(LocalDateTime.of(2025, 1, 1, 0, 1)); + assertThat(harness.getOutput()) + .containsExactly( + Row.of("P1", "fired-myTimer", LocalDateTime.of(2025, 1, 1, 0, 0, 6))); + } + } + + @Test + void testClearAllTimers() throws Exception { + try (ProcessTableFunctionTestHarness harness = + ProcessTableFunctionTestHarness.ofClass(ClearTimerPTF.class) + .withTableArgument( + "input", + DataTypes.of( + "ROW")) + .withPartitionBy("input", "partition") + .withOnTimeColumn("ts") + .build()) { + + harness.processElement( + Row.of("P1", "register-named", LocalDateTime.of(2025, 1, 1, 0, 0, 1))); + harness.processElement( + Row.of("P1", "register-unnamed", LocalDateTime.of(2025, 1, 1, 0, 0, 2))); + assertThat(harness.getPendingTimers()).hasSize(2); + + harness.processElement( + Row.of("P1", "clear-all", LocalDateTime.of(2025, 1, 1, 0, 0, 3))); + assertThat(harness.getPendingTimers()).isEmpty(); + + harness.setWatermark(LocalDateTime.of(2025, 1, 1, 0, 1)); + assertThat(harness.getOutput()).isEmpty(); + } + } + + @Test + void testInitialWatermarkForUnknownTableThrows() { + assertThrows( + IllegalArgumentException.class, + () -> + ProcessTableFunctionTestHarness.ofClass(TimerPTF.class) + .withTableArgument( + "input", + DataTypes.of( + "ROW")) + .withPartitionBy("input", "partition") + .withOnTimeColumn("ts") + .withInitialWatermarkForTable( + "nonexistent", LocalDateTime.of(2025, 1, 1, 0, 0)) + .build()); + } }