Skip to content
This repository was archived by the owner on Dec 1, 2025. It is now read-only.

Commit cb2bbbf

Browse files
authored
Merge pull request #90 from lincc-frameworks/infer_nesting
add infer_nesting to reduce
2 parents 49d26eb + d7e0d53 commit cb2bbbf

2 files changed

Lines changed: 71 additions & 3 deletions

File tree

src/nested_dask/core.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import pandas as pd
1212
import pyarrow as pa
1313
from dask.dataframe.dask_expr._collection import new_collection
14+
from dask.dataframe.dask_expr._expr import no_default as dsk_no_default
1415
from nested_pandas.series.dtype import NestedDtype
1516
from nested_pandas.series.packer import pack, pack_flat, pack_lists
1617
from pandas._libs import lib
@@ -731,7 +732,7 @@ def sort_values(
731732
meta=self._meta,
732733
)
733734

734-
def reduce(self, func, *args, meta=None, **kwargs) -> NestedFrame:
735+
def reduce(self, func, *args, meta=dsk_no_default, infer_nesting=True, **kwargs) -> NestedFrame:
735736
"""
736737
Takes a function and applies it to each top-level row of the NestedFrame.
737738
@@ -751,7 +752,15 @@ def reduce(self, func, *args, meta=None, **kwargs) -> NestedFrame:
751752
Positional arguments to pass to the function, the first *args should be the names of the
752753
columns to apply the function to.
753754
meta : dataframe or series-like, optional
754-
The dask meta of the output.
755+
The dask meta of the output. If not provided, dask will try to
756+
infer the metadata. This may lead to unexpected results, so
757+
providing meta is recommended.
758+
infer_nesting : bool, default True
759+
If True, the function will pack output columns into nested
760+
structures based on column names adhering to a nested naming
761+
scheme. E.g. "nested.b" and "nested.c" will be packed into a column
762+
called "nested" with columns "b" and "c". If False, all outputs
763+
will be returned as base columns.
755764
kwargs : keyword arguments, optional
756765
Keyword arguments to pass to the function.
757766
@@ -773,6 +782,26 @@ def reduce(self, func, *args, meta=None, **kwargs) -> NestedFrame:
773782
>>> '''reduce will return a NestedFrame with two columns'''
774783
>>> return {"sum_col1": sum(col1), "sum_col2": sum(col2)}
775784
785+
When using nesting inference (infer_nesting=True), the output may
786+
contain nested columns. In such cases, the meta should be provided with
787+
the appropriate dtype for these columns. For example, the following
788+
function, which produces a nested column "lc":
789+
790+
>>> def complex_output(flux):
791+
>>> return {"max_flux": np.max(flux),
792+
>>> "lc.flux_quantiles": np.quantile(flux, [0.1, 0.2, 0.3, 0.4, 0.5]),
793+
>>> "lc.labels": [0.1, 0.2, 0.3, 0.4, 0.5]}
794+
795+
Would require the following meta:
796+
797+
>>> # create a NestedDtype for the nested column "lc"
798+
>>> from nested_pandas.series.dtype import NestedDtype
799+
>>> lc_dtype = NestedDtype(pa.struct([pa.field("flux_quantiles", pa.list_(pa.float64())),
800+
>>> pa.field("labels", pa.list_(pa.float64()))]))
801+
>>> # use the lc_dtype in meta creation
802+
>>> result_meta = npd.NestedFrame({'max_flux':pd.Series([], dtype='float'),
803+
>>> 'lc':pd.Series([], dtype=lc_dtype)})
804+
776805
"""
777806

778807
# Handle meta shorthands to produce nestedframe output
@@ -787,7 +816,9 @@ def reduce(self, func, *args, meta=None, **kwargs) -> NestedFrame:
787816
# apply nested_pandas reduce via map_partitions
788817
# wrap the partition in a npd.NestedFrame call for:
789818
# https://github.com/lincc-frameworks/nested-dask/issues/21
790-
return self.map_partitions(lambda x: npd.NestedFrame(x).reduce(func, *args, **kwargs), meta=meta)
819+
return self.map_partitions(
820+
lambda x: npd.NestedFrame(x).reduce(func, *args, infer_nesting=infer_nesting, **kwargs), meta=meta
821+
)
791822

792823
def to_parquet(self, path, by_layer=True, **kwargs) -> None:
793824
"""Creates parquet file(s) with the data of a NestedFrame, either

tests/nested_dask/test_nestedframe.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,43 @@ def mean_arr(arr): # type: ignore
363363
assert isinstance(reduced.compute(), npd.NestedFrame)
364364

365365

366+
def test_reduce_output_inference():
367+
"""test the extension of the reduce result nesting inference"""
368+
369+
ndd = generate_data(20, 20, npartitions=2, seed=1)
370+
371+
def complex_output(flux):
372+
return {
373+
"max_flux": np.max(flux),
374+
"lc.flux_quantiles": np.quantile(flux, [0.1, 0.2, 0.3, 0.4, 0.5]),
375+
"lc.labels": [0.1, 0.2, 0.3, 0.4, 0.5],
376+
"meta.colors": ["green", "red", "blue"],
377+
}
378+
379+
# this sucks
380+
result_meta = npd.NestedFrame(
381+
{
382+
"max_flux": pd.Series([], dtype="float"),
383+
"lc": pd.Series(
384+
[],
385+
dtype=NestedDtype(
386+
pa.struct(
387+
[
388+
pa.field("flux_quantiles", pa.list_(pa.float64())),
389+
pa.field("labels", pa.list_(pa.float64())),
390+
]
391+
)
392+
),
393+
),
394+
"meta": pd.Series([], dtype=NestedDtype(pa.struct([pa.field("colors", pa.list_(pa.string()))]))),
395+
}
396+
)
397+
result = ndd.reduce(complex_output, "nested.flux", infer_nesting=True, meta=result_meta)
398+
399+
assert list(result.dtypes) == list(result.compute().dtypes)
400+
assert list(result.columns) == list(result.compute().columns)
401+
402+
366403
def test_to_parquet_combined(test_dataset, tmp_path):
367404
"""test to_parquet when saving all layers to a single directory"""
368405

0 commit comments

Comments
 (0)