Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,26 @@ def mock_generate_multimodal_dataset_display_name():
yield mock_generate


@pytest.fixture
def mock_try_import_storage():
with mock.patch.object(
_datasets_utils, "_try_import_storage"
) as mock_import_storage:
blob = mock.MagicMock()
blob.download_as_text.return_value = (
'{"contents": ["test1"]}\n{"contents": ["test2"]}'
)

bucket = mock.MagicMock()
bucket.blob.return_value = blob

client = mock.MagicMock()
client.bucket.return_value = bucket
mock_import_storage.return_value.Client.return_value = client

yield mock_import_storage


def test_create_dataset(client):
create_dataset_operation = client.datasets._create_multimodal_dataset(
name="projects/vertex-sdk-dev/locations/us-central1",
Expand Down Expand Up @@ -295,6 +315,43 @@ def test_create_dataset_from_bigframes_preserves_other_metadata(client, is_repla
)


@pytest.mark.usefixtures(
"mock_bigquery_client", "mock_import_bigframes", "mock_try_import_storage"
)
def test_create_from_gemini_request_jsonl(client, is_replay_mode):
if is_replay_mode:
with mock.patch.object(client.datasets, "create_from_bigframes") as mock_create:
mock_ds = mock.MagicMock()
mock_ds.display_name = "test-from-gemini-jsonl"
mock_create.return_value = mock_ds

dataset = client.datasets.create_from_gemini_request_jsonl(
gcs_uri="gs://test-bucket/test-blob.jsonl",
target_table_id=BIGQUERY_TABLE_NAME,
multimodal_dataset={
"display_name": "test-from-gemini-jsonl",
},
)
assert dataset.display_name == "test-from-gemini-jsonl"
assert (
dataset.metadata.gemini_request_read_config.assembled_request_column_name
== "requests"
)
else:
dataset = client.datasets.create_from_gemini_request_jsonl(
gcs_uri="gs://test-bucket/test-blob.jsonl",
target_table_id=BIGQUERY_TABLE_NAME,
multimodal_dataset={
"display_name": "test-from-gemini-jsonl",
},
)
assert dataset.display_name == "test-from-gemini-jsonl"
assert (
dataset.metadata.gemini_request_read_config.assembled_request_column_name
== "requests"
)


pytestmark = pytest_helper.setup(
file=__file__,
globals_for_file=globals(),
Expand Down Expand Up @@ -549,3 +606,43 @@ async def test_create_dataset_from_bigframes_preserves_other_metadata_async(
assert dataset.metadata.input_config.bigquery_source.uri == (
f"bq://{BIGQUERY_TABLE_NAME}"
)


@pytest.mark.asyncio
@pytest.mark.usefixtures(
"mock_bigquery_client", "mock_import_bigframes", "mock_try_import_storage"
)
async def test_create_from_gemini_request_jsonl_async(client, is_replay_mode):
if is_replay_mode:
with mock.patch.object(
client.aio.datasets, "create_from_bigframes"
) as mock_create:
mock_ds = mock.MagicMock()
mock_ds.display_name = "test-from-gemini-jsonl-async"
mock_create.return_value = mock_ds

dataset = await client.aio.datasets.create_from_gemini_request_jsonl(
gcs_uri="gs://test-bucket/test-blob-async.jsonl",
target_table_id=BIGQUERY_TABLE_NAME,
multimodal_dataset={
"display_name": "test-from-gemini-jsonl-async",
},
)
assert dataset.display_name == "test-from-gemini-jsonl-async"
assert (
dataset.metadata.gemini_request_read_config.assembled_request_column_name
== "requests"
)
else:
dataset = await client.aio.datasets.create_from_gemini_request_jsonl(
gcs_uri="gs://test-bucket/test-blob-async.jsonl",
target_table_id=BIGQUERY_TABLE_NAME,
multimodal_dataset={
"display_name": "test-from-gemini-jsonl-async",
},
)
assert dataset.display_name == "test-from-gemini-jsonl-async"
assert (
dataset.metadata.gemini_request_read_config.assembled_request_column_name
== "requests"
)
13 changes: 13 additions & 0 deletions vertexai/_genai/_datasets_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,19 @@ def _try_import_bigquery() -> Any:
) from exc


def _try_import_storage() -> Any:
"""Tries to import `storage`."""
try:
from google.cloud import storage # type: ignore[attr-defined]

return storage
except ImportError as exc:
raise ImportError(
"`storage` is not installed. Please call 'pip install"
" google-cloud-storage'."
) from exc


def _bq_dataset_location_allowed(
vertex_location: str, bq_dataset_location: str
) -> bool:
Expand Down
193 changes: 193 additions & 0 deletions vertexai/_genai/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# Code generated by the Google Gen AI SDK generator DO NOT EDIT.

import asyncio
import io
import json
import logging
import time
Expand Down Expand Up @@ -1112,6 +1113,102 @@ def create_from_bigframes(
multimodal_dataset=multimodal_dataset, config=config
)

def create_from_gemini_request_jsonl(
self,
*,
gcs_uri: str,
multimodal_dataset: Optional[types.MultimodalDatasetOrDict] = None,
target_table_id: Optional[str] = None,
config: Optional[types.CreateMultimodalDatasetConfigOrDict] = None,
) -> types.MultimodalDataset:
"""Creates a multimodal dataset from a JSONL file stored on GCS.

The JSONL file should contain instances of Gemini
`GenerateContentRequest` on each line. The data will be stored in a
BigQuery table with a single column called "requests". The
request_column_name in the dataset metadata will be set to "requests".

Args:
gcs_uri (str):
The Google Cloud Storage URI of the JSONL file to import.
For example, 'gs://my-bucket/path/to/data.jsonl'
multimodal_dataset:
Optional. A representation of a multimodal dataset.
target_table_id (str):
Optional. The BigQuery table id where the dataframe will be
uploaded. The table id can be in the format of "dataset.table"
or "project.dataset.table". Note that the BigQuery
dataset must already exist and be in the same location as the
multimodal dataset. If not provided, a generated table id will
be created in the `vertex_datasets` dataset (e.g.
`project.vertex_datasets_us_central1.multimodal_dataset_4cbf7ffd`).
config:
Optional. A configuration for creating the multimodal dataset. If not
provided, the default configuration will be used.

Returns:
The created multimodal dataset.
"""
storage = _datasets_utils._try_import_storage()

if isinstance(multimodal_dataset, dict):
multimodal_dataset = types.MultimodalDataset(**multimodal_dataset)
elif not multimodal_dataset:
multimodal_dataset = types.MultimodalDataset()

gcs_uri_prefix = "gs://"
if gcs_uri.startswith(gcs_uri_prefix):
gcs_uri = gcs_uri[len(gcs_uri_prefix) :]
parts = gcs_uri.split("/", 1)
if len(parts) != 2:
raise ValueError(
"Invalid GCS URI format. Expected: gs://bucket-name/object-path"
)
bucket_name = parts[0]
blob_name = parts[1]

project = self._api_client.project
location = self._api_client.location
credentials = self._api_client._credentials

storage_client = storage.Client(project=project)
bucket = storage_client.bucket(bucket_name)
blob = bucket.blob(blob_name)
request_column_name = "requests"

jsonl_string = blob.download_as_text()
lines = [line.strip() for line in jsonl_string.splitlines() if line.strip()]
json_string = json.dumps({request_column_name: lines})

multimodal_dataset = multimodal_dataset.model_copy(deep=True)
metadata = multimodal_dataset.metadata or types.SchemaTablesDatasetMetadata()

read_config = (
metadata.gemini_request_read_config or types.GeminiRequestReadConfig()
)
read_config.assembled_request_column_name = request_column_name
metadata.gemini_request_read_config = read_config

multimodal_dataset.metadata = metadata

bigframes = _datasets_utils._try_import_bigframes()
session_options = bigframes.BigQueryOptions(
credentials=credentials,
project=project,
location=location,
)
with bigframes.connect(session_options) as session:
temp_bigframes_df = session.read_json(io.StringIO(json_string))
temp_bigframes_df[request_column_name] = bigframes.bigquery.parse_json(
temp_bigframes_df[request_column_name]
)
return self.create_from_bigframes(
dataframe=temp_bigframes_df,
multimodal_dataset=multimodal_dataset,
target_table_id=target_table_id,
config=config,
)

def update_multimodal_dataset(
self,
*,
Expand Down Expand Up @@ -2400,6 +2497,102 @@ async def create_from_bigframes(
multimodal_dataset=multimodal_dataset, config=config
)

async def create_from_gemini_request_jsonl(
self,
*,
gcs_uri: str,
multimodal_dataset: Optional[types.MultimodalDatasetOrDict] = None,
target_table_id: Optional[str] = None,
config: Optional[types.CreateMultimodalDatasetConfigOrDict] = None,
) -> types.MultimodalDataset:
"""Creates a multimodal dataset from a JSONL file stored on GCS.

The JSONL file should contain instances of Gemini
`GenerateContentRequest` on each line. The data will be stored in a
BigQuery table with a single column called "requests". The
request_column_name in the dataset metadata will be set to "requests".

Args:
gcs_uri (str):
The Google Cloud Storage URI of the JSONL file to import.
For example, 'gs://my-bucket/path/to/data.jsonl'
multimodal_dataset:
Optional. A representation of a multimodal dataset.
target_table_id (str):
Optional. The BigQuery table id where the dataframe will be
uploaded. The table id can be in the format of "dataset.table"
or "project.dataset.table". Note that the BigQuery
dataset must already exist and be in the same location as the
multimodal dataset. If not provided, a generated table id will
be created in the `vertex_datasets` dataset (e.g.
`project.vertex_datasets_us_central1.multimodal_dataset_4cbf7ffd`).
config:
Optional. A configuration for creating the multimodal dataset. If not
provided, the default configuration will be used.

Returns:
The created multimodal dataset.
"""
storage = _datasets_utils._try_import_storage()

if isinstance(multimodal_dataset, dict):
multimodal_dataset = types.MultimodalDataset(**multimodal_dataset)
elif not multimodal_dataset:
multimodal_dataset = types.MultimodalDataset()

gcs_uri_prefix = "gs://"
if gcs_uri.startswith(gcs_uri_prefix):
gcs_uri = gcs_uri[len(gcs_uri_prefix) :]
parts = gcs_uri.split("/", 1)
if len(parts) != 2:
raise ValueError(
"Invalid GCS URI format. Expected: gs://bucket-name/object-path"
)
bucket_name = parts[0]
blob_name = parts[1]

project = self._api_client.project
location = self._api_client.location
credentials = self._api_client._credentials

storage_client = storage.Client(project=project)
bucket = storage_client.bucket(bucket_name)
blob = bucket.blob(blob_name)
request_column_name = "requests"

jsonl_string = await asyncio.to_thread(blob.download_as_text)
lines = [line.strip() for line in jsonl_string.splitlines() if line.strip()]
json_string = json.dumps({request_column_name: lines})

multimodal_dataset = multimodal_dataset.model_copy(deep=True)
metadata = multimodal_dataset.metadata or types.SchemaTablesDatasetMetadata()

read_config = (
metadata.gemini_request_read_config or types.GeminiRequestReadConfig()
)
read_config.assembled_request_column_name = request_column_name
metadata.gemini_request_read_config = read_config

multimodal_dataset.metadata = metadata

bigframes = _datasets_utils._try_import_bigframes()
session_options = bigframes.BigQueryOptions(
credentials=credentials,
project=project,
location=location,
)
with bigframes.connect(session_options) as session:
temp_bigframes_df = session.read_json(io.StringIO(json_string))
temp_bigframes_df[request_column_name] = bigframes.bigquery.parse_json(
temp_bigframes_df[request_column_name]
)
return await self.create_from_bigframes(
dataframe=temp_bigframes_df,
multimodal_dataset=multimodal_dataset,
target_table_id=target_table_id,
config=config,
)

async def update_multimodal_dataset(
self,
*,
Expand Down
Loading