diff --git a/internal/devices/devices.go b/internal/devices/devices.go index 6f43d0a84..294879bdb 100644 --- a/internal/devices/devices.go +++ b/internal/devices/devices.go @@ -53,3 +53,11 @@ var assertCharDeviceStub = func(path string) error { } return nil } + +func IsOverrideApplied() bool { + return isOverrideAppliedStub() +} + +var isOverrideAppliedStub = func() bool { + return false +} diff --git a/internal/devices/devices_mock.go b/internal/devices/devices_mock.go index ad1eaedcc..5dced6724 100644 --- a/internal/devices/devices_mock.go +++ b/internal/devices/devices_mock.go @@ -23,6 +23,9 @@ var _ Interface = &InterfaceMock{} // DeviceFromPathFunc: func(s1 string, s2 string) (*Device, error) { // panic("mock out the DeviceFromPath method") // }, +// IsOverrideAppliedFunc: func() bool { +// panic("mock out the IsOverrideApplied method") +// }, // } // // // use mockedInterface in code that requires Interface @@ -36,6 +39,9 @@ type InterfaceMock struct { // DeviceFromPathFunc mocks the DeviceFromPath method. DeviceFromPathFunc func(s1 string, s2 string) (*Device, error) + // IsOverrideAppliedFunc mocks the IsOverrideApplied method. + IsOverrideAppliedFunc func() bool + // calls tracks calls to the methods. calls struct { // AssertCharDevice holds details about calls to the AssertCharDevice method. @@ -50,9 +56,13 @@ type InterfaceMock struct { // S2 is the s2 argument value. S2 string } + // IsOverrideApplied holds details about calls to the IsOverrideApplied method. + IsOverrideApplied []struct { + } } - lockAssertCharDevice sync.RWMutex - lockDeviceFromPath sync.RWMutex + lockAssertCharDevice sync.RWMutex + lockDeviceFromPath sync.RWMutex + lockIsOverrideApplied sync.RWMutex } // AssertCharDevice calls AssertCharDeviceFunc. @@ -122,3 +132,30 @@ func (mock *InterfaceMock) DeviceFromPathCalls() []struct { mock.lockDeviceFromPath.RUnlock() return calls } + +// IsOverrideApplied calls IsOverrideAppliedFunc. +func (mock *InterfaceMock) IsOverrideApplied() bool { + if mock.IsOverrideAppliedFunc == nil { + panic("InterfaceMock.IsOverrideAppliedFunc: method is nil but Interface.IsOverrideApplied was just called") + } + callInfo := struct { + }{} + mock.lockIsOverrideApplied.Lock() + mock.calls.IsOverrideApplied = append(mock.calls.IsOverrideApplied, callInfo) + mock.lockIsOverrideApplied.Unlock() + return mock.IsOverrideAppliedFunc() +} + +// IsOverrideAppliedCalls gets all the calls that were made to IsOverrideApplied. +// Check the length with: +// +// len(mockedInterface.IsOverrideAppliedCalls()) +func (mock *InterfaceMock) IsOverrideAppliedCalls() []struct { +} { + var calls []struct { + } + mock.lockIsOverrideApplied.RLock() + calls = mock.calls.IsOverrideApplied + mock.lockIsOverrideApplied.RUnlock() + return calls +} diff --git a/internal/devices/devices_tests.go b/internal/devices/devices_tests.go index b2d6e930d..dc5340119 100644 --- a/internal/devices/devices_tests.go +++ b/internal/devices/devices_tests.go @@ -29,6 +29,7 @@ import ( type Interface interface { DeviceFromPath(string, string) (*Device, error) AssertCharDevice(string) error + IsOverrideApplied() bool } type testDefaults struct{} @@ -47,6 +48,7 @@ func SetInterfaceForTests(m Interface) func() { funcs := []func(){ SetDeviceFromPathForTest(m.DeviceFromPath), SetAssertCharDeviceForTest(m.AssertCharDevice), + SetIsOverrideAppliedForTest(m.IsOverrideApplied), } return func() { for _, f := range funcs { @@ -71,6 +73,14 @@ func SetAssertCharDeviceForTest(testFunc func(string) error) func() { } } +func SetIsOverrideAppliedForTest(testFunc func() bool) func() { + current := isOverrideAppliedStub + isOverrideAppliedStub = testFunc + return func() { + isOverrideAppliedStub = current + } +} + type testDevice struct { Device } @@ -115,3 +125,7 @@ func (t *testDefaults) AssertCharDevice(path string) error { return nil } + +func (t *testDefaults) IsOverrideApplied() bool { + return true +} diff --git a/internal/modifier/cdi.go b/internal/modifier/cdi.go index e7f28df65..3a6bc3bf3 100644 --- a/internal/modifier/cdi.go +++ b/internal/modifier/cdi.go @@ -27,6 +27,7 @@ import ( "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" "github.com/NVIDIA/nvidia-container-toolkit/internal/modifier/cdi" "github.com/NVIDIA/nvidia-container-toolkit/internal/oci" + "github.com/NVIDIA/nvidia-container-toolkit/internal/platform-support/tegra/csv" "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi" ) @@ -180,6 +181,14 @@ func (f *Factory) newAutomaticCDISpecModifier(devices []string) (oci.SpecModifie nvcdiFeatureFlags = append(nvcdiFeatureFlags, nvcdi.FeatureNoAdditionalGIDsForDeviceNodes) } + csvFiles, err := csv.GetFileList(f.cfg.NVIDIAContainerRuntimeConfig.Modes.CSV.MountSpecPath) + if err != nil { + f.logger.Warningf("Failed to get the list of CSV files: %v", err) + } + if f.image.Getenv(image.EnvVarNvidiaRequireJetpack) != "csv-mounts=all" { + csvFiles = csv.BaseFilesOnly(csvFiles) + } + cdiModeIdentifiers := cdiModeIdentfiersFromDevices(devices...) f.logger.Debugf("Per-mode identifiers: %v", cdiModeIdentifiers) var modifiers oci.SpecModifiers @@ -194,6 +203,7 @@ func (f *Factory) newAutomaticCDISpecModifier(devices []string) (oci.SpecModifie nvcdi.WithMode(mode), nvcdi.WithFeatureFlags(nvcdiFeatureFlags...), nvcdi.WithCSVCompatContainerRoot(f.cfg.NVIDIAContainerRuntimeConfig.Modes.CSV.CompatContainerRoot), + nvcdi.WithCSVFiles(csvFiles), ) if err != nil { return nil, fmt.Errorf("failed to construct CDI library for mode %q: %w", mode, err) diff --git a/internal/modifier/cdi/spec.go b/internal/modifier/cdi/spec.go index 24b475ee0..69cabb83e 100644 --- a/internal/modifier/cdi/spec.go +++ b/internal/modifier/cdi/spec.go @@ -20,14 +20,16 @@ import ( "fmt" "github.com/opencontainers/runtime-spec/specs-go" - "tags.cncf.io/container-device-interface/pkg/cdi" + cdiapi "tags.cncf.io/container-device-interface/pkg/cdi" + cdi "tags.cncf.io/container-device-interface/specs-go" + "github.com/NVIDIA/nvidia-container-toolkit/internal/devices" "github.com/NVIDIA/nvidia-container-toolkit/internal/oci" ) // fromCDISpec represents the modifications performed from a raw CDI spec. type fromCDISpec struct { - cdiSpec *cdi.Spec + cdiSpec *cdiapi.Spec } var _ oci.SpecModifier = (*fromCDISpec)(nil) @@ -35,14 +37,34 @@ var _ oci.SpecModifier = (*fromCDISpec)(nil) // Modify applies the mofiications defined by the raw CDI spec to the incomming OCI spec. func (m fromCDISpec) Modify(spec *specs.Spec) error { for _, device := range m.cdiSpec.Devices { - device := device - cdiDevice := cdi.Device{ + device := m.enrichDevice(device) + cdiDevice := cdiapi.Device{ Device: &device, } if err := cdiDevice.ApplyEdits(spec); err != nil { - return fmt.Errorf("failed to apply edits for device %q: %v", cdiDevice.GetQualifiedName(), err) + return fmt.Errorf("failed to apply edits for device %q: %v", m.cdiSpec.Kind+"="+device.Name, err) } } return m.cdiSpec.ApplyEdits(spec) } + +func (m fromCDISpec) enrichDevice(device cdi.Device) cdi.Device { + if !devices.IsOverrideApplied() { + return device + } + // For testing we need to override the device node information to ensure + // that we don't trigger the CDI modification that requires the device node + // to exist and be a character device. + // The following condition is used to determine whether a failure to get + // the info is fatal: + // hasMinimalSpecification := d.Type != "" && (d.Major != 0 || d.Type == fifoDevice) + for i, dn := range device.ContainerEdits.DeviceNodes { + dn.Type = "c" + if dn.Major == 0 { + dn.Major = 99 + } + device.ContainerEdits.DeviceNodes[i] = dn + } + return device +} diff --git a/internal/modifier/csv.go b/internal/modifier/csv.go index 5842e17d3..b20fdb134 100644 --- a/internal/modifier/csv.go +++ b/internal/modifier/csv.go @@ -22,17 +22,14 @@ import ( "github.com/NVIDIA/nvidia-container-toolkit/internal/config/image" "github.com/NVIDIA/nvidia-container-toolkit/internal/cuda" "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" - "github.com/NVIDIA/nvidia-container-toolkit/internal/modifier/cdi" "github.com/NVIDIA/nvidia-container-toolkit/internal/oci" - "github.com/NVIDIA/nvidia-container-toolkit/internal/platform-support/tegra/csv" "github.com/NVIDIA/nvidia-container-toolkit/internal/requirements" - "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi" ) // newCSVModifier creates a modifier that applies modications to an OCI spec if required by the runtime wrapper. // The modifications are defined by CSV MountSpecs. func (f *Factory) newCSVModifier() (oci.SpecModifier, error) { - devices := f.image.VisibleDevices() + devices := withUniqueDevices(csvDevices(*f.image)).DeviceRequests() if len(devices) == 0 { f.logger.Infof("No modification required; no devices requested") return nil, nil @@ -43,37 +40,7 @@ func (f *Factory) newCSVModifier() (oci.SpecModifier, error) { return nil, fmt.Errorf("requirements not met: %v", err) } - csvFiles, err := csv.GetFileList(f.cfg.NVIDIAContainerRuntimeConfig.Modes.CSV.MountSpecPath) - if err != nil { - return nil, fmt.Errorf("failed to get list of CSV files: %v", err) - } - - if f.image.Getenv(image.EnvVarNvidiaRequireJetpack) != "csv-mounts=all" { - csvFiles = csv.BaseFilesOnly(csvFiles) - } - - cdilib, err := nvcdi.New( - nvcdi.WithLogger(f.logger), - nvcdi.WithDriverRoot(f.driver.Root), - nvcdi.WithDevRoot(f.driver.DevRoot), - nvcdi.WithNVIDIACDIHookPath(f.cfg.NVIDIACTKConfig.Path), - nvcdi.WithMode(nvcdi.ModeCSV), - nvcdi.WithCSVFiles(csvFiles), - nvcdi.WithCSVCompatContainerRoot(f.cfg.NVIDIAContainerRuntimeConfig.Modes.CSV.CompatContainerRoot), - ) - if err != nil { - return nil, fmt.Errorf("failed to construct CDI library: %v", err) - } - - spec, err := cdilib.GetSpec(devices...) - if err != nil { - return nil, fmt.Errorf("failed to get CDI spec: %v", err) - } - - return cdi.New( - cdi.WithLogger(f.logger), - cdi.WithSpec(spec.Raw()), - ) + return f.newAutomaticCDISpecModifier(devices) } func checkRequirements(logger logger.Interface, image *image.CUDA) error { @@ -107,3 +74,14 @@ func checkRequirements(logger logger.Interface, image *image.CUDA) error { return r.Assert() } + +type csvDevices image.CUDA + +func (d csvDevices) DeviceRequests() []string { + var devices []string + i := (image.CUDA)(d) + for _, deviceID := range i.VisibleDevices() { + devices = append(devices, "mode=csv,id="+deviceID) + } + return devices +} diff --git a/internal/modifier/csv_test.go b/internal/modifier/csv_test.go index 809a784c4..d91c6d019 100644 --- a/internal/modifier/csv_test.go +++ b/internal/modifier/csv_test.go @@ -17,40 +17,89 @@ package modifier import ( + "path/filepath" "testing" + "github.com/opencontainers/runtime-spec/specs-go" testlog "github.com/sirupsen/logrus/hooks/test" "github.com/stretchr/testify/require" "github.com/NVIDIA/nvidia-container-toolkit/api/config/v1" "github.com/NVIDIA/nvidia-container-toolkit/internal/config/image" + "github.com/NVIDIA/nvidia-container-toolkit/internal/devices" "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/root" + "github.com/NVIDIA/nvidia-container-toolkit/internal/test" + "github.com/NVIDIA/nvidia-container-toolkit/internal/test/to" ) func TestNewCSVModifier(t *testing.T) { logger, _ := testlog.NewNullLogger() + defer devices.SetAllForTest()() + + moduleRoot, err := test.GetModuleRoot() + require.NoError(t, err) + + lookupRoot := filepath.Join(moduleRoot, "testdata", "lookup") + testCases := []struct { - description string - cfg *config.Config - envmap map[string]string - expectedError error - expectedNil bool + description string + cfg config.Config + envmap map[string]string + driverRootfs string + devRootfs string + expectedErrorString string + expectedSpec specs.Spec }{ { description: "visible devices not set returns nil", envmap: map[string]string{}, - expectedNil: true, }, { description: "visible devices empty returns nil", envmap: map[string]string{"NVIDIA_VISIBLE_DEVICES": ""}, - expectedNil: true, }, { description: "visible devices 'void' returns nil", envmap: map[string]string{"NVIDIA_VISIBLE_DEVICES": "void"}, - expectedNil: true, + }, + { + description: "visible devices all", + envmap: map[string]string{"NVIDIA_VISIBLE_DEVICES": "all"}, + driverRootfs: "rootfs-orin", + expectedSpec: specs.Spec{ + Process: &specs.Process{ + Env: []string{"NVIDIA_VISIBLE_DEVICES=void"}, + }, + Mounts: []specs.Mount{ + {Source: "/usr/lib/aarch64-linux-gnu/nvidia/libcuda.so.1.1", Destination: "/usr/lib/aarch64-linux-gnu/nvidia/libcuda.so.1.1", Options: []string{"ro", "nosuid", "nodev", "rbind", "rprivate"}}, + {Source: "/usr/lib/aarch64-linux-gnu/nvidia/libnvidia-ml.so.1", Destination: "/usr/lib/aarch64-linux-gnu/nvidia/libnvidia-ml.so.1", Options: []string{"ro", "nosuid", "nodev", "rbind", "rprivate"}}, + }, + Hooks: &specs.Hooks{ + CreateContainer: []specs.Hook{ + { + Path: "/usr/bin/nvidia-cdi-hook", + Args: []string{"nvidia-cdi-hook", "create-symlinks", "--link", "libcuda.so.1::/usr/lib/aarch64-linux-gnu/nvidia/libcuda.so"}, + Env: []string{"NVIDIA_CTK_DEBUG=false"}, + }, + { + Path: "/usr/bin/nvidia-cdi-hook", + Args: []string{"nvidia-cdi-hook", "update-ldcache", "--folder", "/usr/lib/aarch64-linux-gnu/nvidia"}, + Env: []string{"NVIDIA_CTK_DEBUG=false"}, + }, + }, + }, + Linux: &specs.Linux{ + Devices: []specs.LinuxDevice{ + {Path: "/dev/nvidia0", Type: "c", Major: 99}, + }, + Resources: &specs.LinuxResources{ + Devices: []specs.LinuxDeviceCgroup{ + {Allow: true, Type: "c", Major: to.Ptr[int64](99), Minor: to.Ptr[int64](0), Access: "rwm"}, + }, + }, + }, + }, }, } @@ -59,24 +108,39 @@ func TestNewCSVModifier(t *testing.T) { image, _ := image.New( image.WithEnvMap(tc.envmap), ) + driverRoot := tc.driverRootfs + if driverRoot != "" { + driverRoot = filepath.Join(lookupRoot, tc.driverRootfs) + } + devRoot := tc.devRootfs + if devRoot != "" { + devRoot = filepath.Join(lookupRoot, tc.devRootfs) + } + driver := root.New(root.WithDriverRoot(driverRoot), root.WithDevRoot(devRoot)) + // Override the CSV file search path for this root. + tc.cfg.NVIDIAContainerRuntimeConfig.Modes.CSV.MountSpecPath = filepath.Join(driverRoot, "/etc/nvidia-container-runtime/host-files-for-container.d") + tc.cfg.NVIDIACTKConfig.Path = "/usr/bin/nvidia-cdi-hook" + f := createFactory( WithLogger(logger), - WithDriver(root.New()), - WithConfig(tc.cfg), + WithDriver(driver), + WithLogger(logger), + WithConfig(&tc.cfg), WithImage(&image), ) m, err := f.newCSVModifier() - if tc.expectedError != nil { - require.Error(t, err) - } else { + if tc.expectedErrorString == "" { require.NoError(t, err) - } - - if tc.expectedNil || tc.expectedError != nil { - require.Nil(t, m) } else { - require.NotNil(t, m) + require.EqualError(t, err, tc.expectedErrorString) + return } + + s := specs.Spec{} + err = list{m}.Modify(&s) + require.NoError(t, err) + + require.EqualValues(t, tc.expectedSpec, test.StripRoot(s, driverRoot)) }) } } diff --git a/internal/modifier/factory.go b/internal/modifier/factory.go index 150fd888d..6b58b56d7 100644 --- a/internal/modifier/factory.go +++ b/internal/modifier/factory.go @@ -62,7 +62,6 @@ func createFactory(opts ...Option) *Factory { for _, opt := range opts { opt(f) } - if f.editsFactory == nil { f.editsFactory = edits.NewFactory(edits.WithLogger(f.logger)) } @@ -71,6 +70,9 @@ func createFactory(opts ...Option) *Factory { } func (f *Factory) validate() error { + if f.driver == nil { + return fmt.Errorf("a driver must be specified") + } switch string(f.runtimeMode) { case "": return fmt.Errorf("a mode must be specified") diff --git a/internal/modifier/gated.go b/internal/modifier/gated.go index e529cda5d..1bebb241d 100644 --- a/internal/modifier/gated.go +++ b/internal/modifier/gated.go @@ -21,8 +21,6 @@ import ( "github.com/NVIDIA/nvidia-container-toolkit/api/config/v1" "github.com/NVIDIA/nvidia-container-toolkit/internal/discover" - "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" - "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/root" "github.com/NVIDIA/nvidia-container-toolkit/internal/oci" ) @@ -41,84 +39,59 @@ func (f *Factory) newFeatureGatedModifier() (oci.SpecModifier, error) { return nil, nil } - var discoverers []discover.Discover - - if f.image.Getenv("NVIDIA_GDS") == "enabled" { - d, err := discover.NewGDSDiscoverer(f.logger, f.driver) - if err != nil { - return nil, fmt.Errorf("failed to construct discoverer for GDS devices: %w", err) - } - discoverers = append(discoverers, d) - } - - if f.image.Getenv("NVIDIA_MOFED") == "enabled" { - d, err := discover.NewMOFEDDiscoverer(f.logger, f.driver) - if err != nil { - return nil, fmt.Errorf("failed to construct discoverer for MOFED devices: %w", err) - } - discoverers = append(discoverers, d) - } - - if f.image.Getenv("NVIDIA_NVSWITCH") == "enabled" { - d, err := discover.NewNvSwitchDiscoverer(f.logger, f.driver) - if err != nil { - return nil, fmt.Errorf("failed to construct discoverer for NVSWITCH devices: %w", err) - } - discoverers = append(discoverers, d) - } - - if f.image.Getenv("NVIDIA_GDRCOPY") == "enabled" { - d, err := discover.NewGDRCopyDiscoverer(f.logger, f.driver) + var modifers list + if gatedDeviceRequests := withUniqueDevices(gatedDevices(*f.image)).DeviceRequests(); len(gatedDeviceRequests) != 0 { + featureGatedModifier, err := f.newAutomaticCDISpecModifier(gatedDeviceRequests) if err != nil { - return nil, fmt.Errorf("failed to construct discoverer for GDRCopy devices: %w", err) + return nil, err } - discoverers = append(discoverers, d) + modifers = append(modifers, featureGatedModifier) } // If the feature flag has explicitly been toggled, we don't make any modification. if !f.cfg.Features.DisableCUDACompatLibHook.IsEnabled() { - cudaCompatDiscoverer, err := getCudaCompatModeDiscoverer(f.logger, f.cfg, f.driver, f.hookCreator) + cudaCompatModifer, err := f.getCudaCompatModeModifier() if err != nil { return nil, fmt.Errorf("failed to construct CUDA Compat discoverer: %w", err) } - discoverers = append(discoverers, cudaCompatDiscoverer) + modifers = append(modifers, cudaCompatModifer) } - return f.newModifierFromDiscoverer(discover.Merge(discoverers...)) + return modifers, nil } -func getCudaCompatModeDiscoverer(logger logger.Interface, cfg *config.Config, driver *root.Driver, hookCreator discover.HookCreator) (discover.Discover, error) { +func (f *Factory) getCudaCompatModeModifier() (oci.SpecModifier, error) { // We don't support the enable-cuda-compat hook in CSV mode. - if cfg.NVIDIAContainerRuntimeConfig.Mode == "csv" { + if f.cfg.NVIDIAContainerRuntimeConfig.Mode == "csv" { return nil, nil } // For legacy mode, we only include the enable-cuda-compat hook if cuda-compat-mode is set to hook. - if cfg.NVIDIAContainerRuntimeConfig.Mode == "legacy" && cfg.NVIDIAContainerRuntimeConfig.Modes.Legacy.CUDACompatMode != config.CUDACompatModeHook { + if f.cfg.NVIDIAContainerRuntimeConfig.Mode == "legacy" && f.cfg.NVIDIAContainerRuntimeConfig.Modes.Legacy.CUDACompatMode != config.CUDACompatModeHook { return nil, nil } - version, err := driver.Version() + version, err := f.driver.Version() if err != nil { return nil, fmt.Errorf("failed to get driver version: %w", err) } - compatLibHookDiscoverer := discover.NewCUDACompatHookDiscoverer(logger, hookCreator, &discover.EnableCUDACompatHookOptions{HostDriverVersion: version}) + compatLibHookDiscoverer := discover.NewCUDACompatHookDiscoverer(f.logger, f.hookCreator, &discover.EnableCUDACompatHookOptions{HostDriverVersion: version}) // For non-legacy modes we return the hook as is. These modes *should* already include the update-ldcache hook. - if cfg.NVIDIAContainerRuntimeConfig.Mode != "legacy" { - return compatLibHookDiscoverer, nil + if f.cfg.NVIDIAContainerRuntimeConfig.Mode != "legacy" { + return f.newModifierFromDiscoverer(compatLibHookDiscoverer) } // For legacy mode, we also need to inject a hook to update the LDCache // after we have modifed the configuration. ldcacheUpdateHookDiscoverer, err := discover.NewLDCacheUpdateHook( - logger, + f.logger, discover.None{}, - hookCreator, + f.hookCreator, ) if err != nil { return nil, fmt.Errorf("failed to construct ldcache update discoverer: %w", err) } - return discover.Merge(compatLibHookDiscoverer, ldcacheUpdateHookDiscoverer), nil + return f.newModifierFromDiscoverer(discover.Merge(compatLibHookDiscoverer, ldcacheUpdateHookDiscoverer)) }