|
15 | 15 | # pylint: disable=protected-access,bad-continuation,missing-function-docstring |
16 | 16 |
|
17 | 17 | from unittest import mock |
| 18 | + |
| 19 | +import bigframes.pandas |
18 | 20 | from google.cloud import bigquery |
19 | 21 | from tests.unit.vertexai.genai.replays import pytest_helper |
20 | 22 | from vertexai._genai import _datasets_utils |
|
28 | 30 | BIGQUERY_TABLE_NAME = "vertex-sdk-dev.multimodal_dataset.test-table" |
29 | 31 |
|
30 | 32 |
|
31 | | -@pytest.fixture |
32 | | -def is_replay_mode(request): |
33 | | - return request.config.getoption("--mode") in ["replay", "tap"] |
34 | | - |
35 | | - |
36 | 33 | @pytest.fixture |
37 | 34 | def mock_bigquery_client(is_replay_mode): |
38 | 35 | if is_replay_mode: |
@@ -161,6 +158,47 @@ def test_create_dataset_from_pandas(client, is_replay_mode): |
161 | 158 | pd.testing.assert_frame_equal(rows.to_dataframe(), dataframe) |
162 | 159 |
|
163 | 160 |
|
| 161 | +@pytest.mark.usefixtures("mock_bigquery_client", "mock_import_bigframes") |
| 162 | +def test_create_dataset_from_bigframes(client, is_replay_mode): |
| 163 | + dataframe = pd.DataFrame( |
| 164 | + { |
| 165 | + "col1": ["col1"], |
| 166 | + "col2": ["col2"], |
| 167 | + } |
| 168 | + ) |
| 169 | + if is_replay_mode: |
| 170 | + bf_dataframe = mock.MagicMock() |
| 171 | + bf_dataframe.to_gbq.return_value = "temp_table_id" |
| 172 | + else: |
| 173 | + bf_dataframe = bigframes.pandas.DataFrame(dataframe) |
| 174 | + |
| 175 | + dataset = client.datasets.create_from_bigframes( |
| 176 | + dataframe=bf_dataframe, |
| 177 | + target_table_id=BIGQUERY_TABLE_NAME, |
| 178 | + multimodal_dataset={ |
| 179 | + "display_name": "test-from-bigframes", |
| 180 | + }, |
| 181 | + ) |
| 182 | + |
| 183 | + assert isinstance(dataset, types.MultimodalDataset) |
| 184 | + assert dataset.display_name == "test-from-bigframes" |
| 185 | + assert dataset.metadata.input_config.bigquery_source.uri == ( |
| 186 | + f"bq://{BIGQUERY_TABLE_NAME}" |
| 187 | + ) |
| 188 | + if not is_replay_mode: |
| 189 | + bigquery_client = bigquery.Client( |
| 190 | + project=client._api_client.project, |
| 191 | + location=client._api_client.location, |
| 192 | + credentials=client._api_client._credentials, |
| 193 | + ) |
| 194 | + rows = bigquery_client.list_rows( |
| 195 | + dataset.metadata.input_config.bigquery_source.uri[5:] |
| 196 | + ) |
| 197 | + pd.testing.assert_frame_equal( |
| 198 | + rows.to_dataframe(), dataframe, check_index_type=False |
| 199 | + ) |
| 200 | + |
| 201 | + |
164 | 202 | pytestmark = pytest_helper.setup( |
165 | 203 | file=__file__, |
166 | 204 | globals_for_file=globals(), |
@@ -279,3 +317,45 @@ async def test_create_dataset_from_pandas_async(client, is_replay_mode): |
279 | 317 | dataset.metadata.input_config.bigquery_source.uri[5:] |
280 | 318 | ) |
281 | 319 | pd.testing.assert_frame_equal(rows.to_dataframe(), dataframe) |
| 320 | + |
| 321 | + |
| 322 | +@pytest.mark.asyncio |
| 323 | +@pytest.mark.usefixtures("mock_bigquery_client", "mock_import_bigframes") |
| 324 | +async def test_create_dataset_from_bigframes_async(client, is_replay_mode): |
| 325 | + dataframe = pd.DataFrame( |
| 326 | + { |
| 327 | + "col1": ["col1"], |
| 328 | + "col2": ["col2"], |
| 329 | + } |
| 330 | + ) |
| 331 | + if is_replay_mode: |
| 332 | + bf_dataframe = mock.MagicMock() |
| 333 | + bf_dataframe.to_gbq.return_value = "temp_table_id" |
| 334 | + else: |
| 335 | + bf_dataframe = bigframes.pandas.DataFrame(dataframe) |
| 336 | + |
| 337 | + dataset = await client.aio.datasets.create_from_bigframes( |
| 338 | + dataframe=bf_dataframe, |
| 339 | + target_table_id=BIGQUERY_TABLE_NAME, |
| 340 | + multimodal_dataset={ |
| 341 | + "display_name": "test-from-bigframes", |
| 342 | + }, |
| 343 | + ) |
| 344 | + |
| 345 | + assert isinstance(dataset, types.MultimodalDataset) |
| 346 | + assert dataset.display_name == "test-from-bigframes" |
| 347 | + assert dataset.metadata.input_config.bigquery_source.uri == ( |
| 348 | + f"bq://{BIGQUERY_TABLE_NAME}" |
| 349 | + ) |
| 350 | + if not is_replay_mode: |
| 351 | + bigquery_client = bigquery.Client( |
| 352 | + project=client._api_client.project, |
| 353 | + location=client._api_client.location, |
| 354 | + credentials=client._api_client._credentials, |
| 355 | + ) |
| 356 | + rows = bigquery_client.list_rows( |
| 357 | + dataset.metadata.input_config.bigquery_source.uri[5:] |
| 358 | + ) |
| 359 | + pd.testing.assert_frame_equal( |
| 360 | + rows.to_dataframe(), dataframe, check_index_type=False |
| 361 | + ) |
0 commit comments