Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions internal/devices/devices.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,11 @@ var assertCharDeviceStub = func(path string) error {
}
return nil
}

func IsOverrideApplied() bool {
return isOverrideAppliedStub()
}

var isOverrideAppliedStub = func() bool {
return false
}
41 changes: 39 additions & 2 deletions internal/devices/devices_mock.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

14 changes: 14 additions & 0 deletions internal/devices/devices_tests.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (
type Interface interface {
DeviceFromPath(string, string) (*Device, error)
AssertCharDevice(string) error
IsOverrideApplied() bool
}

type testDefaults struct{}
Expand All @@ -47,6 +48,7 @@ func SetInterfaceForTests(m Interface) func() {
funcs := []func(){
SetDeviceFromPathForTest(m.DeviceFromPath),
SetAssertCharDeviceForTest(m.AssertCharDevice),
SetIsOverrideAppliedForTest(m.IsOverrideApplied),
}
return func() {
for _, f := range funcs {
Expand All @@ -71,6 +73,14 @@ func SetAssertCharDeviceForTest(testFunc func(string) error) func() {
}
}

func SetIsOverrideAppliedForTest(testFunc func() bool) func() {
current := isOverrideAppliedStub
isOverrideAppliedStub = testFunc
return func() {
isOverrideAppliedStub = current
}
}

type testDevice struct {
Device
}
Expand Down Expand Up @@ -115,3 +125,7 @@ func (t *testDefaults) AssertCharDevice(path string) error {

return nil
}

func (t *testDefaults) IsOverrideApplied() bool {
return true
}
10 changes: 10 additions & 0 deletions internal/modifier/cdi.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
"github.com/NVIDIA/nvidia-container-toolkit/internal/modifier/cdi"
"github.com/NVIDIA/nvidia-container-toolkit/internal/oci"
"github.com/NVIDIA/nvidia-container-toolkit/internal/platform-support/tegra/csv"
"github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi"
)

Expand Down Expand Up @@ -180,6 +181,14 @@ func (f *Factory) newAutomaticCDISpecModifier(devices []string) (oci.SpecModifie
nvcdiFeatureFlags = append(nvcdiFeatureFlags, nvcdi.FeatureNoAdditionalGIDsForDeviceNodes)
}

csvFiles, err := csv.GetFileList(f.cfg.NVIDIAContainerRuntimeConfig.Modes.CSV.MountSpecPath)
if err != nil {
f.logger.Warningf("Failed to get the list of CSV files: %v", err)
}
if f.image.Getenv(image.EnvVarNvidiaRequireJetpack) != "csv-mounts=all" {
csvFiles = csv.BaseFilesOnly(csvFiles)
}

cdiModeIdentifiers := cdiModeIdentfiersFromDevices(devices...)
f.logger.Debugf("Per-mode identifiers: %v", cdiModeIdentifiers)
var modifiers oci.SpecModifiers
Expand All @@ -194,6 +203,7 @@ func (f *Factory) newAutomaticCDISpecModifier(devices []string) (oci.SpecModifie
nvcdi.WithMode(mode),
nvcdi.WithFeatureFlags(nvcdiFeatureFlags...),
nvcdi.WithCSVCompatContainerRoot(f.cfg.NVIDIAContainerRuntimeConfig.Modes.CSV.CompatContainerRoot),
nvcdi.WithCSVFiles(csvFiles),
)
if err != nil {
return nil, fmt.Errorf("failed to construct CDI library for mode %q: %w", mode, err)
Expand Down
32 changes: 27 additions & 5 deletions internal/modifier/cdi/spec.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,29 +20,51 @@ import (
"fmt"

"github.com/opencontainers/runtime-spec/specs-go"
"tags.cncf.io/container-device-interface/pkg/cdi"
cdiapi "tags.cncf.io/container-device-interface/pkg/cdi"
cdi "tags.cncf.io/container-device-interface/specs-go"

"github.com/NVIDIA/nvidia-container-toolkit/internal/devices"
"github.com/NVIDIA/nvidia-container-toolkit/internal/oci"
)

// fromCDISpec represents the modifications performed from a raw CDI spec.
type fromCDISpec struct {
cdiSpec *cdi.Spec
cdiSpec *cdiapi.Spec
}

var _ oci.SpecModifier = (*fromCDISpec)(nil)

// Modify applies the mofiications defined by the raw CDI spec to the incomming OCI spec.
func (m fromCDISpec) Modify(spec *specs.Spec) error {
for _, device := range m.cdiSpec.Devices {
device := device
cdiDevice := cdi.Device{
device := m.enrichDevice(device)
cdiDevice := cdiapi.Device{
Device: &device,
}
if err := cdiDevice.ApplyEdits(spec); err != nil {
return fmt.Errorf("failed to apply edits for device %q: %v", cdiDevice.GetQualifiedName(), err)
return fmt.Errorf("failed to apply edits for device %q: %v", m.cdiSpec.Kind+"="+device.Name, err)
}
}

return m.cdiSpec.ApplyEdits(spec)
}

func (m fromCDISpec) enrichDevice(device cdi.Device) cdi.Device {
if !devices.IsOverrideApplied() {
return device
}
// For testing we need to override the device node information to ensure
// that we don't trigger the CDI modification that requires the device node
// to exist and be a character device.
// The following condition is used to determine whether a failure to get
// the info is fatal:
// hasMinimalSpecification := d.Type != "" && (d.Major != 0 || d.Type == fifoDevice)
for i, dn := range device.ContainerEdits.DeviceNodes {
dn.Type = "c"
if dn.Major == 0 {
dn.Major = 99
}
device.ContainerEdits.DeviceNodes[i] = dn
}
return device
}
48 changes: 13 additions & 35 deletions internal/modifier/csv.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,14 @@ import (
"github.com/NVIDIA/nvidia-container-toolkit/internal/config/image"
"github.com/NVIDIA/nvidia-container-toolkit/internal/cuda"
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
"github.com/NVIDIA/nvidia-container-toolkit/internal/modifier/cdi"
"github.com/NVIDIA/nvidia-container-toolkit/internal/oci"
"github.com/NVIDIA/nvidia-container-toolkit/internal/platform-support/tegra/csv"
"github.com/NVIDIA/nvidia-container-toolkit/internal/requirements"
"github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi"
)

// newCSVModifier creates a modifier that applies modications to an OCI spec if required by the runtime wrapper.
// The modifications are defined by CSV MountSpecs.
func (f *Factory) newCSVModifier() (oci.SpecModifier, error) {
devices := f.image.VisibleDevices()
devices := withUniqueDevices(csvDevices(*f.image)).DeviceRequests()
if len(devices) == 0 {
f.logger.Infof("No modification required; no devices requested")
return nil, nil
Expand All @@ -43,37 +40,7 @@ func (f *Factory) newCSVModifier() (oci.SpecModifier, error) {
return nil, fmt.Errorf("requirements not met: %v", err)
}

csvFiles, err := csv.GetFileList(f.cfg.NVIDIAContainerRuntimeConfig.Modes.CSV.MountSpecPath)
if err != nil {
return nil, fmt.Errorf("failed to get list of CSV files: %v", err)
}

if f.image.Getenv(image.EnvVarNvidiaRequireJetpack) != "csv-mounts=all" {
csvFiles = csv.BaseFilesOnly(csvFiles)
}

cdilib, err := nvcdi.New(
nvcdi.WithLogger(f.logger),
nvcdi.WithDriverRoot(f.driver.Root),
nvcdi.WithDevRoot(f.driver.DevRoot),
nvcdi.WithNVIDIACDIHookPath(f.cfg.NVIDIACTKConfig.Path),
nvcdi.WithMode(nvcdi.ModeCSV),
nvcdi.WithCSVFiles(csvFiles),
nvcdi.WithCSVCompatContainerRoot(f.cfg.NVIDIAContainerRuntimeConfig.Modes.CSV.CompatContainerRoot),
)
if err != nil {
return nil, fmt.Errorf("failed to construct CDI library: %v", err)
}

spec, err := cdilib.GetSpec(devices...)
if err != nil {
return nil, fmt.Errorf("failed to get CDI spec: %v", err)
}

return cdi.New(
cdi.WithLogger(f.logger),
cdi.WithSpec(spec.Raw()),
)
return f.newAutomaticCDISpecModifier(devices)
}

func checkRequirements(logger logger.Interface, image *image.CUDA) error {
Expand Down Expand Up @@ -107,3 +74,14 @@ func checkRequirements(logger logger.Interface, image *image.CUDA) error {

return r.Assert()
}

type csvDevices image.CUDA

func (d csvDevices) DeviceRequests() []string {
var devices []string
i := (image.CUDA)(d)
for _, deviceID := range i.VisibleDevices() {
devices = append(devices, "mode=csv,id="+deviceID)
}
return devices
}
Loading