Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 53 additions & 5 deletions sagemaker-core/src/sagemaker/core/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1189,6 +1189,7 @@ def run(
job_name: Optional[str] = None,
experiment_config: Optional[Dict[str, str]] = None,
kms_key: Optional[str] = None,
entry_point: Optional[str] = None
):
"""Runs a processing job.

Expand Down Expand Up @@ -1216,6 +1217,9 @@ def run(
experiment_config (dict[str, str]): Experiment management configuration.
kms_key (str): The ARN of the KMS key that is used to encrypt the
user code file (default: None).
entry_point (str): Path (absolute or relative) to a custom entrypoint script file
(e.g., runproc.sh). The python script call is appended automatically.

Returns:
None or pipeline step arguments in case the Processor instance is built with
:class:`~sagemaker.workflow.pipeline_context.PipelineSession`
Expand All @@ -1227,6 +1231,7 @@ def run(
job_name,
inputs,
kms_key,
entry_point
)

# Submit a processing job.
Expand All @@ -1250,6 +1255,7 @@ def _pack_and_upload_code(
job_name,
inputs,
kms_key=None,
entry_point=None
):
"""Pack local code bundle and upload to Amazon S3."""
if code.startswith("s3://"):
Expand All @@ -1274,7 +1280,7 @@ def _pack_and_upload_code(
script = os.path.basename(code)
evaluated_kms_key = kms_key if kms_key else self.output_kms_key
s3_runproc_sh = self._create_and_upload_runproc(
script, evaluated_kms_key, entrypoint_s3_uri
script, evaluated_kms_key, entrypoint_s3_uri, entry_point, source_dir
)

return s3_runproc_sh, inputs, job_name
Expand Down Expand Up @@ -1312,12 +1318,12 @@ def _set_entrypoint(self, command, user_script_name):
)
self.entrypoint = self.framework_entrypoint_command + [user_script_location]

def _create_and_upload_runproc(self, user_script, kms_key, entrypoint_s3_uri):
def _create_and_upload_runproc(self, user_script, kms_key, entrypoint_s3_uri, entry_point=None, source_dir=None):
"""Create runproc shell script and upload to S3 bucket."""
from sagemaker.core.workflow.utilities import _pipeline_config, hash_object

if _pipeline_config and _pipeline_config.pipeline_name:
runproc_file_str = self._generate_framework_script(user_script)
runproc_file_str = self._generate_framework_script(user_script, entry_point, source_dir)
runproc_file_hash = hash_object(runproc_file_str)
s3_uri = s3.s3_path_join(
"s3://",
Expand All @@ -1336,16 +1342,19 @@ def _create_and_upload_runproc(self, user_script, kms_key, entrypoint_s3_uri):
)
else:
s3_runproc_sh = s3.S3Uploader.upload_string_as_file_body(
self._generate_framework_script(user_script),
self._generate_framework_script(user_script, entry_point, source_dir),
desired_s3_uri=entrypoint_s3_uri,
kms_key=kms_key,
sagemaker_session=self.sagemaker_session,
)

return s3_runproc_sh

def _generate_framework_script(self, user_script: str) -> str:
def _generate_framework_script(self, user_script: str, entry_point: str = None, source_dir: str = None) -> str:
"""Generate the framework entrypoint file (as text) for a processing job."""
if entry_point:
return self._generate_custom_framework_script(user_script, entry_point, source_dir)

return dedent(
"""\
#!/bin/bash
Expand Down Expand Up @@ -1383,6 +1392,45 @@ def _generate_framework_script(self, user_script: str) -> str:
entry_point=user_script,
)

def _generate_custom_framework_script(
self, user_script: str, entry_point: str, source_dir: str = None
) -> str:
"""
Generate a custom framework script with a user-provided entrypoint embedded.

Reads the entry_point file and embeds its content in the script,
then appends the command to execute the user script.

Args:
user_script (str): Relative path to the user script in the source bundle
entry_point (str): Path to the custom entrypoint script file
source_dir (str): Path to the source directory. If provided and entry_point
is relative, it will be combined with source_dir.

Returns:
str: The generated script content
"""
# Resolve the full path to the entry_point file
if source_dir and not os.path.isabs(entry_point):
full_entry_point_path = os.path.join(source_dir, entry_point)
else:
full_entry_point_path = entry_point

# Read the entry_point file content
with open(full_entry_point_path, "r", encoding="utf-8") as f:
entry_point_content = f.read()

# Generate the script with embedded entry_point content
return dedent("""\
{entry_point_content}

{entry_point_command} {entry_point} "$@"
""").format(
entry_point_content=entry_point_content,
entry_point_command=" ".join(self.command),
entry_point=user_script,
)


class FeatureStoreOutput(ApiObject):
"""Configuration for processing job outputs in Amazon SageMaker Feature Store."""
Expand Down
65 changes: 65 additions & 0 deletions sagemaker-core/tests/unit/test_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -863,6 +863,71 @@ def test_create_and_upload_runproc_without_pipeline(self, mock_session):
)
assert result == "s3://bucket/runproc.sh"

def test_generate_framework_script_with_custom_entry_point(self, mock_session):
processor = FrameworkProcessor(
role="arn:aws:iam::123456789012:role/SageMakerRole",
image_uri="test-image:latest",
command=["python3"],
instance_count=1,
instance_type="ml.m5.xlarge",
sagemaker_session=mock_session,
)

custom_script_content = "#!/bin/bash\necho 'THIS IS THE CUSTOM runproc.sh'\nset -e\n"

with tempfile.NamedTemporaryFile(mode="w", suffix=".sh", delete=False) as f:
f.write(custom_script_content)
entry_point_path = f.name

try:
script = processor._generate_framework_script(
"train.py", entry_point=entry_point_path
)
assert custom_script_content in script
assert "python3 train.py" in script
assert "tar -xzf sourcedir.tar.gz" not in script
finally:
os.unlink(entry_point_path)

def test_generate_framework_script_with_custom_entry_point_and_source_dir(self, mock_session):
processor = FrameworkProcessor(
role="arn:aws:iam::123456789012:role/SageMakerRole",
image_uri="test-image:latest",
command=["python3"],
instance_count=1,
instance_type="ml.m5.xlarge",
sagemaker_session=mock_session,
)

with tempfile.TemporaryDirectory() as tmpdir:
custom_script_content = "#!/bin/bash\necho 'custom from source_dir'\n"
script_path = os.path.join(tmpdir, "custom_runproc.sh")
with open(script_path, "w") as f:
f.write(custom_script_content)

script = processor._generate_framework_script(
"train.py",
entry_point="custom_runproc.sh",
source_dir=tmpdir,
)
assert custom_script_content in script
assert "python3 train.py" in script

def test_generate_framework_script_with_default_entry_point(self, mock_session):
processor = FrameworkProcessor(
role="arn:aws:iam::123456789012:role/SageMakerRole",
image_uri="test-image:latest",
command=["python3"],
instance_count=1,
instance_type="ml.m5.xlarge",
sagemaker_session=mock_session,
)

script = processor._generate_framework_script("train.py")
assert "#!/bin/bash" in script
assert "tar -xzf sourcedir.tar.gz" in script
assert "python3 train.py" in script


class TestHelperFunctions:
def test_processing_input_to_request_dict(self):
Expand Down
Loading