From 0393bd027f9f7462dbe91a9376d0db184d522a39 Mon Sep 17 00:00:00 2001 From: Shenyang Cai Date: Wed, 11 Mar 2026 22:18:42 +0000 Subject: [PATCH 1/2] feat: use EUC for AI IF, CLASSIFY, and SCORE when connection is not provided --- bigframes/bigquery/_operations/ai.py | 12 ++++++------ .../core/compile/sqlglot/expressions/ai_ops.py | 6 +++--- bigframes/operations/ai_ops.py | 6 +++--- .../test_ai_ops/test_ai_classify/None/out.sql | 3 +++ .../out.sql | 7 +++++++ .../snapshots/test_ai_ops/test_ai_if/None/out.sql | 3 +++ .../out.sql | 6 ++++++ .../test_ai_ops/test_ai_score/None/out.sql | 3 +++ .../out.sql | 6 ++++++ .../compile/sqlglot/expressions/test_ai_ops.py | 15 +++++++++------ .../ibis/expr/operations/ai_ops.py | 6 +++--- 11 files changed, 52 insertions(+), 21 deletions(-) create mode 100644 tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_classify/None/out.sql create mode 100644 tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_classify/bigframes-dev.us.bigframes-default-connection/out.sql create mode 100644 tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_if/None/out.sql create mode 100644 tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_if/bigframes-dev.us.bigframes-default-connection/out.sql create mode 100644 tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_score/None/out.sql create mode 100644 tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_score/bigframes-dev.us.bigframes-default-connection/out.sql diff --git a/bigframes/bigquery/_operations/ai.py b/bigframes/bigquery/_operations/ai.py index bc28cb2e353..e578f4be4a7 100644 --- a/bigframes/bigquery/_operations/ai.py +++ b/bigframes/bigquery/_operations/ai.py @@ -745,7 +745,7 @@ def if_( or pandas Series. connection_id (str, optional): Specifies the connection to use to communicate with the model. For example, `myproject.us.myconnection`. - If not provided, the connection from the current session will be used. + If not provided, the query uses your end-user credential. Returns: bigframes.series.Series: A new series of bools. @@ -756,7 +756,7 @@ def if_( operator = ai_ops.AIIf( prompt_context=tuple(prompt_context), - connection_id=_resolve_connection_id(series_list[0], connection_id), + connection_id=connection_id, ) return series_list[0]._apply_nary_op(operator, series_list[1:]) @@ -800,7 +800,7 @@ def classify( Categories to classify the input into. connection_id (str, optional): Specifies the connection to use to communicate with the model. For example, `myproject.us.myconnection`. - If not provided, the connection from the current session will be used. + If not provided, the query uses your end-user credential. Returns: bigframes.series.Series: A new series of strings. @@ -812,7 +812,7 @@ def classify( operator = ai_ops.AIClassify( prompt_context=tuple(prompt_context), categories=tuple(categories), - connection_id=_resolve_connection_id(series_list[0], connection_id), + connection_id=connection_id, ) return series_list[0]._apply_nary_op(operator, series_list[1:]) @@ -853,7 +853,7 @@ def score( or pandas Series. connection_id (str, optional): Specifies the connection to use to communicate with the model. For example, `myproject.us.myconnection`. - If not provided, the connection from the current session will be used. + If not provided, the query uses your end-user credential. Returns: bigframes.series.Series: A new series of double (float) values. @@ -864,7 +864,7 @@ def score( operator = ai_ops.AIScore( prompt_context=tuple(prompt_context), - connection_id=_resolve_connection_id(series_list[0], connection_id), + connection_id=connection_id, ) return series_list[0]._apply_nary_op(operator, series_list[1:]) diff --git a/bigframes/core/compile/sqlglot/expressions/ai_ops.py b/bigframes/core/compile/sqlglot/expressions/ai_ops.py index cc0cbaad8fe..df659097b33 100644 --- a/bigframes/core/compile/sqlglot/expressions/ai_ops.py +++ b/bigframes/core/compile/sqlglot/expressions/ai_ops.py @@ -113,9 +113,9 @@ def _construct_named_args(op: ops.NaryOp) -> list[sge.Kwarg]: ) ) - endpoit = op_args.get("endpoint", None) - if endpoit is not None: - args.append(sge.Kwarg(this="endpoint", expression=sge.Literal.string(endpoit))) + endpoint = op_args.get("endpoint", None) + if endpoint is not None: + args.append(sge.Kwarg(this="endpoint", expression=sge.Literal.string(endpoint))) request_type = op_args.get("request_type", None) if request_type is not None: diff --git a/bigframes/operations/ai_ops.py b/bigframes/operations/ai_ops.py index 8dc8c2ffab4..b20314fe232 100644 --- a/bigframes/operations/ai_ops.py +++ b/bigframes/operations/ai_ops.py @@ -123,7 +123,7 @@ class AIIf(base_ops.NaryOp): name: ClassVar[str] = "ai_if" prompt_context: Tuple[str | None, ...] - connection_id: str + connection_id: str | None def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType: return dtypes.BOOL_DTYPE @@ -135,7 +135,7 @@ class AIClassify(base_ops.NaryOp): prompt_context: Tuple[str | None, ...] categories: tuple[str, ...] - connection_id: str + connection_id: str | None def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType: return dtypes.STRING_DTYPE @@ -146,7 +146,7 @@ class AIScore(base_ops.NaryOp): name: ClassVar[str] = "ai_score" prompt_context: Tuple[str | None, ...] - connection_id: str + connection_id: str | None def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType: return dtypes.FLOAT_DTYPE diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_classify/None/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_classify/None/out.sql new file mode 100644 index 00000000000..6771527318f --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_classify/None/out.sql @@ -0,0 +1,3 @@ +SELECT + AI.CLASSIFY(input => (`string_col`), categories => ['greeting', 'rejection']) AS `result` +FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_classify/bigframes-dev.us.bigframes-default-connection/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_classify/bigframes-dev.us.bigframes-default-connection/out.sql new file mode 100644 index 00000000000..63c31d94566 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_classify/bigframes-dev.us.bigframes-default-connection/out.sql @@ -0,0 +1,7 @@ +SELECT + AI.CLASSIFY( + input => (`string_col`), + categories => ['greeting', 'rejection'], + connection_id => 'bigframes-dev.us.bigframes-default-connection' + ) AS `result` +FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_if/None/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_if/None/out.sql new file mode 100644 index 00000000000..bae091982ea --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_if/None/out.sql @@ -0,0 +1,3 @@ +SELECT + AI.IF(prompt => (`string_col`, ' is the same as ', `string_col`)) AS `result` +FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_if/bigframes-dev.us.bigframes-default-connection/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_if/bigframes-dev.us.bigframes-default-connection/out.sql new file mode 100644 index 00000000000..698523d2e0b --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_if/bigframes-dev.us.bigframes-default-connection/out.sql @@ -0,0 +1,6 @@ +SELECT + AI.IF( + prompt => (`string_col`, ' is the same as ', `string_col`), + connection_id => 'bigframes-dev.us.bigframes-default-connection' + ) AS `result` +FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_score/None/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_score/None/out.sql new file mode 100644 index 00000000000..6a16276734e --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_score/None/out.sql @@ -0,0 +1,3 @@ +SELECT + AI.SCORE(prompt => (`string_col`, ' is the same as ', `string_col`)) AS `result` +FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_score/bigframes-dev.us.bigframes-default-connection/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_score/bigframes-dev.us.bigframes-default-connection/out.sql new file mode 100644 index 00000000000..92de7cdcdc6 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_score/bigframes-dev.us.bigframes-default-connection/out.sql @@ -0,0 +1,6 @@ +SELECT + AI.SCORE( + prompt => (`string_col`, ' is the same as ', `string_col`), + connection_id => 'bigframes-dev.us.bigframes-default-connection' + ) AS `result` +FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py b/tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py index c0cbece9054..64a5a94c9e7 100644 --- a/tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py +++ b/tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py @@ -281,12 +281,13 @@ def test_ai_generate_double_with_model_param( snapshot.assert_match(sql, "out.sql") -def test_ai_if(scalar_types_df: dataframe.DataFrame, snapshot): +@pytest.mark.parametrize("connection_id", [None, CONNECTION_ID]) +def test_ai_if(scalar_types_df: dataframe.DataFrame, snapshot, connection_id): col_name = "string_col" op = ops.AIIf( prompt_context=(None, " is the same as ", None), - connection_id=CONNECTION_ID, + connection_id=connection_id, ) sql = utils._apply_ops_to_sql( @@ -296,13 +297,14 @@ def test_ai_if(scalar_types_df: dataframe.DataFrame, snapshot): snapshot.assert_match(sql, "out.sql") -def test_ai_classify(scalar_types_df: dataframe.DataFrame, snapshot): +@pytest.mark.parametrize("connection_id", [None, CONNECTION_ID]) +def test_ai_classify(scalar_types_df: dataframe.DataFrame, snapshot, connection_id): col_name = "string_col" op = ops.AIClassify( prompt_context=(None,), categories=("greeting", "rejection"), - connection_id=CONNECTION_ID, + connection_id=connection_id, ) sql = utils._apply_ops_to_sql(scalar_types_df, [op.as_expr(col_name)], ["result"]) @@ -310,12 +312,13 @@ def test_ai_classify(scalar_types_df: dataframe.DataFrame, snapshot): snapshot.assert_match(sql, "out.sql") -def test_ai_score(scalar_types_df: dataframe.DataFrame, snapshot): +@pytest.mark.parametrize("connection_id", [None, CONNECTION_ID]) +def test_ai_score(scalar_types_df: dataframe.DataFrame, snapshot, connection_id): col_name = "string_col" op = ops.AIScore( prompt_context=(None, " is the same as ", None), - connection_id=CONNECTION_ID, + connection_id=connection_id, ) sql = utils._apply_ops_to_sql( diff --git a/third_party/bigframes_vendored/ibis/expr/operations/ai_ops.py b/third_party/bigframes_vendored/ibis/expr/operations/ai_ops.py index ef387d33792..fe978e14536 100644 --- a/third_party/bigframes_vendored/ibis/expr/operations/ai_ops.py +++ b/third_party/bigframes_vendored/ibis/expr/operations/ai_ops.py @@ -113,7 +113,7 @@ class AIIf(Value): """Generate True/False based on the prompt""" prompt: Value - connection_id: Value[dt.String] + connection_id: Optional[Value[dt.String]] shape = rlz.shape_like("prompt") @@ -128,7 +128,7 @@ class AIClassify(Value): input: Value categories: Value[dt.Array[dt.String]] - connection_id: Value[dt.String] + connection_id: Optional[Value[dt.String]] shape = rlz.shape_like("input") @@ -142,7 +142,7 @@ class AIScore(Value): """Generate doubles based on the prompt""" prompt: Value - connection_id: Value[dt.String] + connection_id: Optional[Value[dt.String]] shape = rlz.shape_like("prompt") From 34bf269013825f3d6dc63499e536885e1d3d1c2a Mon Sep 17 00:00:00 2001 From: Shenyang Cai Date: Thu, 12 Mar 2026 00:54:43 +0000 Subject: [PATCH 2/2] remove unused files --- .../snapshots/test_ai_ops/test_ai_classify/out.sql | 7 ------- .../expressions/snapshots/test_ai_ops/test_ai_if/out.sql | 6 ------ .../snapshots/test_ai_ops/test_ai_score/out.sql | 6 ------ 3 files changed, 19 deletions(-) delete mode 100644 tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_classify/out.sql delete mode 100644 tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_if/out.sql delete mode 100644 tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_score/out.sql diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_classify/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_classify/out.sql deleted file mode 100644 index 63c31d94566..00000000000 --- a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_classify/out.sql +++ /dev/null @@ -1,7 +0,0 @@ -SELECT - AI.CLASSIFY( - input => (`string_col`), - categories => ['greeting', 'rejection'], - connection_id => 'bigframes-dev.us.bigframes-default-connection' - ) AS `result` -FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_if/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_if/out.sql deleted file mode 100644 index 698523d2e0b..00000000000 --- a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_if/out.sql +++ /dev/null @@ -1,6 +0,0 @@ -SELECT - AI.IF( - prompt => (`string_col`, ' is the same as ', `string_col`), - connection_id => 'bigframes-dev.us.bigframes-default-connection' - ) AS `result` -FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_score/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_score/out.sql deleted file mode 100644 index 92de7cdcdc6..00000000000 --- a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_score/out.sql +++ /dev/null @@ -1,6 +0,0 @@ -SELECT - AI.SCORE( - prompt => (`string_col`, ' is the same as ', `string_col`), - connection_id => 'bigframes-dev.us.bigframes-default-connection' - ) AS `result` -FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0` \ No newline at end of file