@@ -829,7 +829,13 @@ def _optimize_for_jumpstart(
829829 self .pysdk_model ._enable_network_isolation = False
830830
831831 if quantization_config or sharding_config or is_compilation :
832- return create_optimization_job_args
832+ # only apply default image for vLLM usecases.
833+ # vLLM does not support compilation for now so skip on compilation
834+ return (
835+ create_optimization_job_args
836+ if is_compilation
837+ else self ._set_optimization_image_default (create_optimization_job_args )
838+ )
833839 return None
834840
835841 def _is_gated_model (self , model = None ) -> bool :
@@ -986,3 +992,28 @@ def _get_neuron_model_env_vars(
986992 )
987993 return job_model .env
988994 return None
995+
996+ def _set_optimization_image_default (
997+ self , create_optimization_job_args : Dict [str , Any ]
998+ ) -> Dict [str , Any ]:
999+ """Defaults the optimization image to the JumpStart deployment config default
1000+
1001+ Args:
1002+ create_optimization_job_args (Dict[str, Any]): create optimization job request
1003+
1004+ Returns:
1005+ Dict[str, Any]: create optimization job request with image uri default
1006+ """
1007+
1008+ for optimization_config in create_optimization_job_args .get ("OptimizationConfigs" ):
1009+ if optimization_config .get ("ModelQuantizationConfig" ):
1010+ model_quantization_config = optimization_config .get ("ModelQuantizationConfig" )
1011+ if not model_quantization_config .get ("Image" ):
1012+ model_quantization_config ["Image" ] = self .pysdk_model .init_kwargs ["image_uri" ]
1013+
1014+ if optimization_config .get ("ModelShardingConfig" ):
1015+ model_sharding_config = optimization_config .get ("ModelShardingConfig" )
1016+ if not model_sharding_config .get ("Image" ):
1017+ model_sharding_config ["Image" ] = self .pysdk_model .init_kwargs ["image_uri" ]
1018+
1019+ return create_optimization_job_args
0 commit comments