158158 InstanceTypePlacementGroupValidator ,
159159 InstanceTypeValidator ,
160160 KeyPairValidator ,
161+ LaunchTemplateOverridesValidator ,
161162 PlacementGroupCapacityReservationValidator ,
162163 PlacementGroupCapacityTypeValidator ,
163164 PlacementGroupNamingValidator ,
@@ -1631,6 +1632,7 @@ def __init__(
16311632 self .managed_head_node_security_group = None
16321633 self .managed_compute_security_group = None
16331634 self .instance_types_data_version = ""
1635+ self .run_instances_overrides_version = ""
16341636
16351637 def _register_validators (self , context : ValidatorContext = None ): # noqa: D102 #pylint: disable=unused-argument
16361638 self ._register_validator (RegionValidator , region = self .region )
@@ -2222,6 +2224,15 @@ def scheduler_resources(self):
22222224 return str (files (__package__ ).parent / "resources" / "batch" )
22232225
22242226
2227+ class LaunchTemplateOverrides (Resource ):
2228+ """Represent the LaunchTemplateOverrides configuration for a compute resource."""
2229+
2230+ def __init__ (self , launch_template_id : str = None , version : int = None , ** kwargs ):
2231+ super ().__init__ (** kwargs )
2232+ self .launch_template_id = Resource .init_param (launch_template_id )
2233+ self .version = Resource .init_param (version )
2234+
2235+
22252236class _BaseSlurmComputeResource (BaseComputeResource ):
22262237 """Represent the Slurm Compute Resource."""
22272238
@@ -2240,6 +2251,7 @@ def __init__(
22402251 tags : List [Tag ] = None ,
22412252 static_node_priority : int = None ,
22422253 dynamic_node_priority : int = None ,
2254+ launch_template_overrides = None ,
22432255 ** kwargs ,
22442256 ):
22452257 super ().__init__ (** kwargs )
@@ -2260,6 +2272,7 @@ def __init__(
22602272 self .tags = tags
22612273 self .static_node_priority = Resource .init_param (static_node_priority , default = 1 )
22622274 self .dynamic_node_priority = Resource .init_param (dynamic_node_priority , default = 1000 )
2275+ self .launch_template_overrides = launch_template_overrides
22632276
22642277 @abstractmethod
22652278 def is_flexible (self ) -> bool :
@@ -2362,6 +2375,15 @@ def _register_validators(self, context: ValidatorContext = None):
23622375 ec2memory = min_memory ,
23632376 instance_type = smallest_type ,
23642377 )
2378+ if self .launch_template_overrides :
2379+ self ._register_validator (
2380+ LaunchTemplateOverridesValidator ,
2381+ launch_template_id = self .launch_template_overrides .launch_template_id ,
2382+ version = self .launch_template_overrides .version ,
2383+ instance_types = self .instance_types ,
2384+ max_network_cards = self .max_network_cards ,
2385+ is_flexible = self .is_flexible (),
2386+ )
23652387
23662388 def is_flexible (self ):
23672389 """Return True because the ComputeResource can contain multiple instance types."""
@@ -2449,6 +2471,15 @@ def _register_validators(self, context: ValidatorContext = None):
24492471 ec2memory = self ._instance_type_info .ec2memory_size_in_mib (),
24502472 instance_type = self .instance_type ,
24512473 )
2474+ if self .launch_template_overrides :
2475+ self ._register_validator (
2476+ LaunchTemplateOverridesValidator ,
2477+ launch_template_id = self .launch_template_overrides .launch_template_id ,
2478+ version = self .launch_template_overrides .version ,
2479+ instance_types = self .instance_types ,
2480+ max_network_cards = self .max_network_cards ,
2481+ is_flexible = self .is_flexible (),
2482+ )
24522483
24532484 @property
24542485 def architecture (self ) -> str :
@@ -2975,6 +3006,40 @@ def get_instance_types_data(self):
29753006 result [instance_type ] = instance_type_info .instance_type_data
29763007 return result
29773008
3009+ def get_run_instances_overrides (self ):
3010+ """
3011+ Build run_instances_overrides data from LaunchTemplateOverrides config.
3012+
3013+ Iterates all queues and compute resources. For each compute resource that has
3014+ launch_template_overrides configured, fetches the launch template data.
3015+
3016+ Returns a dict keyed by {queue_name} -> {compute_resource_name} -> {launch_template_data}.
3017+ Returns empty dict if no overrides are configured.
3018+ """
3019+ overrides = {}
3020+ for queue in self .scheduling .queues :
3021+ for compute_resource in queue .compute_resources :
3022+ if not getattr (compute_resource , "launch_template_overrides" , None ):
3023+ continue
3024+
3025+ lt_overrides = compute_resource .launch_template_overrides
3026+ lt_id = lt_overrides .launch_template_id
3027+ lt_version = lt_overrides .version
3028+
3029+ LOGGER .info (
3030+ "Fetching launch template %s version %s for queue %s, compute resource %s" ,
3031+ lt_id ,
3032+ lt_version ,
3033+ queue .name ,
3034+ compute_resource .name ,
3035+ )
3036+ lt_data = AWSApi .instance ().ec2 .describe_launch_template_version (lt_id , lt_version )
3037+
3038+ if lt_data :
3039+ overrides .setdefault (queue .name , {})[compute_resource .name ] = lt_data
3040+
3041+ return overrides
3042+
29783043 @property
29793044 def login_nodes_ami (self ):
29803045 """Get the image id of the LoginNodes."""
0 commit comments