Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 53 additions & 20 deletions bigframes/core/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1822,9 +1822,9 @@ def melt(
Arguments correspond to pandas.melt arguments.
"""
# TODO: Implement col_level and ignore_index
value_labels: pd.Index = pd.Index(
[self.col_id_to_label[col_id] for col_id in value_vars]
)
value_labels: pd.Index = self.column_labels[
[self.value_columns.index(col_id) for col_id in value_vars]
]
id_labels = [self.col_id_to_label[col_id] for col_id in id_vars]

unpivot_expr, (var_col_ids, unpivot_out, passthrough_cols) = unpivot(
Expand Down Expand Up @@ -3417,6 +3417,7 @@ def unpivot(
joined_array, (labels_mapping, column_mapping) = labels_array.relational_join(
array_value, type="cross"
)

new_passthrough_cols = [column_mapping[col] for col in passthrough_columns]
# Last column is offsets
index_col_ids = [labels_mapping[col] for col in labels_array.column_ids[:-1]]
Expand All @@ -3426,20 +3427,24 @@ def unpivot(
unpivot_exprs: List[ex.Expression] = []
# Supports producing multiple stacked ouput columns for stacking only part of hierarchical index
for input_ids in unpivot_columns:
# row explode offset used to choose the input column
# we use offset instead of label as labels are not necessarily unique
cases = itertools.chain(
*(
(
ops.eq_op.as_expr(explode_offsets_id, ex.const(i)),
ex.deref(column_mapping[id_or_null])
if (id_or_null is not None)
else ex.const(None),
col_expr: ex.Expression
if not input_ids:
col_expr = ex.const(None)
else:
# row explode offset used to choose the input column
# we use offset instead of label as labels are not necessarily unique
cases = itertools.chain(
*(
(
ops.eq_op.as_expr(explode_offsets_id, ex.const(i)),
ex.deref(column_mapping[id_or_null])
if (id_or_null is not None)
else ex.const(None),
)
for i, id_or_null in enumerate(input_ids)
)
for i, id_or_null in enumerate(input_ids)
)
)
col_expr = ops.case_when_op.as_expr(*cases)
col_expr = ops.case_when_op.as_expr(*cases)
unpivot_exprs.append(col_expr)

joined_array, unpivot_col_ids = joined_array.compute_values(unpivot_exprs)
Expand All @@ -3457,19 +3462,47 @@ def _pd_index_to_array_value(
Create an ArrayValue from a list of label tuples.
The last column will be row offsets.
"""
id_gen = bigframes.core.identifiers.standard_id_strings()
col_ids = [next(id_gen) for _ in range(index.nlevels)]
offset_id = next(id_gen)

rows = []
labels_as_tuples = utils.index_as_tuples(index)
for row_offset in range(len(index)):
id_gen = bigframes.core.identifiers.standard_id_strings()
row_label = labels_as_tuples[row_offset]
row_label = (row_label,) if not isinstance(row_label, tuple) else row_label
row = {}
for label_part, id in zip(row_label, id_gen):
row[id] = label_part if pd.notnull(label_part) else None
row[next(id_gen)] = row_offset
for label_part, col_id in zip(row_label, col_ids):
row[col_id] = label_part if pd.notnull(label_part) else None
row[offset_id] = row_offset
rows.append(row)

return core.ArrayValue.from_pyarrow(pa.Table.from_pylist(rows), session=session)
import pyarrow as pa
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, this import can be removed because the header of this file has the pyarrow importing?


if not rows:
from bigframes.dtypes import bigframes_dtype_to_arrow_dtype
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we have bigframes.dtypes import in the head of this file.


dtypes_list = getattr(index, "dtypes", None)
if dtypes_list is None:
dtypes_list = (
[index.dtype] if hasattr(index, "dtype") else [pd.Float64Dtype()]
)

fields = []
for col_id, dtype in zip(col_ids, dtypes_list):
try:
pa_type = bigframes_dtype_to_arrow_dtype(dtype)
except Exception:
pa_type = pa.string()
fields.append(pa.field(col_id, pa_type))
fields.append(pa.field(offset_id, pa.int64()))
schema = pa.schema(fields)
pt = pa.Table.from_pylist([], schema=schema)
else:
pt = pa.Table.from_pylist(rows)
pt = pt.rename_columns([*col_ids, offset_id])

return core.ArrayValue.from_pyarrow(pt, session=session)


def _resolve_index_col(
Expand Down
12 changes: 12 additions & 0 deletions tests/system/small/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -5902,6 +5902,18 @@ def test_to_gbq_table_labels(scalars_df_index):
assert table.labels["test"] == "labels"


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: move test_dataframe_melt_multiindex to tests/system/small/test_multiindex.py file, which has tests related to multi index.

def test_dataframe_melt_multiindex(session):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like I cannot reproduce the issue on the main branch. See screenshot/AXRSuFmXPTt5gzj

# Tests that `melt` operations via count do not cause MultiIndex drops in Arrow
df = pd.DataFrame({"A": [1], "B": ["string"], "C": [3]})
df.columns = pd.MultiIndex.from_tuples(
[("Group1", "A"), ("Group2", "B"), ("Group1", "C")]
)
bdf = session.read_pandas(df)

count_df = bdf.count().to_pandas()
assert count_df.shape[0] == 3


@pytest.mark.parametrize(
("col_names", "ignore_index"),
[
Expand Down
Loading