From 303aa92626f87217748286664c9b08f2695444b7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?H=C3=A9lder=20Greg=C3=B3rio?= Date: Fri, 13 Mar 2026 23:54:27 +0000 Subject: [PATCH 1/3] move orchestration out of ArrowFlightMetaImpl --- .../driver/jdbc/ArrowFlightConnection.java | 28 +++ .../driver/jdbc/ArrowFlightJdbcFactory.java | 25 +- .../ArrowFlightJdbcFlightStreamResultSet.java | 2 +- ...owFlightJdbcVectorSchemaRootResultSet.java | 2 +- .../driver/jdbc/ArrowFlightMetaImpl.java | 229 +++++++----------- .../jdbc/ArrowFlightPreparedStatement.java | 129 ++++++++-- .../driver/jdbc/ArrowFlightStatement.java | 18 +- 7 files changed, 249 insertions(+), 184 deletions(-) diff --git a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightConnection.java b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightConnection.java index 623c2b81be..166c1157a1 100644 --- a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightConnection.java +++ b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightConnection.java @@ -19,6 +19,8 @@ import static org.apache.arrow.driver.jdbc.utils.ArrowFlightConnectionConfigImpl.ArrowFlightConnectionProperty.replaceSemiColons; import io.netty.util.concurrent.DefaultThreadFactory; +import java.sql.PreparedStatement; +import java.sql.ResultSet; import java.sql.SQLException; import java.util.ArrayList; import java.util.HashMap; @@ -257,4 +259,30 @@ BufferAllocator getBufferAllocator() { public ArrowFlightMetaImpl getMeta() { return (ArrowFlightMetaImpl) this.meta; } + + @Override + public PreparedStatement prepareStatement(final String sql) throws SQLException { + checkOpen(); + return prepareStatement(sql, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY); + } + + @Override + public PreparedStatement prepareStatement( + final String sql, final int resultSetType, final int resultSetConcurrency) + throws SQLException { + checkOpen(); + return prepareStatement(sql, resultSetType, resultSetConcurrency, getHoldability()); + } + + @Override + public PreparedStatement prepareStatement( + final String sql, + final int resultSetType, + final int resultSetConcurrency, + final int resultSetHoldability) + throws SQLException { + checkOpen(); + return getMeta() + .createPreparedStatement(sql, resultSetType, resultSetConcurrency, resultSetHoldability); + } } diff --git a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcFactory.java b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcFactory.java index e1ccfc820f..202b491f5c 100644 --- a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcFactory.java +++ b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcFactory.java @@ -20,7 +20,6 @@ import java.sql.SQLException; import java.util.Properties; import java.util.TimeZone; -import org.apache.arrow.driver.jdbc.client.ArrowFlightSqlClientHandler; import org.apache.arrow.memory.RootAllocator; import org.apache.calcite.avatica.AvaticaConnection; import org.apache.calcite.avatica.AvaticaFactory; @@ -79,20 +78,20 @@ public ArrowFlightPreparedStatement newPreparedStatement( final Meta.Signature signature, final int resultType, final int resultSetConcurrency, - final int resultSetHoldability) - throws SQLException { + final int resultSetHoldability) { final ArrowFlightConnection flightConnection = (ArrowFlightConnection) connection; - ArrowFlightSqlClientHandler.PreparedStatement preparedStatement = - flightConnection.getMeta().getPreparedStatement(statementHandle); + final AvaticaStatement existingStatement = + flightConnection.statementMap.get(statementHandle.id); + if (existingStatement instanceof ArrowFlightPreparedStatement) { + return (ArrowFlightPreparedStatement) existingStatement; + } + if (existingStatement != null) { + throw new IllegalStateException( + "Unexpected statement type found for prepared statement handle: " + statementHandle); + } - return ArrowFlightPreparedStatement.newPreparedStatement( - flightConnection, - preparedStatement, - statementHandle, - signature, - resultType, - resultSetConcurrency, - resultSetHoldability); + throw new IllegalStateException( + "PreparedStatement was not pre-created for handle: " + statementHandle); } @Override diff --git a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcFlightStreamResultSet.java b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcFlightStreamResultSet.java index 2885f7895b..f230c95340 100644 --- a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcFlightStreamResultSet.java +++ b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcFlightStreamResultSet.java @@ -106,7 +106,7 @@ static ArrowFlightJdbcFlightStreamResultSet fromFlightInfo( final TimeZone timeZone = TimeZone.getDefault(); final QueryState state = new QueryState(); - final Meta.Signature signature = ArrowFlightMetaImpl.newSignature(null, null, null); + final Meta.Signature signature = ArrowFlightMetaImpl.buildDefaultSignature(); final AvaticaResultSetMetaData resultSetMetaData = new AvaticaResultSetMetaData(null, null, signature); diff --git a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcVectorSchemaRootResultSet.java b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcVectorSchemaRootResultSet.java index 49334951de..e084a6e9c3 100644 --- a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcVectorSchemaRootResultSet.java +++ b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcVectorSchemaRootResultSet.java @@ -73,7 +73,7 @@ public static ArrowFlightJdbcVectorSchemaRootResultSet fromVectorSchemaRoot( final TimeZone timeZone = TimeZone.getDefault(); final QueryState state = new QueryState(); - final Meta.Signature signature = ArrowFlightMetaImpl.newSignature(null, null, null); + final Meta.Signature signature = ArrowFlightMetaImpl.buildDefaultSignature(); final AvaticaResultSetMetaData resultSetMetaData = new AvaticaResultSetMetaData(null, null, signature); diff --git a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightMetaImpl.java b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightMetaImpl.java index 64529b50c8..6415455db2 100644 --- a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightMetaImpl.java +++ b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightMetaImpl.java @@ -17,20 +17,17 @@ package org.apache.arrow.driver.jdbc; import java.sql.Connection; +import java.sql.ResultSet; import java.sql.SQLException; import java.sql.SQLTimeoutException; import java.util.ArrayList; import java.util.Collections; import java.util.List; -import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; -import org.apache.arrow.driver.jdbc.client.ArrowFlightSqlClientHandler.PreparedStatement; -import org.apache.arrow.driver.jdbc.utils.AvaticaParameterBinder; import org.apache.arrow.driver.jdbc.utils.ConvertUtils; -import org.apache.arrow.util.Preconditions; import org.apache.arrow.vector.types.pojo.Schema; import org.apache.calcite.avatica.AvaticaConnection; import org.apache.calcite.avatica.AvaticaParameter; +import org.apache.calcite.avatica.AvaticaStatement; import org.apache.calcite.avatica.ColumnMetaData; import org.apache.calcite.avatica.MetaImpl; import org.apache.calcite.avatica.NoSuchStatementException; @@ -39,8 +36,6 @@ /** Metadata handler for Arrow Flight. */ public class ArrowFlightMetaImpl extends MetaImpl { - private final Map statementHandlePreparedStatementMap; - /** * Constructs a {@link MetaImpl} object specific for Arrow Flight. * @@ -48,42 +43,14 @@ public class ArrowFlightMetaImpl extends MetaImpl { */ public ArrowFlightMetaImpl(final AvaticaConnection connection) { super(connection); - this.statementHandlePreparedStatementMap = new ConcurrentHashMap<>(); setDefaultConnectionProperties(); } - /** Construct a signature. */ - static Signature newSignature(final String sql, Schema resultSetSchema, Schema parameterSchema) { - List columnMetaData = - resultSetSchema == null - ? new ArrayList<>() - : ConvertUtils.convertArrowFieldsToColumnMetaDataList(resultSetSchema.getFields()); - List parameters = - parameterSchema == null - ? new ArrayList<>() - : ConvertUtils.convertArrowFieldsToAvaticaParameters(parameterSchema.getFields()); - StatementType statementType = - resultSetSchema == null || resultSetSchema.getFields().isEmpty() - ? StatementType.IS_DML - : StatementType.SELECT; - return new Signature( - columnMetaData, - sql, - parameters, - Collections.emptyMap(), - null, // unnecessary, as SQL requests use ArrowFlightJdbcCursor - statementType); - } - @Override public void closeStatement(final StatementHandle statementHandle) { - PreparedStatement preparedStatement = - statementHandlePreparedStatementMap.remove(new StatementHandleKey(statementHandle)); - // Testing if the prepared statement was created because the statement can be - // not created until - // this moment - if (preparedStatement != null) { - preparedStatement.close(); + AvaticaStatement statement = connection.statementMap.get(statementHandle.id); + if (statement instanceof ArrowFlightPreparedStatement) { + ((ArrowFlightPreparedStatement) statement).closePreparedResources(); } } @@ -97,36 +64,8 @@ public ExecuteResult execute( final StatementHandle statementHandle, final List typedValues, final long maxRowCount) { - Preconditions.checkArgument( - connection.id.equals(statementHandle.connectionId), "Connection IDs are not consistent"); - PreparedStatement preparedStatement = getPreparedStatement(statementHandle); - - if (preparedStatement == null) { - throw new IllegalStateException("Prepared statement not found: " + statementHandle); - } - - new AvaticaParameterBinder( - preparedStatement, ((ArrowFlightConnection) connection).getBufferAllocator()) - .bind(typedValues); - - if (statementHandle.signature == null - || statementHandle.signature.statementType == StatementType.IS_DML) { - // Update query - long updatedCount = preparedStatement.executeUpdate(); - return new ExecuteResult( - Collections.singletonList( - MetaResultSet.count(statementHandle.connectionId, statementHandle.id, updatedCount))); - } else { - // TODO Why is maxRowCount ignored? - return new ExecuteResult( - Collections.singletonList( - MetaResultSet.create( - statementHandle.connectionId, - statementHandle.id, - true, - statementHandle.signature, - null))); - } + return getPreparedStatementInstance(statementHandle) + .executeWithTypedValues(statementHandle, typedValues, maxRowCount); } @Override @@ -141,24 +80,8 @@ public ExecuteResult execute( public ExecuteBatchResult executeBatch( final StatementHandle statementHandle, final List> parameterValuesList) throws IllegalStateException { - Preconditions.checkArgument( - connection.id.equals(statementHandle.connectionId), "Connection IDs are not consistent"); - PreparedStatement preparedStatement = getPreparedStatement(statementHandle); - - if (preparedStatement == null) { - throw new IllegalStateException("Prepared statement not found: " + statementHandle); - } - - final AvaticaParameterBinder binder = - new AvaticaParameterBinder( - preparedStatement, ((ArrowFlightConnection) connection).getBufferAllocator()); - for (int i = 0; i < parameterValuesList.size(); i++) { - binder.bind(parameterValuesList.get(i), i); - } - - // Update query - long[] updatedCounts = {preparedStatement.executeUpdate()}; - return new ExecuteBatchResult(updatedCounts); + return getPreparedStatementInstance(statementHandle) + .executeBatchWithTypedValues(statementHandle, parameterValuesList); } @Override @@ -173,22 +96,36 @@ public Frame fetch( String.format("%s does not use frames.", this), AvaticaConnection.HELPER.unsupported()); } - private PreparedStatement prepareForHandle(final String query, StatementHandle handle) { - final PreparedStatement preparedStatement = - ((ArrowFlightConnection) connection).getClientHandler().prepare(query); - handle.signature = - newSignature( - query, preparedStatement.getDataSetSchema(), preparedStatement.getParameterSchema()); - statementHandlePreparedStatementMap.put(new StatementHandleKey(handle), preparedStatement); - return preparedStatement; + ArrowFlightPreparedStatement createPreparedStatement( + final String query, + final int resultSetType, + final int resultSetConcurrency, + final int resultSetHoldability) + throws SQLException { + final StatementHandle handle = super.createStatement(connection.handle); + return ArrowFlightPreparedStatement.createPrepared( + (ArrowFlightConnection) connection, + handle, + // null, + query, + resultSetType, + resultSetConcurrency, + resultSetHoldability); } @Override public StatementHandle prepare( final ConnectionHandle connectionHandle, final String query, final long maxRowCount) { - final StatementHandle handle = super.createStatement(connectionHandle); - prepareForHandle(query, handle); - return handle; + try { + return createPreparedStatement( + query, + ResultSet.TYPE_FORWARD_ONLY, + ResultSet.CONCUR_READ_ONLY, + connection.getHoldability()) + .handle; + } catch (SQLException e) { + throw new RuntimeException(e); + } } @Override @@ -211,19 +148,20 @@ public ExecuteResult prepareAndExecute( final PrepareCallback callback) throws NoSuchStatementException { try { - PreparedStatement preparedStatement = prepareForHandle(query, handle); - final StatementType statementType = preparedStatement.getType(); - - final long updateCount = - statementType.equals(StatementType.UPDATE) ? preparedStatement.executeUpdate() : -1; - synchronized (callback.getMonitor()) { - callback.clear(); - callback.assign(handle.signature, null, updateCount); + final AvaticaStatement statement = connection.statementMap.get(handle.id); + if (!(statement instanceof ArrowFlightStatement) + && !(statement instanceof ArrowFlightPreparedStatement)) { + throw new IllegalStateException("Prepared statement not found: " + handle); } - callback.execute(); - final MetaResultSet metaResultSet = - MetaResultSet.create(handle.connectionId, handle.id, false, handle.signature, null); - return new ExecuteResult(Collections.singletonList(metaResultSet)); + final ArrowFlightPreparedStatement preparedStatement = + ArrowFlightPreparedStatement.createPrepared( + (ArrowFlightConnection) connection, + handle, + query, + statement.getResultSetType(), + statement.getResultSetConcurrency(), + statement.getResultSetHoldability()); + return preparedStatement.prepareAndExecute(callback); } catch (SQLTimeoutException e) { // So far AvaticaStatement(executeInternal) only handles NoSuchStatement and // Runtime @@ -280,45 +218,48 @@ void setDefaultConnectionProperties() { .setTransactionIsolation(Connection.TRANSACTION_NONE); } - PreparedStatement getPreparedStatement(StatementHandle statementHandle) { - return statementHandlePreparedStatementMap.get(new StatementHandleKey(statementHandle)); + private ArrowFlightPreparedStatement getPreparedStatementInstance( + StatementHandle statementHandle) { + AvaticaStatement statement = connection.statementMap.get(statementHandle.id); + if (!(statement instanceof ArrowFlightPreparedStatement)) { + throw new IllegalStateException("Prepared statement not found: " + statementHandle); + } + return (ArrowFlightPreparedStatement) statement; } - // Helper used to look up prepared statement instances later. Avatica doesn't - // give us the - // signature in - // an UPDATE code path so we can't directly use StatementHandle as a map key. - private static final class StatementHandleKey { - public final String connectionId; - public final int id; - - StatementHandleKey(StatementHandle statementHandle) { - this.connectionId = statementHandle.connectionId; - this.id = statementHandle.id; + ArrowFlightPreparedStatement getPreparedStatementInstanceOrNull(StatementHandle statementHandle) { + AvaticaStatement statement = connection.statementMap.get(statementHandle.id); + if (statement instanceof ArrowFlightPreparedStatement) { + return (ArrowFlightPreparedStatement) statement; } + return null; + } - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - - StatementHandleKey that = (StatementHandleKey) o; - - if (id != that.id) { - return false; - } - return connectionId.equals(that.connectionId); - } + public static Signature buildDefaultSignature() { + return buildSignature(null, null, null); + } - @Override - public int hashCode() { - int result = connectionId.hashCode(); - result = 31 * result + id; - return result; - } + /** Builds an Avatica signature from Arrow result and parameter schemas. */ + public static Signature buildSignature( + final String sql, final Schema resultSetSchema, final Schema parameterSchema) { + List columnMetaData = + resultSetSchema == null + ? new ArrayList<>() + : ConvertUtils.convertArrowFieldsToColumnMetaDataList(resultSetSchema.getFields()); + List parameters = + parameterSchema == null + ? new ArrayList<>() + : ConvertUtils.convertArrowFieldsToAvaticaParameters(parameterSchema.getFields()); + StatementType statementType = + resultSetSchema == null || resultSetSchema.getFields().isEmpty() + ? StatementType.IS_DML + : StatementType.SELECT; + return new Signature( + columnMetaData, + sql, + parameters, + Collections.emptyMap(), + null, // unnecessary, as SQL requests use ArrowFlightJdbcCursor + statementType); } } diff --git a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightPreparedStatement.java b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightPreparedStatement.java index d7af6902f4..7856d38f07 100644 --- a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightPreparedStatement.java +++ b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightPreparedStatement.java @@ -16,24 +16,32 @@ */ package org.apache.arrow.driver.jdbc; -import java.sql.PreparedStatement; import java.sql.SQLException; +import java.util.Collections; +import java.util.List; import org.apache.arrow.driver.jdbc.client.ArrowFlightSqlClientHandler; +import org.apache.arrow.driver.jdbc.utils.AvaticaParameterBinder; import org.apache.arrow.flight.FlightInfo; import org.apache.arrow.util.Preconditions; +import org.apache.arrow.vector.types.pojo.Schema; import org.apache.calcite.avatica.AvaticaPreparedStatement; +import org.apache.calcite.avatica.Meta.ExecuteBatchResult; +import org.apache.calcite.avatica.Meta.ExecuteResult; +import org.apache.calcite.avatica.Meta.MetaResultSet; +import org.apache.calcite.avatica.Meta.PrepareCallback; import org.apache.calcite.avatica.Meta.Signature; import org.apache.calcite.avatica.Meta.StatementHandle; +import org.apache.calcite.avatica.Meta.StatementType; +import org.apache.calcite.avatica.remote.TypedValue; -/** Arrow Flight JBCS's implementation {@link PreparedStatement}. */ +/** Arrow Flight JDBC's implementation {@link java.sql.PreparedStatement}. */ public class ArrowFlightPreparedStatement extends AvaticaPreparedStatement implements ArrowFlightInfoStatement { - private final ArrowFlightSqlClientHandler.PreparedStatement preparedStatement; + private ArrowFlightSqlClientHandler.PreparedStatement preparedStatement; private ArrowFlightPreparedStatement( final ArrowFlightConnection connection, - final ArrowFlightSqlClientHandler.PreparedStatement preparedStatement, final StatementHandle handle, final Signature signature, final int resultSetType, @@ -41,26 +49,34 @@ private ArrowFlightPreparedStatement( final int resultSetHoldability) throws SQLException { super(connection, handle, signature, resultSetType, resultSetConcurrency, resultSetHoldability); - this.preparedStatement = Preconditions.checkNotNull(preparedStatement); } - static ArrowFlightPreparedStatement newPreparedStatement( + static ArrowFlightPreparedStatement createPrepared( final ArrowFlightConnection connection, - final ArrowFlightSqlClientHandler.PreparedStatement preparedStmt, final StatementHandle statementHandle, - final Signature signature, + final String query, final int resultSetType, final int resultSetConcurrency, final int resultSetHoldability) throws SQLException { - return new ArrowFlightPreparedStatement( - connection, - preparedStmt, - statementHandle, - signature, - resultSetType, - resultSetConcurrency, - resultSetHoldability); + final ArrowFlightSqlClientHandler.PreparedStatement preparedStatement = + connection.getClientHandler().prepare(query); + final Signature signature = + ArrowFlightMetaImpl.buildSignature( + query, preparedStatement.getDataSetSchema(), preparedStatement.getParameterSchema()); + statementHandle.signature = signature; + + final ArrowFlightPreparedStatement statement = + new ArrowFlightPreparedStatement( + connection, + statementHandle, + signature, + resultSetType, + resultSetConcurrency, + resultSetHoldability); + statement.preparedStatement = Preconditions.checkNotNull(preparedStatement); + statement.setSignature(signature); + return statement; } @Override @@ -68,14 +84,93 @@ public ArrowFlightConnection getConnection() throws SQLException { return (ArrowFlightConnection) super.getConnection(); } + ExecuteResult prepareAndExecute(final PrepareCallback callback) throws SQLException { + ensurePrepared(); + final StatementType statementType = preparedStatement.getType(); + final long updateCount = + statementType.equals(StatementType.UPDATE) ? preparedStatement.executeUpdate() : -1; + synchronized (callback.getMonitor()) { + callback.clear(); + callback.assign(handle.signature, null, updateCount); + } + callback.execute(); + final MetaResultSet metaResultSet = + MetaResultSet.create(handle.connectionId, handle.id, false, handle.signature, null); + return new ExecuteResult(Collections.singletonList(metaResultSet)); + } + + Schema getDataSetSchema() { + ensurePrepared(); + return preparedStatement.getDataSetSchema(); + } + @Override public synchronized void close() throws SQLException { - this.preparedStatement.close(); super.close(); } + void closePreparedResources() { + if (preparedStatement != null) { + preparedStatement.close(); + preparedStatement = null; + } + } + + ExecuteResult executeWithTypedValues( + final StatementHandle statementHandle, + final List typedValues, + final long maxRowCount) { + ensurePrepared(); + Preconditions.checkArgument( + connection.id.equals(statementHandle.connectionId), "Connection IDs are not consistent"); + new AvaticaParameterBinder( + preparedStatement, ((ArrowFlightConnection) connection).getBufferAllocator()) + .bind(typedValues); + + if (statementHandle.signature == null + || statementHandle.signature.statementType == StatementType.IS_DML) { + long updatedCount = preparedStatement.executeUpdate(); + return new ExecuteResult( + Collections.singletonList( + MetaResultSet.count(statementHandle.connectionId, statementHandle.id, updatedCount))); + } + + // TODO Why is maxRowCount ignored? + return new ExecuteResult( + Collections.singletonList( + MetaResultSet.create( + statementHandle.connectionId, + statementHandle.id, + true, + statementHandle.signature, + null))); + } + + ExecuteBatchResult executeBatchWithTypedValues( + final StatementHandle statementHandle, final List> parameterValuesList) { + ensurePrepared(); + Preconditions.checkArgument( + connection.id.equals(statementHandle.connectionId), "Connection IDs are not consistent"); + final AvaticaParameterBinder binder = + new AvaticaParameterBinder( + preparedStatement, ((ArrowFlightConnection) connection).getBufferAllocator()); + for (int i = 0; i < parameterValuesList.size(); i++) { + binder.bind(parameterValuesList.get(i), i); + } + + long[] updatedCounts = {preparedStatement.executeUpdate()}; + return new ExecuteBatchResult(updatedCounts); + } + @Override public FlightInfo executeFlightInfoQuery() throws SQLException { + ensurePrepared(); return preparedStatement.executeQuery(); } + + private void ensurePrepared() { + if (preparedStatement == null) { + throw new IllegalStateException("PreparedStatement is already closed."); + } + } } diff --git a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightStatement.java b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightStatement.java index 577aee3b4a..9e514ccc9f 100644 --- a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightStatement.java +++ b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightStatement.java @@ -17,7 +17,6 @@ package org.apache.arrow.driver.jdbc; import java.sql.SQLException; -import org.apache.arrow.driver.jdbc.client.ArrowFlightSqlClientHandler.PreparedStatement; import org.apache.arrow.driver.jdbc.utils.ConvertUtils; import org.apache.arrow.flight.FlightInfo; import org.apache.arrow.vector.types.pojo.Schema; @@ -44,18 +43,21 @@ public ArrowFlightConnection getConnection() throws SQLException { @Override public FlightInfo executeFlightInfoQuery() throws SQLException { - final PreparedStatement preparedStatement = - getConnection().getMeta().getPreparedStatement(handle); + final ArrowFlightPreparedStatement preparedStatement = + getConnection().getMeta().getPreparedStatementInstanceOrNull(handle); final Meta.Signature signature = getSignature(); if (signature == null) { return null; } - final Schema resultSetSchema = preparedStatement.getDataSetSchema(); - signature.columns.addAll( - ConvertUtils.convertArrowFieldsToColumnMetaDataList(resultSetSchema.getFields())); - setSignature(signature); + if (preparedStatement != null) { + final Schema resultSetSchema = preparedStatement.getDataSetSchema(); + signature.columns.addAll( + ConvertUtils.convertArrowFieldsToColumnMetaDataList(resultSetSchema.getFields())); + setSignature(signature); + return preparedStatement.executeFlightInfoQuery(); + } - return preparedStatement.executeQuery(); + throw new IllegalStateException("Prepared statement query not found: " + handle); } } From b62d85b7cd070c3950cff237e898f82028c97fcd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?H=C3=A9lder=20Greg=C3=B3rio?= Date: Sat, 14 Mar 2026 23:13:01 +0000 Subject: [PATCH 2/3] add builder pattern --- .../driver/jdbc/ArrowFlightMetaImpl.java | 30 +++-- .../jdbc/ArrowFlightPreparedStatement.java | 108 +++++++++++++----- .../ArrowFlightPreparedStatementTest.java | 23 ++++ .../jdbc/ArrowFlightStatementExecuteTest.java | 17 +++ 4 files changed, 136 insertions(+), 42 deletions(-) diff --git a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightMetaImpl.java b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightMetaImpl.java index 6415455db2..e0a8727630 100644 --- a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightMetaImpl.java +++ b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightMetaImpl.java @@ -102,15 +102,13 @@ ArrowFlightPreparedStatement createPreparedStatement( final int resultSetConcurrency, final int resultSetHoldability) throws SQLException { - final StatementHandle handle = super.createStatement(connection.handle); - return ArrowFlightPreparedStatement.createPrepared( - (ArrowFlightConnection) connection, - handle, - // null, - query, - resultSetType, - resultSetConcurrency, - resultSetHoldability); + return ArrowFlightPreparedStatement.builder((ArrowFlightConnection) connection) + .withQuery(query) + .withGeneratedHandle() + .withResultSetType(resultSetType) + .withResultSetConcurrency(resultSetConcurrency) + .withResultSetHoldability(resultSetHoldability) + .build(); } @Override @@ -153,14 +151,14 @@ public ExecuteResult prepareAndExecute( && !(statement instanceof ArrowFlightPreparedStatement)) { throw new IllegalStateException("Prepared statement not found: " + handle); } + if (statement instanceof ArrowFlightPreparedStatement) { + ((ArrowFlightPreparedStatement) statement).closePreparedResources(); + } final ArrowFlightPreparedStatement preparedStatement = - ArrowFlightPreparedStatement.createPrepared( - (ArrowFlightConnection) connection, - handle, - query, - statement.getResultSetType(), - statement.getResultSetConcurrency(), - statement.getResultSetHoldability()); + ArrowFlightPreparedStatement.builder((ArrowFlightConnection) connection) + .withQuery(query) + .withExistingStatement(statement) + .build(); return preparedStatement.prepareAndExecute(callback); } catch (SQLTimeoutException e) { // So far AvaticaStatement(executeInternal) only handles NoSuchStatement and diff --git a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightPreparedStatement.java b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightPreparedStatement.java index 7856d38f07..1c6f0cdb21 100644 --- a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightPreparedStatement.java +++ b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightPreparedStatement.java @@ -25,6 +25,7 @@ import org.apache.arrow.util.Preconditions; import org.apache.arrow.vector.types.pojo.Schema; import org.apache.calcite.avatica.AvaticaPreparedStatement; +import org.apache.calcite.avatica.AvaticaStatement; import org.apache.calcite.avatica.Meta.ExecuteBatchResult; import org.apache.calcite.avatica.Meta.ExecuteResult; import org.apache.calcite.avatica.Meta.MetaResultSet; @@ -44,39 +45,19 @@ private ArrowFlightPreparedStatement( final ArrowFlightConnection connection, final StatementHandle handle, final Signature signature, + final ArrowFlightSqlClientHandler.PreparedStatement preparedStatement, final int resultSetType, final int resultSetConcurrency, final int resultSetHoldability) throws SQLException { super(connection, handle, signature, resultSetType, resultSetConcurrency, resultSetHoldability); + this.preparedStatement = Preconditions.checkNotNull(preparedStatement); + this.handle.signature = signature; + setSignature(signature); } - static ArrowFlightPreparedStatement createPrepared( - final ArrowFlightConnection connection, - final StatementHandle statementHandle, - final String query, - final int resultSetType, - final int resultSetConcurrency, - final int resultSetHoldability) - throws SQLException { - final ArrowFlightSqlClientHandler.PreparedStatement preparedStatement = - connection.getClientHandler().prepare(query); - final Signature signature = - ArrowFlightMetaImpl.buildSignature( - query, preparedStatement.getDataSetSchema(), preparedStatement.getParameterSchema()); - statementHandle.signature = signature; - - final ArrowFlightPreparedStatement statement = - new ArrowFlightPreparedStatement( - connection, - statementHandle, - signature, - resultSetType, - resultSetConcurrency, - resultSetHoldability); - statement.preparedStatement = Preconditions.checkNotNull(preparedStatement); - statement.setSignature(signature); - return statement; + static Builder builder(final ArrowFlightConnection connection) { + return new Builder(connection); } @Override @@ -173,4 +154,79 @@ private void ensurePrepared() { throw new IllegalStateException("PreparedStatement is already closed."); } } + + static final class Builder { + private final ArrowFlightConnection connection; + private StatementHandle handle; + private String query; + private Integer resultSetType; + private Integer resultSetConcurrency; + private Integer resultSetHoldability; + private boolean generateHandle; + + private Builder(final ArrowFlightConnection connection) { + this.connection = Preconditions.checkNotNull(connection); + } + + Builder withQuery(final String query) { + this.query = Preconditions.checkNotNull(query); + return this; + } + + Builder withGeneratedHandle() { + this.generateHandle = true; + this.handle = null; + return this; + } + + Builder withExistingStatement(final AvaticaStatement statement) throws SQLException { + Preconditions.checkNotNull(statement); + this.generateHandle = false; + this.handle = Preconditions.checkNotNull(statement.handle); + this.resultSetType = statement.getResultSetType(); + this.resultSetConcurrency = statement.getResultSetConcurrency(); + this.resultSetHoldability = statement.getResultSetHoldability(); + return this; + } + + Builder withResultSetType(final int resultSetType) { + this.resultSetType = resultSetType; + return this; + } + + Builder withResultSetConcurrency(final int resultSetConcurrency) { + this.resultSetConcurrency = resultSetConcurrency; + return this; + } + + Builder withResultSetHoldability(final int resultSetHoldability) { + this.resultSetHoldability = resultSetHoldability; + return this; + } + + ArrowFlightPreparedStatement build() throws SQLException { + Preconditions.checkNotNull(query); + Preconditions.checkNotNull(resultSetType); + Preconditions.checkNotNull(resultSetConcurrency); + Preconditions.checkNotNull(resultSetHoldability); + if (!generateHandle && handle == null) { + throw new IllegalStateException("PreparedStatement builder requires a handle."); + } + + final ArrowFlightSqlClientHandler.PreparedStatement preparedStatement = + connection.getClientHandler().prepare(query); + final Signature signature = + ArrowFlightMetaImpl.buildSignature( + query, preparedStatement.getDataSetSchema(), preparedStatement.getParameterSchema()); + + return new ArrowFlightPreparedStatement( + connection, + generateHandle ? null : handle, + signature, + preparedStatement, + resultSetType, + resultSetConcurrency, + resultSetHoldability); + } + } } diff --git a/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightPreparedStatementTest.java b/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightPreparedStatementTest.java index 0369c3a162..d4e7a0953d 100644 --- a/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightPreparedStatementTest.java +++ b/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightPreparedStatementTest.java @@ -21,6 +21,8 @@ import static org.junit.jupiter.api.Assertions.assertAll; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertSame; import static org.junit.jupiter.api.Assertions.assertTrue; import java.nio.charset.StandardCharsets; @@ -98,6 +100,27 @@ public void testSimpleQueryNoParameterBindingWithExecute() throws SQLException { } } + @Test + public void testPrepareStatementRegistersCreatedStatementByGeneratedHandle() throws SQLException { + final String query = CoreMockedSqlProducers.LEGACY_REGULAR_SQL_CMD; + final ArrowFlightConnection flightConnection = (ArrowFlightConnection) connection; + + try (final PreparedStatement preparedStatement = connection.prepareStatement(query)) { + final ArrowFlightPreparedStatement arrowPreparedStatement = + (ArrowFlightPreparedStatement) preparedStatement; + + assertNotNull( + flightConnection + .getMeta() + .getPreparedStatementInstanceOrNull(arrowPreparedStatement.handle)); + assertSame( + arrowPreparedStatement, + flightConnection + .getMeta() + .getPreparedStatementInstanceOrNull(arrowPreparedStatement.handle)); + } + } + @Test public void testQueryWithParameterBinding() throws SQLException { final String query = "Fake query with parameters"; diff --git a/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightStatementExecuteTest.java b/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightStatementExecuteTest.java index 632cb0ba56..e4df71967b 100644 --- a/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightStatementExecuteTest.java +++ b/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightStatementExecuteTest.java @@ -22,6 +22,8 @@ import static org.hamcrest.CoreMatchers.not; import static org.hamcrest.CoreMatchers.nullValue; import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertSame; import java.sql.Connection; import java.sql.ResultSet; @@ -137,6 +139,21 @@ public void testExecuteShouldRunSelectQuery() throws SQLException { is(allOf(equalTo(statement.getLargeUpdateCount()), equalTo(-1L)))); } + @Test + public void testExecuteReplacesStatementMapEntryWithPreparedStatement() throws SQLException { + final ArrowFlightStatement arrowStatement = (ArrowFlightStatement) statement; + final ArrowFlightConnection arrowConnection = (ArrowFlightConnection) connection; + + assertThat(statement.execute(SAMPLE_QUERY_CMD), is(true)); + + final ArrowFlightPreparedStatement preparedStatement = + arrowConnection.getMeta().getPreparedStatementInstanceOrNull(arrowStatement.handle); + + assertNotNull(preparedStatement); + assertSame(preparedStatement, arrowConnection.statementMap.get(arrowStatement.handle.id)); + assertThat(preparedStatement.handle.id, is(equalTo(arrowStatement.handle.id))); + } + @Test public void testExecuteShouldRunUpdateQueryForSmallUpdate() throws SQLException { assertThat(statement.execute(SAMPLE_UPDATE_QUERY), is(false)); // Means this is an UPDATE query. From c9568aa9f643c95bc3c436819982a505454854bd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?H=C3=A9lder=20Greg=C3=B3rio?= Date: Fri, 20 Mar 2026 15:26:11 +0000 Subject: [PATCH 3/3] Optimize statement queries 3 (#3) * reduced number of requests for Statement * move meta orchestration * detect prepared statement --- .../driver/jdbc/ArrowFlightConnection.java | 9 +- .../driver/jdbc/ArrowFlightInfoStatement.java | 36 -- .../ArrowFlightJdbcFlightStreamResultSet.java | 2 +- .../driver/jdbc/ArrowFlightMetaImpl.java | 93 ++--- .../driver/jdbc/ArrowFlightMetaStatement.java | 60 +++ .../jdbc/ArrowFlightPreparedStatement.java | 36 +- .../driver/jdbc/ArrowFlightStatement.java | 139 ++++++- .../client/ArrowFlightSqlClientHandler.java | 10 + .../example/ArrowFlightJdbcSampleApp.java | 120 ++++++ .../ArrowFlightPreparedStatementTest.java | 9 +- .../jdbc/ArrowFlightStatementExecuteTest.java | 20 +- ...ArrowFlightStatementExecuteUpdateTest.java | 10 + .../ArrowFlightStatementProtocolTest.java | 376 ++++++++++++++++++ .../jdbc/utils/MockFlightSqlProducer.java | 22 + 14 files changed, 825 insertions(+), 117 deletions(-) delete mode 100644 flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightInfoStatement.java create mode 100644 flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightMetaStatement.java create mode 100644 flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/example/ArrowFlightJdbcSampleApp.java create mode 100644 flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightStatementProtocolTest.java diff --git a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightConnection.java b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightConnection.java index 166c1157a1..107cfa0c2f 100644 --- a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightConnection.java +++ b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightConnection.java @@ -282,7 +282,12 @@ public PreparedStatement prepareStatement( final int resultSetHoldability) throws SQLException { checkOpen(); - return getMeta() - .createPreparedStatement(sql, resultSetType, resultSetConcurrency, resultSetHoldability); + return ArrowFlightPreparedStatement.builder(this) + .withQuery(sql) + .withGeneratedHandle() + .withResultSetType(resultSetType) + .withResultSetConcurrency(resultSetConcurrency) + .withResultSetHoldability(resultSetHoldability) + .build(); } } diff --git a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightInfoStatement.java b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightInfoStatement.java deleted file mode 100644 index 37ee93722a..0000000000 --- a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightInfoStatement.java +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.arrow.driver.jdbc; - -import java.sql.SQLException; -import java.sql.Statement; -import org.apache.arrow.flight.FlightInfo; - -/** A {@link Statement} that deals with {@link FlightInfo}. */ -public interface ArrowFlightInfoStatement extends Statement { - - @Override - ArrowFlightConnection getConnection() throws SQLException; - - /** - * Executes the query this {@link Statement} is holding. - * - * @return the {@link FlightInfo} for the results. - * @throws SQLException on error. - */ - FlightInfo executeFlightInfoQuery() throws SQLException; -} diff --git a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcFlightStreamResultSet.java b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcFlightStreamResultSet.java index f230c95340..d383d239d1 100644 --- a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcFlightStreamResultSet.java +++ b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcFlightStreamResultSet.java @@ -67,7 +67,7 @@ public final class ArrowFlightJdbcFlightStreamResultSet throws SQLException { super(statement, state, signature, resultSetMetaData, timeZone, firstFrame); this.connection = (ArrowFlightConnection) statement.connection; - this.flightInfo = ((ArrowFlightInfoStatement) statement).executeFlightInfoQuery(); + this.flightInfo = ((ArrowFlightMetaStatement) statement).executeFlightInfoQuery(); } /** Private constructor for fromFlightInfo. */ diff --git a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightMetaImpl.java b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightMetaImpl.java index e0a8727630..4da182ca63 100644 --- a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightMetaImpl.java +++ b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightMetaImpl.java @@ -48,10 +48,7 @@ public ArrowFlightMetaImpl(final AvaticaConnection connection) { @Override public void closeStatement(final StatementHandle statementHandle) { - AvaticaStatement statement = connection.statementMap.get(statementHandle.id); - if (statement instanceof ArrowFlightPreparedStatement) { - ((ArrowFlightPreparedStatement) statement).closePreparedResources(); - } + getMetaStatement(statementHandle).closeStatement(); } @Override @@ -64,8 +61,7 @@ public ExecuteResult execute( final StatementHandle statementHandle, final List typedValues, final long maxRowCount) { - return getPreparedStatementInstance(statementHandle) - .executeWithTypedValues(statementHandle, typedValues, maxRowCount); + return getMetaStatement(statementHandle).execute(statementHandle, typedValues, maxRowCount); } @Override @@ -80,8 +76,7 @@ public ExecuteResult execute( public ExecuteBatchResult executeBatch( final StatementHandle statementHandle, final List> parameterValuesList) throws IllegalStateException { - return getPreparedStatementInstance(statementHandle) - .executeBatchWithTypedValues(statementHandle, parameterValuesList); + return getMetaStatement(statementHandle).executeBatch(statementHandle, parameterValuesList); } @Override @@ -96,31 +91,16 @@ public Frame fetch( String.format("%s does not use frames.", this), AvaticaConnection.HELPER.unsupported()); } - ArrowFlightPreparedStatement createPreparedStatement( - final String query, - final int resultSetType, - final int resultSetConcurrency, - final int resultSetHoldability) - throws SQLException { - return ArrowFlightPreparedStatement.builder((ArrowFlightConnection) connection) - .withQuery(query) - .withGeneratedHandle() - .withResultSetType(resultSetType) - .withResultSetConcurrency(resultSetConcurrency) - .withResultSetHoldability(resultSetHoldability) - .build(); - } - @Override public StatementHandle prepare( final ConnectionHandle connectionHandle, final String query, final long maxRowCount) { try { - return createPreparedStatement( - query, - ResultSet.TYPE_FORWARD_ONLY, - ResultSet.CONCUR_READ_ONLY, - connection.getHoldability()) - .handle; + // This is the Avatica entry point used by Connection.prepareStatement(String). + ArrowFlightPreparedStatement stmt = + (ArrowFlightPreparedStatement) + connection.prepareStatement( + query, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY); + return stmt.handle; } catch (SQLException e) { throw new RuntimeException(e); } @@ -133,6 +113,7 @@ public ExecuteResult prepareAndExecute( final long maxRowCount, final PrepareCallback prepareCallback) throws NoSuchStatementException { + // This is the Avatica entry point used by Statement.execute(String). return prepareAndExecute( statementHandle, query, maxRowCount, -1 /* Not used */, prepareCallback); } @@ -146,20 +127,9 @@ public ExecuteResult prepareAndExecute( final PrepareCallback callback) throws NoSuchStatementException { try { - final AvaticaStatement statement = connection.statementMap.get(handle.id); - if (!(statement instanceof ArrowFlightStatement) - && !(statement instanceof ArrowFlightPreparedStatement)) { - throw new IllegalStateException("Prepared statement not found: " + handle); - } - if (statement instanceof ArrowFlightPreparedStatement) { - ((ArrowFlightPreparedStatement) statement).closePreparedResources(); - } - final ArrowFlightPreparedStatement preparedStatement = - ArrowFlightPreparedStatement.builder((ArrowFlightConnection) connection) - .withQuery(query) - .withExistingStatement(statement) - .build(); - return preparedStatement.prepareAndExecute(callback); + // This is the Avatica entry point used by Statement.execute(String). + return getMetaStatement(handle) + .prepareAndExecute(query, maxRowCount, maxRowsInFirstFrame, callback); } catch (SQLTimeoutException e) { // So far AvaticaStatement(executeInternal) only handles NoSuchStatement and // Runtime @@ -216,30 +186,37 @@ void setDefaultConnectionProperties() { .setTransactionIsolation(Connection.TRANSACTION_NONE); } - private ArrowFlightPreparedStatement getPreparedStatementInstance( - StatementHandle statementHandle) { + private ArrowFlightMetaStatement getMetaStatement(StatementHandle statementHandle) { AvaticaStatement statement = connection.statementMap.get(statementHandle.id); - if (!(statement instanceof ArrowFlightPreparedStatement)) { - throw new IllegalStateException("Prepared statement not found: " + statementHandle); + if (statement instanceof ArrowFlightMetaStatement) { + return (ArrowFlightMetaStatement) statement; } - return (ArrowFlightPreparedStatement) statement; + throw new IllegalStateException("Statement not found: " + statementHandle); } - ArrowFlightPreparedStatement getPreparedStatementInstanceOrNull(StatementHandle statementHandle) { - AvaticaStatement statement = connection.statementMap.get(statementHandle.id); - if (statement instanceof ArrowFlightPreparedStatement) { - return (ArrowFlightPreparedStatement) statement; - } - return null; + public static Signature buildDefaultSignature() { + return buildSignature(null, StatementType.SELECT); } - public static Signature buildDefaultSignature() { - return buildSignature(null, null, null); + public static Signature buildSignature(final String sql, final StatementType type) { + return buildSignature(sql, null, null, type); } /** Builds an Avatica signature from Arrow result and parameter schemas. */ public static Signature buildSignature( final String sql, final Schema resultSetSchema, final Schema parameterSchema) { + StatementType statementType = + resultSetSchema == null || resultSetSchema.getFields().isEmpty() + ? StatementType.IS_DML + : StatementType.SELECT; + return buildSignature(sql, resultSetSchema, parameterSchema, statementType); + } + + private static Signature buildSignature( + final String sql, + final Schema resultSetSchema, + final Schema parameterSchema, + final StatementType statementType) { List columnMetaData = resultSetSchema == null ? new ArrayList<>() @@ -248,10 +225,6 @@ public static Signature buildSignature( parameterSchema == null ? new ArrayList<>() : ConvertUtils.convertArrowFieldsToAvaticaParameters(parameterSchema.getFields()); - StatementType statementType = - resultSetSchema == null || resultSetSchema.getFields().isEmpty() - ? StatementType.IS_DML - : StatementType.SELECT; return new Signature( columnMetaData, sql, diff --git a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightMetaStatement.java b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightMetaStatement.java new file mode 100644 index 0000000000..415af19e8f --- /dev/null +++ b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightMetaStatement.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.arrow.driver.jdbc; + +import java.sql.SQLException; +import java.sql.Statement; +import java.util.List; +import org.apache.arrow.flight.FlightInfo; +import org.apache.calcite.avatica.Meta.ExecuteBatchResult; +import org.apache.calcite.avatica.Meta.ExecuteResult; +import org.apache.calcite.avatica.Meta.PrepareCallback; +import org.apache.calcite.avatica.Meta.StatementHandle; +import org.apache.calcite.avatica.remote.TypedValue; + +/** Statement capabilities used by {@link ArrowFlightMetaImpl}. */ +interface ArrowFlightMetaStatement extends Statement { + + @Override + ArrowFlightConnection getConnection() throws SQLException; + + FlightInfo executeFlightInfoQuery() throws SQLException; + + /** + * Avatica routes {@link Statement#execute(String)} through Meta.prepareAndExecute(...), so plain + * statements still need this hook even when they support direct executeQuery/executeUpdate paths. + */ + ExecuteResult prepareAndExecute( + String query, long maxRowCount, int maxRowsInFirstFrame, PrepareCallback callback) + throws SQLException; + + default ExecuteResult execute( + final StatementHandle statementHandle, + final List typedValues, + final long maxRowCount) { + throw new IllegalStateException( + "Statement operation is not supported for handle: " + statementHandle); + } + + default ExecuteBatchResult executeBatch( + final StatementHandle statementHandle, final List> parameterValuesList) { + throw new IllegalStateException( + "Statement operation is not supported for handle: " + statementHandle); + } + + default void closeStatement() {} +} diff --git a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightPreparedStatement.java b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightPreparedStatement.java index 1c6f0cdb21..bd7ebbe0e4 100644 --- a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightPreparedStatement.java +++ b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightPreparedStatement.java @@ -37,7 +37,7 @@ /** Arrow Flight JDBC's implementation {@link java.sql.PreparedStatement}. */ public class ArrowFlightPreparedStatement extends AvaticaPreparedStatement - implements ArrowFlightInfoStatement { + implements ArrowFlightMetaStatement { private ArrowFlightSqlClientHandler.PreparedStatement preparedStatement; @@ -80,6 +80,21 @@ ExecuteResult prepareAndExecute(final PrepareCallback callback) throws SQLExcept return new ExecuteResult(Collections.singletonList(metaResultSet)); } + @Override + public ExecuteResult prepareAndExecute( + final String query, + final long maxRowCount, + final int maxRowsInFirstFrame, + final PrepareCallback callback) + throws SQLException { + + return ArrowFlightPreparedStatement.builder(getConnection()) + .withQuery(query) + .withExistingStatement(this) + .build() + .prepareAndExecute(callback); + } + Schema getDataSetSchema() { ensurePrepared(); return preparedStatement.getDataSetSchema(); @@ -143,6 +158,25 @@ ExecuteBatchResult executeBatchWithTypedValues( return new ExecuteBatchResult(updatedCounts); } + @Override + public ExecuteResult execute( + final StatementHandle statementHandle, + final List typedValues, + final long maxRowCount) { + return executeWithTypedValues(statementHandle, typedValues, maxRowCount); + } + + @Override + public ExecuteBatchResult executeBatch( + final StatementHandle statementHandle, final List> parameterValuesList) { + return executeBatchWithTypedValues(statementHandle, parameterValuesList); + } + + @Override + public void closeStatement() { + closePreparedResources(); + } + @Override public FlightInfo executeFlightInfoQuery() throws SQLException { ensurePrepared(); diff --git a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightStatement.java b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightStatement.java index 9e514ccc9f..0df8f20d2a 100644 --- a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightStatement.java +++ b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightStatement.java @@ -16,16 +16,24 @@ */ package org.apache.arrow.driver.jdbc; +import java.sql.ResultSet; import java.sql.SQLException; import org.apache.arrow.driver.jdbc.utils.ConvertUtils; import org.apache.arrow.flight.FlightInfo; +import org.apache.arrow.flight.FlightRuntimeException; +import org.apache.arrow.flight.FlightStatusCode; import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.calcite.avatica.AvaticaConnection; +import org.apache.calcite.avatica.AvaticaResultSet; import org.apache.calcite.avatica.AvaticaStatement; import org.apache.calcite.avatica.Meta; +import org.apache.calcite.avatica.Meta.ExecuteResult; +import org.apache.calcite.avatica.Meta.PrepareCallback; import org.apache.calcite.avatica.Meta.StatementHandle; +import org.apache.calcite.avatica.Meta.StatementType; /** A SQL statement for querying data from an Arrow Flight server. */ -public class ArrowFlightStatement extends AvaticaStatement implements ArrowFlightInfoStatement { +public class ArrowFlightStatement extends AvaticaStatement implements ArrowFlightMetaStatement { ArrowFlightStatement( final ArrowFlightConnection connection, @@ -41,23 +49,140 @@ public ArrowFlightConnection getConnection() throws SQLException { return (ArrowFlightConnection) super.getConnection(); } + @Override + public ExecuteResult prepareAndExecute( + final String query, + final long maxRowCount, + final int maxRowsInFirstFrame, + final PrepareCallback callback) + throws SQLException { + // Keep Avatica Statement.execute(String) behavior: Avatica calls Meta.prepareAndExecute, + // which resolves to this statement hook. + this.closeStatement(); + + return ArrowFlightPreparedStatement.builder(getConnection()) + .withQuery(query) + .withExistingStatement(this) + .build() + .prepareAndExecute(callback); + } + + @Override + public ResultSet executeQuery(final String sql) throws SQLException { + checkOpen(); + updateCount = -1; + switchToDirectStatementMode(); + try { + final Meta.Signature signature = + ArrowFlightMetaImpl.buildSignature(sql, StatementType.SELECT); + setSignature(signature); + return executeQueryInternal(signature, false); + } catch (Exception exception) { + throw wrapStatementExecutionException(sql, exception); + } + } + + @Override + public long executeLargeUpdate(final String sql) throws SQLException { + checkOpen(); + clearOpenResultSet(); + updateCount = -1; + switchToDirectStatementMode(); + + try { + final long updatedCount = getConnection().getClientHandler().executeUpdate(sql); + setSignature(ArrowFlightMetaImpl.buildSignature(sql, StatementType.IS_DML)); + updateCount = updatedCount; + return updatedCount; + } catch (Exception exception) { + throw wrapStatementExecutionException(sql, exception); + } + } + @Override public FlightInfo executeFlightInfoQuery() throws SQLException { - final ArrowFlightPreparedStatement preparedStatement = - getConnection().getMeta().getPreparedStatementInstanceOrNull(handle); + final ArrowFlightConnection connection = getConnection(); final Meta.Signature signature = getSignature(); if (signature == null) { return null; } - if (preparedStatement != null) { - final Schema resultSetSchema = preparedStatement.getDataSetSchema(); + // A Statement handle can point to either this direct statement instance or a prepared + // statement instance created by Avatica Statement.execute(String) through + // Meta.prepareAndExecute. + final AvaticaStatement currentStatement = connection.statementMap.get(handle.id); + if (currentStatement instanceof ArrowFlightPreparedStatement) { + // Prepared path: reuse the current statement implementation associated with the handle. + final FlightInfo flightInfo = + ((ArrowFlightPreparedStatement) currentStatement).executeFlightInfoQuery(); + updateSignatureColumnsFromFlightInfo(signature, flightInfo); + return flightInfo; + } + + // Direct Statement.executeQuery(String) / executeUpdate(String) path. + final FlightInfo flightInfo = connection.getClientHandler().getInfo(signature.sql); + updateSignatureColumnsFromFlightInfo(signature, flightInfo); + return flightInfo; + } + + private void updateSignatureColumnsFromFlightInfo( + final Meta.Signature signature, final FlightInfo flightInfo) { + final Schema resultSetSchema = flightInfo.getSchemaOptional().orElse(null); + if (resultSetSchema != null) { signature.columns.addAll( ConvertUtils.convertArrowFieldsToColumnMetaDataList(resultSetSchema.getFields())); setSignature(signature); - return preparedStatement.executeFlightInfoQuery(); } + } + + private SQLException wrapStatementExecutionException(final String sql, final Exception exception) + throws SQLException { + if (!(exception instanceof SQLException)) { + return AvaticaConnection.HELPER.createException( + "Error while executing SQL \"" + sql + "\": " + exception.getMessage(), exception); + } + final SQLException sqlException = (SQLException) exception; + final String prefix = "Error while executing SQL \"" + sql + "\""; + final String message = sqlException.getMessage(); + if (message != null && message.startsWith(prefix)) { + return sqlException; + } + final Throwable cause = sqlException.getCause(); + if (cause instanceof FlightRuntimeException) { + final FlightStatusCode statusCode = ((FlightRuntimeException) cause).status().code(); + if (statusCode == FlightStatusCode.UNAVAILABLE) { + return sqlException; + } + } + return AvaticaConnection.HELPER.createException(prefix + ": " + message, sqlException); + } - throw new IllegalStateException("Prepared statement query not found: " + handle); + private void clearOpenResultSet() throws SQLException { + synchronized (this) { + if (openResultSet != null) { + final AvaticaResultSet resultSet = openResultSet; + openResultSet = null; + try { + resultSet.close(); + } catch (Exception exception) { + throw AvaticaConnection.HELPER.createException( + "Error while closing previous result set", exception); + } + } + } + } + + private void switchToDirectStatementMode() throws SQLException { + final ArrowFlightConnection connection = getConnection(); + final AvaticaStatement existingStatement = connection.statementMap.get(handle.id); + if (existingStatement == this) { + return; + } + if (existingStatement instanceof ArrowFlightPreparedStatement) { + // Release resources from previously attached statement implementation before switching back + // to direct statement mode for executeQuery/executeUpdate. + ((ArrowFlightPreparedStatement) existingStatement).closeStatement(); + } + connection.statementMap.put(handle.id, this); } } diff --git a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandler.java b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandler.java index f0ea284239..08b2c5f93e 100644 --- a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandler.java +++ b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandler.java @@ -267,6 +267,16 @@ public FlightInfo getInfo(final String query) { return sqlClient.execute(query, getOptions()); } + /** + * Executes an update query directly, without creating a prepared statement first. + * + * @param query The update query. + * @return the number of rows affected. + */ + public long executeUpdate(final String query) { + return sqlClient.executeUpdate(query, getOptions()); + } + @Override public void close() throws SQLException { if (catalog.isPresent()) { diff --git a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/example/ArrowFlightJdbcSampleApp.java b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/example/ArrowFlightJdbcSampleApp.java new file mode 100644 index 0000000000..10ce0bd285 --- /dev/null +++ b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/example/ArrowFlightJdbcSampleApp.java @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.arrow.driver.jdbc.example; + +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.ResultSet; +import java.sql.ResultSetMetaData; +import java.sql.SQLException; +import java.sql.Statement; +import java.util.Properties; + +/** + * Minimal sample app for using the Arrow Flight SQL JDBC driver. + * + *

Defaults are configured for a local Dremio instance: + * + *

    + *
  • host: {@code localhost} + *
  • port: {@code 32010} + *
  • user: {@code dremio} + *
  • password: {@code dremio123} + *
+ * + *

Arguments are optional and positional: + * + *

+ *   [host] [port] [user] [password] [selectSql] [updateSql]
+ * 
+ * + *

If {@code updateSql} is omitted, only {@code Statement.executeQuery(...)} is executed. + */ +public final class ArrowFlightJdbcSampleApp { + private static final String DEFAULT_HOST = "localhost"; + private static final int DEFAULT_PORT = 32010; + private static final String DEFAULT_USER = "dremio"; + private static final String DEFAULT_PASSWORD = "dremio123"; + private static final String DEFAULT_SELECT_SQL = "SELECT 1 AS sample_value"; + + private ArrowFlightJdbcSampleApp() {} + + public static void main(final String[] args) throws Exception { + final String host = getArg(args, 0, DEFAULT_HOST); + final int port = Integer.parseInt(getArg(args, 1, Integer.toString(DEFAULT_PORT))); + final String user = getArg(args, 2, DEFAULT_USER); + final String password = getArg(args, 3, DEFAULT_PASSWORD); + final String selectSql = getArg(args, 4, DEFAULT_SELECT_SQL); + final String updateSql = getArg(args, 5, ""); + + final String url = String.format("jdbc:arrow-flight-sql://%s:%d", host, port); + final Properties properties = new Properties(); + properties.setProperty("user", user); + properties.setProperty("password", password); + properties.setProperty("useEncryption", "false"); + + System.out.println("Connecting to " + url); + try (Connection connection = DriverManager.getConnection(url, properties); + Statement statement = connection.createStatement()) { + runSelect(statement, selectSql); + + if (updateSql.isEmpty()) { + System.out.println( + "Skipping Statement.executeUpdate(...) because no updateSql argument was provided."); + } else { + runUpdate(statement, updateSql); + } + } + } + + private static void runSelect(final Statement statement, final String selectSql) + throws SQLException { + System.out.println("Running Statement.executeQuery: " + selectSql); + try (ResultSet resultSet = statement.executeQuery(selectSql)) { + final ResultSetMetaData metadata = resultSet.getMetaData(); + final int columnCount = metadata.getColumnCount(); + int rowCount = 0; + while (resultSet.next()) { + rowCount++; + final StringBuilder rowBuilder = new StringBuilder(); + for (int i = 1; i <= columnCount; i++) { + if (i > 1) { + rowBuilder.append(", "); + } + rowBuilder.append(metadata.getColumnLabel(i)).append('=').append(resultSet.getObject(i)); + } + System.out.println("row " + rowCount + ": " + rowBuilder); + } + System.out.println("Statement.executeQuery returned " + rowCount + " row(s)"); + } + } + + private static void runUpdate(final Statement statement, final String updateSql) + throws SQLException { + System.out.println("Running Statement.executeUpdate: " + updateSql); + final int updateCount = statement.executeUpdate(updateSql); + System.out.println("Statement.executeUpdate affected " + updateCount + " row(s)"); + } + + private static String getArg(final String[] args, final int index, final String defaultValue) { + if (index >= args.length) { + return defaultValue; + } + final String arg = args[index]; + return arg == null || arg.isEmpty() ? defaultValue : arg; + } +} diff --git a/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightPreparedStatementTest.java b/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightPreparedStatementTest.java index d4e7a0953d..138a8e5b76 100644 --- a/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightPreparedStatementTest.java +++ b/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightPreparedStatementTest.java @@ -109,15 +109,10 @@ public void testPrepareStatementRegistersCreatedStatementByGeneratedHandle() thr final ArrowFlightPreparedStatement arrowPreparedStatement = (ArrowFlightPreparedStatement) preparedStatement; - assertNotNull( - flightConnection - .getMeta() - .getPreparedStatementInstanceOrNull(arrowPreparedStatement.handle)); + assertNotNull(flightConnection.statementMap.get(arrowPreparedStatement.handle.id)); assertSame( arrowPreparedStatement, - flightConnection - .getMeta() - .getPreparedStatementInstanceOrNull(arrowPreparedStatement.handle)); + flightConnection.statementMap.get(arrowPreparedStatement.handle.id)); } } diff --git a/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightStatementExecuteTest.java b/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightStatementExecuteTest.java index e4df71967b..20e2059722 100644 --- a/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightStatementExecuteTest.java +++ b/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightStatementExecuteTest.java @@ -18,6 +18,7 @@ import static org.hamcrest.CoreMatchers.allOf; import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.instanceOf; import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.CoreMatchers.not; import static org.hamcrest.CoreMatchers.nullValue; @@ -146,12 +147,25 @@ public void testExecuteReplacesStatementMapEntryWithPreparedStatement() throws S assertThat(statement.execute(SAMPLE_QUERY_CMD), is(true)); - final ArrowFlightPreparedStatement preparedStatement = - arrowConnection.getMeta().getPreparedStatementInstanceOrNull(arrowStatement.handle); + final Object preparedStatement = arrowConnection.statementMap.get(arrowStatement.handle.id); assertNotNull(preparedStatement); assertSame(preparedStatement, arrowConnection.statementMap.get(arrowStatement.handle.id)); - assertThat(preparedStatement.handle.id, is(equalTo(arrowStatement.handle.id))); + assertThat(preparedStatement, instanceOf(ArrowFlightPreparedStatement.class)); + } + + @Test + public void testExecuteQueryRestoresStatementMapEntryWithStatement() throws SQLException { + final ArrowFlightStatement arrowStatement = (ArrowFlightStatement) statement; + final ArrowFlightConnection arrowConnection = (ArrowFlightConnection) connection; + + assertThat(statement.execute(SAMPLE_QUERY_CMD), is(true)); + + try (ResultSet resultSet = statement.executeQuery(SAMPLE_QUERY_CMD)) { + assertThat(resultSet.next(), is(true)); + } + + assertSame(arrowStatement, arrowConnection.statementMap.get(arrowStatement.handle.id)); } @Test diff --git a/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightStatementExecuteUpdateTest.java b/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightStatementExecuteUpdateTest.java index f7c31c590c..05e85227f0 100644 --- a/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightStatementExecuteUpdateTest.java +++ b/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightStatementExecuteUpdateTest.java @@ -22,6 +22,7 @@ import static org.hamcrest.CoreMatchers.instanceOf; import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.CoreMatchers.not; +import static org.hamcrest.CoreMatchers.startsWith; import static org.hamcrest.MatcherAssert.assertThat; import static org.junit.jupiter.api.Assertions.assertThrows; @@ -224,4 +225,13 @@ public void testShouldFailToPrepareStatementForBadStatement() { } assertThat(count, is(1)); } + + @Test + public void testExecuteLargeUpdateShouldWrapBadStatement() { + final String badQuery = "BAD INVALID UPDATE"; + final SQLException exception = + assertThrows(SQLException.class, () -> statement.executeLargeUpdate(badQuery)); + assertThat( + exception.getMessage(), startsWith(format("Error while executing SQL \"%s\"", badQuery))); + } } diff --git a/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightStatementProtocolTest.java b/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightStatementProtocolTest.java new file mode 100644 index 0000000000..c5c35a9173 --- /dev/null +++ b/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightStatementProtocolTest.java @@ -0,0 +1,376 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.arrow.driver.jdbc; + +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import com.google.protobuf.Message; +import java.sql.Connection; +import java.sql.DatabaseMetaData; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Statement; +import java.util.Collections; +import java.util.function.Consumer; +import org.apache.arrow.driver.jdbc.utils.MockFlightSqlProducer; +import org.apache.arrow.flight.FlightProducer.ServerStreamListener; +import org.apache.arrow.flight.sql.FlightSqlProducer.Schemas; +import org.apache.arrow.flight.sql.FlightSqlUtils; +import org.apache.arrow.flight.sql.impl.FlightSql.CommandGetDbSchemas; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.util.AutoCloseables; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.types.Types.MinorType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.arrow.vector.util.Text; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +public class ArrowFlightStatementProtocolTest { + private static final String SELECT_QUERY = "SELECT * FROM PROTOCOL_SELECT"; + private static final String UPDATE_QUERY = "UPDATE PROTOCOL_UPDATE"; + private static final Schema QUERY_SCHEMA = + new Schema(Collections.singletonList(Field.nullable("id", MinorType.INT.getType()))); + + private static final MockFlightSqlProducer PRODUCER = new MockFlightSqlProducer(); + + @RegisterExtension + public static final FlightServerTestExtension FLIGHT_SERVER_TEST_EXTENSION = + FlightServerTestExtension.createStandardTestExtension(PRODUCER); + + private Connection connection; + + @BeforeAll + public static void setUpBeforeClass() { + PRODUCER.addSelectQuery( + SELECT_QUERY, + QUERY_SCHEMA, + Collections.singletonList( + listener -> { + try (final BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); + final VectorSchemaRoot root = VectorSchemaRoot.create(QUERY_SCHEMA, allocator)) { + IntVector vector = (IntVector) root.getVector("id"); + vector.setSafe(0, 1); + root.setRowCount(1); + listener.start(root); + listener.putNext(); + } catch (final Throwable throwable) { + listener.error(throwable); + } finally { + listener.completed(); + } + })); + PRODUCER.addUpdateQuery(UPDATE_QUERY, 1); + + final Message commandGetDbSchemas = CommandGetDbSchemas.getDefaultInstance(); + final Consumer commandGetSchemasResultProducer = + listener -> { + try (final BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); + final VectorSchemaRoot root = + VectorSchemaRoot.create(Schemas.GET_SCHEMAS_SCHEMA, allocator)) { + final VarCharVector catalogName = (VarCharVector) root.getVector("catalog_name"); + final VarCharVector schemaName = (VarCharVector) root.getVector("db_schema_name"); + catalogName.setSafe(0, new Text("catalog_name #0")); + schemaName.setSafe(0, new Text("db_schema_name #0")); + root.setRowCount(1); + listener.start(root); + listener.putNext(); + } catch (final Throwable throwable) { + listener.error(throwable); + } finally { + listener.completed(); + } + }; + PRODUCER.addCatalogQuery(commandGetDbSchemas, commandGetSchemasResultProducer); + } + + @BeforeEach + public void setUp() throws SQLException { + PRODUCER.clearActionTypeCounter(); + PRODUCER.clearCommandTypeCounter(); + connection = FLIGHT_SERVER_TEST_EXTENSION.getConnection(false); + } + + @AfterEach + public void tearDown() throws Exception { + AutoCloseables.close(connection); + } + + @AfterAll + public static void tearDownAfterClass() throws Exception { + AutoCloseables.close(PRODUCER); + } + + @Test + public void testStatementExecuteQueryUsesStatementProtocol() throws SQLException { + try (Statement statement = connection.createStatement(); + ResultSet resultSet = statement.executeQuery(SELECT_QUERY)) { + assertTrue(resultSet.next()); + } + + assertThat( + PRODUCER + .getActionTypeCounter() + .getOrDefault(FlightSqlUtils.FLIGHT_SQL_CREATE_PREPARED_STATEMENT.getType(), 0), + is(0)); + assertThat( + PRODUCER + .getCommandTypeCounter() + .getOrDefault(MockFlightSqlProducer.COMMAND_STATEMENT_QUERY, 0), + is(1)); + assertThat( + PRODUCER + .getCommandTypeCounter() + .getOrDefault(MockFlightSqlProducer.COMMAND_PREPARED_STATEMENT_QUERY, 0), + is(0)); + } + + @Test + public void testStatementExecuteUsesPreparedProtocolForQuery() throws SQLException { + try (Statement statement = connection.createStatement()) { + assertThat(statement.execute(SELECT_QUERY), is(true)); + try (ResultSet resultSet = statement.getResultSet()) { + assertTrue(resultSet.next()); + } + } + + assertThat( + PRODUCER + .getActionTypeCounter() + .getOrDefault(FlightSqlUtils.FLIGHT_SQL_CREATE_PREPARED_STATEMENT.getType(), 0), + is(1)); + assertThat( + PRODUCER + .getCommandTypeCounter() + .getOrDefault(MockFlightSqlProducer.COMMAND_PREPARED_STATEMENT_QUERY, 0), + is(1)); + assertThat( + PRODUCER + .getCommandTypeCounter() + .getOrDefault(MockFlightSqlProducer.COMMAND_STATEMENT_QUERY, 0), + is(0)); + } + + @Test + public void testStatementExecuteUpdateUsesStatementProtocol() throws SQLException { + try (Statement statement = connection.createStatement()) { + assertThat(statement.executeUpdate(UPDATE_QUERY), is(1)); + } + + assertThat( + PRODUCER + .getActionTypeCounter() + .getOrDefault(FlightSqlUtils.FLIGHT_SQL_CREATE_PREPARED_STATEMENT.getType(), 0), + is(0)); + assertThat( + PRODUCER + .getCommandTypeCounter() + .getOrDefault(MockFlightSqlProducer.COMMAND_STATEMENT_UPDATE, 0), + is(1)); + assertThat( + PRODUCER + .getCommandTypeCounter() + .getOrDefault(MockFlightSqlProducer.COMMAND_PREPARED_STATEMENT_UPDATE, 0), + is(0)); + } + + @Test + public void testStatementExecuteUsesPreparedProtocolForUpdate() throws SQLException { + try (Statement statement = connection.createStatement()) { + assertThat(statement.execute(UPDATE_QUERY), is(false)); + assertThat(statement.getUpdateCount(), is(1)); + } + + assertThat( + PRODUCER + .getActionTypeCounter() + .getOrDefault(FlightSqlUtils.FLIGHT_SQL_CREATE_PREPARED_STATEMENT.getType(), 0), + is(1)); + assertThat( + PRODUCER + .getCommandTypeCounter() + .getOrDefault(MockFlightSqlProducer.COMMAND_PREPARED_STATEMENT_UPDATE, 0), + is(1)); + } + + @Test + public void testStatementExecuteThenExecuteUpdateUsesStatementProtocol() throws SQLException { + try (Statement statement = connection.createStatement()) { + assertThat(statement.execute(SELECT_QUERY), is(true)); + try (ResultSet resultSet = statement.getResultSet()) { + assertTrue(resultSet.next()); + } + assertThat(statement.executeUpdate(UPDATE_QUERY), is(1)); + } + + assertThat( + PRODUCER + .getActionTypeCounter() + .getOrDefault(FlightSqlUtils.FLIGHT_SQL_CREATE_PREPARED_STATEMENT.getType(), 0), + is(1)); + assertThat( + PRODUCER + .getCommandTypeCounter() + .getOrDefault(MockFlightSqlProducer.COMMAND_PREPARED_STATEMENT_QUERY, 0), + is(1)); + assertThat( + PRODUCER + .getCommandTypeCounter() + .getOrDefault(MockFlightSqlProducer.COMMAND_STATEMENT_UPDATE, 0), + is(1)); + } + + @Test + public void testStatementExecuteUpdateThenExecuteQueryUsesStatementProtocol() + throws SQLException { + try (Statement statement = connection.createStatement()) { + assertThat(statement.execute(UPDATE_QUERY), is(false)); + assertThat(statement.getUpdateCount(), is(1)); + try (ResultSet resultSet = statement.executeQuery(SELECT_QUERY)) { + assertTrue(resultSet.next()); + } + } + + assertThat( + PRODUCER + .getActionTypeCounter() + .getOrDefault(FlightSqlUtils.FLIGHT_SQL_CREATE_PREPARED_STATEMENT.getType(), 0), + is(1)); + assertThat( + PRODUCER + .getCommandTypeCounter() + .getOrDefault(MockFlightSqlProducer.COMMAND_PREPARED_STATEMENT_UPDATE, 0), + is(1)); + assertThat( + PRODUCER + .getCommandTypeCounter() + .getOrDefault(MockFlightSqlProducer.COMMAND_STATEMENT_QUERY, 0), + is(1)); + } + + @Test + public void testPreparedStatementExecuteQueryUsesPreparedProtocol() throws SQLException { + try (PreparedStatement statement = connection.prepareStatement(SELECT_QUERY); + ResultSet resultSet = statement.executeQuery()) { + assertTrue(resultSet.next()); + } + + assertThat( + PRODUCER + .getActionTypeCounter() + .getOrDefault(FlightSqlUtils.FLIGHT_SQL_CREATE_PREPARED_STATEMENT.getType(), 0), + is(1)); + assertThat( + PRODUCER + .getCommandTypeCounter() + .getOrDefault(MockFlightSqlProducer.COMMAND_PREPARED_STATEMENT_QUERY, 0), + is(1)); + assertThat( + PRODUCER + .getCommandTypeCounter() + .getOrDefault(MockFlightSqlProducer.COMMAND_STATEMENT_QUERY, 0), + is(0)); + } + + @Test + public void testPreparedStatementExecuteUsesPreparedProtocolForQuery() throws SQLException { + try (PreparedStatement statement = connection.prepareStatement(SELECT_QUERY)) { + assertThat(statement.execute(), is(true)); + try (ResultSet resultSet = statement.getResultSet()) { + assertTrue(resultSet.next()); + } + } + + assertThat( + PRODUCER + .getActionTypeCounter() + .getOrDefault(FlightSqlUtils.FLIGHT_SQL_CREATE_PREPARED_STATEMENT.getType(), 0), + is(1)); + assertThat( + PRODUCER + .getCommandTypeCounter() + .getOrDefault(MockFlightSqlProducer.COMMAND_PREPARED_STATEMENT_QUERY, 0), + is(1)); + assertThat( + PRODUCER + .getCommandTypeCounter() + .getOrDefault(MockFlightSqlProducer.COMMAND_STATEMENT_QUERY, 0), + is(0)); + } + + @Test + public void testPreparedStatementExecuteUpdateUsesPreparedProtocol() throws SQLException { + try (PreparedStatement statement = connection.prepareStatement(UPDATE_QUERY)) { + assertThat(statement.executeUpdate(), is(1)); + } + + assertThat( + PRODUCER + .getActionTypeCounter() + .getOrDefault(FlightSqlUtils.FLIGHT_SQL_CREATE_PREPARED_STATEMENT.getType(), 0), + is(1)); + assertThat( + PRODUCER + .getCommandTypeCounter() + .getOrDefault(MockFlightSqlProducer.COMMAND_PREPARED_STATEMENT_UPDATE, 0), + is(1)); + } + + @Test + public void testPreparedStatementExecuteUsesPreparedProtocolForUpdate() throws SQLException { + try (PreparedStatement statement = connection.prepareStatement(UPDATE_QUERY)) { + assertThat(statement.execute(), is(false)); + assertThat(statement.getUpdateCount(), is(1)); + } + + assertThat( + PRODUCER + .getActionTypeCounter() + .getOrDefault(FlightSqlUtils.FLIGHT_SQL_CREATE_PREPARED_STATEMENT.getType(), 0), + is(1)); + assertThat( + PRODUCER + .getCommandTypeCounter() + .getOrDefault(MockFlightSqlProducer.COMMAND_PREPARED_STATEMENT_UPDATE, 0), + is(1)); + } + + @Test + public void testMetadataGetSchemasUsesJdbcApi() throws SQLException { + final DatabaseMetaData metaData = connection.getMetaData(); + try (ResultSet resultSet = metaData.getSchemas()) { + assertTrue(resultSet.next()); + } + + assertThat( + PRODUCER + .getActionTypeCounter() + .getOrDefault(FlightSqlUtils.FLIGHT_SQL_CREATE_PREPARED_STATEMENT.getType(), 0), + is(0)); + } +} diff --git a/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/MockFlightSqlProducer.java b/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/MockFlightSqlProducer.java index 45c2a96404..230c1346fb 100644 --- a/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/MockFlightSqlProducer.java +++ b/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/MockFlightSqlProducer.java @@ -103,6 +103,12 @@ public final class MockFlightSqlProducer implements FlightSqlProducer { private final Map>> expectedParameterValues = new HashMap<>(); private final Map actionTypeCounter = new HashMap<>(); + private final Map commandTypeCounter = new HashMap<>(); + + public static final String COMMAND_STATEMENT_QUERY = "statement_query"; + public static final String COMMAND_STATEMENT_UPDATE = "statement_update"; + public static final String COMMAND_PREPARED_STATEMENT_QUERY = "prepared_statement_query"; + public static final String COMMAND_PREPARED_STATEMENT_UPDATE = "prepared_statement_update"; private static FlightInfo getFlightInfoExportedAndImportedKeys( final Message message, final FlightDescriptor descriptor) { @@ -269,6 +275,7 @@ public FlightInfo getFlightInfoStatement( final CommandStatementQuery commandStatementQuery, final CallContext callContext, final FlightDescriptor flightDescriptor) { + incrementCommandTypeCounter(COMMAND_STATEMENT_QUERY); final String query = commandStatementQuery.getQuery(); final Entry> queryInfo = Preconditions.checkNotNull( @@ -289,6 +296,7 @@ public FlightInfo getFlightInfoPreparedStatement( final CommandPreparedStatementQuery commandPreparedStatementQuery, final CallContext callContext, final FlightDescriptor flightDescriptor) { + incrementCommandTypeCounter(COMMAND_PREPARED_STATEMENT_QUERY); final ByteString preparedStatementHandle = commandPreparedStatementQuery.getPreparedStatementHandle(); @@ -356,6 +364,7 @@ public Runnable acceptPutStatement( final CallContext callContext, final FlightStream flightStream, final StreamListener streamListener) { + incrementCommandTypeCounter(COMMAND_STATEMENT_UPDATE); return () -> { final String query = commandStatementUpdate.getQuery(); final BiConsumer> resultProvider = @@ -429,6 +438,7 @@ public Runnable acceptPutPreparedStatementUpdate( final CallContext callContext, final FlightStream flightStream, final StreamListener streamListener) { + incrementCommandTypeCounter(COMMAND_PREPARED_STATEMENT_UPDATE); final ByteString handle = commandPreparedStatementUpdate.getPreparedStatementHandle(); final String query = Preconditions.checkNotNull( @@ -651,10 +661,22 @@ public void clearActionTypeCounter() { actionTypeCounter.clear(); } + public void clearCommandTypeCounter() { + commandTypeCounter.clear(); + } + public Map getActionTypeCounter() { return actionTypeCounter; } + public Map getCommandTypeCounter() { + return commandTypeCounter; + } + + private void incrementCommandTypeCounter(String commandType) { + commandTypeCounter.put(commandType, commandTypeCounter.getOrDefault(commandType, 0) + 1); + } + private void getStreamCatalogFunctions( final Message ticket, final ServerStreamListener serverStreamListener) { Preconditions.checkNotNull(