diff --git a/bigframes/core/blocks.py b/bigframes/core/blocks.py index 239eedf6d3..fff03c9310 100644 --- a/bigframes/core/blocks.py +++ b/bigframes/core/blocks.py @@ -1822,9 +1822,9 @@ def melt( Arguments correspond to pandas.melt arguments. """ # TODO: Implement col_level and ignore_index - value_labels: pd.Index = pd.Index( - [self.col_id_to_label[col_id] for col_id in value_vars] - ) + value_labels: pd.Index = self.column_labels[ + [self.value_columns.index(col_id) for col_id in value_vars] + ] id_labels = [self.col_id_to_label[col_id] for col_id in id_vars] unpivot_expr, (var_col_ids, unpivot_out, passthrough_cols) = unpivot( @@ -3417,6 +3417,7 @@ def unpivot( joined_array, (labels_mapping, column_mapping) = labels_array.relational_join( array_value, type="cross" ) + new_passthrough_cols = [column_mapping[col] for col in passthrough_columns] # Last column is offsets index_col_ids = [labels_mapping[col] for col in labels_array.column_ids[:-1]] @@ -3426,20 +3427,24 @@ def unpivot( unpivot_exprs: List[ex.Expression] = [] # Supports producing multiple stacked ouput columns for stacking only part of hierarchical index for input_ids in unpivot_columns: - # row explode offset used to choose the input column - # we use offset instead of label as labels are not necessarily unique - cases = itertools.chain( - *( - ( - ops.eq_op.as_expr(explode_offsets_id, ex.const(i)), - ex.deref(column_mapping[id_or_null]) - if (id_or_null is not None) - else ex.const(None), + col_expr: ex.Expression + if not input_ids: + col_expr = ex.const(None) + else: + # row explode offset used to choose the input column + # we use offset instead of label as labels are not necessarily unique + cases = itertools.chain( + *( + ( + ops.eq_op.as_expr(explode_offsets_id, ex.const(i)), + ex.deref(column_mapping[id_or_null]) + if (id_or_null is not None) + else ex.const(None), + ) + for i, id_or_null in enumerate(input_ids) ) - for i, id_or_null in enumerate(input_ids) ) - ) - col_expr = ops.case_when_op.as_expr(*cases) + col_expr = ops.case_when_op.as_expr(*cases) unpivot_exprs.append(col_expr) joined_array, unpivot_col_ids = joined_array.compute_values(unpivot_exprs) @@ -3457,19 +3462,47 @@ def _pd_index_to_array_value( Create an ArrayValue from a list of label tuples. The last column will be row offsets. """ + id_gen = bigframes.core.identifiers.standard_id_strings() + col_ids = [next(id_gen) for _ in range(index.nlevels)] + offset_id = next(id_gen) + rows = [] labels_as_tuples = utils.index_as_tuples(index) for row_offset in range(len(index)): - id_gen = bigframes.core.identifiers.standard_id_strings() row_label = labels_as_tuples[row_offset] row_label = (row_label,) if not isinstance(row_label, tuple) else row_label row = {} - for label_part, id in zip(row_label, id_gen): - row[id] = label_part if pd.notnull(label_part) else None - row[next(id_gen)] = row_offset + for label_part, col_id in zip(row_label, col_ids): + row[col_id] = label_part if pd.notnull(label_part) else None + row[offset_id] = row_offset rows.append(row) - return core.ArrayValue.from_pyarrow(pa.Table.from_pylist(rows), session=session) + import pyarrow as pa + + if not rows: + from bigframes.dtypes import bigframes_dtype_to_arrow_dtype + + dtypes_list = getattr(index, "dtypes", None) + if dtypes_list is None: + dtypes_list = ( + [index.dtype] if hasattr(index, "dtype") else [pd.Float64Dtype()] + ) + + fields = [] + for col_id, dtype in zip(col_ids, dtypes_list): + try: + pa_type = bigframes_dtype_to_arrow_dtype(dtype) + except Exception: + pa_type = pa.string() + fields.append(pa.field(col_id, pa_type)) + fields.append(pa.field(offset_id, pa.int64())) + schema = pa.schema(fields) + pt = pa.Table.from_pylist([], schema=schema) + else: + pt = pa.Table.from_pylist(rows) + pt = pt.rename_columns([*col_ids, offset_id]) + + return core.ArrayValue.from_pyarrow(pt, session=session) def _resolve_index_col( diff --git a/tests/system/small/test_dataframe.py b/tests/system/small/test_dataframe.py index 9683a8bc52..c9f9be880f 100644 --- a/tests/system/small/test_dataframe.py +++ b/tests/system/small/test_dataframe.py @@ -5902,6 +5902,18 @@ def test_to_gbq_table_labels(scalars_df_index): assert table.labels["test"] == "labels" +def test_dataframe_melt_multiindex(session): + # Tests that `melt` operations via count do not cause MultiIndex drops in Arrow + df = pd.DataFrame({"A": [1], "B": ["string"], "C": [3]}) + df.columns = pd.MultiIndex.from_tuples( + [("Group1", "A"), ("Group2", "B"), ("Group1", "C")] + ) + bdf = session.read_pandas(df) + + count_df = bdf.count().to_pandas() + assert count_df.shape[0] == 3 + + @pytest.mark.parametrize( ("col_names", "ignore_index"), [