Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@
"pyyaml>=5.3.1,<7",
]
datasets_extra_require = [
"pyarrow >= 3.0.0, < 8.0.0; python_version<'3.11'",
"pyarrow >= 3.0.0, < 8.0.0; python_version<'3.10'",
"pyarrow >= 10.0.1; python_version=='3.10'",
"pyarrow >= 10.0.1; python_version=='3.11'",
"pyarrow >= 14.0.0; python_version>='3.12'",
]
Expand Down
5 changes: 5 additions & 0 deletions tests/unit/vertexai/genai/replays/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,11 @@ def replays_prefix():
return "test"


@pytest.fixture
def is_replay_mode(request):
return request.config.getoption("--mode") in ["replay", "tap"]


@pytest.fixture
def mock_agent_engine_create_path_exists():
"""Mocks os.path.exists to return True."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,25 +14,23 @@
#
# pylint: disable=protected-access,bad-continuation,missing-function-docstring

import sys
from unittest import mock

from google.cloud import bigquery
from tests.unit.vertexai.genai.replays import pytest_helper
from vertexai._genai import _datasets_utils
from vertexai._genai import types
import pandas as pd
import pytest


METADATA_SCHEMA_URI = (
"gs://google-cloud-aiplatform/schema/dataset/metadata/multimodal_1.0.0.yaml"
)
BIGQUERY_TABLE_NAME = "vertex-sdk-dev.multimodal_dataset.test-table"


@pytest.fixture
def is_replay_mode(request):
return request.config.getoption("--mode") in ["replay", "tap"]


@pytest.fixture
def mock_bigquery_client(is_replay_mode):
if is_replay_mode:
Expand Down Expand Up @@ -161,6 +159,52 @@ def test_create_dataset_from_pandas(client, is_replay_mode):
pd.testing.assert_frame_equal(rows.to_dataframe(), dataframe)


@pytest.mark.skipif(
sys.version_info < (3, 10), reason="bigframes requires python 3.10 or higher"
)
@pytest.mark.usefixtures("mock_bigquery_client", "mock_import_bigframes")
def test_create_dataset_from_bigframes(client, is_replay_mode):
import bigframes.pandas

dataframe = pd.DataFrame(
{
"col1": ["col1"],
"col2": ["col2"],
}
)
if is_replay_mode:
bf_dataframe = mock.MagicMock()
bf_dataframe.to_gbq.return_value = "temp_table_id"
else:
bf_dataframe = bigframes.pandas.DataFrame(dataframe)

dataset = client.datasets.create_from_bigframes(
dataframe=bf_dataframe,
target_table_id=BIGQUERY_TABLE_NAME,
multimodal_dataset={
"display_name": "test-from-bigframes",
},
)

assert isinstance(dataset, types.MultimodalDataset)
assert dataset.display_name == "test-from-bigframes"
assert dataset.metadata.input_config.bigquery_source.uri == (
f"bq://{BIGQUERY_TABLE_NAME}"
)
if not is_replay_mode:
bigquery_client = bigquery.Client(
project=client._api_client.project,
location=client._api_client.location,
credentials=client._api_client._credentials,
)
rows = bigquery_client.list_rows(
dataset.metadata.input_config.bigquery_source.uri[5:]
)
pd.testing.assert_frame_equal(
rows.to_dataframe(), dataframe, check_index_type=False
)


pytestmark = pytest_helper.setup(
file=__file__,
globals_for_file=globals(),
Expand Down Expand Up @@ -279,3 +323,50 @@ async def test_create_dataset_from_pandas_async(client, is_replay_mode):
dataset.metadata.input_config.bigquery_source.uri[5:]
)
pd.testing.assert_frame_equal(rows.to_dataframe(), dataframe)


@pytest.mark.skipif(
sys.version_info < (3, 10), reason="bigframes requires python 3.10 or higher"
)
@pytest.mark.asyncio
@pytest.mark.usefixtures("mock_bigquery_client", "mock_import_bigframes")
async def test_create_dataset_from_bigframes_async(client, is_replay_mode):
import bigframes.pandas

dataframe = pd.DataFrame(
{
"col1": ["col1"],
"col2": ["col2"],
}
)
if is_replay_mode:
bf_dataframe = mock.MagicMock()
bf_dataframe.to_gbq.return_value = "temp_table_id"
else:
bf_dataframe = bigframes.pandas.DataFrame(dataframe)

dataset = await client.aio.datasets.create_from_bigframes(
dataframe=bf_dataframe,
target_table_id=BIGQUERY_TABLE_NAME,
multimodal_dataset={
"display_name": "test-from-bigframes",
},
)

assert isinstance(dataset, types.MultimodalDataset)
assert dataset.display_name == "test-from-bigframes"
assert dataset.metadata.input_config.bigquery_source.uri == (
f"bq://{BIGQUERY_TABLE_NAME}"
)
if not is_replay_mode:
bigquery_client = bigquery.Client(
project=client._api_client.project,
location=client._api_client.location,
credentials=client._api_client._credentials,
)
rows = bigquery_client.list_rows(
dataset.metadata.input_config.bigquery_source.uri[5:]
)
pd.testing.assert_frame_equal(
rows.to_dataframe(), dataframe, check_index_type=False
)
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,34 @@
# pylint: disable=protected-access,bad-continuation,missing-function-docstring

from tests.unit.vertexai.genai.replays import pytest_helper
from vertexai._genai import _datasets_utils
from vertexai._genai import types

from unittest import mock
import pytest

BIGQUERY_TABLE_NAME = "vertex-sdk-dev.multimodal_dataset.test-table"
DATASET = "8810841321427173376"


@pytest.fixture
def mock_import_bigframes(is_replay_mode):
if is_replay_mode:
with mock.patch.object(
_datasets_utils, "_try_import_bigframes"
) as mock_import_bigframes:
mock_read_gbq_table_result = mock.MagicMock()
mock_read_gbq_table_result.sql = f"SLECT * FROM `{BIGQUERY_TABLE_NAME}`"

bigframes = mock.MagicMock()
bigframes.pandas.read_gbq_table.return_value = mock_read_gbq_table_result

mock_import_bigframes.return_value = bigframes
yield mock_import_bigframes
else:
yield None


def test_get_dataset(client):
dataset = client.datasets._get_multimodal_dataset(
name=DATASET,
Expand All @@ -41,6 +61,15 @@ def test_get_dataset_from_public_method(client):
assert dataset.display_name == "test-display-name"


@pytest.mark.usefixtures("mock_import_bigframes")
def test_to_bigframes(client):
dataset = client.datasets.get_multimodal_dataset(
name=DATASET,
)
df = client.datasets.to_bigframes(multimodal_dataset=dataset)
assert BIGQUERY_TABLE_NAME in df.sql


pytestmark = pytest_helper.setup(
file=__file__,
globals_for_file=globals(),
Expand All @@ -67,3 +96,13 @@ async def test_get_dataset_from_public_method_async(client):
assert isinstance(dataset, types.MultimodalDataset)
assert dataset.name.endswith(DATASET)
assert dataset.display_name == "test-display-name"


@pytest.mark.asyncio
@pytest.mark.usefixtures("mock_import_bigframes")
async def test_to_bigframes_async(client):
dataset = await client.aio.datasets.get_multimodal_dataset(
name=DATASET,
)
df = await client.aio.datasets.to_bigframes(multimodal_dataset=dataset)
assert BIGQUERY_TABLE_NAME in df.sql
Loading
Loading