From ee9e387ce5ad6938060deada8bd9bfb841e8abec Mon Sep 17 00:00:00 2001 From: Yongzao <532741407@qq.com> Date: Mon, 12 Jan 2026 12:32:51 +0800 Subject: [PATCH] finish --- .../ainode/it/AINodeCallInferenceIT.java | 22 +++++++++++++++++++ .../ainode/core/manager/inference_manager.py | 6 ++++- 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeCallInferenceIT.java b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeCallInferenceIT.java index 3131e398059e..9bf0927bb622 100644 --- a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeCallInferenceIT.java +++ b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeCallInferenceIT.java @@ -48,6 +48,8 @@ public class AINodeCallInferenceIT { private static final String CALL_INFERENCE_SQL_TEMPLATE = "CALL INFERENCE(%s, \"SELECT s%d FROM root.AI LIMIT %d\", generateTime=true, outputLength=%d)"; + private static final String CALL_INFERENCE_BY_DEFAULT_SQL_TEMPLATE = + "CALL INFERENCE(%s, \"SELECT s%d FROM root.AI LIMIT 256\")"; private static final int DEFAULT_INPUT_LENGTH = 256; private static final int DEFAULT_OUTPUT_LENGTH = 48; @@ -69,6 +71,7 @@ public void callInferenceTest() throws SQLException { try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT); Statement statement = connection.createStatement()) { callInferenceTest(statement, modelInfo); + callInferenceByDefaultTest(statement, modelInfo); } } } @@ -96,4 +99,23 @@ public static void callInferenceTest(Statement statement, AINodeTestUtils.FakeMo } } } + + public static void callInferenceByDefaultTest( + Statement statement, AINodeTestUtils.FakeModelInfo modelInfo) throws SQLException { + // Invoke call inference for specified models, there should exist result. + for (int i = 0; i < 4; i++) { + String callInferenceSQL = + String.format(CALL_INFERENCE_BY_DEFAULT_SQL_TEMPLATE, modelInfo.getModelId(), i); + try (ResultSet resultSet = statement.executeQuery(callInferenceSQL)) { + ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); + checkHeader(resultSetMetaData, "output"); + int count = 0; + while (resultSet.next()) { + count++; + } + // Ensure the call inference return results + Assert.assertTrue(count > 0); + } + } + } } diff --git a/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py b/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py index ada641dd54c7..b758908c5e40 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py +++ b/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py @@ -265,7 +265,11 @@ def inference(self, req: TInferenceReq): req, data_getter=lambda r: r.dataset, extract_attrs=lambda r: { - "output_length": int(r.inferenceAttributes.pop("outputLength", 96)), + "output_length": ( + 96 + if r.inferenceAttributes is None + else int(r.inferenceAttributes.pop("outputLength", 96)) + ), **(r.inferenceAttributes or {}), }, resp_cls=TInferenceResp,