Skip to content

Commit c17a2d1

Browse files
committed
fix(operator): fix TrainJob suspend/resume webhook error (#3008)
1 parent 63494f6 commit c17a2d1

8 files changed

Lines changed: 86 additions & 12 deletions

File tree

pkg/controller/trainjob_controller.go

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,10 @@ func WithWatchers(watchers ...TrainJobWatcher) TrainJobReconcilerOption {
7070
}
7171
}
7272

73-
var _ reconcile.Reconciler = (*TrainJobReconciler)(nil)
74-
var _ predicate.TypedPredicate[*trainer.TrainJob] = (*TrainJobReconciler)(nil)
73+
var (
74+
_ reconcile.Reconciler = (*TrainJobReconciler)(nil)
75+
_ predicate.TypedPredicate[*trainer.TrainJob] = (*TrainJobReconciler)(nil)
76+
)
7577

7678
func NewTrainJobReconciler(client client.Client, recorder record.EventRecorder, runtimes map[string]jobruntimes.Runtime, opts ...TrainJobReconcilerOption) *TrainJobReconciler {
7779
options := &TrainJobReconcilerOptions{}
@@ -145,6 +147,12 @@ func (r *TrainJobReconciler) Reconcile(ctx context.Context, req ctrl.Request) (c
145147
}
146148

147149
func (r *TrainJobReconciler) reconcileObjects(ctx context.Context, runtime jobruntimes.Runtime, trainJob *trainer.TrainJob) error {
150+
if ptr.Deref(trainJob.Spec.Suspend, false) {
151+
if err := runtime.SyncSuspend(ctx, trainJob); err != nil {
152+
return err
153+
}
154+
}
155+
148156
objects, err := runtime.NewObjects(ctx, trainJob)
149157
if err != nil {
150158
return err
@@ -154,6 +162,11 @@ func (r *TrainJobReconciler) reconcileObjects(ctx context.Context, runtime jobru
154162
return err
155163
}
156164
}
165+
if !ptr.Deref(trainJob.Spec.Suspend, false) {
166+
if err := runtime.SyncSuspend(ctx, trainJob); err != nil {
167+
return err
168+
}
169+
}
157170
return nil
158171
}
159172

pkg/runtime/core/clustertrainingruntime.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,10 @@ func (r *ClusterTrainingRuntime) TrainJobStatus(ctx context.Context, trainJob *t
7777
return r.TrainingRuntime.TrainJobStatus(ctx, trainJob)
7878
}
7979

80+
func (r *ClusterTrainingRuntime) SyncSuspend(ctx context.Context, trainJob *trainer.TrainJob) error {
81+
return r.TrainingRuntime.SyncSuspend(ctx, trainJob)
82+
}
83+
8084
func (r *ClusterTrainingRuntime) EventHandlerRegistrars() []runtime.ReconcilerBuilder {
8185
return nil
8286
}

pkg/runtime/core/trainingruntime.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,3 +263,7 @@ func (r *TrainingRuntime) ValidateObjects(ctx context.Context, old, new *trainer
263263
info, _ := r.newRuntimeInfo(new, trainingRuntime.Spec.Template, trainingRuntime.Spec.MLPolicy, trainingRuntime.Spec.PodGroupPolicy) // ignoring the error here as the runtime configured should be valid
264264
return r.framework.RunCustomValidationPlugins(ctx, info, old, new)
265265
}
266+
267+
func (r *TrainingRuntime) SyncSuspend(ctx context.Context, trainJob *trainer.TrainJob) error {
268+
return r.framework.RunSuspendSyncPlugins(ctx, trainJob)
269+
}

pkg/runtime/framework/core/framework.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ type Framework struct {
4444
podNetworkPlugins []framework.PodNetworkPlugin
4545
componentBuilderPlugins []framework.ComponentBuilderPlugin
4646
trainJobStatusPlugin framework.TrainJobStatusPlugin
47+
suspendSyncPlugins []framework.SuspendSyncPlugin
4748
}
4849

4950
func New(ctx context.Context, c client.Client, r fwkplugins.Registry, indexer client.FieldIndexer) (*Framework, error) {
@@ -85,6 +86,9 @@ func New(ctx context.Context, c client.Client, r fwkplugins.Registry, indexer cl
8586
}
8687
f.trainJobStatusPlugin = p
8788
}
89+
if p, ok := plugin.(framework.SuspendSyncPlugin); ok {
90+
f.suspendSyncPlugins = append(f.suspendSyncPlugins, p)
91+
}
8892
}
8993
f.plugins = plugins
9094
return f, nil
@@ -151,6 +155,15 @@ func (f *Framework) RunTrainJobStatusPlugin(ctx context.Context, trainJob *train
151155
return nil, nil
152156
}
153157

158+
func (f *Framework) RunSuspendSyncPlugins(ctx context.Context, trainJob *trainer.TrainJob) error {
159+
for _, plugin := range f.suspendSyncPlugins {
160+
if err := plugin.SyncSuspend(ctx, trainJob); err != nil {
161+
return err
162+
}
163+
}
164+
return nil
165+
}
166+
154167
func (f *Framework) WatchExtensionPlugins() []framework.WatchExtensionPlugin {
155168
return f.watchExtensionPlugins
156169
}

