Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,7 @@ impl TableProvider for IndexTableProvider {
.partitioned_file()
// provide the starting access plan to the DataSourceExec by
// storing it as "extensions" on PartitionedFile
.with_extensions(Arc::new(access_plan) as _);
.with_extension(access_plan);

// Prepare for scanning
let schema = self.schema();
Expand Down
211 changes: 211 additions & 0 deletions datafusion/common/src/extensions.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

//! A type-keyed map of opaque, `Arc`'d objects.
//!
//! Used as the backing store for the various `extensions` fields throughout
//! DataFusion (e.g. [`SessionConfig`], [`ExtendedStatistics`],
//! [`PartitionedFile`]) so that independent components can each attach
//! their own data without conflict, each keyed by its concrete Rust type.
//!
//! [`SessionConfig`]: https://docs.rs/datafusion-execution/latest/datafusion_execution/config/struct.SessionConfig.html
//! [`ExtendedStatistics`]: https://docs.rs/datafusion-physical-plan/latest/datafusion_physical_plan/operator_statistics/struct.ExtendedStatistics.html
//! [`PartitionedFile`]: https://docs.rs/datafusion-datasource/latest/datafusion_datasource/struct.PartitionedFile.html

use std::any::{Any, TypeId};
use std::collections::HashMap;
use std::hash::{BuildHasherDefault, Hasher};
use std::sync::Arc;

/// A type-keyed map of opaque `Arc`'d values. Each Rust type `T` occupies
/// its own slot, so independent components can each attach their own data
/// without conflict.
///
/// Cloning is cheap: the backing values are reference-counted.
///
/// # Example
///
/// ```
/// # use std::sync::Arc;
/// # use datafusion_common::extensions::Extensions;
/// struct MyData(u32);
/// struct OtherData(&'static str);
///
/// let mut ext = Extensions::new();
/// ext.insert(MyData(42));
/// ext.insert_arc(Arc::new(OtherData("hello")));
///
/// assert_eq!(ext.get::<MyData>().unwrap().0, 42);
/// assert_eq!(ext.get::<OtherData>().unwrap().0, "hello");
/// ```
#[derive(Debug, Clone, Default)]
pub struct Extensions {
inner: HashMap<TypeId, Arc<dyn Any + Send + Sync>, BuildHasherDefault<IdHasher>>,
}

impl Extensions {
/// Create an empty map.
pub fn new() -> Self {
Self::default()
}

/// Returns true if no extensions are set.
pub fn is_empty(&self) -> bool {
self.inner.is_empty()
}

/// Number of extensions set.
pub fn len(&self) -> usize {
self.inner.len()
}

/// Insert an extension keyed by its concrete type `T`. Returns the
/// previous value of that type, if any.
///
/// The value is wrapped in an [`Arc`] internally. If the caller already
/// has an `Arc<T>` and wants to avoid an extra allocation, use
/// [`Self::insert_arc`].
pub fn insert<T: Any + Send + Sync>(&mut self, value: T) -> Option<Arc<T>> {
self.insert_arc(Arc::new(value))
}

/// Insert an extension keyed by its concrete type `T`, taking an
/// already-allocated [`Arc<T>`]. Returns the previous value of that type,
/// if any.
pub fn insert_arc<T: Any + Send + Sync>(&mut self, value: Arc<T>) -> Option<Arc<T>> {
self.inner
.insert(TypeId::of::<T>(), value)
.map(|p| Arc::downcast::<T>(p).expect("TypeId matches T"))
}

/// Insert an already-type-erased value, keyed by its dynamic
/// [`TypeId`]. Used internally to support APIs that accept
/// `Arc<dyn Any + Send + Sync>` for backwards compatibility and need
/// to recover the concrete type for keying.
///
/// New code should use [`Self::insert`] or [`Self::insert_arc`], which
/// preserve the concrete type at the call site.
#[deprecated(
since = "54.0.0",
note = "use `insert` or `insert_arc`; only retained to support the deprecated `PartitionedFile::with_extensions` shim"
)]
pub fn insert_dyn(
&mut self,
value: Arc<dyn Any + Send + Sync>,
) -> Option<Arc<dyn Any + Send + Sync>> {
let id = (*value).type_id();
self.inner.insert(id, value)
}

/// Borrow the extension of type `T`, if set.
pub fn get<T: Any + Send + Sync>(&self) -> Option<&T> {
self.inner
.get(&TypeId::of::<T>())
.and_then(|a| a.downcast_ref::<T>())
}

