Skip to content

Commit 6ffbaa1

Browse files
committed
feat(rust/core)!: define cancellation in a sensible way
Closes #3454.
1 parent 11205f1 commit 6ffbaa1

8 files changed

Lines changed: 154 additions & 74 deletions

File tree

rust/core/src/lib.rs

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,24 @@ pub trait Database: Optionable<Option = OptionDatabase> {
117117
) -> Result<Self::ConnectionType>;
118118
}
119119

120+
/// A handle to cancel an in-progress operation on a connection.
121+
///
122+
/// This is a separated handle because otherwise it would be impossible to
123+
/// call a `cancel` method on a connection or statement itself.
124+
pub trait CancelHandle: Send {
125+
/// Cancel the in-progress operation on a connection.
126+
fn try_cancel(&self) -> Result<()>;
127+
}
128+
129+
/// A cancellation handle that does nothing (because cancellation is unsupported).
130+
pub struct NoOpCancellationHandle;
131+
132+
impl CancelHandle for NoOpCancellationHandle {
133+
fn try_cancel(&self) -> Result<()> {
134+
Ok(())
135+
}
136+
}
137+
120138
/// A handle to an ADBC connection.
121139
///
122140
/// Connections provide methods for query execution, managing prepared
@@ -133,8 +151,10 @@ pub trait Connection: Optionable<Option = OptionConnection> {
133151
/// Allocate and initialize a new statement.
134152
fn new_statement(&mut self) -> Result<Self::StatementType>;
135153

136-
/// Cancel the in-progress operation on a connection.
137-
fn cancel(&mut self) -> Result<()>;
154+
/// Get a handle to cancel operations on this connection.
155+
fn get_cancel_handle(&self) -> Box<dyn CancelHandle> {
156+
Box::new(NoOpCancellationHandle {})
157+
}
138158

139159
/// Get metadata about the database/driver.
140160
///
@@ -494,15 +514,17 @@ pub trait Statement: Optionable<Option = OptionStatement> {
494514
/// expected to be executed repeatedly, call [Statement::prepare] first.
495515
fn set_substrait_plan(&mut self, plan: impl AsRef<[u8]>) -> Result<()>;
496516

497-
/// Cancel execution of an in-progress query.
517+
/// Get a handle to cancel operations on this statement.
498518
///
499-
/// This can be called during [Statement::execute] (or similar), or while
500-
/// consuming a result set returned from such.
519+
/// The resulting handle can be called during [Statement::execute] (or
520+
/// similar), or while consuming a result set returned from such.
501521
///
502522
/// # Since
503523
///
504524
/// ADBC API revision 1.1.0
505-
fn cancel(&mut self) -> Result<()>;
525+
fn get_cancel_handle(&self) -> Box<dyn CancelHandle> {
526+
Box::new(NoOpCancellationHandle {})
527+
}
506528
}
507529

508530
/// Each data partition is described by an opaque byte array and can be

rust/driver/datafusion/src/lib.rs

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -735,10 +735,6 @@ impl Connection for DataFusionConnection {
735735
})
736736
}
737737

738-
fn cancel(&mut self) -> adbc_core::error::Result<()> {
739-
todo!()
740-
}
741-
742738
fn get_info(
743739
&self,
744740
codes: Option<std::collections::HashSet<adbc_core::options::InfoCode>>,
@@ -981,10 +977,6 @@ impl Statement for DataFusionStatement {
981977
self.substrait_plan = Some(Plan::decode(plan.as_ref()).unwrap());
982978
Ok(())
983979
}
984-
985-
fn cancel(&mut self) -> adbc_core::error::Result<()> {
986-
todo!()
987-
}
988980
}
989981

990982
#[cfg(feature = "ffi")]

rust/driver/dummy/src/lib.rs

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -294,16 +294,24 @@ impl Connection for DummyConnection {
294294
Ok(Self::StatementType::default())
295295
}
296296

