From 3e448c5d0b3d178103c744a84093387a3c9e2cc0 Mon Sep 17 00:00:00 2001 From: GonzaloSaez <11050889+GonzaloSaez@users.noreply.github.com> Date: Sun, 15 Feb 2026 00:56:39 +0000 Subject: [PATCH] Fix launcher job scheduling directives when unsuspending Signed-off-by: GonzaloSaez <11050889+GonzaloSaez@users.noreply.github.com> --- pkg/controller/mpi_job_controller.go | 45 ++++- pkg/controller/mpi_job_controller_test.go | 183 +++++++++++++++++++- test/integration/mpi_job_controller_test.go | 62 ++++++- 3 files changed, 284 insertions(+), 6 deletions(-) diff --git a/pkg/controller/mpi_job_controller.go b/pkg/controller/mpi_job_controller.go index 47a1b589..4af15e25 100644 --- a/pkg/controller/mpi_job_controller.go +++ b/pkg/controller/mpi_job_controller.go @@ -24,6 +24,7 @@ import ( "encoding/pem" "errors" "fmt" + "maps" "reflect" "sort" "strconv" @@ -687,9 +688,29 @@ func (c *MPIJobController) syncHandler(key string) error { } if launcher != nil { - if isMPIJobSuspended(mpiJob) != isJobSuspended(launcher) { - // align the suspension state of launcher with the MPIJob - launcher.Spec.Suspend = ptr.To(isMPIJobSuspended(mpiJob)) + if !isMPIJobSuspended(mpiJob) && isJobSuspended(launcher) { + // We are unsuspending, hence we need to sync the pod template with the current MPIJob spec. + // This is important for interop with Kueue as it may have injected schedulingGates. + // Kubernetes validates that a Job template is immutable once StartTime is set, + // so we must clear it first via a status sub-resource update (consistent with JobSet). + if launcher.Status.StartTime != nil { + launcher.Status.StartTime = nil + var err error + if launcher, err = c.kubeClient.BatchV1().Jobs(namespace).UpdateStatus(context.TODO(), launcher, metav1.UpdateOptions{}); err != nil { + return err + } + } + + // Sync mutable scheduling directives (KEP-2926) and unsuspend. + desiredPodTemplate := c.newLauncherPodTemplate(mpiJob) + syncLauncherSchedulingDirectives(launcher, &desiredPodTemplate) + launcher.Spec.Suspend = ptr.To(false) + if _, err := c.kubeClient.BatchV1().Jobs(namespace).Update(context.TODO(), launcher, metav1.UpdateOptions{}); err != nil { + return err + } + } else if isMPIJobSuspended(mpiJob) && !isJobSuspended(launcher) { + // align the suspension state of launcher with the MPIJob. + launcher.Spec.Suspend = ptr.To(true) if _, err := c.kubeClient.BatchV1().Jobs(namespace).Update(context.TODO(), launcher, metav1.UpdateOptions{}); err != nil { return err } @@ -1623,6 +1644,24 @@ func (c *MPIJobController) newLauncherPodTemplate(mpiJob *kubeflow.MPIJob) corev } } +func mergeMaps[K comparable, V any](a, b map[K]V) map[K]V { + merged := make(map[K]V, max(len(a), len(b))) + maps.Copy(merged, a) + maps.Copy(merged, b) + return merged +} + +// syncLauncherSchedulingDirectives updates the mutable scheduling directives (as per KEP-2926) on +// the launcher Job's pod template to match the desired template. +func syncLauncherSchedulingDirectives(launcher *batchv1.Job, desired *corev1.PodTemplateSpec) { + launcher.Spec.Template.Labels = mergeMaps(launcher.Spec.Template.Labels, desired.Labels) + launcher.Spec.Template.Annotations = mergeMaps(launcher.Spec.Template.Annotations, desired.Annotations) + + launcher.Spec.Template.Spec.NodeSelector = desired.Spec.NodeSelector + launcher.Spec.Template.Spec.Tolerations = desired.Spec.Tolerations + launcher.Spec.Template.Spec.SchedulingGates = desired.Spec.SchedulingGates +} + func (c *MPIJobController) jobPods(j *batchv1.Job) ([]*corev1.Pod, error) { selector, err := metav1.LabelSelectorAsSelector(j.Spec.Selector) if err != nil { diff --git a/pkg/controller/mpi_job_controller_test.go b/pkg/controller/mpi_job_controller_test.go index ea39f21c..888a7f4a 100644 --- a/pkg/controller/mpi_job_controller_test.go +++ b/pkg/controller/mpi_job_controller_test.go @@ -1024,14 +1024,16 @@ func TestResumeMPIJob(t *testing.T) { // resume the MPIJob mpiJob.Spec.RunPolicy.Suspend = ptr.To(false) - // expect creation of the pods + // expect creation of the worker pods for i := 0; i < int(replicas); i++ { worker := fmjc.newWorker(mpiJob, i) f.kubeActions = append(f.kubeActions, core.NewCreateAction(schema.GroupVersionResource{Resource: "pods"}, mpiJob.Namespace, worker)) } - // expect the launcher update to resume it + // expect the launcher update to sync scheduling directives and resume it launcherCopy := launcher.DeepCopy() + desiredPodTemplate := fmjc.newLauncherPodTemplate(mpiJob) + syncLauncherSchedulingDirectives(launcherCopy, &desiredPodTemplate) launcherCopy.Spec.Suspend = ptr.To(false) f.expectUpdateJobAction(launcherCopy) @@ -1044,6 +1046,183 @@ func TestResumeMPIJob(t *testing.T) { f.runWithClock(t.Context(), getKey(mpiJob, t), fakeClock) } +func TestResumeMPIJobWithExistingLauncher(t *testing.T) { + // Tests the running→suspended→resumed path where a launcher already exists + // (from before suspension) with startTime == nil. The launcher should be + // updated in place with synced scheduling directives (KEP-2926). + fakeClock := clocktesting.NewFakeClock(time.Now().Truncate(time.Second)) + f := newFixture(t, "") + + var replicas int32 = 8 + startTime := metav1.Now() + mpiJob := newMPIJob("test", &replicas, &startTime, nil) + mpiJob.Spec.RunPolicy.Suspend = ptr.To(true) + msg := fmt.Sprintf("MPIJob %s/%s is created.", mpiJob.Namespace, mpiJob.Name) + updateMPIJobConditions(mpiJob, kubeflow.JobCreated, corev1.ConditionTrue, mpiJobCreatedReason, msg) + updateMPIJobConditions(mpiJob, kubeflow.JobSuspended, corev1.ConditionTrue, mpiJobSuspendedReason, "MPIJob suspended") + msg = fmt.Sprintf("MPIJob %s/%s is suspended.", mpiJob.Namespace, mpiJob.Name) + updateMPIJobConditions(mpiJob, kubeflow.JobRunning, corev1.ConditionFalse, mpiJobSuspendedReason, msg) + mpiJob.Status.ReplicaStatuses = map[kubeflow.MPIReplicaType]*kubeflow.ReplicaStatus{ + kubeflow.MPIReplicaTypeLauncher: {}, + kubeflow.MPIReplicaTypeWorker: {}, + } + f.setUpMPIJob(mpiJob) + + scheme.Scheme.Default(mpiJob) + f.expectCreateServiceAction(newJobService(mpiJob)) + cfgMap := newConfigMap(mpiJob, replicas, "") + updateDiscoverHostsInConfigMap(cfgMap, mpiJob, nil, "") + f.setUpConfigMap(cfgMap) + secret, err := newSSHAuthSecret(mpiJob) + if err != nil { + t.Fatalf("Failed creating secret") + } + f.setUpSecret(secret) + + // set up an existing suspended launcher (startTime == nil, never started) + fmjc := f.newFakeMPIJobController() + launcher := fmjc.newLauncherJob(mpiJob) + launcher.Spec.Suspend = ptr.To(true) + // Simulate Kueue injecting scheduling directives into the MPIJob template + // after the launcher was already created (so the launcher has stale templates). + launcherSpec := &mpiJob.Spec.MPIReplicaSpecs[kubeflow.MPIReplicaTypeLauncher].Template + launcherSpec.Spec.NodeSelector = map[string]string{ + "foo": "bar", + } + launcherSpec.Spec.Tolerations = []corev1.Toleration{ + {Key: "gpu", Operator: corev1.TolerationOpEqual, Value: "true", Effect: corev1.TaintEffectNoSchedule}, + } + launcherSpec.Spec.SchedulingGates = []corev1.PodSchedulingGate{ + {Name: "kueue.x-k8s.io/topology"}, + } + if launcherSpec.Annotations == nil { + launcherSpec.Annotations = make(map[string]string) + } + launcherSpec.Annotations["kueue.x-k8s.io/workload"] = "my-workload" + f.setUpLauncher(launcher) + + fakeClock.Sleep(time.Second) + + // resume the MPIJob + mpiJob.Spec.RunPolicy.Suspend = ptr.To(false) + + // expect creation of the worker pods + for i := 0; i < int(replicas); i++ { + worker := fmjc.newWorker(mpiJob, i) + f.kubeActions = append(f.kubeActions, core.NewCreateAction(schema.GroupVersionResource{Resource: "pods"}, mpiJob.Namespace, worker)) + } + + // expect the launcher to be updated (scheduling directives synced + unsuspended) + launcherCopy := launcher.DeepCopy() + desiredPodTemplate := fmjc.newLauncherPodTemplate(mpiJob) + syncLauncherSchedulingDirectives(launcherCopy, &desiredPodTemplate) + launcherCopy.Spec.Suspend = ptr.To(false) + + // Verify the synced launcher has the Kueue-injected scheduling directives. + tmpl := &launcherCopy.Spec.Template + if tmpl.Spec.NodeSelector["foo"] != "bar" { + t.Errorf("expected nodeSelector to be synced, got %v", tmpl.Spec.NodeSelector) + } + if len(tmpl.Spec.Tolerations) != 1 || tmpl.Spec.Tolerations[0].Key != "gpu" { + t.Errorf("expected tolerations to be synced, got %v", tmpl.Spec.Tolerations) + } + if len(tmpl.Spec.SchedulingGates) != 1 || tmpl.Spec.SchedulingGates[0].Name != "kueue.x-k8s.io/topology" { + t.Errorf("expected schedulingGates to be synced, got %v", tmpl.Spec.SchedulingGates) + } + if tmpl.Annotations["kueue.x-k8s.io/workload"] != "my-workload" { + t.Errorf("expected annotations to be synced, got %v", tmpl.Annotations) + } + + f.expectUpdateJobAction(launcherCopy) + + // expect status update + mpiJobCopy := mpiJob.DeepCopy() + mpiJobCopy.Status.StartTime = &metav1.Time{Time: fakeClock.Now()} + updateMPIJobConditions(mpiJobCopy, kubeflow.JobSuspended, corev1.ConditionFalse, "MPIJobResumed", "MPIJob resumed") + f.expectUpdateMPIJobStatusAction(mpiJobCopy) + + f.runWithClock(t.Context(), getKey(mpiJob, t), fakeClock) +} + +func TestResumeMPIJobClearsStartTime(t *testing.T) { + // Tests the re-admission case where the launcher has startTime != nil. + // The controller should clear StartTime via a status sub-resource update + // (consistent with JobSet), then sync scheduling directives and unsuspend. + fakeClock := clocktesting.NewFakeClock(time.Now().Truncate(time.Second)) + f := newFixture(t, "") + + var replicas int32 = 8 + startTime := metav1.Now() + mpiJob := newMPIJob("test", &replicas, &startTime, nil) + mpiJob.Spec.RunPolicy.Suspend = ptr.To(true) + msg := fmt.Sprintf("MPIJob %s/%s is created.", mpiJob.Namespace, mpiJob.Name) + updateMPIJobConditions(mpiJob, kubeflow.JobCreated, corev1.ConditionTrue, mpiJobCreatedReason, msg) + updateMPIJobConditions(mpiJob, kubeflow.JobSuspended, corev1.ConditionTrue, mpiJobSuspendedReason, "MPIJob suspended") + msg = fmt.Sprintf("MPIJob %s/%s is suspended.", mpiJob.Namespace, mpiJob.Name) + updateMPIJobConditions(mpiJob, kubeflow.JobRunning, corev1.ConditionFalse, mpiJobSuspendedReason, msg) + mpiJob.Status.ReplicaStatuses = map[kubeflow.MPIReplicaType]*kubeflow.ReplicaStatus{ + kubeflow.MPIReplicaTypeLauncher: {}, + kubeflow.MPIReplicaTypeWorker: {}, + } + f.setUpMPIJob(mpiJob) + + scheme.Scheme.Default(mpiJob) + f.expectCreateServiceAction(newJobService(mpiJob)) + cfgMap := newConfigMap(mpiJob, replicas, "") + updateDiscoverHostsInConfigMap(cfgMap, mpiJob, nil, "") + f.setUpConfigMap(cfgMap) + secret, err := newSSHAuthSecret(mpiJob) + if err != nil { + t.Fatalf("Failed creating secret") + } + f.setUpSecret(secret) + + // set up an existing suspended launcher that was previously started (startTime != nil) + fmjc := f.newFakeMPIJobController() + launcher := fmjc.newLauncherJob(mpiJob) + launcher.Spec.Suspend = ptr.To(true) + launcherStartTime := metav1.Now() + launcher.Status.StartTime = &launcherStartTime + f.setUpLauncher(launcher) + + fakeClock.Sleep(time.Second) + + // resume the MPIJob + mpiJob.Spec.RunPolicy.Suspend = ptr.To(false) + + // expect creation of worker pods + for i := 0; i < int(replicas); i++ { + worker := fmjc.newWorker(mpiJob, i) + f.kubeActions = append(f.kubeActions, core.NewCreateAction(schema.GroupVersionResource{Resource: "pods"}, mpiJob.Namespace, worker)) + } + + // expect a status sub-resource update to clear launcher's StartTime + launcherStatusCleared := launcher.DeepCopy() + launcherStatusCleared.Status.StartTime = nil + f.kubeActions = append(f.kubeActions, core.NewUpdateSubresourceAction( + schema.GroupVersionResource{Resource: "jobs", Group: "batch", Version: "v1"}, + "status", + mpiJob.Namespace, + launcherStatusCleared, + )) + + // expect the launcher to be updated (scheduling directives synced + unsuspended) + launcherCopy := launcher.DeepCopy() + launcherCopy.Status.StartTime = nil + desiredPodTemplate := fmjc.newLauncherPodTemplate(mpiJob) + syncLauncherSchedulingDirectives(launcherCopy, &desiredPodTemplate) + launcherCopy.Spec.Suspend = ptr.To(false) + f.expectUpdateJobAction(launcherCopy) + + // expect MPIJob status update + mpiJobCopy := mpiJob.DeepCopy() + mpiJobCopy.Status.StartTime = &metav1.Time{Time: fakeClock.Now()} + updateMPIJobConditions(mpiJobCopy, kubeflow.JobSuspended, corev1.ConditionFalse, "MPIJobResumed", "MPIJob resumed") + f.expectUpdateMPIJobStatusAction(mpiJobCopy) + + f.runWithClock(t.Context(), getKey(mpiJob, t), fakeClock) +} + func TestWorkerNotControlledByUs(t *testing.T) { f := newFixture(t, "") startTime := metav1.Now() diff --git a/test/integration/mpi_job_controller_test.go b/test/integration/mpi_job_controller_test.go index de5da72f..0b8f8b4a 100644 --- a/test/integration/mpi_job_controller_test.go +++ b/test/integration/mpi_job_controller_test.go @@ -385,7 +385,7 @@ func TestMPIJobResumingAndSuspending(t *testing.T) { t.Errorf("MPIJob missing Suspended condition") } if !isJobSuspended(launcherJob) { - t.Errorf("LauncherJob is suspended") + t.Errorf("LauncherJob is not suspended") } if mpiJob.Status.StartTime != nil { t.Errorf("MPIJob has unexpected start time: %v", mpiJob.Status.StartTime) @@ -393,6 +393,29 @@ func TestMPIJobResumingAndSuspending(t *testing.T) { s.events.verify(t) + // Simulate Kueue injecting scheduling directives into the MPIJob template + // while suspended. When resumed, these must propagate to the launcher Job. + launcherTemplate := &mpiJob.Spec.MPIReplicaSpecs[kubeflow.MPIReplicaTypeLauncher].Template + launcherTemplate.Labels = map[string]string{ + "foo": "bar", + } + launcherTemplate.Annotations = map[string]string{ + "kueue.x-k8s.io/workload": "my-workload", + } + launcherTemplate.Spec.NodeSelector = map[string]string{ + "example.com/accelerator": "example-model", + } + launcherTemplate.Spec.Tolerations = []corev1.Toleration{ + {Key: "gpu", Operator: corev1.TolerationOpEqual, Value: "true", Effect: corev1.TaintEffectNoSchedule}, + } + launcherTemplate.Spec.SchedulingGates = []corev1.PodSchedulingGate{ + {Name: "kueue.x-k8s.io/topology"}, + } + mpiJob, err = s.mpiClient.KubeflowV2beta1().MPIJobs(s.namespace).Update(ctx, mpiJob, metav1.UpdateOptions{}) + if err != nil { + t.Fatalf("Failed to update the MPIJob: %v", err) + } + // 2. Resume the MPIJob mpiJob.Spec.RunPolicy.Suspend = ptr.To(false) mpiJob, err = s.mpiClient.KubeflowV2beta1().MPIJobs(mpiJob.Namespace).Update(ctx, mpiJob, metav1.UpdateOptions{}) @@ -422,6 +445,24 @@ func TestMPIJobResumingAndSuspending(t *testing.T) { s.events.verify(t) + // Verify all scheduling directives were propagated to the launcher Job's pod template. + launcherTmpl := &launcherJob.Spec.Template + if launcherTmpl.Labels["foo"] != "bar" { + t.Errorf("expected label 'foo=bar' on launcher Job template, got labels: %v", launcherTmpl.Labels) + } + if launcherTmpl.Annotations["kueue.x-k8s.io/workload"] != "my-workload" { + t.Errorf("expected annotation 'kueue.x-k8s.io/workload' on launcher Job template, got annotations: %v", launcherTmpl.Annotations) + } + if launcherTmpl.Spec.NodeSelector["example.com/accelerator"] != "example-model" { + t.Errorf("expected nodeSelector 'example.com/accelerator=example-model' on launcher Job template, got: %v", launcherTmpl.Spec.NodeSelector) + } + if len(launcherTmpl.Spec.Tolerations) == 0 || launcherTmpl.Spec.Tolerations[len(launcherTmpl.Spec.Tolerations)-1].Key != "gpu" { + t.Errorf("expected toleration with key 'gpu' on launcher Job template, got: %v", launcherTmpl.Spec.Tolerations) + } + if len(launcherTmpl.Spec.SchedulingGates) == 0 || launcherTmpl.Spec.SchedulingGates[len(launcherTmpl.Spec.SchedulingGates)-1].Name != "kueue.x-k8s.io/topology" { + t.Errorf("expected schedulingGate 'kueue.x-k8s.io/topology' on launcher Job template, got: %v", launcherTmpl.Spec.SchedulingGates) + } + // 3. Set the pods to be running err = updatePodsToPhase(ctx, s.kClient, workerPods, corev1.PodRunning) if err != nil { @@ -473,6 +514,25 @@ func TestMPIJobResumingAndSuspending(t *testing.T) { if !mpiJobHasConditionWithStatus(mpiJob, kubeflow.JobRunning, corev1.ConditionFalse) { t.Errorf("MPIJob has unexpected Running condition") } + + // Update the MPIJob launcher template again and resume, verifying the + // launcher Job gets the updated scheduling directives on second resume. + mpiJobLauncherTemplate := &mpiJob.Spec.MPIReplicaSpecs[kubeflow.MPIReplicaTypeLauncher].Template + mpiJobLauncherTemplate.Labels["foo"] = "baz" + mpiJobLauncherTemplate.Spec.NodeSelector["example.com/accelerator"] = "example-model-v2" + mpiJob.Spec.RunPolicy.Suspend = ptr.To(false) + mpiJob, err = s.mpiClient.KubeflowV2beta1().MPIJobs(s.namespace).Update(ctx, mpiJob, metav1.UpdateOptions{}) + if err != nil { + t.Fatalf("Failed to update the MPIJob: %v", err) + } + + _, launcherJob = validateMPIJobDependencies(ctx, t, s.kClient, mpiJob, 2, nil) + if launcherJob.Spec.Template.Labels["foo"] != "baz" { + t.Errorf("expected label 'foo=baz' on launcher Job template, got labels: %v", launcherJob.Spec.Template.Labels) + } + if launcherJob.Spec.Template.Spec.NodeSelector["example.com/accelerator"] != "example-model-v2" { + t.Errorf("expected nodeSelector 'example.com/accelerator=example-model-v2' on launcher Job template, got: %v", launcherJob.Spec.Template.Spec.NodeSelector) + } } func TestMPIJobFailure(t *testing.T) {