/// Get a cloned `Arc<T>` of the extension, if set.
pub fn get_arc<T: Any + Send + Sync>(&self) -> Option<Arc<T>> {
self.inner
.get(&TypeId::of::<T>())
.map(|a| Arc::downcast::<T>(Arc::clone(a)).expect("TypeId matches T"))
}

/// Returns true if an extension of type `T` is set.
pub fn contains<T: Any + Send + Sync>(&self) -> bool {
self.inner.contains_key(&TypeId::of::<T>())
}

/// Merge entries from `other` into `self`. Entries in `other` take
/// precedence over existing entries with the same type.
pub fn merge(&mut self, other: &Extensions) {
for (id, ext) in &other.inner {
self.inner.insert(*id, Arc::clone(ext));
}
}
}

/// Hasher specialized for [`TypeId`] keys. Since `TypeId` is already a
/// hash produced by the compiler, we don't need to hash it again — we
/// just store the `u64` it writes and return it unchanged.
#[derive(Default)]
struct IdHasher(u64);

impl Hasher for IdHasher {
fn write(&mut self, _: &[u8]) {
unreachable!("TypeId calls write_u64");
}

#[inline]
fn write_u64(&mut self, id: u64) {
self.0 = id;
}

#[inline]
fn finish(&self) -> u64 {
self.0
}
}

