Skip to content

Commit 12cf674

Browse files
timsaucerclaude
andcommitted
refactor: opt-in UDTF session injection via with_session flag
Replaces signature sniffing with an explicit ``with_session=True`` kwarg on ``TableFunction`` / ``udtf``. Avoids name-based detection footguns (positional-only ``session`` params, accidental ``**kwargs`` opt-in, shadowing by unrelated params) and makes author intent visible at registration. Also documents the feature in the UDTF user guide. Rust field renamed ``accepts_session`` -> ``inject_session_on_call`` to match the Python-side opt-in semantics. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 67715c6 commit 12cf674

4 files changed

Lines changed: 111 additions & 58 deletions

File tree

crates/core/src/udtf.rs

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,11 @@ use crate::table::PyTable;
3939
#[derive(Debug, Clone)]
4040
pub(crate) struct PythonTableFunctionCallable {
4141
pub(crate) callable: Arc<Py<PyAny>>,
42-
/// Whether the callable's signature accepts a ``session`` keyword
43-
/// argument (or ``**kwargs``). When true the calling
44-
/// :class:`SessionContext` is threaded through on each invocation.
45-
pub(crate) accepts_session: bool,
42+
/// When true, the calling :class:`SessionContext` is passed to the
43+
/// callable as a ``session`` keyword argument on every invocation.
44+
/// Opt-in at registration time via ``with_session=True`` on the
45+
/// Python wrapper.
46+
pub(crate) inject_session_on_call: bool,
4647
}
4748

