diff --git a/sagemaker-mlops/src/sagemaker/mlops/workflow/emr_step.py b/sagemaker-mlops/src/sagemaker/mlops/workflow/emr_step.py index 5b251af573..60f189a19a 100644 --- a/sagemaker-mlops/src/sagemaker/mlops/workflow/emr_step.py +++ b/sagemaker-mlops/src/sagemaker/mlops/workflow/emr_step.py @@ -11,6 +11,7 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. """The step definitions for workflow.""" + from __future__ import absolute_import from typing import Any, Dict, List, Union, Optional @@ -27,7 +28,12 @@ class EMRStepConfig: """Config for a Hadoop Jar step.""" def __init__( - self, jar, args: List[str] = None, main_class: str = None, properties: List[dict] = None + self, + jar, + args: List[str] = None, + main_class: str = None, + properties: List[dict] = None, + output_args: dict[str, str] = None, ): """Create a definition for input data used by an EMR cluster(job flow) step. @@ -41,12 +47,24 @@ def __init__( jar(str): A path to a JAR file run during the step. main_class(str): The name of the main class in the specified Java file. properties(List(dict)): A list of key-value pairs that are set when the step runs. + output_args(dict[str, str]): + A dict of argument-value pairs (output_name: S3 URI) that extends the command line + args and can be accessible in other steps via EMRStep.emr_outputs[output_name]. + Argument names are prepended by '--' automatically. + Example: {"output-path": "s3://my-bucket/output/"} will result in the following + command line args: ["--output-path", "s3://my-bucket/output/"] """ self.jar = jar self.args = args self.main_class = main_class self.properties = properties + self.output_args_index = {} + if output_args: + for output_arg_name, output_arg_value in output_args.items(): + self.args.extend([f"--{output_arg_name}", output_arg_value]) + self.output_args_index[output_arg_name] = len(self.args) - 1 + def to_request(self) -> RequestType: """Convert EMRStepConfig object to request dict.""" config = {"HadoopJarStep": {"Jar": self.jar}} @@ -230,6 +248,11 @@ def __init__( self.cache_config = cache_config self._properties = root_property + self.emr_outputs = { + output_name: self.properties.Config.Args[step_config.output_args_index[output_name]] + for output_name in step_config.output_args_index + } + @property def arguments(self) -> RequestType: """The arguments dict that is used to call `AddJobFlowSteps`. @@ -250,4 +273,4 @@ def to_request(self) -> RequestType: request_dict = super().to_request() if self.cache_config: request_dict.update(self.cache_config.config) - return request_dict \ No newline at end of file + return request_dict diff --git a/sagemaker-mlops/tests/unit/workflow/test_emr_step.py b/sagemaker-mlops/tests/unit/workflow/test_emr_step.py index bb636fde80..6ddeaeef37 100644 --- a/sagemaker-mlops/tests/unit/workflow/test_emr_step.py +++ b/sagemaker-mlops/tests/unit/workflow/test_emr_step.py @@ -17,6 +17,7 @@ from sagemaker.mlops.workflow.emr_step import EMRStep, EMRStepConfig from sagemaker.mlops.workflow.steps import StepTypeEnum +from sagemaker.core.workflow.properties import Properties def test_emr_step_config_init(): @@ -39,7 +40,7 @@ def test_emr_step_with_cluster_id(): display_name="EMR Step", description="Test EMR step", cluster_id="j-123456", - step_config=config + step_config=config, ) assert step.name == "emr-step" assert step.step_type == StepTypeEnum.EMR @@ -48,9 +49,7 @@ def test_emr_step_with_cluster_id(): def test_emr_step_with_cluster_config(): config = EMRStepConfig(jar="s3://bucket/my.jar") cluster_config = { - "Instances": { - "InstanceGroups": [{"InstanceType": "m5.xlarge", "InstanceCount": 1}] - } + "Instances": {"InstanceGroups": [{"InstanceType": "m5.xlarge", "InstanceCount": 1}]} } step = EMRStep( name="emr-step", @@ -58,7 +57,7 @@ def test_emr_step_with_cluster_config(): description="Test EMR step", cluster_id=None, step_config=config, - cluster_config=cluster_config + cluster_config=cluster_config, ) assert step.name == "emr-step" @@ -71,7 +70,7 @@ def test_emr_step_without_cluster_id_or_config_raises_error(): display_name="EMR Step", description="Test EMR step", cluster_id=None, - step_config=config + step_config=config, ) @@ -84,5 +83,17 @@ def test_emr_step_with_both_cluster_id_and_config_raises_error(): description="Test EMR step", cluster_id="j-123456", step_config=config, - cluster_config={"Instances": {}} + cluster_config={"Instances": {}}, ) + +def test_emr_step_with_output_args(): + config = EMRStepConfig(jar="s3://bucket/my.jar", args=["arg1"], output_args={"output": "s3://bucket/my/output/path"}) + step = EMRStep( + name="emr-step", + display_name="EMR Step", + description="Test EMR step", + cluster_id="j-123456", + step_config=config, + ) + assert "output" in step.emr_outputs + assert isinstance(step.emr_outputs["output"], Properties)