Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 45 additions & 4 deletions sagemaker-core/src/sagemaker/core/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
)
from sagemaker.core.local.local_session import LocalSession
from sagemaker.core.helper.session_helper import Session
from sagemaker.core.shapes import ProcessingInput, ProcessingOutput, ProcessingS3Input
from sagemaker.core.shapes import ProcessingInput, ProcessingOutput, ProcessingS3Input, ProcessingS3Output
from sagemaker.core.resources import ProcessingJob
from sagemaker.core.workflow.pipeline_context import PipelineSession
from sagemaker.core.common_utils import (
Expand Down Expand Up @@ -483,7 +483,46 @@ def _normalize_outputs(self, outputs=None):
# Generate a name for the ProcessingOutput if it doesn't have one.
if output.output_name is None:
output.output_name = "output-{}".format(count)
if output.s3_output and is_pipeline_variable(output.s3_output.s3_uri):
if output.s3_output and output.s3_output.s3_uri is not None and is_pipeline_variable(output.s3_output.s3_uri):
normalized_outputs.append(output)
continue
# If s3_output or s3_uri is None, auto-generate an S3 URI
if output.s3_output is None or output.s3_output.s3_uri is None:
if _pipeline_config:
s3_uri = Join(
on="/",
values=[
"s3:/",
self.sagemaker_session.default_bucket(),
*(
[self.sagemaker_session.default_bucket_prefix]
if self.sagemaker_session.default_bucket_prefix
else []
),
_pipeline_config.pipeline_name,
ExecutionVariables.PIPELINE_EXECUTION_ID,
_pipeline_config.step_name,
"output",
output.output_name,
],
)
else:
s3_uri = s3.s3_path_join(
"s3://",
self.sagemaker_session.default_bucket(),
self.sagemaker_session.default_bucket_prefix,
self._current_job_name,
"output",
output.output_name,
)
if output.s3_output is None:
output.s3_output = ProcessingS3Output(
s3_uri=s3_uri,
local_path="/opt/ml/processing/output",
s3_upload_mode="EndOfJob",
)
else:
output.s3_output.s3_uri = s3_uri
normalized_outputs.append(output)
continue
# If the output's s3_uri is not an s3_uri, create one.
Expand Down Expand Up @@ -1421,11 +1460,13 @@ def _processing_output_to_request_dict(processing_output):
}

