-
Notifications
You must be signed in to change notification settings - Fork 67
fix: support melting empty DataFrames without crashing #2509
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1822,9 +1822,9 @@ def melt( | |
| Arguments correspond to pandas.melt arguments. | ||
| """ | ||
| # TODO: Implement col_level and ignore_index | ||
| value_labels: pd.Index = pd.Index( | ||
| [self.col_id_to_label[col_id] for col_id in value_vars] | ||
| ) | ||
| value_labels: pd.Index = self.column_labels[ | ||
| [self.value_columns.index(col_id) for col_id in value_vars] | ||
| ] | ||
| id_labels = [self.col_id_to_label[col_id] for col_id in id_vars] | ||
|
|
||
| unpivot_expr, (var_col_ids, unpivot_out, passthrough_cols) = unpivot( | ||
|
|
@@ -3417,6 +3417,7 @@ def unpivot( | |
| joined_array, (labels_mapping, column_mapping) = labels_array.relational_join( | ||
| array_value, type="cross" | ||
| ) | ||
|
|
||
| new_passthrough_cols = [column_mapping[col] for col in passthrough_columns] | ||
| # Last column is offsets | ||
| index_col_ids = [labels_mapping[col] for col in labels_array.column_ids[:-1]] | ||
|
|
@@ -3426,20 +3427,24 @@ def unpivot( | |
| unpivot_exprs: List[ex.Expression] = [] | ||
| # Supports producing multiple stacked ouput columns for stacking only part of hierarchical index | ||
| for input_ids in unpivot_columns: | ||
| # row explode offset used to choose the input column | ||
| # we use offset instead of label as labels are not necessarily unique | ||
| cases = itertools.chain( | ||
| *( | ||
| ( | ||
| ops.eq_op.as_expr(explode_offsets_id, ex.const(i)), | ||
| ex.deref(column_mapping[id_or_null]) | ||
| if (id_or_null is not None) | ||
| else ex.const(None), | ||
| col_expr: ex.Expression | ||
| if not input_ids: | ||
| col_expr = ex.const(None) | ||
| else: | ||
| # row explode offset used to choose the input column | ||
| # we use offset instead of label as labels are not necessarily unique | ||
| cases = itertools.chain( | ||
| *( | ||
| ( | ||
| ops.eq_op.as_expr(explode_offsets_id, ex.const(i)), | ||
| ex.deref(column_mapping[id_or_null]) | ||
| if (id_or_null is not None) | ||
| else ex.const(None), | ||
| ) | ||
| for i, id_or_null in enumerate(input_ids) | ||
| ) | ||
| for i, id_or_null in enumerate(input_ids) | ||
| ) | ||
| ) | ||
| col_expr = ops.case_when_op.as_expr(*cases) | ||
| col_expr = ops.case_when_op.as_expr(*cases) | ||
| unpivot_exprs.append(col_expr) | ||
|
|
||
| joined_array, unpivot_col_ids = joined_array.compute_values(unpivot_exprs) | ||
|
|
@@ -3457,19 +3462,47 @@ def _pd_index_to_array_value( | |
| Create an ArrayValue from a list of label tuples. | ||
| The last column will be row offsets. | ||
| """ | ||
| id_gen = bigframes.core.identifiers.standard_id_strings() | ||
| col_ids = [next(id_gen) for _ in range(index.nlevels)] | ||
| offset_id = next(id_gen) | ||
|
|
||
| rows = [] | ||
| labels_as_tuples = utils.index_as_tuples(index) | ||
| for row_offset in range(len(index)): | ||
| id_gen = bigframes.core.identifiers.standard_id_strings() | ||
| row_label = labels_as_tuples[row_offset] | ||
| row_label = (row_label,) if not isinstance(row_label, tuple) else row_label | ||
| row = {} | ||
| for label_part, id in zip(row_label, id_gen): | ||
| row[id] = label_part if pd.notnull(label_part) else None | ||
| row[next(id_gen)] = row_offset | ||
| for label_part, col_id in zip(row_label, col_ids): | ||
| row[col_id] = label_part if pd.notnull(label_part) else None | ||
| row[offset_id] = row_offset | ||
| rows.append(row) | ||
|
|
||
| return core.ArrayValue.from_pyarrow(pa.Table.from_pylist(rows), session=session) | ||
| import pyarrow as pa | ||
|
|
||
| if not rows: | ||
| from bigframes.dtypes import bigframes_dtype_to_arrow_dtype | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we have |
||
|
|
||
| dtypes_list = getattr(index, "dtypes", None) | ||
| if dtypes_list is None: | ||
| dtypes_list = ( | ||
| [index.dtype] if hasattr(index, "dtype") else [pd.Float64Dtype()] | ||
| ) | ||
|
|
||
| fields = [] | ||
| for col_id, dtype in zip(col_ids, dtypes_list): | ||
| try: | ||
| pa_type = bigframes_dtype_to_arrow_dtype(dtype) | ||
| except Exception: | ||
| pa_type = pa.string() | ||
| fields.append(pa.field(col_id, pa_type)) | ||
| fields.append(pa.field(offset_id, pa.int64())) | ||
| schema = pa.schema(fields) | ||
| pt = pa.Table.from_pylist([], schema=schema) | ||
| else: | ||
| pt = pa.Table.from_pylist(rows) | ||
| pt = pt.rename_columns([*col_ids, offset_id]) | ||
|
|
||
| return core.ArrayValue.from_pyarrow(pt, session=session) | ||
|
|
||
|
|
||
| def _resolve_index_col( | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5902,6 +5902,18 @@ def test_to_gbq_table_labels(scalars_df_index): | |
| assert table.labels["test"] == "labels" | ||
|
|
||
|
|
||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: move |
||
| def test_dataframe_melt_multiindex(session): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It looks like I cannot reproduce the issue on the main branch. See screenshot/AXRSuFmXPTt5gzj |
||
| # Tests that `melt` operations via count do not cause MultiIndex drops in Arrow | ||
| df = pd.DataFrame({"A": [1], "B": ["string"], "C": [3]}) | ||
| df.columns = pd.MultiIndex.from_tuples( | ||
| [("Group1", "A"), ("Group2", "B"), ("Group1", "C")] | ||
| ) | ||
| bdf = session.read_pandas(df) | ||
|
|
||
| count_df = bdf.count().to_pandas() | ||
| assert count_df.shape[0] == 3 | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| ("col_names", "ignore_index"), | ||
| [ | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIUC, this import can be removed because the header of this file has the pyarrow importing?