diff --git a/internal/modifier/cdi.go b/internal/modifier/cdi.go index 3a6bc3bf3..005c42f48 100644 --- a/internal/modifier/cdi.go +++ b/internal/modifier/cdi.go @@ -18,7 +18,6 @@ package modifier import ( "fmt" - "slices" "strings" "tags.cncf.io/container-device-interface/pkg/parser" @@ -176,11 +175,6 @@ func filterAutomaticDevices(devices []string) []string { func (f *Factory) newAutomaticCDISpecModifier(devices []string) (oci.SpecModifier, error) { f.logger.Debugf("Generating in-memory CDI specs for devices %v", devices) - nvcdiFeatureFlags := slices.Clone(f.cfg.NVIDIAContainerRuntimeConfig.Modes.JitCDI.NVCDIFeatureFlags) - if f.cfg.Features.NoAdditionalGIDsForDeviceNodes.IsEnabled() { - nvcdiFeatureFlags = append(nvcdiFeatureFlags, nvcdi.FeatureNoAdditionalGIDsForDeviceNodes) - } - csvFiles, err := csv.GetFileList(f.cfg.NVIDIAContainerRuntimeConfig.Modes.CSV.MountSpecPath) if err != nil { f.logger.Warningf("Failed to get the list of CSV files: %v", err) @@ -198,10 +192,11 @@ func (f *Factory) newAutomaticCDISpecModifier(devices []string) (oci.SpecModifie nvcdi.WithNVIDIACDIHookPath(f.cfg.NVIDIACTKConfig.Path), nvcdi.WithDriverRoot(f.driver.Root), nvcdi.WithDevRoot(f.driver.DevRoot), + nvcdi.WithEditsFactory(f.editsFactory), nvcdi.WithVendor(automaticDeviceVendor), nvcdi.WithClass(cdiModeIdentifiers.deviceClassByMode[mode]), nvcdi.WithMode(mode), - nvcdi.WithFeatureFlags(nvcdiFeatureFlags...), + nvcdi.WithFeatureFlags(f.cfg.NVIDIAContainerRuntimeConfig.Modes.JitCDI.NVCDIFeatureFlags...), nvcdi.WithCSVCompatContainerRoot(f.cfg.NVIDIAContainerRuntimeConfig.Modes.CSV.CompatContainerRoot), nvcdi.WithCSVFiles(csvFiles), ) diff --git a/internal/modifier/csv_test.go b/internal/modifier/csv_test.go index d91c6d019..994ceb65f 100644 --- a/internal/modifier/csv_test.go +++ b/internal/modifier/csv_test.go @@ -124,7 +124,6 @@ func TestNewCSVModifier(t *testing.T) { f := createFactory( WithLogger(logger), WithDriver(driver), - WithLogger(logger), WithConfig(&tc.cfg), WithImage(&image), ) diff --git a/internal/modifier/discover_test.go b/internal/modifier/discover_test.go index 76de5fc75..136716da4 100644 --- a/internal/modifier/discover_test.go +++ b/internal/modifier/discover_test.go @@ -24,6 +24,7 @@ import ( testlog "github.com/sirupsen/logrus/hooks/test" "github.com/stretchr/testify/require" + "github.com/NVIDIA/nvidia-container-toolkit/api/config/v1" "github.com/NVIDIA/nvidia-container-toolkit/internal/discover" ) @@ -132,6 +133,7 @@ func TestDiscoverModifier(t *testing.T) { factory := createFactory( WithLogger(logger), + WithConfig(&config.Config{}), ) for _, tc := range testCases { t.Run(tc.description, func(t *testing.T) { diff --git a/internal/modifier/factory.go b/internal/modifier/factory.go index 6b58b56d7..3ec104fa2 100644 --- a/internal/modifier/factory.go +++ b/internal/modifier/factory.go @@ -32,14 +32,20 @@ import ( "github.com/NVIDIA/nvidia-container-toolkit/internal/oci" ) -type Factory struct { +// factoryOptions define the set of options that must be set when constructing +// a modifier factory. +type factoryOptions struct { logger logger.Interface cfg *config.Config driver *root.Driver hookCreator discover.HookCreator image *image.CUDA runtimeMode info.RuntimeMode +} +type Factory struct { + factoryOptions + // An editsFactory is created at construction. editsFactory edits.Factory } @@ -60,12 +66,14 @@ func New(opts ...Option) (oci.SpecModifier, error) { func createFactory(opts ...Option) *Factory { f := &Factory{} for _, opt := range opts { - opt(f) - } - if f.editsFactory == nil { - f.editsFactory = edits.NewFactory(edits.WithLogger(f.logger)) + opt(&f.factoryOptions) } + f.editsFactory = edits.NewFactory( + edits.WithLogger(f.logger), + edits.WithNoAdditionalGIDsForDeviceNodes(f.cfg.Features.NoAdditionalGIDsForDeviceNodes.IsEnabled()), + ) + return f } @@ -125,39 +133,39 @@ func (f *Factory) create() (oci.SpecModifier, error) { return modifiers, nil } -type Option func(*Factory) +type Option func(*factoryOptions) func WithConfig(cfg *config.Config) Option { - return func(f *Factory) { + return func(f *factoryOptions) { f.cfg = cfg } } func WithDriver(driver *root.Driver) Option { - return func(f *Factory) { + return func(f *factoryOptions) { f.driver = driver } } func WithHookCreator(hookCreator discover.HookCreator) Option { - return func(f *Factory) { + return func(f *factoryOptions) { f.hookCreator = hookCreator } } func WithImage(image *image.CUDA) Option { - return func(f *Factory) { + return func(f *factoryOptions) { f.image = image } } func WithLogger(logger logger.Interface) Option { - return func(f *Factory) { + return func(f *factoryOptions) { f.logger = logger } } func WithRuntimeMode(runtimeMode info.RuntimeMode) Option { - return func(f *Factory) { + return func(f *factoryOptions) { f.runtimeMode = runtimeMode } } diff --git a/pkg/nvcdi/lib.go b/pkg/nvcdi/lib.go index 91f116cb9..fa84e19c5 100644 --- a/pkg/nvcdi/lib.go +++ b/pkg/nvcdi/lib.go @@ -71,10 +71,7 @@ func New(opts ...Option) (Interface, error) { discover.WithLdconfigPath(o.ldconfigPath), discover.WithDisabledHooks(o.disabledHooks...), ), - editsFactory: edits.NewFactory( - edits.WithLogger(o.logger), - edits.WithNoAdditionalGIDsForDeviceNodes(o.featureFlags[FeatureNoAdditionalGIDsForDeviceNodes]), - ), + editsFactory: o.editsFactory, } var factory deviceSpecGeneratorFactory diff --git a/pkg/nvcdi/options.go b/pkg/nvcdi/options.go index 907af6c99..a5e2b1c40 100644 --- a/pkg/nvcdi/options.go +++ b/pkg/nvcdi/options.go @@ -22,6 +22,7 @@ import ( "github.com/NVIDIA/go-nvml/pkg/nvml" "github.com/NVIDIA/nvidia-container-toolkit/internal/discover" + "github.com/NVIDIA/nvidia-container-toolkit/internal/edits" "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" "github.com/NVIDIA/nvidia-container-toolkit/internal/nvsandboxutils" "github.com/NVIDIA/nvidia-container-toolkit/internal/platform-support/tegra/csv" @@ -52,6 +53,8 @@ type options struct { disabledHooks []discover.HookName enabledHooks []discover.HookName + + editsFactory edits.Factory } type platformlibs struct { @@ -116,6 +119,13 @@ func populateOptions(opts ...Option) *options { o.disabledHooks = append(o.disabledHooks, HookEnableCudaCompat, DisableDeviceNodeModificationHook) } + if o.editsFactory == nil { + o.editsFactory = edits.NewFactory( + edits.WithLogger(o.logger), + edits.WithNoAdditionalGIDsForDeviceNodes(o.featureFlags[FeatureNoAdditionalGIDsForDeviceNodes]), + ) + } + return o } @@ -191,6 +201,12 @@ func WithDevRoot(root string) Option { } } +func WithEditsFactory(editsFactory edits.Factory) Option { + return func(l *options) { + l.editsFactory = editsFactory + } +} + // WithLogger sets the logger for the library func WithLogger(logger logger.Interface) Option { return func(l *options) {