From c14f979cec4cfb1317f03cdca3bfd571e50d30b9 Mon Sep 17 00:00:00 2001 From: yux <34335406+yuxiqian@users.noreply.github.com> Date: Tue, 12 May 2026 15:42:14 +0800 Subject: [PATCH 01/11] [FLINK-39568] YAML built-in AI Function and model providers overhaul --- .../parser/YamlPipelineDefinitionParser.java | 40 ++- .../YamlPipelineDefinitionParserTest.java | 254 ++++++++++++++++-- .../definitions/pipeline-definition-full.yaml | 15 +- .../flink/cdc/common/model/AiModelClient.java | 42 +++ .../common/model/AiModelClientFactory.java | 89 ++++++ .../flink/cdc/common/model/ModelContext.java | 36 +++ .../model/abilities/SupportsEmbedding.java | 32 +++ .../abilities/SupportsTextGeneration.java | 35 +++ .../model/AiModelClientFactoryTest.java | 166 ++++++++++++ flink-cdc-composer/pom.xml | 4 +- .../cdc/composer/definition/ModelDef.java | 56 ++-- .../composer/flink/FlinkPipelineComposer.java | 1 - .../flink/translator/TransformTranslator.java | 106 ++++++-- .../flink/FlinkPipelineAiFunctionITCase.java | 212 +++++++++++++++ .../flink/FlinkPipelineUdfITCase.java | 75 ------ ...link.cdc.common.model.AiModelClientFactory | 16 ++ .../flink-cdc-pipeline-e2e-tests/pom.xml | 10 + .../pipeline/tests/AiFunctionE2eITCase.java | 123 +++++++++ .../src/test/resources/rules/malformed.yaml | 17 -- .../flink-cdc-pipeline-model-dummy/pom.xml | 32 +++ .../cdc/models/dummy/DummyModelClient.java | 62 +++++ .../models/dummy/DummyModelClientFactory.java | 49 ++++ ...link.cdc.common.model.AiModelClientFactory | 16 ++ flink-cdc-pipeline-model/pom.xml | 58 +--- .../flink/cdc/runtime/model/ModelOptions.java | 50 ---- .../cdc/runtime/model/OpenAIChatModel.java | 97 ------- .../runtime/model/OpenAIEmbeddingModel.java | 109 -------- .../runtime/model/TestOpenAIChatModel.java | 42 --- .../model/TestOpenAIEmbeddingModel.java | 45 ---- flink-cdc-runtime/pom.xml | 6 - .../runtime/ai/AiEmbeddingFunctionDef.java | 50 ++++ .../cdc/runtime/ai/AiTextFunctionDef.java | 83 ++++++ .../runtime/functions/impl/AiFunctions.java | 98 +++++++ .../transform/PostTransformOperator.java | 43 ++- .../PostTransformOperatorBuilder.java | 10 +- .../transform/ProjectionColumnProcessor.java | 33 ++- .../TransformExpressionCompiler.java | 22 ++ .../transform/TransformFilterProcessor.java | 19 +- .../TransformProjectionProcessor.java | 9 +- .../cdc/runtime/parser/JaninoCompiler.java | 57 ++-- .../cdc/runtime/parser/TransformParser.java | 7 +- .../metadata/AiFunctionSqlOperatorTable.java | 109 ++++++++ .../metadata/TransformSqlOperatorTable.java | 32 --- .../functions/impl/AiFunctionsTest.java | 125 +++++++++ .../UserDefinedFunctionDescriptorTest.java | 13 +- .../runtime/parser/AiFunctionParserTest.java | 107 ++++++++ 46 files changed, 2049 insertions(+), 663 deletions(-) create mode 100644 flink-cdc-common/src/main/java/org/apache/flink/cdc/common/model/AiModelClient.java create mode 100644 flink-cdc-common/src/main/java/org/apache/flink/cdc/common/model/AiModelClientFactory.java create mode 100644 flink-cdc-common/src/main/java/org/apache/flink/cdc/common/model/ModelContext.java create mode 100644 flink-cdc-common/src/main/java/org/apache/flink/cdc/common/model/abilities/SupportsEmbedding.java create mode 100644 flink-cdc-common/src/main/java/org/apache/flink/cdc/common/model/abilities/SupportsTextGeneration.java create mode 100644 flink-cdc-common/src/test/java/org/apache/flink/cdc/common/model/AiModelClientFactoryTest.java create mode 100644 flink-cdc-composer/src/test/java/org/apache/flink/cdc/composer/flink/FlinkPipelineAiFunctionITCase.java create mode 100644 flink-cdc-composer/src/test/resources/META-INF/services/org.apache.flink.cdc.common.model.AiModelClientFactory create mode 100644 flink-cdc-e2e-tests/flink-cdc-pipeline-e2e-tests/src/test/java/org/apache/flink/cdc/pipeline/tests/AiFunctionE2eITCase.java create mode 100644 flink-cdc-pipeline-model/flink-cdc-pipeline-model-dummy/pom.xml create mode 100644 flink-cdc-pipeline-model/flink-cdc-pipeline-model-dummy/src/main/java/org/apache/flink/cdc/models/dummy/DummyModelClient.java create mode 100644 flink-cdc-pipeline-model/flink-cdc-pipeline-model-dummy/src/main/java/org/apache/flink/cdc/models/dummy/DummyModelClientFactory.java create mode 100644 flink-cdc-pipeline-model/flink-cdc-pipeline-model-dummy/src/main/resources/META-INF/services/org.apache.flink.cdc.common.model.AiModelClientFactory delete mode 100644 flink-cdc-pipeline-model/src/main/java/org/apache/flink/cdc/runtime/model/ModelOptions.java delete mode 100644 flink-cdc-pipeline-model/src/main/java/org/apache/flink/cdc/runtime/model/OpenAIChatModel.java delete mode 100644 flink-cdc-pipeline-model/src/main/java/org/apache/flink/cdc/runtime/model/OpenAIEmbeddingModel.java delete mode 100644 flink-cdc-pipeline-model/src/test/java/org/apache/flink/cdc/runtime/model/TestOpenAIChatModel.java delete mode 100644 flink-cdc-pipeline-model/src/test/java/org/apache/flink/cdc/runtime/model/TestOpenAIEmbeddingModel.java create mode 100644 flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/ai/AiEmbeddingFunctionDef.java create mode 100644 flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/ai/AiTextFunctionDef.java create mode 100644 flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/functions/impl/AiFunctions.java create mode 100644 flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/parser/metadata/AiFunctionSqlOperatorTable.java create mode 100644 flink-cdc-runtime/src/test/java/org/apache/flink/cdc/runtime/functions/impl/AiFunctionsTest.java create mode 100644 flink-cdc-runtime/src/test/java/org/apache/flink/cdc/runtime/parser/AiFunctionParserTest.java diff --git a/flink-cdc-cli/src/main/java/org/apache/flink/cdc/cli/parser/YamlPipelineDefinitionParser.java b/flink-cdc-cli/src/main/java/org/apache/flink/cdc/cli/parser/YamlPipelineDefinitionParser.java index 831bb5e6536..9cad4d81c93 100644 --- a/flink-cdc-cli/src/main/java/org/apache/flink/cdc/cli/parser/YamlPipelineDefinitionParser.java +++ b/flink-cdc-cli/src/main/java/org/apache/flink/cdc/cli/parser/YamlPipelineDefinitionParser.java @@ -44,6 +44,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; +import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; import java.util.List; @@ -96,9 +97,8 @@ public class YamlPipelineDefinitionParser implements PipelineDefinitionParser { private static final String UDF_OPTIONS_KEY = "options"; // Model related keys - private static final String MODEL_NAME_KEY = "model-name"; - - private static final String MODEL_CLASS_NAME_KEY = "class-name"; + private static final String MODEL_NAME_KEY = "name"; + private static final String MODEL_TYPE_KEY = "type"; public static final String TRANSFORM_PRIMARY_KEY_KEY = "primary-keys"; @@ -145,7 +145,6 @@ private PipelineDef parse(JsonNode pipelineDefJsonNode, Configuration globalPipe Optional.ofNullable( ((ObjectNode) pipelineDefJsonNode.get(PIPELINE_KEY)).remove(MODEL_KEY)) - .map(node -> validateArray("model", node)) .ifPresent(node -> modelDefs.addAll(parseModels(node))); } @@ -428,24 +427,45 @@ private List parseModels(JsonNode modelsNode) { } else { modelDefs.add(convertJsonNodeToModelDef(modelsNode)); } + Set seenNames = new HashSet<>(); + for (ModelDef model : modelDefs) { + if (!seenNames.add(model.getName())) { + throw new IllegalArgumentException( + "Duplicate model name '" + model.getName() + "' in pipeline definition."); + } + } return modelDefs; } private ModelDef convertJsonNodeToModelDef(JsonNode modelNode) { + Preconditions.checkArgument( + modelNode instanceof ObjectNode, + "`model` in `pipeline` should be an object, but got %s", + modelNode); + ObjectNode node = (ObjectNode) modelNode; String name = checkNotNull( - modelNode.get(MODEL_NAME_KEY), + node.remove(MODEL_NAME_KEY), "Missing required field \"%s\" in `model`", MODEL_NAME_KEY) .asText(); - String model = + Preconditions.checkArgument( + name.matches("[a-zA-Z_][a-zA-Z0-9_]*") && !name.startsWith("__"), + "Model name \"%s\" is not a valid identifier. " + + "It must start with a letter or underscore, " + + "contain only letters, digits, or underscores, " + + "and must not start with double underscores.", + name); + String type = checkNotNull( - modelNode.get(MODEL_CLASS_NAME_KEY), + node.remove(MODEL_TYPE_KEY), "Missing required field \"%s\" in `model`", - MODEL_CLASS_NAME_KEY) + MODEL_TYPE_KEY) .asText(); - Map properties = mapper.convertValue(modelNode, Map.class); - return new ModelDef(name, model, properties); + Map options = new HashMap<>(); + node.fields() + .forEachRemaining(entry -> options.put(entry.getKey(), entry.getValue().asText())); + return new ModelDef(name, type, options); } private void validateJsonNodeKeys( diff --git a/flink-cdc-cli/src/test/java/org/apache/flink/cdc/cli/parser/YamlPipelineDefinitionParserTest.java b/flink-cdc-cli/src/test/java/org/apache/flink/cdc/cli/parser/YamlPipelineDefinitionParserTest.java index c8c10db76e7..bf564d1c8e7 100644 --- a/flink-cdc-cli/src/test/java/org/apache/flink/cdc/cli/parser/YamlPipelineDefinitionParserTest.java +++ b/flink-cdc-cli/src/test/java/org/apache/flink/cdc/cli/parser/YamlPipelineDefinitionParserTest.java @@ -34,6 +34,8 @@ import org.apache.flink.shaded.guava31.com.google.common.io.Resources; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; import java.net.URL; import java.time.Duration; @@ -422,17 +424,28 @@ private void testSchemaEvolutionTypesParsing( "add new uniq_id for each row", null)), Collections.emptyList(), - Collections.singletonList( + Arrays.asList( + new ModelDef( + "Sonnet", + "openai-compatible", + new LinkedHashMap<>( + ImmutableMap.builder() + .put("model-name", "claude-sonnet-4-6") + .put( + "endpoint", + "https://idealab.alibaba-inc.com/api/openai/v1") + .put("api-key", "cafebabe") + .build())), new ModelDef( - "GET_EMBEDDING", - "OpenAIEmbeddingModel", + "Opus", + "openai-compatible", new LinkedHashMap<>( ImmutableMap.builder() - .put("model-name", "GET_EMBEDDING") - .put("class-name", "OpenAIEmbeddingModel") - .put("openai.model", "text-embedding-3-small") - .put("openai.host", "https://xxxx") - .put("openai.apikey", "abcd1234") + .put("model-name", "claude-opus-4-5") + .put( + "endpoint", + "https://idealab.alibaba-inc.com/api/openai/v1") + .put("api-key", "cafebabe") .build()))), Configuration.fromMap( ImmutableMap.builder() @@ -492,11 +505,16 @@ void testParsingFullDefinitionFromString() throws Exception { + " schema-operator.rpc-timeout: 1 h\n" + " execution.runtime-mode: STREAMING\n" + " model:\n" - + " - model-name: GET_EMBEDDING\n" - + " class-name: OpenAIEmbeddingModel\n" - + " openai.model: text-embedding-3-small\n" - + " openai.host: https://xxxx\n" - + " openai.apikey: abcd1234"; + + " - name: Sonnet\n" + + " type: openai-compatible\n" + + " model-name: claude-sonnet-4-6\n" + + " endpoint: https://idealab.alibaba-inc.com/api/openai/v1\n" + + " api-key: cafebabe\n" + + " - name: Opus\n" + + " type: openai-compatible\n" + + " model-name: claude-opus-4-5\n" + + " endpoint: https://idealab.alibaba-inc.com/api/openai/v1\n" + + " api-key: cafebabe\n"; YamlPipelineDefinitionParser parser = new YamlPipelineDefinitionParser(); PipelineDef pipelineDef = parser.parse(pipelineDefText, new Configuration()); assertThat(pipelineDef).isEqualTo(fullDef); @@ -560,17 +578,28 @@ void testParsingFullDefinitionFromString() throws Exception { "add new uniq_id for each row", null)), Collections.emptyList(), - Collections.singletonList( + Arrays.asList( new ModelDef( - "GET_EMBEDDING", - "OpenAIEmbeddingModel", + "Sonnet", + "openai-compatible", new LinkedHashMap<>( ImmutableMap.builder() - .put("model-name", "GET_EMBEDDING") - .put("class-name", "OpenAIEmbeddingModel") - .put("openai.model", "text-embedding-3-small") - .put("openai.host", "https://xxxx") - .put("openai.apikey", "abcd1234") + .put("model-name", "claude-sonnet-4-6") + .put( + "endpoint", + "https://idealab.alibaba-inc.com/api/openai/v1") + .put("api-key", "cafebabe") + .build())), + new ModelDef( + "Opus", + "openai-compatible", + new LinkedHashMap<>( + ImmutableMap.builder() + .put("model-name", "claude-opus-4-5") + .put( + "endpoint", + "https://idealab.alibaba-inc.com/api/openai/v1") + .put("api-key", "cafebabe") .build()))), Configuration.fromMap( ImmutableMap.builder() @@ -710,6 +739,189 @@ void testParsingFullDefinitionFromString() throws Exception { .put("schema-operator.rpc-timeout", "1 h") .build())); + @Test + void testParsingSingleModelAsObject() throws Exception { + String pipelineDefText = + "source:\n" + + " type: foo\n" + + "sink:\n" + + " type: bar\n" + + "pipeline:\n" + + " parallelism: 1\n" + + " model:\n" + + " name: Sonnet\n" + + " type: openai-compatible\n" + + " model-name: claude-sonnet-4-6\n" + + " endpoint: https://example.com/v1\n" + + " api-key: test-key\n"; + YamlPipelineDefinitionParser parser = new YamlPipelineDefinitionParser(); + PipelineDef pipelineDef = parser.parse(pipelineDefText, new Configuration()); + assertThat(pipelineDef.getModels()).hasSize(1); + assertThat(pipelineDef.getModels().get(0)) + .isEqualTo( + new ModelDef( + "Sonnet", + "openai-compatible", + new LinkedHashMap<>( + ImmutableMap.of( + "model-name", + "claude-sonnet-4-6", + "endpoint", + "https://example.com/v1", + "api-key", + "test-key")))); + } + + @Test + void testParsingMultipleModelsAsList() throws Exception { + String pipelineDefText = + "source:\n" + + " type: foo\n" + + "sink:\n" + + " type: bar\n" + + "pipeline:\n" + + " parallelism: 1\n" + + " model:\n" + + " - name: Sonnet\n" + + " type: openai-compatible\n" + + " model-name: claude-sonnet-4-6\n" + + " endpoint: https://example.com/v1\n" + + " api-key: key1\n" + + " - name: Opus\n" + + " type: openai-compatible\n" + + " model-name: claude-opus-4-5\n" + + " endpoint: https://example.com/v1\n" + + " api-key: key2\n"; + YamlPipelineDefinitionParser parser = new YamlPipelineDefinitionParser(); + PipelineDef pipelineDef = parser.parse(pipelineDefText, new Configuration()); + assertThat(pipelineDef.getModels()).hasSize(2); + assertThat(pipelineDef.getModels().get(0)) + .isEqualTo( + new ModelDef( + "Sonnet", + "openai-compatible", + new LinkedHashMap<>( + ImmutableMap.of( + "model-name", + "claude-sonnet-4-6", + "endpoint", + "https://example.com/v1", + "api-key", + "key1")))); + assertThat(pipelineDef.getModels().get(1)) + .isEqualTo( + new ModelDef( + "Opus", + "openai-compatible", + new LinkedHashMap<>( + ImmutableMap.of( + "model-name", + "claude-opus-4-5", + "endpoint", + "https://example.com/v1", + "api-key", + "key2")))); + } + + @Test + void testDuplicateModelName() { + String pipelineDefText = + "source:\n" + + " type: foo\n" + + "sink:\n" + + " type: bar\n" + + "pipeline:\n" + + " parallelism: 1\n" + + " model:\n" + + " - name: Sonnet\n" + + " type: openai-compatible\n" + + " model-name: claude-sonnet-4-6\n" + + " endpoint: https://example.com/v1\n" + + " api-key: key1\n" + + " - name: Sonnet\n" + + " type: openai-compatible\n" + + " model-name: claude-sonnet-4-5\n" + + " endpoint: https://example.com/v1\n" + + " api-key: key2\n"; + YamlPipelineDefinitionParser parser = new YamlPipelineDefinitionParser(); + assertThatThrownBy(() -> parser.parse(pipelineDefText, new Configuration())) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Duplicate model name 'Sonnet' in pipeline definition."); + } + + @ParameterizedTest + @ValueSource(strings = {"123model", "my-model", "__reserved", "my model", "my.model"}) + void testInvalidModelName(String invalidName) { + YamlPipelineDefinitionParser parser = new YamlPipelineDefinitionParser(); + assertThatThrownBy( + () -> + parser.parse( + buildPipelineDefWithModelName(invalidName), + new Configuration())) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining( + "Model name \"" + invalidName + "\" is not a valid identifier"); + } + + @Test + void testModelOptionsValueTypes() throws Exception { + YamlPipelineDefinitionParser parser = new YamlPipelineDefinitionParser(); + + // Boolean, integer, and float values in YAML should all be converted to String + String yaml = + "source:\n" + + " type: foo\n" + + "sink:\n" + + " type: bar\n" + + "pipeline:\n" + + " parallelism: 1\n" + + " model:\n" + + " - name: myModel\n" + + " type: dummy\n" + + " debug: true\n" + + " max-tokens: 1024\n" + + " temperature: 0.7\n" + + " endpoint: https://example.com/v1\n"; + + PipelineDef def = parser.parse(yaml, new Configuration()); + assertThat(def.getModels()).hasSize(1); + ModelDef model = def.getModels().get(0); + assertThat(model.getOptions()) + .containsEntry("debug", "true") + .containsEntry("max-tokens", "1024") + .containsEntry("temperature", "0.7") + .containsEntry("endpoint", "https://example.com/v1"); + // Verify all values are String type + model.getOptions().values().forEach(v -> assertThat(v).isInstanceOf(String.class)); + } + + @ParameterizedTest + @ValueSource(strings = {"Sonnet", "my_model_v2", "_private"}) + void testValidModelName(String validName) throws Exception { + YamlPipelineDefinitionParser parser = new YamlPipelineDefinitionParser(); + PipelineDef def = + parser.parse(buildPipelineDefWithModelName(validName), new Configuration()); + assertThat(def.getModels()).hasSize(1); + assertThat(def.getModels().get(0).getName()).isEqualTo(validName); + } + + private String buildPipelineDefWithModelName(String modelName) { + return "source:\n" + + " type: foo\n" + + "sink:\n" + + " type: bar\n" + + "pipeline:\n" + + " parallelism: 1\n" + + " model:\n" + + " - name: " + + modelName + + "\n" + + " type: openai-compatible\n" + + " model-name: some-model\n" + + " endpoint: https://example.com/v1\n" + + " api-key: test-key\n"; + } + private final PipelineDef pipelineDefWithUdf = new PipelineDef( new SourceDef("values", null, new Configuration()), diff --git a/flink-cdc-cli/src/test/resources/definitions/pipeline-definition-full.yaml b/flink-cdc-cli/src/test/resources/definitions/pipeline-definition-full.yaml index dec2c25dc23..8af7ac40d35 100644 --- a/flink-cdc-cli/src/test/resources/definitions/pipeline-definition-full.yaml +++ b/flink-cdc-cli/src/test/resources/definitions/pipeline-definition-full.yaml @@ -60,8 +60,13 @@ pipeline: schema-operator.rpc-timeout: 1 h execution.runtime-mode: STREAMING model: - - model-name: GET_EMBEDDING - class-name: OpenAIEmbeddingModel - openai.model: text-embedding-3-small - openai.host: https://xxxx - openai.apikey: abcd1234 + - name: Sonnet + type: openai-compatible + model-name: claude-sonnet-4-6 + endpoint: https://idealab.alibaba-inc.com/api/openai/v1 + api-key: cafebabe + - name: Opus + type: openai-compatible + model-name: claude-opus-4-5 + endpoint: https://idealab.alibaba-inc.com/api/openai/v1 + api-key: cafebabe diff --git a/flink-cdc-common/src/main/java/org/apache/flink/cdc/common/model/AiModelClient.java b/flink-cdc-common/src/main/java/org/apache/flink/cdc/common/model/AiModelClient.java new file mode 100644 index 00000000000..0e46508bb2e --- /dev/null +++ b/flink-cdc-common/src/main/java/org/apache/flink/cdc/common/model/AiModelClient.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.cdc.common.model; + +import org.apache.flink.cdc.common.annotation.Experimental; + +import java.io.Serializable; + +/** + * Marker interface for a runtime AI model client. Concrete capabilities are declared via ability + * interfaces in {@code org.apache.flink.cdc.common.model.abilities}. + * + *

Implementations must be {@link Serializable} so that they can be distributed across Flink task + * managers together with the operator that holds them. + */ +@Experimental +public interface AiModelClient extends Serializable, AutoCloseable { + + default void open() throws Exception { + // Do nothing + } + + @Override + default void close() throws Exception { + // Do nothing + } +} diff --git a/flink-cdc-common/src/main/java/org/apache/flink/cdc/common/model/AiModelClientFactory.java b/flink-cdc-common/src/main/java/org/apache/flink/cdc/common/model/AiModelClientFactory.java new file mode 100644 index 00000000000..e32d2ed0e09 --- /dev/null +++ b/flink-cdc-common/src/main/java/org/apache/flink/cdc/common/model/AiModelClientFactory.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.cdc.common.model; + +import org.apache.flink.cdc.common.annotation.Experimental; + +import java.util.Set; +import java.util.stream.Collectors; + +/** + * SPI interface for AI model client factories. Each provider (e.g. OpenAI-compatible, DashScope) + * ships one implementation, discoverable via {@link java.util.ServiceLoader}. + * + *

The {@link #identifier()} value maps to the {@code type} field of a {@code pipeline.model} + * entry in the pipeline YAML. + */ +@Experimental +public interface AiModelClientFactory { + + /** A unique, lower-case identifier for this provider, e.g. {@code "openai-compatible"}. */ + String identifier(); + + /** Option keys that must be present in the model YAML options block. */ + Set requiredOptions(); + + /** Option keys that may optionally appear in the model YAML options block. */ + Set optionalOptions(); + + /** + * Validates that the given context contains all required options and no unknown options. + * Subclasses may override this to add custom validation logic. + */ + default void validate(ModelContext context) { + Set required = requiredOptions(); + Set optional = optionalOptions(); + if (required != null) { + Set missing = + required.stream() + .filter(k -> !context.getOptions().containsKey(k)) + .collect(Collectors.toSet()); + if (!missing.isEmpty()) { + throw new IllegalArgumentException( + "Missing required options for model '" + + context.getModelName() + + "' (type='" + + identifier() + + "'): " + + missing); + } + } + if (required != null && optional != null) { + Set unknown = + context.getOptions().keySet().stream() + .filter(k -> !required.contains(k) && !optional.contains(k)) + .collect(Collectors.toSet()); + if (!unknown.isEmpty()) { + throw new IllegalArgumentException( + "Unknown options for model '" + + context.getModelName() + + "' (type='" + + identifier() + + "'): " + + unknown); + } + } + } + + /** + * Creates a new {@link AiModelClient} from the given context. Called once per model definition + * at pipeline assembly time on the job-manager side; the returned client is serialized and + * shipped to task managers. + */ + AiModelClient createClient(ModelContext context); +} diff --git a/flink-cdc-common/src/main/java/org/apache/flink/cdc/common/model/ModelContext.java b/flink-cdc-common/src/main/java/org/apache/flink/cdc/common/model/ModelContext.java new file mode 100644 index 00000000000..a6b30250a4d --- /dev/null +++ b/flink-cdc-common/src/main/java/org/apache/flink/cdc/common/model/ModelContext.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.cdc.common.model; + +import org.apache.flink.cdc.common.annotation.Experimental; + +import java.util.Map; + +/** Context passed to {@link AiModelClientFactory#createClient} at pipeline assembly time. */ +@Experimental +public interface ModelContext { + + /** The logical name of this model as declared in the pipeline YAML. */ + String getModelName(); + + /** Raw key/value options from the pipeline YAML {@code model.options} block. */ + Map getOptions(); + + /** Class loader to use when loading implementation classes. */ + ClassLoader getClassLoader(); +} diff --git a/flink-cdc-common/src/main/java/org/apache/flink/cdc/common/model/abilities/SupportsEmbedding.java b/flink-cdc-common/src/main/java/org/apache/flink/cdc/common/model/abilities/SupportsEmbedding.java new file mode 100644 index 00000000000..d6c8fc7d1c3 --- /dev/null +++ b/flink-cdc-common/src/main/java/org/apache/flink/cdc/common/model/abilities/SupportsEmbedding.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.cdc.common.model.abilities; + +import org.apache.flink.cdc.common.annotation.Experimental; +import org.apache.flink.cdc.common.model.AiModelClient; + +/** + * Ability interface for {@link AiModelClient} implementations that can produce dense vector + * embeddings from text input. + */ +@Experimental +public interface SupportsEmbedding { + + /** Converts the given text into a dense float vector. */ + float[] embed(String text); +} diff --git a/flink-cdc-common/src/main/java/org/apache/flink/cdc/common/model/abilities/SupportsTextGeneration.java b/flink-cdc-common/src/main/java/org/apache/flink/cdc/common/model/abilities/SupportsTextGeneration.java new file mode 100644 index 00000000000..eb1ea7362e4 --- /dev/null +++ b/flink-cdc-common/src/main/java/org/apache/flink/cdc/common/model/abilities/SupportsTextGeneration.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.cdc.common.model.abilities; + +import org.apache.flink.cdc.common.annotation.Experimental; +import org.apache.flink.cdc.common.model.AiModelClient; + +/** + * Ability interface for {@link AiModelClient} implementations that can perform chat-style text + * generation given a system prompt and a user input. + */ +@Experimental +public interface SupportsTextGeneration { + + /** + * Generates text based on a system-level prompt and a user-provided input message. Returns a + * JSON string conforming to the output schema declared by the calling AI function. + */ + String generate(String systemPrompt, String userInput); +} diff --git a/flink-cdc-common/src/test/java/org/apache/flink/cdc/common/model/AiModelClientFactoryTest.java b/flink-cdc-common/src/test/java/org/apache/flink/cdc/common/model/AiModelClientFactoryTest.java new file mode 100644 index 00000000000..721362ec20e --- /dev/null +++ b/flink-cdc-common/src/test/java/org/apache/flink/cdc/common/model/AiModelClientFactoryTest.java @@ -0,0 +1,166 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.cdc.common.model; + +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.Test; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.Set; + +/** Tests for the default {@link AiModelClientFactory#validate} method. */ +class AiModelClientFactoryTest { + + private static final String IDENTIFIER = "test-provider"; + private static final String MODEL_NAME = "my-model"; + + private static final class StubFactory implements AiModelClientFactory { + private final Set required; + private final Set optional; + + StubFactory(Set required, Set optional) { + this.required = required; + this.optional = optional; + } + + @Override + public String identifier() { + return IDENTIFIER; + } + + @Override + public Set requiredOptions() { + return required; + } + + @Override + public Set optionalOptions() { + return optional; + } + + @Override + public AiModelClient createClient(ModelContext context) { + return new AiModelClient() {}; + } + } + + private static ModelContext contextWithOptions(Map options) { + return new ModelContext() { + @Override + public String getModelName() { + return MODEL_NAME; + } + + @Override + public Map getOptions() { + return options; + } + + @Override + public ClassLoader getClassLoader() { + return Thread.currentThread().getContextClassLoader(); + } + }; + } + + @Test + void testValidatePassesWithAllRequiredOptions() { + StubFactory factory = new StubFactory(Set.of("api-key", "endpoint"), Set.of("timeout")); + + Map options = new HashMap<>(); + options.put("api-key", "sk-xxx"); + options.put("endpoint", "https://api.example.com"); + + // Should not throw + factory.validate(contextWithOptions(options)); + } + + @Test + void testValidatePassesWithRequiredAndOptionalOptions() { + StubFactory factory = new StubFactory(Set.of("api-key", "endpoint"), Set.of("timeout")); + + Map options = new HashMap<>(); + options.put("api-key", "sk-xxx"); + options.put("endpoint", "https://api.example.com"); + options.put("timeout", "30000"); + + factory.validate(contextWithOptions(options)); + } + + @Test + void testValidateThrowsOnMissingRequiredOption() { + StubFactory factory = new StubFactory(Set.of("api-key", "endpoint"), Set.of("timeout")); + + // Missing "endpoint" + Map options = new HashMap<>(); + options.put("api-key", "sk-xxx"); + + Assertions.assertThatThrownBy(() -> factory.validate(contextWithOptions(options))) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage( + "Missing required options for model 'my-model' (type='test-provider'): [endpoint]"); + } + + @Test + void testValidateThrowsOnMultipleMissingRequiredOptions() { + StubFactory factory = new StubFactory(Set.of("api-key", "endpoint", "model"), Set.of()); + + // All required options missing + Assertions.assertThatThrownBy( + () -> factory.validate(contextWithOptions(Collections.emptyMap()))) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage( + "Missing required options for model 'my-model' (type='test-provider'): [endpoint, api-key, model]"); + } + + @Test + void testValidateThrowsOnUnknownOption() { + StubFactory factory = new StubFactory(Set.of("api-key"), Set.of("timeout")); + + Map options = new HashMap<>(); + options.put("api-key", "sk-xxx"); + options.put("bogus", "unexpected"); + + Assertions.assertThatThrownBy(() -> factory.validate(contextWithOptions(options))) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Unknown options for model 'my-model' (type='test-provider'): [bogus]"); + } + + @Test + void testValidateThrowsOnMultipleUnknownOptions() { + StubFactory factory = new StubFactory(Set.of("api-key"), Set.of()); + + Map options = new HashMap<>(); + options.put("api-key", "sk-xxx"); + options.put("foo", "a"); + options.put("bar", "b"); + + Assertions.assertThatThrownBy(() -> factory.validate(contextWithOptions(options))) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage( + "Unknown options for model 'my-model' (type='test-provider'): [bar, foo]"); + } + + @Test + void testValidatePassesWithNoRequiredAndNoOptions() { + StubFactory factory = new StubFactory(Set.of(), Set.of()); + factory.validate(contextWithOptions(Collections.emptyMap())); + } +} diff --git a/flink-cdc-composer/pom.xml b/flink-cdc-composer/pom.xml index 4161491ba04..522714c9e45 100644 --- a/flink-cdc-composer/pom.xml +++ b/flink-cdc-composer/pom.xml @@ -80,9 +80,11 @@ limitations under the License. ${project.version} test + + org.apache.flink - flink-cdc-pipeline-model + flink-cdc-pipeline-model-dummy ${project.version} test diff --git a/flink-cdc-composer/src/main/java/org/apache/flink/cdc/composer/definition/ModelDef.java b/flink-cdc-composer/src/main/java/org/apache/flink/cdc/composer/definition/ModelDef.java index 21cc6befaf2..6d16e222c7b 100644 --- a/flink-cdc-composer/src/main/java/org/apache/flink/cdc/composer/definition/ModelDef.java +++ b/flink-cdc-composer/src/main/java/org/apache/flink/cdc/composer/definition/ModelDef.java @@ -20,41 +20,31 @@ import java.util.Map; import java.util.Objects; -/** - * Common properties of model. - * - *

A transformation definition contains: - * - *

    - *
  • modelName: The name of function. - *
  • className: The model to transform data. - *
  • parameters: The parameters that used to configure the model. - *
- */ +/** Common properties of model. */ public class ModelDef { - private final String modelName; + private final String name; - private final String className; + private final String type; - private final Map parameters; + private final Map options; - public ModelDef(String modelName, String className, Map parameters) { - this.modelName = modelName; - this.className = className; - this.parameters = parameters; + public ModelDef(String name, String type, Map options) { + this.name = name; + this.type = type; + this.options = options; } - public String getModelName() { - return modelName; + public String getName() { + return name; } - public String getClassName() { - return className; + public String getType() { + return type; } - public Map getParameters() { - return parameters; + public Map getOptions() { + return options; } @Override @@ -66,27 +56,27 @@ public boolean equals(Object o) { return false; } ModelDef modelDef = (ModelDef) o; - return Objects.equals(modelName, modelDef.modelName) - && Objects.equals(className, modelDef.className) - && Objects.equals(parameters, modelDef.parameters); + return Objects.equals(name, modelDef.name) + && Objects.equals(type, modelDef.type) + && Objects.equals(options, modelDef.options); } @Override public int hashCode() { - return Objects.hash(modelName, className, parameters); + return Objects.hash(name, type, options); } @Override public String toString() { return "ModelDef{" + "name='" - + modelName + + name + '\'' - + ", model='" - + className + + ", type='" + + type + '\'' - + ", parameters=" - + parameters + + ", options=" + + options + '}'; } } diff --git a/flink-cdc-composer/src/main/java/org/apache/flink/cdc/composer/flink/FlinkPipelineComposer.java b/flink-cdc-composer/src/main/java/org/apache/flink/cdc/composer/flink/FlinkPipelineComposer.java index 1854b05032b..d77c48bfee2 100644 --- a/flink-cdc-composer/src/main/java/org/apache/flink/cdc/composer/flink/FlinkPipelineComposer.java +++ b/flink-cdc-composer/src/main/java/org/apache/flink/cdc/composer/flink/FlinkPipelineComposer.java @@ -186,7 +186,6 @@ private void translate(StreamExecutionEnvironment env, PipelineDef pipelineDef) stream, pipelineDef.getTransforms(), pipelineDef.getUdfs(), - pipelineDef.getModels(), dataSource.supportedMetadataColumns()); // PreTransform ---> PostTransform diff --git a/flink-cdc-composer/src/main/java/org/apache/flink/cdc/composer/flink/translator/TransformTranslator.java b/flink-cdc-composer/src/main/java/org/apache/flink/cdc/composer/flink/translator/TransformTranslator.java index 6e296db4d1c..3056ad57663 100644 --- a/flink-cdc-composer/src/main/java/org/apache/flink/cdc/composer/flink/translator/TransformTranslator.java +++ b/flink-cdc-composer/src/main/java/org/apache/flink/cdc/composer/flink/translator/TransformTranslator.java @@ -19,6 +19,9 @@ import org.apache.flink.api.java.tuple.Tuple3; import org.apache.flink.cdc.common.event.Event; +import org.apache.flink.cdc.common.model.AiModelClient; +import org.apache.flink.cdc.common.model.AiModelClientFactory; +import org.apache.flink.cdc.common.model.ModelContext; import org.apache.flink.cdc.common.source.SupportedMetadataColumn; import org.apache.flink.cdc.composer.definition.ModelDef; import org.apache.flink.cdc.composer.definition.TransformDef; @@ -30,8 +33,12 @@ import org.apache.flink.cdc.runtime.typeutils.EventTypeInfo; import org.apache.flink.streaming.api.datastream.DataStream; +import java.util.Collections; +import java.util.HashMap; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; +import java.util.ServiceLoader; import java.util.stream.Collectors; /** @@ -40,15 +47,10 @@ */ public class TransformTranslator { - /** Package of built-in model. */ - public static final String PREFIX_CLASSPATH_BUILT_IN_MODEL = - "org.apache.flink.cdc.runtime.model."; - public DataStream translatePreTransform( DataStream input, List transforms, List udfFunctions, - List models, SupportedMetadataColumn[] supportedMetadataColumns) { if (transforms.isEmpty()) { return input; @@ -56,13 +58,12 @@ public DataStream translatePreTransform( return input.transform( "Transform:Schema", new EventTypeInfo(), - generatePreTransform(transforms, udfFunctions, models, supportedMetadataColumns)); + generatePreTransform(transforms, udfFunctions, supportedMetadataColumns)); } private PreTransformOperator generatePreTransform( List transforms, List udfFunctions, - List models, SupportedMetadataColumn[] supportedMetadataColumns) { PreTransformOperatorBuilder preTransformFunctionBuilder = PreTransformOperator.newBuilder(); @@ -79,14 +80,8 @@ private PreTransformOperator generatePreTransform( supportedMetadataColumns); } - preTransformFunctionBuilder - .addUdfFunctions( - udfFunctions.stream() - .map(this::udfDefToUDFTuple) - .collect(Collectors.toList())) - .addUdfFunctions( - models.stream().map(this::modelToUDFTuple).collect(Collectors.toList())); - + preTransformFunctionBuilder.addUdfFunctions( + udfFunctions.stream().map(this::udfDefToUDFTuple).collect(Collectors.toList())); return preTransformFunctionBuilder.build(); } @@ -116,24 +111,85 @@ public DataStream translatePostTransform( transform.getPostTransformConverter(), supportedMetadataColumns); } - postTransformFunctionBuilder.addTimezone(timezone); - postTransformFunctionBuilder.addUdfFunctions( - udfFunctions.stream().map(this::udfDefToUDFTuple).collect(Collectors.toList())); - postTransformFunctionBuilder.addUdfFunctions( - models.stream().map(this::modelToUDFTuple).collect(Collectors.toList())); + postTransformFunctionBuilder + .addTimezone(timezone) + .addUdfFunctions( + udfFunctions.stream() + .map(this::udfDefToUDFTuple) + .collect(Collectors.toList())) + .addModelClients(loadModelClients(models)); + return input.transform( "Transform:Data", new EventTypeInfo(), postTransformFunctionBuilder.build()) .uid(operatorUidGenerator.generateUid("post-transform")); } - private Tuple3> modelToUDFTuple(ModelDef model) { - return Tuple3.of( - model.getModelName(), - PREFIX_CLASSPATH_BUILT_IN_MODEL + model.getClassName(), - model.getParameters()); + /** + * Loads AI model clients for all declared models via SPI, returning a map from Janino parameter + * name to the client instance. + */ + private Map loadModelClients(List models) { + if (models.isEmpty()) { + return Collections.emptyMap(); + } + + Map factories = new HashMap<>(); + ServiceLoader loader = + ServiceLoader.load( + AiModelClientFactory.class, Thread.currentThread().getContextClassLoader()); + for (AiModelClientFactory factory : loader) { + factories.put(factory.identifier(), factory); + } + + Map clients = new LinkedHashMap<>(); + for (ModelDef model : models) { + AiModelClientFactory factory = factories.get(model.getType()); + if (factory == null) { + throw new IllegalArgumentException( + "No AiModelClientFactory found for model type '" + + model.getType() + + "'. Available factories: " + + factories.keySet()); + } + ModelContext ctx = + new DefaultModelContext(model, Thread.currentThread().getContextClassLoader()); + factory.validate(ctx); + AiModelClient client = factory.createClient(ctx); + clients.put(model.getName(), client); + } + return clients; } private Tuple3> udfDefToUDFTuple(UdfDef udf) { return Tuple3.of(udf.getName(), udf.getClasspath(), udf.getOptions()); } + + // ------------------------------------------------------------------------- + // Internal ModelContext implementation + // ------------------------------------------------------------------------- + + private static final class DefaultModelContext implements ModelContext { + private final ModelDef modelDef; + private final ClassLoader classLoader; + + DefaultModelContext(ModelDef modelDef, ClassLoader classLoader) { + this.modelDef = modelDef; + this.classLoader = classLoader; + } + + @Override + public String getModelName() { + return modelDef.getName(); + } + + @Override + public Map getOptions() { + return modelDef.getOptions(); + } + + @Override + public ClassLoader getClassLoader() { + return classLoader; + } + } } diff --git a/flink-cdc-composer/src/test/java/org/apache/flink/cdc/composer/flink/FlinkPipelineAiFunctionITCase.java b/flink-cdc-composer/src/test/java/org/apache/flink/cdc/composer/flink/FlinkPipelineAiFunctionITCase.java new file mode 100644 index 00000000000..37f1701065f --- /dev/null +++ b/flink-cdc-composer/src/test/java/org/apache/flink/cdc/composer/flink/FlinkPipelineAiFunctionITCase.java @@ -0,0 +1,212 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.cdc.composer.flink; + +import org.apache.flink.cdc.common.configuration.Configuration; +import org.apache.flink.cdc.common.data.binary.BinaryStringData; +import org.apache.flink.cdc.common.event.CreateTableEvent; +import org.apache.flink.cdc.common.event.DataChangeEvent; +import org.apache.flink.cdc.common.event.Event; +import org.apache.flink.cdc.common.event.TableId; +import org.apache.flink.cdc.common.pipeline.PipelineOptions; +import org.apache.flink.cdc.common.pipeline.SchemaChangeBehavior; +import org.apache.flink.cdc.common.schema.Schema; +import org.apache.flink.cdc.common.types.DataType; +import org.apache.flink.cdc.common.types.DataTypes; +import org.apache.flink.cdc.composer.PipelineExecution; +import org.apache.flink.cdc.composer.definition.ModelDef; +import org.apache.flink.cdc.composer.definition.PipelineDef; +import org.apache.flink.cdc.composer.definition.SinkDef; +import org.apache.flink.cdc.composer.definition.SourceDef; +import org.apache.flink.cdc.composer.definition.TransformDef; +import org.apache.flink.cdc.connectors.values.ValuesDatabase; +import org.apache.flink.cdc.connectors.values.factory.ValuesDataFactory; +import org.apache.flink.cdc.connectors.values.sink.ValuesDataSinkOptions; +import org.apache.flink.cdc.connectors.values.source.ValuesDataSourceHelper; +import org.apache.flink.cdc.connectors.values.source.ValuesDataSourceOptions; +import org.apache.flink.cdc.runtime.typeutils.BinaryRecordDataGenerator; +import org.apache.flink.runtime.testutils.MiniClusterResourceConfiguration; +import org.apache.flink.test.junit5.MiniClusterExtension; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import java.io.ByteArrayOutputStream; +import java.io.PrintStream; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; + +import static org.apache.flink.configuration.CoreOptions.ALWAYS_PARENT_FIRST_LOADER_PATTERNS_ADDITIONAL; +import static org.assertj.core.api.Assertions.assertThat; + +/** Integration test for AI functions in the Flink pipeline. */ +class FlinkPipelineAiFunctionITCase { + + private static final int MAX_PARALLELISM = 4; + + private static final org.apache.flink.configuration.Configuration MINI_CLUSTER_CONFIG = + new org.apache.flink.configuration.Configuration(); + + static { + MINI_CLUSTER_CONFIG.set( + ALWAYS_PARENT_FIRST_LOADER_PATTERNS_ADDITIONAL, + Collections.singletonList("org.apache.flink.cdc")); + } + + @RegisterExtension + static final MiniClusterExtension MINI_CLUSTER_RESOURCE = + new MiniClusterExtension( + new MiniClusterResourceConfiguration.Builder() + .setNumberTaskManagers(1) + .setNumberSlotsPerTaskManager(MAX_PARALLELISM) + .setConfiguration(MINI_CLUSTER_CONFIG) + .build()); + + private final PrintStream standardOut = System.out; + private final ByteArrayOutputStream outCaptor = new ByteArrayOutputStream(); + + @BeforeEach + void init() { + System.setOut(new PrintStream(outCaptor)); + ValuesDatabase.clear(); + } + + @AfterEach + void cleanup() { + System.setOut(standardOut); + } + + private static final String DUMMY_JSON = "{\"result\":\"dummy result\",\"summary\":\"TL;DR\"}"; + + @Test + void testAiCompleteInProjection() throws Exception { + String[] output = + runAiFunctionTest( + "id, content, AI_COMPLETE('testModel', content, 'Classify sentiment') AS res", + List.of(new ModelDef("testModel", "dummy", new HashMap<>()))); + assertThat(output) + .containsExactly( + "CreateTableEvent{tableId=default_namespace.default_schema.mytable1, schema=columns={`id` INT NOT NULL,`content` STRING,`res` VARIANT}, primaryKeys=id, options=()}", + "DataChangeEvent{tableId=default_namespace.default_schema.mytable1, before=[], after=[1, I love this product, " + + DUMMY_JSON + + "], op=INSERT, meta=()}"); + } + + @Test + void testAiEmbedInProjection() throws Exception { + String[] output = + runAiFunctionTest( + "id, AI_EMBED('embedModel', content) AS embedding", + List.of(new ModelDef("embedModel", "dummy", new HashMap<>()))); + assertThat(output) + .containsExactly( + "CreateTableEvent{tableId=default_namespace.default_schema.mytable1, schema=columns={`id` INT NOT NULL,`embedding` ARRAY}, primaryKeys=id, options=()}", + "DataChangeEvent{tableId=default_namespace.default_schema.mytable1, before=[], after=[1, [3.0, 1.0, 4.0, 1.0, 5.0, 9.0, 2.0, 6.0]], op=INSERT, meta=()}"); + } + + @Test + void testMultipleAiFunctionsWithSameModel() throws Exception { + String[] output = + runAiFunctionTest( + "id, AI_COMPLETE('myModel', content, 'Classify') AS cls, AI_SUMMARIZE('myModel', content, 100) AS summary", + List.of(new ModelDef("myModel", "dummy", new HashMap<>()))); + assertThat(output) + .containsExactly( + "CreateTableEvent{tableId=default_namespace.default_schema.mytable1, schema=columns={`id` INT NOT NULL,`cls` VARIANT,`summary` VARIANT}, primaryKeys=id, options=()}", + "DataChangeEvent{tableId=default_namespace.default_schema.mytable1, before=[], after=[1, " + + DUMMY_JSON + + ", " + + DUMMY_JSON + + "], op=INSERT, meta=()}"); + } + + private String[] runAiFunctionTest(String projection, List models) throws Exception { + FlinkPipelineComposer composer = FlinkPipelineComposer.ofMiniCluster(); + + // Source: one table with a single row + TableId tableId = TableId.tableId("default_namespace", "default_schema", "mytable1"); + Schema schema = + Schema.newBuilder() + .physicalColumn("id", DataTypes.INT()) + .physicalColumn("content", DataTypes.STRING()) + .primaryKey("id") + .build(); + BinaryRecordDataGenerator generator = + new BinaryRecordDataGenerator(schema.getColumnDataTypes().toArray(new DataType[0])); + + List events = new ArrayList<>(); + events.add(new CreateTableEvent(tableId, schema)); + events.add( + DataChangeEvent.insertEvent( + tableId, + generator.generate( + new Object[] { + 1, BinaryStringData.fromString("I love this product") + }))); + ValuesDataSourceHelper.setSourceEvents(Collections.singletonList(events)); + + Configuration sourceConfig = new Configuration(); + sourceConfig.set( + ValuesDataSourceOptions.EVENT_SET_ID, + ValuesDataSourceHelper.EventSetId.CUSTOM_SOURCE_EVENTS); + SourceDef sourceDef = + new SourceDef(ValuesDataFactory.IDENTIFIER, "Value Source", sourceConfig); + + // Sink + Configuration sinkConfig = new Configuration(); + sinkConfig.set(ValuesDataSinkOptions.MATERIALIZED_IN_MEMORY, true); + SinkDef sinkDef = new SinkDef(ValuesDataFactory.IDENTIFIER, "Value Sink", sinkConfig); + + // Transform + TransformDef transformDef = + new TransformDef( + "default_namespace.default_schema.mytable1", + projection, + null, + "id", + null, + null, + null, + null); + + // Pipeline + Configuration pipelineConfig = new Configuration(); + pipelineConfig.set(PipelineOptions.PIPELINE_PARALLELISM, 1); + pipelineConfig.set( + PipelineOptions.PIPELINE_SCHEMA_CHANGE_BEHAVIOR, SchemaChangeBehavior.EVOLVE); + PipelineDef pipelineDef = + new PipelineDef( + sourceDef, + sinkDef, + Collections.emptyList(), + Collections.singletonList(transformDef), + Collections.emptyList(), + models, + pipelineConfig); + + // Execute & capture output + PipelineExecution execution = composer.compose(pipelineDef); + execution.execute(); + + return outCaptor.toString().trim().split("\n"); + } +} diff --git a/flink-cdc-composer/src/test/java/org/apache/flink/cdc/composer/flink/FlinkPipelineUdfITCase.java b/flink-cdc-composer/src/test/java/org/apache/flink/cdc/composer/flink/FlinkPipelineUdfITCase.java index 9d5a8b22e23..dc04c8de66e 100644 --- a/flink-cdc-composer/src/test/java/org/apache/flink/cdc/composer/flink/FlinkPipelineUdfITCase.java +++ b/flink-cdc-composer/src/test/java/org/apache/flink/cdc/composer/flink/FlinkPipelineUdfITCase.java @@ -21,7 +21,6 @@ import org.apache.flink.cdc.common.pipeline.PipelineOptions; import org.apache.flink.cdc.common.pipeline.SchemaChangeBehavior; import org.apache.flink.cdc.composer.PipelineExecution; -import org.apache.flink.cdc.composer.definition.ModelDef; import org.apache.flink.cdc.composer.definition.PipelineDef; import org.apache.flink.cdc.composer.definition.SinkDef; import org.apache.flink.cdc.composer.definition.SourceDef; @@ -40,7 +39,6 @@ import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.extension.RegisterExtension; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; @@ -48,10 +46,8 @@ import java.io.ByteArrayOutputStream; import java.io.PrintStream; -import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; -import java.util.LinkedHashMap; import java.util.stream.Stream; import static org.apache.flink.configuration.CoreOptions.ALWAYS_PARENT_FIRST_LOADER_PATTERNS_ADDITIONAL; @@ -911,77 +907,6 @@ void testComplicatedFlinkUdf(ValuesDataSink.SinkApi sinkApi, String language) th "DataChangeEvent{tableId=default_namespace.default_schema.table1, before=[2, , 4, Integer: 42, 2-42], after=[2, x, 4, Integer: 42, 2-42], op=UPDATE, meta=({op_ts=5})}"); } - @ParameterizedTest - @MethodSource("testParams") - @Disabled("For manual test as there is a limit for quota.") - void testTransformWithModel(ValuesDataSink.SinkApi sinkApi, String language) throws Exception { - FlinkPipelineComposer composer = FlinkPipelineComposer.ofMiniCluster(); - - // Setup value source - Configuration sourceConfig = new Configuration(); - sourceConfig.set( - ValuesDataSourceOptions.EVENT_SET_ID, - ValuesDataSourceHelper.EventSetId.TRANSFORM_TABLE); - SourceDef sourceDef = - new SourceDef(ValuesDataFactory.IDENTIFIER, "Value Source", sourceConfig); - - // Setup value sink - Configuration sinkConfig = new Configuration(); - sinkConfig.set(ValuesDataSinkOptions.MATERIALIZED_IN_MEMORY, true); - sinkConfig.set(ValuesDataSinkOptions.SINK_API, sinkApi); - SinkDef sinkDef = new SinkDef(ValuesDataFactory.IDENTIFIER, "Value Sink", sinkConfig); - - // Setup transform - TransformDef transformDef = - new TransformDef( - "default_namespace.default_schema.table1", - "*, CHAT(col1) AS emb", - null, - "col1", - null, - "key1=value1", - "", - null); - - // Setup pipeline - Configuration pipelineConfig = new Configuration(); - pipelineConfig.set(PipelineOptions.PIPELINE_PARALLELISM, 1); - pipelineConfig.set( - PipelineOptions.PIPELINE_SCHEMA_CHANGE_BEHAVIOR, SchemaChangeBehavior.EVOLVE); - PipelineDef pipelineDef = - new PipelineDef( - sourceDef, - sinkDef, - Collections.emptyList(), - Collections.singletonList(transformDef), - new ArrayList<>(), - Arrays.asList( - new ModelDef( - "CHAT", - "OpenAIChatModel", - new LinkedHashMap<>( - ImmutableMap.builder() - .put("openai.model", "gpt-4o-mini") - .put( - "openai.host", - "http://langchain4j.dev/demo/openai/v1") - .put("openai.apikey", "demo") - .build()))), - pipelineConfig); - - // Execute the pipeline - PipelineExecution execution = composer.compose(pipelineDef); - execution.execute(); - - // Check the order and content of all received events - String[] outputEvents = outCaptor.toString().trim().split("\n"); - assertThat(outputEvents) - .contains( - "CreateTableEvent{tableId=default_namespace.default_schema.table1, schema=columns={`col1` STRING NOT NULL,`col2` STRING,`emb` STRING}, primaryKeys=col1, options=({key1=value1})}") - // The result of transform by model is not fixed. - .hasSize(9); - } - @ParameterizedTest @MethodSource("testParams") void testComplicatedUdfReturnTypes(ValuesDataSink.SinkApi sinkApi, String language) diff --git a/flink-cdc-composer/src/test/resources/META-INF/services/org.apache.flink.cdc.common.model.AiModelClientFactory b/flink-cdc-composer/src/test/resources/META-INF/services/org.apache.flink.cdc.common.model.AiModelClientFactory new file mode 100644 index 00000000000..c1ed9c43ff3 --- /dev/null +++ b/flink-cdc-composer/src/test/resources/META-INF/services/org.apache.flink.cdc.common.model.AiModelClientFactory @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +org.apache.flink.cdc.models.dummy.DummyModelClientFactory diff --git a/flink-cdc-e2e-tests/flink-cdc-pipeline-e2e-tests/pom.xml b/flink-cdc-e2e-tests/flink-cdc-pipeline-e2e-tests/pom.xml index 24e6213ed9e..6c1aa689fec 100644 --- a/flink-cdc-e2e-tests/flink-cdc-pipeline-e2e-tests/pom.xml +++ b/flink-cdc-e2e-tests/flink-cdc-pipeline-e2e-tests/pom.xml @@ -543,6 +543,16 @@ limitations under the License. + + org.apache.flink + flink-cdc-pipeline-model-dummy + ${project.version} + dummy-model.jar + jar + ${project.build.directory}/dependencies + + + org.apache.flink flink-cdc-pipeline-connector-mysql diff --git a/flink-cdc-e2e-tests/flink-cdc-pipeline-e2e-tests/src/test/java/org/apache/flink/cdc/pipeline/tests/AiFunctionE2eITCase.java b/flink-cdc-e2e-tests/flink-cdc-pipeline-e2e-tests/src/test/java/org/apache/flink/cdc/pipeline/tests/AiFunctionE2eITCase.java new file mode 100644 index 00000000000..c25e72df103 --- /dev/null +++ b/flink-cdc-e2e-tests/flink-cdc-pipeline-e2e-tests/src/test/java/org/apache/flink/cdc/pipeline/tests/AiFunctionE2eITCase.java @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.cdc.pipeline.tests; + +import org.apache.flink.cdc.common.test.utils.TestUtils; +import org.apache.flink.cdc.pipeline.tests.utils.PipelineTestEnvironment; + +import org.junit.jupiter.api.Test; + +import java.nio.file.Path; +import java.time.Duration; + +/** E2e tests for AI functions with the dummy model SPI. */ +class AiFunctionE2eITCase extends PipelineTestEnvironment { + + private static final String TABLE_1 = "default_namespace.default_schema.table1"; + private static final String TABLE_2 = "default_namespace.default_schema.table2"; + private static final String DUMMY_JSON = + "{\"result\":\"dummy result\",\"summary\":\"TL;DR\"}"; + private static final String DUMMY_EMBEDDING = "[3.0, 1.0, 4.0, 1.0, 5.0, 9.0, 2.0, 6.0]"; + + @Test + void testAiFunctionsWithDummyModel() throws Exception { + String pipelineJob = + "source:\n" + + " type: values\n" + + " event-set.id: SINGLE_SPLIT_MULTI_TABLES\n" + + "\n" + + "sink:\n" + + " type: values\n" + + "\n" + + "transform:\n" + + " - source-table: " + + TABLE_1 + + "\n" + + " projection: col1, AI_COMPLETE('myModel', col1, 'Classify into catA or catB') AS cls\n" + + " - source-table: " + + TABLE_2 + + "\n" + + " projection: col1, AI_EMBED('myModel', col1) AS embedding\n" + + "\n" + + "pipeline:\n" + + " parallelism: 1\n" + + " schema.change.behavior: evolve\n" + + " model:\n" + + " name: myModel\n" + + " type: dummy\n" + + " debug: true\n"; + + Path dummyModelJar = TestUtils.getResource("dummy-model.jar"); + submitPipelineJob(pipelineJob, dummyModelJar); + waitUntilJobFinished(Duration.ofMinutes(3)); + + validateResult("Successfully opened AI model client 'myModel'."); + validateResult( + "CreateTableEvent{tableId=" + + TABLE_1 + + ", schema=columns={`col1` STRING NOT NULL,`cls` VARIANT}, primaryKeys=col1, options=()}", + "DataChangeEvent{tableId=" + + TABLE_1 + + ", before=[], after=[1, " + + DUMMY_JSON + + "], op=INSERT, meta=()}", + "DataChangeEvent{tableId=" + + TABLE_1 + + ", before=[], after=[2, " + + DUMMY_JSON + + "], op=INSERT, meta=()}", + "DataChangeEvent{tableId=" + + TABLE_1 + + ", before=[], after=[3, " + + DUMMY_JSON + + "], op=INSERT, meta=()}", + "DataChangeEvent{tableId=" + + TABLE_1 + + ", before=[1, " + + DUMMY_JSON + + "], after=[], op=DELETE, meta=()}", + "DataChangeEvent{tableId=" + + TABLE_1 + + ", before=[2, " + + DUMMY_JSON + + "], after=[2, " + + DUMMY_JSON + + "], op=UPDATE, meta=()}"); + + validateResult( + "CreateTableEvent{tableId=" + + TABLE_2 + + ", schema=columns={`col1` STRING NOT NULL,`embedding` ARRAY}, primaryKeys=col1, options=()}", + "DataChangeEvent{tableId=" + + TABLE_2 + + ", before=[], after=[1, " + + DUMMY_EMBEDDING + + "], op=INSERT, meta=()}", + "DataChangeEvent{tableId=" + + TABLE_2 + + ", before=[], after=[2, " + + DUMMY_EMBEDDING + + "], op=INSERT, meta=()}", + "DataChangeEvent{tableId=" + + TABLE_2 + + ", before=[], after=[3, " + + DUMMY_EMBEDDING + + "], op=INSERT, meta=()}"); + validateResult("Successfully closed AI model client 'myModel'."); + } +} diff --git a/flink-cdc-e2e-tests/flink-cdc-pipeline-e2e-tests/src/test/resources/rules/malformed.yaml b/flink-cdc-e2e-tests/flink-cdc-pipeline-e2e-tests/src/test/resources/rules/malformed.yaml index ffed7ad9bf5..6fe4a61e1d6 100644 --- a/flink-cdc-e2e-tests/flink-cdc-pipeline-e2e-tests/src/test/resources/rules/malformed.yaml +++ b/flink-cdc-e2e-tests/flink-cdc-pipeline-e2e-tests/src/test/resources/rules/malformed.yaml @@ -55,20 +55,3 @@ steps: error: | YAML UDF block is expecting an array children, but got an OBJECT ({"name":"addone","classpath":"org.apache.flink.cdc.udf.examples.%s.AddOneFunctionClass"}). Perhaps you missed a dash prefix `-`? - # Models not an array - - type: submit - yaml: | - source: - type: values - sink: - type: values - pipeline: - model: - model-name: GET_EMBEDDING - class-name: OpenAIEmbeddingModel - openai.model: text-embedding-3-small - openai.host: https://xxxx - openai.apikey: abcd1234 - error: | - YAML model block is expecting an array children, but got an OBJECT ({"model-name":"GET_EMBEDDING","class-name":"OpenAIEmbeddingModel","openai.model":"text-embedding-3-small","openai.host":"https://xxxx","openai.apikey":"abcd1234"}). - Perhaps you missed a dash prefix `-`? diff --git a/flink-cdc-pipeline-model/flink-cdc-pipeline-model-dummy/pom.xml b/flink-cdc-pipeline-model/flink-cdc-pipeline-model-dummy/pom.xml new file mode 100644 index 00000000000..81f462b0fd8 --- /dev/null +++ b/flink-cdc-pipeline-model/flink-cdc-pipeline-model-dummy/pom.xml @@ -0,0 +1,32 @@ + + + + + + org.apache.flink + flink-cdc-pipeline-model + ${revision} + + + 4.0.0 + + flink-cdc-pipeline-model-dummy + jar + diff --git a/flink-cdc-pipeline-model/flink-cdc-pipeline-model-dummy/src/main/java/org/apache/flink/cdc/models/dummy/DummyModelClient.java b/flink-cdc-pipeline-model/flink-cdc-pipeline-model-dummy/src/main/java/org/apache/flink/cdc/models/dummy/DummyModelClient.java new file mode 100644 index 00000000000..ea33eb6d851 --- /dev/null +++ b/flink-cdc-pipeline-model/flink-cdc-pipeline-model-dummy/src/main/java/org/apache/flink/cdc/models/dummy/DummyModelClient.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.cdc.models.dummy; + +import org.apache.flink.cdc.common.model.AiModelClient; +import org.apache.flink.cdc.common.model.abilities.SupportsEmbedding; +import org.apache.flink.cdc.common.model.abilities.SupportsTextGeneration; + +/** Deterministic dummy AI model client for testing. */ +public class DummyModelClient implements AiModelClient, SupportsTextGeneration, SupportsEmbedding { + + private static final long serialVersionUID = 1L; + + private final boolean debug; + + public DummyModelClient(boolean debug) { + this.debug = debug; + } + + @Override + public String generate(String systemPrompt, String userInput) { + if (debug) { + System.out.printf("Received prompt: %s\nUser input: %s\n", systemPrompt, userInput); + } + // Returns a JSON covering fields for AI_COMPLETE and AI_SUMMARIZE + return "{\"result\":\"dummy result\",\"summary\":\"TL;DR\"}"; + } + + @Override + public float[] embed(String text) { + return new float[] {3f, 1f, 4f, 1f, 5f, 9f, 2f, 6f}; + } + + @Override + public void open() { + if (debug) { + System.out.println("Dummy model opened."); + } + } + + @Override + public void close() { + if (debug) { + System.out.println("Dummy model closed."); + } + } +} diff --git a/flink-cdc-pipeline-model/flink-cdc-pipeline-model-dummy/src/main/java/org/apache/flink/cdc/models/dummy/DummyModelClientFactory.java b/flink-cdc-pipeline-model/flink-cdc-pipeline-model-dummy/src/main/java/org/apache/flink/cdc/models/dummy/DummyModelClientFactory.java new file mode 100644 index 00000000000..4df45704ae3 --- /dev/null +++ b/flink-cdc-pipeline-model/flink-cdc-pipeline-model-dummy/src/main/java/org/apache/flink/cdc/models/dummy/DummyModelClientFactory.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.cdc.models.dummy; + +import org.apache.flink.cdc.common.model.AiModelClient; +import org.apache.flink.cdc.common.model.AiModelClientFactory; +import org.apache.flink.cdc.common.model.ModelContext; + +import java.util.Set; + +/** SPI factory for {@link DummyModelClient}. For testing purposes only. */ +public class DummyModelClientFactory implements AiModelClientFactory { + + @Override + public String identifier() { + return "dummy"; + } + + @Override + public Set requiredOptions() { + return Set.of(); + } + + @Override + public Set optionalOptions() { + return Set.of("debug"); + } + + @Override + public AiModelClient createClient(ModelContext context) { + boolean debug = Boolean.parseBoolean(context.getOptions().getOrDefault("debug", "false")); + return new DummyModelClient(debug); + } +} diff --git a/flink-cdc-pipeline-model/flink-cdc-pipeline-model-dummy/src/main/resources/META-INF/services/org.apache.flink.cdc.common.model.AiModelClientFactory b/flink-cdc-pipeline-model/flink-cdc-pipeline-model-dummy/src/main/resources/META-INF/services/org.apache.flink.cdc.common.model.AiModelClientFactory new file mode 100644 index 00000000000..c1ed9c43ff3 --- /dev/null +++ b/flink-cdc-pipeline-model/flink-cdc-pipeline-model-dummy/src/main/resources/META-INF/services/org.apache.flink.cdc.common.model.AiModelClientFactory @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +org.apache.flink.cdc.models.dummy.DummyModelClientFactory diff --git a/flink-cdc-pipeline-model/pom.xml b/flink-cdc-pipeline-model/pom.xml index e7dba7f6f66..29a24da57e9 100644 --- a/flink-cdc-pipeline-model/pom.xml +++ b/flink-cdc-pipeline-model/pom.xml @@ -19,16 +19,19 @@ limitations under the License. xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> - flink-cdc-parent org.apache.flink + flink-cdc-parent ${revision} + 4.0.0 flink-cdc-pipeline-model - - 0.23.0 - + pom + + + flink-cdc-pipeline-model-dummy + @@ -37,51 +40,6 @@ limitations under the License. ${project.version} provided - - org.apache.flink - flink-test-utils-junit - ${flink.version} - test - - - org.testcontainers - testcontainers - - - - - dev.langchain4j - langchain4j - ${langchain4j.version} - - - dev.langchain4j - langchain4j-open-ai - ${langchain4j.version} - - - com.theokanning.openai-gpt3-java - service - 0.12.0 - - - - - - org.apache.maven.plugins - maven-jar-plugin - - - test-jar - - test-jar - - - - - - - - \ No newline at end of file + diff --git a/flink-cdc-pipeline-model/src/main/java/org/apache/flink/cdc/runtime/model/ModelOptions.java b/flink-cdc-pipeline-model/src/main/java/org/apache/flink/cdc/runtime/model/ModelOptions.java deleted file mode 100644 index f56b76d5bae..00000000000 --- a/flink-cdc-pipeline-model/src/main/java/org/apache/flink/cdc/runtime/model/ModelOptions.java +++ /dev/null @@ -1,50 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.cdc.runtime.model; - -import org.apache.flink.cdc.common.configuration.ConfigOption; -import org.apache.flink.cdc.common.configuration.ConfigOptions; - -/** Options of built-in model. */ -public class ModelOptions { - - // Options for Open AI Model. - public static final ConfigOption OPENAI_MODEL_NAME = - ConfigOptions.key("openai.model") - .stringType() - .noDefaultValue() - .withDescription("Name of model to be called."); - - public static final ConfigOption OPENAI_HOST = - ConfigOptions.key("openai.host") - .stringType() - .noDefaultValue() - .withDescription("Host of the Model server to be connected."); - - public static final ConfigOption OPENAI_API_KEY = - ConfigOptions.key("openai.apikey") - .stringType() - .noDefaultValue() - .withDescription("Api Key for verification of the Model server."); - - public static final ConfigOption OPENAI_CHAT_PROMPT = - ConfigOptions.key("openai.chat.prompt") - .stringType() - .noDefaultValue() - .withDescription("Prompt for chat using OpenAI."); -} diff --git a/flink-cdc-pipeline-model/src/main/java/org/apache/flink/cdc/runtime/model/OpenAIChatModel.java b/flink-cdc-pipeline-model/src/main/java/org/apache/flink/cdc/runtime/model/OpenAIChatModel.java deleted file mode 100644 index 2fa147f509f..00000000000 --- a/flink-cdc-pipeline-model/src/main/java/org/apache/flink/cdc/runtime/model/OpenAIChatModel.java +++ /dev/null @@ -1,97 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.cdc.runtime.model; - -import org.apache.flink.cdc.common.configuration.Configuration; -import org.apache.flink.cdc.common.types.DataType; -import org.apache.flink.cdc.common.types.DataTypes; -import org.apache.flink.cdc.common.udf.UserDefinedFunction; -import org.apache.flink.cdc.common.udf.UserDefinedFunctionContext; -import org.apache.flink.cdc.common.utils.Preconditions; - -import dev.langchain4j.data.message.UserMessage; -import dev.langchain4j.model.openai.OpenAiChatModel; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.util.Collections; - -import static org.apache.flink.cdc.runtime.model.ModelOptions.OPENAI_API_KEY; -import static org.apache.flink.cdc.runtime.model.ModelOptions.OPENAI_CHAT_PROMPT; -import static org.apache.flink.cdc.runtime.model.ModelOptions.OPENAI_HOST; -import static org.apache.flink.cdc.runtime.model.ModelOptions.OPENAI_MODEL_NAME; - -/** - * A {@link UserDefinedFunction} that use Model defined by OpenAI to generate text, refer to docs}. - */ -public class OpenAIChatModel implements UserDefinedFunction { - - private static final Logger LOG = LoggerFactory.getLogger(OpenAIChatModel.class); - - private OpenAiChatModel chatModel; - - private String modelName; - - private String host; - - private String prompt; - - public String eval(String input) { - return chat(input); - } - - private String chat(String input) { - if (input == null || input.trim().isEmpty()) { - LOG.warn("Empty or null input provided for embedding."); - return ""; - } - if (prompt != null) { - input = prompt + ": " + input; - } - return chatModel - .generate(Collections.singletonList(new UserMessage(input))) - .content() - .text(); - } - - @Override - public DataType getReturnType() { - return DataTypes.STRING(); - } - - @Override - public void open(UserDefinedFunctionContext userDefinedFunctionContext) { - Configuration modelOptions = userDefinedFunctionContext.configuration(); - this.modelName = modelOptions.get(OPENAI_MODEL_NAME); - Preconditions.checkNotNull(modelName, OPENAI_MODEL_NAME.key() + " should not be empty."); - this.host = modelOptions.get(OPENAI_HOST); - Preconditions.checkNotNull(host, OPENAI_HOST.key() + " should not be empty."); - String apiKey = modelOptions.get(OPENAI_API_KEY); - Preconditions.checkNotNull(apiKey, OPENAI_API_KEY.key() + " should not be empty."); - this.prompt = modelOptions.get(OPENAI_CHAT_PROMPT); - LOG.info("Opening OpenAIChatModel " + modelName + " " + host); - this.chatModel = - OpenAiChatModel.builder().apiKey(apiKey).baseUrl(host).modelName(modelName).build(); - } - - @Override - public void close() { - LOG.info("Closed OpenAIChatModel " + modelName + " " + host); - } -} diff --git a/flink-cdc-pipeline-model/src/main/java/org/apache/flink/cdc/runtime/model/OpenAIEmbeddingModel.java b/flink-cdc-pipeline-model/src/main/java/org/apache/flink/cdc/runtime/model/OpenAIEmbeddingModel.java deleted file mode 100644 index dbc29c307ec..00000000000 --- a/flink-cdc-pipeline-model/src/main/java/org/apache/flink/cdc/runtime/model/OpenAIEmbeddingModel.java +++ /dev/null @@ -1,109 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.cdc.runtime.model; - -import org.apache.flink.cdc.common.configuration.Configuration; -import org.apache.flink.cdc.common.data.ArrayData; -import org.apache.flink.cdc.common.data.GenericArrayData; -import org.apache.flink.cdc.common.types.DataType; -import org.apache.flink.cdc.common.types.DataTypes; -import org.apache.flink.cdc.common.udf.UserDefinedFunction; -import org.apache.flink.cdc.common.udf.UserDefinedFunctionContext; -import org.apache.flink.cdc.common.utils.Preconditions; - -import dev.langchain4j.data.document.Metadata; -import dev.langchain4j.data.embedding.Embedding; -import dev.langchain4j.data.segment.TextSegment; -import dev.langchain4j.model.openai.OpenAiEmbeddingModel; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.util.Collections; -import java.util.List; - -import static org.apache.flink.cdc.runtime.model.ModelOptions.OPENAI_API_KEY; -import static org.apache.flink.cdc.runtime.model.ModelOptions.OPENAI_HOST; -import static org.apache.flink.cdc.runtime.model.ModelOptions.OPENAI_MODEL_NAME; - -/** - * A {@link UserDefinedFunction} that use Model defined by OpenAI to generate vector data, refer to - * docs}. - */ -public class OpenAIEmbeddingModel implements UserDefinedFunction { - - private static final Logger LOG = LoggerFactory.getLogger(OpenAIEmbeddingModel.class); - - private String modelName; - - private String host; - - private OpenAiEmbeddingModel embeddingModel; - - public ArrayData eval(String input) { - return getEmbedding(input); - } - - private ArrayData getEmbedding(String input) { - if (input == null || input.trim().isEmpty()) { - LOG.debug("Empty or null input provided for embedding."); - return new GenericArrayData(new Float[0]); - } - - TextSegment textSegment = new TextSegment(input, new Metadata()); - - List embeddings = - embeddingModel.embedAll(Collections.singletonList(textSegment)).content(); - - if (embeddings != null && !embeddings.isEmpty()) { - List embeddingList = embeddings.get(0).vectorAsList(); - Float[] embeddingArray = embeddingList.toArray(new Float[0]); - return new GenericArrayData(embeddingArray); - } else { - LOG.warn("No embedding results returned for input: {}", input); - return new GenericArrayData(new Float[0]); - } - } - - @Override - public DataType getReturnType() { - return DataTypes.ARRAY(DataTypes.FLOAT()); - } - - @Override - public void open(UserDefinedFunctionContext userDefinedFunctionContext) { - Configuration modelOptions = userDefinedFunctionContext.configuration(); - this.modelName = modelOptions.get(OPENAI_MODEL_NAME); - Preconditions.checkNotNull(modelName, OPENAI_MODEL_NAME.key() + " should not be empty."); - this.host = modelOptions.get(OPENAI_HOST); - Preconditions.checkNotNull(host, OPENAI_HOST.key() + " should not be empty."); - String apiKey = modelOptions.get(OPENAI_API_KEY); - Preconditions.checkNotNull(apiKey, OPENAI_API_KEY.key() + " should not be empty."); - LOG.info("Opening OpenAIEmbeddingModel " + modelName + " " + host); - this.embeddingModel = - OpenAiEmbeddingModel.builder() - .apiKey(apiKey) - .baseUrl(host) - .modelName(modelName) - .build(); - } - - @Override - public void close() { - LOG.info("Closed OpenAIEmbeddingModel " + modelName + " " + host); - } -} diff --git a/flink-cdc-pipeline-model/src/test/java/org/apache/flink/cdc/runtime/model/TestOpenAIChatModel.java b/flink-cdc-pipeline-model/src/test/java/org/apache/flink/cdc/runtime/model/TestOpenAIChatModel.java deleted file mode 100644 index b14a7d89c65..00000000000 --- a/flink-cdc-pipeline-model/src/test/java/org/apache/flink/cdc/runtime/model/TestOpenAIChatModel.java +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.cdc.runtime.model; - -import org.apache.flink.cdc.common.configuration.Configuration; -import org.apache.flink.cdc.common.udf.UserDefinedFunctionContext; - -import org.assertj.core.api.Assertions; -import org.junit.jupiter.api.Disabled; -import org.junit.jupiter.api.Test; - -/** A test for {@link OpenAIChatModel}. */ -class TestOpenAIChatModel { - @Test - @Disabled("For manual test as there is a limit for quota.") - public void testEval() { - OpenAIChatModel openAIChatModel = new OpenAIChatModel(); - Configuration configuration = new Configuration(); - configuration.set(ModelOptions.OPENAI_HOST, "http://langchain4j.dev/demo/openai/v1"); - configuration.set(ModelOptions.OPENAI_API_KEY, "demo"); - configuration.set(ModelOptions.OPENAI_MODEL_NAME, "gpt-4o-mini"); - UserDefinedFunctionContext userDefinedFunctionContext = () -> configuration; - openAIChatModel.open(userDefinedFunctionContext); - String response = openAIChatModel.eval("Who invented the electric light?"); - Assertions.assertThat(response).isNotEmpty(); - } -} diff --git a/flink-cdc-pipeline-model/src/test/java/org/apache/flink/cdc/runtime/model/TestOpenAIEmbeddingModel.java b/flink-cdc-pipeline-model/src/test/java/org/apache/flink/cdc/runtime/model/TestOpenAIEmbeddingModel.java deleted file mode 100644 index 2be50d41e05..00000000000 --- a/flink-cdc-pipeline-model/src/test/java/org/apache/flink/cdc/runtime/model/TestOpenAIEmbeddingModel.java +++ /dev/null @@ -1,45 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.cdc.runtime.model; - -import org.apache.flink.cdc.common.configuration.Configuration; -import org.apache.flink.cdc.common.data.ArrayData; -import org.apache.flink.cdc.common.udf.UserDefinedFunctionContext; - -import org.assertj.core.api.Assertions; -import org.junit.jupiter.api.Disabled; -import org.junit.jupiter.api.Test; - -/** A test for {@link OpenAIEmbeddingModel}. */ -class TestOpenAIEmbeddingModel { - - @Test - @Disabled("For manual test as there is a limit for quota.") - public void testEval() { - OpenAIEmbeddingModel openAIEmbeddingModel = new OpenAIEmbeddingModel(); - Configuration configuration = new Configuration(); - configuration.set(ModelOptions.OPENAI_HOST, "http://langchain4j.dev/demo/openai/v1"); - configuration.set(ModelOptions.OPENAI_API_KEY, "demo"); - configuration.set(ModelOptions.OPENAI_MODEL_NAME, "text-embedding-3-small"); - UserDefinedFunctionContext userDefinedFunctionContext = () -> configuration; - openAIEmbeddingModel.open(userDefinedFunctionContext); - ArrayData arrayData = - openAIEmbeddingModel.eval("Flink CDC is a streaming data integration tool"); - Assertions.assertThat(arrayData).isNotNull(); - } -} diff --git a/flink-cdc-runtime/pom.xml b/flink-cdc-runtime/pom.xml index fc474013c22..9fcfdc86c05 100644 --- a/flink-cdc-runtime/pom.xml +++ b/flink-cdc-runtime/pom.xml @@ -95,12 +95,6 @@ limitations under the License. ${project.version} test - - org.apache.flink - flink-cdc-pipeline-model - ${project.version} - test - diff --git a/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/ai/AiEmbeddingFunctionDef.java b/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/ai/AiEmbeddingFunctionDef.java new file mode 100644 index 00000000000..3554f604c2f --- /dev/null +++ b/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/ai/AiEmbeddingFunctionDef.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.cdc.runtime.ai; + +import org.apache.flink.cdc.common.types.DataType; +import org.apache.flink.cdc.common.types.DataTypes; + +/** Built-in AI embedding function definitions with configurable input and output types. */ +public enum AiEmbeddingFunctionDef { + AI_EMBED("AI_EMBED", DataTypes.STRING(), DataTypes.ARRAY(DataTypes.FLOAT())); + + private final String functionName; + private final DataType inputType; + private final DataType outputType; + + AiEmbeddingFunctionDef(String functionName, DataType inputType, DataType outputType) { + this.functionName = functionName; + this.inputType = inputType; + this.outputType = outputType; + } + + public String getFunctionName() { + return functionName; + } + + /** The type of the input value (e.g. STRING for text embedding). */ + public DataType getInputType() { + return inputType; + } + + /** The type of the output value (e.g. ARRAY<FLOAT> for vector embedding). */ + public DataType getOutputType() { + return outputType; + } +} diff --git a/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/ai/AiTextFunctionDef.java b/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/ai/AiTextFunctionDef.java new file mode 100644 index 00000000000..845a082c771 --- /dev/null +++ b/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/ai/AiTextFunctionDef.java @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.cdc.runtime.ai; + +import org.apache.flink.cdc.common.types.DataType; +import org.apache.flink.cdc.common.types.DataTypes; +import org.apache.flink.cdc.common.types.RowType; + +/** + * Built-in AI text generation function definitions with their prompt templates and type metadata. + */ +public enum AiTextFunctionDef { + AI_COMPLETE( + "AI_COMPLETE", + RowType.of(new DataType[] {DataTypes.STRING()}, new String[] {"systemPrompt"}), + RowType.of(new DataType[] {DataTypes.STRING()}, new String[] {"result"}), + "%s\n"), + + AI_SUMMARIZE( + "AI_SUMMARIZE", + RowType.of(new DataType[] {DataTypes.INT()}, new String[] {"maxLength"}), + RowType.of(new DataType[] {DataTypes.STRING()}, new String[] {"summary"}), + "You are a text summarization expert. Generate an accurate, coherent, and informative " + + "summary that does not exceed %d characters.\n" + + "Output requirements:\n" + + "- summary: the summarized content\n" + + "Principles:\n" + + "- Stay within the specified length\n" + + "- Preserve core ideas and key information\n" + + "- Use concise language with clear logic\n" + + "- Maintain text coherence\n" + + "- Avoid subjective opinions\n"); + + private final String functionName; + private final RowType inputType; + private final RowType outputType; + private final String promptTemplate; + + AiTextFunctionDef( + String functionName, RowType inputType, RowType outputType, String promptTemplate) { + this.functionName = functionName; + this.inputType = inputType; + this.outputType = outputType; + this.promptTemplate = promptTemplate; + } + + public String getFunctionName() { + return functionName; + } + + /** + * Returns the additional parameter types for promptTemplate placeholders. + * + *

Input text parameter is always added by runtime, not included here. + */ + public RowType getInputType() { + return inputType; + } + + public RowType getOutputType() { + return outputType; + } + + /** Builds the core system prompt by filling in the template placeholders. */ + public String buildPrompt(Object... args) { + return String.format(promptTemplate, args); + } +} diff --git a/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/functions/impl/AiFunctions.java b/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/functions/impl/AiFunctions.java new file mode 100644 index 00000000000..be49d8d7785 --- /dev/null +++ b/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/functions/impl/AiFunctions.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.cdc.runtime.functions.impl; + +import org.apache.flink.cdc.common.model.AiModelClient; +import org.apache.flink.cdc.common.model.abilities.SupportsEmbedding; +import org.apache.flink.cdc.common.model.abilities.SupportsTextGeneration; +import org.apache.flink.cdc.common.types.RowType; +import org.apache.flink.cdc.common.types.variant.BinaryVariant; +import org.apache.flink.cdc.common.types.variant.BinaryVariantInternalBuilder; +import org.apache.flink.cdc.runtime.ai.AiTextFunctionDef; + +import org.apache.flink.shaded.guava31.com.google.common.primitives.Floats; + +import java.util.List; + +/** Built-in AI functions available as static imports in Janino-compiled transform expressions. */ +public class AiFunctions { + + /** General-purpose text completion with a user-provided system prompt. */ + public static BinaryVariant aiComplete(AiModelClient model, String input, String systemPrompt) { + return invokeTextGeneration(model, AiTextFunctionDef.AI_COMPLETE, input, systemPrompt); + } + + /** Text summarization with a maximum length constraint. */ + public static BinaryVariant aiSummarize(AiModelClient model, String input, int maxLength) { + return invokeTextGeneration(model, AiTextFunctionDef.AI_SUMMARIZE, input, maxLength); + } + + /** Text embedding that converts input text to a vector representation. */ + public static List aiEmbed(AiModelClient model, String input) { + if (!(model instanceof SupportsEmbedding)) { + throw new UnsupportedOperationException( + "Model " + model.getClass().getName() + " does not support embedding"); + } + return Floats.asList(((SupportsEmbedding) model).embed(input)); + } + + private static BinaryVariant invokeTextGeneration( + AiModelClient model, AiTextFunctionDef funcDef, String input, Object... args) { + if (!(model instanceof SupportsTextGeneration)) { + throw new UnsupportedOperationException( + "Model " + model.getClass().getName() + " does not support text generation"); + } + StringBuilder promptBuilder = new StringBuilder(); + promptBuilder.append(funcDef.buildPrompt(args)); + promptBuilder.append("\n").append(buildOutputSchemaHint(funcDef.getOutputType())); + + String systemPrompt = promptBuilder.toString(); + String json = ((SupportsTextGeneration) model).generate(systemPrompt, input); + if (json == null) { + return null; + } + try { + return BinaryVariantInternalBuilder.parseJson(json, false); + } catch (java.io.IOException e) { + throw new RuntimeException("Failed to parse AI response as JSON: " + json, e); + } + } + + /** Builds the JSON output schema hint based on outputType. */ + private static String buildOutputSchemaHint(RowType outputType) { + StringBuilder sb = new StringBuilder(); + sb.append("You must return the result strictly in the following JSON format:\n"); + sb.append("{\n"); + List fieldNames = outputType.getFieldNames(); + for (int i = 0; i < fieldNames.size(); i++) { + sb.append(" \"") + .append(fieldNames.get(i)) + .append("\": <") + .append(fieldNames.get(i)) + .append(">"); + if (i < fieldNames.size() - 1) { + sb.append(","); + } + sb.append("\n"); + } + sb.append("}\n"); + sb.append( + "Important: Return only valid JSON with no additional text, without YAML code blocks.\n"); + return sb.toString(); + } +} diff --git a/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/operators/transform/PostTransformOperator.java b/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/operators/transform/PostTransformOperator.java index 0a8b63703b9..c90a7de8c17 100644 --- a/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/operators/transform/PostTransformOperator.java +++ b/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/operators/transform/PostTransformOperator.java @@ -29,6 +29,7 @@ import org.apache.flink.cdc.common.event.Event; import org.apache.flink.cdc.common.event.SchemaChangeEvent; import org.apache.flink.cdc.common.event.TableId; +import org.apache.flink.cdc.common.model.AiModelClient; import org.apache.flink.cdc.common.schema.Schema; import org.apache.flink.cdc.common.schema.Selectors; import org.apache.flink.cdc.common.udf.UserDefinedFunctionContext; @@ -41,6 +42,7 @@ import org.apache.flink.cdc.runtime.typeutils.BinaryRecordDataGenerator; import org.apache.flink.streaming.api.operators.OneInputStreamOperator; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.util.FlinkRuntimeException; import org.apache.flink.shaded.guava31.com.google.common.cache.CacheBuilder; import org.apache.flink.shaded.guava31.com.google.common.cache.CacheLoader; @@ -79,6 +81,9 @@ public class PostTransformOperator extends AbstractStreamOperatorAdapter // Tuple3 items are: function name, class path, and extra options. private final List>> udfFunctions; + // Serializable AI model clients keyed by model name, e.g. myModel. + private final Map modelClients; + private transient List transformers; private transient List udfDescriptors; private transient List udfFunctionInstances; @@ -98,13 +103,15 @@ public static PostTransformOperatorBuilder newBuilder() { PostTransformOperator( List transformRules, String timezone, - List>> udfFunctions) { + List>> udfFunctions, + Map modelClients) { this.timezone = timezone; this.transformRules = transformRules; this.hasAsteriskMap = new HashMap<>(); this.projectedColumnsMap = new HashMap<>(); this.postTransformInfoMap = new ConcurrentHashMap<>(); this.udfFunctions = udfFunctions; + this.modelClients = modelClients; } @Override @@ -115,6 +122,9 @@ public void open() throws Exception { this.projectionProcessors = HashBasedTable.create(); this.filterProcessors = HashBasedTable.create(); + // Initialize AI model clients + initializeAiModelClients(); + // Be sure to initialize UDF related fields before creating transformers initializeUdf(); @@ -136,6 +146,7 @@ public void close() throws Exception { super.close(); TransformExpressionCompiler.cleanUp(); destroyUdf(); + destroyAiModelClients(); } @Override @@ -447,7 +458,8 @@ private TransformProjectionProcessor getProjectionProcessor( timezone, udfDescriptors, udfFunctionInstances, - postTransformer.getSupportedMetadataColumns())); + postTransformer.getSupportedMetadataColumns(), + modelClients)); } return projectionProcessors.get(tableId, postTransformer); } @@ -472,7 +484,8 @@ private TransformFilterProcessor getFilterProcessor( timezone, udfDescriptors, udfFunctionInstances, - postTransformer.getSupportedMetadataColumns())); + postTransformer.getSupportedMetadataColumns(), + modelClients)); } } return filterProcessors.get(tableId, postTransformer); @@ -558,4 +571,28 @@ private void destroyUdf() { udfDescriptors.clear(); udfFunctionInstances.clear(); } + + private void initializeAiModelClients() { + for (Map.Entry entry : modelClients.entrySet()) { + try { + entry.getValue().open(); + LOG.info("Successfully opened AI model client '{}'.", entry.getKey()); + } catch (Exception e) { + LOG.error("Failed to open AI model client '{}'.", entry.getKey(), e); + throw new FlinkRuntimeException( + "Failed to initialize AI model: " + entry.getKey(), e); + } + } + } + + private void destroyAiModelClients() { + for (Map.Entry entry : modelClients.entrySet()) { + try { + entry.getValue().close(); + LOG.info("Successfully closed AI model client '{}'.", entry.getKey()); + } catch (Exception e) { + LOG.warn("Failed to close AI model client '{}'.", entry.getKey(), e); + } + } + } } diff --git a/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/operators/transform/PostTransformOperatorBuilder.java b/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/operators/transform/PostTransformOperatorBuilder.java index 380d9343a74..be2e8823b84 100644 --- a/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/operators/transform/PostTransformOperatorBuilder.java +++ b/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/operators/transform/PostTransformOperatorBuilder.java @@ -18,6 +18,7 @@ package org.apache.flink.cdc.runtime.operators.transform; import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.cdc.common.model.AiModelClient; import org.apache.flink.cdc.common.pipeline.PipelineOptions; import org.apache.flink.cdc.common.source.SupportedMetadataColumn; @@ -25,6 +26,7 @@ import java.time.ZoneId; import java.util.ArrayList; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; @@ -34,6 +36,7 @@ public class PostTransformOperatorBuilder { private String timezone; private final List>> udfFunctions = new ArrayList<>(); + private final Map modelClients = new LinkedHashMap<>(); public PostTransformOperatorBuilder addTransform( String tableInclusions, @@ -111,7 +114,12 @@ public PostTransformOperatorBuilder addUdfFunctions( return this; } + public PostTransformOperatorBuilder addModelClients(Map clients) { + this.modelClients.putAll(clients); + return this; + } + public PostTransformOperator build() { - return new PostTransformOperator(transformRules, timezone, udfFunctions); + return new PostTransformOperator(transformRules, timezone, udfFunctions, modelClients); } } diff --git a/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/operators/transform/ProjectionColumnProcessor.java b/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/operators/transform/ProjectionColumnProcessor.java index db5c37ca258..2dc86b37bea 100644 --- a/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/operators/transform/ProjectionColumnProcessor.java +++ b/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/operators/transform/ProjectionColumnProcessor.java @@ -18,6 +18,7 @@ package org.apache.flink.cdc.runtime.operators.transform; import org.apache.flink.cdc.common.converter.JavaClassConverter; +import org.apache.flink.cdc.common.model.AiModelClient; import org.apache.flink.cdc.common.schema.Column; import org.apache.flink.cdc.common.source.SupportedMetadataColumn; import org.apache.flink.cdc.runtime.parser.JaninoCompiler; @@ -26,6 +27,7 @@ import java.lang.reflect.InvocationTargetException; import java.util.ArrayList; +import java.util.Collections; import java.util.LinkedHashSet; import java.util.List; import java.util.Map; @@ -45,6 +47,7 @@ public class ProjectionColumnProcessor { private final TransformExpressionKey transformExpressionKey; private final Map supportedMetadataColumns; private final List udfFunctionInstances; + private final Map modelClients; private final ExpressionEvaluator expressionEvaluator; public ProjectionColumnProcessor( @@ -53,15 +56,17 @@ public ProjectionColumnProcessor( String timezone, List udfDescriptors, final List udfFunctionInstances, - Map supportedMetadataColumns) { + Map supportedMetadataColumns, + Map modelClients) { this.tableInfo = tableInfo; this.projectionColumn = projectionColumn; this.timezone = timezone; this.supportedMetadataColumns = supportedMetadataColumns; + this.modelClients = modelClients; this.transformExpressionKey = generateTransformExpressionKey(); this.expressionEvaluator = TransformExpressionCompiler.compileExpression( - transformExpressionKey, udfDescriptors); + transformExpressionKey, udfDescriptors, modelClients); this.udfFunctionInstances = udfFunctionInstances; } @@ -72,13 +77,32 @@ public static ProjectionColumnProcessor of( List udfDescriptors, List udfFunctionInstances, Map supportedMetadataColumns) { + return of( + tableInfo, + projectionColumn, + timezone, + udfDescriptors, + udfFunctionInstances, + supportedMetadataColumns, + Collections.emptyMap()); + } + + public static ProjectionColumnProcessor of( + PostTransformChangeInfo tableInfo, + ProjectionColumn projectionColumn, + String timezone, + List udfDescriptors, + List udfFunctionInstances, + Map supportedMetadataColumns, + Map modelClients) { return new ProjectionColumnProcessor( tableInfo, projectionColumn, timezone, udfDescriptors, udfFunctionInstances, - supportedMetadataColumns); + supportedMetadataColumns, + modelClients); } public Object evaluate(Object[] rowData, TransformContext context) { @@ -123,6 +147,9 @@ private Object[] generateParams(Object[] rowData, TransformContext context) { // 3 - Add UDF function instances params.addAll(udfFunctionInstances); + + // 4 - Add AI model client instances + params.addAll(modelClients.values()); return params.toArray(); } diff --git a/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/operators/transform/TransformExpressionCompiler.java b/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/operators/transform/TransformExpressionCompiler.java index 1085a7df82d..776b4af2d0c 100644 --- a/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/operators/transform/TransformExpressionCompiler.java +++ b/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/operators/transform/TransformExpressionCompiler.java @@ -18,6 +18,7 @@ package org.apache.flink.cdc.runtime.operators.transform; import org.apache.flink.api.common.InvalidProgramException; +import org.apache.flink.cdc.common.model.AiModelClient; import org.apache.flink.cdc.runtime.operators.transform.exceptions.TransformException; import org.apache.flink.util.FlinkRuntimeException; @@ -30,7 +31,9 @@ import org.slf4j.LoggerFactory; import java.util.ArrayList; +import java.util.Collections; import java.util.List; +import java.util.Map; /** * The processor of the transform expression. It processes the expression of projections and @@ -54,6 +57,20 @@ public static void cleanUp() { /** Compiles an expression code to a janino {@link ExpressionEvaluator}. */ public static ExpressionEvaluator compileExpression( TransformExpressionKey key, List udfDescriptors) { + return compileExpression(key, udfDescriptors, Collections.emptyMap()); + } + + /** + * Compiles an expression code to a janino {@link ExpressionEvaluator}, with additional {@link + * AiModelClient} instances appended after UDF instances. + * + *

{@code modelClients} maps model names (e.g. {@code myModel}) to the corresponding client + * instances. + */ + public static ExpressionEvaluator compileExpression( + TransformExpressionKey key, + List udfDescriptors, + Map modelClients) { try { return COMPILED_EXPRESSION_CACHE.get( key, @@ -68,6 +85,11 @@ public static ExpressionEvaluator compileExpression( argumentClasses.add(Class.forName(udfFunction.getClasspath())); } + for (String paramName : modelClients.keySet()) { + argumentNames.add(paramName); + argumentClasses.add(AiModelClient.class); + } + // Input args expressionEvaluator.setParameters( argumentNames.toArray(new String[0]), diff --git a/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/operators/transform/TransformFilterProcessor.java b/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/operators/transform/TransformFilterProcessor.java index 77dc5cfe6ef..a72e7aa94e6 100644 --- a/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/operators/transform/TransformFilterProcessor.java +++ b/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/operators/transform/TransformFilterProcessor.java @@ -19,6 +19,7 @@ import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.cdc.common.converter.JavaClassConverter; +import org.apache.flink.cdc.common.model.AiModelClient; import org.apache.flink.cdc.common.schema.Column; import org.apache.flink.cdc.common.source.SupportedMetadataColumn; import org.apache.flink.cdc.runtime.parser.JaninoCompiler; @@ -47,6 +48,7 @@ public class TransformFilterProcessor { private final String timezone; private final List udfFunctionInstances; private final Map supportedMetadataColumns; + private final Map modelClients; private final TransformExpressionKey transformExpressionKey; private final ExpressionEvaluator expressionEvaluator; @@ -58,13 +60,15 @@ protected TransformFilterProcessor( String timezone, List udfDescriptors, List udfFunctionInstances, - Map supportedMetadataColumns) { + Map supportedMetadataColumns, + Map modelClients) { this.isNoOp = isNoOp; this.tableInfo = tableInfo; this.transformFilter = transformFilter; this.timezone = timezone; this.udfFunctionInstances = udfFunctionInstances; this.supportedMetadataColumns = supportedMetadataColumns; + this.modelClients = modelClients; if (isNoOp) { this.transformExpressionKey = null; @@ -79,12 +83,12 @@ protected TransformFilterProcessor( .toArray(new SupportedMetadataColumn[0])); this.expressionEvaluator = TransformExpressionCompiler.compileExpression( - transformExpressionKey, udfDescriptors); + transformExpressionKey, udfDescriptors, modelClients); } } public static TransformFilterProcessor ofNoOp() { - return new TransformFilterProcessor(true, null, null, null, null, null, null); + return new TransformFilterProcessor(true, null, null, null, null, null, null, null); } public static TransformFilterProcessor of( @@ -93,7 +97,8 @@ public static TransformFilterProcessor of( String timezone, List udfDescriptors, List udfFunctionInstances, - SupportedMetadataColumn[] supportedMetadataColumns) { + SupportedMetadataColumn[] supportedMetadataColumns, + Map modelClients) { Map supportedMetadataColumnsMap = new HashMap<>(); for (SupportedMetadataColumn supportedMetadataColumn : supportedMetadataColumns) { supportedMetadataColumnsMap.put( @@ -106,7 +111,8 @@ public static TransformFilterProcessor of( timezone, udfDescriptors, udfFunctionInstances, - supportedMetadataColumnsMap); + supportedMetadataColumnsMap, + modelClients); } public boolean test(Object[] preRow, Object[] postRow, TransformContext context) { @@ -208,6 +214,9 @@ private Object[] generateParams(Object[] preRow, Object[] postRow, TransformCont // 3 - Add UDF function instances params.addAll(udfFunctionInstances); + + // 4 - Add AI model client instances + params.addAll(modelClients.values()); return params.toArray(); } diff --git a/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/operators/transform/TransformProjectionProcessor.java b/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/operators/transform/TransformProjectionProcessor.java index 231e72875ac..09fe5fb19db 100644 --- a/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/operators/transform/TransformProjectionProcessor.java +++ b/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/operators/transform/TransformProjectionProcessor.java @@ -17,6 +17,7 @@ package org.apache.flink.cdc.runtime.operators.transform; +import org.apache.flink.cdc.common.model.AiModelClient; import org.apache.flink.cdc.common.source.SupportedMetadataColumn; import org.apache.flink.cdc.common.utils.Preconditions; import org.apache.flink.cdc.runtime.parser.TransformParser; @@ -52,6 +53,7 @@ public class TransformProjectionProcessor { private final List columnProcessors; private final SupportedMetadataColumn[] supportedMetadataColumns; private final Map supportedMetadataColumnsMap; + private final Map modelClients; public TransformProjectionProcessor( PostTransformChangeInfo changeInfo, @@ -59,13 +61,15 @@ public TransformProjectionProcessor( String timezone, List udfDescriptors, List udfFunctionInstances, - SupportedMetadataColumn[] supportedMetadataColumns) { + SupportedMetadataColumn[] supportedMetadataColumns, + Map modelClients) { this.changeInfo = changeInfo; this.projectionExpression = projectionExpression; this.timezone = timezone; this.udfDescriptors = udfDescriptors; this.udfFunctionInstances = udfFunctionInstances; this.supportedMetadataColumns = supportedMetadataColumns; + this.modelClients = modelClients; // Construct a mapping table ad-hoc to accelerate looking-up Map supportedMetadataColumnsMap = new HashMap<>(); @@ -105,7 +109,8 @@ private List createProjectionColumnProcessors() { timezone, udfDescriptors, udfFunctionInstances, - supportedMetadataColumnsMap)) + supportedMetadataColumnsMap, + modelClients)) .collect(Collectors.toList()); LOG.info("Successfully created projection column processors cache."); diff --git a/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/parser/JaninoCompiler.java b/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/parser/JaninoCompiler.java index 14c483d7082..3093231e2c8 100644 --- a/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/parser/JaninoCompiler.java +++ b/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/parser/JaninoCompiler.java @@ -27,6 +27,8 @@ import org.apache.flink.cdc.common.types.DataTypeRoot; import org.apache.flink.cdc.common.utils.Preconditions; import org.apache.flink.cdc.common.utils.StringUtils; +import org.apache.flink.cdc.runtime.ai.AiEmbeddingFunctionDef; +import org.apache.flink.cdc.runtime.ai.AiTextFunctionDef; import org.apache.flink.cdc.runtime.operators.transform.UserDefinedFunctionDescriptor; import org.apache.calcite.sql.SqlBasicCall; @@ -90,7 +92,7 @@ public class JaninoCompiler { public static final String DEFAULT_TIME_ZONE = "__time_zone__"; private static final String[] BUILTIN_FUNCTION_MODULES = { - "Arithmetic", "Casting", "Comparison", "Logical", "String", "Struct", "Temporal" + "Ai", "Arithmetic", "Casting", "Comparison", "Logical", "String", "Struct", "Temporal" }; @VisibleForTesting @@ -526,23 +528,44 @@ private static Java.Rvalue generateOtherFunctionOperation( context.udfDescriptors.stream() .filter(e -> e.getName().equalsIgnoreCase(operationName)) .findFirst(); - return udfFunctionOptional - .map( - udfFunction -> - new Java.MethodInvocation( - Location.NOWHERE, - null, - generateInvokeExpression(udfFunction), - atoms)) - .orElseGet( - () -> - new Java.MethodInvocation( - Location.NOWHERE, - null, - StringUtils.convertToCamelCase( - sqlBasicCall.getOperator().getName()), - atoms)); + if (udfFunctionOptional.isPresent()) { + return new Java.MethodInvocation( + Location.NOWHERE, + null, + generateInvokeExpression(udfFunctionOptional.get()), + atoms); + } + if (isAiFunction(operationName) && atoms.length >= 1) { + rewriteAiFunctionModelArg(atoms); + } + return new Java.MethodInvocation( + Location.NOWHERE, + null, + StringUtils.convertToCamelCase(sqlBasicCall.getOperator().getName()), + atoms); + } + } + + private static boolean isAiFunction(String upperCaseName) { + for (AiTextFunctionDef def : AiTextFunctionDef.values()) { + if (def.getFunctionName().equals(upperCaseName)) { + return true; + } + } + for (AiEmbeddingFunctionDef def : AiEmbeddingFunctionDef.values()) { + if (def.getFunctionName().equals(upperCaseName)) { + return true; + } + } + return false; + } + + private static void rewriteAiFunctionModelArg(Java.Rvalue[] atoms) { + String modelName = atoms[0].toString(); + if (modelName.startsWith("\"") && modelName.endsWith("\"")) { + modelName = modelName.substring(1, modelName.length() - 1); } + atoms[0] = new Java.AmbiguousName(Location.NOWHERE, new String[] {modelName}); } private static Java.Rvalue generateTimezoneFreeTemporalFunctionOperation( diff --git a/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/parser/TransformParser.java b/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/parser/TransformParser.java index 5ee7153013c..115dfedd7f2 100644 --- a/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/parser/TransformParser.java +++ b/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/parser/TransformParser.java @@ -24,6 +24,7 @@ import org.apache.flink.cdc.common.utils.Preconditions; import org.apache.flink.cdc.runtime.operators.transform.ProjectionColumn; import org.apache.flink.cdc.runtime.operators.transform.UserDefinedFunctionDescriptor; +import org.apache.flink.cdc.runtime.parser.metadata.AiFunctionSqlOperatorTable; import org.apache.flink.cdc.runtime.parser.metadata.TransformSchemaFactory; import org.apache.flink.cdc.runtime.parser.metadata.TransformSqlOperatorTable; import org.apache.flink.cdc.runtime.typeutils.CalciteDataTypeConverter; @@ -163,9 +164,13 @@ private static RelNode sqlToRel( new CalciteConnectionConfigImpl(new Properties())); TransformSqlOperatorTable transformSqlOperatorTable = TransformSqlOperatorTable.instance(); SqlOperatorTable udfOperatorTable = SqlOperatorTables.of(udfFunctions); + SqlOperatorTable aiFunctionOperatorTable = AiFunctionSqlOperatorTable.create(); SqlValidator validator = SqlValidatorUtil.newValidator( - SqlOperatorTables.chain(transformSqlOperatorTable, udfOperatorTable), + SqlOperatorTables.chain( + transformSqlOperatorTable, + udfOperatorTable, + aiFunctionOperatorTable), calciteCatalogReader, factory, SqlValidator.Config.DEFAULT diff --git a/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/parser/metadata/AiFunctionSqlOperatorTable.java b/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/parser/metadata/AiFunctionSqlOperatorTable.java new file mode 100644 index 00000000000..69628c5d327 --- /dev/null +++ b/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/parser/metadata/AiFunctionSqlOperatorTable.java @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.cdc.runtime.parser.metadata; + +import org.apache.flink.cdc.common.types.DataType; +import org.apache.flink.cdc.common.types.RowType; +import org.apache.flink.cdc.runtime.ai.AiEmbeddingFunctionDef; +import org.apache.flink.cdc.runtime.ai.AiTextFunctionDef; +import org.apache.flink.cdc.runtime.typeutils.CalciteDataTypeConverter; + +import org.apache.calcite.sql.SqlFunction; +import org.apache.calcite.sql.SqlFunctionCategory; +import org.apache.calcite.sql.SqlKind; +import org.apache.calcite.sql.type.OperandTypes; +import org.apache.calcite.sql.type.ReturnTypes; +import org.apache.calcite.sql.type.SqlTypeFamily; +import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.calcite.sql.util.SqlOperatorTables; + +import java.util.ArrayList; +import java.util.List; + +/** Creates SqlOperatorTable from {@link AiTextFunctionDef} definitions. */ +public class AiFunctionSqlOperatorTable { + + private AiFunctionSqlOperatorTable() {} + + /** Creates an SqlOperatorTable containing all AI functions defined in AiFunctionDef. */ + public static org.apache.calcite.sql.SqlOperatorTable create() { + List functions = new ArrayList<>(); + for (AiTextFunctionDef def : AiTextFunctionDef.values()) { + functions.add(createTextSqlFunction(def)); + } + for (AiEmbeddingFunctionDef def : AiEmbeddingFunctionDef.values()) { + functions.add(createEmbeddingSqlFunction(def)); + } + return SqlOperatorTables.of(functions); + } + + private static SqlFunction createTextSqlFunction(AiTextFunctionDef def) { + return new SqlFunction( + def.getFunctionName(), + SqlKind.OTHER_FUNCTION, + ReturnTypes.explicit(SqlTypeName.VARIANT), + null, + OperandTypes.family(toSqlTypeFamiliesWithAdditionalParams(def.getInputType())), + SqlFunctionCategory.USER_DEFINED_FUNCTION); + } + + private static SqlFunction createEmbeddingSqlFunction(AiEmbeddingFunctionDef def) { + return new SqlFunction( + def.getFunctionName(), + SqlKind.OTHER_FUNCTION, + opBinding -> + CalciteDataTypeConverter.convertCalciteType( + opBinding.getTypeFactory(), def.getOutputType()), + null, + OperandTypes.family(SqlTypeFamily.STRING, toSqlTypeFamily(def.getInputType())), + SqlFunctionCategory.USER_DEFINED_FUNCTION); + } + + /** + * Converts inputType to SqlTypeFamily array, prepending additional parameters: modelName + * (STRING) and input (STRING). + */ + private static SqlTypeFamily[] toSqlTypeFamiliesWithAdditionalParams(RowType inputType) { + List families = new ArrayList<>(); + families.add(SqlTypeFamily.STRING); // modelName + families.add(SqlTypeFamily.STRING); // input + for (DataType fieldType : inputType.getFieldTypes()) { + families.add(toSqlTypeFamily(fieldType)); + } + return families.toArray(new SqlTypeFamily[0]); + } + + private static SqlTypeFamily toSqlTypeFamily(DataType dataType) { + switch (dataType.getTypeRoot()) { + case VARCHAR: + case CHAR: + return SqlTypeFamily.STRING; + case INTEGER: + return SqlTypeFamily.INTEGER; + case BIGINT: + return SqlTypeFamily.NUMERIC; + case FLOAT: + case DOUBLE: + return SqlTypeFamily.APPROXIMATE_NUMERIC; + case BOOLEAN: + return SqlTypeFamily.BOOLEAN; + default: + return SqlTypeFamily.ANY; + } + } +} diff --git a/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/parser/metadata/TransformSqlOperatorTable.java b/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/parser/metadata/TransformSqlOperatorTable.java index eef0755c112..26408270ccc 100644 --- a/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/parser/metadata/TransformSqlOperatorTable.java +++ b/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/parser/metadata/TransformSqlOperatorTable.java @@ -406,38 +406,6 @@ public SqlSyntax getSyntax() { // Supports accessing elements of ARRAY[index], ROW[index], MAP[key], and VARIANT[index/key] public static final SqlOperator ITEM = new VariantAwareItemOperator(); - public static final SqlFunction AI_CHAT_PREDICT = - new SqlFunction( - "AI_CHAT_PREDICT", - SqlKind.OTHER_FUNCTION, - ReturnTypes.explicit(SqlTypeName.VARCHAR), - null, - OperandTypes.family( - SqlTypeFamily.STRING, SqlTypeFamily.STRING, SqlTypeFamily.STRING), - SqlFunctionCategory.USER_DEFINED_FUNCTION); - - // Define the AI_EMBEDDING function - public static final SqlFunction GET_EMBEDDING = - new SqlFunction( - "GET_EMBEDDING", - SqlKind.OTHER_FUNCTION, - ReturnTypes.explicit(SqlTypeName.VARCHAR), - null, - OperandTypes.family( - SqlTypeFamily.STRING, SqlTypeFamily.STRING, SqlTypeFamily.STRING), - SqlFunctionCategory.USER_DEFINED_FUNCTION); - - // Define the AI_LANGCHAIN_PREDICT function - public static final SqlFunction AI_LANGCHAIN_PREDICT = - new SqlFunction( - "AI_LANGCHAIN_PREDICT", - SqlKind.OTHER_FUNCTION, - ReturnTypes.explicit(SqlTypeName.VARCHAR), - null, - OperandTypes.family( - SqlTypeFamily.STRING, SqlTypeFamily.STRING, SqlTypeFamily.STRING), - SqlFunctionCategory.USER_DEFINED_FUNCTION); - // -------------------------------------------------------------------------------------------- // Variant functions // -------------------------------------------------------------------------------------------- diff --git a/flink-cdc-runtime/src/test/java/org/apache/flink/cdc/runtime/functions/impl/AiFunctionsTest.java b/flink-cdc-runtime/src/test/java/org/apache/flink/cdc/runtime/functions/impl/AiFunctionsTest.java new file mode 100644 index 00000000000..01498675aad --- /dev/null +++ b/flink-cdc-runtime/src/test/java/org/apache/flink/cdc/runtime/functions/impl/AiFunctionsTest.java @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.cdc.runtime.functions.impl; + +import org.apache.flink.cdc.common.model.AiModelClient; +import org.apache.flink.cdc.common.model.abilities.SupportsEmbedding; +import org.apache.flink.cdc.common.model.abilities.SupportsTextGeneration; + +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** Unit tests for {@link AiFunctions}. */ +class AiFunctionsTest { + + private static class MockModelClient + implements AiModelClient, SupportsTextGeneration, SupportsEmbedding { + private static final long serialVersionUID = 1L; + + @Override + public String generate(String systemPrompt, String userInput) { + if (systemPrompt.contains("summarization expert")) { + return "{\"summary\": \"This is a summary.\"}"; + } + // Default for AI_COMPLETE + return "{\"result\": \"hello world\"}"; + } + + @Override + public float[] embed(String text) { + return new float[] {0.1f, 0.2f, 0.3f, 0.4f, 0.5f}; + } + } + + private static class UselessClient implements AiModelClient {} + + @Test + void testAiFunctionInvocation() { + MockModelClient model = new MockModelClient(); + assertThat(AiFunctions.aiComplete(model, "Say hello", "You are a helpful assistant.")) + .hasToString("{\"result\":\"hello world\"}"); + assertThat(AiFunctions.aiSummarize(model, "Long text here...", 100)) + .hasToString("{\"summary\":\"This is a summary.\"}"); + assertThat(AiFunctions.aiEmbed(model, "Test text")) + .containsExactly(0.1f, 0.2f, 0.3f, 0.4f, 0.5f); + } + + @Test + void testUnsupportedModel() { + UselessClient model = new UselessClient(); + assertThatThrownBy(() -> AiFunctions.aiComplete(model, "test", "prompt")) + .isInstanceOf(UnsupportedOperationException.class) + .hasMessageContaining("does not support text generation"); + assertThatThrownBy(() -> AiFunctions.aiSummarize(model, "test", 100)) + .isInstanceOf(UnsupportedOperationException.class) + .hasMessageContaining("does not support text generation"); + assertThatThrownBy(() -> AiFunctions.aiEmbed(model, "test")) + .isInstanceOf(UnsupportedOperationException.class) + .hasMessageContaining("does not support embedding"); + } + + @Test + void testAiCompleteWithInvalidJsonResponse() { + MockModelClient model = + new MockModelClient() { + @Override + public String generate(String systemPrompt, String userInput) { + return "invalid json"; + } + }; + + assertThatThrownBy(() -> AiFunctions.aiComplete(model, "test", "prompt")) + .isInstanceOf(RuntimeException.class) + .hasMessageContaining("Failed to parse AI response as JSON"); + } + + @Test + void testAiCompleteWithEmptyJsonResponse() { + MockModelClient model = + new MockModelClient() { + @Override + public String generate(String systemPrompt, String userInput) { + return "{}"; + } + }; + assertThat(AiFunctions.aiComplete(model, "test", "prompt")).hasToString("{}"); + } + + @Test + void testAiCompleteWithNullResponse() { + MockModelClient model = + new MockModelClient() { + @Override + public String generate(String systemPrompt, String userInput) { + return null; + } + }; + assertThat(AiFunctions.aiComplete(model, "test", "prompt")).isNull(); + assertThat(AiFunctions.aiSummarize(model, "test", 100)).isNull(); + } + + @Test + void testAiFunctionCornerCase() { + MockModelClient model = new MockModelClient(); + assertThat(AiFunctions.aiEmbed(model, "")).containsExactly(0.1f, 0.2f, 0.3f, 0.4f, 0.5f); + assertThat(AiFunctions.aiEmbed(model, "中文嵌入测试")) + .containsExactly(0.1f, 0.2f, 0.3f, 0.4f, 0.5f); + } +} diff --git a/flink-cdc-runtime/src/test/java/org/apache/flink/cdc/runtime/operators/transform/UserDefinedFunctionDescriptorTest.java b/flink-cdc-runtime/src/test/java/org/apache/flink/cdc/runtime/operators/transform/UserDefinedFunctionDescriptorTest.java index 79f9c92af5b..bfd1bb098d4 100644 --- a/flink-cdc-runtime/src/test/java/org/apache/flink/cdc/runtime/operators/transform/UserDefinedFunctionDescriptorTest.java +++ b/flink-cdc-runtime/src/test/java/org/apache/flink/cdc/runtime/operators/transform/UserDefinedFunctionDescriptorTest.java @@ -20,10 +20,8 @@ import org.apache.flink.cdc.common.types.DataType; import org.apache.flink.cdc.common.types.DataTypes; import org.apache.flink.cdc.common.udf.UserDefinedFunction; -import org.apache.flink.cdc.runtime.model.OpenAIEmbeddingModel; import org.apache.flink.table.functions.ScalarFunction; -import com.fasterxml.jackson.core.JsonProcessingException; import org.junit.jupiter.api.Test; import static org.assertj.core.api.Assertions.assertThat; @@ -50,7 +48,7 @@ public static class FlinkUdf extends ScalarFunction {} public static class NotUDF {} @Test - void testUserDefinedFunctionDescriptor() throws JsonProcessingException { + void testUserDefinedFunctionDescriptor() { assertThat(new UserDefinedFunctionDescriptor("cdc_udf", CdcUdf.class.getName())) .extracting("name", "className", "classpath", "returnTypeHint", "isCdcPipelineUdf") @@ -95,14 +93,5 @@ void testUserDefinedFunctionDescriptor() throws JsonProcessingException { "not_even_exist", "not.a.valid.class.path")) .isExactlyInstanceOf(IllegalArgumentException.class) .hasMessage("Failed to instantiate UDF not_even_exist@not.a.valid.class.path"); - String name = "GET_EMBEDDING"; - assertThat(new UserDefinedFunctionDescriptor(name, OpenAIEmbeddingModel.class.getName())) - .extracting("name", "className", "classpath", "returnTypeHint", "isCdcPipelineUdf") - .containsExactly( - "GET_EMBEDDING", - "OpenAIEmbeddingModel", - "org.apache.flink.cdc.runtime.model.OpenAIEmbeddingModel", - DataTypes.ARRAY(DataTypes.FLOAT()), - true); } } diff --git a/flink-cdc-runtime/src/test/java/org/apache/flink/cdc/runtime/parser/AiFunctionParserTest.java b/flink-cdc-runtime/src/test/java/org/apache/flink/cdc/runtime/parser/AiFunctionParserTest.java new file mode 100644 index 00000000000..af1d225e665 --- /dev/null +++ b/flink-cdc-runtime/src/test/java/org/apache/flink/cdc/runtime/parser/AiFunctionParserTest.java @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.cdc.runtime.parser; + +import org.apache.flink.cdc.common.schema.Column; +import org.apache.flink.cdc.common.source.SupportedMetadataColumn; +import org.apache.flink.cdc.common.types.DataTypes; +import org.apache.flink.cdc.runtime.operators.transform.ProjectionColumn; + +import org.junit.jupiter.api.Test; + +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** Unit tests for {@link TransformParser} and {@link JaninoCompiler} with AI functions. */ +class AiFunctionParserTest { + + private static final List DUMMY_COLUMNS = + List.of( + Column.physicalColumn("id", DataTypes.INT()), + Column.physicalColumn("content", DataTypes.STRING()), + Column.physicalColumn("text", DataTypes.STRING()), + Column.physicalColumn("flag", DataTypes.BOOLEAN())); + private static final Map COLUMN_MAP = + Map.of( + "id", "$1", + "content", "$2", + "text", "$3", + "flag", "$4"); + + @Test + void testTranslateAiFunctionInProjection() { + assertThat( + translateAsProjection( + "id, AI_COMPLETE('myModel', content, 'You are a classifier') AS output_col")) + .map(ProjectionColumn::toString) + .containsExactly( + "ProjectionColumn{column=`id` INT, expression='id', scriptExpression='$0', originalColumnNames=[id], columnNameMap={id=$0}}", + "ProjectionColumn{column=`output_col` VARIANT, expression='AI_COMPLETE('myModel', `TB`.`content`, 'You are a classifier')', scriptExpression='aiComplete(myModel, $0, \"You are a classifier\")', originalColumnNames=[content], columnNameMap={content=$0}}"); + + assertThat(translateAsProjection("id, AI_SUMMARIZE('m1', content, 100) AS summary")) + .map(ProjectionColumn::toString) + .containsExactly( + "ProjectionColumn{column=`id` INT, expression='id', scriptExpression='$0', originalColumnNames=[id], columnNameMap={id=$0}}", + "ProjectionColumn{column=`summary` VARIANT, expression='AI_SUMMARIZE('m1', `TB`.`content`, 100)', scriptExpression='aiSummarize(m1, $0, 100)', originalColumnNames=[content], columnNameMap={content=$0}}"); + + assertThatThrownBy( + () -> translateAsProjection("AI_SUMMARIZE('m', content, flag) AS out_col")) + .hasMessageContaining( + "Cannot apply 'AI_SUMMARIZE' to arguments of type 'AI_SUMMARIZE(, , )'.") + .hasMessageContaining( + "Supported form(s): 'AI_SUMMARIZE(, , )'"); + + assertThatThrownBy(() -> translateAsProjection("AI_EMBED('m') AS out_col")) + .hasMessageContaining("Invalid number of arguments to function 'AI_EMBED'.") + .hasMessageContaining("Was expecting 2 arguments"); + + assertThatThrownBy(() -> translateAsProjection("AI_COMPLETE('m', content) AS out_col")) + .hasMessageContaining("Invalid number of arguments to function 'AI_COMPLETE'.") + .hasMessageContaining("Was expecting 3 arguments"); + } + + @Test + void testTranslateAiFunctionInFilter() { + assertThat(translateAsFilter("AI_COMPLETE('myModel', content, 'Classify this text')")) + .isEqualTo("aiComplete(myModel, $2, \"Classify this text\")"); + assertThat(translateAsFilter("AI_SUMMARIZE('summarizer', content, 100)")) + .isEqualTo("aiSummarize(summarizer, $2, 100)"); + assertThat(translateAsFilter("AI_EMBED('embedder', content)")) + .isEqualTo("aiEmbed(embedder, $2)"); + assertThat(translateAsFilter("ai_complete('myModel', content, 'prompt')")) + .isEqualTo("aiComplete(myModel, $2, \"prompt\")"); + } + + private List translateAsProjection(String expression) { + return TransformParser.generateProjectionColumns( + expression, DUMMY_COLUMNS, Collections.emptyList(), new SupportedMetadataColumn[0]); + } + + private String translateAsFilter(String expression) { + return TransformParser.translateFilterExpressionToJaninoExpression( + expression, + DUMMY_COLUMNS, + Collections.emptyList(), + new SupportedMetadataColumn[0], + COLUMN_MAP); + } +} From 9ac7e5a18e4ba6126b376f85b68575053ff75ea0 Mon Sep 17 00:00:00 2001 From: yux <34335406+yuxiqian@users.noreply.github.com> Date: Tue, 12 May 2026 16:09:07 +0800 Subject: [PATCH 02/11] [FLINK-39568-2] Add openai-compatible model provider --- .../flink-cdc-pipeline-e2e-tests/pom.xml | 22 +++ .../pipeline/tests/AiFunctionE2eITCase.java | 133 +++++++++--------- .../pom.xml | 88 ++++++++++++ .../openai/OpenAiCompatibleModelClient.java | 90 ++++++++++++ .../OpenAiCompatibleModelClientFactory.java | 51 +++++++ ...link.cdc.common.model.AiModelClientFactory | 16 +++ ...penAiCompatibleModelClientFactoryTest.java | 118 ++++++++++++++++ .../OpenAiCompatibleModelClientITCase.java | 63 +++++++++ .../src/test/resources/log4j2-test.properties | 25 ++++ flink-cdc-pipeline-model/pom.xml | 1 + 10 files changed, 544 insertions(+), 63 deletions(-) create mode 100644 flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/pom.xml create mode 100644 flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/src/main/java/org/apache/flink/cdc/models/openai/OpenAiCompatibleModelClient.java create mode 100644 flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/src/main/java/org/apache/flink/cdc/models/openai/OpenAiCompatibleModelClientFactory.java create mode 100644 flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/src/main/resources/META-INF/services/org.apache.flink.cdc.common.model.AiModelClientFactory create mode 100644 flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/src/test/java/org/apache/flink/cdc/models/openai/OpenAiCompatibleModelClientFactoryTest.java create mode 100644 flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/src/test/java/org/apache/flink/cdc/models/openai/OpenAiCompatibleModelClientITCase.java create mode 100644 flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/src/test/resources/log4j2-test.properties diff --git a/flink-cdc-e2e-tests/flink-cdc-pipeline-e2e-tests/pom.xml b/flink-cdc-e2e-tests/flink-cdc-pipeline-e2e-tests/pom.xml index 6c1aa689fec..733bf3cb77d 100644 --- a/flink-cdc-e2e-tests/flink-cdc-pipeline-e2e-tests/pom.xml +++ b/flink-cdc-e2e-tests/flink-cdc-pipeline-e2e-tests/pom.xml @@ -229,6 +229,18 @@ limitations under the License. ${project.version} test + + org.apache.flink + flink-cdc-pipeline-model-dummy + ${project.version} + test + + + org.apache.flink + flink-cdc-pipeline-model-openai-compatible + ${project.version} + test + @@ -553,6 +565,16 @@ limitations under the License. + + org.apache.flink + flink-cdc-pipeline-model-openai-compatible + ${project.version} + openai-compatible-model.jar + jar + ${project.build.directory}/dependencies + + + org.apache.flink flink-cdc-pipeline-connector-mysql diff --git a/flink-cdc-e2e-tests/flink-cdc-pipeline-e2e-tests/src/test/java/org/apache/flink/cdc/pipeline/tests/AiFunctionE2eITCase.java b/flink-cdc-e2e-tests/flink-cdc-pipeline-e2e-tests/src/test/java/org/apache/flink/cdc/pipeline/tests/AiFunctionE2eITCase.java index c25e72df103..c3b21bed1b9 100644 --- a/flink-cdc-e2e-tests/flink-cdc-pipeline-e2e-tests/src/test/java/org/apache/flink/cdc/pipeline/tests/AiFunctionE2eITCase.java +++ b/flink-cdc-e2e-tests/flink-cdc-pipeline-e2e-tests/src/test/java/org/apache/flink/cdc/pipeline/tests/AiFunctionE2eITCase.java @@ -20,20 +20,15 @@ import org.apache.flink.cdc.common.test.utils.TestUtils; import org.apache.flink.cdc.pipeline.tests.utils.PipelineTestEnvironment; +import org.assertj.core.api.Assumptions; import org.junit.jupiter.api.Test; import java.nio.file.Path; import java.time.Duration; -/** E2e tests for AI functions with the dummy model SPI. */ +/** E2e tests for AI functions with the dummy model and openai-compatible model. */ class AiFunctionE2eITCase extends PipelineTestEnvironment { - private static final String TABLE_1 = "default_namespace.default_schema.table1"; - private static final String TABLE_2 = "default_namespace.default_schema.table2"; - private static final String DUMMY_JSON = - "{\"result\":\"dummy result\",\"summary\":\"TL;DR\"}"; - private static final String DUMMY_EMBEDDING = "[3.0, 1.0, 4.0, 1.0, 5.0, 9.0, 2.0, 6.0]"; - @Test void testAiFunctionsWithDummyModel() throws Exception { String pipelineJob = @@ -45,13 +40,9 @@ void testAiFunctionsWithDummyModel() throws Exception { + " type: values\n" + "\n" + "transform:\n" - + " - source-table: " - + TABLE_1 - + "\n" + + " - source-table: default_namespace.default_schema.table1\n" + " projection: col1, AI_COMPLETE('myModel', col1, 'Classify into catA or catB') AS cls\n" - + " - source-table: " - + TABLE_2 - + "\n" + + " - source-table: default_namespace.default_schema.table2\n" + " projection: col1, AI_EMBED('myModel', col1) AS embedding\n" + "\n" + "pipeline:\n" @@ -66,58 +57,74 @@ void testAiFunctionsWithDummyModel() throws Exception { submitPipelineJob(pipelineJob, dummyModelJar); waitUntilJobFinished(Duration.ofMinutes(3)); - validateResult("Successfully opened AI model client 'myModel'."); validateResult( - "CreateTableEvent{tableId=" - + TABLE_1 - + ", schema=columns={`col1` STRING NOT NULL,`cls` VARIANT}, primaryKeys=col1, options=()}", - "DataChangeEvent{tableId=" - + TABLE_1 - + ", before=[], after=[1, " - + DUMMY_JSON - + "], op=INSERT, meta=()}", - "DataChangeEvent{tableId=" - + TABLE_1 - + ", before=[], after=[2, " - + DUMMY_JSON - + "], op=INSERT, meta=()}", - "DataChangeEvent{tableId=" - + TABLE_1 - + ", before=[], after=[3, " - + DUMMY_JSON - + "], op=INSERT, meta=()}", - "DataChangeEvent{tableId=" - + TABLE_1 - + ", before=[1, " - + DUMMY_JSON - + "], after=[], op=DELETE, meta=()}", - "DataChangeEvent{tableId=" - + TABLE_1 - + ", before=[2, " - + DUMMY_JSON - + "], after=[2, " - + DUMMY_JSON - + "], op=UPDATE, meta=()}"); + "Successfully opened AI model client 'myModel'.", + "CreateTableEvent{tableId=default_namespace.default_schema.table1, schema=columns={`col1` STRING NOT NULL,`cls` VARIANT}, primaryKeys=col1, options=()}", + "DataChangeEvent{tableId=default_namespace.default_schema.table1, before=[], after=[1, {\"result\":\"dummy result\",\"summary\":\"TL;DR\"}], op=INSERT, meta=()}", + "DataChangeEvent{tableId=default_namespace.default_schema.table1, before=[], after=[2, {\"result\":\"dummy result\",\"summary\":\"TL;DR\"}], op=INSERT, meta=()}", + "DataChangeEvent{tableId=default_namespace.default_schema.table1, before=[], after=[3, {\"result\":\"dummy result\",\"summary\":\"TL;DR\"}], op=INSERT, meta=()}", + "DataChangeEvent{tableId=default_namespace.default_schema.table1, before=[1, {\"result\":\"dummy result\",\"summary\":\"TL;DR\"}], after=[], op=DELETE, meta=()}", + "DataChangeEvent{tableId=default_namespace.default_schema.table1, before=[2, {\"result\":\"dummy result\",\"summary\":\"TL;DR\"}], after=[2, {\"result\":\"dummy result\",\"summary\":\"TL;DR\"}], op=UPDATE, meta=()}", + "CreateTableEvent{tableId=default_namespace.default_schema.table2, schema=columns={`col1` STRING NOT NULL,`embedding` ARRAY}, primaryKeys=col1, options=()}", + "DataChangeEvent{tableId=default_namespace.default_schema.table2, before=[], after=[1, [3.0, 1.0, 4.0, 1.0, 5.0, 9.0, 2.0, 6.0]], op=INSERT, meta=()}", + "DataChangeEvent{tableId=default_namespace.default_schema.table2, before=[], after=[2, [3.0, 1.0, 4.0, 1.0, 5.0, 9.0, 2.0, 6.0]], op=INSERT, meta=()}", + "DataChangeEvent{tableId=default_namespace.default_schema.table2, before=[], after=[3, [3.0, 1.0, 4.0, 1.0, 5.0, 9.0, 2.0, 6.0]], op=INSERT, meta=()}", + "Successfully closed AI model client 'myModel'."); + } + + @Test + void testAiFunctionsWithOpenAiCompatibleModel() throws Exception { + String endpoint = System.getenv("OPENAI_BASE_URL"); + String apiKey = System.getenv("OPENAI_API_KEY"); + String model = System.getenv("OPENAI_MODEL"); + Assumptions.assumeThat(endpoint != null && apiKey != null && model != null) + .as("OPENAI_BASE_URL, OPENAI_API_KEY and OPENAI_MODEL must be set") + .isTrue(); + + String pipelineJob = + "source:\n" + + " type: values\n" + + " event-set.id: SINGLE_SPLIT_MULTI_TABLES\n" + + "\n" + + "sink:\n" + + " type: values\n" + + "\n" + + "transform:\n" + + " - source-table: default_namespace.default_schema.table1\n" + + " projection: col1, AI_COMPLETE('openaiModel', col1, 'Reply only the negative value of input number') AS reversed\n" + + " - source-table: default_namespace.default_schema.table2\n" + + " projection: col1, AI_SUMMARIZE('openaiModel', col1, 50) AS summary\n" + + "\n" + + "pipeline:\n" + + " parallelism: 1\n" + + " schema.change.behavior: evolve\n" + + " model:\n" + + " name: openaiModel\n" + + " type: openai-compatible\n" + + " endpoint: " + + endpoint + + "\n" + + " api-key: " + + apiKey + + "\n" + + " model-name: " + + model + + "\n"; + + Path openaiModelJar = TestUtils.getResource("openai-compatible-model.jar"); + submitPipelineJob(pipelineJob, openaiModelJar); + waitUntilJobFinished(Duration.ofMinutes(5)); validateResult( - "CreateTableEvent{tableId=" - + TABLE_2 - + ", schema=columns={`col1` STRING NOT NULL,`embedding` ARRAY}, primaryKeys=col1, options=()}", - "DataChangeEvent{tableId=" - + TABLE_2 - + ", before=[], after=[1, " - + DUMMY_EMBEDDING - + "], op=INSERT, meta=()}", - "DataChangeEvent{tableId=" - + TABLE_2 - + ", before=[], after=[2, " - + DUMMY_EMBEDDING - + "], op=INSERT, meta=()}", - "DataChangeEvent{tableId=" - + TABLE_2 - + ", before=[], after=[3, " - + DUMMY_EMBEDDING - + "], op=INSERT, meta=()}"); - validateResult("Successfully closed AI model client 'myModel'."); + "CreateTableEvent{tableId=default_namespace.default_schema.table1, schema=columns={`col1` STRING NOT NULL,`reversed` VARIANT}, primaryKeys=col1, options=()}", + "CreateTableEvent{tableId=default_namespace.default_schema.table2, schema=columns={`col1` STRING NOT NULL,`summary` VARIANT}, primaryKeys=col1, options=()}", + "DataChangeEvent{tableId=default_namespace.default_schema.table1, before=[], after=[1, {\"result\":-1}], op=INSERT, meta=()}", + "DataChangeEvent{tableId=default_namespace.default_schema.table1, before=[], after=[2, {\"result\":-2}], op=INSERT, meta=()}", + "DataChangeEvent{tableId=default_namespace.default_schema.table1, before=[], after=[3, {\"result\":-3}], op=INSERT, meta=()}", + "DataChangeEvent{tableId=default_namespace.default_schema.table2, before=[], after=[1, {\"summary\":\"1\"}], op=INSERT, meta=()}", + "DataChangeEvent{tableId=default_namespace.default_schema.table2, before=[], after=[2, {\"summary\":\"2\"}], op=INSERT, meta=()}", + "DataChangeEvent{tableId=default_namespace.default_schema.table2, before=[], after=[3, {\"summary\":\"3\"}], op=INSERT, meta=()}", + "DataChangeEvent{tableId=default_namespace.default_schema.table1, before=[1, {\"result\":-1}], after=[], op=DELETE, meta=()}", + "DataChangeEvent{tableId=default_namespace.default_schema.table1, before=[2, {\"result\":-2}], after=[2, {\"result\":-2}], op=UPDATE, meta=()}"); } } diff --git a/flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/pom.xml b/flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/pom.xml new file mode 100644 index 00000000000..53ba797441a --- /dev/null +++ b/flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/pom.xml @@ -0,0 +1,88 @@ + + + + + + org.apache.flink + flink-cdc-pipeline-model + ${revision} + + + 4.0.0 + + flink-cdc-pipeline-model-openai-compatible + jar + + + 2.13.4 + + + + + com.openai + openai-java + 4.32.0 + + + + + + + org.apache.maven.plugins + maven-shade-plugin + + + shade-flink + + + + *:* + + + + + *:* + + META-INF/services/com.fasterxml.** + META-INF/services/kotlin.** + + + + + + com.fasterxml + org.apache.flink.cdc.models.openai.shaded.com.fasterxml + + + okhttp3 + org.apache.flink.cdc.models.openai.shaded.okhttp3 + + + okio + org.apache.flink.cdc.models.openai.shaded.okio + + + + + + + + + diff --git a/flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/src/main/java/org/apache/flink/cdc/models/openai/OpenAiCompatibleModelClient.java b/flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/src/main/java/org/apache/flink/cdc/models/openai/OpenAiCompatibleModelClient.java new file mode 100644 index 00000000000..5aed57a35b3 --- /dev/null +++ b/flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/src/main/java/org/apache/flink/cdc/models/openai/OpenAiCompatibleModelClient.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.cdc.models.openai; + +import org.apache.flink.cdc.common.model.AiModelClient; +import org.apache.flink.cdc.common.model.abilities.SupportsEmbedding; +import org.apache.flink.cdc.common.model.abilities.SupportsTextGeneration; + +import com.openai.client.OpenAIClient; +import com.openai.client.okhttp.OpenAIOkHttpClient; +import com.openai.models.chat.completions.ChatCompletion; +import com.openai.models.chat.completions.ChatCompletionCreateParams; +import com.openai.models.embeddings.CreateEmbeddingResponse; +import com.openai.models.embeddings.Embedding; +import com.openai.models.embeddings.EmbeddingCreateParams; + +import java.util.List; + +/** AI model client that connects to any OpenAI-compatible endpoint. */ +public class OpenAiCompatibleModelClient + implements AiModelClient, SupportsTextGeneration, SupportsEmbedding { + + private static final long serialVersionUID = 1L; + + private final String endpoint; + private final String apiKey; + private final String modelName; + + private transient OpenAIClient client; + + public OpenAiCompatibleModelClient(String endpoint, String apiKey, String modelName) { + this.endpoint = endpoint; + this.apiKey = apiKey; + this.modelName = modelName; + } + + @Override + public void open() { + client = OpenAIOkHttpClient.builder().baseUrl(endpoint).apiKey(apiKey).build(); + } + + @Override + public void close() { + client = null; + } + + @Override + public String generate(String systemPrompt, String userInput) { + ChatCompletionCreateParams params = + ChatCompletionCreateParams.builder() + .model(modelName) + .addSystemMessage(systemPrompt) + .addUserMessage(userInput) + .build(); + ChatCompletion completion = client.chat().completions().create(params); + return completion.choices().get(0).message().content().orElse(null); + } + + @Override + public float[] embed(String text) { + EmbeddingCreateParams params = + EmbeddingCreateParams.builder().model(modelName).input(text).build(); + CreateEmbeddingResponse response = client.embeddings().create(params); + List data = response.data(); + if (data.isEmpty()) { + return new float[0]; + } + List embedding = data.get(0).embedding(); + float[] result = new float[embedding.size()]; + for (int i = 0; i < result.length; i++) { + result[i] = embedding.get(i); + } + return result; + } +} diff --git a/flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/src/main/java/org/apache/flink/cdc/models/openai/OpenAiCompatibleModelClientFactory.java b/flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/src/main/java/org/apache/flink/cdc/models/openai/OpenAiCompatibleModelClientFactory.java new file mode 100644 index 00000000000..c0050dc03dc --- /dev/null +++ b/flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/src/main/java/org/apache/flink/cdc/models/openai/OpenAiCompatibleModelClientFactory.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.cdc.models.openai; + +import org.apache.flink.cdc.common.model.AiModelClient; +import org.apache.flink.cdc.common.model.AiModelClientFactory; +import org.apache.flink.cdc.common.model.ModelContext; + +import java.util.Set; + +/** SPI factory for {@link OpenAiCompatibleModelClient}. */ +public class OpenAiCompatibleModelClientFactory implements AiModelClientFactory { + + @Override + public String identifier() { + return "openai-compatible"; + } + + @Override + public Set requiredOptions() { + return Set.of("endpoint", "api-key", "model-name"); + } + + @Override + public Set optionalOptions() { + return Set.of(); + } + + @Override + public AiModelClient createClient(ModelContext context) { + String endpoint = context.getOptions().get("endpoint"); + String apiKey = context.getOptions().get("api-key"); + String modelName = context.getOptions().get("model-name"); + return new OpenAiCompatibleModelClient(endpoint, apiKey, modelName); + } +} diff --git a/flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/src/main/resources/META-INF/services/org.apache.flink.cdc.common.model.AiModelClientFactory b/flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/src/main/resources/META-INF/services/org.apache.flink.cdc.common.model.AiModelClientFactory new file mode 100644 index 00000000000..017156db45f --- /dev/null +++ b/flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/src/main/resources/META-INF/services/org.apache.flink.cdc.common.model.AiModelClientFactory @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +org.apache.flink.cdc.models.openai.OpenAiCompatibleModelClientFactory diff --git a/flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/src/test/java/org/apache/flink/cdc/models/openai/OpenAiCompatibleModelClientFactoryTest.java b/flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/src/test/java/org/apache/flink/cdc/models/openai/OpenAiCompatibleModelClientFactoryTest.java new file mode 100644 index 00000000000..585ab25adec --- /dev/null +++ b/flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/src/test/java/org/apache/flink/cdc/models/openai/OpenAiCompatibleModelClientFactoryTest.java @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.cdc.models.openai; + +import org.apache.flink.cdc.common.model.AiModelClient; +import org.apache.flink.cdc.common.model.ModelContext; +import org.apache.flink.cdc.common.model.abilities.SupportsEmbedding; +import org.apache.flink.cdc.common.model.abilities.SupportsTextGeneration; + +import org.junit.jupiter.api.Test; + +import java.util.HashMap; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +class OpenAiCompatibleModelClientFactoryTest { + + private final OpenAiCompatibleModelClientFactory factory = + new OpenAiCompatibleModelClientFactory(); + + private ModelContext contextWithOptions(Map options) { + return new ModelContext() { + @Override + public String getModelName() { + return "test-model"; + } + + @Override + public Map getOptions() { + return options; + } + + @Override + public ClassLoader getClassLoader() { + return Thread.currentThread().getContextClassLoader(); + } + }; + } + + @Test + void testIdentifier() { + assertThat(factory.identifier()).isEqualTo("openai-compatible"); + } + + @Test + void testRequiredOptions() { + assertThat(factory.requiredOptions()) + .containsExactlyInAnyOrder("endpoint", "api-key", "model-name"); + } + + @Test + void testOptionalOptions() { + assertThat(factory.optionalOptions()).isEmpty(); + } + + @Test + void testCreateClient() { + Map options = new HashMap<>(); + options.put("endpoint", "https://api.example.com/v1"); + options.put("api-key", "sk-test"); + options.put("model-name", "gpt-4"); + + AiModelClient client = factory.createClient(contextWithOptions(options)); + assertThat(client).isInstanceOf(OpenAiCompatibleModelClient.class); + assertThat(client).isInstanceOf(SupportsTextGeneration.class); + assertThat(client).isInstanceOf(SupportsEmbedding.class); + } + + @Test + void testValidatePassesWithAllRequiredOptions() { + Map options = new HashMap<>(); + options.put("endpoint", "https://api.example.com/v1"); + options.put("api-key", "sk-test"); + options.put("model-name", "gpt-4"); + + factory.validate(contextWithOptions(options)); + } + + @Test + void testValidateThrowsOnMissingRequiredOption() { + Map options = new HashMap<>(); + options.put("endpoint", "https://api.example.com/v1"); + + assertThatThrownBy(() -> factory.validate(contextWithOptions(options))) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Missing required options"); + } + + @Test + void testValidateThrowsOnUnknownOption() { + Map options = new HashMap<>(); + options.put("endpoint", "https://api.example.com/v1"); + options.put("api-key", "sk-test"); + options.put("model-name", "gpt-4"); + options.put("unknown-key", "value"); + + assertThatThrownBy(() -> factory.validate(contextWithOptions(options))) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Unknown options"); + } +} diff --git a/flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/src/test/java/org/apache/flink/cdc/models/openai/OpenAiCompatibleModelClientITCase.java b/flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/src/test/java/org/apache/flink/cdc/models/openai/OpenAiCompatibleModelClientITCase.java new file mode 100644 index 00000000000..33a3bc38934 --- /dev/null +++ b/flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/src/test/java/org/apache/flink/cdc/models/openai/OpenAiCompatibleModelClientITCase.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.cdc.models.openai; + +import org.assertj.core.api.Assumptions; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; + +class OpenAiCompatibleModelClientITCase { + + private OpenAiCompatibleModelClient client; + + @BeforeEach + void setUp() { + String endpoint = System.getenv("OPENAI_BASE_URL"); + String apiKey = System.getenv("OPENAI_API_KEY"); + String model = System.getenv("OPENAI_MODEL"); + Assumptions.assumeThat(endpoint != null && apiKey != null && model != null) + .as("OPENAI_BASE_URL, OPENAI_API_KEY and OPENAI_MODEL must be set") + .isTrue(); + + client = new OpenAiCompatibleModelClient(endpoint, apiKey, model); + client.open(); + } + + @AfterEach + void tearDown() { + if (client != null) { + client.close(); + } + } + + @Test + void testGenerate() { + String result = + client.generate("You are a calculator.", "What is 1 + 1? Answer only the number."); + assertThat(result).isNotNull().contains("2"); + } + + @Test + void testGenerateWithEmptyUserInput() { + String result = client.generate("Reply with exactly: OK", ""); + assertThat(result).isNotNull().contains("OK"); + } +} diff --git a/flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/src/test/resources/log4j2-test.properties b/flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/src/test/resources/log4j2-test.properties new file mode 100644 index 00000000000..0d45bab8011 --- /dev/null +++ b/flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/src/test/resources/log4j2-test.properties @@ -0,0 +1,25 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +# Set root logger level to ERROR to not flood build logs +# set manually to INFO for debugging purposes +rootLogger.level = ERROR +rootLogger.appenderRef.test.ref = TestLogger + +appender.testlogger.name = TestLogger +appender.testlogger.type = CONSOLE +appender.testlogger.target = SYSTEM_ERR +appender.testlogger.layout.type = PatternLayout +appender.testlogger.layout.pattern = %-4r [%t] %-5p %c - %m%n diff --git a/flink-cdc-pipeline-model/pom.xml b/flink-cdc-pipeline-model/pom.xml index 29a24da57e9..7cfb2fb53fc 100644 --- a/flink-cdc-pipeline-model/pom.xml +++ b/flink-cdc-pipeline-model/pom.xml @@ -31,6 +31,7 @@ limitations under the License. flink-cdc-pipeline-model-dummy + flink-cdc-pipeline-model-openai-compatible From f34c0f3d31ed4d07765ff1229d55eeea27b6f397 Mon Sep 17 00:00:00 2001 From: yux <34335406+yuxiqian@users.noreply.github.com> Date: Fri, 15 May 2026 10:19:48 +0800 Subject: [PATCH 03/11] use stable option orders --- .../apache/flink/cdc/common/model/AiModelClientFactory.java | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/flink-cdc-common/src/main/java/org/apache/flink/cdc/common/model/AiModelClientFactory.java b/flink-cdc-common/src/main/java/org/apache/flink/cdc/common/model/AiModelClientFactory.java index e32d2ed0e09..a3009dfd90c 100644 --- a/flink-cdc-common/src/main/java/org/apache/flink/cdc/common/model/AiModelClientFactory.java +++ b/flink-cdc-common/src/main/java/org/apache/flink/cdc/common/model/AiModelClientFactory.java @@ -19,6 +19,7 @@ import org.apache.flink.cdc.common.annotation.Experimental; +import java.util.List; import java.util.Set; import java.util.stream.Collectors; @@ -64,10 +65,11 @@ default void validate(ModelContext context) { } } if (required != null && optional != null) { - Set unknown = + List unknown = context.getOptions().keySet().stream() .filter(k -> !required.contains(k) && !optional.contains(k)) - .collect(Collectors.toSet()); + .sorted() + .collect(Collectors.toList()); if (!unknown.isEmpty()) { throw new IllegalArgumentException( "Unknown options for model '" From 4e475f838579e73ee043a613221dbf425e9a9eef Mon Sep 17 00:00:00 2001 From: yux <34335406+yuxiqian@users.noreply.github.com> Date: Fri, 15 May 2026 11:18:29 +0800 Subject: [PATCH 04/11] enforce stricter checks --- .../cdc/runtime/parser/JaninoCompiler.java | 33 +++++++++++++++++ .../runtime/parser/AiFunctionParserTest.java | 36 +++++++++++++++++++ 2 files changed, 69 insertions(+) diff --git a/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/parser/JaninoCompiler.java b/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/parser/JaninoCompiler.java index 3093231e2c8..518f46d5c24 100644 --- a/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/parser/JaninoCompiler.java +++ b/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/parser/JaninoCompiler.java @@ -48,6 +48,8 @@ import org.codehaus.janino.ExpressionEvaluator; import org.codehaus.janino.Java; +import javax.lang.model.SourceVersion; + import java.util.ArrayList; import java.util.Arrays; import java.util.List; @@ -195,6 +197,10 @@ private static Java.Rvalue translateSqlSqlLiteral(Context context, SqlLiteral sq } private static Java.Rvalue translateSqlBasicCall(Context context, SqlBasicCall sqlBasicCall) { + String operationName = sqlBasicCall.getOperator().getName().toUpperCase(); + if (isAiFunction(operationName)) { + validateAiFunctionModelArg(sqlBasicCall); + } List operandList = sqlBasicCall.getOperandList(); List atoms = new ArrayList<>(); for (SqlNode sqlNode : operandList) { @@ -560,11 +566,38 @@ private static boolean isAiFunction(String upperCaseName) { return false; } + private static void validateAiFunctionModelArg(SqlBasicCall sqlBasicCall) { + String functionName = sqlBasicCall.getOperator().getName(); + List operandList = sqlBasicCall.getOperandList(); + if (operandList.isEmpty()) { + throw new ParseException( + "AI function '" + + functionName + + "' requires the model name as the first argument."); + } + SqlNode first = operandList.get(0); + if (!(first instanceof SqlCharStringLiteral)) { + throw new ParseException( + "The first argument of AI function '" + + functionName + + "' must be a string literal naming the model, but got: " + + first + + "."); + } + } + private static void rewriteAiFunctionModelArg(Java.Rvalue[] atoms) { String modelName = atoms[0].toString(); if (modelName.startsWith("\"") && modelName.endsWith("\"")) { modelName = modelName.substring(1, modelName.length() - 1); } + if (!SourceVersion.isName(modelName)) { + throw new ParseException( + "AI function model name '" + + modelName + + "' is not a valid Java identifier. " + + "Model names must follow Java identifier rules and must not be reserved keywords."); + } atoms[0] = new Java.AmbiguousName(Location.NOWHERE, new String[] {modelName}); } diff --git a/flink-cdc-runtime/src/test/java/org/apache/flink/cdc/runtime/parser/AiFunctionParserTest.java b/flink-cdc-runtime/src/test/java/org/apache/flink/cdc/runtime/parser/AiFunctionParserTest.java index af1d225e665..e0e6441d7d2 100644 --- a/flink-cdc-runtime/src/test/java/org/apache/flink/cdc/runtime/parser/AiFunctionParserTest.java +++ b/flink-cdc-runtime/src/test/java/org/apache/flink/cdc/runtime/parser/AiFunctionParserTest.java @@ -79,6 +79,42 @@ void testTranslateAiFunctionInProjection() { .hasMessageContaining("Was expecting 3 arguments"); } + @Test + void testAiFunctionRejectsNonStringLiteralModelArg() { + assertThatThrownBy( + () -> + translateAsProjection( + "AI_COMPLETE(content, content, 'prompt') AS out_col")) + .hasMessageContaining( + "The first argument of AI function 'AI_COMPLETE' must be a string literal naming the model"); + + assertThatThrownBy(() -> translateAsProjection("AI_EMBED(123, content) AS out_col")) + .hasMessageContaining( + "The first argument of AI function 'AI_EMBED' must be a string literal naming the model"); + + assertThatThrownBy(() -> translateAsProjection("AI_EMBED(UPPER('m'), content) AS out_col")) + .hasMessageContaining( + "The first argument of AI function 'AI_EMBED' must be a string literal naming the model"); + + assertThatThrownBy( + () -> + translateAsProjection( + "AI_COMPLETE('my-model', content, 'p') AS out_col")) + .hasMessageContaining( + "AI function model name 'my-model' is not a valid Java identifier.") + .hasMessageContaining( + "Model names must follow Java identifier rules and must not be reserved keywords."); + + assertThatThrownBy( + () -> + translateAsProjection( + "AI_COMPLETE('class', content, 'p') AS out_col")) + .hasMessageContaining( + "AI function model name 'class' is not a valid Java identifier.") + .hasMessageContaining( + "Model names must follow Java identifier rules and must not be reserved keywords."); + } + @Test void testTranslateAiFunctionInFilter() { assertThat(translateAsFilter("AI_COMPLETE('myModel', content, 'Classify this text')")) From c13f0c306fc3ddda680ebc3071349d9c5ed61f1f Mon Sep 17 00:00:00 2001 From: yux <34335406+yuxiqian@users.noreply.github.com> Date: Fri, 15 May 2026 11:24:37 +0800 Subject: [PATCH 05/11] remove stale overloads --- .../operators/transform/TransformExpressionCompiler.java | 7 ------- 1 file changed, 7 deletions(-) diff --git a/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/operators/transform/TransformExpressionCompiler.java b/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/operators/transform/TransformExpressionCompiler.java index 776b4af2d0c..9faa111966f 100644 --- a/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/operators/transform/TransformExpressionCompiler.java +++ b/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/operators/transform/TransformExpressionCompiler.java @@ -31,7 +31,6 @@ import org.slf4j.LoggerFactory; import java.util.ArrayList; -import java.util.Collections; import java.util.List; import java.util.Map; @@ -54,12 +53,6 @@ public static void cleanUp() { COMPILED_EXPRESSION_CACHE.invalidateAll(); } - /** Compiles an expression code to a janino {@link ExpressionEvaluator}. */ - public static ExpressionEvaluator compileExpression( - TransformExpressionKey key, List udfDescriptors) { - return compileExpression(key, udfDescriptors, Collections.emptyMap()); - } - /** * Compiles an expression code to a janino {@link ExpressionEvaluator}, with additional {@link * AiModelClient} instances appended after UDF instances. From dde3957421c31a778a8c538b73f8000cf4ac7a68 Mon Sep 17 00:00:00 2001 From: yuxiqian <34335406+yuxiqian@users.noreply.github.com> Date: Fri, 15 May 2026 11:27:40 +0800 Subject: [PATCH 06/11] close openai client properly Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- .../cdc/models/openai/OpenAiCompatibleModelClient.java | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/src/main/java/org/apache/flink/cdc/models/openai/OpenAiCompatibleModelClient.java b/flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/src/main/java/org/apache/flink/cdc/models/openai/OpenAiCompatibleModelClient.java index 5aed57a35b3..ae378b7615e 100644 --- a/flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/src/main/java/org/apache/flink/cdc/models/openai/OpenAiCompatibleModelClient.java +++ b/flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/src/main/java/org/apache/flink/cdc/models/openai/OpenAiCompatibleModelClient.java @@ -56,7 +56,13 @@ public void open() { @Override public void close() { - client = null; + if (client != null) { + try { + client.close(); + } finally { + client = null; + } + } } @Override From 4dd7716040cf199742aa473795de084cd6f22ecd Mon Sep 17 00:00:00 2001 From: yux <34335406+yuxiqian@users.noreply.github.com> Date: Fri, 15 May 2026 11:35:36 +0800 Subject: [PATCH 07/11] throw error message --- .../cdc/models/openai/OpenAiCompatibleModelClient.java | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/src/main/java/org/apache/flink/cdc/models/openai/OpenAiCompatibleModelClient.java b/flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/src/main/java/org/apache/flink/cdc/models/openai/OpenAiCompatibleModelClient.java index ae378b7615e..909c385eb91 100644 --- a/flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/src/main/java/org/apache/flink/cdc/models/openai/OpenAiCompatibleModelClient.java +++ b/flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/src/main/java/org/apache/flink/cdc/models/openai/OpenAiCompatibleModelClient.java @@ -84,7 +84,11 @@ public float[] embed(String text) { CreateEmbeddingResponse response = client.embeddings().create(params); List data = response.data(); if (data.isEmpty()) { - return new float[0]; + throw new IllegalStateException( + "Embedding response from model '" + + modelName + + "' contained no embeddings. " + + "This indicates a server-side anomaly; refusing to emit an empty vector."); } List embedding = data.get(0).embedding(); float[] result = new float[embedding.size()]; From 698fdf430a6e39d5b5ed810a0e58778fc5c57568 Mon Sep 17 00:00:00 2001 From: yux <34335406+yuxiqian@users.noreply.github.com> Date: Fri, 15 May 2026 11:41:29 +0800 Subject: [PATCH 08/11] more ai function parser tests --- .../cdc/runtime/parser/JaninoCompiler.java | 2 +- .../runtime/parser/AiFunctionParserTest.java | 32 +++++++++++++++++++ 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/parser/JaninoCompiler.java b/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/parser/JaninoCompiler.java index 518f46d5c24..a022f879c42 100644 --- a/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/parser/JaninoCompiler.java +++ b/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/parser/JaninoCompiler.java @@ -591,7 +591,7 @@ private static void rewriteAiFunctionModelArg(Java.Rvalue[] atoms) { if (modelName.startsWith("\"") && modelName.endsWith("\"")) { modelName = modelName.substring(1, modelName.length() - 1); } - if (!SourceVersion.isName(modelName)) { + if (!SourceVersion.isIdentifier(modelName) || SourceVersion.isKeyword(modelName)) { throw new ParseException( "AI function model name '" + modelName diff --git a/flink-cdc-runtime/src/test/java/org/apache/flink/cdc/runtime/parser/AiFunctionParserTest.java b/flink-cdc-runtime/src/test/java/org/apache/flink/cdc/runtime/parser/AiFunctionParserTest.java index e0e6441d7d2..64e226c0c56 100644 --- a/flink-cdc-runtime/src/test/java/org/apache/flink/cdc/runtime/parser/AiFunctionParserTest.java +++ b/flink-cdc-runtime/src/test/java/org/apache/flink/cdc/runtime/parser/AiFunctionParserTest.java @@ -113,6 +113,38 @@ void testAiFunctionRejectsNonStringLiteralModelArg() { "AI function model name 'class' is not a valid Java identifier.") .hasMessageContaining( "Model names must follow Java identifier rules and must not be reserved keywords."); + + assertThatThrownBy( + () -> + translateAsProjection( + "AI_COMPLETE('my.model', content, 'p') AS out_col")) + .hasMessageContaining( + "AI function model name 'my.model' is not a valid Java identifier."); + + assertThatThrownBy( + () -> + translateAsProjection( + "AI_COMPLETE('my model', content, 'p') AS out_col")) + .hasMessageContaining( + "AI function model name 'my model' is not a valid Java identifier."); + + assertThatThrownBy( + () -> + translateAsProjection( + "AI_COMPLETE('123model', content, 'p') AS out_col")) + .hasMessageContaining( + "AI function model name '123model' is not a valid Java identifier."); + + assertThatThrownBy(() -> translateAsProjection("AI_COMPLETE('', content, 'p') AS out_col")) + .hasMessageContaining("AI function model name '' is not a valid Java identifier."); + } + + @Test + void testAiFunctionPreservesModelNameCase() { + assertThat(translateAsFilter("AI_COMPLETE('MyModel', content, 'p')")) + .isEqualTo("aiComplete(MyModel, $2, \"p\")"); + assertThat(translateAsFilter("AI_COMPLETE('mymodel', content, 'p')")) + .isEqualTo("aiComplete(mymodel, $2, \"p\")"); } @Test From 574d97fe4a8c2aee696c42dc5af4f56464749435 Mon Sep 17 00:00:00 2001 From: yux <34335406+yuxiqian@users.noreply.github.com> Date: Wed, 20 May 2026 16:32:23 +0800 Subject: [PATCH 09/11] fail fast during open --- .../openai/OpenAiCompatibleModelClient.java | 13 +++++++++++++ .../OpenAiCompatibleModelClientFactoryTest.java | 17 +++++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/src/main/java/org/apache/flink/cdc/models/openai/OpenAiCompatibleModelClient.java b/flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/src/main/java/org/apache/flink/cdc/models/openai/OpenAiCompatibleModelClient.java index 909c385eb91..1f29dca0bf9 100644 --- a/flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/src/main/java/org/apache/flink/cdc/models/openai/OpenAiCompatibleModelClient.java +++ b/flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/src/main/java/org/apache/flink/cdc/models/openai/OpenAiCompatibleModelClient.java @@ -28,6 +28,8 @@ import com.openai.models.embeddings.CreateEmbeddingResponse; import com.openai.models.embeddings.Embedding; import com.openai.models.embeddings.EmbeddingCreateParams; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import java.util.List; @@ -35,6 +37,8 @@ public class OpenAiCompatibleModelClient implements AiModelClient, SupportsTextGeneration, SupportsEmbedding { + private static final Logger LOG = LoggerFactory.getLogger(OpenAiCompatibleModelClient.class); + private static final long serialVersionUID = 1L; private final String endpoint; @@ -52,6 +56,15 @@ public OpenAiCompatibleModelClient(String endpoint, String apiKey, String modelN @Override public void open() { client = OpenAIOkHttpClient.builder().baseUrl(endpoint).apiKey(apiKey).build(); + LOG.info( + "Successfully constructed OpenAI http client. Endpoint: {} Model: {}", + endpoint, + modelName); + try { + client.models().list(); + } catch (Exception e) { + throw new RuntimeException("Failed to perform livecheck on OpenAI model client", e); + } } @Override diff --git a/flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/src/test/java/org/apache/flink/cdc/models/openai/OpenAiCompatibleModelClientFactoryTest.java b/flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/src/test/java/org/apache/flink/cdc/models/openai/OpenAiCompatibleModelClientFactoryTest.java index 585ab25adec..d9f3d11cd20 100644 --- a/flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/src/test/java/org/apache/flink/cdc/models/openai/OpenAiCompatibleModelClientFactoryTest.java +++ b/flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/src/test/java/org/apache/flink/cdc/models/openai/OpenAiCompatibleModelClientFactoryTest.java @@ -115,4 +115,21 @@ void testValidateThrowsOnUnknownOption() { .isInstanceOf(IllegalArgumentException.class) .hasMessageContaining("Unknown options"); } + + @Test + void testCreateClientFailFast() { + Map options = new HashMap<>(); + options.put("endpoint", "https://api.example.com/v1"); + options.put("api-key", "sk-test"); + options.put("model-name", "gpt-4"); + + AiModelClient client = factory.createClient(contextWithOptions(options)); + assertThat(client).isInstanceOf(OpenAiCompatibleModelClient.class); + assertThat(client).isInstanceOf(SupportsTextGeneration.class); + assertThat(client).isInstanceOf(SupportsEmbedding.class); + + assertThatThrownBy(client::open) + .isInstanceOf(RuntimeException.class) + .hasMessageContaining("Failed to perform livecheck on OpenAI model client"); + } } From 13254855e0d931e3d582f0bcfda2c27e31a29f1f Mon Sep 17 00:00:00 2001 From: yux <34335406+yuxiqian@users.noreply.github.com> Date: Wed, 20 May 2026 16:33:06 +0800 Subject: [PATCH 10/11] add timing logs --- .../runtime/functions/impl/AiFunctions.java | 32 ++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/functions/impl/AiFunctions.java b/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/functions/impl/AiFunctions.java index be49d8d7785..21c911f218c 100644 --- a/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/functions/impl/AiFunctions.java +++ b/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/functions/impl/AiFunctions.java @@ -27,11 +27,16 @@ import org.apache.flink.shaded.guava31.com.google.common.primitives.Floats; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + import java.util.List; /** Built-in AI functions available as static imports in Janino-compiled transform expressions. */ public class AiFunctions { + private static final Logger LOG = LoggerFactory.getLogger(AiFunctions.class); + /** General-purpose text completion with a user-provided system prompt. */ public static BinaryVariant aiComplete(AiModelClient model, String input, String systemPrompt) { return invokeTextGeneration(model, AiTextFunctionDef.AI_COMPLETE, input, systemPrompt); @@ -48,7 +53,21 @@ public static List aiEmbed(AiModelClient model, String input) { throw new UnsupportedOperationException( "Model " + model.getClass().getName() + " does not support embedding"); } - return Floats.asList(((SupportsEmbedding) model).embed(input)); + + long startTime = 0; + if (LOG.isDebugEnabled()) { + startTime = System.currentTimeMillis(); + } + float[] embeddingResult = ((SupportsEmbedding) model).embed(input); + if (LOG.isDebugEnabled()) { + long endTime = System.currentTimeMillis(); + LOG.debug( + "Generated {}-dim vector in {} ms", + embeddingResult.length, + endTime - startTime); + } + + return Floats.asList(embeddingResult); } private static BinaryVariant invokeTextGeneration( @@ -62,7 +81,18 @@ private static BinaryVariant invokeTextGeneration( promptBuilder.append("\n").append(buildOutputSchemaHint(funcDef.getOutputType())); String systemPrompt = promptBuilder.toString(); + + long startTime = 0; + if (LOG.isDebugEnabled()) { + startTime = System.currentTimeMillis(); + } String json = ((SupportsTextGeneration) model).generate(systemPrompt, input); + + if (LOG.isDebugEnabled()) { + long endTime = System.currentTimeMillis(); + LOG.debug("Generated {} characters in {} ms", json.length(), endTime - startTime); + } + if (json == null) { return null; } From 771508546daacb208bdef251d700411b87c5e115 Mon Sep 17 00:00:00 2001 From: yux <34335406+yuxiqian@users.noreply.github.com> Date: Thu, 21 May 2026 20:51:13 +0800 Subject: [PATCH 11/11] Use cdc common factory instead of inventing my own --- .../AiModelClientFactory.java} | 22 +-- .../common/model/AiModelClientFactory.java | 91 ---------- .../model/AiModelClientFactoryTest.java | 166 ------------------ .../flink/translator/TransformTranslator.java | 66 ++----- ....apache.flink.cdc.common.factories.Factory | 3 +- .../models/dummy/DummyModelClientFactory.java | 20 ++- ....apache.flink.cdc.common.factories.Factory | 0 ...link.cdc.common.model.AiModelClientFactory | 16 -- .../pom.xml | 7 + .../openai/OpenAiCompatibleModelClient.java | 124 ++++++++++++- .../OpenAiCompatibleModelClientFactory.java | 154 ++++++++++++++-- .../openai/OpenAiCompatibleModelOptions.java | 90 ++++++++++ ...apache.flink.cdc.common.factories.Factory} | 0 ...penAiCompatibleModelClientFactoryTest.java | 158 ++++++++++++++--- .../OpenAiCompatibleModelClientITCase.java | 58 +++++- .../OpenAiCompatibleModelClientTest.java | 60 +++++++ 16 files changed, 639 insertions(+), 396 deletions(-) rename flink-cdc-common/src/main/java/org/apache/flink/cdc/common/{model/ModelContext.java => factories/AiModelClientFactory.java} (63%) delete mode 100644 flink-cdc-common/src/main/java/org/apache/flink/cdc/common/model/AiModelClientFactory.java delete mode 100644 flink-cdc-common/src/test/java/org/apache/flink/cdc/common/model/AiModelClientFactoryTest.java rename flink-cdc-composer/src/test/resources/META-INF/services/org.apache.flink.cdc.common.model.AiModelClientFactory => flink-cdc-pipeline-model/flink-cdc-pipeline-model-dummy/src/main/resources/META-INF/services/org.apache.flink.cdc.common.factories.Factory (100%) delete mode 100644 flink-cdc-pipeline-model/flink-cdc-pipeline-model-dummy/src/main/resources/META-INF/services/org.apache.flink.cdc.common.model.AiModelClientFactory create mode 100644 flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/src/main/java/org/apache/flink/cdc/models/openai/OpenAiCompatibleModelOptions.java rename flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/src/main/resources/META-INF/services/{org.apache.flink.cdc.common.model.AiModelClientFactory => org.apache.flink.cdc.common.factories.Factory} (100%) create mode 100644 flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/src/test/java/org/apache/flink/cdc/models/openai/OpenAiCompatibleModelClientTest.java diff --git a/flink-cdc-common/src/main/java/org/apache/flink/cdc/common/model/ModelContext.java b/flink-cdc-common/src/main/java/org/apache/flink/cdc/common/factories/AiModelClientFactory.java similarity index 63% rename from flink-cdc-common/src/main/java/org/apache/flink/cdc/common/model/ModelContext.java rename to flink-cdc-common/src/main/java/org/apache/flink/cdc/common/factories/AiModelClientFactory.java index a6b30250a4d..974a07fbad3 100644 --- a/flink-cdc-common/src/main/java/org/apache/flink/cdc/common/model/ModelContext.java +++ b/flink-cdc-common/src/main/java/org/apache/flink/cdc/common/factories/AiModelClientFactory.java @@ -15,22 +15,18 @@ * limitations under the License. */ -package org.apache.flink.cdc.common.model; +package org.apache.flink.cdc.common.factories; import org.apache.flink.cdc.common.annotation.Experimental; +import org.apache.flink.cdc.common.model.AiModelClient; -import java.util.Map; - -/** Context passed to {@link AiModelClientFactory#createClient} at pipeline assembly time. */ +/** + * A factory to create {@link AiModelClient} instances. See also {@link Factory} for more + * information. + */ @Experimental -public interface ModelContext { - - /** The logical name of this model as declared in the pipeline YAML. */ - String getModelName(); - - /** Raw key/value options from the pipeline YAML {@code model.options} block. */ - Map getOptions(); +public interface AiModelClientFactory extends Factory { - /** Class loader to use when loading implementation classes. */ - ClassLoader getClassLoader(); + /** Creates a new {@link AiModelClient} instance. */ + AiModelClient createClient(Context context); } diff --git a/flink-cdc-common/src/main/java/org/apache/flink/cdc/common/model/AiModelClientFactory.java b/flink-cdc-common/src/main/java/org/apache/flink/cdc/common/model/AiModelClientFactory.java deleted file mode 100644 index a3009dfd90c..00000000000 --- a/flink-cdc-common/src/main/java/org/apache/flink/cdc/common/model/AiModelClientFactory.java +++ /dev/null @@ -1,91 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.cdc.common.model; - -import org.apache.flink.cdc.common.annotation.Experimental; - -import java.util.List; -import java.util.Set; -import java.util.stream.Collectors; - -/** - * SPI interface for AI model client factories. Each provider (e.g. OpenAI-compatible, DashScope) - * ships one implementation, discoverable via {@link java.util.ServiceLoader}. - * - *

The {@link #identifier()} value maps to the {@code type} field of a {@code pipeline.model} - * entry in the pipeline YAML. - */ -@Experimental -public interface AiModelClientFactory { - - /** A unique, lower-case identifier for this provider, e.g. {@code "openai-compatible"}. */ - String identifier(); - - /** Option keys that must be present in the model YAML options block. */ - Set requiredOptions(); - - /** Option keys that may optionally appear in the model YAML options block. */ - Set optionalOptions(); - - /** - * Validates that the given context contains all required options and no unknown options. - * Subclasses may override this to add custom validation logic. - */ - default void validate(ModelContext context) { - Set required = requiredOptions(); - Set optional = optionalOptions(); - if (required != null) { - Set missing = - required.stream() - .filter(k -> !context.getOptions().containsKey(k)) - .collect(Collectors.toSet()); - if (!missing.isEmpty()) { - throw new IllegalArgumentException( - "Missing required options for model '" - + context.getModelName() - + "' (type='" - + identifier() - + "'): " - + missing); - } - } - if (required != null && optional != null) { - List unknown = - context.getOptions().keySet().stream() - .filter(k -> !required.contains(k) && !optional.contains(k)) - .sorted() - .collect(Collectors.toList()); - if (!unknown.isEmpty()) { - throw new IllegalArgumentException( - "Unknown options for model '" - + context.getModelName() - + "' (type='" - + identifier() - + "'): " - + unknown); - } - } - } - - /** - * Creates a new {@link AiModelClient} from the given context. Called once per model definition - * at pipeline assembly time on the job-manager side; the returned client is serialized and - * shipped to task managers. - */ - AiModelClient createClient(ModelContext context); -} diff --git a/flink-cdc-common/src/test/java/org/apache/flink/cdc/common/model/AiModelClientFactoryTest.java b/flink-cdc-common/src/test/java/org/apache/flink/cdc/common/model/AiModelClientFactoryTest.java deleted file mode 100644 index 721362ec20e..00000000000 --- a/flink-cdc-common/src/test/java/org/apache/flink/cdc/common/model/AiModelClientFactoryTest.java +++ /dev/null @@ -1,166 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.cdc.common.model; - -import org.assertj.core.api.Assertions; -import org.junit.jupiter.api.Test; - -import java.util.Collections; -import java.util.HashMap; -import java.util.Map; -import java.util.Set; - -/** Tests for the default {@link AiModelClientFactory#validate} method. */ -class AiModelClientFactoryTest { - - private static final String IDENTIFIER = "test-provider"; - private static final String MODEL_NAME = "my-model"; - - private static final class StubFactory implements AiModelClientFactory { - private final Set required; - private final Set optional; - - StubFactory(Set required, Set optional) { - this.required = required; - this.optional = optional; - } - - @Override - public String identifier() { - return IDENTIFIER; - } - - @Override - public Set requiredOptions() { - return required; - } - - @Override - public Set optionalOptions() { - return optional; - } - - @Override - public AiModelClient createClient(ModelContext context) { - return new AiModelClient() {}; - } - } - - private static ModelContext contextWithOptions(Map options) { - return new ModelContext() { - @Override - public String getModelName() { - return MODEL_NAME; - } - - @Override - public Map getOptions() { - return options; - } - - @Override - public ClassLoader getClassLoader() { - return Thread.currentThread().getContextClassLoader(); - } - }; - } - - @Test - void testValidatePassesWithAllRequiredOptions() { - StubFactory factory = new StubFactory(Set.of("api-key", "endpoint"), Set.of("timeout")); - - Map options = new HashMap<>(); - options.put("api-key", "sk-xxx"); - options.put("endpoint", "https://api.example.com"); - - // Should not throw - factory.validate(contextWithOptions(options)); - } - - @Test - void testValidatePassesWithRequiredAndOptionalOptions() { - StubFactory factory = new StubFactory(Set.of("api-key", "endpoint"), Set.of("timeout")); - - Map options = new HashMap<>(); - options.put("api-key", "sk-xxx"); - options.put("endpoint", "https://api.example.com"); - options.put("timeout", "30000"); - - factory.validate(contextWithOptions(options)); - } - - @Test - void testValidateThrowsOnMissingRequiredOption() { - StubFactory factory = new StubFactory(Set.of("api-key", "endpoint"), Set.of("timeout")); - - // Missing "endpoint" - Map options = new HashMap<>(); - options.put("api-key", "sk-xxx"); - - Assertions.assertThatThrownBy(() -> factory.validate(contextWithOptions(options))) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage( - "Missing required options for model 'my-model' (type='test-provider'): [endpoint]"); - } - - @Test - void testValidateThrowsOnMultipleMissingRequiredOptions() { - StubFactory factory = new StubFactory(Set.of("api-key", "endpoint", "model"), Set.of()); - - // All required options missing - Assertions.assertThatThrownBy( - () -> factory.validate(contextWithOptions(Collections.emptyMap()))) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage( - "Missing required options for model 'my-model' (type='test-provider'): [endpoint, api-key, model]"); - } - - @Test - void testValidateThrowsOnUnknownOption() { - StubFactory factory = new StubFactory(Set.of("api-key"), Set.of("timeout")); - - Map options = new HashMap<>(); - options.put("api-key", "sk-xxx"); - options.put("bogus", "unexpected"); - - Assertions.assertThatThrownBy(() -> factory.validate(contextWithOptions(options))) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("Unknown options for model 'my-model' (type='test-provider'): [bogus]"); - } - - @Test - void testValidateThrowsOnMultipleUnknownOptions() { - StubFactory factory = new StubFactory(Set.of("api-key"), Set.of()); - - Map options = new HashMap<>(); - options.put("api-key", "sk-xxx"); - options.put("foo", "a"); - options.put("bar", "b"); - - Assertions.assertThatThrownBy(() -> factory.validate(contextWithOptions(options))) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage( - "Unknown options for model 'my-model' (type='test-provider'): [bar, foo]"); - } - - @Test - void testValidatePassesWithNoRequiredAndNoOptions() { - StubFactory factory = new StubFactory(Set.of(), Set.of()); - factory.validate(contextWithOptions(Collections.emptyMap())); - } -} diff --git a/flink-cdc-composer/src/main/java/org/apache/flink/cdc/composer/flink/translator/TransformTranslator.java b/flink-cdc-composer/src/main/java/org/apache/flink/cdc/composer/flink/translator/TransformTranslator.java index 3056ad57663..1364191500b 100644 --- a/flink-cdc-composer/src/main/java/org/apache/flink/cdc/composer/flink/translator/TransformTranslator.java +++ b/flink-cdc-composer/src/main/java/org/apache/flink/cdc/composer/flink/translator/TransformTranslator.java @@ -18,14 +18,17 @@ package org.apache.flink.cdc.composer.flink.translator; import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.cdc.common.configuration.Configuration; import org.apache.flink.cdc.common.event.Event; +import org.apache.flink.cdc.common.factories.AiModelClientFactory; +import org.apache.flink.cdc.common.factories.Factory; +import org.apache.flink.cdc.common.factories.FactoryHelper; import org.apache.flink.cdc.common.model.AiModelClient; -import org.apache.flink.cdc.common.model.AiModelClientFactory; -import org.apache.flink.cdc.common.model.ModelContext; import org.apache.flink.cdc.common.source.SupportedMetadataColumn; import org.apache.flink.cdc.composer.definition.ModelDef; import org.apache.flink.cdc.composer.definition.TransformDef; import org.apache.flink.cdc.composer.definition.UdfDef; +import org.apache.flink.cdc.composer.utils.FactoryDiscoveryUtils; import org.apache.flink.cdc.runtime.operators.transform.PostTransformOperator; import org.apache.flink.cdc.runtime.operators.transform.PostTransformOperatorBuilder; import org.apache.flink.cdc.runtime.operators.transform.PreTransformOperator; @@ -34,11 +37,9 @@ import org.apache.flink.streaming.api.datastream.DataStream; import java.util.Collections; -import java.util.HashMap; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; -import java.util.ServiceLoader; import java.util.stream.Collectors; /** @@ -133,27 +134,17 @@ private Map loadModelClients(List models) { return Collections.emptyMap(); } - Map factories = new HashMap<>(); - ServiceLoader loader = - ServiceLoader.load( - AiModelClientFactory.class, Thread.currentThread().getContextClassLoader()); - for (AiModelClientFactory factory : loader) { - factories.put(factory.identifier(), factory); - } - Map clients = new LinkedHashMap<>(); for (ModelDef model : models) { - AiModelClientFactory factory = factories.get(model.getType()); - if (factory == null) { - throw new IllegalArgumentException( - "No AiModelClientFactory found for model type '" - + model.getType() - + "'. Available factories: " - + factories.keySet()); - } - ModelContext ctx = - new DefaultModelContext(model, Thread.currentThread().getContextClassLoader()); - factory.validate(ctx); + AiModelClientFactory factory = + FactoryDiscoveryUtils.getFactoryByIdentifier( + model.getType(), AiModelClientFactory.class); + Factory.Context ctx = + new FactoryHelper.DefaultContext( + Configuration.fromMap(model.getOptions()), + new Configuration(), + Thread.currentThread().getContextClassLoader()); + FactoryHelper.createFactoryHelper(factory, ctx).validate(); AiModelClient client = factory.createClient(ctx); clients.put(model.getName(), client); } @@ -163,33 +154,4 @@ private Map loadModelClients(List models) { private Tuple3> udfDefToUDFTuple(UdfDef udf) { return Tuple3.of(udf.getName(), udf.getClasspath(), udf.getOptions()); } - - // ------------------------------------------------------------------------- - // Internal ModelContext implementation - // ------------------------------------------------------------------------- - - private static final class DefaultModelContext implements ModelContext { - private final ModelDef modelDef; - private final ClassLoader classLoader; - - DefaultModelContext(ModelDef modelDef, ClassLoader classLoader) { - this.modelDef = modelDef; - this.classLoader = classLoader; - } - - @Override - public String getModelName() { - return modelDef.getName(); - } - - @Override - public Map getOptions() { - return modelDef.getOptions(); - } - - @Override - public ClassLoader getClassLoader() { - return classLoader; - } - } } diff --git a/flink-cdc-composer/src/test/resources/META-INF/services/org.apache.flink.cdc.common.factories.Factory b/flink-cdc-composer/src/test/resources/META-INF/services/org.apache.flink.cdc.common.factories.Factory index 274faa0078d..df0fb2fdfc1 100644 --- a/flink-cdc-composer/src/test/resources/META-INF/services/org.apache.flink.cdc.common.factories.Factory +++ b/flink-cdc-composer/src/test/resources/META-INF/services/org.apache.flink.cdc.common.factories.Factory @@ -18,4 +18,5 @@ org.apache.flink.cdc.composer.utils.factory.DataSourceFactory1 org.apache.flink.cdc.composer.utils.factory.DataSourceFactory2 org.apache.flink.cdc.composer.testsource.factory.DistributedDataSourceFactory org.apache.flink.cdc.composer.flink.FlinkPipelineComposerTest$TestDataSinkFactory -org.apache.flink.cdc.composer.flink.FlinkPipelineComposerTest$TestDataSourceFactory \ No newline at end of file +org.apache.flink.cdc.composer.flink.FlinkPipelineComposerTest$TestDataSourceFactory +org.apache.flink.cdc.models.dummy.DummyModelClientFactory diff --git a/flink-cdc-pipeline-model/flink-cdc-pipeline-model-dummy/src/main/java/org/apache/flink/cdc/models/dummy/DummyModelClientFactory.java b/flink-cdc-pipeline-model/flink-cdc-pipeline-model-dummy/src/main/java/org/apache/flink/cdc/models/dummy/DummyModelClientFactory.java index 4df45704ae3..042b9762998 100644 --- a/flink-cdc-pipeline-model/flink-cdc-pipeline-model-dummy/src/main/java/org/apache/flink/cdc/models/dummy/DummyModelClientFactory.java +++ b/flink-cdc-pipeline-model/flink-cdc-pipeline-model-dummy/src/main/java/org/apache/flink/cdc/models/dummy/DummyModelClientFactory.java @@ -17,33 +17,39 @@ package org.apache.flink.cdc.models.dummy; +import org.apache.flink.cdc.common.configuration.ConfigOption; +import org.apache.flink.cdc.common.factories.AiModelClientFactory; +import org.apache.flink.cdc.common.factories.Factory; import org.apache.flink.cdc.common.model.AiModelClient; -import org.apache.flink.cdc.common.model.AiModelClientFactory; -import org.apache.flink.cdc.common.model.ModelContext; import java.util.Set; +import static org.apache.flink.cdc.common.configuration.ConfigOptions.key; + /** SPI factory for {@link DummyModelClient}. For testing purposes only. */ public class DummyModelClientFactory implements AiModelClientFactory { + public static final ConfigOption DEBUG = + key("debug").booleanType().defaultValue(false); + @Override public String identifier() { return "dummy"; } @Override - public Set requiredOptions() { + public Set> requiredOptions() { return Set.of(); } @Override - public Set optionalOptions() { - return Set.of("debug"); + public Set> optionalOptions() { + return Set.of(DEBUG); } @Override - public AiModelClient createClient(ModelContext context) { - boolean debug = Boolean.parseBoolean(context.getOptions().getOrDefault("debug", "false")); + public AiModelClient createClient(Factory.Context context) { + boolean debug = context.getFactoryConfiguration().get(DEBUG); return new DummyModelClient(debug); } } diff --git a/flink-cdc-composer/src/test/resources/META-INF/services/org.apache.flink.cdc.common.model.AiModelClientFactory b/flink-cdc-pipeline-model/flink-cdc-pipeline-model-dummy/src/main/resources/META-INF/services/org.apache.flink.cdc.common.factories.Factory similarity index 100% rename from flink-cdc-composer/src/test/resources/META-INF/services/org.apache.flink.cdc.common.model.AiModelClientFactory rename to flink-cdc-pipeline-model/flink-cdc-pipeline-model-dummy/src/main/resources/META-INF/services/org.apache.flink.cdc.common.factories.Factory diff --git a/flink-cdc-pipeline-model/flink-cdc-pipeline-model-dummy/src/main/resources/META-INF/services/org.apache.flink.cdc.common.model.AiModelClientFactory b/flink-cdc-pipeline-model/flink-cdc-pipeline-model-dummy/src/main/resources/META-INF/services/org.apache.flink.cdc.common.model.AiModelClientFactory deleted file mode 100644 index c1ed9c43ff3..00000000000 --- a/flink-cdc-pipeline-model/flink-cdc-pipeline-model-dummy/src/main/resources/META-INF/services/org.apache.flink.cdc.common.model.AiModelClientFactory +++ /dev/null @@ -1,16 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -org.apache.flink.cdc.models.dummy.DummyModelClientFactory diff --git a/flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/pom.xml b/flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/pom.xml index 53ba797441a..549a3d7db06 100644 --- a/flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/pom.xml +++ b/flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/pom.xml @@ -32,6 +32,7 @@ limitations under the License. 2.13.4 + 3.12.4 @@ -40,6 +41,12 @@ limitations under the License. openai-java 4.32.0 + + org.mockito + mockito-core + ${mockito.version} + test + diff --git a/flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/src/main/java/org/apache/flink/cdc/models/openai/OpenAiCompatibleModelClient.java b/flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/src/main/java/org/apache/flink/cdc/models/openai/OpenAiCompatibleModelClient.java index 1f29dca0bf9..fe0fa425131 100644 --- a/flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/src/main/java/org/apache/flink/cdc/models/openai/OpenAiCompatibleModelClient.java +++ b/flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/src/main/java/org/apache/flink/cdc/models/openai/OpenAiCompatibleModelClient.java @@ -23,6 +23,9 @@ import com.openai.client.OpenAIClient; import com.openai.client.okhttp.OpenAIOkHttpClient; +import com.openai.core.JsonValue; +import com.openai.models.ResponseFormatJsonObject; +import com.openai.models.ResponseFormatText; import com.openai.models.chat.completions.ChatCompletion; import com.openai.models.chat.completions.ChatCompletionCreateParams; import com.openai.models.embeddings.CreateEmbeddingResponse; @@ -31,7 +34,9 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.util.Collections; import java.util.List; +import java.util.Map; /** AI model client that connects to any OpenAI-compatible endpoint. */ public class OpenAiCompatibleModelClient @@ -44,13 +49,52 @@ public class OpenAiCompatibleModelClient private final String endpoint; private final String apiKey; private final String modelName; + private final String globalSystemPrompt; + private final Double temperature; + private final Double topP; + private final String stop; + private final Integer maxTokens; + private final Double presencePenalty; + private final Integer n; + private final Long seed; + private final String responseFormat; + private final Map> extraHeaders; + private final Map extraBody; + private final Integer embeddingDimension; private transient OpenAIClient client; - public OpenAiCompatibleModelClient(String endpoint, String apiKey, String modelName) { + public OpenAiCompatibleModelClient( + String endpoint, + String apiKey, + String modelName, + String globalSystemPrompt, + Double temperature, + Double topP, + String stop, + Integer maxTokens, + Double presencePenalty, + Integer n, + Long seed, + String responseFormat, + Map> extraHeaders, + Map extraBody, + Integer embeddingDimension) { this.endpoint = endpoint; this.apiKey = apiKey; this.modelName = modelName; + this.globalSystemPrompt = globalSystemPrompt; + this.temperature = temperature; + this.topP = topP; + this.stop = stop; + this.maxTokens = maxTokens; + this.presencePenalty = presencePenalty; + this.n = n; + this.seed = seed; + this.responseFormat = responseFormat; + this.extraHeaders = extraHeaders == null ? Collections.emptyMap() : extraHeaders; + this.extraBody = extraBody == null ? Collections.emptyMap() : extraBody; + this.embeddingDimension = embeddingDimension; } @Override @@ -80,20 +124,86 @@ public void close() { @Override public String generate(String systemPrompt, String userInput) { - ChatCompletionCreateParams params = + String mergedSystemPrompt = mergeSystemPrompt(globalSystemPrompt, systemPrompt); + ChatCompletionCreateParams.Builder builder = ChatCompletionCreateParams.builder() .model(modelName) - .addSystemMessage(systemPrompt) - .addUserMessage(userInput) - .build(); + .addSystemMessage(mergedSystemPrompt) + .addUserMessage(userInput); + + if (temperature != null) { + builder.temperature(temperature); + } + if (topP != null) { + builder.topP(topP); + } + if (stop != null && !stop.trim().isEmpty()) { + builder.stop(stop); + } + if (maxTokens != null) { + builder.maxTokens(maxTokens.longValue()); + } + if (presencePenalty != null) { + builder.presencePenalty(presencePenalty); + } + if (n != null) { + builder.n(n.longValue()); + } + if (seed != null) { + builder.seed(seed); + } + if ("json_object".equals(responseFormat)) { + builder.responseFormat( + ResponseFormatJsonObject.builder().type(JsonValue.from("json_object")).build()); + } else { + builder.responseFormat( + ResponseFormatText.builder().type(JsonValue.from("text")).build()); + } + for (Map.Entry> entry : extraHeaders.entrySet()) { + builder.putAdditionalHeaders(entry.getKey(), entry.getValue()); + } + for (Map.Entry entry : extraBody.entrySet()) { + builder.putAdditionalBodyProperty(entry.getKey(), JsonValue.from(entry.getValue())); + } + + ChatCompletionCreateParams params = builder.build(); ChatCompletion completion = client.chat().completions().create(params); return completion.choices().get(0).message().content().orElse(null); } + static String mergeSystemPrompt(String globalSystemPrompt, String runtimeSystemPrompt) { + String normalizedGlobal = normalizePrompt(globalSystemPrompt); + String normalizedRuntime = normalizePrompt(runtimeSystemPrompt); + if (normalizedGlobal == null) { + return normalizedRuntime == null ? "" : normalizedRuntime; + } + if (normalizedRuntime == null) { + return normalizedGlobal; + } + return normalizedGlobal + "\n\n" + normalizedRuntime; + } + + private static String normalizePrompt(String prompt) { + if (prompt == null || prompt.trim().isEmpty()) { + return null; + } + return prompt; + } + @Override public float[] embed(String text) { - EmbeddingCreateParams params = - EmbeddingCreateParams.builder().model(modelName).input(text).build(); + EmbeddingCreateParams.Builder builder = + EmbeddingCreateParams.builder().model(modelName).input(text); + if (embeddingDimension != null) { + builder.dimensions(embeddingDimension.longValue()); + } + for (Map.Entry> entry : extraHeaders.entrySet()) { + builder.putAdditionalHeaders(entry.getKey(), entry.getValue()); + } + for (Map.Entry entry : extraBody.entrySet()) { + builder.putAdditionalBodyProperty(entry.getKey(), JsonValue.from(entry.getValue())); + } + EmbeddingCreateParams params = builder.build(); CreateEmbeddingResponse response = client.embeddings().create(params); List data = response.data(); if (data.isEmpty()) { diff --git a/flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/src/main/java/org/apache/flink/cdc/models/openai/OpenAiCompatibleModelClientFactory.java b/flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/src/main/java/org/apache/flink/cdc/models/openai/OpenAiCompatibleModelClientFactory.java index c0050dc03dc..d481cb1a9a2 100644 --- a/flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/src/main/java/org/apache/flink/cdc/models/openai/OpenAiCompatibleModelClientFactory.java +++ b/flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/src/main/java/org/apache/flink/cdc/models/openai/OpenAiCompatibleModelClientFactory.java @@ -17,35 +17,167 @@ package org.apache.flink.cdc.models.openai; +import org.apache.flink.cdc.common.configuration.ConfigOption; +import org.apache.flink.cdc.common.configuration.Configuration; +import org.apache.flink.cdc.common.factories.AiModelClientFactory; +import org.apache.flink.cdc.common.factories.Factory; import org.apache.flink.cdc.common.model.AiModelClient; -import org.apache.flink.cdc.common.model.AiModelClientFactory; -import org.apache.flink.cdc.common.model.ModelContext; +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; import java.util.Set; /** SPI factory for {@link OpenAiCompatibleModelClient}. */ public class OpenAiCompatibleModelClientFactory implements AiModelClientFactory { + private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); + @Override public String identifier() { return "openai-compatible"; } @Override - public Set requiredOptions() { - return Set.of("endpoint", "api-key", "model-name"); + public Set> requiredOptions() { + return OpenAiCompatibleModelOptions.requiredOptions(); } @Override - public Set optionalOptions() { - return Set.of(); + public Set> optionalOptions() { + return OpenAiCompatibleModelOptions.optionalOptions(); } @Override - public AiModelClient createClient(ModelContext context) { - String endpoint = context.getOptions().get("endpoint"); - String apiKey = context.getOptions().get("api-key"); - String modelName = context.getOptions().get("model-name"); - return new OpenAiCompatibleModelClient(endpoint, apiKey, modelName); + public AiModelClient createClient(Factory.Context context) { + Configuration configuration = context.getFactoryConfiguration(); + String endpoint = configuration.get(OpenAiCompatibleModelOptions.ENDPOINT); + String apiKey = configuration.get(OpenAiCompatibleModelOptions.API_KEY); + String modelName = configuration.get(OpenAiCompatibleModelOptions.MODEL_NAME); + String systemPrompt = + configuration.getOptional(OpenAiCompatibleModelOptions.SYSTEM_PROMPT).orElse(null); + Double temperature = + configuration.getOptional(OpenAiCompatibleModelOptions.TEMPERATURE).orElse(null); + Double topP = configuration.getOptional(OpenAiCompatibleModelOptions.TOP_P).orElse(null); + String stop = configuration.getOptional(OpenAiCompatibleModelOptions.STOP).orElse(null); + Integer maxTokens = + configuration.getOptional(OpenAiCompatibleModelOptions.MAX_TOKENS).orElse(null); + Double presencePenalty = + configuration + .getOptional(OpenAiCompatibleModelOptions.PRESENCE_PENALTY) + .orElse(null); + Integer n = configuration.getOptional(OpenAiCompatibleModelOptions.N).orElse(null); + Long seed = configuration.getOptional(OpenAiCompatibleModelOptions.SEED).orElse(null); + String responseFormat = configuration.get(OpenAiCompatibleModelOptions.RESPONSE_FORMAT); + String extraHeader = + configuration.getOptional(OpenAiCompatibleModelOptions.EXTRA_HEADER).orElse(null); + String extraBody = + configuration.getOptional(OpenAiCompatibleModelOptions.EXTRA_BODY).orElse(null); + Integer dimension = + configuration.getOptional(OpenAiCompatibleModelOptions.DIMENSION).orElse(null); + + validateValueRanges(temperature, presencePenalty, n, maxTokens, dimension); + validateEnumValue("response-format", responseFormat, "text", "json_object"); + + return new OpenAiCompatibleModelClient( + endpoint, + apiKey, + modelName, + systemPrompt, + temperature, + topP, + stop, + maxTokens, + presencePenalty, + n, + seed, + responseFormat, + parseExtraHeaders(extraHeader), + parseExtraBody(extraBody), + dimension); + } + + private static void validateValueRanges( + Double temperature, + Double presencePenalty, + Integer n, + Integer maxTokens, + Integer dimension) { + if (temperature != null && (temperature < 0 || temperature >= 2)) { + throw new IllegalArgumentException("Option 'temperature' must be in range [0, 2)."); + } + if (presencePenalty != null && (presencePenalty < -2 || presencePenalty > 2)) { + throw new IllegalArgumentException( + "Option 'presence-penalty' must be in range [-2.0, 2.0]."); + } + if (n != null && n <= 0) { + throw new IllegalArgumentException("Option 'n' must be greater than 0."); + } + if (maxTokens != null && maxTokens <= 0) { + throw new IllegalArgumentException("Option 'max-tokens' must be greater than 0."); + } + if (dimension != null && dimension <= 0) { + throw new IllegalArgumentException("Option 'dimension' must be greater than 0."); + } + } + + private static void validateEnumValue( + String optionKey, String value, String first, String second) { + if (!first.equals(value) && !second.equals(value)) { + throw new IllegalArgumentException( + String.format( + "Option '%s' must be '%s' or '%s', but got '%s'.", + optionKey, first, second, value)); + } + } + + private static Map> parseExtraHeaders(String rawExtraHeader) { + if (rawExtraHeader == null || rawExtraHeader.trim().isEmpty()) { + return Collections.emptyMap(); + } + Map rawMap = parseJsonObject(rawExtraHeader, "extra-header"); + Map> parsed = new LinkedHashMap<>(); + for (Map.Entry entry : rawMap.entrySet()) { + Object value = entry.getValue(); + if (value instanceof String) { + parsed.put(entry.getKey(), Collections.singletonList((String) value)); + } else if (value instanceof List) { + List values = new ArrayList<>(); + for (Object element : (List) value) { + if (!(element instanceof String)) { + throw new IllegalArgumentException( + "Option 'extra-header' only supports string values or string arrays."); + } + values.add((String) element); + } + parsed.put(entry.getKey(), values); + } else { + throw new IllegalArgumentException( + "Option 'extra-header' only supports string values or string arrays."); + } + } + return parsed; + } + + private static Map parseExtraBody(String rawExtraBody) { + if (rawExtraBody == null || rawExtraBody.trim().isEmpty()) { + return Collections.emptyMap(); + } + return parseJsonObject(rawExtraBody, "extra-body"); + } + + private static Map parseJsonObject(String rawJson, String optionName) { + try { + return OBJECT_MAPPER.readValue(rawJson, new TypeReference>() {}); + } catch (Exception e) { + throw new IllegalArgumentException( + String.format("Option '%s' must be a valid JSON object string.", optionName), + e); + } } } diff --git a/flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/src/main/java/org/apache/flink/cdc/models/openai/OpenAiCompatibleModelOptions.java b/flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/src/main/java/org/apache/flink/cdc/models/openai/OpenAiCompatibleModelOptions.java new file mode 100644 index 00000000000..9fcfc64ca37 --- /dev/null +++ b/flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/src/main/java/org/apache/flink/cdc/models/openai/OpenAiCompatibleModelOptions.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.cdc.models.openai; + +import org.apache.flink.cdc.common.configuration.ConfigOption; + +import java.util.Set; + +import static org.apache.flink.cdc.common.configuration.ConfigOptions.key; + +/** Config options for {@link OpenAiCompatibleModelClient}. */ +public class OpenAiCompatibleModelOptions { + + public static final ConfigOption ENDPOINT = + key("endpoint").stringType().noDefaultValue(); + + public static final ConfigOption API_KEY = key("api-key").stringType().noDefaultValue(); + + public static final ConfigOption MODEL_NAME = + key("model-name").stringType().noDefaultValue(); + + public static final ConfigOption SYSTEM_PROMPT = + key("system-prompt").stringType().noDefaultValue(); + + public static final ConfigOption TEMPERATURE = + key("temperature").doubleType().noDefaultValue(); + + public static final ConfigOption TOP_P = key("top-p").doubleType().noDefaultValue(); + + public static final ConfigOption STOP = key("stop").stringType().noDefaultValue(); + + public static final ConfigOption MAX_TOKENS = + key("max-tokens").intType().noDefaultValue(); + + public static final ConfigOption PRESENCE_PENALTY = + key("presence-penalty").doubleType().noDefaultValue(); + + public static final ConfigOption N = key("n").intType().noDefaultValue(); + + public static final ConfigOption SEED = key("seed").longType().noDefaultValue(); + + public static final ConfigOption RESPONSE_FORMAT = + key("response-format").stringType().defaultValue("text"); + + public static final ConfigOption EXTRA_HEADER = + key("extra-header").stringType().noDefaultValue(); + + public static final ConfigOption EXTRA_BODY = + key("extra-body").stringType().noDefaultValue(); + + public static final ConfigOption DIMENSION = + key("dimension").intType().noDefaultValue(); + + public static Set> requiredOptions() { + return Set.of(ENDPOINT, API_KEY, MODEL_NAME); + } + + public static Set> optionalOptions() { + return Set.of( + SYSTEM_PROMPT, + TEMPERATURE, + TOP_P, + STOP, + MAX_TOKENS, + PRESENCE_PENALTY, + N, + SEED, + RESPONSE_FORMAT, + EXTRA_HEADER, + EXTRA_BODY, + DIMENSION); + } + + private OpenAiCompatibleModelOptions() {} +} diff --git a/flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/src/main/resources/META-INF/services/org.apache.flink.cdc.common.model.AiModelClientFactory b/flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/src/main/resources/META-INF/services/org.apache.flink.cdc.common.factories.Factory similarity index 100% rename from flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/src/main/resources/META-INF/services/org.apache.flink.cdc.common.model.AiModelClientFactory rename to flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/src/main/resources/META-INF/services/org.apache.flink.cdc.common.factories.Factory diff --git a/flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/src/test/java/org/apache/flink/cdc/models/openai/OpenAiCompatibleModelClientFactoryTest.java b/flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/src/test/java/org/apache/flink/cdc/models/openai/OpenAiCompatibleModelClientFactoryTest.java index d9f3d11cd20..bba5a102f2f 100644 --- a/flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/src/test/java/org/apache/flink/cdc/models/openai/OpenAiCompatibleModelClientFactoryTest.java +++ b/flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/src/test/java/org/apache/flink/cdc/models/openai/OpenAiCompatibleModelClientFactoryTest.java @@ -17,15 +17,22 @@ package org.apache.flink.cdc.models.openai; +import org.apache.flink.cdc.common.configuration.ConfigOption; +import org.apache.flink.cdc.common.configuration.Configuration; +import org.apache.flink.cdc.common.factories.Factory; +import org.apache.flink.cdc.common.factories.FactoryHelper; import org.apache.flink.cdc.common.model.AiModelClient; -import org.apache.flink.cdc.common.model.ModelContext; import org.apache.flink.cdc.common.model.abilities.SupportsEmbedding; import org.apache.flink.cdc.common.model.abilities.SupportsTextGeneration; import org.junit.jupiter.api.Test; +import java.lang.reflect.Field; +import java.util.Arrays; import java.util.HashMap; +import java.util.List; import java.util.Map; +import java.util.stream.Collectors; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -35,23 +42,11 @@ class OpenAiCompatibleModelClientFactoryTest { private final OpenAiCompatibleModelClientFactory factory = new OpenAiCompatibleModelClientFactory(); - private ModelContext contextWithOptions(Map options) { - return new ModelContext() { - @Override - public String getModelName() { - return "test-model"; - } - - @Override - public Map getOptions() { - return options; - } - - @Override - public ClassLoader getClassLoader() { - return Thread.currentThread().getContextClassLoader(); - } - }; + private Factory.Context contextWithOptions(Map options) { + return new FactoryHelper.DefaultContext( + Configuration.fromMap(options), + new Configuration(), + Thread.currentThread().getContextClassLoader()); } @Test @@ -61,13 +56,35 @@ void testIdentifier() { @Test void testRequiredOptions() { - assertThat(factory.requiredOptions()) - .containsExactlyInAnyOrder("endpoint", "api-key", "model-name"); + assertThat( + factory.requiredOptions().stream() + .map(ConfigOption::key) + .collect(Collectors.toSet())) + .containsExactlyInAnyOrder( + OpenAiCompatibleModelOptions.ENDPOINT.key(), + OpenAiCompatibleModelOptions.API_KEY.key(), + OpenAiCompatibleModelOptions.MODEL_NAME.key()); } @Test void testOptionalOptions() { - assertThat(factory.optionalOptions()).isEmpty(); + assertThat( + factory.optionalOptions().stream() + .map(ConfigOption::key) + .collect(Collectors.toSet())) + .contains( + OpenAiCompatibleModelOptions.SYSTEM_PROMPT.key(), + OpenAiCompatibleModelOptions.TEMPERATURE.key(), + OpenAiCompatibleModelOptions.TOP_P.key(), + OpenAiCompatibleModelOptions.STOP.key(), + OpenAiCompatibleModelOptions.MAX_TOKENS.key(), + OpenAiCompatibleModelOptions.PRESENCE_PENALTY.key(), + OpenAiCompatibleModelOptions.N.key(), + OpenAiCompatibleModelOptions.SEED.key(), + OpenAiCompatibleModelOptions.RESPONSE_FORMAT.key(), + OpenAiCompatibleModelOptions.EXTRA_HEADER.key(), + OpenAiCompatibleModelOptions.EXTRA_BODY.key(), + OpenAiCompatibleModelOptions.DIMENSION.key()); } @Test @@ -83,6 +100,42 @@ void testCreateClient() { assertThat(client).isInstanceOf(SupportsEmbedding.class); } + @Test + void testCreateClientWithRequiredOptionsOnlyUsesDefaults() throws Exception { + Map options = new HashMap<>(); + options.put("endpoint", "https://api.example.com/v1"); + options.put("api-key", "sk-test"); + options.put("model-name", "gpt-4"); + + OpenAiCompatibleModelClient client = + (OpenAiCompatibleModelClient) factory.createClient(contextWithOptions(options)); + assertThat(readField(client, "globalSystemPrompt")).isNull(); + assertThat(readField(client, "responseFormat")).isEqualTo("text"); + } + + @SuppressWarnings("unchecked") + @Test + void testCreateClientParsesExtraHeaderAndBody() throws Exception { + Map options = new HashMap<>(); + options.put("endpoint", "https://api.example.com/v1"); + options.put("api-key", "sk-test"); + options.put("model-name", "gpt-4"); + options.put( + "extra-header", "{\"X-Trace-Id\":\"trace-1\",\"X-Tags\":[\"tag-a\",\"tag-b\"]}"); + options.put("extra-body", "{\"foo\":\"bar\",\"num\":1}"); + + OpenAiCompatibleModelClient client = + (OpenAiCompatibleModelClient) factory.createClient(contextWithOptions(options)); + Map> extraHeaders = + (Map>) readField(client, "extraHeaders"); + Map extraBody = (Map) readField(client, "extraBody"); + + assertThat(extraHeaders).containsEntry("X-Trace-Id", Arrays.asList("trace-1")); + assertThat(extraHeaders).containsEntry("X-Tags", Arrays.asList("tag-a", "tag-b")); + assertThat(extraBody).containsEntry("foo", "bar"); + assertThat(extraBody).containsKey("num"); + } + @Test void testValidatePassesWithAllRequiredOptions() { Map options = new HashMap<>(); @@ -90,7 +143,18 @@ void testValidatePassesWithAllRequiredOptions() { options.put("api-key", "sk-test"); options.put("model-name", "gpt-4"); - factory.validate(contextWithOptions(options)); + FactoryHelper.createFactoryHelper(factory, contextWithOptions(options)).validate(); + } + + @Test + void testValidatePassesWithOptionalSystemPrompt() { + Map options = new HashMap<>(); + options.put("endpoint", "https://api.example.com/v1"); + options.put("api-key", "sk-test"); + options.put("model-name", "gpt-4"); + options.put("system-prompt", "You are a strict JSON generator."); + + FactoryHelper.createFactoryHelper(factory, contextWithOptions(options)).validate(); } @Test @@ -98,9 +162,12 @@ void testValidateThrowsOnMissingRequiredOption() { Map options = new HashMap<>(); options.put("endpoint", "https://api.example.com/v1"); - assertThatThrownBy(() -> factory.validate(contextWithOptions(options))) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("Missing required options"); + assertThatThrownBy( + () -> + FactoryHelper.createFactoryHelper( + factory, contextWithOptions(options)) + .validate()) + .hasMessageContaining("required options are missing"); } @Test @@ -111,9 +178,12 @@ void testValidateThrowsOnUnknownOption() { options.put("model-name", "gpt-4"); options.put("unknown-key", "value"); - assertThatThrownBy(() -> factory.validate(contextWithOptions(options))) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("Unknown options"); + assertThatThrownBy( + () -> + FactoryHelper.createFactoryHelper( + factory, contextWithOptions(options)) + .validate()) + .hasMessageContaining("Unsupported options"); } @Test @@ -132,4 +202,36 @@ void testCreateClientFailFast() { .isInstanceOf(RuntimeException.class) .hasMessageContaining("Failed to perform livecheck on OpenAI model client"); } + + @Test + void testCreateClientThrowsOnInvalidTemperature() { + Map options = new HashMap<>(); + options.put("endpoint", "https://api.example.com/v1"); + options.put("api-key", "sk-test"); + options.put("model-name", "gpt-4"); + options.put("temperature", "2.0"); + + assertThatThrownBy(() -> factory.createClient(contextWithOptions(options))) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("temperature"); + } + + @Test + void testCreateClientThrowsOnInvalidExtraHeaderJson() { + Map options = new HashMap<>(); + options.put("endpoint", "https://api.example.com/v1"); + options.put("api-key", "sk-test"); + options.put("model-name", "gpt-4"); + options.put("extra-header", "not-a-json"); + + assertThatThrownBy(() -> factory.createClient(contextWithOptions(options))) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("extra-header"); + } + + private static Object readField(Object target, String fieldName) throws Exception { + Field field = target.getClass().getDeclaredField(fieldName); + field.setAccessible(true); + return field.get(target); + } } diff --git a/flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/src/test/java/org/apache/flink/cdc/models/openai/OpenAiCompatibleModelClientITCase.java b/flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/src/test/java/org/apache/flink/cdc/models/openai/OpenAiCompatibleModelClientITCase.java index 33a3bc38934..b508caeb414 100644 --- a/flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/src/test/java/org/apache/flink/cdc/models/openai/OpenAiCompatibleModelClientITCase.java +++ b/flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/src/test/java/org/apache/flink/cdc/models/openai/OpenAiCompatibleModelClientITCase.java @@ -22,22 +22,41 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import java.util.Collections; + import static org.assertj.core.api.Assertions.assertThat; class OpenAiCompatibleModelClientITCase { private OpenAiCompatibleModelClient client; + private final String endpoint = System.getenv("OPENAI_BASE_URL"); + private final String apiKey = System.getenv("OPENAI_API_KEY"); + private final String model = System.getenv("OPENAI_MODEL"); + @BeforeEach void setUp() { - String endpoint = System.getenv("OPENAI_BASE_URL"); - String apiKey = System.getenv("OPENAI_API_KEY"); - String model = System.getenv("OPENAI_MODEL"); Assumptions.assumeThat(endpoint != null && apiKey != null && model != null) .as("OPENAI_BASE_URL, OPENAI_API_KEY and OPENAI_MODEL must be set") .isTrue(); - client = new OpenAiCompatibleModelClient(endpoint, apiKey, model); + client = + new OpenAiCompatibleModelClient( + endpoint, + apiKey, + model, + "You are a helpful assistant.", + null, + null, + null, + null, + null, + null, + null, + "text", + Collections.emptyMap(), + Collections.emptyMap(), + null); client.open(); } @@ -60,4 +79,35 @@ void testGenerateWithEmptyUserInput() { String result = client.generate("Reply with exactly: OK", ""); assertThat(result).isNotNull().contains("OK"); } + + @Test + void testGenerateWithGlobalSystemPrompt() { + String marker = "GLOBAL_MARKER_9F2A"; + OpenAiCompatibleModelClient globalPromptClient = + new OpenAiCompatibleModelClient( + endpoint, + apiKey, + model, + "You must always include the exact token '" + + marker + + "' in every response.", + null, + null, + null, + null, + null, + null, + null, + "text", + Collections.emptyMap(), + Collections.emptyMap(), + null); + try (globalPromptClient) { + globalPromptClient.open(); + String result = + globalPromptClient.generate( + "Reply in one short sentence to say hello.", "Say hello."); + assertThat(result).isNotNull().contains(marker); + } + } } diff --git a/flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/src/test/java/org/apache/flink/cdc/models/openai/OpenAiCompatibleModelClientTest.java b/flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/src/test/java/org/apache/flink/cdc/models/openai/OpenAiCompatibleModelClientTest.java new file mode 100644 index 00000000000..512cfef02ee --- /dev/null +++ b/flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/src/test/java/org/apache/flink/cdc/models/openai/OpenAiCompatibleModelClientTest.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.cdc.models.openai; + +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; + +class OpenAiCompatibleModelClientTest { + + @Test + void testMergeSystemPromptWithGlobalAndRuntime() { + assertThat(OpenAiCompatibleModelClient.mergeSystemPrompt("global", "runtime")) + .isEqualTo("global\n\nruntime"); + } + + @Test + void testMergeSystemPromptWithGlobalOnly() { + assertThat(OpenAiCompatibleModelClient.mergeSystemPrompt("global", null)) + .isEqualTo("global"); + } + + @Test + void testMergeSystemPromptWithRuntimeOnly() { + assertThat(OpenAiCompatibleModelClient.mergeSystemPrompt(null, "runtime")) + .isEqualTo("runtime"); + } + + @Test + void testMergeSystemPromptSkipsBlankGlobal() { + assertThat(OpenAiCompatibleModelClient.mergeSystemPrompt(" ", "runtime")) + .isEqualTo("runtime"); + } + + @Test + void testMergeSystemPromptSkipsBlankRuntime() { + assertThat(OpenAiCompatibleModelClient.mergeSystemPrompt("global", " ")) + .isEqualTo("global"); + } + + @Test + void testMergeSystemPromptReturnsEmptyWhenBothBlank() { + assertThat(OpenAiCompatibleModelClient.mergeSystemPrompt(" ", " ")).isEmpty(); + } +}