297-
// This method is used to test that errors round-trip correctly.
298-
fn cancel(&mut self) -> Result<()> {
299-
let mut error = Error::with_message_and_status("message", Status::Cancelled);
300-
error.vendor_code = constants::ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA;
301-
error.sqlstate = [1, 2, 3, 4, 5];
302-
error.details = Some(vec![
303-
("key1".into(), b"AAA".into()),
304-
("key2".into(), b"ZZZZZ".into()),
305-
]);
306-
Err(error)
297+
/// This method is used to test that errors round-trip correctly.
298+
fn get_cancel_handle(&self) -> Box<dyn adbc_core::CancelHandle> {
299+
struct CancelHandle;
300+
301+
impl adbc_core::CancelHandle for CancelHandle {
302+
fn try_cancel(&self) -> Result<()> {
303+
let mut error = Error::with_message_and_status("message", Status::Cancelled);
304+
error.vendor_code = constants::ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA;
305+
error.sqlstate = [1, 2, 3, 4, 5];
306+
error.details = Some(vec![
307+
("key1".into(), b"AAA".into()),
308+
("key2".into(), b"ZZZZZ".into()),
309+
]);
310+
Err(error)
311+
}
312+
}
313+
314+
Box::new(CancelHandle)
307315
}
308316

309317
fn commit(&mut self) -> Result<()> {
@@ -848,10 +856,6 @@ impl Statement for DummyStatement {
848856
Ok(())
849857
}
850858

851-
fn cancel(&mut self) -> Result<()> {
852-
Ok(())
853-
}
854-
855859
fn execute(&mut self) -> Result<impl RecordBatchReader> {
856860
maybe_panic("StatementExecuteQuery");
857861
let batch = get_table_data();

rust/driver/snowflake/src/connection.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,8 @@ impl adbc_core::Connection for Connection {
7070
self.0.new_statement().map(Statement)
7171
}
7272

73-
fn cancel(&mut self) -> Result<()> {
74-
self.0.cancel()
73+
fn get_cancel_handle(&self) -> Box<dyn adbc_core::CancelHandle> {
74+
self.0.get_cancel_handle()
7575
}
7676

7777
fn get_info(&self, codes: Option<HashSet<InfoCode>>) -> Result<impl RecordBatchReader + Send> {

rust/driver/snowflake/src/statement.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ impl adbc_core::Statement for Statement {
9696
self.0.set_substrait_plan(plan)
9797
}
9898

99-
fn cancel(&mut self) -> Result<()> {
100-
self.0.cancel()
99+
fn get_cancel_handle(&self) -> Box<dyn adbc_core::CancelHandle> {
100+
self.0.get_cancel_handle()
101101
}
102102
}

rust/driver_manager/src/lib.rs

Lines changed: 58 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1198,6 +1198,31 @@ pub struct ManagedConnection {
11981198
inner: Arc<ManagedConnectionInner>,
11991199
}
12001200

1201+
struct ConnectionCancelHandle {
1202+
inner: std::sync::Weak<ManagedConnectionInner>,
1203+
}
1204+
1205+
impl adbc_core::CancelHandle for ConnectionCancelHandle {
1206+
fn try_cancel(&self) -> Result<()> {
1207+
if let Some(inner) = self.inner.upgrade() {
1208+
if let AdbcVersion::V100 = inner.database.driver.version {
1209+
return Err(Error::with_message_and_status(
1210+
ERR_CANCEL_UNSUPPORTED,
1211+
Status::NotImplemented,
1212+
));
1213+
}
1214+
let driver = &inner.database.driver.driver;
1215+
let mut connection = inner.connection.lock().unwrap();
1216+
let mut error = adbc_ffi::FFI_AdbcError::with_driver(driver);
1217+
let method = driver_method!(driver, ConnectionCancel);
1218+
let status = unsafe { method(connection.deref_mut(), &mut error) };
1219+
check_status(status, error)
1220+
} else {
1221+
Ok(())
1222+
}
1223+
}
1224+
}
1225+
12011226
impl ManagedConnection {
12021227
fn ffi_driver(&self) -> &adbc_ffi::FFI_AdbcDriver {
12031228
&self.inner.database.driver.driver
@@ -1296,19 +1321,10 @@ impl Connection for ManagedConnection {
12961321
Ok(Self::StatementType { inner })
12971322
}
12981323

1299-
fn cancel(&mut self) -> Result<()> {
1300-
if let AdbcVersion::V100 = self.driver_version() {
1301-
return Err(Error::with_message_and_status(
1302-
ERR_CANCEL_UNSUPPORTED,
1303-
Status::NotImplemented,
1304-
));
1305-
}
1306-
let driver = self.ffi_driver();
1307-
let mut connection = self.inner.connection.lock().unwrap();
1308-
let mut error = adbc_ffi::FFI_AdbcError::with_driver(driver);
1309-
let method = driver_method!(driver, ConnectionCancel);
1310-
let status = unsafe { method(connection.deref_mut(), &mut error) };
1311-
check_status(status, error)
1324+
fn get_cancel_handle(&self) -> Box<dyn adbc_core::CancelHandle> {
1325+
Box::new(ConnectionCancelHandle {
1326+
inner: Arc::downgrade(&self.inner),
1327+
})
13121328
}
13131329

13141330
fn commit(&mut self) -> Result<()> {
@@ -1612,6 +1628,31 @@ impl ManagedStatement {
16121628
}
16131629
}
16141630

1631+
struct StatementCancelHandle {
1632+
inner: std::sync::Weak<ManagedStatementInner>,
1633+
}
1634+
1635+
impl adbc_core::CancelHandle for StatementCancelHandle {
1636+
fn try_cancel(&self) -> Result<()> {
1637+
if let Some(inner) = self.inner.upgrade() {
1638+
if let AdbcVersion::V100 = inner.connection.database.driver.version {
1639+
return Err(Error::with_message_and_status(
1640+
ERR_CANCEL_UNSUPPORTED,
1641+
Status::NotImplemented,
1642+
));
1643+
}
1644+
let driver = &inner.connection.database.driver.driver;
1645+
let mut statement = inner.statement.lock().unwrap();
1646+
let mut error = adbc_ffi::FFI_AdbcError::with_driver(driver);
1647+
let method = driver_method!(driver, StatementCancel);
1648+
let status = unsafe { method(statement.deref_mut(), &mut error) };
1649+
check_status(status, error)
1650+
} else {
1651+
Ok(())
1652+
}
1653+
}
1654+
}
1655+
16151656
impl Statement for ManagedStatement {
16161657
fn bind(&mut self, batch: RecordBatch) -> Result<()> {
16171658
let driver = self.ffi_driver();
@@ -1636,19 +1677,10 @@ impl Statement for ManagedStatement {
16361677
Ok(())
16371678
}
16381679

1639-
fn cancel(&mut self) -> Result<()> {
1640-
if let AdbcVersion::V100 = self.driver_version() {
1641-
return Err(Error::with_message_and_status(
1642-
ERR_CANCEL_UNSUPPORTED,
1643-
Status::NotImplemented,
1644-
));
1645-
}
1646-
let driver = self.ffi_driver();
1647-
let mut statement = self.inner.statement.lock().unwrap();
1648-
let mut error = adbc_ffi::FFI_AdbcError::with_driver(driver);
1649-
let method = driver_method!(driver, StatementCancel);
1650-
let status = unsafe { method(statement.deref_mut(), &mut error) };
1651-
check_status(status, error)
1680+
fn get_cancel_handle(&self) -> Box<dyn adbc_core::CancelHandle> {
1681+
Box::new(StatementCancelHandle {
1682+
inner: Arc::downgrade(&self.inner),
1683+
})
16521684
}
16531685

16541686
fn execute(&mut self) -> Result<impl RecordBatchReader> {

rust/driver_manager/tests/driver_manager_sqlite.rs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -128,9 +128,10 @@ fn test_connection_get_option() {
128128
fn test_connection_cancel() {
129129
let mut driver = get_driver();
130130
let database = get_database(&mut driver);
131-
let mut connection = database.new_connection().unwrap();
131+
let connection = database.new_connection().unwrap();
132132

133-
let error = connection.cancel().unwrap_err();
133+
let handle = connection.get_cancel_handle();
134+
let error = handle.try_cancel().unwrap_err();
134135
assert_eq!(error.status, Status::NotImplemented);
135136
}
136137

@@ -285,9 +286,10 @@ fn test_statement_cancel() {
285286
let mut driver = get_driver();
286287
let database = get_database(&mut driver);
287288
let mut connection = database.new_connection().unwrap();
288-
let mut statement = connection.new_statement().unwrap();
289+
let statement = connection.new_statement().unwrap();
289290

290-
let error = statement.cancel().unwrap_err();
291+
let handle = statement.get_cancel_handle();
292+
let error = handle.try_cancel().unwrap_err();
291293
assert_eq!(error.status, Status::NotImplemented);
292294
}
293295

0 commit comments

Comments
 (0)