if processing_output.s3_output:
request_dict["S3Output"] = {
"S3Uri": processing_output.s3_output.s3_uri,
s3_output_dict = {
"LocalPath": processing_output.s3_output.local_path,
"S3UploadMode": processing_output.s3_output.s3_upload_mode,
}
if processing_output.s3_output.s3_uri is not None:
s3_output_dict["S3Uri"] = processing_output.s3_output.s3_uri
request_dict["S3Output"] = s3_output_dict

return request_dict

Expand Down
143 changes: 143 additions & 0 deletions sagemaker-core/tests/unit/test_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,149 @@ def test_normalize_outputs_invalid_type(self, mock_session):



class TestProcessingS3OutputOptionalS3Uri:
"""Tests for ProcessingS3Output with optional s3_uri (issue #5559)."""

def test_processing_s3_output_with_none_s3_uri_is_valid(self):
"""Verify ProcessingS3Output can be constructed with s3_uri=None."""
s3_output = ProcessingS3Output(
s3_uri=None,
local_path="/opt/ml/processing/output",
s3_upload_mode="EndOfJob",
)
assert s3_output.s3_uri is None
assert s3_output.local_path == "/opt/ml/processing/output"
assert s3_output.s3_upload_mode == "EndOfJob"

def test_processing_s3_output_default_s3_uri_is_none(self):
"""Verify ProcessingS3Output defaults s3_uri to None."""
s3_output = ProcessingS3Output(
local_path="/opt/ml/processing/output",
s3_upload_mode="EndOfJob",
)
assert s3_output.s3_uri is None

def test_normalize_outputs_with_none_s3_uri_generates_s3_path(self, mock_session):
"""Verify _normalize_outputs auto-generates S3 URI when s3_uri is None."""
processor = Processor(
role="arn:aws:iam::123456789012:role/SageMakerRole",
image_uri="test-image:latest",
instance_count=1,
instance_type="ml.m5.xlarge",
sagemaker_session=mock_session,
)
processor._current_job_name = "test-job"

s3_output = ProcessingS3Output(
s3_uri=None,
local_path="/opt/ml/processing/output",
s3_upload_mode="EndOfJob",
)
outputs = [ProcessingOutput(output_name="my-output", s3_output=s3_output)]

with patch("sagemaker.core.workflow.utilities._pipeline_config", None):
result = processor._normalize_outputs(outputs)

assert len(result) == 1
assert result[0].s3_output.s3_uri is not None
assert "s3://" in str(result[0].s3_output.s3_uri)
assert "test-job" in str(result[0].s3_output.s3_uri)
assert "my-output" in str(result[0].s3_output.s3_uri)

def test_normalize_outputs_with_none_s3_uri_and_pipeline_config(self, mock_session):
"""Verify _normalize_outputs generates Join-based S3 URI when s3_uri is None and pipeline config is set."""
from sagemaker.core.workflow.functions import Join

processor = Processor(
role="arn:aws:iam::123456789012:role/SageMakerRole",
image_uri="test-image:latest",
instance_count=1,
instance_type="ml.m5.xlarge",
sagemaker_session=mock_session,
)
processor._current_job_name = "test-job"

s3_output = ProcessingS3Output(
s3_uri=None,
local_path="/opt/ml/processing/output",
s3_upload_mode="EndOfJob",
)
outputs = [ProcessingOutput(output_name="my-output", s3_output=s3_output)]

with patch("sagemaker.core.workflow.utilities._pipeline_config") as mock_config:
mock_config.pipeline_name = "test-pipeline"
mock_config.step_name = "test-step"
result = processor._normalize_outputs(outputs)

assert len(result) == 1
assert result[0].s3_output.s3_uri is not None
assert isinstance(result[0].s3_output.s3_uri, Join)

def test_normalize_outputs_with_none_s3_output_generates_s3_path(self, mock_session):
"""Verify _normalize_outputs handles ProcessingOutput with s3_output=None."""
processor = Processor(
role="arn:aws:iam::123456789012:role/SageMakerRole",
image_uri="test-image:latest",
instance_count=1,
instance_type="ml.m5.xlarge",
sagemaker_session=mock_session,
)
processor._current_job_name = "test-job"

outputs = [ProcessingOutput(output_name="my-output", s3_output=None)]

with patch("sagemaker.core.workflow.utilities._pipeline_config", None):
result = processor._normalize_outputs(outputs)

assert len(result) == 1
assert result[0].s3_output is not None
assert result[0].s3_output.s3_uri is not None
assert "s3://" in str(result[0].s3_output.s3_uri)
assert "test-job" in str(result[0].s3_output.s3_uri)
assert "my-output" in str(result[0].s3_output.s3_uri)

def test_processing_output_to_request_dict_with_none_s3_uri(self):
"""Verify _processing_output_to_request_dict omits S3Uri when s3_uri is None."""
s3_output = ProcessingS3Output(
s3_uri=None,
local_path="/opt/ml/processing/output",
s3_upload_mode="EndOfJob",
)
processing_output = ProcessingOutput(output_name="results", s3_output=s3_output)

result = _processing_output_to_request_dict(processing_output)

assert result["OutputName"] == "results"
assert "S3Output" in result
assert "S3Uri" not in result["S3Output"]
assert result["S3Output"]["LocalPath"] == "/opt/ml/processing/output"
assert result["S3Output"]["S3UploadMode"] == "EndOfJob"

def test_processing_s3_output_with_explicit_s3_uri_unchanged(self, mock_session):
"""Verify existing behavior of explicit s3_uri is preserved (regression test)."""
processor = Processor(
role="arn:aws:iam::123456789012:role/SageMakerRole",
image_uri="test-image:latest",
instance_count=1,
instance_type="ml.m5.xlarge",
sagemaker_session=mock_session,
)
processor._current_job_name = "test-job"

s3_output = ProcessingS3Output(
s3_uri="s3://my-bucket/my-output",
local_path="/opt/ml/processing/output",
s3_upload_mode="EndOfJob",
)
outputs = [ProcessingOutput(output_name="my-output", s3_output=s3_output)]

with patch("sagemaker.core.workflow.utilities._pipeline_config", None):
result = processor._normalize_outputs(outputs)

assert len(result) == 1
assert result[0].s3_output.s3_uri == "s3://my-bucket/my-output"


class TestBugConditionFileUriReplacedInLocalMode:
"""Bug condition exploration test: file:// URIs should be preserved in local mode.

Expand Down
Loading