diff --git a/sagemaker-mlops/tests/integ/test_pipeline_train_registry.py b/sagemaker-mlops/tests/integ/test_pipeline_train_registry.py index 90c1eb3aaf..68ac923ae6 100644 --- a/sagemaker-mlops/tests/integ/test_pipeline_train_registry.py +++ b/sagemaker-mlops/tests/integ/test_pipeline_train_registry.py @@ -51,7 +51,27 @@ def test_pipeline_with_train_and_registry(sagemaker_session, pipeline_session, r # Parameters processing_instance_count = ParameterInteger(name="ProcessingInstanceCount", default_value=1) + processing_image_uri = ParameterString( + name="ProcessingImageUri", + default_value=image_uris.retrieve( + framework="sklearn", + region=region, + version="1.2-1", + py_version="py3", + instance_type="ml.m5.xlarge", + ), + ) training_instance_count = ParameterInteger(name="TrainingInstanceCount", default_value=1) + training_image_uri = ParameterString( + name="TrainingImageUri", + default_value=image_uris.retrieve( + framework="xgboost", + region=region, + version="1.0-1", + py_version="py3", + instance_type="ml.m5.xlarge", + ), + ) instance_type = ParameterString(name="InstanceType", default_value="ml.m5.xlarge") input_data = ParameterString( name="InputDataUrl", @@ -65,13 +85,7 @@ def test_pipeline_with_train_and_registry(sagemaker_session, pipeline_session, r # Processing step sklearn_processor = ScriptProcessor( - image_uri=image_uris.retrieve( - framework="sklearn", - region=region, - version="1.2-1", - py_version="py3", - instance_type="ml.m5.xlarge", - ), + image_uri=processing_image_uri, instance_type=instance_type, instance_count=processing_instance_count, base_job_name=f"{base_job_prefix}-sklearn", @@ -128,17 +142,8 @@ def test_pipeline_with_train_and_registry(sagemaker_session, pipeline_session, r cache_config=cache_config, ) - # Training step - image_uri = image_uris.retrieve( - framework="xgboost", - region=region, - version="1.0-1", - py_version="py3", - instance_type="ml.m5.xlarge", - ) - model_trainer = ModelTrainer( - training_image=image_uri, + training_image=training_image_uri, compute=Compute(instance_type=instance_type, instance_count=training_instance_count), base_job_name=f"{base_job_prefix}-xgboost", sagemaker_session=pipeline_session, @@ -195,7 +200,9 @@ def test_pipeline_with_train_and_registry(sagemaker_session, pipeline_session, r name=pipeline_name, parameters=[ processing_instance_count, + processing_image_uri, training_instance_count, + training_image_uri, instance_type, input_data, hyper_parameter_objective,