diff --git a/flink-cdc-cli/src/main/java/org/apache/flink/cdc/cli/parser/YamlPipelineDefinitionParser.java b/flink-cdc-cli/src/main/java/org/apache/flink/cdc/cli/parser/YamlPipelineDefinitionParser.java index 831bb5e6536..9cad4d81c93 100644 --- a/flink-cdc-cli/src/main/java/org/apache/flink/cdc/cli/parser/YamlPipelineDefinitionParser.java +++ b/flink-cdc-cli/src/main/java/org/apache/flink/cdc/cli/parser/YamlPipelineDefinitionParser.java @@ -44,6 +44,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; +import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; import java.util.List; @@ -96,9 +97,8 @@ public class YamlPipelineDefinitionParser implements PipelineDefinitionParser { private static final String UDF_OPTIONS_KEY = "options"; // Model related keys - private static final String MODEL_NAME_KEY = "model-name"; - - private static final String MODEL_CLASS_NAME_KEY = "class-name"; + private static final String MODEL_NAME_KEY = "name"; + private static final String MODEL_TYPE_KEY = "type"; public static final String TRANSFORM_PRIMARY_KEY_KEY = "primary-keys"; @@ -145,7 +145,6 @@ private PipelineDef parse(JsonNode pipelineDefJsonNode, Configuration globalPipe Optional.ofNullable( ((ObjectNode) pipelineDefJsonNode.get(PIPELINE_KEY)).remove(MODEL_KEY)) - .map(node -> validateArray("model", node)) .ifPresent(node -> modelDefs.addAll(parseModels(node))); } @@ -428,24 +427,45 @@ private List parseModels(JsonNode modelsNode) { } else { modelDefs.add(convertJsonNodeToModelDef(modelsNode)); } + Set seenNames = new HashSet<>(); + for (ModelDef model : modelDefs) { + if (!seenNames.add(model.getName())) { + throw new IllegalArgumentException( + "Duplicate model name '" + model.getName() + "' in pipeline definition."); + } + } return modelDefs; } private ModelDef convertJsonNodeToModelDef(JsonNode modelNode) { + Preconditions.checkArgument( + modelNode instanceof ObjectNode, + "`model` in `pipeline` should be an object, but got %s", + modelNode); + ObjectNode node = (ObjectNode) modelNode; String name = checkNotNull( - modelNode.get(MODEL_NAME_KEY), + node.remove(MODEL_NAME_KEY), "Missing required field \"%s\" in `model`", MODEL_NAME_KEY) .asText(); - String model = + Preconditions.checkArgument( + name.matches("[a-zA-Z_][a-zA-Z0-9_]*") && !name.startsWith("__"), + "Model name \"%s\" is not a valid identifier. " + + "It must start with a letter or underscore, " + + "contain only letters, digits, or underscores, " + + "and must not start with double underscores.", + name); + String type = checkNotNull( - modelNode.get(MODEL_CLASS_NAME_KEY), + node.remove(MODEL_TYPE_KEY), "Missing required field \"%s\" in `model`", - MODEL_CLASS_NAME_KEY) + MODEL_TYPE_KEY) .asText(); - Map properties = mapper.convertValue(modelNode, Map.class); - return new ModelDef(name, model, properties); + Map options = new HashMap<>(); + node.fields() + .forEachRemaining(entry -> options.put(entry.getKey(), entry.getValue().asText())); + return new ModelDef(name, type, options); } private void validateJsonNodeKeys( diff --git a/flink-cdc-cli/src/test/java/org/apache/flink/cdc/cli/parser/YamlPipelineDefinitionParserTest.java b/flink-cdc-cli/src/test/java/org/apache/flink/cdc/cli/parser/YamlPipelineDefinitionParserTest.java index c8c10db76e7..bf564d1c8e7 100644 --- a/flink-cdc-cli/src/test/java/org/apache/flink/cdc/cli/parser/YamlPipelineDefinitionParserTest.java +++ b/flink-cdc-cli/src/test/java/org/apache/flink/cdc/cli/parser/YamlPipelineDefinitionParserTest.java @@ -34,6 +34,8 @@ import org.apache.flink.shaded.guava31.com.google.common.io.Resources; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; import java.net.URL; import java.time.Duration; @@ -422,17 +424,28 @@ private void testSchemaEvolutionTypesParsing( "add new uniq_id for each row", null)), Collections.emptyList(), - Collections.singletonList( + Arrays.asList( + new ModelDef( + "Sonnet", + "openai-compatible", + new LinkedHashMap<>( + ImmutableMap.builder() + .put("model-name", "claude-sonnet-4-6") + .put( + "endpoint", + "https://idealab.alibaba-inc.com/api/openai/v1") + .put("api-key", "cafebabe") + .build())), new ModelDef( - "GET_EMBEDDING", - "OpenAIEmbeddingModel", + "Opus", + "openai-compatible", new LinkedHashMap<>( ImmutableMap.builder() - .put("model-name", "GET_EMBEDDING") - .put("class-name", "OpenAIEmbeddingModel") - .put("openai.model", "text-embedding-3-small") - .put("openai.host", "https://xxxx") - .put("openai.apikey", "abcd1234") + .put("model-name", "claude-opus-4-5") + .put( + "endpoint", + "https://idealab.alibaba-inc.com/api/openai/v1") + .put("api-key", "cafebabe") .build()))), Configuration.fromMap( ImmutableMap.builder() @@ -492,11 +505,16 @@ void testParsingFullDefinitionFromString() throws Exception { + " schema-operator.rpc-timeout: 1 h\n" + " execution.runtime-mode: STREAMING\n" + " model:\n" - + " - model-name: GET_EMBEDDING\n" - + " class-name: OpenAIEmbeddingModel\n" - + " openai.model: text-embedding-3-small\n" - + " openai.host: https://xxxx\n" - + " openai.apikey: abcd1234"; + + " - name: Sonnet\n" + + " type: openai-compatible\n" + + " model-name: claude-sonnet-4-6\n" + + " endpoint: https://idealab.alibaba-inc.com/api/openai/v1\n" + + " api-key: cafebabe\n" + + " - name: Opus\n" + + " type: openai-compatible\n" + + " model-name: claude-opus-4-5\n" + + " endpoint: https://idealab.alibaba-inc.com/api/openai/v1\n" + + " api-key: cafebabe\n"; YamlPipelineDefinitionParser parser = new YamlPipelineDefinitionParser(); PipelineDef pipelineDef = parser.parse(pipelineDefText, new Configuration()); assertThat(pipelineDef).isEqualTo(fullDef); @@ -560,17 +578,28 @@ void testParsingFullDefinitionFromString() throws Exception { "add new uniq_id for each row", null)), Collections.emptyList(), - Collections.singletonList( + Arrays.asList( new ModelDef( - "GET_EMBEDDING", - "OpenAIEmbeddingModel", + "Sonnet", + "openai-compatible", new LinkedHashMap<>( ImmutableMap.builder() - .put("model-name", "GET_EMBEDDING") - .put("class-name", "OpenAIEmbeddingModel") - .put("openai.model", "text-embedding-3-small") - .put("openai.host", "https://xxxx") - .put("openai.apikey", "abcd1234") + .put("model-name", "claude-sonnet-4-6") + .put( + "endpoint", + "https://idealab.alibaba-inc.com/api/openai/v1") + .put("api-key", "cafebabe") + .build())), + new ModelDef( + "Opus", + "openai-compatible", + new LinkedHashMap<>( + ImmutableMap.builder() + .put("model-name", "claude-opus-4-5") + .put( + "endpoint", + "https://idealab.alibaba-inc.com/api/openai/v1") + .put("api-key", "cafebabe") .build()))), Configuration.fromMap( ImmutableMap.builder() @@ -710,6 +739,189 @@ void testParsingFullDefinitionFromString() throws Exception { .put("schema-operator.rpc-timeout", "1 h") .build())); + @Test + void testParsingSingleModelAsObject() throws Exception { + String pipelineDefText = + "source:\n" + + " type: foo\n" + + "sink:\n" + + " type: bar\n" + + "pipeline:\n" + + " parallelism: 1\n" + + " model:\n" + + " name: Sonnet\n" + + " type: openai-compatible\n" + + " model-name: claude-sonnet-4-6\n" + + " endpoint: https://example.com/v1\n" + + " api-key: test-key\n"; + YamlPipelineDefinitionParser parser = new YamlPipelineDefinitionParser(); + PipelineDef pipelineDef = parser.parse(pipelineDefText, new Configuration()); + assertThat(pipelineDef.getModels()).hasSize(1); + assertThat(pipelineDef.getModels().get(0)) + .isEqualTo( + new ModelDef( + "Sonnet", + "openai-compatible", + new LinkedHashMap<>( + ImmutableMap.of( + "model-name", + "claude-sonnet-4-6", + "endpoint", + "https://example.com/v1", + "api-key", + "test-key")))); + } + + @Test + void testParsingMultipleModelsAsList() throws Exception { + String pipelineDefText = + "source:\n" + + " type: foo\n" + + "sink:\n" + + " type: bar\n" + + "pipeline:\n" + + " parallelism: 1\n" + + " model:\n" + + " - name: Sonnet\n" + + " type: openai-compatible\n" + + " model-name: claude-sonnet-4-6\n" + + " endpoint: https://example.com/v1\n" + + " api-key: key1\n" + + " - name: Opus\n" + + " type: openai-compatible\n" + + " model-name: claude-opus-4-5\n" + + " endpoint: https://example.com/v1\n" + + " api-key: key2\n"; + YamlPipelineDefinitionParser parser = new YamlPipelineDefinitionParser(); + PipelineDef pipelineDef = parser.parse(pipelineDefText, new Configuration()); + assertThat(pipelineDef.getModels()).hasSize(2); + assertThat(pipelineDef.getModels().get(0)) + .isEqualTo( + new ModelDef( + "Sonnet", + "openai-compatible", + new LinkedHashMap<>( + ImmutableMap.of( + "model-name", + "claude-sonnet-4-6", + "endpoint", + "https://example.com/v1", + "api-key", + "key1")))); + assertThat(pipelineDef.getModels().get(1)) + .isEqualTo( + new ModelDef( + "Opus", + "openai-compatible", + new LinkedHashMap<>( + ImmutableMap.of( + "model-name", + "claude-opus-4-5", + "endpoint", + "https://example.com/v1", + "api-key", + "key2")))); + } + + @Test + void testDuplicateModelName() { + String pipelineDefText = + "source:\n" + + " type: foo\n" + + "sink:\n" + + " type: bar\n" + + "pipeline:\n" + + " parallelism: 1\n" + + " model:\n" + + " - name: Sonnet\n" + + " type: openai-compatible\n" + + " model-name: claude-sonnet-4-6\n" + + " endpoint: https://example.com/v1\n" + + " api-key: key1\n" + + " - name: Sonnet\n" + + " type: openai-compatible\n" + + " model-name: claude-sonnet-4-5\n" + + " endpoint: https://example.com/v1\n" + + " api-key: key2\n"; + YamlPipelineDefinitionParser parser = new YamlPipelineDefinitionParser(); + assertThatThrownBy(() -> parser.parse(pipelineDefText, new Configuration())) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Duplicate model name 'Sonnet' in pipeline definition."); + } + + @ParameterizedTest + @ValueSource(strings = {"123model", "my-model", "__reserved", "my model", "my.model"}) + void testInvalidModelName(String invalidName) { + YamlPipelineDefinitionParser parser = new YamlPipelineDefinitionParser(); + assertThatThrownBy( + () -> + parser.parse( + buildPipelineDefWithModelName(invalidName), + new Configuration())) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining( + "Model name \"" + invalidName + "\" is not a valid identifier"); + } + + @Test + void testModelOptionsValueTypes() throws Exception { + YamlPipelineDefinitionParser parser = new YamlPipelineDefinitionParser(); + + // Boolean, integer, and float values in YAML should all be converted to String + String yaml = + "source:\n" + + " type: foo\n" + + "sink:\n" + + " type: bar\n" + + "pipeline:\n" + + " parallelism: 1\n" + + " model:\n" + + " - name: myModel\n" + + " type: dummy\n" + + " debug: true\n" + + " max-tokens: 1024\n" + + " temperature: 0.7\n" + + " endpoint: https://example.com/v1\n"; + + PipelineDef def = parser.parse(yaml, new Configuration()); + assertThat(def.getModels()).hasSize(1); + ModelDef model = def.getModels().get(0); + assertThat(model.getOptions()) + .containsEntry("debug", "true") + .containsEntry("max-tokens", "1024") + .containsEntry("temperature", "0.7") + .containsEntry("endpoint", "https://example.com/v1"); + // Verify all values are String type + model.getOptions().values().forEach(v -> assertThat(v).isInstanceOf(String.class)); + } + + @ParameterizedTest + @ValueSource(strings = {"Sonnet", "my_model_v2", "_private"}) + void testValidModelName(String validName) throws Exception { + YamlPipelineDefinitionParser parser = new YamlPipelineDefinitionParser(); + PipelineDef def = + parser.parse(buildPipelineDefWithModelName(validName), new Configuration()); + assertThat(def.getModels()).hasSize(1); + assertThat(def.getModels().get(0).getName()).isEqualTo(validName); + } + + private String buildPipelineDefWithModelName(String modelName) { + return "source:\n" + + " type: foo\n" + + "sink:\n" + + " type: bar\n" + + "pipeline:\n" + + " parallelism: 1\n" + + " model:\n" + + " - name: " + + modelName + + "\n" + + " type: openai-compatible\n" + + " model-name: some-model\n" + + " endpoint: https://example.com/v1\n" + + " api-key: test-key\n"; + } + private final PipelineDef pipelineDefWithUdf = new PipelineDef( new SourceDef("values", null, new Configuration()), diff --git a/flink-cdc-cli/src/test/resources/definitions/pipeline-definition-full.yaml b/flink-cdc-cli/src/test/resources/definitions/pipeline-definition-full.yaml index dec2c25dc23..8af7ac40d35 100644 --- a/flink-cdc-cli/src/test/resources/definitions/pipeline-definition-full.yaml +++ b/flink-cdc-cli/src/test/resources/definitions/pipeline-definition-full.yaml @@ -60,8 +60,13 @@ pipeline: schema-operator.rpc-timeout: 1 h execution.runtime-mode: STREAMING model: - - model-name: GET_EMBEDDING - class-name: OpenAIEmbeddingModel - openai.model: text-embedding-3-small - openai.host: https://xxxx - openai.apikey: abcd1234 + - name: Sonnet + type: openai-compatible + model-name: claude-sonnet-4-6 + endpoint: https://idealab.alibaba-inc.com/api/openai/v1 + api-key: cafebabe + - name: Opus + type: openai-compatible + model-name: claude-opus-4-5 + endpoint: https://idealab.alibaba-inc.com/api/openai/v1 + api-key: cafebabe diff --git a/flink-cdc-common/src/main/java/org/apache/flink/cdc/common/factories/AiModelClientFactory.java b/flink-cdc-common/src/main/java/org/apache/flink/cdc/common/factories/AiModelClientFactory.java new file mode 100644 index 00000000000..974a07fbad3 --- /dev/null +++ b/flink-cdc-common/src/main/java/org/apache/flink/cdc/common/factories/AiModelClientFactory.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.cdc.common.factories; + +import org.apache.flink.cdc.common.annotation.Experimental; +import org.apache.flink.cdc.common.model.AiModelClient; + +/** + * A factory to create {@link AiModelClient} instances. See also {@link Factory} for more + * information. + */ +@Experimental +public interface AiModelClientFactory extends Factory { + + /** Creates a new {@link AiModelClient} instance. */ + AiModelClient createClient(Context context); +} diff --git a/flink-cdc-common/src/main/java/org/apache/flink/cdc/common/model/AiModelClient.java b/flink-cdc-common/src/main/java/org/apache/flink/cdc/common/model/AiModelClient.java new file mode 100644 index 00000000000..0e46508bb2e --- /dev/null +++ b/flink-cdc-common/src/main/java/org/apache/flink/cdc/common/model/AiModelClient.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.cdc.common.model; + +import org.apache.flink.cdc.common.annotation.Experimental; + +import java.io.Serializable; + +/** + * Marker interface for a runtime AI model client. Concrete capabilities are declared via ability + * interfaces in {@code org.apache.flink.cdc.common.model.abilities}. + * + *

Implementations must be {@link Serializable} so that they can be distributed across Flink task + * managers together with the operator that holds them. + */ +@Experimental +public interface AiModelClient extends Serializable, AutoCloseable { + + default void open() throws Exception { + // Do nothing + } + + @Override + default void close() throws Exception { + // Do nothing + } +} diff --git a/flink-cdc-common/src/main/java/org/apache/flink/cdc/common/model/abilities/SupportsEmbedding.java b/flink-cdc-common/src/main/java/org/apache/flink/cdc/common/model/abilities/SupportsEmbedding.java new file mode 100644 index 00000000000..d6c8fc7d1c3 --- /dev/null +++ b/flink-cdc-common/src/main/java/org/apache/flink/cdc/common/model/abilities/SupportsEmbedding.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.cdc.common.model.abilities; + +import org.apache.flink.cdc.common.annotation.Experimental; +import org.apache.flink.cdc.common.model.AiModelClient; + +/** + * Ability interface for {@link AiModelClient} implementations that can produce dense vector + * embeddings from text input. + */ +@Experimental +public interface SupportsEmbedding { + + /** Converts the given text into a dense float vector. */ + float[] embed(String text); +} diff --git a/flink-cdc-common/src/main/java/org/apache/flink/cdc/common/model/abilities/SupportsTextGeneration.java b/flink-cdc-common/src/main/java/org/apache/flink/cdc/common/model/abilities/SupportsTextGeneration.java new file mode 100644 index 00000000000..eb1ea7362e4 --- /dev/null +++ b/flink-cdc-common/src/main/java/org/apache/flink/cdc/common/model/abilities/SupportsTextGeneration.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.cdc.common.model.abilities; + +import org.apache.flink.cdc.common.annotation.Experimental; +import org.apache.flink.cdc.common.model.AiModelClient; + +/** + * Ability interface for {@link AiModelClient} implementations that can perform chat-style text + * generation given a system prompt and a user input. + */ +@Experimental +public interface SupportsTextGeneration { + + /** + * Generates text based on a system-level prompt and a user-provided input message. Returns a + * JSON string conforming to the output schema declared by the calling AI function. + */ + String generate(String systemPrompt, String userInput); +} diff --git a/flink-cdc-composer/pom.xml b/flink-cdc-composer/pom.xml index 4161491ba04..522714c9e45 100644 --- a/flink-cdc-composer/pom.xml +++ b/flink-cdc-composer/pom.xml @@ -80,9 +80,11 @@ limitations under the License. ${project.version} test + + org.apache.flink - flink-cdc-pipeline-model + flink-cdc-pipeline-model-dummy ${project.version} test diff --git a/flink-cdc-composer/src/main/java/org/apache/flink/cdc/composer/definition/ModelDef.java b/flink-cdc-composer/src/main/java/org/apache/flink/cdc/composer/definition/ModelDef.java index 21cc6befaf2..6d16e222c7b 100644 --- a/flink-cdc-composer/src/main/java/org/apache/flink/cdc/composer/definition/ModelDef.java +++ b/flink-cdc-composer/src/main/java/org/apache/flink/cdc/composer/definition/ModelDef.java @@ -20,41 +20,31 @@ import java.util.Map; import java.util.Objects; -/** - * Common properties of model. - * - *

A transformation definition contains: - * - *

    - *
  • modelName: The name of function. - *
  • className: The model to transform data. - *
  • parameters: The parameters that used to configure the model. - *
- */ +/** Common properties of model. */ public class ModelDef { - private final String modelName; + private final String name; - private final String className; + private final String type; - private final Map parameters; + private final Map options; - public ModelDef(String modelName, String className, Map parameters) { - this.modelName = modelName; - this.className = className; - this.parameters = parameters; + public ModelDef(String name, String type, Map options) { + this.name = name; + this.type = type; + this.options = options; } - public String getModelName() { - return modelName; + public String getName() { + return name; } - public String getClassName() { - return className; + public String getType() { + return type; } - public Map getParameters() { - return parameters; + public Map getOptions() { + return options; } @Override @@ -66,27 +56,27 @@ public boolean equals(Object o) { return false; } ModelDef modelDef = (ModelDef) o; - return Objects.equals(modelName, modelDef.modelName) - && Objects.equals(className, modelDef.className) - && Objects.equals(parameters, modelDef.parameters); + return Objects.equals(name, modelDef.name) + && Objects.equals(type, modelDef.type) + && Objects.equals(options, modelDef.options); } @Override public int hashCode() { - return Objects.hash(modelName, className, parameters); + return Objects.hash(name, type, options); } @Override public String toString() { return "ModelDef{" + "name='" - + modelName + + name + '\'' - + ", model='" - + className + + ", type='" + + type + '\'' - + ", parameters=" - + parameters + + ", options=" + + options + '}'; } } diff --git a/flink-cdc-composer/src/main/java/org/apache/flink/cdc/composer/flink/FlinkPipelineComposer.java b/flink-cdc-composer/src/main/java/org/apache/flink/cdc/composer/flink/FlinkPipelineComposer.java index 1854b05032b..d77c48bfee2 100644 --- a/flink-cdc-composer/src/main/java/org/apache/flink/cdc/composer/flink/FlinkPipelineComposer.java +++ b/flink-cdc-composer/src/main/java/org/apache/flink/cdc/composer/flink/FlinkPipelineComposer.java @@ -186,7 +186,6 @@ private void translate(StreamExecutionEnvironment env, PipelineDef pipelineDef) stream, pipelineDef.getTransforms(), pipelineDef.getUdfs(), - pipelineDef.getModels(), dataSource.supportedMetadataColumns()); // PreTransform ---> PostTransform diff --git a/flink-cdc-composer/src/main/java/org/apache/flink/cdc/composer/flink/translator/TransformTranslator.java b/flink-cdc-composer/src/main/java/org/apache/flink/cdc/composer/flink/translator/TransformTranslator.java index 6e296db4d1c..1364191500b 100644 --- a/flink-cdc-composer/src/main/java/org/apache/flink/cdc/composer/flink/translator/TransformTranslator.java +++ b/flink-cdc-composer/src/main/java/org/apache/flink/cdc/composer/flink/translator/TransformTranslator.java @@ -18,11 +18,17 @@ package org.apache.flink.cdc.composer.flink.translator; import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.cdc.common.configuration.Configuration; import org.apache.flink.cdc.common.event.Event; +import org.apache.flink.cdc.common.factories.AiModelClientFactory; +import org.apache.flink.cdc.common.factories.Factory; +import org.apache.flink.cdc.common.factories.FactoryHelper; +import org.apache.flink.cdc.common.model.AiModelClient; import org.apache.flink.cdc.common.source.SupportedMetadataColumn; import org.apache.flink.cdc.composer.definition.ModelDef; import org.apache.flink.cdc.composer.definition.TransformDef; import org.apache.flink.cdc.composer.definition.UdfDef; +import org.apache.flink.cdc.composer.utils.FactoryDiscoveryUtils; import org.apache.flink.cdc.runtime.operators.transform.PostTransformOperator; import org.apache.flink.cdc.runtime.operators.transform.PostTransformOperatorBuilder; import org.apache.flink.cdc.runtime.operators.transform.PreTransformOperator; @@ -30,6 +36,8 @@ import org.apache.flink.cdc.runtime.typeutils.EventTypeInfo; import org.apache.flink.streaming.api.datastream.DataStream; +import java.util.Collections; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.stream.Collectors; @@ -40,15 +48,10 @@ */ public class TransformTranslator { - /** Package of built-in model. */ - public static final String PREFIX_CLASSPATH_BUILT_IN_MODEL = - "org.apache.flink.cdc.runtime.model."; - public DataStream translatePreTransform( DataStream input, List transforms, List udfFunctions, - List models, SupportedMetadataColumn[] supportedMetadataColumns) { if (transforms.isEmpty()) { return input; @@ -56,13 +59,12 @@ public DataStream translatePreTransform( return input.transform( "Transform:Schema", new EventTypeInfo(), - generatePreTransform(transforms, udfFunctions, models, supportedMetadataColumns)); + generatePreTransform(transforms, udfFunctions, supportedMetadataColumns)); } private PreTransformOperator generatePreTransform( List transforms, List udfFunctions, - List models, SupportedMetadataColumn[] supportedMetadataColumns) { PreTransformOperatorBuilder preTransformFunctionBuilder = PreTransformOperator.newBuilder(); @@ -79,14 +81,8 @@ private PreTransformOperator generatePreTransform( supportedMetadataColumns); } - preTransformFunctionBuilder - .addUdfFunctions( - udfFunctions.stream() - .map(this::udfDefToUDFTuple) - .collect(Collectors.toList())) - .addUdfFunctions( - models.stream().map(this::modelToUDFTuple).collect(Collectors.toList())); - + preTransformFunctionBuilder.addUdfFunctions( + udfFunctions.stream().map(this::udfDefToUDFTuple).collect(Collectors.toList())); return preTransformFunctionBuilder.build(); } @@ -116,21 +112,43 @@ public DataStream translatePostTransform( transform.getPostTransformConverter(), supportedMetadataColumns); } - postTransformFunctionBuilder.addTimezone(timezone); - postTransformFunctionBuilder.addUdfFunctions( - udfFunctions.stream().map(this::udfDefToUDFTuple).collect(Collectors.toList())); - postTransformFunctionBuilder.addUdfFunctions( - models.stream().map(this::modelToUDFTuple).collect(Collectors.toList())); + postTransformFunctionBuilder + .addTimezone(timezone) + .addUdfFunctions( + udfFunctions.stream() + .map(this::udfDefToUDFTuple) + .collect(Collectors.toList())) + .addModelClients(loadModelClients(models)); + return input.transform( "Transform:Data", new EventTypeInfo(), postTransformFunctionBuilder.build()) .uid(operatorUidGenerator.generateUid("post-transform")); } - private Tuple3> modelToUDFTuple(ModelDef model) { - return Tuple3.of( - model.getModelName(), - PREFIX_CLASSPATH_BUILT_IN_MODEL + model.getClassName(), - model.getParameters()); + /** + * Loads AI model clients for all declared models via SPI, returning a map from Janino parameter + * name to the client instance. + */ + private Map loadModelClients(List models) { + if (models.isEmpty()) { + return Collections.emptyMap(); + } + + Map clients = new LinkedHashMap<>(); + for (ModelDef model : models) { + AiModelClientFactory factory = + FactoryDiscoveryUtils.getFactoryByIdentifier( + model.getType(), AiModelClientFactory.class); + Factory.Context ctx = + new FactoryHelper.DefaultContext( + Configuration.fromMap(model.getOptions()), + new Configuration(), + Thread.currentThread().getContextClassLoader()); + FactoryHelper.createFactoryHelper(factory, ctx).validate(); + AiModelClient client = factory.createClient(ctx); + clients.put(model.getName(), client); + } + return clients; } private Tuple3> udfDefToUDFTuple(UdfDef udf) { diff --git a/flink-cdc-composer/src/test/java/org/apache/flink/cdc/composer/flink/FlinkPipelineAiFunctionITCase.java b/flink-cdc-composer/src/test/java/org/apache/flink/cdc/composer/flink/FlinkPipelineAiFunctionITCase.java new file mode 100644 index 00000000000..37f1701065f --- /dev/null +++ b/flink-cdc-composer/src/test/java/org/apache/flink/cdc/composer/flink/FlinkPipelineAiFunctionITCase.java @@ -0,0 +1,212 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.cdc.composer.flink; + +import org.apache.flink.cdc.common.configuration.Configuration; +import org.apache.flink.cdc.common.data.binary.BinaryStringData; +import org.apache.flink.cdc.common.event.CreateTableEvent; +import org.apache.flink.cdc.common.event.DataChangeEvent; +import org.apache.flink.cdc.common.event.Event; +import org.apache.flink.cdc.common.event.TableId; +import org.apache.flink.cdc.common.pipeline.PipelineOptions; +import org.apache.flink.cdc.common.pipeline.SchemaChangeBehavior; +import org.apache.flink.cdc.common.schema.Schema; +import org.apache.flink.cdc.common.types.DataType; +import org.apache.flink.cdc.common.types.DataTypes; +import org.apache.flink.cdc.composer.PipelineExecution; +import org.apache.flink.cdc.composer.definition.ModelDef; +import org.apache.flink.cdc.composer.definition.PipelineDef; +import org.apache.flink.cdc.composer.definition.SinkDef; +import org.apache.flink.cdc.composer.definition.SourceDef; +import org.apache.flink.cdc.composer.definition.TransformDef; +import org.apache.flink.cdc.connectors.values.ValuesDatabase; +import org.apache.flink.cdc.connectors.values.factory.ValuesDataFactory; +import org.apache.flink.cdc.connectors.values.sink.ValuesDataSinkOptions; +import org.apache.flink.cdc.connectors.values.source.ValuesDataSourceHelper; +import org.apache.flink.cdc.connectors.values.source.ValuesDataSourceOptions; +import org.apache.flink.cdc.runtime.typeutils.BinaryRecordDataGenerator; +import org.apache.flink.runtime.testutils.MiniClusterResourceConfiguration; +import org.apache.flink.test.junit5.MiniClusterExtension; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import java.io.ByteArrayOutputStream; +import java.io.PrintStream; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; + +import static org.apache.flink.configuration.CoreOptions.ALWAYS_PARENT_FIRST_LOADER_PATTERNS_ADDITIONAL; +import static org.assertj.core.api.Assertions.assertThat; + +/** Integration test for AI functions in the Flink pipeline. */ +class FlinkPipelineAiFunctionITCase { + + private static final int MAX_PARALLELISM = 4; + + private static final org.apache.flink.configuration.Configuration MINI_CLUSTER_CONFIG = + new org.apache.flink.configuration.Configuration(); + + static { + MINI_CLUSTER_CONFIG.set( + ALWAYS_PARENT_FIRST_LOADER_PATTERNS_ADDITIONAL, + Collections.singletonList("org.apache.flink.cdc")); + } + + @RegisterExtension + static final MiniClusterExtension MINI_CLUSTER_RESOURCE = + new MiniClusterExtension( + new MiniClusterResourceConfiguration.Builder() + .setNumberTaskManagers(1) + .setNumberSlotsPerTaskManager(MAX_PARALLELISM) + .setConfiguration(MINI_CLUSTER_CONFIG) + .build()); + + private final PrintStream standardOut = System.out; + private final ByteArrayOutputStream outCaptor = new ByteArrayOutputStream(); + + @BeforeEach + void init() { + System.setOut(new PrintStream(outCaptor)); + ValuesDatabase.clear(); + } + + @AfterEach + void cleanup() { + System.setOut(standardOut); + } + + private static final String DUMMY_JSON = "{\"result\":\"dummy result\",\"summary\":\"TL;DR\"}"; + + @Test + void testAiCompleteInProjection() throws Exception { + String[] output = + runAiFunctionTest( + "id, content, AI_COMPLETE('testModel', content, 'Classify sentiment') AS res", + List.of(new ModelDef("testModel", "dummy", new HashMap<>()))); + assertThat(output) + .containsExactly( + "CreateTableEvent{tableId=default_namespace.default_schema.mytable1, schema=columns={`id` INT NOT NULL,`content` STRING,`res` VARIANT}, primaryKeys=id, options=()}", + "DataChangeEvent{tableId=default_namespace.default_schema.mytable1, before=[], after=[1, I love this product, " + + DUMMY_JSON + + "], op=INSERT, meta=()}"); + } + + @Test + void testAiEmbedInProjection() throws Exception { + String[] output = + runAiFunctionTest( + "id, AI_EMBED('embedModel', content) AS embedding", + List.of(new ModelDef("embedModel", "dummy", new HashMap<>()))); + assertThat(output) + .containsExactly( + "CreateTableEvent{tableId=default_namespace.default_schema.mytable1, schema=columns={`id` INT NOT NULL,`embedding` ARRAY}, primaryKeys=id, options=()}", + "DataChangeEvent{tableId=default_namespace.default_schema.mytable1, before=[], after=[1, [3.0, 1.0, 4.0, 1.0, 5.0, 9.0, 2.0, 6.0]], op=INSERT, meta=()}"); + } + + @Test + void testMultipleAiFunctionsWithSameModel() throws Exception { + String[] output = + runAiFunctionTest( + "id, AI_COMPLETE('myModel', content, 'Classify') AS cls, AI_SUMMARIZE('myModel', content, 100) AS summary", + List.of(new ModelDef("myModel", "dummy", new HashMap<>()))); + assertThat(output) + .containsExactly( + "CreateTableEvent{tableId=default_namespace.default_schema.mytable1, schema=columns={`id` INT NOT NULL,`cls` VARIANT,`summary` VARIANT}, primaryKeys=id, options=()}", + "DataChangeEvent{tableId=default_namespace.default_schema.mytable1, before=[], after=[1, " + + DUMMY_JSON + + ", " + + DUMMY_JSON + + "], op=INSERT, meta=()}"); + } + + private String[] runAiFunctionTest(String projection, List models) throws Exception { + FlinkPipelineComposer composer = FlinkPipelineComposer.ofMiniCluster(); + + // Source: one table with a single row + TableId tableId = TableId.tableId("default_namespace", "default_schema", "mytable1"); + Schema schema = + Schema.newBuilder() + .physicalColumn("id", DataTypes.INT()) + .physicalColumn("content", DataTypes.STRING()) + .primaryKey("id") + .build(); + BinaryRecordDataGenerator generator = + new BinaryRecordDataGenerator(schema.getColumnDataTypes().toArray(new DataType[0])); + + List events = new ArrayList<>(); + events.add(new CreateTableEvent(tableId, schema)); + events.add( + DataChangeEvent.insertEvent( + tableId, + generator.generate( + new Object[] { + 1, BinaryStringData.fromString("I love this product") + }))); + ValuesDataSourceHelper.setSourceEvents(Collections.singletonList(events)); + + Configuration sourceConfig = new Configuration(); + sourceConfig.set( + ValuesDataSourceOptions.EVENT_SET_ID, + ValuesDataSourceHelper.EventSetId.CUSTOM_SOURCE_EVENTS); + SourceDef sourceDef = + new SourceDef(ValuesDataFactory.IDENTIFIER, "Value Source", sourceConfig); + + // Sink + Configuration sinkConfig = new Configuration(); + sinkConfig.set(ValuesDataSinkOptions.MATERIALIZED_IN_MEMORY, true); + SinkDef sinkDef = new SinkDef(ValuesDataFactory.IDENTIFIER, "Value Sink", sinkConfig); + + // Transform + TransformDef transformDef = + new TransformDef( + "default_namespace.default_schema.mytable1", + projection, + null, + "id", + null, + null, + null, + null); + + // Pipeline + Configuration pipelineConfig = new Configuration(); + pipelineConfig.set(PipelineOptions.PIPELINE_PARALLELISM, 1); + pipelineConfig.set( + PipelineOptions.PIPELINE_SCHEMA_CHANGE_BEHAVIOR, SchemaChangeBehavior.EVOLVE); + PipelineDef pipelineDef = + new PipelineDef( + sourceDef, + sinkDef, + Collections.emptyList(), + Collections.singletonList(transformDef), + Collections.emptyList(), + models, + pipelineConfig); + + // Execute & capture output + PipelineExecution execution = composer.compose(pipelineDef); + execution.execute(); + + return outCaptor.toString().trim().split("\n"); + } +} diff --git a/flink-cdc-composer/src/test/java/org/apache/flink/cdc/composer/flink/FlinkPipelineUdfITCase.java b/flink-cdc-composer/src/test/java/org/apache/flink/cdc/composer/flink/FlinkPipelineUdfITCase.java index 9d5a8b22e23..dc04c8de66e 100644 --- a/flink-cdc-composer/src/test/java/org/apache/flink/cdc/composer/flink/FlinkPipelineUdfITCase.java +++ b/flink-cdc-composer/src/test/java/org/apache/flink/cdc/composer/flink/FlinkPipelineUdfITCase.java @@ -21,7 +21,6 @@ import org.apache.flink.cdc.common.pipeline.PipelineOptions; import org.apache.flink.cdc.common.pipeline.SchemaChangeBehavior; import org.apache.flink.cdc.composer.PipelineExecution; -import org.apache.flink.cdc.composer.definition.ModelDef; import org.apache.flink.cdc.composer.definition.PipelineDef; import org.apache.flink.cdc.composer.definition.SinkDef; import org.apache.flink.cdc.composer.definition.SourceDef; @@ -40,7 +39,6 @@ import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.extension.RegisterExtension; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; @@ -48,10 +46,8 @@ import java.io.ByteArrayOutputStream; import java.io.PrintStream; -import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; -import java.util.LinkedHashMap; import java.util.stream.Stream; import static org.apache.flink.configuration.CoreOptions.ALWAYS_PARENT_FIRST_LOADER_PATTERNS_ADDITIONAL; @@ -911,77 +907,6 @@ void testComplicatedFlinkUdf(ValuesDataSink.SinkApi sinkApi, String language) th "DataChangeEvent{tableId=default_namespace.default_schema.table1, before=[2, , 4, Integer: 42, 2-42], after=[2, x, 4, Integer: 42, 2-42], op=UPDATE, meta=({op_ts=5})}"); } - @ParameterizedTest - @MethodSource("testParams") - @Disabled("For manual test as there is a limit for quota.") - void testTransformWithModel(ValuesDataSink.SinkApi sinkApi, String language) throws Exception { - FlinkPipelineComposer composer = FlinkPipelineComposer.ofMiniCluster(); - - // Setup value source - Configuration sourceConfig = new Configuration(); - sourceConfig.set( - ValuesDataSourceOptions.EVENT_SET_ID, - ValuesDataSourceHelper.EventSetId.TRANSFORM_TABLE); - SourceDef sourceDef = - new SourceDef(ValuesDataFactory.IDENTIFIER, "Value Source", sourceConfig); - - // Setup value sink - Configuration sinkConfig = new Configuration(); - sinkConfig.set(ValuesDataSinkOptions.MATERIALIZED_IN_MEMORY, true); - sinkConfig.set(ValuesDataSinkOptions.SINK_API, sinkApi); - SinkDef sinkDef = new SinkDef(ValuesDataFactory.IDENTIFIER, "Value Sink", sinkConfig); - - // Setup transform - TransformDef transformDef = - new TransformDef( - "default_namespace.default_schema.table1", - "*, CHAT(col1) AS emb", - null, - "col1", - null, - "key1=value1", - "", - null); - - // Setup pipeline - Configuration pipelineConfig = new Configuration(); - pipelineConfig.set(PipelineOptions.PIPELINE_PARALLELISM, 1); - pipelineConfig.set( - PipelineOptions.PIPELINE_SCHEMA_CHANGE_BEHAVIOR, SchemaChangeBehavior.EVOLVE); - PipelineDef pipelineDef = - new PipelineDef( - sourceDef, - sinkDef, - Collections.emptyList(), - Collections.singletonList(transformDef), - new ArrayList<>(), - Arrays.asList( - new ModelDef( - "CHAT", - "OpenAIChatModel", - new LinkedHashMap<>( - ImmutableMap.builder() - .put("openai.model", "gpt-4o-mini") - .put( - "openai.host", - "http://langchain4j.dev/demo/openai/v1") - .put("openai.apikey", "demo") - .build()))), - pipelineConfig); - - // Execute the pipeline - PipelineExecution execution = composer.compose(pipelineDef); - execution.execute(); - - // Check the order and content of all received events - String[] outputEvents = outCaptor.toString().trim().split("\n"); - assertThat(outputEvents) - .contains( - "CreateTableEvent{tableId=default_namespace.default_schema.table1, schema=columns={`col1` STRING NOT NULL,`col2` STRING,`emb` STRING}, primaryKeys=col1, options=({key1=value1})}") - // The result of transform by model is not fixed. - .hasSize(9); - } - @ParameterizedTest @MethodSource("testParams") void testComplicatedUdfReturnTypes(ValuesDataSink.SinkApi sinkApi, String language) diff --git a/flink-cdc-composer/src/test/resources/META-INF/services/org.apache.flink.cdc.common.factories.Factory b/flink-cdc-composer/src/test/resources/META-INF/services/org.apache.flink.cdc.common.factories.Factory index 274faa0078d..df0fb2fdfc1 100644 --- a/flink-cdc-composer/src/test/resources/META-INF/services/org.apache.flink.cdc.common.factories.Factory +++ b/flink-cdc-composer/src/test/resources/META-INF/services/org.apache.flink.cdc.common.factories.Factory @@ -18,4 +18,5 @@ org.apache.flink.cdc.composer.utils.factory.DataSourceFactory1 org.apache.flink.cdc.composer.utils.factory.DataSourceFactory2 org.apache.flink.cdc.composer.testsource.factory.DistributedDataSourceFactory org.apache.flink.cdc.composer.flink.FlinkPipelineComposerTest$TestDataSinkFactory -org.apache.flink.cdc.composer.flink.FlinkPipelineComposerTest$TestDataSourceFactory \ No newline at end of file +org.apache.flink.cdc.composer.flink.FlinkPipelineComposerTest$TestDataSourceFactory +org.apache.flink.cdc.models.dummy.DummyModelClientFactory diff --git a/flink-cdc-e2e-tests/flink-cdc-pipeline-e2e-tests/pom.xml b/flink-cdc-e2e-tests/flink-cdc-pipeline-e2e-tests/pom.xml index 24e6213ed9e..733bf3cb77d 100644 --- a/flink-cdc-e2e-tests/flink-cdc-pipeline-e2e-tests/pom.xml +++ b/flink-cdc-e2e-tests/flink-cdc-pipeline-e2e-tests/pom.xml @@ -229,6 +229,18 @@ limitations under the License. ${project.version} test + + org.apache.flink + flink-cdc-pipeline-model-dummy + ${project.version} + test + + + org.apache.flink + flink-cdc-pipeline-model-openai-compatible + ${project.version} + test + @@ -543,6 +555,26 @@ limitations under the License. + + org.apache.flink + flink-cdc-pipeline-model-dummy + ${project.version} + dummy-model.jar + jar + ${project.build.directory}/dependencies + + + + + org.apache.flink + flink-cdc-pipeline-model-openai-compatible + ${project.version} + openai-compatible-model.jar + jar + ${project.build.directory}/dependencies + + + org.apache.flink flink-cdc-pipeline-connector-mysql diff --git a/flink-cdc-e2e-tests/flink-cdc-pipeline-e2e-tests/src/test/java/org/apache/flink/cdc/pipeline/tests/AiFunctionE2eITCase.java b/flink-cdc-e2e-tests/flink-cdc-pipeline-e2e-tests/src/test/java/org/apache/flink/cdc/pipeline/tests/AiFunctionE2eITCase.java new file mode 100644 index 00000000000..c3b21bed1b9 --- /dev/null +++ b/flink-cdc-e2e-tests/flink-cdc-pipeline-e2e-tests/src/test/java/org/apache/flink/cdc/pipeline/tests/AiFunctionE2eITCase.java @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.cdc.pipeline.tests; + +import org.apache.flink.cdc.common.test.utils.TestUtils; +import org.apache.flink.cdc.pipeline.tests.utils.PipelineTestEnvironment; + +import org.assertj.core.api.Assumptions; +import org.junit.jupiter.api.Test; + +import java.nio.file.Path; +import java.time.Duration; + +/** E2e tests for AI functions with the dummy model and openai-compatible model. */ +class AiFunctionE2eITCase extends PipelineTestEnvironment { + + @Test + void testAiFunctionsWithDummyModel() throws Exception { + String pipelineJob = + "source:\n" + + " type: values\n" + + " event-set.id: SINGLE_SPLIT_MULTI_TABLES\n" + + "\n" + + "sink:\n" + + " type: values\n" + + "\n" + + "transform:\n" + + " - source-table: default_namespace.default_schema.table1\n" + + " projection: col1, AI_COMPLETE('myModel', col1, 'Classify into catA or catB') AS cls\n" + + " - source-table: default_namespace.default_schema.table2\n" + + " projection: col1, AI_EMBED('myModel', col1) AS embedding\n" + + "\n" + + "pipeline:\n" + + " parallelism: 1\n" + + " schema.change.behavior: evolve\n" + + " model:\n" + + " name: myModel\n" + + " type: dummy\n" + + " debug: true\n"; + + Path dummyModelJar = TestUtils.getResource("dummy-model.jar"); + submitPipelineJob(pipelineJob, dummyModelJar); + waitUntilJobFinished(Duration.ofMinutes(3)); + + validateResult( + "Successfully opened AI model client 'myModel'.", + "CreateTableEvent{tableId=default_namespace.default_schema.table1, schema=columns={`col1` STRING NOT NULL,`cls` VARIANT}, primaryKeys=col1, options=()}", + "DataChangeEvent{tableId=default_namespace.default_schema.table1, before=[], after=[1, {\"result\":\"dummy result\",\"summary\":\"TL;DR\"}], op=INSERT, meta=()}", + "DataChangeEvent{tableId=default_namespace.default_schema.table1, before=[], after=[2, {\"result\":\"dummy result\",\"summary\":\"TL;DR\"}], op=INSERT, meta=()}", + "DataChangeEvent{tableId=default_namespace.default_schema.table1, before=[], after=[3, {\"result\":\"dummy result\",\"summary\":\"TL;DR\"}], op=INSERT, meta=()}", + "DataChangeEvent{tableId=default_namespace.default_schema.table1, before=[1, {\"result\":\"dummy result\",\"summary\":\"TL;DR\"}], after=[], op=DELETE, meta=()}", + "DataChangeEvent{tableId=default_namespace.default_schema.table1, before=[2, {\"result\":\"dummy result\",\"summary\":\"TL;DR\"}], after=[2, {\"result\":\"dummy result\",\"summary\":\"TL;DR\"}], op=UPDATE, meta=()}", + "CreateTableEvent{tableId=default_namespace.default_schema.table2, schema=columns={`col1` STRING NOT NULL,`embedding` ARRAY}, primaryKeys=col1, options=()}", + "DataChangeEvent{tableId=default_namespace.default_schema.table2, before=[], after=[1, [3.0, 1.0, 4.0, 1.0, 5.0, 9.0, 2.0, 6.0]], op=INSERT, meta=()}", + "DataChangeEvent{tableId=default_namespace.default_schema.table2, before=[], after=[2, [3.0, 1.0, 4.0, 1.0, 5.0, 9.0, 2.0, 6.0]], op=INSERT, meta=()}", + "DataChangeEvent{tableId=default_namespace.default_schema.table2, before=[], after=[3, [3.0, 1.0, 4.0, 1.0, 5.0, 9.0, 2.0, 6.0]], op=INSERT, meta=()}", + "Successfully closed AI model client 'myModel'."); + } + + @Test + void testAiFunctionsWithOpenAiCompatibleModel() throws Exception { + String endpoint = System.getenv("OPENAI_BASE_URL"); + String apiKey = System.getenv("OPENAI_API_KEY"); + String model = System.getenv("OPENAI_MODEL"); + Assumptions.assumeThat(endpoint != null && apiKey != null && model != null) + .as("OPENAI_BASE_URL, OPENAI_API_KEY and OPENAI_MODEL must be set") + .isTrue(); + + String pipelineJob = + "source:\n" + + " type: values\n" + + " event-set.id: SINGLE_SPLIT_MULTI_TABLES\n" + + "\n" + + "sink:\n" + + " type: values\n" + + "\n" + + "transform:\n" + + " - source-table: default_namespace.default_schema.table1\n" + + " projection: col1, AI_COMPLETE('openaiModel', col1, 'Reply only the negative value of input number') AS reversed\n" + + " - source-table: default_namespace.default_schema.table2\n" + + " projection: col1, AI_SUMMARIZE('openaiModel', col1, 50) AS summary\n" + + "\n" + + "pipeline:\n" + + " parallelism: 1\n" + + " schema.change.behavior: evolve\n" + + " model:\n" + + " name: openaiModel\n" + + " type: openai-compatible\n" + + " endpoint: " + + endpoint + + "\n" + + " api-key: " + + apiKey + + "\n" + + " model-name: " + + model + + "\n"; + + Path openaiModelJar = TestUtils.getResource("openai-compatible-model.jar"); + submitPipelineJob(pipelineJob, openaiModelJar); + waitUntilJobFinished(Duration.ofMinutes(5)); + + validateResult( + "CreateTableEvent{tableId=default_namespace.default_schema.table1, schema=columns={`col1` STRING NOT NULL,`reversed` VARIANT}, primaryKeys=col1, options=()}", + "CreateTableEvent{tableId=default_namespace.default_schema.table2, schema=columns={`col1` STRING NOT NULL,`summary` VARIANT}, primaryKeys=col1, options=()}", + "DataChangeEvent{tableId=default_namespace.default_schema.table1, before=[], after=[1, {\"result\":-1}], op=INSERT, meta=()}", + "DataChangeEvent{tableId=default_namespace.default_schema.table1, before=[], after=[2, {\"result\":-2}], op=INSERT, meta=()}", + "DataChangeEvent{tableId=default_namespace.default_schema.table1, before=[], after=[3, {\"result\":-3}], op=INSERT, meta=()}", + "DataChangeEvent{tableId=default_namespace.default_schema.table2, before=[], after=[1, {\"summary\":\"1\"}], op=INSERT, meta=()}", + "DataChangeEvent{tableId=default_namespace.default_schema.table2, before=[], after=[2, {\"summary\":\"2\"}], op=INSERT, meta=()}", + "DataChangeEvent{tableId=default_namespace.default_schema.table2, before=[], after=[3, {\"summary\":\"3\"}], op=INSERT, meta=()}", + "DataChangeEvent{tableId=default_namespace.default_schema.table1, before=[1, {\"result\":-1}], after=[], op=DELETE, meta=()}", + "DataChangeEvent{tableId=default_namespace.default_schema.table1, before=[2, {\"result\":-2}], after=[2, {\"result\":-2}], op=UPDATE, meta=()}"); + } +} diff --git a/flink-cdc-e2e-tests/flink-cdc-pipeline-e2e-tests/src/test/resources/rules/malformed.yaml b/flink-cdc-e2e-tests/flink-cdc-pipeline-e2e-tests/src/test/resources/rules/malformed.yaml index ffed7ad9bf5..6fe4a61e1d6 100644 --- a/flink-cdc-e2e-tests/flink-cdc-pipeline-e2e-tests/src/test/resources/rules/malformed.yaml +++ b/flink-cdc-e2e-tests/flink-cdc-pipeline-e2e-tests/src/test/resources/rules/malformed.yaml @@ -55,20 +55,3 @@ steps: error: | YAML UDF block is expecting an array children, but got an OBJECT ({"name":"addone","classpath":"org.apache.flink.cdc.udf.examples.%s.AddOneFunctionClass"}). Perhaps you missed a dash prefix `-`? - # Models not an array - - type: submit - yaml: | - source: - type: values - sink: - type: values - pipeline: - model: - model-name: GET_EMBEDDING - class-name: OpenAIEmbeddingModel - openai.model: text-embedding-3-small - openai.host: https://xxxx - openai.apikey: abcd1234 - error: | - YAML model block is expecting an array children, but got an OBJECT ({"model-name":"GET_EMBEDDING","class-name":"OpenAIEmbeddingModel","openai.model":"text-embedding-3-small","openai.host":"https://xxxx","openai.apikey":"abcd1234"}). - Perhaps you missed a dash prefix `-`? diff --git a/flink-cdc-pipeline-model/flink-cdc-pipeline-model-dummy/pom.xml b/flink-cdc-pipeline-model/flink-cdc-pipeline-model-dummy/pom.xml new file mode 100644 index 00000000000..81f462b0fd8 --- /dev/null +++ b/flink-cdc-pipeline-model/flink-cdc-pipeline-model-dummy/pom.xml @@ -0,0 +1,32 @@ + + + + + + org.apache.flink + flink-cdc-pipeline-model + ${revision} + + + 4.0.0 + + flink-cdc-pipeline-model-dummy + jar + diff --git a/flink-cdc-pipeline-model/flink-cdc-pipeline-model-dummy/src/main/java/org/apache/flink/cdc/models/dummy/DummyModelClient.java b/flink-cdc-pipeline-model/flink-cdc-pipeline-model-dummy/src/main/java/org/apache/flink/cdc/models/dummy/DummyModelClient.java new file mode 100644 index 00000000000..ea33eb6d851 --- /dev/null +++ b/flink-cdc-pipeline-model/flink-cdc-pipeline-model-dummy/src/main/java/org/apache/flink/cdc/models/dummy/DummyModelClient.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.cdc.models.dummy; + +import org.apache.flink.cdc.common.model.AiModelClient; +import org.apache.flink.cdc.common.model.abilities.SupportsEmbedding; +import org.apache.flink.cdc.common.model.abilities.SupportsTextGeneration; + +/** Deterministic dummy AI model client for testing. */ +public class DummyModelClient implements AiModelClient, SupportsTextGeneration, SupportsEmbedding { + + private static final long serialVersionUID = 1L; + + private final boolean debug; + + public DummyModelClient(boolean debug) { + this.debug = debug; + } + + @Override + public String generate(String systemPrompt, String userInput) { + if (debug) { + System.out.printf("Received prompt: %s\nUser input: %s\n", systemPrompt, userInput); + } + // Returns a JSON covering fields for AI_COMPLETE and AI_SUMMARIZE + return "{\"result\":\"dummy result\",\"summary\":\"TL;DR\"}"; + } + + @Override + public float[] embed(String text) { + return new float[] {3f, 1f, 4f, 1f, 5f, 9f, 2f, 6f}; + } + + @Override + public void open() { + if (debug) { + System.out.println("Dummy model opened."); + } + } + + @Override + public void close() { + if (debug) { + System.out.println("Dummy model closed."); + } + } +} diff --git a/flink-cdc-pipeline-model/flink-cdc-pipeline-model-dummy/src/main/java/org/apache/flink/cdc/models/dummy/DummyModelClientFactory.java b/flink-cdc-pipeline-model/flink-cdc-pipeline-model-dummy/src/main/java/org/apache/flink/cdc/models/dummy/DummyModelClientFactory.java new file mode 100644 index 00000000000..042b9762998 --- /dev/null +++ b/flink-cdc-pipeline-model/flink-cdc-pipeline-model-dummy/src/main/java/org/apache/flink/cdc/models/dummy/DummyModelClientFactory.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.cdc.models.dummy; + +import org.apache.flink.cdc.common.configuration.ConfigOption; +import org.apache.flink.cdc.common.factories.AiModelClientFactory; +import org.apache.flink.cdc.common.factories.Factory; +import org.apache.flink.cdc.common.model.AiModelClient; + +import java.util.Set; + +import static org.apache.flink.cdc.common.configuration.ConfigOptions.key; + +/** SPI factory for {@link DummyModelClient}. For testing purposes only. */ +public class DummyModelClientFactory implements AiModelClientFactory { + + public static final ConfigOption DEBUG = + key("debug").booleanType().defaultValue(false); + + @Override + public String identifier() { + return "dummy"; + } + + @Override + public Set> requiredOptions() { + return Set.of(); + } + + @Override + public Set> optionalOptions() { + return Set.of(DEBUG); + } + + @Override + public AiModelClient createClient(Factory.Context context) { + boolean debug = context.getFactoryConfiguration().get(DEBUG); + return new DummyModelClient(debug); + } +} diff --git a/flink-cdc-pipeline-model/flink-cdc-pipeline-model-dummy/src/main/resources/META-INF/services/org.apache.flink.cdc.common.factories.Factory b/flink-cdc-pipeline-model/flink-cdc-pipeline-model-dummy/src/main/resources/META-INF/services/org.apache.flink.cdc.common.factories.Factory new file mode 100644 index 00000000000..c1ed9c43ff3 --- /dev/null +++ b/flink-cdc-pipeline-model/flink-cdc-pipeline-model-dummy/src/main/resources/META-INF/services/org.apache.flink.cdc.common.factories.Factory @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +org.apache.flink.cdc.models.dummy.DummyModelClientFactory diff --git a/flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/pom.xml b/flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/pom.xml new file mode 100644 index 00000000000..549a3d7db06 --- /dev/null +++ b/flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/pom.xml @@ -0,0 +1,95 @@ + + + + + + org.apache.flink + flink-cdc-pipeline-model + ${revision} + + + 4.0.0 + + flink-cdc-pipeline-model-openai-compatible + jar + + + 2.13.4 + 3.12.4 + + + + + com.openai + openai-java + 4.32.0 + + + org.mockito + mockito-core + ${mockito.version} + test + + + + + + + org.apache.maven.plugins + maven-shade-plugin + + + shade-flink + + + + *:* + + + + + *:* + + META-INF/services/com.fasterxml.** + META-INF/services/kotlin.** + + + + + + com.fasterxml + org.apache.flink.cdc.models.openai.shaded.com.fasterxml + + + okhttp3 + org.apache.flink.cdc.models.openai.shaded.okhttp3 + + + okio + org.apache.flink.cdc.models.openai.shaded.okio + + + + + + + + + diff --git a/flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/src/main/java/org/apache/flink/cdc/models/openai/OpenAiCompatibleModelClient.java b/flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/src/main/java/org/apache/flink/cdc/models/openai/OpenAiCompatibleModelClient.java new file mode 100644 index 00000000000..fe0fa425131 --- /dev/null +++ b/flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/src/main/java/org/apache/flink/cdc/models/openai/OpenAiCompatibleModelClient.java @@ -0,0 +1,223 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.cdc.models.openai; + +import org.apache.flink.cdc.common.model.AiModelClient; +import org.apache.flink.cdc.common.model.abilities.SupportsEmbedding; +import org.apache.flink.cdc.common.model.abilities.SupportsTextGeneration; + +import com.openai.client.OpenAIClient; +import com.openai.client.okhttp.OpenAIOkHttpClient; +import com.openai.core.JsonValue; +import com.openai.models.ResponseFormatJsonObject; +import com.openai.models.ResponseFormatText; +import com.openai.models.chat.completions.ChatCompletion; +import com.openai.models.chat.completions.ChatCompletionCreateParams; +import com.openai.models.embeddings.CreateEmbeddingResponse; +import com.openai.models.embeddings.Embedding; +import com.openai.models.embeddings.EmbeddingCreateParams; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Collections; +import java.util.List; +import java.util.Map; + +/** AI model client that connects to any OpenAI-compatible endpoint. */ +public class OpenAiCompatibleModelClient + implements AiModelClient, SupportsTextGeneration, SupportsEmbedding { + + private static final Logger LOG = LoggerFactory.getLogger(OpenAiCompatibleModelClient.class); + + private static final long serialVersionUID = 1L; + + private final String endpoint; + private final String apiKey; + private final String modelName; + private final String globalSystemPrompt; + private final Double temperature; + private final Double topP; + private final String stop; + private final Integer maxTokens; + private final Double presencePenalty; + private final Integer n; + private final Long seed; + private final String responseFormat; + private final Map> extraHeaders; + private final Map extraBody; + private final Integer embeddingDimension; + + private transient OpenAIClient client; + + public OpenAiCompatibleModelClient( + String endpoint, + String apiKey, + String modelName, + String globalSystemPrompt, + Double temperature, + Double topP, + String stop, + Integer maxTokens, + Double presencePenalty, + Integer n, + Long seed, + String responseFormat, + Map> extraHeaders, + Map extraBody, + Integer embeddingDimension) { + this.endpoint = endpoint; + this.apiKey = apiKey; + this.modelName = modelName; + this.globalSystemPrompt = globalSystemPrompt; + this.temperature = temperature; + this.topP = topP; + this.stop = stop; + this.maxTokens = maxTokens; + this.presencePenalty = presencePenalty; + this.n = n; + this.seed = seed; + this.responseFormat = responseFormat; + this.extraHeaders = extraHeaders == null ? Collections.emptyMap() : extraHeaders; + this.extraBody = extraBody == null ? Collections.emptyMap() : extraBody; + this.embeddingDimension = embeddingDimension; + } + + @Override + public void open() { + client = OpenAIOkHttpClient.builder().baseUrl(endpoint).apiKey(apiKey).build(); + LOG.info( + "Successfully constructed OpenAI http client. Endpoint: {} Model: {}", + endpoint, + modelName); + try { + client.models().list(); + } catch (Exception e) { + throw new RuntimeException("Failed to perform livecheck on OpenAI model client", e); + } + } + + @Override + public void close() { + if (client != null) { + try { + client.close(); + } finally { + client = null; + } + } + } + + @Override + public String generate(String systemPrompt, String userInput) { + String mergedSystemPrompt = mergeSystemPrompt(globalSystemPrompt, systemPrompt); + ChatCompletionCreateParams.Builder builder = + ChatCompletionCreateParams.builder() + .model(modelName) + .addSystemMessage(mergedSystemPrompt) + .addUserMessage(userInput); + + if (temperature != null) { + builder.temperature(temperature); + } + if (topP != null) { + builder.topP(topP); + } + if (stop != null && !stop.trim().isEmpty()) { + builder.stop(stop); + } + if (maxTokens != null) { + builder.maxTokens(maxTokens.longValue()); + } + if (presencePenalty != null) { + builder.presencePenalty(presencePenalty); + } + if (n != null) { + builder.n(n.longValue()); + } + if (seed != null) { + builder.seed(seed); + } + if ("json_object".equals(responseFormat)) { + builder.responseFormat( + ResponseFormatJsonObject.builder().type(JsonValue.from("json_object")).build()); + } else { + builder.responseFormat( + ResponseFormatText.builder().type(JsonValue.from("text")).build()); + } + for (Map.Entry> entry : extraHeaders.entrySet()) { + builder.putAdditionalHeaders(entry.getKey(), entry.getValue()); + } + for (Map.Entry entry : extraBody.entrySet()) { + builder.putAdditionalBodyProperty(entry.getKey(), JsonValue.from(entry.getValue())); + } + + ChatCompletionCreateParams params = builder.build(); + ChatCompletion completion = client.chat().completions().create(params); + return completion.choices().get(0).message().content().orElse(null); + } + + static String mergeSystemPrompt(String globalSystemPrompt, String runtimeSystemPrompt) { + String normalizedGlobal = normalizePrompt(globalSystemPrompt); + String normalizedRuntime = normalizePrompt(runtimeSystemPrompt); + if (normalizedGlobal == null) { + return normalizedRuntime == null ? "" : normalizedRuntime; + } + if (normalizedRuntime == null) { + return normalizedGlobal; + } + return normalizedGlobal + "\n\n" + normalizedRuntime; + } + + private static String normalizePrompt(String prompt) { + if (prompt == null || prompt.trim().isEmpty()) { + return null; + } + return prompt; + } + + @Override + public float[] embed(String text) { + EmbeddingCreateParams.Builder builder = + EmbeddingCreateParams.builder().model(modelName).input(text); + if (embeddingDimension != null) { + builder.dimensions(embeddingDimension.longValue()); + } + for (Map.Entry> entry : extraHeaders.entrySet()) { + builder.putAdditionalHeaders(entry.getKey(), entry.getValue()); + } + for (Map.Entry entry : extraBody.entrySet()) { + builder.putAdditionalBodyProperty(entry.getKey(), JsonValue.from(entry.getValue())); + } + EmbeddingCreateParams params = builder.build(); + CreateEmbeddingResponse response = client.embeddings().create(params); + List data = response.data(); + if (data.isEmpty()) { + throw new IllegalStateException( + "Embedding response from model '" + + modelName + + "' contained no embeddings. " + + "This indicates a server-side anomaly; refusing to emit an empty vector."); + } + List embedding = data.get(0).embedding(); + float[] result = new float[embedding.size()]; + for (int i = 0; i < result.length; i++) { + result[i] = embedding.get(i); + } + return result; + } +} diff --git a/flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/src/main/java/org/apache/flink/cdc/models/openai/OpenAiCompatibleModelClientFactory.java b/flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/src/main/java/org/apache/flink/cdc/models/openai/OpenAiCompatibleModelClientFactory.java new file mode 100644 index 00000000000..d481cb1a9a2 --- /dev/null +++ b/flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/src/main/java/org/apache/flink/cdc/models/openai/OpenAiCompatibleModelClientFactory.java @@ -0,0 +1,183 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.cdc.models.openai; + +import org.apache.flink.cdc.common.configuration.ConfigOption; +import org.apache.flink.cdc.common.configuration.Configuration; +import org.apache.flink.cdc.common.factories.AiModelClientFactory; +import org.apache.flink.cdc.common.factories.Factory; +import org.apache.flink.cdc.common.model.AiModelClient; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; + +/** SPI factory for {@link OpenAiCompatibleModelClient}. */ +public class OpenAiCompatibleModelClientFactory implements AiModelClientFactory { + + private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); + + @Override + public String identifier() { + return "openai-compatible"; + } + + @Override + public Set> requiredOptions() { + return OpenAiCompatibleModelOptions.requiredOptions(); + } + + @Override + public Set> optionalOptions() { + return OpenAiCompatibleModelOptions.optionalOptions(); + } + + @Override + public AiModelClient createClient(Factory.Context context) { + Configuration configuration = context.getFactoryConfiguration(); + String endpoint = configuration.get(OpenAiCompatibleModelOptions.ENDPOINT); + String apiKey = configuration.get(OpenAiCompatibleModelOptions.API_KEY); + String modelName = configuration.get(OpenAiCompatibleModelOptions.MODEL_NAME); + String systemPrompt = + configuration.getOptional(OpenAiCompatibleModelOptions.SYSTEM_PROMPT).orElse(null); + Double temperature = + configuration.getOptional(OpenAiCompatibleModelOptions.TEMPERATURE).orElse(null); + Double topP = configuration.getOptional(OpenAiCompatibleModelOptions.TOP_P).orElse(null); + String stop = configuration.getOptional(OpenAiCompatibleModelOptions.STOP).orElse(null); + Integer maxTokens = + configuration.getOptional(OpenAiCompatibleModelOptions.MAX_TOKENS).orElse(null); + Double presencePenalty = + configuration + .getOptional(OpenAiCompatibleModelOptions.PRESENCE_PENALTY) + .orElse(null); + Integer n = configuration.getOptional(OpenAiCompatibleModelOptions.N).orElse(null); + Long seed = configuration.getOptional(OpenAiCompatibleModelOptions.SEED).orElse(null); + String responseFormat = configuration.get(OpenAiCompatibleModelOptions.RESPONSE_FORMAT); + String extraHeader = + configuration.getOptional(OpenAiCompatibleModelOptions.EXTRA_HEADER).orElse(null); + String extraBody = + configuration.getOptional(OpenAiCompatibleModelOptions.EXTRA_BODY).orElse(null); + Integer dimension = + configuration.getOptional(OpenAiCompatibleModelOptions.DIMENSION).orElse(null); + + validateValueRanges(temperature, presencePenalty, n, maxTokens, dimension); + validateEnumValue("response-format", responseFormat, "text", "json_object"); + + return new OpenAiCompatibleModelClient( + endpoint, + apiKey, + modelName, + systemPrompt, + temperature, + topP, + stop, + maxTokens, + presencePenalty, + n, + seed, + responseFormat, + parseExtraHeaders(extraHeader), + parseExtraBody(extraBody), + dimension); + } + + private static void validateValueRanges( + Double temperature, + Double presencePenalty, + Integer n, + Integer maxTokens, + Integer dimension) { + if (temperature != null && (temperature < 0 || temperature >= 2)) { + throw new IllegalArgumentException("Option 'temperature' must be in range [0, 2)."); + } + if (presencePenalty != null && (presencePenalty < -2 || presencePenalty > 2)) { + throw new IllegalArgumentException( + "Option 'presence-penalty' must be in range [-2.0, 2.0]."); + } + if (n != null && n <= 0) { + throw new IllegalArgumentException("Option 'n' must be greater than 0."); + } + if (maxTokens != null && maxTokens <= 0) { + throw new IllegalArgumentException("Option 'max-tokens' must be greater than 0."); + } + if (dimension != null && dimension <= 0) { + throw new IllegalArgumentException("Option 'dimension' must be greater than 0."); + } + } + + private static void validateEnumValue( + String optionKey, String value, String first, String second) { + if (!first.equals(value) && !second.equals(value)) { + throw new IllegalArgumentException( + String.format( + "Option '%s' must be '%s' or '%s', but got '%s'.", + optionKey, first, second, value)); + } + } + + private static Map> parseExtraHeaders(String rawExtraHeader) { + if (rawExtraHeader == null || rawExtraHeader.trim().isEmpty()) { + return Collections.emptyMap(); + } + Map rawMap = parseJsonObject(rawExtraHeader, "extra-header"); + Map> parsed = new LinkedHashMap<>(); + for (Map.Entry entry : rawMap.entrySet()) { + Object value = entry.getValue(); + if (value instanceof String) { + parsed.put(entry.getKey(), Collections.singletonList((String) value)); + } else if (value instanceof List) { + List values = new ArrayList<>(); + for (Object element : (List) value) { + if (!(element instanceof String)) { + throw new IllegalArgumentException( + "Option 'extra-header' only supports string values or string arrays."); + } + values.add((String) element); + } + parsed.put(entry.getKey(), values); + } else { + throw new IllegalArgumentException( + "Option 'extra-header' only supports string values or string arrays."); + } + } + return parsed; + } + + private static Map parseExtraBody(String rawExtraBody) { + if (rawExtraBody == null || rawExtraBody.trim().isEmpty()) { + return Collections.emptyMap(); + } + return parseJsonObject(rawExtraBody, "extra-body"); + } + + private static Map parseJsonObject(String rawJson, String optionName) { + try { + return OBJECT_MAPPER.readValue(rawJson, new TypeReference>() {}); + } catch (Exception e) { + throw new IllegalArgumentException( + String.format("Option '%s' must be a valid JSON object string.", optionName), + e); + } + } +} diff --git a/flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/src/main/java/org/apache/flink/cdc/models/openai/OpenAiCompatibleModelOptions.java b/flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/src/main/java/org/apache/flink/cdc/models/openai/OpenAiCompatibleModelOptions.java new file mode 100644 index 00000000000..9fcfc64ca37 --- /dev/null +++ b/flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/src/main/java/org/apache/flink/cdc/models/openai/OpenAiCompatibleModelOptions.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.cdc.models.openai; + +import org.apache.flink.cdc.common.configuration.ConfigOption; + +import java.util.Set; + +import static org.apache.flink.cdc.common.configuration.ConfigOptions.key; + +/** Config options for {@link OpenAiCompatibleModelClient}. */ +public class OpenAiCompatibleModelOptions { + + public static final ConfigOption ENDPOINT = + key("endpoint").stringType().noDefaultValue(); + + public static final ConfigOption API_KEY = key("api-key").stringType().noDefaultValue(); + + public static final ConfigOption MODEL_NAME = + key("model-name").stringType().noDefaultValue(); + + public static final ConfigOption SYSTEM_PROMPT = + key("system-prompt").stringType().noDefaultValue(); + + public static final ConfigOption TEMPERATURE = + key("temperature").doubleType().noDefaultValue(); + + public static final ConfigOption TOP_P = key("top-p").doubleType().noDefaultValue(); + + public static final ConfigOption STOP = key("stop").stringType().noDefaultValue(); + + public static final ConfigOption MAX_TOKENS = + key("max-tokens").intType().noDefaultValue(); + + public static final ConfigOption PRESENCE_PENALTY = + key("presence-penalty").doubleType().noDefaultValue(); + + public static final ConfigOption N = key("n").intType().noDefaultValue(); + + public static final ConfigOption SEED = key("seed").longType().noDefaultValue(); + + public static final ConfigOption RESPONSE_FORMAT = + key("response-format").stringType().defaultValue("text"); + + public static final ConfigOption EXTRA_HEADER = + key("extra-header").stringType().noDefaultValue(); + + public static final ConfigOption EXTRA_BODY = + key("extra-body").stringType().noDefaultValue(); + + public static final ConfigOption DIMENSION = + key("dimension").intType().noDefaultValue(); + + public static Set> requiredOptions() { + return Set.of(ENDPOINT, API_KEY, MODEL_NAME); + } + + public static Set> optionalOptions() { + return Set.of( + SYSTEM_PROMPT, + TEMPERATURE, + TOP_P, + STOP, + MAX_TOKENS, + PRESENCE_PENALTY, + N, + SEED, + RESPONSE_FORMAT, + EXTRA_HEADER, + EXTRA_BODY, + DIMENSION); + } + + private OpenAiCompatibleModelOptions() {} +} diff --git a/flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/src/main/resources/META-INF/services/org.apache.flink.cdc.common.factories.Factory b/flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/src/main/resources/META-INF/services/org.apache.flink.cdc.common.factories.Factory new file mode 100644 index 00000000000..017156db45f --- /dev/null +++ b/flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/src/main/resources/META-INF/services/org.apache.flink.cdc.common.factories.Factory @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +org.apache.flink.cdc.models.openai.OpenAiCompatibleModelClientFactory diff --git a/flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/src/test/java/org/apache/flink/cdc/models/openai/OpenAiCompatibleModelClientFactoryTest.java b/flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/src/test/java/org/apache/flink/cdc/models/openai/OpenAiCompatibleModelClientFactoryTest.java new file mode 100644 index 00000000000..bba5a102f2f --- /dev/null +++ b/flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/src/test/java/org/apache/flink/cdc/models/openai/OpenAiCompatibleModelClientFactoryTest.java @@ -0,0 +1,237 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.cdc.models.openai; + +import org.apache.flink.cdc.common.configuration.ConfigOption; +import org.apache.flink.cdc.common.configuration.Configuration; +import org.apache.flink.cdc.common.factories.Factory; +import org.apache.flink.cdc.common.factories.FactoryHelper; +import org.apache.flink.cdc.common.model.AiModelClient; +import org.apache.flink.cdc.common.model.abilities.SupportsEmbedding; +import org.apache.flink.cdc.common.model.abilities.SupportsTextGeneration; + +import org.junit.jupiter.api.Test; + +import java.lang.reflect.Field; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +class OpenAiCompatibleModelClientFactoryTest { + + private final OpenAiCompatibleModelClientFactory factory = + new OpenAiCompatibleModelClientFactory(); + + private Factory.Context contextWithOptions(Map options) { + return new FactoryHelper.DefaultContext( + Configuration.fromMap(options), + new Configuration(), + Thread.currentThread().getContextClassLoader()); + } + + @Test + void testIdentifier() { + assertThat(factory.identifier()).isEqualTo("openai-compatible"); + } + + @Test + void testRequiredOptions() { + assertThat( + factory.requiredOptions().stream() + .map(ConfigOption::key) + .collect(Collectors.toSet())) + .containsExactlyInAnyOrder( + OpenAiCompatibleModelOptions.ENDPOINT.key(), + OpenAiCompatibleModelOptions.API_KEY.key(), + OpenAiCompatibleModelOptions.MODEL_NAME.key()); + } + + @Test + void testOptionalOptions() { + assertThat( + factory.optionalOptions().stream() + .map(ConfigOption::key) + .collect(Collectors.toSet())) + .contains( + OpenAiCompatibleModelOptions.SYSTEM_PROMPT.key(), + OpenAiCompatibleModelOptions.TEMPERATURE.key(), + OpenAiCompatibleModelOptions.TOP_P.key(), + OpenAiCompatibleModelOptions.STOP.key(), + OpenAiCompatibleModelOptions.MAX_TOKENS.key(), + OpenAiCompatibleModelOptions.PRESENCE_PENALTY.key(), + OpenAiCompatibleModelOptions.N.key(), + OpenAiCompatibleModelOptions.SEED.key(), + OpenAiCompatibleModelOptions.RESPONSE_FORMAT.key(), + OpenAiCompatibleModelOptions.EXTRA_HEADER.key(), + OpenAiCompatibleModelOptions.EXTRA_BODY.key(), + OpenAiCompatibleModelOptions.DIMENSION.key()); + } + + @Test + void testCreateClient() { + Map options = new HashMap<>(); + options.put("endpoint", "https://api.example.com/v1"); + options.put("api-key", "sk-test"); + options.put("model-name", "gpt-4"); + + AiModelClient client = factory.createClient(contextWithOptions(options)); + assertThat(client).isInstanceOf(OpenAiCompatibleModelClient.class); + assertThat(client).isInstanceOf(SupportsTextGeneration.class); + assertThat(client).isInstanceOf(SupportsEmbedding.class); + } + + @Test + void testCreateClientWithRequiredOptionsOnlyUsesDefaults() throws Exception { + Map options = new HashMap<>(); + options.put("endpoint", "https://api.example.com/v1"); + options.put("api-key", "sk-test"); + options.put("model-name", "gpt-4"); + + OpenAiCompatibleModelClient client = + (OpenAiCompatibleModelClient) factory.createClient(contextWithOptions(options)); + assertThat(readField(client, "globalSystemPrompt")).isNull(); + assertThat(readField(client, "responseFormat")).isEqualTo("text"); + } + + @SuppressWarnings("unchecked") + @Test + void testCreateClientParsesExtraHeaderAndBody() throws Exception { + Map options = new HashMap<>(); + options.put("endpoint", "https://api.example.com/v1"); + options.put("api-key", "sk-test"); + options.put("model-name", "gpt-4"); + options.put( + "extra-header", "{\"X-Trace-Id\":\"trace-1\",\"X-Tags\":[\"tag-a\",\"tag-b\"]}"); + options.put("extra-body", "{\"foo\":\"bar\",\"num\":1}"); + + OpenAiCompatibleModelClient client = + (OpenAiCompatibleModelClient) factory.createClient(contextWithOptions(options)); + Map> extraHeaders = + (Map>) readField(client, "extraHeaders"); + Map extraBody = (Map) readField(client, "extraBody"); + + assertThat(extraHeaders).containsEntry("X-Trace-Id", Arrays.asList("trace-1")); + assertThat(extraHeaders).containsEntry("X-Tags", Arrays.asList("tag-a", "tag-b")); + assertThat(extraBody).containsEntry("foo", "bar"); + assertThat(extraBody).containsKey("num"); + } + + @Test + void testValidatePassesWithAllRequiredOptions() { + Map options = new HashMap<>(); + options.put("endpoint", "https://api.example.com/v1"); + options.put("api-key", "sk-test"); + options.put("model-name", "gpt-4"); + + FactoryHelper.createFactoryHelper(factory, contextWithOptions(options)).validate(); + } + + @Test + void testValidatePassesWithOptionalSystemPrompt() { + Map options = new HashMap<>(); + options.put("endpoint", "https://api.example.com/v1"); + options.put("api-key", "sk-test"); + options.put("model-name", "gpt-4"); + options.put("system-prompt", "You are a strict JSON generator."); + + FactoryHelper.createFactoryHelper(factory, contextWithOptions(options)).validate(); + } + + @Test + void testValidateThrowsOnMissingRequiredOption() { + Map options = new HashMap<>(); + options.put("endpoint", "https://api.example.com/v1"); + + assertThatThrownBy( + () -> + FactoryHelper.createFactoryHelper( + factory, contextWithOptions(options)) + .validate()) + .hasMessageContaining("required options are missing"); + } + + @Test + void testValidateThrowsOnUnknownOption() { + Map options = new HashMap<>(); + options.put("endpoint", "https://api.example.com/v1"); + options.put("api-key", "sk-test"); + options.put("model-name", "gpt-4"); + options.put("unknown-key", "value"); + + assertThatThrownBy( + () -> + FactoryHelper.createFactoryHelper( + factory, contextWithOptions(options)) + .validate()) + .hasMessageContaining("Unsupported options"); + } + + @Test + void testCreateClientFailFast() { + Map options = new HashMap<>(); + options.put("endpoint", "https://api.example.com/v1"); + options.put("api-key", "sk-test"); + options.put("model-name", "gpt-4"); + + AiModelClient client = factory.createClient(contextWithOptions(options)); + assertThat(client).isInstanceOf(OpenAiCompatibleModelClient.class); + assertThat(client).isInstanceOf(SupportsTextGeneration.class); + assertThat(client).isInstanceOf(SupportsEmbedding.class); + + assertThatThrownBy(client::open) + .isInstanceOf(RuntimeException.class) + .hasMessageContaining("Failed to perform livecheck on OpenAI model client"); + } + + @Test + void testCreateClientThrowsOnInvalidTemperature() { + Map options = new HashMap<>(); + options.put("endpoint", "https://api.example.com/v1"); + options.put("api-key", "sk-test"); + options.put("model-name", "gpt-4"); + options.put("temperature", "2.0"); + + assertThatThrownBy(() -> factory.createClient(contextWithOptions(options))) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("temperature"); + } + + @Test + void testCreateClientThrowsOnInvalidExtraHeaderJson() { + Map options = new HashMap<>(); + options.put("endpoint", "https://api.example.com/v1"); + options.put("api-key", "sk-test"); + options.put("model-name", "gpt-4"); + options.put("extra-header", "not-a-json"); + + assertThatThrownBy(() -> factory.createClient(contextWithOptions(options))) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("extra-header"); + } + + private static Object readField(Object target, String fieldName) throws Exception { + Field field = target.getClass().getDeclaredField(fieldName); + field.setAccessible(true); + return field.get(target); + } +} diff --git a/flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/src/test/java/org/apache/flink/cdc/models/openai/OpenAiCompatibleModelClientITCase.java b/flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/src/test/java/org/apache/flink/cdc/models/openai/OpenAiCompatibleModelClientITCase.java new file mode 100644 index 00000000000..b508caeb414 --- /dev/null +++ b/flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/src/test/java/org/apache/flink/cdc/models/openai/OpenAiCompatibleModelClientITCase.java @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.cdc.models.openai; + +import org.assertj.core.api.Assumptions; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.util.Collections; + +import static org.assertj.core.api.Assertions.assertThat; + +class OpenAiCompatibleModelClientITCase { + + private OpenAiCompatibleModelClient client; + + private final String endpoint = System.getenv("OPENAI_BASE_URL"); + private final String apiKey = System.getenv("OPENAI_API_KEY"); + private final String model = System.getenv("OPENAI_MODEL"); + + @BeforeEach + void setUp() { + Assumptions.assumeThat(endpoint != null && apiKey != null && model != null) + .as("OPENAI_BASE_URL, OPENAI_API_KEY and OPENAI_MODEL must be set") + .isTrue(); + + client = + new OpenAiCompatibleModelClient( + endpoint, + apiKey, + model, + "You are a helpful assistant.", + null, + null, + null, + null, + null, + null, + null, + "text", + Collections.emptyMap(), + Collections.emptyMap(), + null); + client.open(); + } + + @AfterEach + void tearDown() { + if (client != null) { + client.close(); + } + } + + @Test + void testGenerate() { + String result = + client.generate("You are a calculator.", "What is 1 + 1? Answer only the number."); + assertThat(result).isNotNull().contains("2"); + } + + @Test + void testGenerateWithEmptyUserInput() { + String result = client.generate("Reply with exactly: OK", ""); + assertThat(result).isNotNull().contains("OK"); + } + + @Test + void testGenerateWithGlobalSystemPrompt() { + String marker = "GLOBAL_MARKER_9F2A"; + OpenAiCompatibleModelClient globalPromptClient = + new OpenAiCompatibleModelClient( + endpoint, + apiKey, + model, + "You must always include the exact token '" + + marker + + "' in every response.", + null, + null, + null, + null, + null, + null, + null, + "text", + Collections.emptyMap(), + Collections.emptyMap(), + null); + try (globalPromptClient) { + globalPromptClient.open(); + String result = + globalPromptClient.generate( + "Reply in one short sentence to say hello.", "Say hello."); + assertThat(result).isNotNull().contains(marker); + } + } +} diff --git a/flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/src/test/java/org/apache/flink/cdc/models/openai/OpenAiCompatibleModelClientTest.java b/flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/src/test/java/org/apache/flink/cdc/models/openai/OpenAiCompatibleModelClientTest.java new file mode 100644 index 00000000000..512cfef02ee --- /dev/null +++ b/flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/src/test/java/org/apache/flink/cdc/models/openai/OpenAiCompatibleModelClientTest.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.cdc.models.openai; + +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; + +class OpenAiCompatibleModelClientTest { + + @Test + void testMergeSystemPromptWithGlobalAndRuntime() { + assertThat(OpenAiCompatibleModelClient.mergeSystemPrompt("global", "runtime")) + .isEqualTo("global\n\nruntime"); + } + + @Test + void testMergeSystemPromptWithGlobalOnly() { + assertThat(OpenAiCompatibleModelClient.mergeSystemPrompt("global", null)) + .isEqualTo("global"); + } + + @Test + void testMergeSystemPromptWithRuntimeOnly() { + assertThat(OpenAiCompatibleModelClient.mergeSystemPrompt(null, "runtime")) + .isEqualTo("runtime"); + } + + @Test + void testMergeSystemPromptSkipsBlankGlobal() { + assertThat(OpenAiCompatibleModelClient.mergeSystemPrompt(" ", "runtime")) + .isEqualTo("runtime"); + } + + @Test + void testMergeSystemPromptSkipsBlankRuntime() { + assertThat(OpenAiCompatibleModelClient.mergeSystemPrompt("global", " ")) + .isEqualTo("global"); + } + + @Test + void testMergeSystemPromptReturnsEmptyWhenBothBlank() { + assertThat(OpenAiCompatibleModelClient.mergeSystemPrompt(" ", " ")).isEmpty(); + } +} diff --git a/flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/src/test/resources/log4j2-test.properties b/flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/src/test/resources/log4j2-test.properties new file mode 100644 index 00000000000..0d45bab8011 --- /dev/null +++ b/flink-cdc-pipeline-model/flink-cdc-pipeline-model-openai-compatible/src/test/resources/log4j2-test.properties @@ -0,0 +1,25 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +# Set root logger level to ERROR to not flood build logs +# set manually to INFO for debugging purposes +rootLogger.level = ERROR +rootLogger.appenderRef.test.ref = TestLogger + +appender.testlogger.name = TestLogger +appender.testlogger.type = CONSOLE +appender.testlogger.target = SYSTEM_ERR +appender.testlogger.layout.type = PatternLayout +appender.testlogger.layout.pattern = %-4r [%t] %-5p %c - %m%n diff --git a/flink-cdc-pipeline-model/pom.xml b/flink-cdc-pipeline-model/pom.xml index e7dba7f6f66..7cfb2fb53fc 100644 --- a/flink-cdc-pipeline-model/pom.xml +++ b/flink-cdc-pipeline-model/pom.xml @@ -19,16 +19,20 @@ limitations under the License. xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> - flink-cdc-parent org.apache.flink + flink-cdc-parent ${revision} + 4.0.0 flink-cdc-pipeline-model - - 0.23.0 - + pom + + + flink-cdc-pipeline-model-dummy + flink-cdc-pipeline-model-openai-compatible + @@ -37,51 +41,6 @@ limitations under the License. ${project.version} provided - - org.apache.flink - flink-test-utils-junit - ${flink.version} - test - - - org.testcontainers - testcontainers - - - - - dev.langchain4j - langchain4j - ${langchain4j.version} - - - dev.langchain4j - langchain4j-open-ai - ${langchain4j.version} - - - com.theokanning.openai-gpt3-java - service - 0.12.0 - - - - - - org.apache.maven.plugins - maven-jar-plugin - - - test-jar - - test-jar - - - - - - - - \ No newline at end of file + diff --git a/flink-cdc-pipeline-model/src/main/java/org/apache/flink/cdc/runtime/model/ModelOptions.java b/flink-cdc-pipeline-model/src/main/java/org/apache/flink/cdc/runtime/model/ModelOptions.java deleted file mode 100644 index f56b76d5bae..00000000000 --- a/flink-cdc-pipeline-model/src/main/java/org/apache/flink/cdc/runtime/model/ModelOptions.java +++ /dev/null @@ -1,50 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.cdc.runtime.model; - -import org.apache.flink.cdc.common.configuration.ConfigOption; -import org.apache.flink.cdc.common.configuration.ConfigOptions; - -/** Options of built-in model. */ -public class ModelOptions { - - // Options for Open AI Model. - public static final ConfigOption OPENAI_MODEL_NAME = - ConfigOptions.key("openai.model") - .stringType() - .noDefaultValue() - .withDescription("Name of model to be called."); - - public static final ConfigOption OPENAI_HOST = - ConfigOptions.key("openai.host") - .stringType() - .noDefaultValue() - .withDescription("Host of the Model server to be connected."); - - public static final ConfigOption OPENAI_API_KEY = - ConfigOptions.key("openai.apikey") - .stringType() - .noDefaultValue() - .withDescription("Api Key for verification of the Model server."); - - public static final ConfigOption OPENAI_CHAT_PROMPT = - ConfigOptions.key("openai.chat.prompt") - .stringType() - .noDefaultValue() - .withDescription("Prompt for chat using OpenAI."); -} diff --git a/flink-cdc-pipeline-model/src/main/java/org/apache/flink/cdc/runtime/model/OpenAIChatModel.java b/flink-cdc-pipeline-model/src/main/java/org/apache/flink/cdc/runtime/model/OpenAIChatModel.java deleted file mode 100644 index 2fa147f509f..00000000000 --- a/flink-cdc-pipeline-model/src/main/java/org/apache/flink/cdc/runtime/model/OpenAIChatModel.java +++ /dev/null @@ -1,97 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.cdc.runtime.model; - -import org.apache.flink.cdc.common.configuration.Configuration; -import org.apache.flink.cdc.common.types.DataType; -import org.apache.flink.cdc.common.types.DataTypes; -import org.apache.flink.cdc.common.udf.UserDefinedFunction; -import org.apache.flink.cdc.common.udf.UserDefinedFunctionContext; -import org.apache.flink.cdc.common.utils.Preconditions; - -import dev.langchain4j.data.message.UserMessage; -import dev.langchain4j.model.openai.OpenAiChatModel; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.util.Collections; - -import static org.apache.flink.cdc.runtime.model.ModelOptions.OPENAI_API_KEY; -import static org.apache.flink.cdc.runtime.model.ModelOptions.OPENAI_CHAT_PROMPT; -import static org.apache.flink.cdc.runtime.model.ModelOptions.OPENAI_HOST; -import static org.apache.flink.cdc.runtime.model.ModelOptions.OPENAI_MODEL_NAME; - -/** - * A {@link UserDefinedFunction} that use Model defined by OpenAI to generate text, refer to docs}. - */ -public class OpenAIChatModel implements UserDefinedFunction { - - private static final Logger LOG = LoggerFactory.getLogger(OpenAIChatModel.class); - - private OpenAiChatModel chatModel; - - private String modelName; - - private String host; - - private String prompt; - - public String eval(String input) { - return chat(input); - } - - private String chat(String input) { - if (input == null || input.trim().isEmpty()) { - LOG.warn("Empty or null input provided for embedding."); - return ""; - } - if (prompt != null) { - input = prompt + ": " + input; - } - return chatModel - .generate(Collections.singletonList(new UserMessage(input))) - .content() - .text(); - } - - @Override - public DataType getReturnType() { - return DataTypes.STRING(); - } - - @Override - public void open(UserDefinedFunctionContext userDefinedFunctionContext) { - Configuration modelOptions = userDefinedFunctionContext.configuration(); - this.modelName = modelOptions.get(OPENAI_MODEL_NAME); - Preconditions.checkNotNull(modelName, OPENAI_MODEL_NAME.key() + " should not be empty."); - this.host = modelOptions.get(OPENAI_HOST); - Preconditions.checkNotNull(host, OPENAI_HOST.key() + " should not be empty."); - String apiKey = modelOptions.get(OPENAI_API_KEY); - Preconditions.checkNotNull(apiKey, OPENAI_API_KEY.key() + " should not be empty."); - this.prompt = modelOptions.get(OPENAI_CHAT_PROMPT); - LOG.info("Opening OpenAIChatModel " + modelName + " " + host); - this.chatModel = - OpenAiChatModel.builder().apiKey(apiKey).baseUrl(host).modelName(modelName).build(); - } - - @Override - public void close() { - LOG.info("Closed OpenAIChatModel " + modelName + " " + host); - } -} diff --git a/flink-cdc-pipeline-model/src/main/java/org/apache/flink/cdc/runtime/model/OpenAIEmbeddingModel.java b/flink-cdc-pipeline-model/src/main/java/org/apache/flink/cdc/runtime/model/OpenAIEmbeddingModel.java deleted file mode 100644 index dbc29c307ec..00000000000 --- a/flink-cdc-pipeline-model/src/main/java/org/apache/flink/cdc/runtime/model/OpenAIEmbeddingModel.java +++ /dev/null @@ -1,109 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.cdc.runtime.model; - -import org.apache.flink.cdc.common.configuration.Configuration; -import org.apache.flink.cdc.common.data.ArrayData; -import org.apache.flink.cdc.common.data.GenericArrayData; -import org.apache.flink.cdc.common.types.DataType; -import org.apache.flink.cdc.common.types.DataTypes; -import org.apache.flink.cdc.common.udf.UserDefinedFunction; -import org.apache.flink.cdc.common.udf.UserDefinedFunctionContext; -import org.apache.flink.cdc.common.utils.Preconditions; - -import dev.langchain4j.data.document.Metadata; -import dev.langchain4j.data.embedding.Embedding; -import dev.langchain4j.data.segment.TextSegment; -import dev.langchain4j.model.openai.OpenAiEmbeddingModel; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.util.Collections; -import java.util.List; - -import static org.apache.flink.cdc.runtime.model.ModelOptions.OPENAI_API_KEY; -import static org.apache.flink.cdc.runtime.model.ModelOptions.OPENAI_HOST; -import static org.apache.flink.cdc.runtime.model.ModelOptions.OPENAI_MODEL_NAME; - -/** - * A {@link UserDefinedFunction} that use Model defined by OpenAI to generate vector data, refer to - * docs}. - */ -public class OpenAIEmbeddingModel implements UserDefinedFunction { - - private static final Logger LOG = LoggerFactory.getLogger(OpenAIEmbeddingModel.class); - - private String modelName; - - private String host; - - private OpenAiEmbeddingModel embeddingModel; - - public ArrayData eval(String input) { - return getEmbedding(input); - } - - private ArrayData getEmbedding(String input) { - if (input == null || input.trim().isEmpty()) { - LOG.debug("Empty or null input provided for embedding."); - return new GenericArrayData(new Float[0]); - } - - TextSegment textSegment = new TextSegment(input, new Metadata()); - - List embeddings = - embeddingModel.embedAll(Collections.singletonList(textSegment)).content(); - - if (embeddings != null && !embeddings.isEmpty()) { - List embeddingList = embeddings.get(0).vectorAsList(); - Float[] embeddingArray = embeddingList.toArray(new Float[0]); - return new GenericArrayData(embeddingArray); - } else { - LOG.warn("No embedding results returned for input: {}", input); - return new GenericArrayData(new Float[0]); - } - } - - @Override - public DataType getReturnType() { - return DataTypes.ARRAY(DataTypes.FLOAT()); - } - - @Override - public void open(UserDefinedFunctionContext userDefinedFunctionContext) { - Configuration modelOptions = userDefinedFunctionContext.configuration(); - this.modelName = modelOptions.get(OPENAI_MODEL_NAME); - Preconditions.checkNotNull(modelName, OPENAI_MODEL_NAME.key() + " should not be empty."); - this.host = modelOptions.get(OPENAI_HOST); - Preconditions.checkNotNull(host, OPENAI_HOST.key() + " should not be empty."); - String apiKey = modelOptions.get(OPENAI_API_KEY); - Preconditions.checkNotNull(apiKey, OPENAI_API_KEY.key() + " should not be empty."); - LOG.info("Opening OpenAIEmbeddingModel " + modelName + " " + host); - this.embeddingModel = - OpenAiEmbeddingModel.builder() - .apiKey(apiKey) - .baseUrl(host) - .modelName(modelName) - .build(); - } - - @Override - public void close() { - LOG.info("Closed OpenAIEmbeddingModel " + modelName + " " + host); - } -} diff --git a/flink-cdc-pipeline-model/src/test/java/org/apache/flink/cdc/runtime/model/TestOpenAIChatModel.java b/flink-cdc-pipeline-model/src/test/java/org/apache/flink/cdc/runtime/model/TestOpenAIChatModel.java deleted file mode 100644 index b14a7d89c65..00000000000 --- a/flink-cdc-pipeline-model/src/test/java/org/apache/flink/cdc/runtime/model/TestOpenAIChatModel.java +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.cdc.runtime.model; - -import org.apache.flink.cdc.common.configuration.Configuration; -import org.apache.flink.cdc.common.udf.UserDefinedFunctionContext; - -import org.assertj.core.api.Assertions; -import org.junit.jupiter.api.Disabled; -import org.junit.jupiter.api.Test; - -/** A test for {@link OpenAIChatModel}. */ -class TestOpenAIChatModel { - @Test - @Disabled("For manual test as there is a limit for quota.") - public void testEval() { - OpenAIChatModel openAIChatModel = new OpenAIChatModel(); - Configuration configuration = new Configuration(); - configuration.set(ModelOptions.OPENAI_HOST, "http://langchain4j.dev/demo/openai/v1"); - configuration.set(ModelOptions.OPENAI_API_KEY, "demo"); - configuration.set(ModelOptions.OPENAI_MODEL_NAME, "gpt-4o-mini"); - UserDefinedFunctionContext userDefinedFunctionContext = () -> configuration; - openAIChatModel.open(userDefinedFunctionContext); - String response = openAIChatModel.eval("Who invented the electric light?"); - Assertions.assertThat(response).isNotEmpty(); - } -} diff --git a/flink-cdc-pipeline-model/src/test/java/org/apache/flink/cdc/runtime/model/TestOpenAIEmbeddingModel.java b/flink-cdc-pipeline-model/src/test/java/org/apache/flink/cdc/runtime/model/TestOpenAIEmbeddingModel.java deleted file mode 100644 index 2be50d41e05..00000000000 --- a/flink-cdc-pipeline-model/src/test/java/org/apache/flink/cdc/runtime/model/TestOpenAIEmbeddingModel.java +++ /dev/null @@ -1,45 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.cdc.runtime.model; - -import org.apache.flink.cdc.common.configuration.Configuration; -import org.apache.flink.cdc.common.data.ArrayData; -import org.apache.flink.cdc.common.udf.UserDefinedFunctionContext; - -import org.assertj.core.api.Assertions; -import org.junit.jupiter.api.Disabled; -import org.junit.jupiter.api.Test; - -/** A test for {@link OpenAIEmbeddingModel}. */ -class TestOpenAIEmbeddingModel { - - @Test - @Disabled("For manual test as there is a limit for quota.") - public void testEval() { - OpenAIEmbeddingModel openAIEmbeddingModel = new OpenAIEmbeddingModel(); - Configuration configuration = new Configuration(); - configuration.set(ModelOptions.OPENAI_HOST, "http://langchain4j.dev/demo/openai/v1"); - configuration.set(ModelOptions.OPENAI_API_KEY, "demo"); - configuration.set(ModelOptions.OPENAI_MODEL_NAME, "text-embedding-3-small"); - UserDefinedFunctionContext userDefinedFunctionContext = () -> configuration; - openAIEmbeddingModel.open(userDefinedFunctionContext); - ArrayData arrayData = - openAIEmbeddingModel.eval("Flink CDC is a streaming data integration tool"); - Assertions.assertThat(arrayData).isNotNull(); - } -} diff --git a/flink-cdc-runtime/pom.xml b/flink-cdc-runtime/pom.xml index fc474013c22..9fcfdc86c05 100644 --- a/flink-cdc-runtime/pom.xml +++ b/flink-cdc-runtime/pom.xml @@ -95,12 +95,6 @@ limitations under the License. ${project.version} test - - org.apache.flink - flink-cdc-pipeline-model - ${project.version} - test - diff --git a/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/ai/AiEmbeddingFunctionDef.java b/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/ai/AiEmbeddingFunctionDef.java new file mode 100644 index 00000000000..3554f604c2f --- /dev/null +++ b/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/ai/AiEmbeddingFunctionDef.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.cdc.runtime.ai; + +import org.apache.flink.cdc.common.types.DataType; +import org.apache.flink.cdc.common.types.DataTypes; + +/** Built-in AI embedding function definitions with configurable input and output types. */ +public enum AiEmbeddingFunctionDef { + AI_EMBED("AI_EMBED", DataTypes.STRING(), DataTypes.ARRAY(DataTypes.FLOAT())); + + private final String functionName; + private final DataType inputType; + private final DataType outputType; + + AiEmbeddingFunctionDef(String functionName, DataType inputType, DataType outputType) { + this.functionName = functionName; + this.inputType = inputType; + this.outputType = outputType; + } + + public String getFunctionName() { + return functionName; + } + + /** The type of the input value (e.g. STRING for text embedding). */ + public DataType getInputType() { + return inputType; + } + + /** The type of the output value (e.g. ARRAY<FLOAT> for vector embedding). */ + public DataType getOutputType() { + return outputType; + } +} diff --git a/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/ai/AiTextFunctionDef.java b/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/ai/AiTextFunctionDef.java new file mode 100644 index 00000000000..845a082c771 --- /dev/null +++ b/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/ai/AiTextFunctionDef.java @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.cdc.runtime.ai; + +import org.apache.flink.cdc.common.types.DataType; +import org.apache.flink.cdc.common.types.DataTypes; +import org.apache.flink.cdc.common.types.RowType; + +/** + * Built-in AI text generation function definitions with their prompt templates and type metadata. + */ +public enum AiTextFunctionDef { + AI_COMPLETE( + "AI_COMPLETE", + RowType.of(new DataType[] {DataTypes.STRING()}, new String[] {"systemPrompt"}), + RowType.of(new DataType[] {DataTypes.STRING()}, new String[] {"result"}), + "%s\n"), + + AI_SUMMARIZE( + "AI_SUMMARIZE", + RowType.of(new DataType[] {DataTypes.INT()}, new String[] {"maxLength"}), + RowType.of(new DataType[] {DataTypes.STRING()}, new String[] {"summary"}), + "You are a text summarization expert. Generate an accurate, coherent, and informative " + + "summary that does not exceed %d characters.\n" + + "Output requirements:\n" + + "- summary: the summarized content\n" + + "Principles:\n" + + "- Stay within the specified length\n" + + "- Preserve core ideas and key information\n" + + "- Use concise language with clear logic\n" + + "- Maintain text coherence\n" + + "- Avoid subjective opinions\n"); + + private final String functionName; + private final RowType inputType; + private final RowType outputType; + private final String promptTemplate; + + AiTextFunctionDef( + String functionName, RowType inputType, RowType outputType, String promptTemplate) { + this.functionName = functionName; + this.inputType = inputType; + this.outputType = outputType; + this.promptTemplate = promptTemplate; + } + + public String getFunctionName() { + return functionName; + } + + /** + * Returns the additional parameter types for promptTemplate placeholders. + * + *

Input text parameter is always added by runtime, not included here. + */ + public RowType getInputType() { + return inputType; + } + + public RowType getOutputType() { + return outputType; + } + + /** Builds the core system prompt by filling in the template placeholders. */ + public String buildPrompt(Object... args) { + return String.format(promptTemplate, args); + } +} diff --git a/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/functions/impl/AiFunctions.java b/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/functions/impl/AiFunctions.java new file mode 100644 index 00000000000..21c911f218c --- /dev/null +++ b/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/functions/impl/AiFunctions.java @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.cdc.runtime.functions.impl; + +import org.apache.flink.cdc.common.model.AiModelClient; +import org.apache.flink.cdc.common.model.abilities.SupportsEmbedding; +import org.apache.flink.cdc.common.model.abilities.SupportsTextGeneration; +import org.apache.flink.cdc.common.types.RowType; +import org.apache.flink.cdc.common.types.variant.BinaryVariant; +import org.apache.flink.cdc.common.types.variant.BinaryVariantInternalBuilder; +import org.apache.flink.cdc.runtime.ai.AiTextFunctionDef; + +import org.apache.flink.shaded.guava31.com.google.common.primitives.Floats; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.List; + +/** Built-in AI functions available as static imports in Janino-compiled transform expressions. */ +public class AiFunctions { + + private static final Logger LOG = LoggerFactory.getLogger(AiFunctions.class); + + /** General-purpose text completion with a user-provided system prompt. */ + public static BinaryVariant aiComplete(AiModelClient model, String input, String systemPrompt) { + return invokeTextGeneration(model, AiTextFunctionDef.AI_COMPLETE, input, systemPrompt); + } + + /** Text summarization with a maximum length constraint. */ + public static BinaryVariant aiSummarize(AiModelClient model, String input, int maxLength) { + return invokeTextGeneration(model, AiTextFunctionDef.AI_SUMMARIZE, input, maxLength); + } + + /** Text embedding that converts input text to a vector representation. */ + public static List aiEmbed(AiModelClient model, String input) { + if (!(model instanceof SupportsEmbedding)) { + throw new UnsupportedOperationException( + "Model " + model.getClass().getName() + " does not support embedding"); + } + + long startTime = 0; + if (LOG.isDebugEnabled()) { + startTime = System.currentTimeMillis(); + } + float[] embeddingResult = ((SupportsEmbedding) model).embed(input); + if (LOG.isDebugEnabled()) { + long endTime = System.currentTimeMillis(); + LOG.debug( + "Generated {}-dim vector in {} ms", + embeddingResult.length, + endTime - startTime); + } + + return Floats.asList(embeddingResult); + } + + private static BinaryVariant invokeTextGeneration( + AiModelClient model, AiTextFunctionDef funcDef, String input, Object... args) { + if (!(model instanceof SupportsTextGeneration)) { + throw new UnsupportedOperationException( + "Model " + model.getClass().getName() + " does not support text generation"); + } + StringBuilder promptBuilder = new StringBuilder(); + promptBuilder.append(funcDef.buildPrompt(args)); + promptBuilder.append("\n").append(buildOutputSchemaHint(funcDef.getOutputType())); + + String systemPrompt = promptBuilder.toString(); + + long startTime = 0; + if (LOG.isDebugEnabled()) { + startTime = System.currentTimeMillis(); + } + String json = ((SupportsTextGeneration) model).generate(systemPrompt, input); + + if (LOG.isDebugEnabled()) { + long endTime = System.currentTimeMillis(); + LOG.debug("Generated {} characters in {} ms", json.length(), endTime - startTime); + } + + if (json == null) { + return null; + } + try { + return BinaryVariantInternalBuilder.parseJson(json, false); + } catch (java.io.IOException e) { + throw new RuntimeException("Failed to parse AI response as JSON: " + json, e); + } + } + + /** Builds the JSON output schema hint based on outputType. */ + private static String buildOutputSchemaHint(RowType outputType) { + StringBuilder sb = new StringBuilder(); + sb.append("You must return the result strictly in the following JSON format:\n"); + sb.append("{\n"); + List fieldNames = outputType.getFieldNames(); + for (int i = 0; i < fieldNames.size(); i++) { + sb.append(" \"") + .append(fieldNames.get(i)) + .append("\": <") + .append(fieldNames.get(i)) + .append(">"); + if (i < fieldNames.size() - 1) { + sb.append(","); + } + sb.append("\n"); + } + sb.append("}\n"); + sb.append( + "Important: Return only valid JSON with no additional text, without YAML code blocks.\n"); + return sb.toString(); + } +} diff --git a/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/operators/transform/PostTransformOperator.java b/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/operators/transform/PostTransformOperator.java index 0a8b63703b9..c90a7de8c17 100644 --- a/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/operators/transform/PostTransformOperator.java +++ b/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/operators/transform/PostTransformOperator.java @@ -29,6 +29,7 @@ import org.apache.flink.cdc.common.event.Event; import org.apache.flink.cdc.common.event.SchemaChangeEvent; import org.apache.flink.cdc.common.event.TableId; +import org.apache.flink.cdc.common.model.AiModelClient; import org.apache.flink.cdc.common.schema.Schema; import org.apache.flink.cdc.common.schema.Selectors; import org.apache.flink.cdc.common.udf.UserDefinedFunctionContext; @@ -41,6 +42,7 @@ import org.apache.flink.cdc.runtime.typeutils.BinaryRecordDataGenerator; import org.apache.flink.streaming.api.operators.OneInputStreamOperator; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.util.FlinkRuntimeException; import org.apache.flink.shaded.guava31.com.google.common.cache.CacheBuilder; import org.apache.flink.shaded.guava31.com.google.common.cache.CacheLoader; @@ -79,6 +81,9 @@ public class PostTransformOperator extends AbstractStreamOperatorAdapter // Tuple3 items are: function name, class path, and extra options. private final List>> udfFunctions; + // Serializable AI model clients keyed by model name, e.g. myModel. + private final Map modelClients; + private transient List transformers; private transient List udfDescriptors; private transient List udfFunctionInstances; @@ -98,13 +103,15 @@ public static PostTransformOperatorBuilder newBuilder() { PostTransformOperator( List transformRules, String timezone, - List>> udfFunctions) { + List>> udfFunctions, + Map modelClients) { this.timezone = timezone; this.transformRules = transformRules; this.hasAsteriskMap = new HashMap<>(); this.projectedColumnsMap = new HashMap<>(); this.postTransformInfoMap = new ConcurrentHashMap<>(); this.udfFunctions = udfFunctions; + this.modelClients = modelClients; } @Override @@ -115,6 +122,9 @@ public void open() throws Exception { this.projectionProcessors = HashBasedTable.create(); this.filterProcessors = HashBasedTable.create(); + // Initialize AI model clients + initializeAiModelClients(); + // Be sure to initialize UDF related fields before creating transformers initializeUdf(); @@ -136,6 +146,7 @@ public void close() throws Exception { super.close(); TransformExpressionCompiler.cleanUp(); destroyUdf(); + destroyAiModelClients(); } @Override @@ -447,7 +458,8 @@ private TransformProjectionProcessor getProjectionProcessor( timezone, udfDescriptors, udfFunctionInstances, - postTransformer.getSupportedMetadataColumns())); + postTransformer.getSupportedMetadataColumns(), + modelClients)); } return projectionProcessors.get(tableId, postTransformer); } @@ -472,7 +484,8 @@ private TransformFilterProcessor getFilterProcessor( timezone, udfDescriptors, udfFunctionInstances, - postTransformer.getSupportedMetadataColumns())); + postTransformer.getSupportedMetadataColumns(), + modelClients)); } } return filterProcessors.get(tableId, postTransformer); @@ -558,4 +571,28 @@ private void destroyUdf() { udfDescriptors.clear(); udfFunctionInstances.clear(); } + + private void initializeAiModelClients() { + for (Map.Entry entry : modelClients.entrySet()) { + try { + entry.getValue().open(); + LOG.info("Successfully opened AI model client '{}'.", entry.getKey()); + } catch (Exception e) { + LOG.error("Failed to open AI model client '{}'.", entry.getKey(), e); + throw new FlinkRuntimeException( + "Failed to initialize AI model: " + entry.getKey(), e); + } + } + } + + private void destroyAiModelClients() { + for (Map.Entry entry : modelClients.entrySet()) { + try { + entry.getValue().close(); + LOG.info("Successfully closed AI model client '{}'.", entry.getKey()); + } catch (Exception e) { + LOG.warn("Failed to close AI model client '{}'.", entry.getKey(), e); + } + } + } } diff --git a/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/operators/transform/PostTransformOperatorBuilder.java b/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/operators/transform/PostTransformOperatorBuilder.java index 380d9343a74..be2e8823b84 100644 --- a/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/operators/transform/PostTransformOperatorBuilder.java +++ b/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/operators/transform/PostTransformOperatorBuilder.java @@ -18,6 +18,7 @@ package org.apache.flink.cdc.runtime.operators.transform; import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.cdc.common.model.AiModelClient; import org.apache.flink.cdc.common.pipeline.PipelineOptions; import org.apache.flink.cdc.common.source.SupportedMetadataColumn; @@ -25,6 +26,7 @@ import java.time.ZoneId; import java.util.ArrayList; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; @@ -34,6 +36,7 @@ public class PostTransformOperatorBuilder { private String timezone; private final List>> udfFunctions = new ArrayList<>(); + private final Map modelClients = new LinkedHashMap<>(); public PostTransformOperatorBuilder addTransform( String tableInclusions, @@ -111,7 +114,12 @@ public PostTransformOperatorBuilder addUdfFunctions( return this; } + public PostTransformOperatorBuilder addModelClients(Map clients) { + this.modelClients.putAll(clients); + return this; + } + public PostTransformOperator build() { - return new PostTransformOperator(transformRules, timezone, udfFunctions); + return new PostTransformOperator(transformRules, timezone, udfFunctions, modelClients); } } diff --git a/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/operators/transform/ProjectionColumnProcessor.java b/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/operators/transform/ProjectionColumnProcessor.java index db5c37ca258..2dc86b37bea 100644 --- a/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/operators/transform/ProjectionColumnProcessor.java +++ b/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/operators/transform/ProjectionColumnProcessor.java @@ -18,6 +18,7 @@ package org.apache.flink.cdc.runtime.operators.transform; import org.apache.flink.cdc.common.converter.JavaClassConverter; +import org.apache.flink.cdc.common.model.AiModelClient; import org.apache.flink.cdc.common.schema.Column; import org.apache.flink.cdc.common.source.SupportedMetadataColumn; import org.apache.flink.cdc.runtime.parser.JaninoCompiler; @@ -26,6 +27,7 @@ import java.lang.reflect.InvocationTargetException; import java.util.ArrayList; +import java.util.Collections; import java.util.LinkedHashSet; import java.util.List; import java.util.Map; @@ -45,6 +47,7 @@ public class ProjectionColumnProcessor { private final TransformExpressionKey transformExpressionKey; private final Map supportedMetadataColumns; private final List udfFunctionInstances; + private final Map modelClients; private final ExpressionEvaluator expressionEvaluator; public ProjectionColumnProcessor( @@ -53,15 +56,17 @@ public ProjectionColumnProcessor( String timezone, List udfDescriptors, final List udfFunctionInstances, - Map supportedMetadataColumns) { + Map supportedMetadataColumns, + Map modelClients) { this.tableInfo = tableInfo; this.projectionColumn = projectionColumn; this.timezone = timezone; this.supportedMetadataColumns = supportedMetadataColumns; + this.modelClients = modelClients; this.transformExpressionKey = generateTransformExpressionKey(); this.expressionEvaluator = TransformExpressionCompiler.compileExpression( - transformExpressionKey, udfDescriptors); + transformExpressionKey, udfDescriptors, modelClients); this.udfFunctionInstances = udfFunctionInstances; } @@ -72,13 +77,32 @@ public static ProjectionColumnProcessor of( List udfDescriptors, List udfFunctionInstances, Map supportedMetadataColumns) { + return of( + tableInfo, + projectionColumn, + timezone, + udfDescriptors, + udfFunctionInstances, + supportedMetadataColumns, + Collections.emptyMap()); + } + + public static ProjectionColumnProcessor of( + PostTransformChangeInfo tableInfo, + ProjectionColumn projectionColumn, + String timezone, + List udfDescriptors, + List udfFunctionInstances, + Map supportedMetadataColumns, + Map modelClients) { return new ProjectionColumnProcessor( tableInfo, projectionColumn, timezone, udfDescriptors, udfFunctionInstances, - supportedMetadataColumns); + supportedMetadataColumns, + modelClients); } public Object evaluate(Object[] rowData, TransformContext context) { @@ -123,6 +147,9 @@ private Object[] generateParams(Object[] rowData, TransformContext context) { // 3 - Add UDF function instances params.addAll(udfFunctionInstances); + + // 4 - Add AI model client instances + params.addAll(modelClients.values()); return params.toArray(); } diff --git a/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/operators/transform/TransformExpressionCompiler.java b/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/operators/transform/TransformExpressionCompiler.java index 1085a7df82d..9faa111966f 100644 --- a/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/operators/transform/TransformExpressionCompiler.java +++ b/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/operators/transform/TransformExpressionCompiler.java @@ -18,6 +18,7 @@ package org.apache.flink.cdc.runtime.operators.transform; import org.apache.flink.api.common.InvalidProgramException; +import org.apache.flink.cdc.common.model.AiModelClient; import org.apache.flink.cdc.runtime.operators.transform.exceptions.TransformException; import org.apache.flink.util.FlinkRuntimeException; @@ -31,6 +32,7 @@ import java.util.ArrayList; import java.util.List; +import java.util.Map; /** * The processor of the transform expression. It processes the expression of projections and @@ -51,9 +53,17 @@ public static void cleanUp() { COMPILED_EXPRESSION_CACHE.invalidateAll(); } - /** Compiles an expression code to a janino {@link ExpressionEvaluator}. */ + /** + * Compiles an expression code to a janino {@link ExpressionEvaluator}, with additional {@link + * AiModelClient} instances appended after UDF instances. + * + *

{@code modelClients} maps model names (e.g. {@code myModel}) to the corresponding client + * instances. + */ public static ExpressionEvaluator compileExpression( - TransformExpressionKey key, List udfDescriptors) { + TransformExpressionKey key, + List udfDescriptors, + Map modelClients) { try { return COMPILED_EXPRESSION_CACHE.get( key, @@ -68,6 +78,11 @@ public static ExpressionEvaluator compileExpression( argumentClasses.add(Class.forName(udfFunction.getClasspath())); } + for (String paramName : modelClients.keySet()) { + argumentNames.add(paramName); + argumentClasses.add(AiModelClient.class); + } + // Input args expressionEvaluator.setParameters( argumentNames.toArray(new String[0]), diff --git a/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/operators/transform/TransformFilterProcessor.java b/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/operators/transform/TransformFilterProcessor.java index 77dc5cfe6ef..a72e7aa94e6 100644 --- a/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/operators/transform/TransformFilterProcessor.java +++ b/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/operators/transform/TransformFilterProcessor.java @@ -19,6 +19,7 @@ import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.cdc.common.converter.JavaClassConverter; +import org.apache.flink.cdc.common.model.AiModelClient; import org.apache.flink.cdc.common.schema.Column; import org.apache.flink.cdc.common.source.SupportedMetadataColumn; import org.apache.flink.cdc.runtime.parser.JaninoCompiler; @@ -47,6 +48,7 @@ public class TransformFilterProcessor { private final String timezone; private final List udfFunctionInstances; private final Map supportedMetadataColumns; + private final Map modelClients; private final TransformExpressionKey transformExpressionKey; private final ExpressionEvaluator expressionEvaluator; @@ -58,13 +60,15 @@ protected TransformFilterProcessor( String timezone, List udfDescriptors, List udfFunctionInstances, - Map supportedMetadataColumns) { + Map supportedMetadataColumns, + Map modelClients) { this.isNoOp = isNoOp; this.tableInfo = tableInfo; this.transformFilter = transformFilter; this.timezone = timezone; this.udfFunctionInstances = udfFunctionInstances; this.supportedMetadataColumns = supportedMetadataColumns; + this.modelClients = modelClients; if (isNoOp) { this.transformExpressionKey = null; @@ -79,12 +83,12 @@ protected TransformFilterProcessor( .toArray(new SupportedMetadataColumn[0])); this.expressionEvaluator = TransformExpressionCompiler.compileExpression( - transformExpressionKey, udfDescriptors); + transformExpressionKey, udfDescriptors, modelClients); } } public static TransformFilterProcessor ofNoOp() { - return new TransformFilterProcessor(true, null, null, null, null, null, null); + return new TransformFilterProcessor(true, null, null, null, null, null, null, null); } public static TransformFilterProcessor of( @@ -93,7 +97,8 @@ public static TransformFilterProcessor of( String timezone, List udfDescriptors, List udfFunctionInstances, - SupportedMetadataColumn[] supportedMetadataColumns) { + SupportedMetadataColumn[] supportedMetadataColumns, + Map modelClients) { Map supportedMetadataColumnsMap = new HashMap<>(); for (SupportedMetadataColumn supportedMetadataColumn : supportedMetadataColumns) { supportedMetadataColumnsMap.put( @@ -106,7 +111,8 @@ public static TransformFilterProcessor of( timezone, udfDescriptors, udfFunctionInstances, - supportedMetadataColumnsMap); + supportedMetadataColumnsMap, + modelClients); } public boolean test(Object[] preRow, Object[] postRow, TransformContext context) { @@ -208,6 +214,9 @@ private Object[] generateParams(Object[] preRow, Object[] postRow, TransformCont // 3 - Add UDF function instances params.addAll(udfFunctionInstances); + + // 4 - Add AI model client instances + params.addAll(modelClients.values()); return params.toArray(); } diff --git a/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/operators/transform/TransformProjectionProcessor.java b/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/operators/transform/TransformProjectionProcessor.java index 231e72875ac..09fe5fb19db 100644 --- a/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/operators/transform/TransformProjectionProcessor.java +++ b/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/operators/transform/TransformProjectionProcessor.java @@ -17,6 +17,7 @@ package org.apache.flink.cdc.runtime.operators.transform; +import org.apache.flink.cdc.common.model.AiModelClient; import org.apache.flink.cdc.common.source.SupportedMetadataColumn; import org.apache.flink.cdc.common.utils.Preconditions; import org.apache.flink.cdc.runtime.parser.TransformParser; @@ -52,6 +53,7 @@ public class TransformProjectionProcessor { private final List columnProcessors; private final SupportedMetadataColumn[] supportedMetadataColumns; private final Map supportedMetadataColumnsMap; + private final Map modelClients; public TransformProjectionProcessor( PostTransformChangeInfo changeInfo, @@ -59,13 +61,15 @@ public TransformProjectionProcessor( String timezone, List udfDescriptors, List udfFunctionInstances, - SupportedMetadataColumn[] supportedMetadataColumns) { + SupportedMetadataColumn[] supportedMetadataColumns, + Map modelClients) { this.changeInfo = changeInfo; this.projectionExpression = projectionExpression; this.timezone = timezone; this.udfDescriptors = udfDescriptors; this.udfFunctionInstances = udfFunctionInstances; this.supportedMetadataColumns = supportedMetadataColumns; + this.modelClients = modelClients; // Construct a mapping table ad-hoc to accelerate looking-up Map supportedMetadataColumnsMap = new HashMap<>(); @@ -105,7 +109,8 @@ private List createProjectionColumnProcessors() { timezone, udfDescriptors, udfFunctionInstances, - supportedMetadataColumnsMap)) + supportedMetadataColumnsMap, + modelClients)) .collect(Collectors.toList()); LOG.info("Successfully created projection column processors cache."); diff --git a/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/parser/JaninoCompiler.java b/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/parser/JaninoCompiler.java index 14c483d7082..a022f879c42 100644 --- a/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/parser/JaninoCompiler.java +++ b/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/parser/JaninoCompiler.java @@ -27,6 +27,8 @@ import org.apache.flink.cdc.common.types.DataTypeRoot; import org.apache.flink.cdc.common.utils.Preconditions; import org.apache.flink.cdc.common.utils.StringUtils; +import org.apache.flink.cdc.runtime.ai.AiEmbeddingFunctionDef; +import org.apache.flink.cdc.runtime.ai.AiTextFunctionDef; import org.apache.flink.cdc.runtime.operators.transform.UserDefinedFunctionDescriptor; import org.apache.calcite.sql.SqlBasicCall; @@ -46,6 +48,8 @@ import org.codehaus.janino.ExpressionEvaluator; import org.codehaus.janino.Java; +import javax.lang.model.SourceVersion; + import java.util.ArrayList; import java.util.Arrays; import java.util.List; @@ -90,7 +94,7 @@ public class JaninoCompiler { public static final String DEFAULT_TIME_ZONE = "__time_zone__"; private static final String[] BUILTIN_FUNCTION_MODULES = { - "Arithmetic", "Casting", "Comparison", "Logical", "String", "Struct", "Temporal" + "Ai", "Arithmetic", "Casting", "Comparison", "Logical", "String", "Struct", "Temporal" }; @VisibleForTesting @@ -193,6 +197,10 @@ private static Java.Rvalue translateSqlSqlLiteral(Context context, SqlLiteral sq } private static Java.Rvalue translateSqlBasicCall(Context context, SqlBasicCall sqlBasicCall) { + String operationName = sqlBasicCall.getOperator().getName().toUpperCase(); + if (isAiFunction(operationName)) { + validateAiFunctionModelArg(sqlBasicCall); + } List operandList = sqlBasicCall.getOperandList(); List atoms = new ArrayList<>(); for (SqlNode sqlNode : operandList) { @@ -526,23 +534,71 @@ private static Java.Rvalue generateOtherFunctionOperation( context.udfDescriptors.stream() .filter(e -> e.getName().equalsIgnoreCase(operationName)) .findFirst(); - return udfFunctionOptional - .map( - udfFunction -> - new Java.MethodInvocation( - Location.NOWHERE, - null, - generateInvokeExpression(udfFunction), - atoms)) - .orElseGet( - () -> - new Java.MethodInvocation( - Location.NOWHERE, - null, - StringUtils.convertToCamelCase( - sqlBasicCall.getOperator().getName()), - atoms)); + if (udfFunctionOptional.isPresent()) { + return new Java.MethodInvocation( + Location.NOWHERE, + null, + generateInvokeExpression(udfFunctionOptional.get()), + atoms); + } + if (isAiFunction(operationName) && atoms.length >= 1) { + rewriteAiFunctionModelArg(atoms); + } + return new Java.MethodInvocation( + Location.NOWHERE, + null, + StringUtils.convertToCamelCase(sqlBasicCall.getOperator().getName()), + atoms); + } + } + + private static boolean isAiFunction(String upperCaseName) { + for (AiTextFunctionDef def : AiTextFunctionDef.values()) { + if (def.getFunctionName().equals(upperCaseName)) { + return true; + } + } + for (AiEmbeddingFunctionDef def : AiEmbeddingFunctionDef.values()) { + if (def.getFunctionName().equals(upperCaseName)) { + return true; + } + } + return false; + } + + private static void validateAiFunctionModelArg(SqlBasicCall sqlBasicCall) { + String functionName = sqlBasicCall.getOperator().getName(); + List operandList = sqlBasicCall.getOperandList(); + if (operandList.isEmpty()) { + throw new ParseException( + "AI function '" + + functionName + + "' requires the model name as the first argument."); + } + SqlNode first = operandList.get(0); + if (!(first instanceof SqlCharStringLiteral)) { + throw new ParseException( + "The first argument of AI function '" + + functionName + + "' must be a string literal naming the model, but got: " + + first + + "."); + } + } + + private static void rewriteAiFunctionModelArg(Java.Rvalue[] atoms) { + String modelName = atoms[0].toString(); + if (modelName.startsWith("\"") && modelName.endsWith("\"")) { + modelName = modelName.substring(1, modelName.length() - 1); + } + if (!SourceVersion.isIdentifier(modelName) || SourceVersion.isKeyword(modelName)) { + throw new ParseException( + "AI function model name '" + + modelName + + "' is not a valid Java identifier. " + + "Model names must follow Java identifier rules and must not be reserved keywords."); } + atoms[0] = new Java.AmbiguousName(Location.NOWHERE, new String[] {modelName}); } private static Java.Rvalue generateTimezoneFreeTemporalFunctionOperation( diff --git a/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/parser/TransformParser.java b/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/parser/TransformParser.java index 5ee7153013c..115dfedd7f2 100644 --- a/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/parser/TransformParser.java +++ b/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/parser/TransformParser.java @@ -24,6 +24,7 @@ import org.apache.flink.cdc.common.utils.Preconditions; import org.apache.flink.cdc.runtime.operators.transform.ProjectionColumn; import org.apache.flink.cdc.runtime.operators.transform.UserDefinedFunctionDescriptor; +import org.apache.flink.cdc.runtime.parser.metadata.AiFunctionSqlOperatorTable; import org.apache.flink.cdc.runtime.parser.metadata.TransformSchemaFactory; import org.apache.flink.cdc.runtime.parser.metadata.TransformSqlOperatorTable; import org.apache.flink.cdc.runtime.typeutils.CalciteDataTypeConverter; @@ -163,9 +164,13 @@ private static RelNode sqlToRel( new CalciteConnectionConfigImpl(new Properties())); TransformSqlOperatorTable transformSqlOperatorTable = TransformSqlOperatorTable.instance(); SqlOperatorTable udfOperatorTable = SqlOperatorTables.of(udfFunctions); + SqlOperatorTable aiFunctionOperatorTable = AiFunctionSqlOperatorTable.create(); SqlValidator validator = SqlValidatorUtil.newValidator( - SqlOperatorTables.chain(transformSqlOperatorTable, udfOperatorTable), + SqlOperatorTables.chain( + transformSqlOperatorTable, + udfOperatorTable, + aiFunctionOperatorTable), calciteCatalogReader, factory, SqlValidator.Config.DEFAULT diff --git a/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/parser/metadata/AiFunctionSqlOperatorTable.java b/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/parser/metadata/AiFunctionSqlOperatorTable.java new file mode 100644 index 00000000000..69628c5d327 --- /dev/null +++ b/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/parser/metadata/AiFunctionSqlOperatorTable.java @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.cdc.runtime.parser.metadata; + +import org.apache.flink.cdc.common.types.DataType; +import org.apache.flink.cdc.common.types.RowType; +import org.apache.flink.cdc.runtime.ai.AiEmbeddingFunctionDef; +import org.apache.flink.cdc.runtime.ai.AiTextFunctionDef; +import org.apache.flink.cdc.runtime.typeutils.CalciteDataTypeConverter; + +import org.apache.calcite.sql.SqlFunction; +import org.apache.calcite.sql.SqlFunctionCategory; +import org.apache.calcite.sql.SqlKind; +import org.apache.calcite.sql.type.OperandTypes; +import org.apache.calcite.sql.type.ReturnTypes; +import org.apache.calcite.sql.type.SqlTypeFamily; +import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.calcite.sql.util.SqlOperatorTables; + +import java.util.ArrayList; +import java.util.List; + +/** Creates SqlOperatorTable from {@link AiTextFunctionDef} definitions. */ +public class AiFunctionSqlOperatorTable { + + private AiFunctionSqlOperatorTable() {} + + /** Creates an SqlOperatorTable containing all AI functions defined in AiFunctionDef. */ + public static org.apache.calcite.sql.SqlOperatorTable create() { + List functions = new ArrayList<>(); + for (AiTextFunctionDef def : AiTextFunctionDef.values()) { + functions.add(createTextSqlFunction(def)); + } + for (AiEmbeddingFunctionDef def : AiEmbeddingFunctionDef.values()) { + functions.add(createEmbeddingSqlFunction(def)); + } + return SqlOperatorTables.of(functions); + } + + private static SqlFunction createTextSqlFunction(AiTextFunctionDef def) { + return new SqlFunction( + def.getFunctionName(), + SqlKind.OTHER_FUNCTION, + ReturnTypes.explicit(SqlTypeName.VARIANT), + null, + OperandTypes.family(toSqlTypeFamiliesWithAdditionalParams(def.getInputType())), + SqlFunctionCategory.USER_DEFINED_FUNCTION); + } + + private static SqlFunction createEmbeddingSqlFunction(AiEmbeddingFunctionDef def) { + return new SqlFunction( + def.getFunctionName(), + SqlKind.OTHER_FUNCTION, + opBinding -> + CalciteDataTypeConverter.convertCalciteType( + opBinding.getTypeFactory(), def.getOutputType()), + null, + OperandTypes.family(SqlTypeFamily.STRING, toSqlTypeFamily(def.getInputType())), + SqlFunctionCategory.USER_DEFINED_FUNCTION); + } + + /** + * Converts inputType to SqlTypeFamily array, prepending additional parameters: modelName + * (STRING) and input (STRING). + */ + private static SqlTypeFamily[] toSqlTypeFamiliesWithAdditionalParams(RowType inputType) { + List families = new ArrayList<>(); + families.add(SqlTypeFamily.STRING); // modelName + families.add(SqlTypeFamily.STRING); // input + for (DataType fieldType : inputType.getFieldTypes()) { + families.add(toSqlTypeFamily(fieldType)); + } + return families.toArray(new SqlTypeFamily[0]); + } + + private static SqlTypeFamily toSqlTypeFamily(DataType dataType) { + switch (dataType.getTypeRoot()) { + case VARCHAR: + case CHAR: + return SqlTypeFamily.STRING; + case INTEGER: + return SqlTypeFamily.INTEGER; + case BIGINT: + return SqlTypeFamily.NUMERIC; + case FLOAT: + case DOUBLE: + return SqlTypeFamily.APPROXIMATE_NUMERIC; + case BOOLEAN: + return SqlTypeFamily.BOOLEAN; + default: + return SqlTypeFamily.ANY; + } + } +} diff --git a/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/parser/metadata/TransformSqlOperatorTable.java b/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/parser/metadata/TransformSqlOperatorTable.java index eef0755c112..26408270ccc 100644 --- a/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/parser/metadata/TransformSqlOperatorTable.java +++ b/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/parser/metadata/TransformSqlOperatorTable.java @@ -406,38 +406,6 @@ public SqlSyntax getSyntax() { // Supports accessing elements of ARRAY[index], ROW[index], MAP[key], and VARIANT[index/key] public static final SqlOperator ITEM = new VariantAwareItemOperator(); - public static final SqlFunction AI_CHAT_PREDICT = - new SqlFunction( - "AI_CHAT_PREDICT", - SqlKind.OTHER_FUNCTION, - ReturnTypes.explicit(SqlTypeName.VARCHAR), - null, - OperandTypes.family( - SqlTypeFamily.STRING, SqlTypeFamily.STRING, SqlTypeFamily.STRING), - SqlFunctionCategory.USER_DEFINED_FUNCTION); - - // Define the AI_EMBEDDING function - public static final SqlFunction GET_EMBEDDING = - new SqlFunction( - "GET_EMBEDDING", - SqlKind.OTHER_FUNCTION, - ReturnTypes.explicit(SqlTypeName.VARCHAR), - null, - OperandTypes.family( - SqlTypeFamily.STRING, SqlTypeFamily.STRING, SqlTypeFamily.STRING), - SqlFunctionCategory.USER_DEFINED_FUNCTION); - - // Define the AI_LANGCHAIN_PREDICT function - public static final SqlFunction AI_LANGCHAIN_PREDICT = - new SqlFunction( - "AI_LANGCHAIN_PREDICT", - SqlKind.OTHER_FUNCTION, - ReturnTypes.explicit(SqlTypeName.VARCHAR), - null, - OperandTypes.family( - SqlTypeFamily.STRING, SqlTypeFamily.STRING, SqlTypeFamily.STRING), - SqlFunctionCategory.USER_DEFINED_FUNCTION); - // -------------------------------------------------------------------------------------------- // Variant functions // -------------------------------------------------------------------------------------------- diff --git a/flink-cdc-runtime/src/test/java/org/apache/flink/cdc/runtime/functions/impl/AiFunctionsTest.java b/flink-cdc-runtime/src/test/java/org/apache/flink/cdc/runtime/functions/impl/AiFunctionsTest.java new file mode 100644 index 00000000000..01498675aad --- /dev/null +++ b/flink-cdc-runtime/src/test/java/org/apache/flink/cdc/runtime/functions/impl/AiFunctionsTest.java @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.cdc.runtime.functions.impl; + +import org.apache.flink.cdc.common.model.AiModelClient; +import org.apache.flink.cdc.common.model.abilities.SupportsEmbedding; +import org.apache.flink.cdc.common.model.abilities.SupportsTextGeneration; + +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** Unit tests for {@link AiFunctions}. */ +class AiFunctionsTest { + + private static class MockModelClient + implements AiModelClient, SupportsTextGeneration, SupportsEmbedding { + private static final long serialVersionUID = 1L; + + @Override + public String generate(String systemPrompt, String userInput) { + if (systemPrompt.contains("summarization expert")) { + return "{\"summary\": \"This is a summary.\"}"; + } + // Default for AI_COMPLETE + return "{\"result\": \"hello world\"}"; + } + + @Override + public float[] embed(String text) { + return new float[] {0.1f, 0.2f, 0.3f, 0.4f, 0.5f}; + } + } + + private static class UselessClient implements AiModelClient {} + + @Test + void testAiFunctionInvocation() { + MockModelClient model = new MockModelClient(); + assertThat(AiFunctions.aiComplete(model, "Say hello", "You are a helpful assistant.")) + .hasToString("{\"result\":\"hello world\"}"); + assertThat(AiFunctions.aiSummarize(model, "Long text here...", 100)) + .hasToString("{\"summary\":\"This is a summary.\"}"); + assertThat(AiFunctions.aiEmbed(model, "Test text")) + .containsExactly(0.1f, 0.2f, 0.3f, 0.4f, 0.5f); + } + + @Test + void testUnsupportedModel() { + UselessClient model = new UselessClient(); + assertThatThrownBy(() -> AiFunctions.aiComplete(model, "test", "prompt")) + .isInstanceOf(UnsupportedOperationException.class) + .hasMessageContaining("does not support text generation"); + assertThatThrownBy(() -> AiFunctions.aiSummarize(model, "test", 100)) + .isInstanceOf(UnsupportedOperationException.class) + .hasMessageContaining("does not support text generation"); + assertThatThrownBy(() -> AiFunctions.aiEmbed(model, "test")) + .isInstanceOf(UnsupportedOperationException.class) + .hasMessageContaining("does not support embedding"); + } + + @Test + void testAiCompleteWithInvalidJsonResponse() { + MockModelClient model = + new MockModelClient() { + @Override + public String generate(String systemPrompt, String userInput) { + return "invalid json"; + } + }; + + assertThatThrownBy(() -> AiFunctions.aiComplete(model, "test", "prompt")) + .isInstanceOf(RuntimeException.class) + .hasMessageContaining("Failed to parse AI response as JSON"); + } + + @Test + void testAiCompleteWithEmptyJsonResponse() { + MockModelClient model = + new MockModelClient() { + @Override + public String generate(String systemPrompt, String userInput) { + return "{}"; + } + }; + assertThat(AiFunctions.aiComplete(model, "test", "prompt")).hasToString("{}"); + } + + @Test + void testAiCompleteWithNullResponse() { + MockModelClient model = + new MockModelClient() { + @Override + public String generate(String systemPrompt, String userInput) { + return null; + } + }; + assertThat(AiFunctions.aiComplete(model, "test", "prompt")).isNull(); + assertThat(AiFunctions.aiSummarize(model, "test", 100)).isNull(); + } + + @Test + void testAiFunctionCornerCase() { + MockModelClient model = new MockModelClient(); + assertThat(AiFunctions.aiEmbed(model, "")).containsExactly(0.1f, 0.2f, 0.3f, 0.4f, 0.5f); + assertThat(AiFunctions.aiEmbed(model, "中文嵌入测试")) + .containsExactly(0.1f, 0.2f, 0.3f, 0.4f, 0.5f); + } +} diff --git a/flink-cdc-runtime/src/test/java/org/apache/flink/cdc/runtime/operators/transform/UserDefinedFunctionDescriptorTest.java b/flink-cdc-runtime/src/test/java/org/apache/flink/cdc/runtime/operators/transform/UserDefinedFunctionDescriptorTest.java index 79f9c92af5b..bfd1bb098d4 100644 --- a/flink-cdc-runtime/src/test/java/org/apache/flink/cdc/runtime/operators/transform/UserDefinedFunctionDescriptorTest.java +++ b/flink-cdc-runtime/src/test/java/org/apache/flink/cdc/runtime/operators/transform/UserDefinedFunctionDescriptorTest.java @@ -20,10 +20,8 @@ import org.apache.flink.cdc.common.types.DataType; import org.apache.flink.cdc.common.types.DataTypes; import org.apache.flink.cdc.common.udf.UserDefinedFunction; -import org.apache.flink.cdc.runtime.model.OpenAIEmbeddingModel; import org.apache.flink.table.functions.ScalarFunction; -import com.fasterxml.jackson.core.JsonProcessingException; import org.junit.jupiter.api.Test; import static org.assertj.core.api.Assertions.assertThat; @@ -50,7 +48,7 @@ public static class FlinkUdf extends ScalarFunction {} public static class NotUDF {} @Test - void testUserDefinedFunctionDescriptor() throws JsonProcessingException { + void testUserDefinedFunctionDescriptor() { assertThat(new UserDefinedFunctionDescriptor("cdc_udf", CdcUdf.class.getName())) .extracting("name", "className", "classpath", "returnTypeHint", "isCdcPipelineUdf") @@ -95,14 +93,5 @@ void testUserDefinedFunctionDescriptor() throws JsonProcessingException { "not_even_exist", "not.a.valid.class.path")) .isExactlyInstanceOf(IllegalArgumentException.class) .hasMessage("Failed to instantiate UDF not_even_exist@not.a.valid.class.path"); - String name = "GET_EMBEDDING"; - assertThat(new UserDefinedFunctionDescriptor(name, OpenAIEmbeddingModel.class.getName())) - .extracting("name", "className", "classpath", "returnTypeHint", "isCdcPipelineUdf") - .containsExactly( - "GET_EMBEDDING", - "OpenAIEmbeddingModel", - "org.apache.flink.cdc.runtime.model.OpenAIEmbeddingModel", - DataTypes.ARRAY(DataTypes.FLOAT()), - true); } } diff --git a/flink-cdc-runtime/src/test/java/org/apache/flink/cdc/runtime/parser/AiFunctionParserTest.java b/flink-cdc-runtime/src/test/java/org/apache/flink/cdc/runtime/parser/AiFunctionParserTest.java new file mode 100644 index 00000000000..64e226c0c56 --- /dev/null +++ b/flink-cdc-runtime/src/test/java/org/apache/flink/cdc/runtime/parser/AiFunctionParserTest.java @@ -0,0 +1,175 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.cdc.runtime.parser; + +import org.apache.flink.cdc.common.schema.Column; +import org.apache.flink.cdc.common.source.SupportedMetadataColumn; +import org.apache.flink.cdc.common.types.DataTypes; +import org.apache.flink.cdc.runtime.operators.transform.ProjectionColumn; + +import org.junit.jupiter.api.Test; + +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** Unit tests for {@link TransformParser} and {@link JaninoCompiler} with AI functions. */ +class AiFunctionParserTest { + + private static final List DUMMY_COLUMNS = + List.of( + Column.physicalColumn("id", DataTypes.INT()), + Column.physicalColumn("content", DataTypes.STRING()), + Column.physicalColumn("text", DataTypes.STRING()), + Column.physicalColumn("flag", DataTypes.BOOLEAN())); + private static final Map COLUMN_MAP = + Map.of( + "id", "$1", + "content", "$2", + "text", "$3", + "flag", "$4"); + + @Test + void testTranslateAiFunctionInProjection() { + assertThat( + translateAsProjection( + "id, AI_COMPLETE('myModel', content, 'You are a classifier') AS output_col")) + .map(ProjectionColumn::toString) + .containsExactly( + "ProjectionColumn{column=`id` INT, expression='id', scriptExpression='$0', originalColumnNames=[id], columnNameMap={id=$0}}", + "ProjectionColumn{column=`output_col` VARIANT, expression='AI_COMPLETE('myModel', `TB`.`content`, 'You are a classifier')', scriptExpression='aiComplete(myModel, $0, \"You are a classifier\")', originalColumnNames=[content], columnNameMap={content=$0}}"); + + assertThat(translateAsProjection("id, AI_SUMMARIZE('m1', content, 100) AS summary")) + .map(ProjectionColumn::toString) + .containsExactly( + "ProjectionColumn{column=`id` INT, expression='id', scriptExpression='$0', originalColumnNames=[id], columnNameMap={id=$0}}", + "ProjectionColumn{column=`summary` VARIANT, expression='AI_SUMMARIZE('m1', `TB`.`content`, 100)', scriptExpression='aiSummarize(m1, $0, 100)', originalColumnNames=[content], columnNameMap={content=$0}}"); + + assertThatThrownBy( + () -> translateAsProjection("AI_SUMMARIZE('m', content, flag) AS out_col")) + .hasMessageContaining( + "Cannot apply 'AI_SUMMARIZE' to arguments of type 'AI_SUMMARIZE(, , )'.") + .hasMessageContaining( + "Supported form(s): 'AI_SUMMARIZE(, , )'"); + + assertThatThrownBy(() -> translateAsProjection("AI_EMBED('m') AS out_col")) + .hasMessageContaining("Invalid number of arguments to function 'AI_EMBED'.") + .hasMessageContaining("Was expecting 2 arguments"); + + assertThatThrownBy(() -> translateAsProjection("AI_COMPLETE('m', content) AS out_col")) + .hasMessageContaining("Invalid number of arguments to function 'AI_COMPLETE'.") + .hasMessageContaining("Was expecting 3 arguments"); + } + + @Test + void testAiFunctionRejectsNonStringLiteralModelArg() { + assertThatThrownBy( + () -> + translateAsProjection( + "AI_COMPLETE(content, content, 'prompt') AS out_col")) + .hasMessageContaining( + "The first argument of AI function 'AI_COMPLETE' must be a string literal naming the model"); + + assertThatThrownBy(() -> translateAsProjection("AI_EMBED(123, content) AS out_col")) + .hasMessageContaining( + "The first argument of AI function 'AI_EMBED' must be a string literal naming the model"); + + assertThatThrownBy(() -> translateAsProjection("AI_EMBED(UPPER('m'), content) AS out_col")) + .hasMessageContaining( + "The first argument of AI function 'AI_EMBED' must be a string literal naming the model"); + + assertThatThrownBy( + () -> + translateAsProjection( + "AI_COMPLETE('my-model', content, 'p') AS out_col")) + .hasMessageContaining( + "AI function model name 'my-model' is not a valid Java identifier.") + .hasMessageContaining( + "Model names must follow Java identifier rules and must not be reserved keywords."); + + assertThatThrownBy( + () -> + translateAsProjection( + "AI_COMPLETE('class', content, 'p') AS out_col")) + .hasMessageContaining( + "AI function model name 'class' is not a valid Java identifier.") + .hasMessageContaining( + "Model names must follow Java identifier rules and must not be reserved keywords."); + + assertThatThrownBy( + () -> + translateAsProjection( + "AI_COMPLETE('my.model', content, 'p') AS out_col")) + .hasMessageContaining( + "AI function model name 'my.model' is not a valid Java identifier."); + + assertThatThrownBy( + () -> + translateAsProjection( + "AI_COMPLETE('my model', content, 'p') AS out_col")) + .hasMessageContaining( + "AI function model name 'my model' is not a valid Java identifier."); + + assertThatThrownBy( + () -> + translateAsProjection( + "AI_COMPLETE('123model', content, 'p') AS out_col")) + .hasMessageContaining( + "AI function model name '123model' is not a valid Java identifier."); + + assertThatThrownBy(() -> translateAsProjection("AI_COMPLETE('', content, 'p') AS out_col")) + .hasMessageContaining("AI function model name '' is not a valid Java identifier."); + } + + @Test + void testAiFunctionPreservesModelNameCase() { + assertThat(translateAsFilter("AI_COMPLETE('MyModel', content, 'p')")) + .isEqualTo("aiComplete(MyModel, $2, \"p\")"); + assertThat(translateAsFilter("AI_COMPLETE('mymodel', content, 'p')")) + .isEqualTo("aiComplete(mymodel, $2, \"p\")"); + } + + @Test + void testTranslateAiFunctionInFilter() { + assertThat(translateAsFilter("AI_COMPLETE('myModel', content, 'Classify this text')")) + .isEqualTo("aiComplete(myModel, $2, \"Classify this text\")"); + assertThat(translateAsFilter("AI_SUMMARIZE('summarizer', content, 100)")) + .isEqualTo("aiSummarize(summarizer, $2, 100)"); + assertThat(translateAsFilter("AI_EMBED('embedder', content)")) + .isEqualTo("aiEmbed(embedder, $2)"); + assertThat(translateAsFilter("ai_complete('myModel', content, 'prompt')")) + .isEqualTo("aiComplete(myModel, $2, \"prompt\")"); + } + + private List translateAsProjection(String expression) { + return TransformParser.generateProjectionColumns( + expression, DUMMY_COLUMNS, Collections.emptyList(), new SupportedMetadataColumn[0]); + } + + private String translateAsFilter(String expression) { + return TransformParser.translateFilterExpressionToJaninoExpression( + expression, + DUMMY_COLUMNS, + Collections.emptyList(), + new SupportedMetadataColumn[0], + COLUMN_MAP); + } +}