From b695743c3804f913ab4e8b2746c9bd1e629420dc Mon Sep 17 00:00:00 2001 From: Erick Benitez-Ramos Date: Thu, 30 Jan 2025 12:19:04 -0800 Subject: [PATCH 1/3] feat: Add support for deeepseek recipes --- .../modules/train/sm_recipes/utils.py | 37 +++++++++++++------ src/sagemaker/pytorch/estimator.py | 7 ++++ .../modules/train/sm_recipes/test_utils.py | 35 ++++++++++++++++++ 3 files changed, 67 insertions(+), 12 deletions(-) diff --git a/src/sagemaker/modules/train/sm_recipes/utils.py b/src/sagemaker/modules/train/sm_recipes/utils.py index ff38bcbde8..d93c5c8595 100644 --- a/src/sagemaker/modules/train/sm_recipes/utils.py +++ b/src/sagemaker/modules/train/sm_recipes/utils.py @@ -125,6 +125,27 @@ def _register_custom_resolvers(): OmegaConf.register_new_resolver("add", lambda *numbers: sum(numbers)) +def _get_trainining_recipe_gpu_model_name_and_script(model_type: str): + """Get the model base name and script for the training recipe.""" + + model_type_to_script = { + "llama_v3": ("llama", "llama_pretrain.py"), + "mistral": ("mistral", "mistral_pretrain.py"), + "mixtral": ("mixtral", "mixtral_pretrain.py"), + "deepseek": ("deepseek", "deepseek_pretrain.py"), + } + + for key in model_type_to_script.keys(): + if model_type.startswith(key): + model_type = key + break + + if model_type not in model_type_to_script: + raise ValueError(f"Model type {model_type} not supported") + + return model_type_to_script[model_type][0], model_type_to_script[model_type][1] + + def _configure_gpu_args( training_recipes_cfg: Dict[str, Any], region_name: str, @@ -140,24 +161,16 @@ def _configure_gpu_args( ) _run_clone_command_silent(adapter_repo, recipe_train_dir.name) - model_type_to_entry = { - "llama_v3": ("llama", "llama_pretrain.py"), - "mistral": ("mistral", "mistral_pretrain.py"), - "mixtral": ("mixtral", "mixtral_pretrain.py"), - } - if "model" not in recipe: raise ValueError("Supplied recipe does not contain required field model.") if "model_type" not in recipe["model"]: raise ValueError("Supplied recipe does not contain required field model_type.") model_type = recipe["model"]["model_type"] - if model_type not in model_type_to_entry: - raise ValueError(f"Model type {model_type} not supported") - source_code.source_dir = os.path.join( - recipe_train_dir.name, "examples", model_type_to_entry[model_type][0] - ) - source_code.entry_script = model_type_to_entry[model_type][1] + model_base_name, script = _get_trainining_recipe_gpu_model_name_and_script(model_type) + + source_code.source_dir = os.path.join(recipe_train_dir.name, "examples", model_base_name) + source_code.entry_script = script gpu_image_cfg = training_recipes_cfg.get("gpu_image") if isinstance(gpu_image_cfg, str): diff --git a/src/sagemaker/pytorch/estimator.py b/src/sagemaker/pytorch/estimator.py index 46c57581d1..dde465aafc 100644 --- a/src/sagemaker/pytorch/estimator.py +++ b/src/sagemaker/pytorch/estimator.py @@ -95,6 +95,7 @@ def _get_training_recipe_gpu_script(code_dir, recipe, source_dir): "llama_v3": ("llama", "llama_pretrain.py"), "mistral": ("mistral", "mistral_pretrain.py"), "mixtral": ("mixtral", "mixtral_pretrain.py"), + "deepseek": ("deepseek", "deepseek_pretrain.py"), } if "model" not in recipe: @@ -102,6 +103,12 @@ def _get_training_recipe_gpu_script(code_dir, recipe, source_dir): if "model_type" not in recipe["model"]: raise ValueError("Supplied recipe does not contain required field model_type.") model_type = recipe["model"]["model_type"] + + for key in model_type_to_script.keys(): + if model_type.startswith(key): + model_type = key + break + if model_type not in model_type_to_script: raise ValueError(f"Model type {model_type} not supported") diff --git a/tests/unit/sagemaker/modules/train/sm_recipes/test_utils.py b/tests/unit/sagemaker/modules/train/sm_recipes/test_utils.py index 66eafab4f0..f5f7ceb083 100644 --- a/tests/unit/sagemaker/modules/train/sm_recipes/test_utils.py +++ b/tests/unit/sagemaker/modules/train/sm_recipes/test_utils.py @@ -26,6 +26,7 @@ _load_recipes_cfg, _configure_gpu_args, _configure_trainium_args, + _get_trainining_recipe_gpu_model_name_and_script, ) from sagemaker.modules.utils import _run_clone_command_silent from sagemaker.modules.configs import Compute @@ -178,3 +179,37 @@ def test_get_args_from_recipe_compute( assert mock_gpu_args.call_count == 0 assert mock_trainium_args.call_count == 0 assert args is None + + @pytest.mark.parametrize( + "test_case", + [ + { + "model_type": "llama_v3", + "script": "llama_pretrain.py", + "model_base_name": "llama_v3", + }, + { + "model_type": "mistral", + "script": "mistral_pretrain.py", + "model_base_name": "mistral", + }, + { + "model_type": "deepseek_llamav3", + "script": "deepseek_pretrain.py", + "model_base_name": "deepseek", + }, + { + "model_type": "deepseek_qwenv2", + "script": "deepseek_pretrain.py", + "model_base_name": "deepseek", + }, + ], + ) + def test_get_trainining_recipe_gpu_model_name_and_script(test_case): + model_type = test_case["model_type"] + script = test_case["script"] + model_base_name, script = _get_trainining_recipe_gpu_model_name_and_script( + model_type, script + ) + assert model_base_name == test_case["model_base_name"] + assert script == test_case["script"] From 0bd510477453aa74cff175c1c5067fcc71a2d602 Mon Sep 17 00:00:00 2001 From: Erick Benitez-Ramos Date: Thu, 30 Jan 2025 12:38:05 -0800 Subject: [PATCH 2/3] pylint --- src/sagemaker/modules/train/sm_recipes/utils.py | 2 +- src/sagemaker/pytorch/estimator.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/sagemaker/modules/train/sm_recipes/utils.py b/src/sagemaker/modules/train/sm_recipes/utils.py index d93c5c8595..549645cbe2 100644 --- a/src/sagemaker/modules/train/sm_recipes/utils.py +++ b/src/sagemaker/modules/train/sm_recipes/utils.py @@ -135,7 +135,7 @@ def _get_trainining_recipe_gpu_model_name_and_script(model_type: str): "deepseek": ("deepseek", "deepseek_pretrain.py"), } - for key in model_type_to_script.keys(): + for key in model_type_to_script: if model_type.startswith(key): model_type = key break diff --git a/src/sagemaker/pytorch/estimator.py b/src/sagemaker/pytorch/estimator.py index dde465aafc..8f300d09fd 100644 --- a/src/sagemaker/pytorch/estimator.py +++ b/src/sagemaker/pytorch/estimator.py @@ -104,7 +104,7 @@ def _get_training_recipe_gpu_script(code_dir, recipe, source_dir): raise ValueError("Supplied recipe does not contain required field model_type.") model_type = recipe["model"]["model_type"] - for key in model_type_to_script.keys(): + for key in model_type_to_script: if model_type.startswith(key): model_type = key break From 0eb83590c7bd52a2bd53facbd3d958e2f8d604d8 Mon Sep 17 00:00:00 2001 From: Erick Benitez-Ramos Date: Thu, 30 Jan 2025 13:27:47 -0800 Subject: [PATCH 3/3] add unit test --- tests/unit/test_pytorch.py | 51 +++++++++++++++++++++++++++++++++++++- 1 file changed, 50 insertions(+), 1 deletion(-) diff --git a/tests/unit/test_pytorch.py b/tests/unit/test_pytorch.py index 6076d44e90..34d3c6784b 100644 --- a/tests/unit/test_pytorch.py +++ b/tests/unit/test_pytorch.py @@ -23,7 +23,10 @@ from sagemaker import image_uris from sagemaker.pytorch import defaults from sagemaker.pytorch import PyTorch, PyTorchPredictor, PyTorchModel -from sagemaker.pytorch.estimator import _get_training_recipe_image_uri +from sagemaker.pytorch.estimator import ( + _get_training_recipe_image_uri, + _get_training_recipe_gpu_script, +) from sagemaker.instance_group import InstanceGroup from sagemaker.session_settings import SessionSettings @@ -1049,6 +1052,52 @@ def test_training_recipe_for_trainium(sagemaker_session): assert pytorch.distribution == expected_distribution +@pytest.mark.parametrize( + "test_case", + [ + { + "script": "llama_pretrain.py", + "recipe": { + "model": { + "model_type": "llama_v3", + }, + }, + }, + { + "script": "mistral_pretrain.py", + "recipe": { + "model": { + "model_type": "mistral", + }, + }, + }, + { + "script": "deepseek_pretrain.py", + "recipe": { + "model": { + "model_type": "deepseek_llamav3", + }, + }, + }, + { + "script": "deepseek_pretrain.py", + "recipe": { + "model": { + "model_type": "deepseek_qwenv2", + }, + }, + }, + ], +) +@patch("shutil.copyfile") +def test_get_training_recipe_gpu_script(mock_copyfile, test_case): + script = test_case["script"] + recipe = test_case["recipe"] + mock_copyfile.return_value = None + + assert _get_training_recipe_gpu_script("code_dir", recipe, "source_dir") == script + + def test_training_recipe_for_trainium_custom_source_dir(sagemaker_session): container_log_level = '"logging.INFO"'