diff --git a/bigframes/core/logging/data_types.py b/bigframes/core/logging/data_types.py index db99b1a020..c80828b7c3 100644 --- a/bigframes/core/logging/data_types.py +++ b/bigframes/core/logging/data_types.py @@ -12,15 +12,94 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + +import functools from bigframes import dtypes +from bigframes.core import agg_expressions, bigframe_node, expression, nodes +from bigframes.core.rewrite import schema_binding + + +def encode_type_refs(root: bigframe_node.BigFrameNode) -> str: + return f"{root.reduce_up(_encode_type_refs_from_node):x}" + + +def _encode_type_refs_from_node( + node: bigframe_node.BigFrameNode, child_results: tuple[int, ...] +) -> int: + child_result = functools.reduce(lambda x, y: x | y, child_results, 0) + + curr_result = 0 + if isinstance(node, nodes.FilterNode): + curr_result = _encode_type_refs_from_expr(node.predicate, node.child) + elif isinstance(node, nodes.ProjectionNode): + for assignment in node.assignments: + expr = assignment[0] + if isinstance(expr, (expression.DerefOp)): + # Ignore direct assignments in projection nodes. + continue + curr_result = curr_result | _encode_type_refs_from_expr( + assignment[0], node.child + ) + elif isinstance(node, nodes.SelectionNode): + # Do nothing + pass + elif isinstance(node, nodes.OrderByNode): + for by in node.by: + curr_result = curr_result | _encode_type_refs_from_expr( + by.scalar_expression, node.child + ) + elif isinstance(node, nodes.JoinNode): + for left, right in node.conditions: + curr_result = ( + curr_result + | _encode_type_refs_from_expr(left, node.left_child) + | _encode_type_refs_from_expr(right, node.right_child) + ) + elif isinstance(node, nodes.InNode): + curr_result = _encode_type_refs_from_expr(node.left_col, node.left_child) + elif isinstance(node, nodes.AggregateNode): + for agg, _ in node.aggregations: + curr_result = curr_result | _encode_type_refs_from_expr(agg, node.child) + elif isinstance(node, nodes.WindowOpNode): + for grouping_key in node.window_spec.grouping_keys: + curr_result = curr_result | _encode_type_refs_from_expr( + grouping_key, node.child + ) + for ordering_expr in node.window_spec.ordering: + curr_result = curr_result | _encode_type_refs_from_expr( + ordering_expr.scalar_expression, node.child + ) + for col_def in node.agg_exprs: + curr_result = curr_result | _encode_type_refs_from_expr( + col_def.expression, node.child + ) + + return child_result | curr_result + + +def _encode_type_refs_from_expr( + expr: expression.Expression, child_node: bigframe_node.BigFrameNode +) -> int: + # TODO(b/409387790): Remove this branch once SQLGlot compiler fully replaces Ibis compiler + if not expr.is_resolved: + if isinstance(expr, agg_expressions.Aggregation): + expr = schema_binding._bind_schema_to_aggregation_expr(expr, child_node) + else: + expr = expression.bind_schema_fields(expr, child_node.field_by_id) + result = _get_dtype_mask(expr.output_type) + for child_expr in expr.children: + result = result | _encode_type_refs_from_expr(child_expr, child_node) -def _add_data_type(existing_types: int, curr_type: dtypes.Dtype) -> int: - return existing_types | _get_dtype_mask(curr_type) + return result -def _get_dtype_mask(dtype: dtypes.Dtype) -> int: +def _get_dtype_mask(dtype: dtypes.Dtype | None) -> int: + if dtype is None: + # If the dtype is not given, ignore + return 0 if dtype == dtypes.INT_DTYPE: return 1 << 1 if dtype == dtypes.FLOAT_DTYPE: diff --git a/tests/system/small/core/logging/__init__.py b/tests/system/small/core/logging/__init__.py new file mode 100644 index 0000000000..58d482ea38 --- /dev/null +++ b/tests/system/small/core/logging/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/system/small/core/logging/test_data_types.py b/tests/system/small/core/logging/test_data_types.py new file mode 100644 index 0000000000..f7d1e82754 --- /dev/null +++ b/tests/system/small/core/logging/test_data_types.py @@ -0,0 +1,101 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Sequence + +from bigframes import dtypes +from bigframes.core.logging import data_types + + +def encode_types(inputs: Sequence[dtypes.Dtype]) -> str: + encoded_val = 0 + for t in inputs: + encoded_val = encoded_val | data_types._get_dtype_mask(t) + + return f"{encoded_val:x}" + + +def test_get_type_refs_no_op(scalars_df_index): + node = scalars_df_index._block._expr.node + expected_types: list[dtypes.Dtype] = [] + + assert data_types.encode_type_refs(node) == encode_types(expected_types) + + +def test_get_type_refs_projection(scalars_df_index): + node = ( + scalars_df_index["datetime_col"] - scalars_df_index["datetime_col"] + )._block._expr.node + expected_types = [dtypes.DATETIME_DTYPE, dtypes.TIMEDELTA_DTYPE] + + assert data_types.encode_type_refs(node) == encode_types(expected_types) + + +def test_get_type_refs_filter(scalars_df_index): + node = scalars_df_index[scalars_df_index["int64_col"] > 0]._block._expr.node + expected_types = [dtypes.INT_DTYPE, dtypes.BOOL_DTYPE] + + assert data_types.encode_type_refs(node) == encode_types(expected_types) + + +def test_get_type_refs_order_by(scalars_df_index): + node = scalars_df_index.sort_index()._block._expr.node + expected_types = [dtypes.INT_DTYPE] + + assert data_types.encode_type_refs(node) == encode_types(expected_types) + + +def test_get_type_refs_join(scalars_df_index): + node = ( + scalars_df_index[["int64_col"]].merge( + scalars_df_index[["float64_col"]], + left_on="int64_col", + right_on="float64_col", + ) + )._block._expr.node + expected_types = [dtypes.INT_DTYPE, dtypes.FLOAT_DTYPE] + + assert data_types.encode_type_refs(node) == encode_types(expected_types) + + +def test_get_type_refs_isin(scalars_df_index): + node = scalars_df_index["string_col"].isin(["a"])._block._expr.node + expected_types = [dtypes.STRING_DTYPE, dtypes.BOOL_DTYPE] + + assert data_types.encode_type_refs(node) == encode_types(expected_types) + + +def test_get_type_refs_agg(scalars_df_index): + node = scalars_df_index[["bool_col", "string_col"]].count()._block._expr.node + expected_types = [ + dtypes.INT_DTYPE, + dtypes.BOOL_DTYPE, + dtypes.STRING_DTYPE, + dtypes.FLOAT_DTYPE, + ] + + assert data_types.encode_type_refs(node) == encode_types(expected_types) + + +def test_get_type_refs_window(scalars_df_index): + node = ( + scalars_df_index[["string_col", "bool_col"]] + .groupby("string_col") + .rolling(window=3) + .count() + ._block._expr.node + ) + expected_types = [dtypes.STRING_DTYPE, dtypes.BOOL_DTYPE, dtypes.INT_DTYPE] + + assert data_types.encode_type_refs(node) == encode_types(expected_types) diff --git a/tests/unit/core/logging/test_data_types.py b/tests/unit/core/logging/test_data_types.py index 9e3d1f1ed0..09b3429f00 100644 --- a/tests/unit/core/logging/test_data_types.py +++ b/tests/unit/core/logging/test_data_types.py @@ -29,6 +29,7 @@ @pytest.mark.parametrize( ("dtype", "expected_mask"), [ + (None, 0), (UNKNOWN_TYPE, 1 << 0), (dtypes.INT_DTYPE, 1 << 1), (dtypes.FLOAT_DTYPE, 1 << 2), @@ -51,19 +52,3 @@ ) def test_get_dtype_mask(dtype, expected_mask): assert data_types._get_dtype_mask(dtype) == expected_mask - - -def test_add_data_type__type_overlap_no_op(): - curr_type = dtypes.STRING_DTYPE - existing_types = data_types._get_dtype_mask(curr_type) - - assert data_types._add_data_type(existing_types, curr_type) == existing_types - - -def test_add_data_type__new_type_updated(): - curr_type = dtypes.STRING_DTYPE - existing_types = data_types._get_dtype_mask(dtypes.INT_DTYPE) - - assert data_types._add_data_type( - existing_types, curr_type - ) == existing_types | data_types._get_dtype_mask(curr_type)