@@ -60,11 +60,14 @@ type JobSet struct {
6060 logger logr.Logger
6161}
6262
63- var _ framework.WatchExtensionPlugin = (* JobSet )(nil )
64- var _ framework.PodNetworkPlugin = (* JobSet )(nil )
65- var _ framework.ComponentBuilderPlugin = (* JobSet )(nil )
66- var _ framework.TrainJobStatusPlugin = (* JobSet )(nil )
67- var _ framework.CustomValidationPlugin = (* JobSet )(nil )
63+ var (
64+ _ framework.WatchExtensionPlugin = (* JobSet )(nil )
65+ _ framework.PodNetworkPlugin = (* JobSet )(nil )
66+ _ framework.ComponentBuilderPlugin = (* JobSet )(nil )
67+ _ framework.TrainJobStatusPlugin = (* JobSet )(nil )
68+ _ framework.CustomValidationPlugin = (* JobSet )(nil )
69+ _ framework.SuspendSyncPlugin = (* JobSet )(nil )
70+ )
6871
6972const Name = constants .JobSetKind
7073
@@ -145,7 +148,6 @@ func (j *JobSet) Validate(ctx context.Context, info *runtime.Info, oldObj, newOb
145148 }
146149 }
147150 }
148-
149151 }
150152
151153 return nil , allErrs
@@ -239,14 +241,15 @@ func (j *JobSet) Build(ctx context.Context, info *runtime.Info, trainJob *traine
239241 return nil , fmt .Errorf ("runtime info or object is missing" )
240242 }
241243
242- // Do not update the JobSet if it already exists and is not suspended
244+ // Check if JobSet already exists
243245 oldJobSet := & jobsetv1alpha2.JobSet {}
244246 if err := j .client .Get (ctx , client .ObjectKeyFromObject (trainJob ), oldJobSet ); err != nil {
245247 if ! apierrors .IsNotFound (err ) {
246248 return nil , err
247249 }
248250 oldJobSet = nil
249251 }
252+
250253 if oldJobSet != nil &&
251254 ! ptr .Deref (trainJob .Spec .Suspend , false ) &&
252255 ! ptr .Deref (oldJobSet .Spec .Suspend , false ) {
@@ -288,12 +291,20 @@ func (j *JobSet) Build(ctx context.Context, info *runtime.Info, trainJob *traine
288291
289292 // TODO (andreyvelich): Refactor the builder with wrappers for PodSpec.
290293 // TODO: Once we remove deprecated runtime.Info.Trainer, we should remove JobSet Builder with DeprecatedTrainer().
291- jobSet : = jobSetBuilder .
294+ jobSetBuilder = jobSetBuilder .
292295 Initializer (trainJob ).
293296 Trainer (info , trainJob ).
294297 PodLabels (info .Scheduler .PodLabels ).
295- PodAnnotations (info .Scheduler .PodAnnotations ).
296- Suspend (trainJob .Spec .Suspend ).
298+ PodAnnotations (info .Scheduler .PodAnnotations )
299+
300+ // Only set suspend on initial creation. For existing JobSets, suspend is
301+ // managed separately via SyncSuspend using merge patch. SSA without this
302+ // field leaves the existing suspend value unchanged (releases ownership only).
303+ if oldJobSet == nil {
304+ jobSetBuilder = jobSetBuilder .Suspend (trainJob .Spec .Suspend )
305+ }
306+
307+ jobSet := jobSetBuilder .
297308 Build ().
298309 WithOwnerReferences (metav1ac .OwnerReference ().
299310 WithAPIVersion (trainer .GroupVersion .String ()).
@@ -337,3 +348,23 @@ func (j *JobSet) Status(ctx context.Context, trainJob *trainer.TrainJob) (*train
337348
338349 return status , nil
339350}
351+
352+ func (j * JobSet ) SyncSuspend (ctx context.Context , trainJob * trainer.TrainJob ) error {
353+ jobSet := & jobsetv1alpha2.JobSet {}
354+ if err := j .client .Get (ctx , client .ObjectKeyFromObject (trainJob ), jobSet ); err != nil {
355+ return client .IgnoreNotFound (err )
356+ }
357+
358+ trainJobSuspend := ptr .Deref (trainJob .Spec .Suspend , false )
359+ jobSetSuspend := ptr .Deref (jobSet .Spec .Suspend , false )
360+
361+ if trainJobSuspend != jobSetSuspend {
362+ patch := client .MergeFrom (jobSet .DeepCopy ())
363+ jobSet .Spec .Suspend = ptr .To (trainJobSuspend )
364+ if err := j .client .Patch (ctx , jobSet , patch ); err != nil {
365+ return fmt .Errorf ("failed to patch JobSet suspend field: %w" , err )
366+ }
367+ }
368+
369+ return nil
370+ }
0 commit comments