@@ -125,6 +125,27 @@ def _register_custom_resolvers():
125125 OmegaConf .register_new_resolver ("add" , lambda * numbers : sum (numbers ))
126126
127127
128+ def _get_trainining_recipe_gpu_model_name_and_script (model_type : str ):
129+ """Get the model base name and script for the training recipe."""
130+
131+ model_type_to_script = {
132+ "llama_v3" : ("llama" , "llama_pretrain.py" ),
133+ "mistral" : ("mistral" , "mistral_pretrain.py" ),
134+ "mixtral" : ("mixtral" , "mixtral_pretrain.py" ),
135+ "deepseek" : ("deepseek" , "deepseek_pretrain.py" ),
136+ }
137+
138+ for key in model_type_to_script :
139+ if model_type .startswith (key ):
140+ model_type = key
141+ break
142+
143+ if model_type not in model_type_to_script :
144+ raise ValueError (f"Model type { model_type } not supported" )
145+
146+ return model_type_to_script [model_type ][0 ], model_type_to_script [model_type ][1 ]
147+
148+
128149def _configure_gpu_args (
129150 training_recipes_cfg : Dict [str , Any ],
130151 region_name : str ,
@@ -140,24 +161,16 @@ def _configure_gpu_args(
140161 )
141162 _run_clone_command_silent (adapter_repo , recipe_train_dir .name )
142163
143- model_type_to_entry = {
144- "llama_v3" : ("llama" , "llama_pretrain.py" ),
145- "mistral" : ("mistral" , "mistral_pretrain.py" ),
146- "mixtral" : ("mixtral" , "mixtral_pretrain.py" ),
147- }
148-
149164 if "model" not in recipe :
150165 raise ValueError ("Supplied recipe does not contain required field model." )
151166 if "model_type" not in recipe ["model" ]:
152167 raise ValueError ("Supplied recipe does not contain required field model_type." )
153168 model_type = recipe ["model" ]["model_type" ]
154- if model_type not in model_type_to_entry :
155- raise ValueError (f"Model type { model_type } not supported" )
156169
157- source_code . source_dir = os . path . join (
158- recipe_train_dir . name , "examples" , model_type_to_entry [ model_type ][ 0 ]
159- )
160- source_code .entry_script = model_type_to_entry [ model_type ][ 1 ]
170+ model_base_name , script = _get_trainining_recipe_gpu_model_name_and_script ( model_type )
171+
172+ source_code . source_dir = os . path . join ( recipe_train_dir . name , "examples" , model_base_name )
173+ source_code .entry_script = script
161174
162175 gpu_image_cfg = training_recipes_cfg .get ("gpu_image" )
163176 if isinstance (gpu_image_cfg , str ):
0 commit comments