Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

import pyarrow as pa
import pytest
from datafusion import SessionContext
from datafusion_ffi_example import MyTableProviderFactory


def test_table_provider_factory_ffi() -> None:
ctx = SessionContext()
table = MyTableProviderFactory()

ctx.register_table_factory("MY_FORMAT", table)

# Create a new external table
ctx.sql("""
CREATE EXTERNAL TABLE
foo
STORED AS my_format
LOCATION '';
""")

# Query the pre-populated table
result = ctx.sql("SELECT * FROM foo;").collect()
assert len(result) == 2
assert result[0].num_columns == 2
3 changes: 3 additions & 0 deletions examples/datafusion-ffi-example/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,15 @@ use crate::catalog_provider::{FixedSchemaProvider, MyCatalogProvider, MyCatalogP
use crate::scalar_udf::IsNullUDF;
use crate::table_function::MyTableFunction;
use crate::table_provider::MyTableProvider;
use crate::table_provider_factory::MyTableProviderFactory;
use crate::window_udf::MyRankUDF;

pub(crate) mod aggregate_udf;
pub(crate) mod catalog_provider;
pub(crate) mod scalar_udf;
pub(crate) mod table_function;
pub(crate) mod table_provider;
pub(crate) mod table_provider_factory;
pub(crate) mod utils;
pub(crate) mod window_udf;

Expand All @@ -37,6 +39,7 @@ fn datafusion_ffi_example(m: &Bound<'_, PyModule>) -> PyResult<()> {
pyo3_log::init();

m.add_class::<MyTableProvider>()?;
m.add_class::<MyTableProviderFactory>()?;
m.add_class::<MyTableFunction>()?;
m.add_class::<MyCatalogProvider>()?;
m.add_class::<MyCatalogProviderList>()?;
Expand Down
87 changes: 87 additions & 0 deletions examples/datafusion-ffi-example/src/table_provider_factory.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use std::sync::Arc;

use async_trait::async_trait;
use datafusion_catalog::{Session, TableProvider, TableProviderFactory};
use datafusion_common::error::Result as DataFusionResult;
use datafusion_expr::CreateExternalTable;
use datafusion_ffi::table_provider_factory::FFI_TableProviderFactory;
use pyo3::types::PyCapsule;
use pyo3::{Bound, PyAny, PyResult, Python, pyclass, pymethods};

use crate::catalog_provider;
use crate::utils::ffi_logical_codec_from_pycapsule;

#[derive(Debug)]
pub(crate) struct ExampleTableProviderFactory {}

impl ExampleTableProviderFactory {
fn new() -> Self {
Self {}
}
}

#[async_trait]
impl TableProviderFactory for ExampleTableProviderFactory {
async fn create(
&self,
_state: &dyn Session,
_cmd: &CreateExternalTable,
) -> DataFusionResult<Arc<dyn TableProvider>> {
Ok(catalog_provider::my_table())
}
}

#[pyclass(
name = "MyTableProviderFactory",
module = "datafusion_ffi_example",
subclass
)]
#[derive(Debug)]
pub struct MyTableProviderFactory {
inner: Arc<ExampleTableProviderFactory>,
}

impl Default for MyTableProviderFactory {
fn default() -> Self {
let inner = Arc::new(ExampleTableProviderFactory::new());
Self { inner }
}
}

#[pymethods]
impl MyTableProviderFactory {
#[new]
pub fn new() -> Self {
Self::default()
}

pub fn __datafusion_table_provider_factory__<'py>(
&self,
py: Python<'py>,
codec: Bound<PyAny>,
) -> PyResult<Bound<'py, PyCapsule>> {
let name = cr"datafusion_table_provider_factory".into();
let codec = ffi_logical_codec_from_pycapsule(codec)?;
let factory = Arc::clone(&self.inner) as Arc<dyn TableProviderFactory + Send>;
let factory = FFI_TableProviderFactory::new_with_ffi_codec(factory, None, codec);

PyCapsule::new(py, factory, Some(name))
}
}
9 changes: 9 additions & 0 deletions python/datafusion/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,15 @@ def kind(self) -> str:
return self._inner.kind


