From 1ea0cbbd3aec613c5eca28bb7fec0b368c359ed9 Mon Sep 17 00:00:00 2001 From: Raul Diaz Garcia Date: Fri, 10 Apr 2026 16:02:06 +0200 Subject: [PATCH] feature: allow custom runproc.sh in FrameworkProcessor --- .../src/sagemaker/core/processing.py | 58 +++++++++++++++-- sagemaker-core/tests/unit/test_processing.py | 65 +++++++++++++++++++ 2 files changed, 118 insertions(+), 5 deletions(-) diff --git a/sagemaker-core/src/sagemaker/core/processing.py b/sagemaker-core/src/sagemaker/core/processing.py index b507ae1a93..6ecdccd101 100644 --- a/sagemaker-core/src/sagemaker/core/processing.py +++ b/sagemaker-core/src/sagemaker/core/processing.py @@ -1189,6 +1189,7 @@ def run( job_name: Optional[str] = None, experiment_config: Optional[Dict[str, str]] = None, kms_key: Optional[str] = None, + entry_point: Optional[str] = None ): """Runs a processing job. @@ -1216,6 +1217,9 @@ def run( experiment_config (dict[str, str]): Experiment management configuration. kms_key (str): The ARN of the KMS key that is used to encrypt the user code file (default: None). + entry_point (str): Path (absolute or relative) to a custom entrypoint script file + (e.g., runproc.sh). The python script call is appended automatically. + Returns: None or pipeline step arguments in case the Processor instance is built with :class:`~sagemaker.workflow.pipeline_context.PipelineSession` @@ -1227,6 +1231,7 @@ def run( job_name, inputs, kms_key, + entry_point ) # Submit a processing job. @@ -1250,6 +1255,7 @@ def _pack_and_upload_code( job_name, inputs, kms_key=None, + entry_point=None ): """Pack local code bundle and upload to Amazon S3.""" if code.startswith("s3://"): @@ -1274,7 +1280,7 @@ def _pack_and_upload_code( script = os.path.basename(code) evaluated_kms_key = kms_key if kms_key else self.output_kms_key s3_runproc_sh = self._create_and_upload_runproc( - script, evaluated_kms_key, entrypoint_s3_uri + script, evaluated_kms_key, entrypoint_s3_uri, entry_point, source_dir ) return s3_runproc_sh, inputs, job_name @@ -1312,12 +1318,12 @@ def _set_entrypoint(self, command, user_script_name): ) self.entrypoint = self.framework_entrypoint_command + [user_script_location] - def _create_and_upload_runproc(self, user_script, kms_key, entrypoint_s3_uri): + def _create_and_upload_runproc(self, user_script, kms_key, entrypoint_s3_uri, entry_point=None, source_dir=None): """Create runproc shell script and upload to S3 bucket.""" from sagemaker.core.workflow.utilities import _pipeline_config, hash_object if _pipeline_config and _pipeline_config.pipeline_name: - runproc_file_str = self._generate_framework_script(user_script) + runproc_file_str = self._generate_framework_script(user_script, entry_point, source_dir) runproc_file_hash = hash_object(runproc_file_str) s3_uri = s3.s3_path_join( "s3://", @@ -1336,7 +1342,7 @@ def _create_and_upload_runproc(self, user_script, kms_key, entrypoint_s3_uri): ) else: s3_runproc_sh = s3.S3Uploader.upload_string_as_file_body( - self._generate_framework_script(user_script), + self._generate_framework_script(user_script, entry_point, source_dir), desired_s3_uri=entrypoint_s3_uri, kms_key=kms_key, sagemaker_session=self.sagemaker_session, @@ -1344,8 +1350,11 @@ def _create_and_upload_runproc(self, user_script, kms_key, entrypoint_s3_uri): return s3_runproc_sh - def _generate_framework_script(self, user_script: str) -> str: + def _generate_framework_script(self, user_script: str, entry_point: str = None, source_dir: str = None) -> str: """Generate the framework entrypoint file (as text) for a processing job.""" + if entry_point: + return self._generate_custom_framework_script(user_script, entry_point, source_dir) + return dedent( """\ #!/bin/bash @@ -1383,6 +1392,45 @@ def _generate_framework_script(self, user_script: str) -> str: entry_point=user_script, ) + def _generate_custom_framework_script( + self, user_script: str, entry_point: str, source_dir: str = None + ) -> str: + """ + Generate a custom framework script with a user-provided entrypoint embedded. + + Reads the entry_point file and embeds its content in the script, + then appends the command to execute the user script. + + Args: + user_script (str): Relative path to the user script in the source bundle + entry_point (str): Path to the custom entrypoint script file + source_dir (str): Path to the source directory. If provided and entry_point + is relative, it will be combined with source_dir. + + Returns: + str: The generated script content + """ + # Resolve the full path to the entry_point file + if source_dir and not os.path.isabs(entry_point): + full_entry_point_path = os.path.join(source_dir, entry_point) + else: + full_entry_point_path = entry_point + + # Read the entry_point file content + with open(full_entry_point_path, "r", encoding="utf-8") as f: + entry_point_content = f.read() + + # Generate the script with embedded entry_point content + return dedent("""\ + {entry_point_content} + + {entry_point_command} {entry_point} "$@" + """).format( + entry_point_content=entry_point_content, + entry_point_command=" ".join(self.command), + entry_point=user_script, + ) + class FeatureStoreOutput(ApiObject): """Configuration for processing job outputs in Amazon SageMaker Feature Store.""" diff --git a/sagemaker-core/tests/unit/test_processing.py b/sagemaker-core/tests/unit/test_processing.py index dbe8d5f9ef..0c6c2bb18f 100644 --- a/sagemaker-core/tests/unit/test_processing.py +++ b/sagemaker-core/tests/unit/test_processing.py @@ -863,6 +863,71 @@ def test_create_and_upload_runproc_without_pipeline(self, mock_session): ) assert result == "s3://bucket/runproc.sh" + def test_generate_framework_script_with_custom_entry_point(self, mock_session): + processor = FrameworkProcessor( + role="arn:aws:iam::123456789012:role/SageMakerRole", + image_uri="test-image:latest", + command=["python3"], + instance_count=1, + instance_type="ml.m5.xlarge", + sagemaker_session=mock_session, + ) + + custom_script_content = "#!/bin/bash\necho 'THIS IS THE CUSTOM runproc.sh'\nset -e\n" + + with tempfile.NamedTemporaryFile(mode="w", suffix=".sh", delete=False) as f: + f.write(custom_script_content) + entry_point_path = f.name + + try: + script = processor._generate_framework_script( + "train.py", entry_point=entry_point_path + ) + assert custom_script_content in script + assert "python3 train.py" in script + assert "tar -xzf sourcedir.tar.gz" not in script + finally: + os.unlink(entry_point_path) + + def test_generate_framework_script_with_custom_entry_point_and_source_dir(self, mock_session): + processor = FrameworkProcessor( + role="arn:aws:iam::123456789012:role/SageMakerRole", + image_uri="test-image:latest", + command=["python3"], + instance_count=1, + instance_type="ml.m5.xlarge", + sagemaker_session=mock_session, + ) + + with tempfile.TemporaryDirectory() as tmpdir: + custom_script_content = "#!/bin/bash\necho 'custom from source_dir'\n" + script_path = os.path.join(tmpdir, "custom_runproc.sh") + with open(script_path, "w") as f: + f.write(custom_script_content) + + script = processor._generate_framework_script( + "train.py", + entry_point="custom_runproc.sh", + source_dir=tmpdir, + ) + assert custom_script_content in script + assert "python3 train.py" in script + + def test_generate_framework_script_with_default_entry_point(self, mock_session): + processor = FrameworkProcessor( + role="arn:aws:iam::123456789012:role/SageMakerRole", + image_uri="test-image:latest", + command=["python3"], + instance_count=1, + instance_type="ml.m5.xlarge", + sagemaker_session=mock_session, + ) + + script = processor._generate_framework_script("train.py") + assert "#!/bin/bash" in script + assert "tar -xzf sourcedir.tar.gz" in script + assert "python3 train.py" in script + class TestHelperFunctions: def test_processing_input_to_request_dict(self):