diff --git a/packages/bigframes/bigframes/core/compile/polars/compiler.py b/packages/bigframes/bigframes/core/compile/polars/compiler.py index 2477f27b6432..3f899478873c 100644 --- a/packages/bigframes/bigframes/core/compile/polars/compiler.py +++ b/packages/bigframes/bigframes/core/compile/polars/compiler.py @@ -37,7 +37,9 @@ import bigframes.operations.generic_ops as gen_ops import bigframes.operations.json_ops as json_ops import bigframes.operations.numeric_ops as num_ops +import bigframes.operations.remote_function_ops as remote_function_ops import bigframes.operations.string_ops as string_ops +import bigframes.operations.struct_ops as struct_ops from bigframes.core import agg_expressions, identifiers, nodes, ordering, window_spec from bigframes.core.compile.polars import lowering @@ -122,7 +124,7 @@ def _bigframes_dtype_to_polars_dtype( ] ) if bigframes.dtypes.is_array_like(dtype): - return pl.Array( + return pl.List( inner=_bigframes_dtype_to_polars_dtype( bigframes.dtypes.get_array_inner_type(dtype) ) @@ -502,6 +504,50 @@ def _(self, op: json_ops.ToJSON, input: pl.Expr) -> pl.Expr: else: return input.cast(pl.String()) + @compile_op.register(json_ops.ToJSONString) + def _(self, op: json_ops.ToJSONString, input: pl.Expr) -> pl.Expr: + from_type = self._expr_types.get(id(input)) + + def preprocess_binary( + expr: pl.Expr, dtype: bigframes.dtypes.ExpressionType + ) -> pl.Expr: + if dtype == bigframes.dtypes.BYTES_DTYPE: + return expr.bin.encode("base64") + if bigframes.dtypes.is_struct_like(dtype): + fields = bigframes.dtypes.get_struct_fields(dtype) + return pl.struct( + *[ + preprocess_binary( + expr.struct.field(name), field_type + ).alias(name) + for name, field_type in fields.items() + ] + ) + if bigframes.dtypes.is_array_like(dtype): + inner_type = bigframes.dtypes.get_array_inner_type(dtype) + return expr.list.eval(preprocess_binary(pl.element(), inner_type)) + return expr + + preprocessed = preprocess_binary(input, from_type) + + if bigframes.dtypes.is_struct_like(from_type): + result = preprocessed.struct.json_encode() + elif from_type == bigframes.dtypes.INT_DTYPE: + result = preprocessed.cast(pl.String) + elif from_type == bigframes.dtypes.BOOL_DTYPE: + result = ( + pl.when(preprocessed) + .then(pl.lit("true")) + .otherwise(pl.lit("false")) + ) + elif from_type == bigframes.dtypes.BYTES_DTYPE: + result = pl.lit('"') + preprocessed + pl.lit('"') + else: + wrapped = pl.struct(value=preprocessed).struct.json_encode() + result = wrapped.str.slice(9, wrapped.str.len_chars() - 10) + + return pl.when(input.is_null()).then(pl.lit("null")).otherwise(result) + @compile_op.register(arr_ops.ToArrayOp) def _(self, op: ops.ToArrayOp, *inputs: pl.Expr) -> pl.Expr: return pl.concat_list(*inputs) @@ -532,6 +578,36 @@ def _(self, op: ops.ArrayReduceOp, input: pl.Expr) -> pl.Expr: f"Haven't implemented array aggregation: {op.aggregation}" ) + @compile_op.register(struct_ops.StructOp) + def _(self, op: struct_ops.StructOp, *inputs: pl.Expr) -> pl.Expr: + return pl.struct(**{col: inp for col, inp in zip(op.column_names, inputs)}) # type: ignore + + @compile_op.register(struct_ops.StructFieldOp) + def _(self, op: struct_ops.StructFieldOp, *inputs: pl.Expr) -> pl.Expr: + return inputs[0].struct[op.name_or_index] + + @compile_op.register(remote_function_ops.PythonUdfOp) + def _(self, op: ops.PythonUdfOp, *inputs: pl.Expr) -> pl.Expr: + from bigframes.functions import function_template + + code = op.function_def.code.to_callable() + if op.function_def.signature.is_row_processor: + + def handler(py_struct): + args = list(py_struct.values()) + series_arg = function_template.get_pd_series(args[0]) + return code(series_arg, *args[1:]) + else: + + def handler(py_struct): + return code(*(field for field in py_struct.values())) + + return pl.struct(*inputs).map_elements( + handler, + return_dtype=_bigframes_dtype_to_polars_dtype(op.output_type()), + skip_nulls=False, + ) + @dataclasses.dataclass(frozen=True) class PolarsAggregateCompiler: scalar_compiler = PolarsExpressionCompiler() diff --git a/packages/bigframes/bigframes/functions/_function_session.py b/packages/bigframes/bigframes/functions/_function_session.py index 213ac6638490..2bc2b597372b 100644 --- a/packages/bigframes/bigframes/functions/_function_session.py +++ b/packages/bigframes/bigframes/functions/_function_session.py @@ -15,21 +15,16 @@ from __future__ import annotations -import collections.abc import functools -import inspect import logging import random import string -import sys import threading import time import warnings from typing import ( TYPE_CHECKING, - Any, Literal, - Mapping, Optional, Sequence, Union, @@ -512,22 +507,10 @@ def wrapper(func): TypeError, f"func must be a callable, got {func}" ) - if sys.version_info >= (3, 10): - # Add `eval_str = True` so that deferred annotations are turned into their - # corresponding type objects. Need Python 3.10 for eval_str parameter. - # https://docs.python.org/3/library/inspect.html#inspect.signature - signature_kwargs: Mapping[str, Any] = {"eval_str": True} - else: - signature_kwargs = {} # type: ignore - - py_sig = _resolve_signature( - inspect.signature(func, **signature_kwargs), + udf_sig = _utils.get_func_signature( + func, input_types, output_type, - ) - - udf_sig = udf_def.UdfSignature.from_py_signature( - py_sig ).to_remote_function_compatible() full_package_requirements = _utils.get_updated_package_requirements( @@ -786,23 +769,11 @@ def wrapper(func): TypeError, f"func must be a callable, got {func}" ) - if sys.version_info >= (3, 10): - # Add `eval_str = True` so that deferred annotations are turned into their - # corresponding type objects. Need Python 3.10 for eval_str parameter. - # https://docs.python.org/3/library/inspect.html#inspect.signature - signature_kwargs: Mapping[str, Any] = {"eval_str": True} - else: - signature_kwargs = {} # type: ignore - - py_sig = inspect.signature( + udf_sig = _utils.get_func_signature( func, - **signature_kwargs, + input_types, + output_type, ) - py_sig = _resolve_signature(py_sig, input_types, output_type) - - # The function will actually be receiving a pandas Series, but allow - # both BigQuery DataFrames and pandas object types for compatibility. - udf_sig = udf_def.UdfSignature.from_py_signature(py_sig) code_def = udf_def.CodeDef.from_func(func, package_requirements=packages) requirements = udf_def.RuntimeRequirements( @@ -878,36 +849,6 @@ def deploy_udf( return self.udf(_force_deploy=True, **kwargs)(func) -def _resolve_signature( - py_sig: inspect.Signature, - input_types: Union[None, type, Sequence[type]] = None, - output_type: Optional[type] = None, -) -> inspect.Signature: - if input_types is not None: - if not isinstance(input_types, collections.abc.Sequence): - input_types = [input_types] - if _utils.has_conflict_input_type(py_sig, input_types): - msg = bfe.format_message( - "Conflicting input types detected, using the one from the decorator." - ) - warnings.warn(msg, category=bfe.FunctionConflictTypeHintWarning) - py_sig = py_sig.replace( - parameters=[ - par.replace(annotation=itype) - for par, itype in zip(py_sig.parameters.values(), input_types) - ] - ) - if output_type: - if _utils.has_conflict_output_type(py_sig, output_type): - msg = bfe.format_message( - "Conflicting return type detected, using the one from the decorator." - ) - warnings.warn(msg, category=bfe.FunctionConflictTypeHintWarning) - py_sig = py_sig.replace(return_annotation=output_type) - - return py_sig - - def get_cloud_function_name( function_def: udf_def.CloudRunFunctionConfig, session_id=None, uniq_suffix=False ): diff --git a/packages/bigframes/bigframes/functions/_utils.py b/packages/bigframes/bigframes/functions/_utils.py index 36736cd6bd77..358f20b2ab42 100644 --- a/packages/bigframes/bigframes/functions/_utils.py +++ b/packages/bigframes/bigframes/functions/_utils.py @@ -13,13 +13,14 @@ # limitations under the License. +import collections import hashlib import inspect import json import sys import typing import warnings -from typing import Any, Optional, Sequence, Set, cast +from typing import Any, Mapping, Optional, Sequence, Set, cast import cloudpickle import google.api_core.exceptions @@ -31,7 +32,7 @@ import bigframes.exceptions as bfe import bigframes.formatting_helpers as bf_formatting -from bigframes.functions import function_typing +from bigframes.functions import function_typing, udf_def # Naming convention for the function artifacts _BIGFRAMES_FUNCTION_PREFIX = "bigframes" @@ -304,3 +305,54 @@ def has_conflict_output_type( return False return return_annotation != output_type + + +def get_func_signature( + func, + input_types: type | Sequence[type] | None = None, + output_type: type | None = None, +) -> udf_def.UdfSignature: + if sys.version_info >= (3, 10): + # Add `eval_str = True` so that deferred annotations are turned into their + # corresponding type objects. Need Python 3.10 for eval_str parameter. + # https://docs.python.org/3/library/inspect.html#inspect.signature + signature_kwargs: Mapping[str, Any] = {"eval_str": True} + else: + signature_kwargs = {} # type: ignore + + py_sig = resolve_signature( + inspect.signature(func, **signature_kwargs), + input_types, + output_type, + ) + return udf_def.UdfSignature.from_py_signature(py_sig) + + +def resolve_signature( + py_sig: inspect.Signature, + input_types: type | Sequence[type] | None = None, + output_type: type | None = None, +) -> inspect.Signature: + if input_types is not None: + if not isinstance(input_types, collections.abc.Sequence): + input_types = [input_types] + if has_conflict_input_type(py_sig, input_types): + msg = bfe.format_message( + "Conflicting input types detected, using the one from the decorator." + ) + warnings.warn(msg, category=bfe.FunctionConflictTypeHintWarning) + py_sig = py_sig.replace( + parameters=[ + par.replace(annotation=itype) + for par, itype in zip(py_sig.parameters.values(), input_types) + ] + ) + if output_type: + if has_conflict_output_type(py_sig, output_type): + msg = bfe.format_message( + "Conflicting return type detected, using the one from the decorator." + ) + warnings.warn(msg, category=bfe.FunctionConflictTypeHintWarning) + py_sig = py_sig.replace(return_annotation=output_type) + + return py_sig diff --git a/packages/bigframes/bigframes/testing/polars_session.py b/packages/bigframes/bigframes/testing/polars_session.py index d26ec63d9c0d..2806dab53f99 100644 --- a/packages/bigframes/bigframes/testing/polars_session.py +++ b/packages/bigframes/bigframes/testing/polars_session.py @@ -26,6 +26,7 @@ import bigframes.session.execution_spec import bigframes.session.executor import bigframes.session.metrics +from bigframes.functions import _utils, function, udf_def # Does not support to_sql, dry_run, peek, cached @@ -111,6 +112,29 @@ def read_pandas(self, pandas_dataframe, write_engine="default"): return bf_df + def udf( + self, + *, + input_types=None, + output_type=None, + **kwargs, + ): + def wrapper(func): + udf_sig = _utils.get_func_signature( + func, + input_types, + output_type, + ) + + code_def = udf_def.CodeDef.from_func(func) + udf_definition = udf_def.PythonUdf( + signature=udf_sig, + code=code_def, + ) + return function.UdfRoutine(func=func, _udf_def=udf_definition) + + return wrapper + @property def bqclient(self): # prevents logger from trying to call bq upon any errors diff --git a/packages/bigframes/tests/system/small/engines/test_generic_ops.py b/packages/bigframes/tests/system/small/engines/test_generic_ops.py index 05739a1c1b63..8484bce09951 100644 --- a/packages/bigframes/tests/system/small/engines/test_generic_ops.py +++ b/packages/bigframes/tests/system/small/engines/test_generic_ops.py @@ -300,6 +300,25 @@ def test_engines_astype_to_json(scalars_array_value: array_value.ArrayValue, eng assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) +@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) +def test_engines_to_json_string(scalars_array_value: array_value.ArrayValue, engine): + exprs = [ + ops.ToJSONString().as_expr(expression.deref("int64_col")), + ops.ToJSONString().as_expr( + # Use a const since float to json has precision issues + expression.const(5.2, bigframes.dtypes.FLOAT_DTYPE) + ), + ops.ToJSONString().as_expr(expression.deref("bool_col")), + ops.ToJSONString().as_expr( + # Use a const since "str_col" has special chars. + expression.const('"hello world"', bigframes.dtypes.STRING_DTYPE) + ), + ] + arr, _ = scalars_array_value.compute_values(exprs) + + assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) + + @pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) def test_engines_astype_timedelta(scalars_array_value: array_value.ArrayValue, engine): arr = apply_op( diff --git a/packages/bigframes/tests/unit/test_dataframe_polars.py b/packages/bigframes/tests/unit/test_dataframe_polars.py index 190280e0a745..c2dc979b71ef 100644 --- a/packages/bigframes/tests/unit/test_dataframe_polars.py +++ b/packages/bigframes/tests/unit/test_dataframe_polars.py @@ -1287,6 +1287,49 @@ def test_apply_series_scalar_callable( pandas.testing.assert_series_equal(bf_result, pd_result) +def test_df_map_with_udf(session): + df = bpd.DataFrame({"x": [1, 2, None, 4], "y": [5, None, 7, 8]}, dtype="Int64") + + @session.udf() + def foo(row: pd.Series) -> int: + if pd.isna(row["x"]) or pd.isna(row["y"]): + return -1 + return int(row["x"] * row["y"]) + + bf_result = df.apply(foo, axis=1).to_pandas() + pd_result = pd.Series([5, -1, -1, 32]) + assert_series_equal(bf_result, pd_result, check_dtype=False) + + +def test_df_apply_complex_udf(session): + df = bpd.DataFrame( + {"x": [1, 2, 3], "y": ["a", "b", "c"]}, + index=["row0", "row1", "row2"], + ) + + @session.udf() + def foo(row: pd.Series) -> str: + idx = str(row.name) + items_str = ";".join(f"{k}={v}" for k, v in row.items()) + return f"({idx}) -> {items_str}" + + bf_result = df.apply(foo, axis=1).to_pandas() + + pd_df = pd.DataFrame( + {"x": [1, 2, 3], "y": ["a", "b", "c"]}, + index=["row0", "row1", "row2"], + ) + + def pd_foo(row): + idx = str(row.name) + items_str = ";".join(f"{k}={v}" for k, v in row.items()) + return f"({idx}) -> {items_str}" + + pd_result = pd_df.apply(pd_foo, axis=1) + + assert_series_equal(bf_result, pd_result, check_dtype=False, check_index_type=False) + + def test_df_pipe( scalars_df_index, scalars_pandas_df_index, diff --git a/packages/bigframes/tests/unit/test_series_polars.py b/packages/bigframes/tests/unit/test_series_polars.py index 2e22d6ed4b6b..8b6d97d8b4b3 100644 --- a/packages/bigframes/tests/unit/test_series_polars.py +++ b/packages/bigframes/tests/unit/test_series_polars.py @@ -4561,6 +4561,20 @@ def test_map_series_input_duplicates_error(scalars_dfs): scalars_df.int64_too.map(bf_map_series, verify_integrity=True) +def test_series_map_with_udf(session): + series = bpd.Series([1, 2, None, 4], dtype="Int64") + + @session.udf(input_types=[int], output_type=int) + def foo(x): + if x is None: + return -1 + return x * 2 + + bf_result = series.map(foo).to_pandas() + pd_result = pd.Series([2, 4, -1, 8]) + assert_series_equal(bf_result, pd_result, check_dtype=False) + + @pytest.mark.skip( reason="NotImplementedError: Polars compiler hasn't implemented hash()" )