diff --git a/setup.py b/setup.py index edf3fca906..27571f9f39 100644 --- a/setup.py +++ b/setup.py @@ -69,7 +69,8 @@ "pyyaml>=5.3.1,<7", ] datasets_extra_require = [ - "pyarrow >= 3.0.0, < 8.0.0; python_version<'3.11'", + "pyarrow >= 3.0.0, < 8.0.0; python_version<'3.10'", + "pyarrow >= 10.0.1; python_version=='3.10'", "pyarrow >= 10.0.1; python_version=='3.11'", "pyarrow >= 14.0.0; python_version>='3.12'", ] diff --git a/tests/unit/vertexai/genai/replays/conftest.py b/tests/unit/vertexai/genai/replays/conftest.py index 7b7e41d0ce..5bb9c93f84 100644 --- a/tests/unit/vertexai/genai/replays/conftest.py +++ b/tests/unit/vertexai/genai/replays/conftest.py @@ -123,6 +123,11 @@ def replays_prefix(): return "test" +@pytest.fixture +def is_replay_mode(request): + return request.config.getoption("--mode") in ["replay", "tap"] + + @pytest.fixture def mock_agent_engine_create_path_exists(): """Mocks os.path.exists to return True.""" diff --git a/tests/unit/vertexai/genai/replays/test_create_multimodal_datasets.py b/tests/unit/vertexai/genai/replays/test_create_multimodal_datasets.py index 0911fa1b43..245be86f88 100644 --- a/tests/unit/vertexai/genai/replays/test_create_multimodal_datasets.py +++ b/tests/unit/vertexai/genai/replays/test_create_multimodal_datasets.py @@ -14,7 +14,9 @@ # # pylint: disable=protected-access,bad-continuation,missing-function-docstring +import sys from unittest import mock + from google.cloud import bigquery from tests.unit.vertexai.genai.replays import pytest_helper from vertexai._genai import _datasets_utils @@ -22,17 +24,13 @@ import pandas as pd import pytest + METADATA_SCHEMA_URI = ( "gs://google-cloud-aiplatform/schema/dataset/metadata/multimodal_1.0.0.yaml" ) BIGQUERY_TABLE_NAME = "vertex-sdk-dev.multimodal_dataset.test-table" -@pytest.fixture -def is_replay_mode(request): - return request.config.getoption("--mode") in ["replay", "tap"] - - @pytest.fixture def mock_bigquery_client(is_replay_mode): if is_replay_mode: @@ -161,6 +159,52 @@ def test_create_dataset_from_pandas(client, is_replay_mode): pd.testing.assert_frame_equal(rows.to_dataframe(), dataframe) +@pytest.mark.skipif( + sys.version_info < (3, 10), reason="bigframes requires python 3.10 or higher" +) +@pytest.mark.usefixtures("mock_bigquery_client", "mock_import_bigframes") +def test_create_dataset_from_bigframes(client, is_replay_mode): + import bigframes.pandas + + dataframe = pd.DataFrame( + { + "col1": ["col1"], + "col2": ["col2"], + } + ) + if is_replay_mode: + bf_dataframe = mock.MagicMock() + bf_dataframe.to_gbq.return_value = "temp_table_id" + else: + bf_dataframe = bigframes.pandas.DataFrame(dataframe) + + dataset = client.datasets.create_from_bigframes( + dataframe=bf_dataframe, + target_table_id=BIGQUERY_TABLE_NAME, + multimodal_dataset={ + "display_name": "test-from-bigframes", + }, + ) + + assert isinstance(dataset, types.MultimodalDataset) + assert dataset.display_name == "test-from-bigframes" + assert dataset.metadata.input_config.bigquery_source.uri == ( + f"bq://{BIGQUERY_TABLE_NAME}" + ) + if not is_replay_mode: + bigquery_client = bigquery.Client( + project=client._api_client.project, + location=client._api_client.location, + credentials=client._api_client._credentials, + ) + rows = bigquery_client.list_rows( + dataset.metadata.input_config.bigquery_source.uri[5:] + ) + pd.testing.assert_frame_equal( + rows.to_dataframe(), dataframe, check_index_type=False + ) + + pytestmark = pytest_helper.setup( file=__file__, globals_for_file=globals(), @@ -279,3 +323,50 @@ async def test_create_dataset_from_pandas_async(client, is_replay_mode): dataset.metadata.input_config.bigquery_source.uri[5:] ) pd.testing.assert_frame_equal(rows.to_dataframe(), dataframe) + + +@pytest.mark.skipif( + sys.version_info < (3, 10), reason="bigframes requires python 3.10 or higher" +) +@pytest.mark.asyncio +@pytest.mark.usefixtures("mock_bigquery_client", "mock_import_bigframes") +async def test_create_dataset_from_bigframes_async(client, is_replay_mode): + import bigframes.pandas + + dataframe = pd.DataFrame( + { + "col1": ["col1"], + "col2": ["col2"], + } + ) + if is_replay_mode: + bf_dataframe = mock.MagicMock() + bf_dataframe.to_gbq.return_value = "temp_table_id" + else: + bf_dataframe = bigframes.pandas.DataFrame(dataframe) + + dataset = await client.aio.datasets.create_from_bigframes( + dataframe=bf_dataframe, + target_table_id=BIGQUERY_TABLE_NAME, + multimodal_dataset={ + "display_name": "test-from-bigframes", + }, + ) + + assert isinstance(dataset, types.MultimodalDataset) + assert dataset.display_name == "test-from-bigframes" + assert dataset.metadata.input_config.bigquery_source.uri == ( + f"bq://{BIGQUERY_TABLE_NAME}" + ) + if not is_replay_mode: + bigquery_client = bigquery.Client( + project=client._api_client.project, + location=client._api_client.location, + credentials=client._api_client._credentials, + ) + rows = bigquery_client.list_rows( + dataset.metadata.input_config.bigquery_source.uri[5:] + ) + pd.testing.assert_frame_equal( + rows.to_dataframe(), dataframe, check_index_type=False + ) diff --git a/tests/unit/vertexai/genai/replays/test_get_multimodal_datasets.py b/tests/unit/vertexai/genai/replays/test_get_multimodal_datasets.py index f5facfe68f..93508e8287 100644 --- a/tests/unit/vertexai/genai/replays/test_get_multimodal_datasets.py +++ b/tests/unit/vertexai/genai/replays/test_get_multimodal_datasets.py @@ -15,14 +15,34 @@ # pylint: disable=protected-access,bad-continuation,missing-function-docstring from tests.unit.vertexai.genai.replays import pytest_helper +from vertexai._genai import _datasets_utils from vertexai._genai import types +from unittest import mock import pytest BIGQUERY_TABLE_NAME = "vertex-sdk-dev.multimodal_dataset.test-table" DATASET = "8810841321427173376" +@pytest.fixture +def mock_import_bigframes(is_replay_mode): + if is_replay_mode: + with mock.patch.object( + _datasets_utils, "_try_import_bigframes" + ) as mock_import_bigframes: + mock_read_gbq_table_result = mock.MagicMock() + mock_read_gbq_table_result.sql = f"SLECT * FROM `{BIGQUERY_TABLE_NAME}`" + + bigframes = mock.MagicMock() + bigframes.pandas.read_gbq_table.return_value = mock_read_gbq_table_result + + mock_import_bigframes.return_value = bigframes + yield mock_import_bigframes + else: + yield None + + def test_get_dataset(client): dataset = client.datasets._get_multimodal_dataset( name=DATASET, @@ -41,6 +61,15 @@ def test_get_dataset_from_public_method(client): assert dataset.display_name == "test-display-name" +@pytest.mark.usefixtures("mock_import_bigframes") +def test_to_bigframes(client): + dataset = client.datasets.get_multimodal_dataset( + name=DATASET, + ) + df = client.datasets.to_bigframes(multimodal_dataset=dataset) + assert BIGQUERY_TABLE_NAME in df.sql + + pytestmark = pytest_helper.setup( file=__file__, globals_for_file=globals(), @@ -67,3 +96,13 @@ async def test_get_dataset_from_public_method_async(client): assert isinstance(dataset, types.MultimodalDataset) assert dataset.name.endswith(DATASET) assert dataset.display_name == "test-display-name" + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("mock_import_bigframes") +async def test_to_bigframes_async(client): + dataset = await client.aio.datasets.get_multimodal_dataset( + name=DATASET, + ) + df = await client.aio.datasets.to_bigframes(multimodal_dataset=dataset) + assert BIGQUERY_TABLE_NAME in df.sql diff --git a/vertexai/_genai/datasets.py b/vertexai/_genai/datasets.py index 1b473c87bd..f5c96f0331 100644 --- a/vertexai/_genai/datasets.py +++ b/vertexai/_genai/datasets.py @@ -899,6 +899,112 @@ def create_from_pandas( config=config, ) + def create_from_bigframes( + self, + *, + dataframe: "bigframes.pandas.DataFrame", # type: ignore # noqa: F821 + multimodal_dataset: types.MultimodalDatasetOrDict, + target_table_id: Optional[str] = None, + config: Optional[types.CreateMultimodalDatasetConfigOrDict] = None, + ) -> types.MultimodalDataset: + """Creates a multimodal dataset from a bigframes dataframe. + + Args: + dataframe (bigframes.pandas.DataFrame): + The BigFrames dataframe that will be used for the created + dataset. + multimodal_dataset: + Required. A representation of a multimodal dataset. + target_table_id (str): + Optional. The BigQuery table id where the dataframe will be + uploaded. The table id can be in the format of "dataset.table" + or "project.dataset.table". If a table already exists with the + given table id, it will be overwritten. Note that the BigQuery + dataset must already exist and be in the same location as the + multimodal dataset. If not provided, a generated table id will + be created in the `vertex_datasets` dataset (e.g. + `project.vertex_datasets_us_central1.multimodal_dataset_4cbf7ffd`). + config: + Optional. A configuration for creating the multimodal dataset. If not + provided, the default configuration will be used. + + Returns: + dataset (MultimodalDataset): + The created multimodal dataset. + """ + if isinstance(multimodal_dataset, dict): + multimodal_dataset = types.MultimodalDataset(**multimodal_dataset) + elif not multimodal_dataset: + multimodal_dataset = types.MultimodalDataset() + + bigquery = _datasets_utils._try_import_bigquery() + project = self._api_client.project + location = self._api_client.location + credentials = self._api_client._credentials + + if target_table_id: + target_table_id = _datasets_utils._normalize_and_validate_table_id( + table_id=target_table_id, + project=project, + location=location, + credentials=credentials, + ) + else: + dataset_id = _datasets_utils._create_default_bigquery_dataset_if_not_exists( + project=project, location=location, credentials=credentials + ) + target_table_id = _datasets_utils._generate_target_table_id(dataset_id) + + temp_table_id = dataframe.to_gbq() + client = bigquery.Client(project=project, credentials=credentials) + copy_job = client.copy_table( + sources=temp_table_id, + destination=target_table_id, + ) + copy_job.result() + + return self.create_from_bigquery( + multimodal_dataset=multimodal_dataset.model_copy( + update={ + "metadata": types.SchemaTablesDatasetMetadata( + input_config=types.SchemaTablesDatasetMetadataInputConfig( + bigquery_source=types.SchemaTablesDatasetMetadataBigQuerySource( + uri=f"bq://{target_table_id}" + ) + ) + ) + } + ), + config=config, + ) + + def to_bigframes( + self, + *, + multimodal_dataset: types.MultimodalDatasetOrDict, + ) -> "bigframes.pandas.DataFrame": # type: ignore # noqa: F821 + """Converts a multimodal dataset to a BigFrames dataframe. + + This is the preferred method to inspect the multimodal dataset in a + notebook. + + Args: + multimodal_dataset: + Required. A representation of a multimodal dataset. + + Returns: + A BigFrames dataframe. + """ + bigframes = _datasets_utils._try_import_bigframes() + + if isinstance(multimodal_dataset, dict): + multimodal_dataset = types.MultimodalDataset(**multimodal_dataset) + elif not multimodal_dataset: + multimodal_dataset = types.MultimodalDataset() + + uri = _datasets_utils.multimodal_dataset_get_bigquery_uri(multimodal_dataset) + return bigframes.pandas.read_gbq_table(uri.lstrip("bq://")) + def update_multimodal_dataset( self, *, @@ -1948,6 +2054,112 @@ async def create_from_pandas( config=config, ) + async def create_from_bigframes( + self, + *, + dataframe: "bigframes.pandas.DataFrame", # type: ignore # noqa: F821 + multimodal_dataset: types.MultimodalDatasetOrDict, + target_table_id: Optional[str] = None, + config: Optional[types.CreateMultimodalDatasetConfigOrDict] = None, + ) -> types.MultimodalDataset: + """Creates a multimodal dataset from a bigframes dataframe. + + Args: + dataframe (bigframes.pandas.DataFrame): + The BigFrames dataframe that will be used for the created + dataset. + multimodal_dataset: + Required. A representation of a multimodal dataset. + target_table_id (str): + Optional. The BigQuery table id where the dataframe will be + uploaded. The table id can be in the format of "dataset.table" + or "project.dataset.table". If a table already exists with the + given table id, it will be overwritten. Note that the BigQuery + dataset must already exist and be in the same location as the + multimodal dataset. If not provided, a generated table id will + be created in the `vertex_datasets` dataset (e.g. + `project.vertex_datasets_us_central1.multimodal_dataset_4cbf7ffd`). + config: + Optional. A configuration for creating the multimodal dataset. If not + provided, the default configuration will be used. + + Returns: + dataset (MultimodalDataset): + The created multimodal dataset. + """ + if isinstance(multimodal_dataset, dict): + multimodal_dataset = types.MultimodalDataset(**multimodal_dataset) + elif not multimodal_dataset: + multimodal_dataset = types.MultimodalDataset() + + bigquery = _datasets_utils._try_import_bigquery() + project = self._api_client.project + location = self._api_client.location + credentials = self._api_client._credentials + + if target_table_id: + target_table_id = _datasets_utils._normalize_and_validate_table_id( + table_id=target_table_id, + project=project, + location=location, + credentials=credentials, + ) + else: + dataset_id = _datasets_utils._create_default_bigquery_dataset_if_not_exists( + project=project, location=location, credentials=credentials + ) + target_table_id = _datasets_utils._generate_target_table_id(dataset_id) + + temp_table_id = dataframe.to_gbq() + client = bigquery.Client(project=project, credentials=credentials) + copy_job = client.copy_table( + sources=temp_table_id, + destination=target_table_id, + ) + copy_job.result() + + return await self.create_from_bigquery( + multimodal_dataset=multimodal_dataset.model_copy( + update={ + "metadata": types.SchemaTablesDatasetMetadata( + input_config=types.SchemaTablesDatasetMetadataInputConfig( + bigquery_source=types.SchemaTablesDatasetMetadataBigQuerySource( + uri=f"bq://{target_table_id}" + ) + ) + ) + } + ), + config=config, + ) + + async def to_bigframes( + self, + *, + multimodal_dataset: types.MultimodalDatasetOrDict, + ) -> "bigframes.pandas.DataFrame": # type: ignore # noqa: F821 + """Converts a multimodal dataset to a BigFrames dataframe. + + This is the preferred method to inspect the multimodal dataset in a + notebook. + + Args: + multimodal_dataset: + Required. A representation of a multimodal dataset. + + Returns: + A BigFrames dataframe. + """ + bigframes = _datasets_utils._try_import_bigframes() + + if isinstance(multimodal_dataset, dict): + multimodal_dataset = types.MultimodalDataset(**multimodal_dataset) + elif not multimodal_dataset: + multimodal_dataset = types.MultimodalDataset() + + uri = _datasets_utils.multimodal_dataset_get_bigquery_uri(multimodal_dataset) + return bigframes.pandas.read_gbq_table(uri.lstrip("bq://")) + async def update_multimodal_dataset( self, *,