Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 42 additions & 3 deletions pkg/controller/mpi_job_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"encoding/pem"
"errors"
"fmt"
"maps"
"reflect"
"sort"
"strconv"
Expand Down Expand Up @@ -687,9 +688,29 @@ func (c *MPIJobController) syncHandler(key string) error {
}

if launcher != nil {
if isMPIJobSuspended(mpiJob) != isJobSuspended(launcher) {
// align the suspension state of launcher with the MPIJob
launcher.Spec.Suspend = ptr.To(isMPIJobSuspended(mpiJob))
if !isMPIJobSuspended(mpiJob) && isJobSuspended(launcher) {
// We are unsuspending, hence we need to sync the pod template with the current MPIJob spec.
// This is important for interop with Kueue as it may have injected schedulingGates.
// Kubernetes validates that a Job template is immutable once StartTime is set,
// so we must clear it first via a status sub-resource update (consistent with JobSet).
if launcher.Status.StartTime != nil {
launcher.Status.StartTime = nil
var err error
if launcher, err = c.kubeClient.BatchV1().Jobs(namespace).UpdateStatus(context.TODO(), launcher, metav1.UpdateOptions{}); err != nil {
return err
}
}

// Sync mutable scheduling directives (KEP-2926) and unsuspend.
desiredPodTemplate := c.newLauncherPodTemplate(mpiJob)
syncLauncherSchedulingDirectives(launcher, &desiredPodTemplate)
launcher.Spec.Suspend = ptr.To(false)
if _, err := c.kubeClient.BatchV1().Jobs(namespace).Update(context.TODO(), launcher, metav1.UpdateOptions{}); err != nil {
return err
}
} else if isMPIJobSuspended(mpiJob) && !isJobSuspended(launcher) {
// align the suspension state of launcher with the MPIJob.
launcher.Spec.Suspend = ptr.To(true)
if _, err := c.kubeClient.BatchV1().Jobs(namespace).Update(context.TODO(), launcher, metav1.UpdateOptions{}); err != nil {
return err
}
Expand Down Expand Up @@ -1623,6 +1644,24 @@ func (c *MPIJobController) newLauncherPodTemplate(mpiJob *kubeflow.MPIJob) corev
}
}

func mergeMaps[K comparable, V any](a, b map[K]V) map[K]V {
merged := make(map[K]V, max(len(a), len(b)))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
merged := make(map[K]V, max(len(a), len(b)))
merged := make(map[K]V, len(a)+len(b))

Sorry for the confusion. As I check this code again, shouldn't this be the sum of a and b?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It depends, if a and b have the same or very similar keys then we'd over-allocating. Lmk what you prefer

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe that the caller functions of mergeMaps should avoid consider internal implementations which means even the length of a and b are pretty different should be considered.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Surely, in the worst case (a and b are mostly the same and both have very big lengths), it will allocate too redundant memory.

Copy link
Member

@tenzen-y tenzen-y Feb 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright, both approaches (max(a, b) and sum(a, b)) have different problems, and I don't want to waste time on trivial discussions. So, I would approve the current your approach.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

/lgtm
/approve

maps.Copy(merged, a)
maps.Copy(merged, b)
return merged
}

// syncLauncherSchedulingDirectives updates the mutable scheduling directives (as per KEP-2926) on
// the launcher Job's pod template to match the desired template.
func syncLauncherSchedulingDirectives(launcher *batchv1.Job, desired *corev1.PodTemplateSpec) {
launcher.Spec.Template.Labels = mergeMaps(launcher.Spec.Template.Labels, desired.Labels)
launcher.Spec.Template.Annotations = mergeMaps(launcher.Spec.Template.Annotations, desired.Annotations)

launcher.Spec.Template.Spec.NodeSelector = desired.Spec.NodeSelector
launcher.Spec.Template.Spec.Tolerations = desired.Spec.Tolerations
launcher.Spec.Template.Spec.SchedulingGates = desired.Spec.SchedulingGates
}

func (c *MPIJobController) jobPods(j *batchv1.Job) ([]*corev1.Pod, error) {
selector, err := metav1.LabelSelectorAsSelector(j.Spec.Selector)
if err != nil {
Expand Down
183 changes: 181 additions & 2 deletions pkg/controller/mpi_job_controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1024,14 +1024,16 @@ func TestResumeMPIJob(t *testing.T) {
// resume the MPIJob
mpiJob.Spec.RunPolicy.Suspend = ptr.To(false)

// expect creation of the pods
// expect creation of the worker pods
for i := 0; i < int(replicas); i++ {
worker := fmjc.newWorker(mpiJob, i)
f.kubeActions = append(f.kubeActions, core.NewCreateAction(schema.GroupVersionResource{Resource: "pods"}, mpiJob.Namespace, worker))
}

// expect the launcher update to resume it
// expect the launcher update to sync scheduling directives and resume it
launcherCopy := launcher.DeepCopy()
desiredPodTemplate := fmjc.newLauncherPodTemplate(mpiJob)
syncLauncherSchedulingDirectives(launcherCopy, &desiredPodTemplate)
launcherCopy.Spec.Suspend = ptr.To(false)
f.expectUpdateJobAction(launcherCopy)

Expand All @@ -1044,6 +1046,183 @@ func TestResumeMPIJob(t *testing.T) {
f.runWithClock(t.Context(), getKey(mpiJob, t), fakeClock)
}

func TestResumeMPIJobWithExistingLauncher(t *testing.T) {
// Tests the running→suspended→resumed path where a launcher already exists
// (from before suspension) with startTime == nil. The launcher should be
// updated in place with synced scheduling directives (KEP-2926).
fakeClock := clocktesting.NewFakeClock(time.Now().Truncate(time.Second))
f := newFixture(t, "")

var replicas int32 = 8
startTime := metav1.Now()
mpiJob := newMPIJob("test", &replicas, &startTime, nil)
mpiJob.Spec.RunPolicy.Suspend = ptr.To(true)
msg := fmt.Sprintf("MPIJob %s/%s is created.", mpiJob.Namespace, mpiJob.Name)
updateMPIJobConditions(mpiJob, kubeflow.JobCreated, corev1.ConditionTrue, mpiJobCreatedReason, msg)
updateMPIJobConditions(mpiJob, kubeflow.JobSuspended, corev1.ConditionTrue, mpiJobSuspendedReason, "MPIJob suspended")
msg = fmt.Sprintf("MPIJob %s/%s is suspended.", mpiJob.Namespace, mpiJob.Name)
updateMPIJobConditions(mpiJob, kubeflow.JobRunning, corev1.ConditionFalse, mpiJobSuspendedReason, msg)
mpiJob.Status.ReplicaStatuses = map[kubeflow.MPIReplicaType]*kubeflow.ReplicaStatus{
kubeflow.MPIReplicaTypeLauncher: {},
kubeflow.MPIReplicaTypeWorker: {},
}
f.setUpMPIJob(mpiJob)

scheme.Scheme.Default(mpiJob)
f.expectCreateServiceAction(newJobService(mpiJob))
cfgMap := newConfigMap(mpiJob, replicas, "")
updateDiscoverHostsInConfigMap(cfgMap, mpiJob, nil, "")
f.setUpConfigMap(cfgMap)
secret, err := newSSHAuthSecret(mpiJob)
if err != nil {
t.Fatalf("Failed creating secret")
}
f.setUpSecret(secret)

// set up an existing suspended launcher (startTime == nil, never started)
fmjc := f.newFakeMPIJobController()
launcher := fmjc.newLauncherJob(mpiJob)
launcher.Spec.Suspend = ptr.To(true)
// Simulate Kueue injecting scheduling directives into the MPIJob template
// after the launcher was already created (so the launcher has stale templates).
launcherSpec := &mpiJob.Spec.MPIReplicaSpecs[kubeflow.MPIReplicaTypeLauncher].Template
launcherSpec.Spec.NodeSelector = map[string]string{
"foo": "bar",
}
launcherSpec.Spec.Tolerations = []corev1.Toleration{
{Key: "gpu", Operator: corev1.TolerationOpEqual, Value: "true", Effect: corev1.TaintEffectNoSchedule},
}
launcherSpec.Spec.SchedulingGates = []corev1.PodSchedulingGate{
{Name: "kueue.x-k8s.io/topology"},
}
if launcherSpec.Annotations == nil {
launcherSpec.Annotations = make(map[string]string)
}
launcherSpec.Annotations["kueue.x-k8s.io/workload"] = "my-workload"
f.setUpLauncher(launcher)

fakeClock.Sleep(time.Second)

// resume the MPIJob
mpiJob.Spec.RunPolicy.Suspend = ptr.To(false)

// expect creation of the worker pods
for i := 0; i < int(replicas); i++ {
worker := fmjc.newWorker(mpiJob, i)
f.kubeActions = append(f.kubeActions, core.NewCreateAction(schema.GroupVersionResource{Resource: "pods"}, mpiJob.Namespace, worker))
}

// expect the launcher to be updated (scheduling directives synced + unsuspended)
launcherCopy := launcher.DeepCopy()
desiredPodTemplate := fmjc.newLauncherPodTemplate(mpiJob)
syncLauncherSchedulingDirectives(launcherCopy, &desiredPodTemplate)
launcherCopy.Spec.Suspend = ptr.To(false)

// Verify the synced launcher has the Kueue-injected scheduling directives.
tmpl := &launcherCopy.Spec.Template
if tmpl.Spec.NodeSelector["foo"] != "bar" {
t.Errorf("expected nodeSelector to be synced, got %v", tmpl.Spec.NodeSelector)
}
if len(tmpl.Spec.Tolerations) != 1 || tmpl.Spec.Tolerations[0].Key != "gpu" {
t.Errorf("expected tolerations to be synced, got %v", tmpl.Spec.Tolerations)
}
if len(tmpl.Spec.SchedulingGates) != 1 || tmpl.Spec.SchedulingGates[0].Name != "kueue.x-k8s.io/topology" {
t.Errorf("expected schedulingGates to be synced, got %v", tmpl.Spec.SchedulingGates)
}
if tmpl.Annotations["kueue.x-k8s.io/workload"] != "my-workload" {
t.Errorf("expected annotations to be synced, got %v", tmpl.Annotations)
}

f.expectUpdateJobAction(launcherCopy)

// expect status update
mpiJobCopy := mpiJob.DeepCopy()
mpiJobCopy.Status.StartTime = &metav1.Time{Time: fakeClock.Now()}
updateMPIJobConditions(mpiJobCopy, kubeflow.JobSuspended, corev1.ConditionFalse, "MPIJobResumed", "MPIJob resumed")
f.expectUpdateMPIJobStatusAction(mpiJobCopy)

f.runWithClock(t.Context(), getKey(mpiJob, t), fakeClock)
}

func TestResumeMPIJobClearsStartTime(t *testing.T) {
// Tests the re-admission case where the launcher has startTime != nil.
// The controller should clear StartTime via a status sub-resource update
// (consistent with JobSet), then sync scheduling directives and unsuspend.
fakeClock := clocktesting.NewFakeClock(time.Now().Truncate(time.Second))
f := newFixture(t, "")

var replicas int32 = 8
startTime := metav1.Now()
mpiJob := newMPIJob("test", &replicas, &startTime, nil)
mpiJob.Spec.RunPolicy.Suspend = ptr.To(true)
msg := fmt.Sprintf("MPIJob %s/%s is created.", mpiJob.Namespace, mpiJob.Name)
updateMPIJobConditions(mpiJob, kubeflow.JobCreated, corev1.ConditionTrue, mpiJobCreatedReason, msg)
updateMPIJobConditions(mpiJob, kubeflow.JobSuspended, corev1.ConditionTrue, mpiJobSuspendedReason, "MPIJob suspended")
msg = fmt.Sprintf("MPIJob %s/%s is suspended.", mpiJob.Namespace, mpiJob.Name)
updateMPIJobConditions(mpiJob, kubeflow.JobRunning, corev1.ConditionFalse, mpiJobSuspendedReason, msg)
mpiJob.Status.ReplicaStatuses = map[kubeflow.MPIReplicaType]*kubeflow.ReplicaStatus{
kubeflow.MPIReplicaTypeLauncher: {},
kubeflow.MPIReplicaTypeWorker: {},
}
f.setUpMPIJob(mpiJob)

scheme.Scheme.Default(mpiJob)
f.expectCreateServiceAction(newJobService(mpiJob))
cfgMap := newConfigMap(mpiJob, replicas, "")
updateDiscoverHostsInConfigMap(cfgMap, mpiJob, nil, "")
f.setUpConfigMap(cfgMap)
secret, err := newSSHAuthSecret(mpiJob)
if err != nil {
t.Fatalf("Failed creating secret")
}
f.setUpSecret(secret)

// set up an existing suspended launcher that was previously started (startTime != nil)
fmjc := f.newFakeMPIJobController()
launcher := fmjc.newLauncherJob(mpiJob)
launcher.Spec.Suspend = ptr.To(true)
launcherStartTime := metav1.Now()
launcher.Status.StartTime = &launcherStartTime
f.setUpLauncher(launcher)

fakeClock.Sleep(time.Second)

// resume the MPIJob
mpiJob.Spec.RunPolicy.Suspend = ptr.To(false)

// expect creation of worker pods
for i := 0; i < int(replicas); i++ {
worker := fmjc.newWorker(mpiJob, i)
f.kubeActions = append(f.kubeActions, core.NewCreateAction(schema.GroupVersionResource{Resource: "pods"}, mpiJob.Namespace, worker))
}

// expect a status sub-resource update to clear launcher's StartTime
launcherStatusCleared := launcher.DeepCopy()
launcherStatusCleared.Status.StartTime = nil
f.kubeActions = append(f.kubeActions, core.NewUpdateSubresourceAction(
schema.GroupVersionResource{Resource: "jobs", Group: "batch", Version: "v1"},
"status",
mpiJob.Namespace,
launcherStatusCleared,
))

// expect the launcher to be updated (scheduling directives synced + unsuspended)
launcherCopy := launcher.DeepCopy()
launcherCopy.Status.StartTime = nil
desiredPodTemplate := fmjc.newLauncherPodTemplate(mpiJob)
syncLauncherSchedulingDirectives(launcherCopy, &desiredPodTemplate)
launcherCopy.Spec.Suspend = ptr.To(false)
f.expectUpdateJobAction(launcherCopy)

// expect MPIJob status update
mpiJobCopy := mpiJob.DeepCopy()
mpiJobCopy.Status.StartTime = &metav1.Time{Time: fakeClock.Now()}
updateMPIJobConditions(mpiJobCopy, kubeflow.JobSuspended, corev1.ConditionFalse, "MPIJobResumed", "MPIJob resumed")
f.expectUpdateMPIJobStatusAction(mpiJobCopy)

f.runWithClock(t.Context(), getKey(mpiJob, t), fakeClock)
}

func TestWorkerNotControlledByUs(t *testing.T) {
f := newFixture(t, "")
startTime := metav1.Now()
Expand Down
62 changes: 61 additions & 1 deletion test/integration/mpi_job_controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -385,14 +385,37 @@ func TestMPIJobResumingAndSuspending(t *testing.T) {
t.Errorf("MPIJob missing Suspended condition")
}
if !isJobSuspended(launcherJob) {
t.Errorf("LauncherJob is suspended")
t.Errorf("LauncherJob is not suspended")
}
if mpiJob.Status.StartTime != nil {
t.Errorf("MPIJob has unexpected start time: %v", mpiJob.Status.StartTime)
}

s.events.verify(t)

// Simulate Kueue injecting scheduling directives into the MPIJob template
// while suspended. When resumed, these must propagate to the launcher Job.
launcherTemplate := &mpiJob.Spec.MPIReplicaSpecs[kubeflow.MPIReplicaTypeLauncher].Template
launcherTemplate.Labels = map[string]string{
"foo": "bar",
}
launcherTemplate.Annotations = map[string]string{
"kueue.x-k8s.io/workload": "my-workload",
}
launcherTemplate.Spec.NodeSelector = map[string]string{
"example.com/accelerator": "example-model",
}
launcherTemplate.Spec.Tolerations = []corev1.Toleration{
{Key: "gpu", Operator: corev1.TolerationOpEqual, Value: "true", Effect: corev1.TaintEffectNoSchedule},
}
launcherTemplate.Spec.SchedulingGates = []corev1.PodSchedulingGate{
{Name: "kueue.x-k8s.io/topology"},
}
mpiJob, err = s.mpiClient.KubeflowV2beta1().MPIJobs(s.namespace).Update(ctx, mpiJob, metav1.UpdateOptions{})
if err != nil {
t.Fatalf("Failed to update the MPIJob: %v", err)
}

// 2. Resume the MPIJob
mpiJob.Spec.RunPolicy.Suspend = ptr.To(false)
mpiJob, err = s.mpiClient.KubeflowV2beta1().MPIJobs(mpiJob.Namespace).Update(ctx, mpiJob, metav1.UpdateOptions{})
Expand Down Expand Up @@ -422,6 +445,24 @@ func TestMPIJobResumingAndSuspending(t *testing.T) {

s.events.verify(t)

// Verify all scheduling directives were propagated to the launcher Job's pod template.
launcherTmpl := &launcherJob.Spec.Template
if launcherTmpl.Labels["foo"] != "bar" {
t.Errorf("expected label 'foo=bar' on launcher Job template, got labels: %v", launcherTmpl.Labels)
}
if launcherTmpl.Annotations["kueue.x-k8s.io/workload"] != "my-workload" {
t.Errorf("expected annotation 'kueue.x-k8s.io/workload' on launcher Job template, got annotations: %v", launcherTmpl.Annotations)
}
if launcherTmpl.Spec.NodeSelector["example.com/accelerator"] != "example-model" {
t.Errorf("expected nodeSelector 'example.com/accelerator=example-model' on launcher Job template, got: %v", launcherTmpl.Spec.NodeSelector)
}
if len(launcherTmpl.Spec.Tolerations) == 0 || launcherTmpl.Spec.Tolerations[len(launcherTmpl.Spec.Tolerations)-1].Key != "gpu" {
t.Errorf("expected toleration with key 'gpu' on launcher Job template, got: %v", launcherTmpl.Spec.Tolerations)
}
if len(launcherTmpl.Spec.SchedulingGates) == 0 || launcherTmpl.Spec.SchedulingGates[len(launcherTmpl.Spec.SchedulingGates)-1].Name != "kueue.x-k8s.io/topology" {
t.Errorf("expected schedulingGate 'kueue.x-k8s.io/topology' on launcher Job template, got: %v", launcherTmpl.Spec.SchedulingGates)
}

// 3. Set the pods to be running
err = updatePodsToPhase(ctx, s.kClient, workerPods, corev1.PodRunning)
if err != nil {
Expand Down Expand Up @@ -473,6 +514,25 @@ func TestMPIJobResumingAndSuspending(t *testing.T) {
if !mpiJobHasConditionWithStatus(mpiJob, kubeflow.JobRunning, corev1.ConditionFalse) {
t.Errorf("MPIJob has unexpected Running condition")
}

// Update the MPIJob launcher template again and resume, verifying the
// launcher Job gets the updated scheduling directives on second resume.
mpiJobLauncherTemplate := &mpiJob.Spec.MPIReplicaSpecs[kubeflow.MPIReplicaTypeLauncher].Template
mpiJobLauncherTemplate.Labels["foo"] = "baz"
mpiJobLauncherTemplate.Spec.NodeSelector["example.com/accelerator"] = "example-model-v2"
mpiJob.Spec.RunPolicy.Suspend = ptr.To(false)
mpiJob, err = s.mpiClient.KubeflowV2beta1().MPIJobs(s.namespace).Update(ctx, mpiJob, metav1.UpdateOptions{})
if err != nil {
t.Fatalf("Failed to update the MPIJob: %v", err)
}

_, launcherJob = validateMPIJobDependencies(ctx, t, s.kClient, mpiJob, 2, nil)
if launcherJob.Spec.Template.Labels["foo"] != "baz" {
t.Errorf("expected label 'foo=baz' on launcher Job template, got labels: %v", launcherJob.Spec.Template.Labels)
}
if launcherJob.Spec.Template.Spec.NodeSelector["example.com/accelerator"] != "example-model-v2" {
t.Errorf("expected nodeSelector 'example.com/accelerator=example-model-v2' on launcher Job template, got: %v", launcherJob.Spec.Template.Spec.NodeSelector)
}
}

func TestMPIJobFailure(t *testing.T) {
Expand Down
Loading