diff --git a/sagemaker-train/src/sagemaker/train/tuner.py b/sagemaker-train/src/sagemaker/train/tuner.py index cde1598481..93e4a35f16 100644 --- a/sagemaker-train/src/sagemaker/train/tuner.py +++ b/sagemaker-train/src/sagemaker/train/tuner.py @@ -1430,6 +1430,7 @@ def _build_training_job_definition(self, inputs): input_data_config.append( Channel( channel_name=inp.channel_name, + content_type=inp.content_type, data_source=DataSource( s3_data_source=S3DataSource( s3_data_type="S3Prefix", diff --git a/sagemaker-train/tests/unit/train/test_tuner.py b/sagemaker-train/tests/unit/train/test_tuner.py index c0255eac47..e5bb042d1b 100644 --- a/sagemaker-train/tests/unit/train/test_tuner.py +++ b/sagemaker-train/tests/unit/train/test_tuner.py @@ -574,3 +574,36 @@ def test_build_training_job_definition_includes_internal_channels(self): assert "train" in channel_names, "User 'train' channel should be included" assert "validation" in channel_names, "User 'validation' channel should be included" assert len(channel_names) == 4, "Should have exactly 4 channels" + + def test_build_training_job_definition_preserves_content_type(self): + """Test that InputData content_type is preserved when converting to Channel. + + This test verifies the fix for GitHub issue #5632 where content_type was + dropped during InputData -> Channel conversion in HyperparameterTuner. + """ + from sagemaker.core.training.configs import InputData + + mock_trainer = _create_mock_model_trainer(with_internal_channels=False) + + user_inputs = [ + InputData( + channel_name="train", + data_source="s3://bucket/train", + content_type="text/csv", + ), + ] + + tuner = HyperparameterTuner( + model_trainer=mock_trainer, + objective_metric_name="accuracy", + hyperparameter_ranges=_create_single_hp_range(), + ) + + definition = tuner._build_training_job_definition(user_inputs) + + train_channel = next( + ch for ch in definition.input_data_config if ch.channel_name == "train" + ) + assert train_channel.content_type == "text/csv", ( + "content_type should be preserved when converting InputData to Channel" + )