Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import io.netty.util.concurrent.DefaultThreadFactory;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
import java.util.Properties;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
Expand All @@ -42,6 +44,8 @@ public final class ArrowFlightConnection extends AvaticaConnection {
private final ArrowFlightSqlClientHandler clientHandler;
private final ArrowFlightConnectionConfigImpl config;
private ExecutorService executorService;
private int metadataResultSetCount;
private Map<Integer, ArrowFlightJdbcFlightStreamResultSet> metadataResultSetMap = new HashMap<>();

/**
* Creates a new {@link ArrowFlightConnection}.
Expand All @@ -66,6 +70,7 @@ private ArrowFlightConnection(
this.config = Preconditions.checkNotNull(config, "Config cannot be null.");
this.allocator = Preconditions.checkNotNull(allocator, "Allocator cannot be null.");
this.clientHandler = Preconditions.checkNotNull(clientHandler, "Handler cannot be null.");
this.metadataResultSetCount = 0;
}

/**
Expand Down Expand Up @@ -173,6 +178,31 @@ synchronized ExecutorService getExecutorService() {
: executorService;
}

/**
* Registers a new metadata ResultSet and assigns it a unique ID. Metadata ResultSets are those
* created without an associated Statement.
*
* @param resultSet the ResultSet to register
* @return the assigned ID
*/
int getNewMetadataResultSetId(ArrowFlightJdbcFlightStreamResultSet resultSet) {
metadataResultSetMap.put(metadataResultSetCount, resultSet);
return metadataResultSetCount++;
}

/**
* Unregisters a metadata ResultSet when it is closed. This method is called by metadata
* ResultSets during their close operation to remove themselves from the tracking map.
*
* @param id the ID of the ResultSet to unregister, or null if not a metadata ResultSet
*/
void onResultSetClose(Integer id) {
if (id == null) {
return;
}
metadataResultSetMap.remove(id);
}

@Override
public Properties getClientInfo() {
final Properties copy = new Properties();
Expand All @@ -190,7 +220,9 @@ public void close() throws SQLException {
} catch (final Exception e) {
topLevelException = e;
}
// copies of the collections are used to avoid concurrent modification problems
ArrayList<AutoCloseable> closeables = new ArrayList<>(statementMap.values());
closeables.addAll(new ArrayList<>(metadataResultSetMap.values()));
closeables.add(clientHandler);
closeables.addAll(allocator.getChildAllocators());
closeables.add(allocator);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ public final class ArrowFlightJdbcFlightStreamResultSet
private VectorSchemaRoot currentVectorSchemaRoot;

private Schema schema;
private Integer id = null; // used for metadata result sets only

/** Public constructor used by ArrowFlightJdbcFactory. */
ArrowFlightJdbcFlightStreamResultSet(
Expand Down Expand Up @@ -82,6 +83,7 @@ private ArrowFlightJdbcFlightStreamResultSet(
super(null, state, signature, resultSetMetaData, timeZone, firstFrame);
this.connection = connection;
this.flightInfo = flightInfo;
this.id = connection.getNewMetadataResultSetId(this);
}

/**
Expand Down Expand Up @@ -234,7 +236,12 @@ protected void cancel() {

@Override
public synchronized void close() {

try {
if (isClosed()) {
return;
}
this.connection.onResultSetClose(id);
if (flightEndpointDataQueue != null) {
// flightStreamQueue should close currentFlightStream internally
flightEndpointDataQueue.close();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import java.sql.Types;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.TimeZone;
import org.apache.arrow.driver.jdbc.utils.ConvertUtils;
Expand Down Expand Up @@ -159,12 +158,10 @@ public void close() {
} catch (final Exception e) {
exceptions.add(e);
}
if (!Objects.isNull(statement)) {
try {
super.close();
} catch (final Exception e) {
exceptions.add(e);
}
try {
super.close();
} catch (final Exception e) {
exceptions.add(e);
}
exceptions.parallelStream().forEach(e -> LOGGER.error(e.getMessage(), e));
exceptions.stream()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,31 +16,42 @@
*/
package org.apache.arrow.driver.jdbc;

import static java.lang.String.format;
import static java.util.stream.IntStream.range;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;

import com.google.protobuf.Message;
import java.net.URISyntaxException;
import java.sql.Connection;
import java.sql.Driver;
import java.sql.DriverManager;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.Map;
import java.util.Properties;
import java.util.function.Consumer;
import org.apache.arrow.driver.jdbc.authentication.UserPasswordAuthentication;
import org.apache.arrow.driver.jdbc.client.ArrowFlightSqlClientHandler;
import org.apache.arrow.driver.jdbc.utils.ArrowFlightConnectionConfigImpl.ArrowFlightConnectionProperty;
import org.apache.arrow.driver.jdbc.utils.MockFlightSqlProducer;
import org.apache.arrow.flight.FlightMethod;
import org.apache.arrow.flight.FlightProducer.ServerStreamListener;
import org.apache.arrow.flight.NoOpSessionOptionValueVisitor;
import org.apache.arrow.flight.SessionOptionValue;
import org.apache.arrow.flight.sql.FlightSqlProducer.Schemas;
import org.apache.arrow.flight.sql.impl.FlightSql.CommandGetTableTypes;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.util.AutoCloseables;
import org.apache.arrow.vector.VarCharVector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.util.Text;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
Expand Down Expand Up @@ -698,4 +709,63 @@ public void testStatementsClosedOnConnectionClose() throws Exception {
assertTrue(statements[i].isClosed());
}
}

@Test
public void testResultSetsFromDatabaseMetadataClosedOnConnectionClose() throws Exception {
// set up the FlightProducer to respond to metadata queries
// getTableTypes() is being used, but any other method would work
int rowCount = 3;
final Message commandGetTableTypes = CommandGetTableTypes.getDefaultInstance();
final Consumer<ServerStreamListener> commandGetTableTypesResultProducer =
listener -> {
try (final BufferAllocator allocator = new RootAllocator();
final VectorSchemaRoot root =
VectorSchemaRoot.create(Schemas.GET_TABLE_TYPES_SCHEMA, allocator)) {
final VarCharVector tableType = (VarCharVector) root.getVector("table_type");
range(0, rowCount)
.forEach(i -> tableType.setSafe(i, new Text(format("table_type #%d", i))));
root.setRowCount(rowCount);
listener.start(root);
listener.putNext();
} catch (final Throwable throwable) {
listener.error(throwable);
} finally {
listener.completed();
}
};
PRODUCER.addCatalogQuery(commandGetTableTypes, commandGetTableTypesResultProducer);

// create a connection
final Properties properties = new Properties();
properties.put(ArrowFlightConnectionProperty.HOST.camelName(), "localhost");
properties.put(
ArrowFlightConnectionProperty.PORT.camelName(), FLIGHT_SERVER_TEST_EXTENSION.getPort());
properties.put(ArrowFlightConnectionProperty.USER.camelName(), userTest);
properties.put(ArrowFlightConnectionProperty.PASSWORD.camelName(), passTest);
properties.put("useEncryption", false);

Connection connection =
DriverManager.getConnection(
"jdbc:arrow-flight-sql://"
+ FLIGHT_SERVER_TEST_EXTENSION.getHost()
+ ":"
+ FLIGHT_SERVER_TEST_EXTENSION.getPort(),
properties);

// create ResultSets from DatabaseMetadata
int numResultSets = 3;
ResultSet[] resultSets = new ResultSet[numResultSets];
for (int i = 0; i < numResultSets; i++) {
resultSets[i] = connection.getMetaData().getTableTypes();
assertFalse(resultSets[i].isClosed());
}

// close the connection
connection.close();

// assert the ResultSets are closed
for (int i = 0; i < numResultSets; i++) {
assertTrue(resultSets[i].isClosed());
}
}
}
Loading