pkg/runtime/framework/core/framework_test.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,9 @@ func TestNew(t *testing.T) {
114114
&mpi.MPI{},
115115
},
116116
trainJobStatusPlugin: &jobset.JobSet{},
117+
suspendSyncPlugins: []framework.SuspendSyncPlugin{
118+
&jobset.JobSet{},
119+
},
117120
},
118121
},
119122
"indexer key for trainingRuntime and runtimeClass is an empty": {

pkg/runtime/framework/interface.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,3 +65,8 @@ type TrainJobStatusPlugin interface {
6565
Plugin
6666
Status(ctx context.Context, trainJob *trainer.TrainJob) (*trainer.TrainJobStatus, error)
6767
}
68+
69+
type SuspendSyncPlugin interface {
70+
Plugin
71+
SyncSuspend(ctx context.Context, trainJob *trainer.TrainJob) error
72+
}

pkg/runtime/framework/plugins/jobset/jobset.go

Lines changed: 41 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -60,11 +60,14 @@ type JobSet struct {
6060
logger logr.Logger
6161
}
6262

63-
var _ framework.WatchExtensionPlugin = (*JobSet)(nil)
64-
var _ framework.PodNetworkPlugin = (*JobSet)(nil)
65-
var _ framework.ComponentBuilderPlugin = (*JobSet)(nil)
66-
var _ framework.TrainJobStatusPlugin = (*JobSet)(nil)
67-
var _ framework.CustomValidationPlugin = (*JobSet)(nil)
63+
var (
64+
_ framework.WatchExtensionPlugin = (*JobSet)(nil)
65+
_ framework.PodNetworkPlugin = (*JobSet)(nil)
66+
_ framework.ComponentBuilderPlugin = (*JobSet)(nil)
67+
_ framework.TrainJobStatusPlugin = (*JobSet)(nil)
68+
_ framework.CustomValidationPlugin = (*JobSet)(nil)
69+
_ framework.SuspendSyncPlugin = (*JobSet)(nil)
70+
)
6871

6972
const Name = constants.JobSetKind
7073

@@ -145,7 +148,6 @@ func (j *JobSet) Validate(ctx context.Context, info *runtime.Info, oldObj, newOb
145148
}
146149
}
147150
}
148-
149151
}
150152

151153
return nil, allErrs
@@ -239,14 +241,15 @@ func (j *JobSet) Build(ctx context.Context, info *runtime.Info, trainJob *traine
239241
return nil, fmt.Errorf("runtime info or object is missing")
240242
}
241243

242-
// Do not update the JobSet if it already exists and is not suspended
244+
// Check if JobSet already exists
243245
oldJobSet := &jobsetv1alpha2.JobSet{}
244246
if err := j.client.Get(ctx, client.ObjectKeyFromObject(trainJob), oldJobSet); err != nil {
245247
if !apierrors.IsNotFound(err) {
246248
return nil, err
247249
}
248250
oldJobSet = nil
249251
}
252+
250253
if oldJobSet != nil &&
251254
!ptr.Deref(trainJob.Spec.Suspend, false) &&
252255
!ptr.Deref(oldJobSet.Spec.Suspend, false) {
@@ -288,12 +291,20 @@ func (j *JobSet) Build(ctx context.Context, info *runtime.Info, trainJob *traine
288291

289292
// TODO (andreyvelich): Refactor the builder with wrappers for PodSpec.
290293
// TODO: Once we remove deprecated runtime.Info.Trainer, we should remove JobSet Builder with DeprecatedTrainer().
291-
jobSet := jobSetBuilder.
294+
jobSetBuilder = jobSetBuilder.
292295
Initializer(trainJob).
293296
Trainer(info, trainJob).
294297
PodLabels(info.Scheduler.PodLabels).
295-
PodAnnotations(info.Scheduler.PodAnnotations).
296-
Suspend(trainJob.Spec.Suspend).
298+
PodAnnotations(info.Scheduler.PodAnnotations)
299+
300+
// Only set suspend on initial creation. For existing JobSets, suspend is
301+
// managed separately via SyncSuspend using merge patch. SSA without this
302+
// field leaves the existing suspend value unchanged (releases ownership only).
303+
if oldJobSet == nil {
304+
jobSetBuilder = jobSetBuilder.Suspend(trainJob.Spec.Suspend)
305+
}
306+
307+
jobSet := jobSetBuilder.
297308
Build().
298309
WithOwnerReferences(metav1ac.OwnerReference().
299310
WithAPIVersion(trainer.GroupVersion.String()).
@@ -337,3 +348,23 @@ func (j *JobSet) Status(ctx context.Context, trainJob *trainer.TrainJob) (*train
337348

338349
return status, nil
339350
}
351+
352+
func (j *JobSet) SyncSuspend(ctx context.Context, trainJob *trainer.TrainJob) error {
353+
jobSet := &jobsetv1alpha2.JobSet{}
354+
if err := j.client.Get(ctx, client.ObjectKeyFromObject(trainJob), jobSet); err != nil {
355+
return client.IgnoreNotFound(err)
356+
}
357+
358+
trainJobSuspend := ptr.Deref(trainJob.Spec.Suspend, false)
359+
jobSetSuspend := ptr.Deref(jobSet.Spec.Suspend, false)
360+
361+
if trainJobSuspend != jobSetSuspend {
362+
patch := client.MergeFrom(jobSet.DeepCopy())
363+
jobSet.Spec.Suspend = ptr.To(trainJobSuspend)
364+
if err := j.client.Patch(ctx, jobSet, patch); err != nil {
365+
return fmt.Errorf("failed to patch JobSet suspend field: %w", err)
366+
}
367+
}
368+
369+
return nil
370+
}

pkg/runtime/interface.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,4 +37,5 @@ type Runtime interface {
3737
TrainJobStatus(ctx context.Context, trainJob *trainer.TrainJob) (*trainer.TrainJobStatus, error)
3838
EventHandlerRegistrars() []ReconcilerBuilder
3939
ValidateObjects(ctx context.Context, old, new *trainer.TrainJob) (admission.Warnings, field.ErrorList)
40+
SyncSuspend(ctx context.Context, trainJob *trainer.TrainJob) error
4041
}

0 commit comments

Comments
 (0)