diff --git a/python/zerocopy/Cargo.lock b/python/zerocopy/Cargo.lock index 330c59e9f..585b1e68f 100644 --- a/python/zerocopy/Cargo.lock +++ b/python/zerocopy/Cargo.lock @@ -403,7 +403,7 @@ checksum = "adb6935a6f5c20170eeceb1a3835a49e12e19d792f6dd344ccc76a985ca5a6ca" [[package]] name = "tskit" -version = "0.15.0-alpha.4" +version = "0.16.3" dependencies = [ "bindgen", "cc", diff --git a/python/zerocopy/src/lib.rs b/python/zerocopy/src/lib.rs index fa1bb4f11..661e0f8ec 100644 --- a/python/zerocopy/src/lib.rs +++ b/python/zerocopy/src/lib.rs @@ -73,3 +73,93 @@ fn treeseq_roundtrip() { unsafe { pyo3::ffi::PyMem_Free(tables_ptr.as_ptr().cast::()) }; }); } + +#[test] +fn test_treeseq_new_from_raw() { + use pyo3::prelude::*; + Python::attach(|_py| { + let mut tables = tskit::TableCollection::new(100.).unwrap(); + tables.add_node(0, 0.0, -1, -1).unwrap(); + + let treeseq = unsafe { + pyo3::ffi::PyMem_Malloc(std::mem::size_of::()) + } as *mut tskit::bindings::tsk_treeseq_t; + let rv = unsafe { + tskit::bindings::tsk_treeseq_init( + treeseq, + tables.into_mut_ptr().unwrap().as_ptr(), + tskit::bindings::TSK_TAKE_OWNERSHIP | tskit::bindings::TSK_TS_INIT_BUILD_INDEXES, + ) + }; + assert_eq!(rv, 0); + let ptr = std::ptr::NonNull::new(treeseq).unwrap(); + let rs_treeseq = unsafe { tskit::TreeSequence::new_from_raw(ptr) }.unwrap(); + assert_eq!(rs_treeseq.nodes().num_rows(), 1); + let mut ptr = rs_treeseq.into_mut_ptr().unwrap(); + let rv = unsafe { tskit::bindings::tsk_treeseq_free(ptr.as_mut()) }; + assert_eq!(rv, 0); + unsafe { + pyo3::ffi::PyMem_Free(ptr.as_ptr() as *mut std::ffi::c_void); + } + }); +} + +#[test] +fn test_treeseq_new_from_raw_tables_also_py_allocated() { + use pyo3::prelude::*; + use tskit::bindings::tsk_table_collection_init; + use tskit::bindings::tsk_table_collection_t; + + Python::attach(|_py| { + let tables_ptr = unsafe { + pyo3::ffi::PyMem_Malloc(std::mem::size_of::()) + .cast::() + }; + assert!(!tables_ptr.is_null()); + + // SAFETY: ptr is not null + let rv = unsafe { tsk_table_collection_init(tables_ptr, 0) }; + unsafe { (*tables_ptr).sequence_length = 100.0 }; + assert_eq!(unsafe { *tables_ptr }.sequence_length, 100.); + assert_eq!(rv, 0); + let tables_ptr = std::ptr::NonNull::new(tables_ptr).unwrap(); + // Not null and initialized w/o error + let mut tables = unsafe { tskit::TableCollection::new_from_raw(tables_ptr) }.unwrap(); + let _ = tables.add_node(0, 0.0, -1, -1).unwrap(); + assert_eq!(tables.sequence_length(), 100.); + + let treeseq = unsafe { + pyo3::ffi::PyMem_Malloc(std::mem::size_of::()) + } as *mut tskit::bindings::tsk_treeseq_t; + let rv = unsafe { + tskit::bindings::tsk_treeseq_init( + treeseq, + tables.into_mut_ptr().unwrap().as_ptr(), + tskit::bindings::TSK_TAKE_OWNERSHIP | tskit::bindings::TSK_TS_INIT_BUILD_INDEXES, + ) + }; + assert_eq!(rv, 0); + let ptr = std::ptr::NonNull::new(treeseq).unwrap(); + let rs_treeseq = unsafe { tskit::TreeSequence::new_from_raw(ptr) }.unwrap(); + assert_eq!(rs_treeseq.nodes().num_rows(), 1); + let mut ptr = rs_treeseq.into_mut_ptr().unwrap(); + + // We allocated the tables via the Python allocator. + // Internally, tskit will free it with C's free, which is + // UB! + // To circumvent UB, we must manually do the steps below. + // We know to do these steps b/c we have read the implementation + // of tsk_treeseq_free. + unsafe { + tskit::bindings::tsk_table_collection_free(ptr.as_mut().tables); + pyo3::ffi::PyMem_Free(ptr.as_mut().tables as *mut std::ffi::c_void); + ptr.as_mut().tables = std::ptr::null_mut(); + } + + let rv = unsafe { tskit::bindings::tsk_treeseq_free(ptr.as_mut()) }; + assert_eq!(rv, 0); + unsafe { + pyo3::ffi::PyMem_Free(ptr.as_ptr() as *mut std::ffi::c_void); + } + }); +} diff --git a/src/sys/treeseq.rs b/src/sys/treeseq.rs index 0a78eaa96..1bc01a6e8 100644 --- a/src/sys/treeseq.rs +++ b/src/sys/treeseq.rs @@ -46,6 +46,17 @@ impl TreeSequence { Ok(Self(tsk)) } + pub fn into_raw(self) -> *mut tsk_treeseq_t { + self.0.into_raw() + } + + // # Safety + // + // `treeseq` must be an initialized `tsk_treeseq_t_t` + pub unsafe fn new_owning_from_nonnull(treeseq: std::ptr::NonNull) -> Self { + Self(TskBox::new_init_owning_from_ptr(treeseq)) + } + pub fn as_ref(&self) -> &bindings::tsk_treeseq_t { self.0.as_ref() } diff --git a/src/trees/treeseq.rs b/src/trees/treeseq.rs index 596056e67..1d1114b31 100644 --- a/src/trees/treeseq.rs +++ b/src/trees/treeseq.rs @@ -18,6 +18,7 @@ use crate::TreeFlags; use crate::TreeSequenceFlags; use crate::TskReturnValue; use sys::bindings as ll_bindings; +use sys::bindings::tsk_treeseq_t; #[cfg(feature = "provenance")] use std::ffi::c_char; @@ -732,6 +733,23 @@ impl TreeSequence { ) -> impl crate::TableColumn + '_ { crate::table_column::OpaqueTableColumn(self.edge_removal_order()) } + + #[cfg(feature = "unsafe_init")] + pub unsafe fn new_from_raw( + ptr: std::ptr::NonNull, + ) -> Result { + let tables = unsafe { + TableCollection::new_from_ll(sys::TableCollection::new_borrowed( + std::ptr::NonNull::new(ptr.as_ref().tables).unwrap(), + )) + }?; + let inner = sys::TreeSequence::new_owning_from_nonnull(ptr); + Ok(Self { inner, tables }) + } + + pub fn into_mut_ptr(self) -> Option> { + std::ptr::NonNull::new(self.inner.into_raw()) + } } impl TryFrom for TreeSequence { @@ -741,3 +759,24 @@ impl TryFrom for TreeSequence { Self::new(value, TreeSequenceFlags::default()) } } + +#[cfg(feature = "unsafe_init")] +#[test] +fn test_new_from_raw() { + let mut tables = crate::TableCollection::new(100.).unwrap(); + tables.add_node(0, 0.0, -1, -1).unwrap(); + + let treeseq = unsafe { libc::malloc(std::mem::size_of::()) } + as *mut sys::bindings::tsk_treeseq_t; + let rv = unsafe { + sys::bindings::tsk_treeseq_init( + treeseq, + tables.into_mut_ptr().unwrap().as_ptr(), + sys::bindings::TSK_TAKE_OWNERSHIP | sys::bindings::TSK_TS_INIT_BUILD_INDEXES, + ) + }; + assert_eq!(rv, 0); + let ptr = std::ptr::NonNull::new(treeseq).unwrap(); + let rs_treeseq = unsafe { TreeSequence::new_from_raw(ptr) }.unwrap(); + assert_eq!(rs_treeseq.nodes().num_rows(), 1); +}