diff --git a/controllers/object_controls.go b/controllers/object_controls.go index 992a698b9..cf67e13d2 100644 --- a/controllers/object_controls.go +++ b/controllers/object_controls.go @@ -1314,14 +1314,19 @@ func TransformToolkit(obj *appsv1.DaemonSet, config *gpuv1.ClusterPolicySpec, n } } - if len(config.Toolkit.Env) > 0 { - for _, env := range config.Toolkit.Env { - setContainerEnv(toolkitMainContainer, env.Name, env.Value) + // configure runtime + runtime := n.runtime.String() + // Update the main container environment from the user-specified values. + for _, env := range config.Toolkit.Env { + if env.Name == "RUNTIME" { + // If the user has specified the runtime, we overide the detected + // value. + // TODO: Add logging. + runtime = env.Value } + setContainerEnv(toolkitMainContainer, env.Name, env.Value) } - // configure runtime - runtime := n.runtime.String() err = transformForRuntime(obj, config, runtime, toolkitMainContainer) if err != nil { return fmt.Errorf("error transforming toolkit daemonset : %w", err) @@ -1332,6 +1337,11 @@ func TransformToolkit(obj *appsv1.DaemonSet, config *gpuv1.ClusterPolicySpec, n func transformForRuntime(obj *appsv1.DaemonSet, config *gpuv1.ClusterPolicySpec, runtime string, container *corev1.Container) error { setContainerEnv(container, "RUNTIME", runtime) + // If the user has explicitly requested 'none' as a runtime, we make no + // additional changes to the container. + if runtime == "none" { + return nil + } if runtime == gpuv1.Containerd.String() { // Set the runtime class name that is to be configured for containerd diff --git a/controllers/transforms_test.go b/controllers/transforms_test.go index ff384c676..95e1bd83b 100644 --- a/controllers/transforms_test.go +++ b/controllers/transforms_test.go @@ -17,6 +17,7 @@ package controllers import ( + "errors" "path/filepath" "testing" @@ -777,11 +778,12 @@ func TestApplyCommonDaemonsetMetadata(t *testing.T) { func TestTransformToolkit(t *testing.T) { testCases := []struct { - description string - ds Daemonset // Input DaemonSet - cpSpec *gpuv1.ClusterPolicySpec // Input configuration - runtime gpuv1.Runtime - expectedDs Daemonset // Expected output DaemonSet + description string + ds Daemonset // Input DaemonSet + cpSpec *gpuv1.ClusterPolicySpec // Input configuration + runtime gpuv1.Runtime + expectedError error + expectedDs Daemonset // Expected output DaemonSet }{ { description: "transform nvidia-container-toolkit-ctr container", @@ -1002,6 +1004,42 @@ func TestTransformToolkit(t *testing.T) { WithHostPathVolume("crio-config", "/etc/crio", ptr.To(corev1.HostPathDirectoryOrCreate)). WithHostPathVolume("crio-drop-in-config", "/etc/crio/crio.conf.d", ptr.To(corev1.HostPathDirectoryOrCreate)), }, + { + description: "transform nvidia-container-toolkit-ctr container with none runtime", + ds: NewDaemonset(). + WithContainer(corev1.Container{Name: "nvidia-container-toolkit-ctr"}), + runtime: gpuv1.Containerd, + cpSpec: &gpuv1.ClusterPolicySpec{ + Toolkit: gpuv1.ToolkitSpec{ + Repository: "nvcr.io/nvidia/cloud-native", + Image: "nvidia-container-toolkit", + Version: "v1.0.0", + Env: []gpuv1.EnvVar{ + {Name: "RUNTIME", Value: "none"}, + }, + }, + }, + expectedDs: NewDaemonset(). + WithContainer(corev1.Container{ + Name: "nvidia-container-toolkit-ctr", + Image: "nvcr.io/nvidia/cloud-native/nvidia-container-toolkit:v1.0.0", + ImagePullPolicy: corev1.PullIfNotPresent, + Env: []corev1.EnvVar{ + {Name: "CDI_ENABLED", Value: "true"}, + {Name: "NVIDIA_RUNTIME_SET_AS_DEFAULT", Value: "false"}, + {Name: "NVIDIA_CONTAINER_RUNTIME_MODE", Value: "cdi"}, + {Name: "CRIO_CONFIG_MODE", Value: "config"}, + {Name: "RUNTIME", Value: "none"}, + }, + VolumeMounts: nil, + }), + }, + { + description: "no nvidia-container-toolkit-ctr container", + ds: NewDaemonset(), + expectedError: errors.New(`failed to find toolkit container "nvidia-container-toolkit-ctr"`), + expectedDs: NewDaemonset(), + }, } for _, tc := range testCases { @@ -1012,7 +1050,7 @@ func TestTransformToolkit(t *testing.T) { } err := TransformToolkit(tc.ds.DaemonSet, tc.cpSpec, controller) - require.NoError(t, err) + require.EqualValues(t, tc.expectedError, err) require.EqualValues(t, tc.expectedDs, tc.ds) }) }