Skip to content

Commit 87ffc1f

Browse files
cleop-googlecopybara-github
authored andcommitted
feat: GenAI SDK client(multimodal) - Support creating multimodal dataset from bigframe DataFrame
PiperOrigin-RevId: 882664740
1 parent 14c298a commit 87ffc1f

5 files changed

Lines changed: 343 additions & 6 deletions

File tree

setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,8 @@
6969
"pyyaml>=5.3.1,<7",
7070
]
7171
datasets_extra_require = [
72-
"pyarrow >= 3.0.0, < 8.0.0; python_version<'3.11'",
72+
"pyarrow >= 3.0.0, < 8.0.0; python_version<'3.10'",
73+
"pyarrow >= 10.0.1; python_version=='3.10'",
7374
"pyarrow >= 10.0.1; python_version=='3.11'",
7475
"pyarrow >= 14.0.0; python_version>='3.12'",
7576
]

tests/unit/vertexai/genai/replays/conftest.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,11 @@ def replays_prefix():
123123
return "test"
124124

125125

126+
@pytest.fixture
127+
def is_replay_mode(request):
128+
return request.config.getoption("--mode") in ["replay", "tap"]
129+
130+
126131
@pytest.fixture
127132
def mock_agent_engine_create_path_exists():
128133
"""Mocks os.path.exists to return True."""

tests/unit/vertexai/genai/replays/test_create_multimodal_datasets.py

Lines changed: 85 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
# pylint: disable=protected-access,bad-continuation,missing-function-docstring
1616

1717
from unittest import mock
18+
19+
import bigframes.pandas
1820
from google.cloud import bigquery
1921
from tests.unit.vertexai.genai.replays import pytest_helper
2022
from vertexai._genai import _datasets_utils
@@ -28,11 +30,6 @@
2830
BIGQUERY_TABLE_NAME = "vertex-sdk-dev.multimodal_dataset.test-table"
2931

3032

31-
@pytest.fixture
32-
def is_replay_mode(request):
33-
return request.config.getoption("--mode") in ["replay", "tap"]
34-
35-
3633
@pytest.fixture
3734
def mock_bigquery_client(is_replay_mode):
3835
if is_replay_mode:
@@ -161,6 +158,47 @@ def test_create_dataset_from_pandas(client, is_replay_mode):
161158
pd.testing.assert_frame_equal(rows.to_dataframe(), dataframe)
162159

163160

161+
@pytest.mark.usefixtures("mock_bigquery_client", "mock_import_bigframes")
162+
def test_create_dataset_from_bigframes(client, is_replay_mode):
163+
dataframe = pd.DataFrame(
164+
{
165+
"col1": ["col1"],
166+
"col2": ["col2"],
167+
}
168+
)
169+
if is_replay_mode:
170+
bf_dataframe = mock.MagicMock()
171+
bf_dataframe.to_gbq.return_value = "temp_table_id"
172+
else:
173+
bf_dataframe = bigframes.pandas.DataFrame(dataframe)
174+
175+
dataset = client.datasets.create_from_bigframes(
176+
dataframe=bf_dataframe,
177+
target_table_id=BIGQUERY_TABLE_NAME,
178+
multimodal_dataset={
179+
"display_name": "test-from-bigframes",
180+
},
181+
)
182+
183+
assert isinstance(dataset, types.MultimodalDataset)
184+
assert dataset.display_name == "test-from-bigframes"
185+
assert dataset.metadata.input_config.bigquery_source.uri == (
186+
f"bq://{BIGQUERY_TABLE_NAME}"
187+
)
188+
if not is_replay_mode:
189+
bigquery_client = bigquery.Client(
190+
project=client._api_client.project,
191+
location=client._api_client.location,
192+
credentials=client._api_client._credentials,
193+
)
194+
rows = bigquery_client.list_rows(
195+
dataset.metadata.input_config.bigquery_source.uri[5:]
196+
)
197+
pd.testing.assert_frame_equal(
198+
rows.to_dataframe(), dataframe, check_index_type=False
199+
)
200+
201+
164202
pytestmark = pytest_helper.setup(
165203
file=__file__,
166204
globals_for_file=globals(),
@@ -279,3 +317,45 @@ async def test_create_dataset_from_pandas_async(client, is_replay_mode):
279317
dataset.metadata.input_config.bigquery_source.uri[5:]
280318
)
281319
pd.testing.assert_frame_equal(rows.to_dataframe(), dataframe)
320+
321+
322+
@pytest.mark.asyncio
323+
@pytest.mark.usefixtures("mock_bigquery_client", "mock_import_bigframes")
324+
async def test_create_dataset_from_bigframes_async(client, is_replay_mode):
325+
dataframe = pd.DataFrame(
326+
{
327+
"col1": ["col1"],
328+
"col2": ["col2"],
329+
}
330+
)
331+
if is_replay_mode:
332+
bf_dataframe = mock.MagicMock()
333+
bf_dataframe.to_gbq.return_value = "temp_table_id"
334+
else:
335+
bf_dataframe = bigframes.pandas.DataFrame(dataframe)
336+
337+
dataset = await client.aio.datasets.create_from_bigframes(
338+
dataframe=bf_dataframe,
339+
target_table_id=BIGQUERY_TABLE_NAME,
340+
multimodal_dataset={
341+
"display_name": "test-from-bigframes",
342+
},
343+
)
344+
345+
assert isinstance(dataset, types.MultimodalDataset)
346+
assert dataset.display_name == "test-from-bigframes"
347+
assert dataset.metadata.input_config.bigquery_source.uri == (
348+
f"bq://{BIGQUERY_TABLE_NAME}"
349+
)
350+
if not is_replay_mode:
351+
bigquery_client = bigquery.Client(
352+
project=client._api_client.project,
353+
location=client._api_client.location,
354+
credentials=client._api_client._credentials,
355+
)
356+
rows = bigquery_client.list_rows(
357+
dataset.metadata.input_config.bigquery_source.uri[5:]
358+
)
359+
pd.testing.assert_frame_equal(
360+
rows.to_dataframe(), dataframe, check_index_type=False
361+
)

