diff --git a/examples/datafusion-ffi-example/python/tests/_test_table_provider_factory.py b/examples/datafusion-ffi-example/python/tests/_test_table_provider_factory.py new file mode 100644 index 000000000..b1e94ec73 --- /dev/null +++ b/examples/datafusion-ffi-example/python/tests/_test_table_provider_factory.py @@ -0,0 +1,41 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from datafusion import SessionContext +from datafusion_ffi_example import MyTableProviderFactory + + +def test_table_provider_factory_ffi() -> None: + ctx = SessionContext() + table = MyTableProviderFactory() + + ctx.register_table_factory("MY_FORMAT", table) + + # Create a new external table + ctx.sql(""" + CREATE EXTERNAL TABLE + foo + STORED AS my_format + LOCATION ''; + """).collect() + + # Query the pre-populated table + result = ctx.sql("SELECT * FROM foo;").collect() + assert len(result) == 2 + assert result[0].num_columns == 2 diff --git a/examples/datafusion-ffi-example/src/lib.rs b/examples/datafusion-ffi-example/src/lib.rs index 23f2001a2..405cc0a46 100644 --- a/examples/datafusion-ffi-example/src/lib.rs +++ b/examples/datafusion-ffi-example/src/lib.rs @@ -22,6 +22,7 @@ use crate::catalog_provider::{FixedSchemaProvider, MyCatalogProvider, MyCatalogP use crate::scalar_udf::IsNullUDF; use crate::table_function::MyTableFunction; use crate::table_provider::MyTableProvider; +use crate::table_provider_factory::MyTableProviderFactory; use crate::window_udf::MyRankUDF; pub(crate) mod aggregate_udf; @@ -29,6 +30,7 @@ pub(crate) mod catalog_provider; pub(crate) mod scalar_udf; pub(crate) mod table_function; pub(crate) mod table_provider; +pub(crate) mod table_provider_factory; pub(crate) mod utils; pub(crate) mod window_udf; @@ -37,6 +39,7 @@ fn datafusion_ffi_example(m: &Bound<'_, PyModule>) -> PyResult<()> { pyo3_log::init(); m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; diff --git a/examples/datafusion-ffi-example/src/table_provider_factory.rs b/examples/datafusion-ffi-example/src/table_provider_factory.rs new file mode 100644 index 000000000..1e139e919 --- /dev/null +++ b/examples/datafusion-ffi-example/src/table_provider_factory.rs @@ -0,0 +1,87 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use async_trait::async_trait; +use datafusion_catalog::{Session, TableProvider, TableProviderFactory}; +use datafusion_common::error::Result as DataFusionResult; +use datafusion_expr::CreateExternalTable; +use datafusion_ffi::table_provider_factory::FFI_TableProviderFactory; +use pyo3::types::PyCapsule; +use pyo3::{Bound, PyAny, PyResult, Python, pyclass, pymethods}; + +use crate::catalog_provider; +use crate::utils::ffi_logical_codec_from_pycapsule; + +#[derive(Debug)] +pub(crate) struct ExampleTableProviderFactory {} + +impl ExampleTableProviderFactory { + fn new() -> Self { + Self {} + } +} + +#[async_trait] +impl TableProviderFactory for ExampleTableProviderFactory { + async fn create( + &self, + _state: &dyn Session, + _cmd: &CreateExternalTable, + ) -> DataFusionResult> { + Ok(catalog_provider::my_table()) + } +} + +#[pyclass( + name = "MyTableProviderFactory", + module = "datafusion_ffi_example", + subclass +)] +#[derive(Debug)] +pub struct MyTableProviderFactory { + inner: Arc, +} + +impl Default for MyTableProviderFactory { + fn default() -> Self { + let inner = Arc::new(ExampleTableProviderFactory::new()); + Self { inner } + } +} + +#[pymethods] +impl MyTableProviderFactory { + #[new] + pub fn new() -> Self { + Self::default() + } + + pub fn __datafusion_table_provider_factory__<'py>( + &self, + py: Python<'py>, + codec: Bound, + ) -> PyResult> { + let name = cr"datafusion_table_provider_factory".into(); + let codec = ffi_logical_codec_from_pycapsule(codec)?; + let factory = Arc::clone(&self.inner) as Arc; + let factory = FFI_TableProviderFactory::new_with_ffi_codec(factory, None, codec); + + PyCapsule::new(py, factory, Some(name)) + } +} diff --git a/python/datafusion/catalog.py b/python/datafusion/catalog.py index bc43cf349..03c0ddc68 100644 --- a/python/datafusion/catalog.py +++ b/python/datafusion/catalog.py @@ -29,6 +29,7 @@ from datafusion import DataFrame, SessionContext from datafusion.context import TableProviderExportable + from datafusion.expr import CreateExternalTable try: from warnings import deprecated # Python 3.13+ @@ -243,6 +244,24 @@ def kind(self) -> str: return self._inner.kind +class TableProviderFactory(ABC): + """Abstract class for defining a Python based Table Provider Factory.""" + + @abstractmethod + def create(self, cmd: CreateExternalTable) -> Table: + """Create a table using the :class:`CreateExternalTable`.""" + ... + + +class TableProviderFactoryExportable(Protocol): + """Type hint for object that has __datafusion_table_provider_factory__ PyCapsule. + + https://docs.rs/datafusion/latest/datafusion/catalog/trait.TableProviderFactory.html + """ + + def __datafusion_table_provider_factory__(self, session: Any) -> object: ... + + class CatalogProviderList(ABC): """Abstract class for defining a Python based Catalog Provider List.""" diff --git a/python/datafusion/context.py b/python/datafusion/context.py index 0d8259774..ba9290a58 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -37,6 +37,8 @@ CatalogProviderExportable, CatalogProviderList, CatalogProviderListExportable, + TableProviderFactory, + TableProviderFactoryExportable, ) from datafusion.dataframe import DataFrame from datafusion.expr import sort_list_to_raw_sort_list @@ -830,6 +832,22 @@ def deregister_table(self, name: str) -> None: """Remove a table from the session.""" self.ctx.deregister_table(name) + def register_table_factory( + self, + format: str, + factory: TableProviderFactory | TableProviderFactoryExportable, + ) -> None: + """Register a :py:class:`~datafusion.TableProviderFactoryExportable`. + + The registered factory can be referenced from SQL DDL statements executed + against this context. + + Args: + format: The value to be used in `STORED AS ${format}` clause. + factory: A PyCapsule that implements :class:`TableProviderFactoryExportable` + """ + self.ctx.register_table_factory(format, factory) + def catalog_names(self) -> set[str]: """Returns the list of catalogs in this context.""" return self.ctx.catalog_names() diff --git a/python/tests/test_catalog.py b/python/tests/test_catalog.py index 9310da506..c89da36bf 100644 --- a/python/tests/test_catalog.py +++ b/python/tests/test_catalog.py @@ -120,6 +120,12 @@ def register_catalog( self.catalogs[name] = catalog +class CustomTableProviderFactory(dfn.catalog.TableProviderFactory): + def create(self, cmd: dfn.expr.CreateExternalTable): + assert cmd.name() == "test_table_factory" + return create_dataset() + + def test_python_catalog_provider_list(ctx: SessionContext): ctx.register_catalog_provider_list(CustomCatalogProviderList()) @@ -314,3 +320,24 @@ def my_table_function_udtf() -> Table: assert len(result[0]) == 1 assert len(result[0][0]) == 1 assert result[0][0][0].as_py() == 3 + + +def test_register_python_table_provider_factory(ctx: SessionContext): + ctx.register_table_factory("CUSTOM_FACTORY", CustomTableProviderFactory()) + + ctx.sql(""" + CREATE EXTERNAL TABLE test_table_factory + STORED AS CUSTOM_FACTORY + LOCATION foo; + """).collect() + + result = ctx.sql("SELECT * FROM test_table_factory;").collect() + + expect = [ + pa.RecordBatch.from_arrays( + [pa.array([1, 2, 3]), pa.array([4, 5, 6])], + names=["a", "b"], + ) + ] + + assert result == expect diff --git a/src/context.rs b/src/context.rs index 2eaf5a737..a0436c143 100644 --- a/src/context.rs +++ b/src/context.rs @@ -27,7 +27,7 @@ use arrow::pyarrow::FromPyArrow; use datafusion::arrow::datatypes::{DataType, Schema, SchemaRef}; use datafusion::arrow::pyarrow::PyArrowType; use datafusion::arrow::record_batch::RecordBatch; -use datafusion::catalog::{CatalogProvider, CatalogProviderList}; +use datafusion::catalog::{CatalogProvider, CatalogProviderList, TableProviderFactory}; use datafusion::common::{ScalarValue, TableReference, exec_err}; use datafusion::datasource::file_format::file_compression_type::FileCompressionType; use datafusion::datasource::file_format::parquet::ParquetFormat; @@ -51,6 +51,7 @@ use datafusion_ffi::catalog_provider::FFI_CatalogProvider; use datafusion_ffi::catalog_provider_list::FFI_CatalogProviderList; use datafusion_ffi::execution::FFI_TaskContextProvider; use datafusion_ffi::proto::logical_extension_codec::FFI_LogicalExtensionCodec; +use datafusion_ffi::table_provider_factory::FFI_TableProviderFactory; use datafusion_proto::logical_plan::DefaultLogicalExtensionCodec; use object_store::ObjectStore; use pyo3::IntoPyObjectExt; @@ -77,7 +78,7 @@ use crate::record_batch::PyRecordBatchStream; use crate::sql::logical::PyLogicalPlan; use crate::sql::util::replace_placeholders_with_strings; use crate::store::StorageContexts; -use crate::table::PyTable; +use crate::table::{PyTable, RustWrappedPyTableProviderFactory}; use crate::udaf::PyAggregateUDF; use crate::udf::PyScalarUDF; use crate::udtf::PyTableFunction; @@ -659,6 +660,43 @@ impl PySessionContext { Ok(()) } + pub fn register_table_factory( + &self, + format: &str, + mut factory: Bound<'_, PyAny>, + ) -> PyDataFusionResult<()> { + if factory.hasattr("__datafusion_table_provider_factory__")? { + let py = factory.py(); + let codec_capsule = create_logical_extension_capsule(py, self.logical_codec.as_ref())?; + factory = factory + .getattr("__datafusion_table_provider_factory__")? + .call1((codec_capsule,))?; + } + + let factory: Arc = + if let Ok(capsule) = factory.cast::().map_err(py_datafusion_err) { + validate_pycapsule(capsule, "datafusion_table_provider_factory")?; + + let data: NonNull = capsule + .pointer_checked(Some(c_str!("datafusion_table_provider_factory")))? + .cast(); + let factory = unsafe { data.as_ref() }; + factory.into() + } else { + Arc::new(RustWrappedPyTableProviderFactory::new( + factory.into(), + self.logical_codec.clone(), + )) + }; + + let st = self.ctx.state_ref(); + let mut lock = st.write(); + lock.table_factories_mut() + .insert(format.to_owned(), factory); + + Ok(()) + } + pub fn register_catalog_provider_list( &self, mut provider: Bound, diff --git a/src/table.rs b/src/table.rs index b9f30af9c..5527fbe9a 100644 --- a/src/table.rs +++ b/src/table.rs @@ -21,19 +21,23 @@ use std::sync::Arc; use arrow::datatypes::SchemaRef; use arrow::pyarrow::ToPyArrow; use async_trait::async_trait; -use datafusion::catalog::Session; +use datafusion::catalog::{Session, TableProviderFactory}; use datafusion::common::Column; use datafusion::datasource::{TableProvider, TableType}; -use datafusion::logical_expr::{Expr, LogicalPlanBuilder, TableProviderFilterPushDown}; +use datafusion::logical_expr::{ + CreateExternalTable, Expr, LogicalPlanBuilder, TableProviderFilterPushDown, +}; use datafusion::physical_plan::ExecutionPlan; use datafusion::prelude::DataFrame; +use datafusion_ffi::proto::logical_extension_codec::FFI_LogicalExtensionCodec; use pyo3::IntoPyObjectExt; use pyo3::prelude::*; use crate::context::PySessionContext; use crate::dataframe::PyDataFrame; use crate::dataset::Dataset; -use crate::utils::table_provider_from_pycapsule; +use crate::expr::create_external_table::PyCreateExternalTable; +use crate::{errors, utils}; /// This struct is used as a common method for all TableProviders, /// whether they refer to an FFI provider, an internally known @@ -91,7 +95,7 @@ impl PyTable { Some(session) => session, None => PySessionContext::global_ctx()?.into_bound_py_any(obj.py())?, }; - table_provider_from_pycapsule(obj.clone(), session)? + utils::table_provider_from_pycapsule(obj.clone(), session)? } { Ok(PyTable::from(provider)) } else { @@ -206,3 +210,51 @@ impl TableProvider for TempViewTable { Ok(vec![TableProviderFilterPushDown::Exact; filters.len()]) } } + +#[derive(Debug)] +pub(crate) struct RustWrappedPyTableProviderFactory { + pub(crate) table_provider_factory: Py, + pub(crate) codec: Arc, +} + +impl RustWrappedPyTableProviderFactory { + pub fn new(table_provider_factory: Py, codec: Arc) -> Self { + Self { + table_provider_factory, + codec, + } + } + + fn create_inner( + &self, + cmd: CreateExternalTable, + codec: Bound, + ) -> PyResult> { + Python::attach(|py| { + let provider = self.table_provider_factory.bind(py); + let cmd = PyCreateExternalTable::from(cmd); + + provider + .call_method1("create", (cmd,)) + .and_then(|t| PyTable::new(t, Some(codec))) + .map(|t| t.table()) + }) + } +} + +#[async_trait] +impl TableProviderFactory for RustWrappedPyTableProviderFactory { + async fn create( + &self, + _: &dyn Session, + cmd: &CreateExternalTable, + ) -> datafusion::common::Result> { + Python::attach(|py| { + let codec = utils::create_logical_extension_capsule(py, self.codec.as_ref()) + .map_err(errors::to_datafusion_err)?; + + self.create_inner(cmd.clone(), codec.into_any()) + .map_err(errors::to_datafusion_err) + }) + } +}