Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 28 additions & 28 deletions google/cloud/aiplatform/utils/gcs_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,8 @@ def stage_local_data_in_gcs(

Raises:
RuntimeError: When source_path does not exist.
RuntimeError: When staging_gcs_dir is not provided and staging_bucket
is not configured via aiplatform.init().
GoogleCloudError: When the upload process fails.
"""
data_path_obj = pathlib.Path(data_path)
Expand All @@ -194,26 +196,12 @@ def stage_local_data_in_gcs(

staging_gcs_dir = staging_gcs_dir or initializer.global_config.staging_bucket
if not staging_gcs_dir:
project = project or initializer.global_config.project
location = location or initializer.global_config.location
credentials = credentials or initializer.global_config.credentials
# Creating the bucket if it does not exist.
# Currently we only do this when staging_gcs_dir is not specified.
# The buckets that we create are regional.
# This prevents errors when some service required regional bucket.
# E.g. "FailedPrecondition: 400 The Cloud Storage bucket of `gs://...` is in location `us`. It must be in the same regional location as the service location `us-central1`."
# We are making the bucket name region-specific since the bucket is regional.
staging_bucket_name = project + "-vertex-staging-" + location
client = storage.Client(project=project, credentials=credentials)
staging_bucket = storage.Bucket(client=client, name=staging_bucket_name)
if not staging_bucket.exists():
_logger.info(f'Creating staging GCS bucket "{staging_bucket_name}"')
staging_bucket = client.create_bucket(
bucket_or_name=staging_bucket,
project=project,
location=location,
)
staging_gcs_dir = "gs://" + staging_bucket_name
raise RuntimeError(
"staging_gcs_dir should be passed to stage_local_data_in_gcs or "
"should be set using aiplatform.init(staging_bucket='gs://my-bucket'). "
"This is required to prevent the use of predictable bucket names "
"which could be exploited via bucket squatting attacks."
)

timestamp = datetime.datetime.now().isoformat(sep="-", timespec="milliseconds")
staging_gcs_subdir = (
Expand All @@ -239,20 +227,32 @@ def generate_gcs_directory_for_pipeline_artifacts(
project: Optional[str] = None,
location: Optional[str] = None,
):
"""Gets or creates the GCS directory for Vertex Pipelines artifacts.
"""Gets the GCS directory for Vertex Pipelines artifacts.

Requires staging_bucket to be configured via aiplatform.init().
The project and location parameters are deprecated and ignored.

Args:
project: Optional. Google Cloud Project that contains the staging bucket.
location: Optional. Google Cloud location to use for the staging bucket.
project: Deprecated. No longer used.
location: Deprecated. No longer used.

Returns:
Google Cloud Storage URI of the staged data.
Google Cloud Storage URI for pipeline artifacts.

Raises:
RuntimeError: When staging_bucket is not configured via aiplatform.init().
"""
project = project or initializer.global_config.project
location = location or initializer.global_config.location
pipeline_root = initializer.global_config.staging_bucket
if not pipeline_root:
raise RuntimeError(
"pipeline_root should be passed to PipelineJob or "
"should be set using aiplatform.init(staging_bucket='gs://my-bucket'). "
"This is required to prevent the use of predictable bucket names "
"which could be exploited via bucket squatting attacks."
)
validate_gcs_path(pipeline_root)

pipelines_bucket_name = project + "-vertex-pipelines-" + location
output_artifacts_gcs_dir = "gs://" + pipelines_bucket_name + "/output_artifacts/"
output_artifacts_gcs_dir = pipeline_root.rstrip("/") + "/output_artifacts/"
return output_artifacts_gcs_dir


Expand Down
35 changes: 25 additions & 10 deletions tests/unit/aiplatform/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,11 +577,27 @@ def test_stage_local_data_in_gcs(
== f"{staging_gcs_dir}/vertex_ai_auto_staging/{timestamp}/test.json"
)

def test_generate_gcs_directory_for_pipeline_artifacts(self):
output = gcs_utils.generate_gcs_directory_for_pipeline_artifacts(
"project", "us-central1"
)
assert output == "gs://project-vertex-pipelines-us-central1/output_artifacts/"
def test_generate_gcs_directory_for_pipeline_artifacts_with_staging_bucket(self):
with patch.object(
gcs_utils.initializer.global_config,
"staging_bucket",
"gs://my-staging-bucket",
):
output = gcs_utils.generate_gcs_directory_for_pipeline_artifacts(
"project", "us-central1"
)
assert output == "gs://my-staging-bucket/output_artifacts/"

def test_generate_gcs_directory_for_pipeline_artifacts_raises_without_staging_bucket(
self,
):
with patch.object(
gcs_utils.initializer.global_config, "staging_bucket", None
):
with pytest.raises(RuntimeError, match="pipeline_root should be passed"):
gcs_utils.generate_gcs_directory_for_pipeline_artifacts(
"project", "us-central1"
)

@patch.object(storage.Bucket, "exists", return_value=False)
@patch.object(storage, "Client")
Expand All @@ -593,15 +609,14 @@ def test_create_gcs_bucket_for_pipeline_artifacts_if_it_does_not_exist(
):
output = (
gcs_utils.create_gcs_bucket_for_pipeline_artifacts_if_it_does_not_exist(
project="test-project", location="us-central1"
output_artifacts_gcs_dir="gs://my-bucket/output_artifacts/",
project="test-project",
location="us-central1",
)
)
assert mock_storage_client.called
assert mock_bucket_not_exist.called
assert mock_get_project_number.called
assert (
output == "gs://test-project-vertex-pipelines-us-central1/output_artifacts/"
)
assert output == "gs://my-bucket/output_artifacts/"

def test_download_from_gcs_dir(
self, mock_storage_client_list_blobs, mock_storage_blob_download_to_filename
Expand Down