tests/unit/vertexai/genai/replays/test_get_multimodal_datasets.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,34 @@
1515
# pylint: disable=protected-access,bad-continuation,missing-function-docstring
1616

1717
from tests.unit.vertexai.genai.replays import pytest_helper
18+
from vertexai._genai import _datasets_utils
1819
from vertexai._genai import types
1920

21+
from unittest import mock
2022
import pytest
2123

2224
BIGQUERY_TABLE_NAME = "vertex-sdk-dev.multimodal_dataset.test-table"
2325
DATASET = "8810841321427173376"
2426

2527

28+
@pytest.fixture
29+
def mock_import_bigframes(is_replay_mode):
30+
if is_replay_mode:
31+
with mock.patch.object(
32+
_datasets_utils, "_try_import_bigframes"
33+
) as mock_import_bigframes:
34+
mock_read_gbq_table_result = mock.MagicMock()
35+
mock_read_gbq_table_result.sql = f"SLECT * FROM `{BIGQUERY_TABLE_NAME}`"
36+
37+
bigframes = mock.MagicMock()
38+
bigframes.pandas.read_gbq_table.return_value = mock_read_gbq_table_result
39+
40+
mock_import_bigframes.return_value = bigframes
41+
yield mock_import_bigframes
42+
else:
43+
yield None
44+
45+
2646
def test_get_dataset(client):
2747
dataset = client.datasets._get_multimodal_dataset(
2848
name=DATASET,
@@ -41,6 +61,15 @@ def test_get_dataset_from_public_method(client):
4161
assert dataset.display_name == "test-display-name"
4262

4363

64+
@pytest.mark.usefixtures("mock_import_bigframes")
65+
def test_to_bigframes(client):
66+
dataset = client.datasets.get_multimodal_dataset(
67+
name=DATASET,
68+
)
69+
df = client.datasets.to_bigframes(multimodal_dataset=dataset)
70+
assert BIGQUERY_TABLE_NAME in df.sql
71+
72+
4473
pytestmark = pytest_helper.setup(
4574
file=__file__,
4675
globals_for_file=globals(),
@@ -67,3 +96,13 @@ async def test_get_dataset_from_public_method_async(client):
6796
assert isinstance(dataset, types.MultimodalDataset)
6897
assert dataset.name.endswith(DATASET)
6998
assert dataset.display_name == "test-display-name"
99+
100+
101+
@pytest.mark.asyncio
102+
@pytest.mark.usefixtures("mock_import_bigframes")
103+
async def test_to_bigframes_async(client):
104+
dataset = await client.aio.datasets.get_multimodal_dataset(
105+
name=DATASET,
106+
)
107+
df = await client.aio.datasets.to_bigframes(multimodal_dataset=dataset)
108+
assert BIGQUERY_TABLE_NAME in df.sql

0 commit comments

Comments
 (0)