#[cfg(test)]
mod tests {
use super::*;

#[derive(Debug, PartialEq)]
struct A(u32);

#[derive(Debug, PartialEq)]
struct B(&'static str);

#[test]
fn insert_get_replace() {
let mut ext = Extensions::new();
assert!(ext.is_empty());

ext.insert(A(1));
ext.insert_arc(Arc::new(B("x")));
assert_eq!(ext.len(), 2);
assert_eq!(ext.get::<A>(), Some(&A(1)));
assert_eq!(ext.get::<B>(), Some(&B("x")));
assert!(ext.contains::<A>());

let prev = ext.insert(A(2));
assert_eq!(prev.as_deref(), Some(&A(1)));
assert_eq!(ext.get::<A>(), Some(&A(2)));
}

#[test]
#[expect(deprecated)]
fn insert_dyn_keys_by_concrete_type() {
let mut ext = Extensions::new();
let erased: Arc<dyn Any + Send + Sync> = Arc::new(A(7));
ext.insert_dyn(erased);
assert_eq!(ext.get::<A>(), Some(&A(7)));
}

#[test]
fn merge_other_wins() {
let mut a = Extensions::new();
a.insert(A(1));
let mut b = Extensions::new();
b.insert(A(2));
b.insert(B("hi"));
a.merge(&b);
assert_eq!(a.get::<A>(), Some(&A(2)));
assert_eq!(a.get::<B>(), Some(&B("hi")));
}
}
1 change: 1 addition & 0 deletions datafusion/common/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ pub mod diagnostic;
pub mod display;
pub mod encryption;
pub mod error;
pub mod extensions;
pub mod file_options;
pub mod format;
pub mod hash_utils;
Expand Down
86 changes: 80 additions & 6 deletions datafusion/core/tests/parquet/custom_reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ use bytes::Bytes;
use datafusion_datasource::file_scan_config::FileScanConfigBuilder;
use datafusion_datasource::source::DataSourceExec;
use datafusion_datasource_parquet::metadata::DFParquetMetadata;
use datafusion_datasource_parquet::{ParquetAccessPlan, RowGroupAccess};
use futures::future::BoxFuture;
use futures::{FutureExt, TryFutureExt};
use insta::assert_snapshot;
Expand All @@ -49,6 +50,7 @@ use parquet::arrow::arrow_reader::ArrowReaderOptions;
use parquet::arrow::async_reader::AsyncFileReader;
use parquet::errors::ParquetError;
use parquet::file::metadata::ParquetMetaData;
use parquet::file::properties::WriterProperties;

const EXPECTED_USER_DEFINED_METADATA: &str = "some-user-defined-metadata";

Expand All @@ -71,7 +73,7 @@ async fn route_data_access_ops_to_parquet_file_reader_factory() {
.into_iter()
.map(|meta| {
PartitionedFile::new_from_meta(meta)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be great to add one integration regression test that attaches both a ParquetAccessPlan and a custom reader payload to the same PartitionedFile. The test could then assert that the custom reader still sees its payload and that the parquet access plan is honored.

The current tests cover the generic map and the two consumers separately, but not quite the end-to-end invariant this PR is trying to protect.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.with_extensions(Arc::new(String::from(EXPECTED_USER_DEFINED_METADATA)))
.with_extension(String::from(EXPECTED_USER_DEFINED_METADATA))
})
.collect();

Expand Down Expand Up @@ -107,6 +109,82 @@ async fn route_data_access_ops_to_parquet_file_reader_factory() {
");
}

/// Regression test for the type-keyed extensions map: independent components
/// must be able to attach their own per-file payloads on the same
/// [`PartitionedFile`] without colliding. Here we attach a custom reader
/// payload (the `String` checked by [`InMemoryParquetFileReaderFactory`])
/// *and* a [`ParquetAccessPlan`] that skips the first row group, then verify
/// (a) the factory still sees its payload (its internal `assert_eq!` would
/// fire if the slot got overwritten) and (b) the access plan is honored — so
/// only the second row group's 5 rows come out, not all 10.
#[tokio::test]
async fn custom_payload_and_access_plan_coexist() {
// Two row groups of 5 rows each: values 0..=4 in row group 0, 5..=9 in
// row group 1.
let c1: ArrayRef = Arc::new(Int64Array::from((0..10).collect::<Vec<i64>>()));
let batch = create_batch(vec![("c1", c1)]);
let file_schema = batch.schema().clone();

let in_memory = InMemory::new();
let mut buf = Vec::<u8>::with_capacity(32 * 1024);
let props = WriterProperties::builder()
.set_max_row_group_row_count(Some(5))
.build();
let mut writer = ArrowWriter::try_new(&mut buf, batch.schema(), Some(props)).unwrap();
writer.write(&batch).unwrap();
writer.close().unwrap();

let location = Path::parse("two-row-groups.parquet").unwrap();
let size = buf.len() as u64;
in_memory
.put(&location, Bytes::from(buf).into())
.await
.unwrap();
let meta = ObjectMeta {
location,
last_modified: chrono::DateTime::from(SystemTime::now()),
size,
e_tag: None,
version: None,
};

let access_plan =
ParquetAccessPlan::new(vec![RowGroupAccess::Skip, RowGroupAccess::Scan]);
let pf = PartitionedFile::new_from_meta(meta)
.with_extension(String::from(EXPECTED_USER_DEFINED_METADATA))
.with_extension(access_plan);

let store: Arc<dyn ObjectStore> = Arc::new(in_memory);
let source = Arc::new(
ParquetSource::new(file_schema.clone()).with_parquet_file_reader_factory(
Arc::new(InMemoryParquetFileReaderFactory(Arc::clone(&store))),
),
);
let base_config =
FileScanConfigBuilder::new(ObjectStoreUrl::local_filesystem(), source)
.with_file_group(vec![pf].into())
.build();
let parquet_exec = DataSourceExec::from_data_source(base_config);

let session_ctx = SessionContext::new();
let read = collect(parquet_exec, session_ctx.task_ctx()).await.unwrap();

let total: usize = read.iter().map(|b| b.num_rows()).sum();
assert_eq!(
total, 5,
"access plan should have skipped the first row group"
);

let values: Vec<i64> = read
.iter()
.flat_map(|b| {
let arr = b.column(0).as_any().downcast_ref::<Int64Array>().unwrap();
(0..arr.len()).map(|i| arr.value(i)).collect::<Vec<_>>()
})
.collect();
assert_eq!(values, vec![5, 6, 7, 8, 9]);
}

#[derive(Debug)]
struct InMemoryParquetFileReaderFactory(Arc<dyn ObjectStore>);

Expand All @@ -119,12 +197,8 @@ impl ParquetFileReaderFactory for InMemoryParquetFileReaderFactory {
metrics: &ExecutionPlanMetricsSet,
) -> Result<Box<dyn AsyncFileReader + Send>> {
let metadata = partitioned_file
.extensions
.as_ref()
.extension::<String>()
.expect("has user defined metadata");
let metadata = metadata
.downcast_ref::<String>()
.expect("has string metadata");

assert_eq!(EXPECTED_USER_DEFINED_METADATA, &metadata[..]);

Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/tests/parquet/external_access_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ impl TestFull {

// add the access plan, if any, as an extension
if let Some(access_plan) = access_plan {
partitioned_file = partitioned_file.with_extensions(Arc::new(access_plan));
partitioned_file = partitioned_file.with_extension(access_plan);
}

// Create a DataSourceExec to read the file
Expand Down
Loading
Loading