diff --git a/tests/unit/vertexai/genai/replays/test_create_multimodal_datasets.py b/tests/unit/vertexai/genai/replays/test_create_multimodal_datasets.py index 99e00e792e..7479c217c3 100644 --- a/tests/unit/vertexai/genai/replays/test_create_multimodal_datasets.py +++ b/tests/unit/vertexai/genai/replays/test_create_multimodal_datasets.py @@ -14,9 +14,12 @@ # # pylint: disable=protected-access,bad-continuation,missing-function-docstring +from unittest import mock +from google.cloud import bigquery from tests.unit.vertexai.genai.replays import pytest_helper from vertexai._genai import types - +from vertexai._genai import _datasets_utils +import pandas as pd import pytest METADATA_SCHEMA_URI = ( @@ -25,6 +28,52 @@ BIGQUERY_TABLE_NAME = "vertex-sdk-dev.multimodal_dataset.test-table" +@pytest.fixture +def is_replay_mode(request): + return request.config.getoption("--mode") in ["replay", "tap"] + + +@pytest.fixture +def mock_bigquery_client(is_replay_mode): + if is_replay_mode: + with mock.patch.object( + _datasets_utils, "_try_import_bigquery" + ) as mock_try_import_bigquery: + mock_dataset = mock.MagicMock() + mock_dataset.location = "us-central1" + + mock_client = mock.MagicMock() + mock_client.get_dataset.return_value = mock_dataset + + mock_try_import_bigquery.return_value.Client.return_value = mock_client + mock_try_import_bigquery.return_value.TableReference = ( + bigquery.TableReference + ) + + yield mock_try_import_bigquery + else: + yield None + + +@pytest.fixture +def mock_import_bigframes(is_replay_mode): + if is_replay_mode: + with mock.patch.object( + _datasets_utils, "_try_import_bigframes" + ) as mock_import_bigframes: + session = mock.MagicMock() + session.read_pandas.return_value = mock.MagicMock() + + bigframes = mock.MagicMock() + bigframes.connect.return_value = mock.MagicMock() + + mock_import_bigframes.return_value = bigframes + + yield mock_import_bigframes + else: + yield None + + def test_create_dataset(client): create_dataset_operation = client.datasets._create_multimodal_dataset( name="projects/vertex-sdk-dev/locations/us-central1", @@ -78,6 +127,40 @@ def test_create_dataset_from_bigquery_without_bq_prefix(client): ) +@pytest.mark.usefixtures("mock_bigquery_client", "mock_import_bigframes") +def test_create_dataset_from_pandas(client, is_replay_mode): + dataframe = pd.DataFrame( + { + "col1": ["col1"], + "col2": ["col2"], + } + ) + + dataset = client.datasets.create_from_pandas( + dataframe=dataframe, + target_table_id=BIGQUERY_TABLE_NAME, + multimodal_dataset={ + "display_name": "test-from-pandas", + }, + ) + + assert isinstance(dataset, types.MultimodalDataset) + assert dataset.display_name == "test-from-pandas" + assert dataset.metadata.input_config.bigquery_source.uri == ( + f"bq://{BIGQUERY_TABLE_NAME}" + ) + if not is_replay_mode: + bigquery_client = bigquery.Client( + project=client._api_client.project, + location=client._api_client.location, + credentials=client._api_client._credentials, + ) + rows = bigquery_client.list_rows( + dataset.metadata.input_config.bigquery_source.uri[5:] + ) + pd.testing.assert_frame_equal(rows.to_dataframe(), dataframe) + + pytestmark = pytest_helper.setup( file=__file__, globals_for_file=globals(), @@ -161,3 +244,38 @@ async def test_create_dataset_from_bigquery_async_without_bq_prefix(client): assert dataset.metadata.input_config.bigquery_source.uri == ( f"bq://{BIGQUERY_TABLE_NAME}" ) + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("mock_bigquery_client", "mock_import_bigframes") +async def test_create_dataset_from_pandas_async(client, is_replay_mode): + dataframe = pd.DataFrame( + { + "col1": ["col1row1", "col1row2"], + "col2": ["col2row1", "col2row2"], + } + ) + + dataset = await client.aio.datasets.create_from_pandas( + dataframe=dataframe, + target_table_id=BIGQUERY_TABLE_NAME, + multimodal_dataset={ + "display_name": "test-from-pandas", + }, + ) + + assert isinstance(dataset, types.MultimodalDataset) + assert dataset.display_name == "test-from-pandas" + assert dataset.metadata.input_config.bigquery_source.uri == ( + f"bq://{BIGQUERY_TABLE_NAME}" + ) + if not is_replay_mode: + bigquery_client = bigquery.Client( + project=client._api_client.project, + location=client._api_client.location, + credentials=client._api_client._credentials, + ) + rows = bigquery_client.list_rows( + dataset.metadata.input_config.bigquery_source.uri[5:] + ) + pd.testing.assert_frame_equal(rows.to_dataframe(), dataframe) diff --git a/vertexai/_genai/_datasets_utils.py b/vertexai/_genai/_datasets_utils.py index 9e5014bab9..8d01e27a5d 100644 --- a/vertexai/_genai/_datasets_utils.py +++ b/vertexai/_genai/_datasets_utils.py @@ -14,13 +14,19 @@ # """Utility functions for multimodal dataset.""" -from typing import Any, TypeVar, Type +import uuid +from typing import Any, Type, TypeVar + +import google.auth.credentials from vertexai._genai.types import common from pydantic import BaseModel METADATA_SCHEMA_URI = ( "gs://google-cloud-aiplatform/schema/dataset/metadata/multimodal_1.0.0.yaml" ) +_BQ_MULTIREGIONS = {"us", "eu"} +_DEFAULT_BQ_DATASET_PREFIX = "vertex_datasets" +_DEFAULT_BQ_TABLE_PREFIX = "multimodal_dataset" T = TypeVar("T", bound=BaseModel) @@ -34,3 +40,99 @@ def create_from_response(model_type: Type[T], response: dict[str, Any]) -> T: if snake_key in model_field_names: filtered_response[snake_key] = value return model_type(**filtered_response) + + +def _try_import_bigframes() -> Any: + """Tries to import `bigframes`.""" + try: + import bigframes + import bigframes.pandas + import bigframes.bigquery + + return bigframes + except ImportError as exc: + raise ImportError( + "`bigframes` is not installed. Please call 'pip install bigframes'." + ) from exc + + +def _try_import_bigquery() -> Any: + """Tries to import `bigquery`.""" + try: + from google.cloud import bigquery + + return bigquery + except ImportError as exc: + raise ImportError( + "`bigquery` is not installed. Please call 'pip install" + " google-cloud-bigquery'." + ) from exc + + +def _bq_dataset_location_allowed( + vertex_location: str, bq_dataset_location: str +) -> bool: + if bq_dataset_location == vertex_location: + return True + if bq_dataset_location in _BQ_MULTIREGIONS: + return vertex_location.startswith(bq_dataset_location) + return False + + +def _normalize_and_validate_table_id( + *, + table_id: str, + project: str, + location: str, + credentials: google.auth.credentials.Credentials, +) -> str: + bigquery = _try_import_bigquery() + + table_ref = bigquery.TableReference.from_string(table_id, default_project=project) + if table_ref.project != project: + raise ValueError( + "The BigQuery table " + f"`{table_ref.project}.{table_ref.dataset_id}.{table_ref.table_id}`" + " must be in the same project as the multimodal dataset." + f" The multimodal dataset is in `{project}`, but the BigQuery table" + f" is in `{table_ref.project}`." + ) + + dataset_ref = bigquery.DatasetReference( + project=table_ref.project, dataset_id=table_ref.dataset_id + ) + client = bigquery.Client(project=project, credentials=credentials) + bq_dataset = client.get_dataset(dataset_ref=dataset_ref) + if not _bq_dataset_location_allowed(location, bq_dataset.location): + raise ValueError( + "The BigQuery dataset" + f" `{dataset_ref.project}.{dataset_ref.dataset_id}` must be in the" + " same location as the multimodal dataset. The multimodal dataset" + f" is in `{location}`, but the BigQuery dataset is in" + f" `{bq_dataset.location}`." + ) + return f"{table_ref.project}.{table_ref.dataset_id}.{table_ref.table_id}" + + +def _create_default_bigquery_dataset_if_not_exists( + *, + project: str, + location: str, + credentials: google.auth.credentials.Credentials, +) -> str: + # Loading bigquery lazily to avoid auto-loading it when importing vertexai + from google.cloud import bigquery # pylint: disable=g-import-not-at-top + + bigquery_client = bigquery.Client(project=project, credentials=credentials) + location_str = location.lower().replace("-", "_") + dataset_id = bigquery.DatasetReference( + project, f"{_DEFAULT_BQ_DATASET_PREFIX}_{location_str}" + ) + dataset = bigquery.Dataset(dataset_ref=dataset_id) + dataset.location = location + bigquery_client.create_dataset(dataset, exists_ok=True) + return f"{dataset_id.project}.{dataset_id.dataset_id}" + + +def _generate_target_table_id(dataset_id: str) -> str: + return f"{dataset_id}.{_DEFAULT_BQ_TABLE_PREFIX}_{str(uuid.uuid4())}" diff --git a/vertexai/_genai/datasets.py b/vertexai/_genai/datasets.py index 9a6b507e10..80585ead94 100644 --- a/vertexai/_genai/datasets.py +++ b/vertexai/_genai/datasets.py @@ -27,6 +27,7 @@ from google.genai import types as genai_types from google.genai._common import get_value_by_path as getv from google.genai._common import set_value_by_path as setv +import pandas as pd from . import _datasets_utils from . import types @@ -836,6 +837,92 @@ def create_from_bigquery( ) return _datasets_utils.create_from_response(types.MultimodalDataset, response) + def create_from_pandas( + self, + *, + dataframe: pd.DataFrame, + multimodal_dataset: types.MultimodalDatasetOrDict, + target_table_id: Optional[str] = None, + config: Optional[types.CreateMultimodalDatasetConfigOrDict] = None, + ) -> types.MultimodalDataset: + """Creates a multimodal dataset from a pandas dataframe. + + Args: + dataframe (pandas.DataFrame): + The pandas dataframe to be used for the created dataset. + multimodal_dataset: + Required. A representation of a multimodal dataset. + target_table_id (str): + Optional. The BigQuery table id where the dataframe will be + uploaded. The table id can be in the format of "dataset.table" + or "project.dataset.table". If a table already exists with the + given table id, it will be overwritten. Note that the BigQuery + dataset must already exist and be in the same location as the + multimodal dataset. If not provided, a generated table id will + be created in the `vertex_datasets` dataset (e.g. + `project.vertex_datasets_us_central1.multimodal_dataset_4cbf7ffd`). + config: + Optional. A configuration for creating the multimodal dataset. If not + provided, the default configuration will be used. + + Returns: + dataset (MultimodalDataset): + The created multimodal dataset. + """ + if isinstance(multimodal_dataset, dict): + multimodal_dataset = types.MultimodalDataset(**multimodal_dataset) + elif not multimodal_dataset: + multimodal_dataset = types.MultimodalDataset() + + bigframes = _datasets_utils._try_import_bigframes() + bigquery = _datasets_utils._try_import_bigquery() + project = self._api_client.project + location = self._api_client.location + credentials = self._api_client._credentials + + if target_table_id: + target_table_id = _datasets_utils._normalize_and_validate_table_id( + table_id=target_table_id, + project=project, + location=location, + credentials=credentials, + ) + else: + dataset_id = _datasets_utils._create_default_bigquery_dataset_if_not_exists( + project=project, location=location, credentials=credentials + ) + target_table_id = _datasets_utils._generate_target_table_id(dataset_id) + + session_options = bigframes.BigQueryOptions( + credentials=credentials, + project=project, + location=location, + ) + with bigframes.connect(session_options) as session: + temp_bigframes_df = session.read_pandas(dataframe) + temp_table_id = temp_bigframes_df.to_gbq() + client = bigquery.Client(project=project, credentials=credentials) + copy_job = client.copy_table( + sources=temp_table_id, + destination=target_table_id, + ) + copy_job.result() + + return self.create_from_bigquery( + multimodal_dataset=multimodal_dataset.model_copy( + update={ + "metadata": types.SchemaTablesDatasetMetadata( + input_config=types.SchemaTablesDatasetMetadataInputConfig( + bigquery_source=types.SchemaTablesDatasetMetadataBigQuerySource( + uri=f"bq://{target_table_id}" + ) + ) + ) + } + ), + config=config, + ) + def update_multimodal_dataset( self, *, @@ -1847,6 +1934,92 @@ async def create_from_bigquery( ) return _datasets_utils.create_from_response(types.MultimodalDataset, response) + async def create_from_pandas( + self, + *, + dataframe: pd.DataFrame, + multimodal_dataset: types.MultimodalDatasetOrDict, + target_table_id: Optional[str] = None, + config: Optional[types.CreateMultimodalDatasetConfigOrDict] = None, + ) -> types.MultimodalDataset: + """Creates a multimodal dataset from a pandas dataframe. + + Args: + dataframe (pandas.DataFrame): + The pandas dataframe to be used for the created dataset. + multimodal_dataset: + Required. A representation of a multimodal dataset. + target_table_id (str): + Optional. The BigQuery table id where the dataframe will be + uploaded. The table id can be in the format of "dataset.table" + or "project.dataset.table". If a table already exists with the + given table id, it will be overwritten. Note that the BigQuery + dataset must already exist and be in the same location as the + multimodal dataset. If not provided, a generated table id will + be created in the `vertex_datasets` dataset (e.g. + `project.vertex_datasets_us_central1.multimodal_dataset_4cbf7ffd`). + config: + Optional. A configuration for creating the multimodal dataset. If not + provided, the default configuration will be used. + + Returns: + dataset (MultimodalDataset): + The created multimodal dataset. + """ + if isinstance(multimodal_dataset, dict): + multimodal_dataset = types.MultimodalDataset(**multimodal_dataset) + elif not multimodal_dataset: + multimodal_dataset = types.MultimodalDataset() + + bigframes = _datasets_utils._try_import_bigframes() + bigquery = _datasets_utils._try_import_bigquery() + project = self._api_client.project + location = self._api_client.location + credentials = self._api_client._credentials + + if target_table_id: + target_table_id = _datasets_utils._normalize_and_validate_table_id( + table_id=target_table_id, + project=project, + location=location, + credentials=credentials, + ) + else: + dataset_id = _datasets_utils._create_default_bigquery_dataset_if_not_exists( + project=project, location=location, credentials=credentials + ) + target_table_id = _datasets_utils._generate_target_table_id(dataset_id) + + session_options = bigframes.BigQueryOptions( + credentials=credentials, + project=project, + location=location, + ) + with bigframes.connect(session_options) as session: + temp_bigframes_df = session.read_pandas(dataframe) + temp_table_id = temp_bigframes_df.to_gbq() + client = bigquery.Client(project=project, credentials=credentials) + copy_job = client.copy_table( + sources=temp_table_id, + destination=target_table_id, + ) + copy_job.result() + + return await self.create_from_bigquery( + multimodal_dataset=multimodal_dataset.model_copy( + update={ + "metadata": types.SchemaTablesDatasetMetadata( + input_config=types.SchemaTablesDatasetMetadataInputConfig( + bigquery_source=types.SchemaTablesDatasetMetadataBigQuerySource( + uri=f"bq://{target_table_id}" + ) + ) + ) + } + ), + config=config, + ) + async def update_multimodal_dataset( self, *,