Skip to content

Commit 1b3d543

Browse files
timsaucerclaude
andcommitted
feat: expose array_filter higher-order function
Add array_filter, the remaining lambda-based higher-order array function in DataFusion (alongside the already-exposed array_transform and array_any_match). Includes the list_filter alias matching upstream, tests, and documentation in the expressions guide and skill. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent e0a46a0 commit 1b3d543

5 files changed

Lines changed: 76 additions & 4 deletions

File tree

crates/core/src/functions.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,13 @@ fn array_any_match(array: PyExpr, predicate: PyExpr) -> PyExpr {
190190
datafusion::functions_nested::expr_fn::array_any_match(array.into(), predicate.into()).into()
191191
}
192192

193+
/// Higher-order function: keep the elements of `array` for which `predicate`
194+
/// (a lambda returning a boolean) is true, returning a new filtered array.
195+
#[pyfunction]
196+
fn array_filter(array: PyExpr, predicate: PyExpr) -> PyExpr {
197+
datafusion::functions_nested::expr_fn::array_filter(array.into(), predicate.into()).into()
198+
}
199+
193200
/// Computes a binary hash of the given data. type is the algorithm to use.
194201
/// Standard algorithms are md5, sha224, sha256, sha384, sha512, blake2s, blake2b, and blake3.
195202
// #[pyfunction(value, method)]
@@ -1118,6 +1125,7 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
11181125
m.add_wrapped(wrap_pyfunction!(lambda_var))?;
11191126
m.add_wrapped(wrap_pyfunction!(array_transform))?;
11201127
m.add_wrapped(wrap_pyfunction!(array_any_match))?;
1128+
m.add_wrapped(wrap_pyfunction!(array_filter))?;
11211129

11221130
// Array Functions
11231131
m.add_wrapped(wrap_pyfunction!(array_append))?;

docs/source/user-guide/common-operations/expressions.rst

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,8 +150,10 @@ Higher-order functions and lambdas
150150

151151
Some array functions are *higher-order*: they take a lambda that runs once per
152152
element. :py:func:`~datafusion.functions.array_transform` maps a lambda over
153-
every element, and :py:func:`~datafusion.functions.array_any_match` returns
154-
whether any element satisfies a predicate lambda.
153+
every element, :py:func:`~datafusion.functions.array_filter` keeps the elements
154+
for which a predicate lambda is true, and
155+
:py:func:`~datafusion.functions.array_any_match` returns whether any element
156+
satisfies a predicate lambda.
155157

156158
The simplest way to supply a lambda is a Python ``lambda``. Its parameter names
157159
become the lambda parameters, and its return value becomes the body.
@@ -164,6 +166,7 @@ become the lambda parameters, and its return value becomes the body.
164166
ctx = SessionContext()
165167
df = ctx.from_pydict({"a": [[1, 2, 3], [4, 5]]})
166168
df.select(f.array_transform(col("a"), lambda v: v * 2).alias("doubled"))
169+
df.select(f.array_filter(col("a"), lambda v: v > 2).alias("big_only"))
167170
df.select(f.array_any_match(col("a"), lambda v: v > 3).alias("has_big"))
168171
169172
If you need explicit control over parameter names, build the lambda with

python/datafusion/functions.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@
8585
"array_empty",
8686
"array_except",
8787
"array_extract",
88+
"array_filter",
8889
"array_has",
8990
"array_has_all",
9091
"array_has_any",
@@ -217,6 +218,7 @@
217218
"list_empty",
218219
"list_except",
219220
"list_extract",
221+
"list_filter",
220222
"list_has",
221223
"list_has_all",
222224
"list_has_any",
@@ -632,6 +634,47 @@ def list_any_match(array: Expr, predicate: Expr | Callable[..., Any]) -> Expr:
632634
return array_any_match(array, predicate)
633635

634636

637+
def array_filter(array: Expr, predicate: Expr | Callable[..., Any]) -> Expr:
638+
"""Keep the elements of ``array`` for which ``predicate`` is ``True``.
639+
640+
``predicate`` may be a Python callable, converted to a lambda
641+
automatically, or an explicit lambda built with :py:func:`lambda_`. It must
642+
return a boolean expression. The result is a new array containing only the
643+
matching elements.
644+
645+
Examples:
646+
Using a Python callable:
647+
648+
>>> ctx = dfn.SessionContext()
649+
>>> df = ctx.from_pydict({"a": [[1, 2, 3, 4, 5]]})
650+
>>> df.select(
651+
... F.array_filter(col("a"), lambda v: v > 2).alias("f")
652+
... ).collect_column("f")[0].as_py()
653+
[3, 4, 5]
654+
655+
Using an explicit lambda built with :py:func:`lambda_`:
656+
657+
>>> predicate = F.lambda_(["v"], F.lambda_var("v") > lit(2))
658+
>>> df.select(
659+
... F.array_filter(col("a"), predicate).alias("f")
660+
... ).collect_column("f")[0].as_py()
661+
[3, 4, 5]
662+
663+
See Also:
664+
:py:func:`array_transform`, :py:func:`array_any_match`, :py:func:`lambda_`.
665+
"""
666+
return Expr(f.array_filter(array.expr, _to_lambda(predicate).expr))
667+
668+
669+
def list_filter(array: Expr, predicate: Expr | Callable[..., Any]) -> Expr:
670+
"""Keep the elements of a list for which a predicate is ``True``.
671+
672+
See Also:
673+
This is an alias for :py:func:`array_filter`.
674+
"""
675+
return array_filter(array, predicate)
676+
677+
635678
def in_list(arg: Expr, values: list[Expr], negated: bool = False) -> Expr:
636679
"""Returns whether the argument is contained within the list ``values``.
637680

python/tests/test_lambda.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,23 @@ def _column(df, expr, name):
7878
[False, True],
7979
id="list_any_match_alias",
8080
),
81+
pytest.param(
82+
lambda: f.array_filter(col("a"), lambda v: v > 2),
83+
[[3], [4, 5]],
84+
id="array_filter_callable",
85+
),
86+
pytest.param(
87+
lambda: f.array_filter(
88+
col("a"), f.lambda_(["v"], f.lambda_var("v") > lit(2))
89+
),
90+
[[3], [4, 5]],
91+
id="array_filter_explicit_lambda",
92+
),
93+
pytest.param(
94+
lambda: f.list_filter(col("a"), lambda v: v > 2),
95+
[[3], [4, 5]],
96+
id="list_filter_alias",
97+
),
8198
],
8299
)
83100
def test_higher_order_function_results(df, build_expr, expected):

skills/datafusion_python/SKILL.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -496,11 +496,12 @@ return value becomes the body:
496496

497497
```python
498498
F.array_transform(col("a"), lambda v: v * 2) # map: [1,2,3] -> [2,4,6]
499+
F.array_filter(col("a"), lambda v: v > 2) # filter: [1,2,3] -> [3]
499500
F.array_any_match(col("a"), lambda v: v > 3) # predicate: any element > 3
500501
```
501502

502-
Aliases: `list_transform` for `array_transform`; `any_match` / `list_any_match`
503-
for `array_any_match`.
503+
Aliases: `list_transform` for `array_transform`; `list_filter` for
504+
`array_filter`; `any_match` / `list_any_match` for `array_any_match`.
504505

505506
For explicit parameter names, build the lambda by hand:
506507

0 commit comments

Comments
 (0)