Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 40 additions & 14 deletions pkg/nvcdi/lib-csv.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,7 @@ var _ deviceSpecGeneratorFactory = (*csvlib)(nil)
// If NVML is not available or the disable-multiple-csv-devices feature flag is
// enabled, a single device is assumed.
func (l *csvlib) DeviceSpecGenerators(ids ...string) (DeviceSpecGenerator, error) {
if l.featureFlags[FeatureDisableMultipleCSVDevices] {
return l.purecsvDeviceSpecGenerators(ids...)
}
hasNVML, _ := l.infolib.HasNvml()
if !hasNVML {
if l.usePureCSVDeviceSpecGenerator() {
return l.purecsvDeviceSpecGenerators(ids...)
}
mixed, err := l.mixedDeviceSpecGenerators(ids...)
Expand All @@ -61,6 +57,29 @@ func (l *csvlib) DeviceSpecGenerators(ids ...string) (DeviceSpecGenerator, error
return mixed, nil
}

func (l *csvlib) usePureCSVDeviceSpecGenerator() bool {
if l.featureFlags[FeatureDisableMultipleCSVDevices] {
return true
}
hasNVML, _ := l.infolib.HasNvml()
if !hasNVML {
return true
}
asNvmlLib := (*nvmllib)(l)
err := asNvmlLib.init()
if err != nil {
return true
}
defer asNvmlLib.tryShutdown()

numDevices, ret := l.nvmllib.DeviceGetCount()
if ret != nvml.SUCCESS {
return true
}

return numDevices <= 1
}

func (l *csvlib) purecsvDeviceSpecGenerators(ids ...string) (DeviceSpecGenerator, error) {
for _, id := range ids {
switch id {
Expand All @@ -74,6 +93,9 @@ func (l *csvlib) purecsvDeviceSpecGenerators(ids ...string) (DeviceSpecGenerator
csvlib: l,
index: 0,
uuid: "",
// We set noFilterDeviceNodes to true to ensure that the /dev/nvidia[0-1]
// device nodes in the CSV files on the system are consumed as-is.
noFilterDeviceNodes: true,
}
return g, nil
}
Expand All @@ -86,8 +108,9 @@ func (l *csvlib) mixedDeviceSpecGenerators(ids ...string) (DeviceSpecGenerator,
// platform-specific CSV files.
type csvDeviceGenerator struct {
*csvlib
index int
uuid string
index int
uuid string
noFilterDeviceNodes bool
}

func (l *csvDeviceGenerator) GetUUID() (string, error) {
Expand Down Expand Up @@ -132,14 +155,17 @@ func (l *csvDeviceGenerator) GetDeviceSpecs() ([]specs.Device, error) {
// particular device is added to the set of device nodes to be discovered.
func (l *csvDeviceGenerator) deviceNodeDiscoverer() (discover.Discover, error) {
mountSpecs := tegra.Transform(
tegra.Transform(
tegra.MountSpecsFromCSVFiles(l.logger, l.csvFiles...),
// We remove non-device nodes.
tegra.OnlyDeviceNodes(),
),
// We remove the regular (nvidia[0-9]+) device nodes.
tegra.WithoutRegularDeviceNodes(),
tegra.MountSpecsFromCSVFiles(l.logger, l.csvFiles...),
// We remove non-device nodes.
tegra.OnlyDeviceNodes(),
)
if !l.noFilterDeviceNodes {
mountSpecs = tegra.Transform(
mountSpecs,
// We remove the regular (nvidia[0-9]+) device nodes.
tegra.WithoutRegularDeviceNodes(),
)
}
return tegra.New(
tegra.WithLogger(l.logger),
tegra.WithDriverRoot(l.driverRoot),
Expand Down