diff --git a/pkg/nvcdi/lib-csv.go b/pkg/nvcdi/lib-csv.go index 77fcd95b3..c01115462 100644 --- a/pkg/nvcdi/lib-csv.go +++ b/pkg/nvcdi/lib-csv.go @@ -46,11 +46,7 @@ var _ deviceSpecGeneratorFactory = (*csvlib)(nil) // If NVML is not available or the disable-multiple-csv-devices feature flag is // enabled, a single device is assumed. func (l *csvlib) DeviceSpecGenerators(ids ...string) (DeviceSpecGenerator, error) { - if l.featureFlags[FeatureDisableMultipleCSVDevices] { - return l.purecsvDeviceSpecGenerators(ids...) - } - hasNVML, _ := l.infolib.HasNvml() - if !hasNVML { + if l.usePureCSVDeviceSpecGenerator() { return l.purecsvDeviceSpecGenerators(ids...) } mixed, err := l.mixedDeviceSpecGenerators(ids...) @@ -61,6 +57,29 @@ func (l *csvlib) DeviceSpecGenerators(ids ...string) (DeviceSpecGenerator, error return mixed, nil } +func (l *csvlib) usePureCSVDeviceSpecGenerator() bool { + if l.featureFlags[FeatureDisableMultipleCSVDevices] { + return true + } + hasNVML, _ := l.infolib.HasNvml() + if !hasNVML { + return true + } + asNvmlLib := (*nvmllib)(l) + err := asNvmlLib.init() + if err != nil { + return true + } + defer asNvmlLib.tryShutdown() + + numDevices, ret := l.nvmllib.DeviceGetCount() + if ret != nvml.SUCCESS { + return true + } + + return numDevices <= 1 +} + func (l *csvlib) purecsvDeviceSpecGenerators(ids ...string) (DeviceSpecGenerator, error) { for _, id := range ids { switch id { @@ -74,6 +93,9 @@ func (l *csvlib) purecsvDeviceSpecGenerators(ids ...string) (DeviceSpecGenerator csvlib: l, index: 0, uuid: "", + // We set noFilterDeviceNodes to true to ensure that the /dev/nvidia[0-1] + // device nodes in the CSV files on the system are consumed as-is. + noFilterDeviceNodes: true, } return g, nil } @@ -86,8 +108,9 @@ func (l *csvlib) mixedDeviceSpecGenerators(ids ...string) (DeviceSpecGenerator, // platform-specific CSV files. type csvDeviceGenerator struct { *csvlib - index int - uuid string + index int + uuid string + noFilterDeviceNodes bool } func (l *csvDeviceGenerator) GetUUID() (string, error) { @@ -132,14 +155,17 @@ func (l *csvDeviceGenerator) GetDeviceSpecs() ([]specs.Device, error) { // particular device is added to the set of device nodes to be discovered. func (l *csvDeviceGenerator) deviceNodeDiscoverer() (discover.Discover, error) { mountSpecs := tegra.Transform( - tegra.Transform( - tegra.MountSpecsFromCSVFiles(l.logger, l.csvFiles...), - // We remove non-device nodes. - tegra.OnlyDeviceNodes(), - ), - // We remove the regular (nvidia[0-9]+) device nodes. - tegra.WithoutRegularDeviceNodes(), + tegra.MountSpecsFromCSVFiles(l.logger, l.csvFiles...), + // We remove non-device nodes. + tegra.OnlyDeviceNodes(), ) + if !l.noFilterDeviceNodes { + mountSpecs = tegra.Transform( + mountSpecs, + // We remove the regular (nvidia[0-9]+) device nodes. + tegra.WithoutRegularDeviceNodes(), + ) + } return tegra.New( tegra.WithLogger(l.logger), tegra.WithDriverRoot(l.driverRoot),