diff --git a/bigframes/core/compile/sqlglot/expressions/generic_ops.py b/bigframes/core/compile/sqlglot/expressions/generic_ops.py index 0fd8a010ae..2f486fc9d5 100644 --- a/bigframes/core/compile/sqlglot/expressions/generic_ops.py +++ b/bigframes/core/compile/sqlglot/expressions/generic_ops.py @@ -152,13 +152,17 @@ def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: return sge.Coalesce(this=left.expr, expressions=[right.expr]) -@register_unary_op(ops.RemoteFunctionOp, pass_op=True) -def _(expr: TypedExpr, op: ops.RemoteFunctionOp) -> sge.Expression: +def _get_remote_function_name(op): routine_ref = op.function_def.routine_ref # Quote project, dataset, and routine IDs to avoid keyword clashes. - func_name = ( + return ( f"`{routine_ref.project}`.`{routine_ref.dataset_id}`.`{routine_ref.routine_id}`" ) + + +@register_unary_op(ops.RemoteFunctionOp, pass_op=True) +def _(expr: TypedExpr, op: ops.RemoteFunctionOp) -> sge.Expression: + func_name = _get_remote_function_name(op) func = sge.func(func_name, expr.expr) if not op.apply_on_null: @@ -175,15 +179,16 @@ def _(expr: TypedExpr, op: ops.RemoteFunctionOp) -> sge.Expression: def _( left: TypedExpr, right: TypedExpr, op: ops.BinaryRemoteFunctionOp ) -> sge.Expression: - routine_ref = op.function_def.routine_ref - # Quote project, dataset, and routine IDs to avoid keyword clashes. - func_name = ( - f"`{routine_ref.project}`.`{routine_ref.dataset_id}`.`{routine_ref.routine_id}`" - ) - + func_name = _get_remote_function_name(op) return sge.func(func_name, left.expr, right.expr) +@register_nary_op(ops.NaryRemoteFunctionOp, pass_op=True) +def _(*operands: TypedExpr, op: ops.NaryRemoteFunctionOp) -> sge.Expression: + func_name = _get_remote_function_name(op) + return sge.func(func_name, *(operand.expr for operand in operands)) + + @register_nary_op(ops.case_when_op) def _(*cases_and_outputs: TypedExpr) -> sge.Expression: # Need to upcast BOOL to INT if any output is numeric diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_generic_ops/test_nary_remote_function_op/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_generic_ops/test_nary_remote_function_op/out.sql new file mode 100644 index 0000000000..a6641b13db --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_generic_ops/test_nary_remote_function_op/out.sql @@ -0,0 +1,15 @@ +WITH `bfcte_0` AS ( + SELECT + `float64_col`, + `int64_col`, + `string_col` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + `my_project`.`my_dataset`.`my_routine`(`int64_col`, `float64_col`, `string_col`) AS `bfcol_3` + FROM `bfcte_0` +) +SELECT + `bfcol_3` AS `int64_col` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/test_generic_ops.py b/tests/unit/core/compile/sqlglot/expressions/test_generic_ops.py index c4f16d93a1..2667e482c8 100644 --- a/tests/unit/core/compile/sqlglot/expressions/test_generic_ops.py +++ b/tests/unit/core/compile/sqlglot/expressions/test_generic_ops.py @@ -12,12 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +from google.cloud import bigquery import pandas as pd import pytest from bigframes import dtypes from bigframes import operations as ops from bigframes.core import expression as ex +from bigframes.functions import udf_def import bigframes.pandas as bpd from bigframes.testing import utils @@ -170,10 +172,6 @@ def test_astype_json_invalid( def test_remote_function_op(scalar_types_df: bpd.DataFrame, snapshot): - from google.cloud import bigquery - - from bigframes.functions import udf_def - bf_df = scalar_types_df[["int64_col"]] function_def = udf_def.BigqueryUdf( routine_ref=bigquery.RoutineReference.from_string( @@ -206,10 +204,6 @@ def test_remote_function_op(scalar_types_df: bpd.DataFrame, snapshot): def test_binary_remote_function_op(scalar_types_df: bpd.DataFrame, snapshot): - from google.cloud import bigquery - - from bigframes.functions import udf_def - bf_df = scalar_types_df[["int64_col", "float64_col"]] op = ops.BinaryRemoteFunctionOp( function_def=udf_def.BigqueryUdf( @@ -242,6 +236,44 @@ def test_binary_remote_function_op(scalar_types_df: bpd.DataFrame, snapshot): snapshot.assert_match(sql, "out.sql") +def test_nary_remote_function_op(scalar_types_df: bpd.DataFrame, snapshot): + bf_df = scalar_types_df[["int64_col", "float64_col", "string_col"]] + op = ops.NaryRemoteFunctionOp( + function_def=udf_def.BigqueryUdf( + routine_ref=bigquery.RoutineReference.from_string( + "my_project.my_dataset.my_routine" + ), + signature=udf_def.UdfSignature( + input_types=( + udf_def.UdfField( + "x", + bigquery.StandardSqlDataType( + type_kind=bigquery.StandardSqlTypeNames.INT64 + ), + ), + udf_def.UdfField( + "y", + bigquery.StandardSqlDataType( + type_kind=bigquery.StandardSqlTypeNames.FLOAT64 + ), + ), + udf_def.UdfField( + "z", + bigquery.StandardSqlDataType( + type_kind=bigquery.StandardSqlTypeNames.STRING + ), + ), + ), + output_bq_type=bigquery.StandardSqlDataType( + type_kind=bigquery.StandardSqlTypeNames.FLOAT64 + ), + ), + ) + ) + sql = utils._apply_nary_op(bf_df, op, "int64_col", "float64_col", "string_col") + snapshot.assert_match(sql, "out.sql") + + def test_case_when_op(scalar_types_df: bpd.DataFrame, snapshot): ops_map = { "single_case": ops.case_when_op.as_expr(