diff --git a/framework/fel/java/fel-core/src/main/java/modelengine/fel/core/util/Tip.java b/framework/fel/java/fel-core/src/main/java/modelengine/fel/core/util/Tip.java index c5a51fa47..b333b2a20 100644 --- a/framework/fel/java/fel-core/src/main/java/modelengine/fel/core/util/Tip.java +++ b/framework/fel/java/fel-core/src/main/java/modelengine/fel/core/util/Tip.java @@ -118,6 +118,9 @@ public Tip addAll(Map args) { * @return 表示当前的 {@link Tip}。 */ public Tip merge(Tip other) { + if (other == null) { + return this; + } return this.addAll(other.values); } diff --git a/framework/fel/java/fel-flow/src/main/java/modelengine/fel/engine/activities/AiStart.java b/framework/fel/java/fel-flow/src/main/java/modelengine/fel/engine/activities/AiStart.java index 19b32abe5..975563c53 100644 --- a/framework/fel/java/fel-flow/src/main/java/modelengine/fel/engine/activities/AiStart.java +++ b/framework/fel/java/fel-flow/src/main/java/modelengine/fel/engine/activities/AiStart.java @@ -592,10 +592,7 @@ public final AiState runnableParallel(Pattern... pat .orElseGet(() -> new AiParallel<>(this.start.parallel(), mineFlow).fork(branchProcessor)); } - AiState state = aiFork.join(Tip::new, (acc, data) -> { - acc.merge(data); - return acc; - }); + AiState state = aiFork.join(Tip::new, (acc, data) -> acc.merge(data)); ((Processor) state.publisher()).displayAs("runnableParallel"); return state; } diff --git a/framework/fel/java/fel-flow/src/test/java/modelengine/fel/engine/operators/PatternTest.java b/framework/fel/java/fel-flow/src/test/java/modelengine/fel/engine/operators/PatternTest.java index 47eac6f0e..733492fb4 100644 --- a/framework/fel/java/fel-flow/src/test/java/modelengine/fel/engine/operators/PatternTest.java +++ b/framework/fel/java/fel-flow/src/test/java/modelengine/fel/engine/operators/PatternTest.java @@ -42,6 +42,7 @@ import modelengine.fitframework.util.StringUtils; import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.RepeatedTest; import org.junit.jupiter.api.Test; import java.util.Collection; @@ -97,6 +98,23 @@ void shouldOkWhenAiFlowWithExampleSelector() { assertThat(converse.offer("1+2").await().text()).isEqualTo("2+2=4\n2+3=5\n1+2="); } + @RepeatedTest(1000) + @DisplayName("测试 RunnableParallel 并发稳定性") + void shouldStableWhenRunnableParallelUnderConcurrency() { + Example[] examples = {new DefaultExample("2+2", "4"), new DefaultExample("2+3", "5")}; + Conversation converse = AiFlows.create() + .runnableParallel(question(), + fewShot(ExampleSelector.builder() + .template("{{q}}={{a}}", "q", "a") + .delimiter("\n") + .example(examples) + .build())) + .prompt(Prompts.human("{{examples}}\n{{question}}=")) + .close() + .converse(); + assertThat(converse.offer("1+2").await().text()).isEqualTo("2+2=4\n2+3=5\n1+2="); + } + @Test @DisplayName("测试 Retriever") void shouldOkWhenAiFlowWithRetriever() { diff --git a/framework/waterflow/java/waterflow-core/src/main/java/modelengine/fit/waterflow/domain/states/Fork.java b/framework/waterflow/java/waterflow-core/src/main/java/modelengine/fit/waterflow/domain/states/Fork.java index e49eed575..93f9956d5 100644 --- a/framework/waterflow/java/waterflow-core/src/main/java/modelengine/fit/waterflow/domain/states/Fork.java +++ b/framework/waterflow/java/waterflow-core/src/main/java/modelengine/fit/waterflow/domain/states/Fork.java @@ -93,7 +93,13 @@ public synchronized R process(FlowContext input) { acc = Tuple.from((R) "", 0); } } - acc = Tuple.from(processor.process(acc.first(), input.getData()), acc.second() + 1); + + O inputData = input.getData(); + if (inputData == null) { + return null; + } + R processedResult = processor.process(acc.first(), inputData); + acc = Tuple.from(processedResult, acc.second() + 1); accs.put(key, acc); if (acc.second() == forkNumber.get()) { diff --git a/framework/waterflow/java/waterflow-core/src/main/java/modelengine/fit/waterflow/domain/stream/nodes/To.java b/framework/waterflow/java/waterflow-core/src/main/java/modelengine/fit/waterflow/domain/stream/nodes/To.java index cb5520193..0d565caf9 100644 --- a/framework/waterflow/java/waterflow-core/src/main/java/modelengine/fit/waterflow/domain/stream/nodes/To.java +++ b/framework/waterflow/java/waterflow-core/src/main/java/modelengine/fit/waterflow/domain/stream/nodes/To.java @@ -937,7 +937,9 @@ public List> process(To to, List