4849
/// Represents a user defined table function
@@ -62,12 +63,12 @@ pub(crate) enum PyTableFunctionInner {
6263
#[pymethods]
6364
impl PyTableFunction {
6465
#[new]
65-
#[pyo3(signature=(name, func, session, accepts_session=false))]
66+
#[pyo3(signature=(name, func, session, inject_session_on_call=false))]
6667
pub fn new(
6768
name: &str,
6869
func: Bound<'_, PyAny>,
6970
session: Option<Bound<PyAny>>,
70-
accepts_session: bool,
71+
inject_session_on_call: bool,
7172
) -> PyResult<Self> {
7273
let inner = if func.hasattr("__datafusion_table_function__")? {
7374
let py = func.py();
@@ -95,7 +96,7 @@ impl PyTableFunction {
9596
} else {
9697
PyTableFunctionInner::PythonFunction(PythonTableFunctionCallable {
9798
callable: Arc::new(func.unbind()),
98-
accepts_session,
99+
inject_session_on_call,
99100
})
100101
};
101102

@@ -162,7 +163,7 @@ fn call_python_table_function(
162163
func: &PythonTableFunctionCallable,
163164
args: TableFunctionArgs,
164165
) -> DataFusionResult<Arc<dyn TableProvider>> {
165-
let py_session = if func.accepts_session {
166+
let py_session = if func.inject_session_on_call {
166167
Some(py_session_from_session(args.session())?)
167168
} else {
168169
None

docs/source/user-guide/common-operations/udf-and-udfa.rst

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -431,3 +431,39 @@ that you wish to expose via PyO3, you need to expose it as a ``PyCapsule``.
431431
PyCapsule::new(py, provider, Some(name))
432432
}
433433
}
434+
435+
Accessing the Calling Session
436+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
437+
438+
Pure-Python UDTFs can opt into receiving the calling
439+
:py:class:`~datafusion.SessionContext` by registering with
440+
``with_session=True``. The context is passed as a ``session`` keyword
441+
argument on every invocation. Use it to look up registered tables,
442+
UDFs, or session configuration from inside the callback.
443+
444+
.. code-block:: python
445+
446+
from datafusion import SessionContext, Table, udtf
447+
from datafusion.context import TableProviderExportable
448+
import pyarrow as pa
449+
import pyarrow.dataset as ds
450+
451+
@udtf("list_tables", with_session=True)
452+
def list_tables(*, session: SessionContext) -> TableProviderExportable:
453+
names = sorted(session.catalog().schema().names())
454+
batch = pa.RecordBatch.from_pydict({"name": names})
455+
return Table(ds.dataset([batch]))
456+
457+
ctx = SessionContext()
458+
ctx.register_batch("t1", pa.RecordBatch.from_pydict({"x": [1]}))
459+
ctx.register_udtf(list_tables)
460+
ctx.sql("SELECT * FROM list_tables()").show()
461+
462+
Without ``with_session=True``, the callback receives only the positional
463+
expression arguments. The flag is opt-in so existing UDTFs keep working
464+
unchanged.
465+
466+
The injected ``session`` is a fresh :py:class:`~datafusion.SessionContext`
467+
wrapper backed by the same underlying state as the caller, so registries
468+
(tables, UDFs, catalogs) are visible. Mutations made through it affect
469+
the live session.

python/datafusion/user_defined.py

Lines changed: 44 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1054,29 +1054,6 @@ def from_pycapsule(func: WindowUDFExportable) -> WindowUDF:
10541054
)
10551055

10561056

1057-
def _callable_accepts_session_kwarg(func: object) -> bool:
1058-
"""Return True if ``func`` accepts a ``session`` keyword argument.
1059-
1060-
Used to opt a Python UDTF callback into receiving the calling
1061-
:class:`SessionContext` at invocation time. ``**kwargs`` callables
1062-
are treated as accepting it; built-ins and objects without an
1063-
introspectable signature fall back to ``False``.
1064-
"""
1065-
import inspect # noqa: PLC0415
1066-
1067-
try:
1068-
signature = inspect.signature(func)
1069-
except (TypeError, ValueError):
1070-
return False
1071-
1072-
for parameter in signature.parameters.values():
1073-
if parameter.name == "session":
1074-
return True
1075-
if parameter.kind is inspect.Parameter.VAR_KEYWORD:
1076-
return True
1077-
return False
1078-
1079-
10801057
def _wrap_session_kwarg_for_udtf(func: Callable[..., Any]) -> Callable[..., Any]:
10811058
"""Adapt the raw internal session pyo3 object back to a Python wrapper.
10821059
@@ -1103,23 +1080,27 @@ class TableFunction:
11031080
"""
11041081

11051082
def __init__(
1106-
self, name: str, func: Callable[[], any], ctx: SessionContext | None = None
1083+
self,
1084+
name: str,
1085+
func: Callable[..., Any],
1086+
ctx: SessionContext | None = None,
1087+
*,
1088+
with_session: bool = False,
11071089
) -> None:
11081090
"""Instantiate a user-defined table function (UDTF).
11091091
1110-
If ``func``'s signature accepts a ``session`` keyword (or
1111-
``**kwargs``), the calling :class:`SessionContext` is threaded
1112-
through to it on each invocation. Use it inside the body to look
1113-
up registered tables, UDFs, or session configuration. Callables
1114-
whose signatures do not declare ``session`` are invoked with the
1115-
positional expression arguments only.
1092+
Set ``with_session=True`` to have the calling
1093+
:class:`SessionContext` passed as a ``session`` keyword argument
1094+
on each invocation. Use it inside the callback to look up
1095+
registered tables, UDFs, or session configuration. When
1096+
``with_session`` is ``False`` (the default), ``func`` is invoked
1097+
with the positional expression arguments only.
11161098
11171099
See :py:func:`udtf` for a convenience function and argument
11181100
descriptions.
11191101
"""
1120-
accepts_session = _callable_accepts_session_kwarg(func)
1121-
registered = _wrap_session_kwarg_for_udtf(func) if accepts_session else func
1122-
self._udtf = df_internal.TableFunction(name, registered, ctx, accepts_session)
1102+
registered = _wrap_session_kwarg_for_udtf(func) if with_session else func
1103+
self._udtf = df_internal.TableFunction(name, registered, ctx, with_session)
11231104

11241105
def __call__(self, *args: Expr) -> Any:
11251106
"""Execute the UDTF and return a table provider."""
@@ -1130,47 +1111,66 @@ def __call__(self, *args: Expr) -> Any:
11301111
@staticmethod
11311112
def udtf(
11321113
name: str,
1114+
*,
1115+
with_session: bool = False,
11331116
) -> Callable[..., Any]: ...
11341117

11351118
@overload
11361119
@staticmethod
11371120
def udtf(
1138-
func: Callable[[], Any],
1121+
func: Callable[..., Any],
11391122
name: str,
1123+
*,
1124+
with_session: bool = False,
11401125
) -> TableFunction: ...
11411126

11421127
@staticmethod
1143-
def udtf(*args: Any, **kwargs: Any):
1144-
"""Create a new User-Defined Table Function (UDTF)."""
1128+
def udtf(*args: Any, with_session: bool = False, **kwargs: Any):
1129+
"""Create a new User-Defined Table Function (UDTF).
1130+
1131+
Pass ``with_session=True`` to have the calling
1132+
:class:`SessionContext` injected as a ``session`` keyword
1133+
argument on each invocation.
1134+
"""
11451135
if args and callable(args[0]):
11461136
# Case 1: Used as a function, require the first parameter to be callable
1147-
return TableFunction._create_table_udf(*args, **kwargs)
1137+
return TableFunction._create_table_udf(
1138+
*args, with_session=with_session, **kwargs
1139+
)
11481140
if args and hasattr(args[0], "__datafusion_table_function__"):
11491141
# Case 2: We have a datafusion FFI provided function
11501142
return TableFunction(args[1], args[0])
11511143
# Case 3: Used as a decorator with parameters
1152-
return TableFunction._create_table_udf_decorator(*args, **kwargs)
1144+
return TableFunction._create_table_udf_decorator(
1145+
*args, with_session=with_session, **kwargs
1146+
)
11531147

11541148
@staticmethod
11551149
def _create_table_udf(
11561150
func: Callable[..., Any],
11571151
name: str,
1152+
*,
1153+
with_session: bool = False,
11581154
) -> TableFunction:
11591155
"""Create a TableFunction instance from function arguments."""
11601156
if not callable(func):
11611157
msg = "`func` must be callable."
11621158
raise TypeError(msg)
11631159

1164-
return TableFunction(name, func)
1160+
return TableFunction(name, func, with_session=with_session)
11651161

11661162
@staticmethod
11671163
def _create_table_udf_decorator(
11681164
name: str | None = None,
1169-
) -> Callable[[Callable[[], WindowEvaluator]], Callable[..., Expr]]:
1170-
"""Create a decorator for a WindowUDF."""
1171-
1172-
def decorator(func: Callable[[], WindowEvaluator]) -> Callable[..., Expr]:
1173-
return TableFunction._create_table_udf(func, name)
1165+
*,
1166+
with_session: bool = False,
1167+
) -> Callable[[Callable[..., Any]], TableFunction]:
1168+
"""Create a decorator for a TableFunction."""
1169+
1170+
def decorator(func: Callable[..., Any]) -> TableFunction:
1171+
return TableFunction._create_table_udf(
1172+
func, name, with_session=with_session
1173+
)
11741174

11751175
return decorator
11761176

python/tests/test_udtf.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -137,11 +137,11 @@ def string_arg_func(prefix: Expr) -> TableProviderExportable:
137137

138138

139139
def test_python_table_function_receives_session() -> None:
140-
"""A UDTF whose signature declares ``session`` gets the calling ctx."""
140+
"""A UDTF registered ``with_session=True`` gets the calling ctx."""
141141
ctx = SessionContext()
142142
captured: list[SessionContext] = []
143143

144-
@udtf("session_aware_func")
144+
@udtf("session_aware_func", with_session=True)
145145
def session_aware_func(*, session: SessionContext) -> TableProviderExportable:
146146
captured.append(session)
147147
batch = pa.RecordBatch.from_pydict({"a": [1, 2, 3]})
@@ -165,7 +165,7 @@ def test_python_table_function_session_used_for_metadata() -> None:
165165

166166
seen_tables: list[set[str]] = []
167167

168-
@udtf("table_inventory")
168+
@udtf("table_inventory", with_session=True)
169169
def table_inventory(*, session: SessionContext) -> TableProviderExportable:
170170
# Stash the visible tables to verify the session wired through.
171171
seen_tables.append(session.catalog().schema().names())
@@ -179,8 +179,8 @@ def table_inventory(*, session: SessionContext) -> TableProviderExportable:
179179
assert result[0].column(0).to_pylist() == ["base_tbl"]
180180

181181

182-
def test_python_table_function_class_callable_session_kwarg() -> None:
183-
"""Class-based UDTFs whose __call__ accepts ``session`` get it too."""
182+
def test_python_table_function_class_callable_with_session() -> None:
183+
"""Class-based UDTFs opt in via ``with_session=True``."""
184184
ctx = SessionContext()
185185
captured: list[SessionContext] = []
186186

@@ -193,9 +193,25 @@ def __call__(
193193
batch = pa.RecordBatch.from_pydict({"a": list(range(count))})
194194
return Table(ds.dataset([batch]))
195195

196-
ctx.register_udtf(udtf(SessionAware(), "session_class_func"))
196+
ctx.register_udtf(udtf(SessionAware(), "session_class_func", with_session=True))
197197
result = ctx.sql("SELECT * FROM session_class_func(3)").collect()
198198

199199
assert len(captured) == 1
200200
assert isinstance(captured[0], SessionContext)
201201
assert result[0].column(0).to_pylist() == [0, 1, 2]
202+
203+
204+
def test_python_table_function_without_session_flag_no_injection() -> None:
205+
"""Default registration (no ``with_session``) calls func positionally."""
206+
ctx = SessionContext()
207+
208+
@udtf("plain_func")
209+
def plain_func(n: Expr) -> TableProviderExportable:
210+
count = n.to_variant().value_i64()
211+
batch = pa.RecordBatch.from_pydict({"a": list(range(count))})
212+
return Table(ds.dataset([batch]))
213+
214+
ctx.register_udtf(plain_func)
215+
result = ctx.sql("SELECT * FROM plain_func(4)").collect()
216+
217+
assert result[0].column(0).to_pylist() == [0, 1, 2, 3]

0 commit comments

Comments
 (0)