Skip to content

Commit 3afbc7d

Browse files
timsaucerclaude
andcommitted
feat: accept distinct kwarg on sum and avg
Upstream exposes `sum_distinct` / `avg_distinct` / `count_distinct` as sibling functions that call the same underlying UDAF with `distinct: bool = true`. The Rust binding side already routes `distinct=Some(true)` through the aggregate builder for `sum`, `avg`, and `count` — but only `count` exposed the kwarg on the Python wrapper. Add `distinct: bool = False` to `sum()` and `avg()` mirroring the existing `count()` signature, and update SKILL.md so the check-upstream audit does not re-flag the three upstream `*_distinct` shortcuts as gaps. The plan emitted by `sum(col, distinct=True)` matches what upstream's `sum_distinct(col)` builds. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent dac9ec6 commit 3afbc7d

3 files changed

Lines changed: 64 additions & 5 deletions

File tree

.ai/skills/check-upstream/SKILL.md

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,11 +82,18 @@ The user may specify an area via `$ARGUMENTS`. If no area is specified or "all"
8282
- Python API: `python/datafusion/functions.py` (aggregate functions are mixed in with scalar functions)
8383
- Rust bindings: `crates/core/src/functions.rs`
8484

85+
**Evaluated and not requiring separate Python exposure:**
86+
- `count_distinct` — covered by `count(expr, distinct=True)`. Both forms call
87+
`count_udaf` with `distinct: bool = true` and produce the same logical plan.
88+
- `sum_distinct` — covered by `sum(expr, distinct=True)`.
89+
- `avg_distinct` — covered by `avg(expr, distinct=True)`.
90+
8591
**How to check:**
8692
1. Fetch the upstream aggregate function documentation page
8793
2. Compare against aggregate functions in `python/datafusion/functions.py` (check `__all__` list and function definitions)
8894
3. A function is covered if it exists in the Python API, even if it aliases another function's Rust binding
89-
4. Report only functions missing from the Python API
95+
4. Check against the "evaluated and not requiring exposure" list before flagging as a gap
96+
5. Report only functions missing from the Python API
9097

9198
### 3. Window Functions
9299

python/datafusion/functions.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4521,17 +4521,20 @@ def grouping(
45214521

45224522
def avg(
45234523
expression: Expr,
4524+
distinct: bool = False,
45244525
filter: Expr | None = None,
45254526
) -> Expr:
45264527
"""Returns the average value.
45274528
45284529
This aggregate function expects a numeric expression and will return a float.
45294530
45304531
If using the builder functions described in ref:`_aggregation` this function ignores
4531-
the options ``order_by``, ``null_treatment``, and ``distinct``.
4532+
the options ``order_by`` and ``null_treatment``.
45324533
45334534
Args:
45344535
expression: Values to combine into an array
4536+
distinct: If True, only distinct values are averaged. Equivalent to the
4537+
upstream ``avg_distinct`` shortcut.
45354538
filter: If provided, only compute against rows for which the filter is True
45364539
45374540
Examples:
@@ -4551,9 +4554,17 @@ def avg(
45514554
... ).alias("v")])
45524555
>>> result.collect_column("v")[0].as_py()
45534556
2.5
4557+
4558+
>>> df = ctx.from_pydict({"a": [1.0, 1.0, 2.0, 3.0]})
4559+
>>> result = df.aggregate(
4560+
... [], [dfn.functions.avg(
4561+
... dfn.col("a"), distinct=True,
4562+
... ).alias("v")])
4563+
>>> result.collect_column("v")[0].as_py()
4564+
2.0
45544565
"""
45554566
filter_raw = filter.expr if filter is not None else None
4556-
return Expr(f.avg(expression.expr, filter=filter_raw))
4567+
return Expr(f.avg(expression.expr, distinct=distinct, filter=filter_raw))
45574568

45584569

45594570
def corr(value_y: Expr, value_x: Expr, filter: Expr | None = None) -> Expr:
@@ -4838,17 +4849,20 @@ def min(expression: Expr, filter: Expr | None = None) -> Expr:
48384849

48394850
def sum(
48404851
expression: Expr,
4852+
distinct: bool = False,
48414853
filter: Expr | None = None,
48424854
) -> Expr:
48434855
"""Computes the sum of a set of numbers.
48444856
48454857
This aggregate function expects a numeric expression.
48464858
48474859
If using the builder functions described in ref:`_aggregation` this function ignores
4848-
the options ``order_by``, ``null_treatment``, and ``distinct``.
4860+
the options ``order_by`` and ``null_treatment``.
48494861
48504862
Args:
48514863
expression: Values to combine into an array
4864+
distinct: If True, only distinct values are summed. Equivalent to the
4865+
upstream ``sum_distinct`` shortcut.
48524866
filter: If provided, only compute against rows for which the filter is True
48534867
48544868
Examples:
@@ -4868,9 +4882,17 @@ def sum(
48684882
... ).alias("v")])
48694883
>>> result.collect_column("v")[0].as_py()
48704884
5
4885+
4886+
>>> df = ctx.from_pydict({"a": [1, 1, 2, 3]})
4887+
>>> result = df.aggregate(
4888+
... [], [dfn.functions.sum(
4889+
... dfn.col("a"), distinct=True,
4890+
... ).alias("v")])
4891+
>>> result.collect_column("v")[0].as_py()
4892+
6
48714893
"""
48724894
filter_raw = filter.expr if filter is not None else None
4873-
return Expr(f.sum(expression.expr, filter=filter_raw))
4895+
return Expr(f.sum(expression.expr, distinct=distinct, filter=filter_raw))
48744896

48754897

48764898
def stddev(expression: Expr, filter: Expr | None = None) -> Expr:

python/tests/test_functions.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1957,6 +1957,36 @@ def test_get_field(df):
19571957
assert result.column(1) == pa.array([4, 5, 6])
19581958

19591959

1960+
def test_sum_distinct_kwarg():
1961+
ctx = SessionContext()
1962+
df = ctx.from_pydict({"a": [1, 1, 2, 3]})
1963+
distinct = (
1964+
df.aggregate([], [f.sum(column("a"), distinct=True).alias("v")])
1965+
.collect_column("v")[0]
1966+
.as_py()
1967+
)
1968+
total = (
1969+
df.aggregate([], [f.sum(column("a")).alias("v")]).collect_column("v")[0].as_py()
1970+
)
1971+
assert distinct == 6
1972+
assert total == 7
1973+
1974+
1975+
def test_avg_distinct_kwarg():
1976+
ctx = SessionContext()
1977+
df = ctx.from_pydict({"a": [1.0, 1.0, 2.0, 3.0]})
1978+
distinct = (
1979+
df.aggregate([], [f.avg(column("a"), distinct=True).alias("v")])
1980+
.collect_column("v")[0]
1981+
.as_py()
1982+
)
1983+
mean = (
1984+
df.aggregate([], [f.avg(column("a")).alias("v")]).collect_column("v")[0].as_py()
1985+
)
1986+
assert distinct == 2.0
1987+
assert mean == 1.75
1988+
1989+
19601990
def test_arrow_metadata():
19611991
ctx = SessionContext()
19621992
field = pa.field("val", pa.int64(), metadata={"key1": "value1", "key2": "value2"})

0 commit comments

Comments
 (0)