class TableProviderFactoryExportable(Protocol):
"""Type hint for object that has __datafusion_table_provider_factory__ PyCapsule.

https://docs.rs/datafusion/latest/datafusion/catalog/trait.TableProviderFactory.html
"""

def __datafusion_table_provider_factory__(self, session: Any) -> object: ...


class CatalogProviderList(ABC):
"""Abstract class for defining a Python based Catalog Provider List."""

Expand Down
15 changes: 15 additions & 0 deletions python/datafusion/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
CatalogProviderExportable,
CatalogProviderList,
CatalogProviderListExportable,
TableProviderFactoryExportable,
)
from datafusion.dataframe import DataFrame
from datafusion.expr import sort_list_to_raw_sort_list
Expand Down Expand Up @@ -830,6 +831,20 @@ def deregister_table(self, name: str) -> None:
"""Remove a table from the session."""
self.ctx.deregister_table(name)

def register_table_factory(
self, format: str, factory: TableProviderFactoryExportable
) -> None:
"""Register a :py:class:`~datafusion.TableProviderFactoryExportable`.

The registered factory can be reference from SQL DDL statements executed
against this context.

Args:
format: The value to be used in `STORED AS ${format}` clause.
factory: A PyCapsule that implements TableProviderFactoryExportable"
"""
self.ctx.register_table_factory(format, factory)

def catalog_names(self) -> set[str]:
"""Returns the list of catalogs in this context."""
return self.ctx.catalog_names()
Expand Down
31 changes: 30 additions & 1 deletion src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ use arrow::pyarrow::FromPyArrow;
use datafusion::arrow::datatypes::{DataType, Schema, SchemaRef};
use datafusion::arrow::pyarrow::PyArrowType;
use datafusion::arrow::record_batch::RecordBatch;
use datafusion::catalog::{CatalogProvider, CatalogProviderList};
use datafusion::catalog::{CatalogProvider, CatalogProviderList, TableProviderFactory};
use datafusion::common::{ScalarValue, TableReference, exec_err};
use datafusion::datasource::file_format::file_compression_type::FileCompressionType;
use datafusion::datasource::file_format::parquet::ParquetFormat;
Expand All @@ -51,6 +51,7 @@ use datafusion_ffi::catalog_provider::FFI_CatalogProvider;
use datafusion_ffi::catalog_provider_list::FFI_CatalogProviderList;
use datafusion_ffi::execution::FFI_TaskContextProvider;
use datafusion_ffi::proto::logical_extension_codec::FFI_LogicalExtensionCodec;
use datafusion_ffi::table_provider_factory::FFI_TableProviderFactory;
use datafusion_proto::logical_plan::DefaultLogicalExtensionCodec;
use object_store::ObjectStore;
use pyo3::IntoPyObjectExt;
Expand Down Expand Up @@ -659,6 +660,34 @@ impl PySessionContext {
Ok(())
}

pub fn register_table_factory(
&self,
format: &str,
factory: Bound<'_, PyAny>,
) -> PyDataFusionResult<()> {
let py = factory.py();
let codec_capsule = create_logical_extension_capsule(py, self.logical_codec.as_ref())?;

let capsule = factory
.getattr("__datafusion_table_provider_factory__")?
.call1((codec_capsule,))?;
let capsule = capsule.cast::<PyCapsule>().map_err(py_datafusion_err)?;
validate_pycapsule(capsule, "datafusion_table_provider_factory")?;

let factory: NonNull<FFI_TableProviderFactory> = capsule
.pointer_checked(Some(c_str!("datafusion_table_provider_factory")))?
.cast();
let factory = unsafe { factory.as_ref() };
let factory: Arc<dyn TableProviderFactory> = factory.into();

let st = self.ctx.state_ref();
let mut lock = st.write();
lock.table_factories_mut()
.insert(format.to_owned(), factory);

Ok(())
}

pub fn register_catalog_provider_list(
&self,
mut provider: Bound<PyAny>,
Expand Down