From 79477c95de6312792dddc3c1905bbb4d2d68685a Mon Sep 17 00:00:00 2001 From: Evan Lezar Date: Thu, 19 Feb 2026 13:52:24 +0100 Subject: [PATCH] fix: Reuse instantiated editsFactory in CDI This change ensures that the editsFactory is instantiated once and passed to the nvcdi constructor. This makes it unnecessary to reprocess optional arguments and configs. Signed-off-by: Evan Lezar --- internal/modifier/cdi.go | 9 ++------- internal/modifier/csv_test.go | 1 - internal/modifier/discover_test.go | 2 ++ internal/modifier/factory.go | 32 +++++++++++++++++++----------- pkg/nvcdi/lib.go | 5 +---- pkg/nvcdi/options.go | 16 +++++++++++++++ 6 files changed, 41 insertions(+), 24 deletions(-) diff --git a/internal/modifier/cdi.go b/internal/modifier/cdi.go index 3a6bc3bf3..005c42f48 100644 --- a/internal/modifier/cdi.go +++ b/internal/modifier/cdi.go @@ -18,7 +18,6 @@ package modifier import ( "fmt" - "slices" "strings" "tags.cncf.io/container-device-interface/pkg/parser" @@ -176,11 +175,6 @@ func filterAutomaticDevices(devices []string) []string { func (f *Factory) newAutomaticCDISpecModifier(devices []string) (oci.SpecModifier, error) { f.logger.Debugf("Generating in-memory CDI specs for devices %v", devices) - nvcdiFeatureFlags := slices.Clone(f.cfg.NVIDIAContainerRuntimeConfig.Modes.JitCDI.NVCDIFeatureFlags) - if f.cfg.Features.NoAdditionalGIDsForDeviceNodes.IsEnabled() { - nvcdiFeatureFlags = append(nvcdiFeatureFlags, nvcdi.FeatureNoAdditionalGIDsForDeviceNodes) - } - csvFiles, err := csv.GetFileList(f.cfg.NVIDIAContainerRuntimeConfig.Modes.CSV.MountSpecPath) if err != nil { f.logger.Warningf("Failed to get the list of CSV files: %v", err) @@ -198,10 +192,11 @@ func (f *Factory) newAutomaticCDISpecModifier(devices []string) (oci.SpecModifie nvcdi.WithNVIDIACDIHookPath(f.cfg.NVIDIACTKConfig.Path), nvcdi.WithDriverRoot(f.driver.Root), nvcdi.WithDevRoot(f.driver.DevRoot), + nvcdi.WithEditsFactory(f.editsFactory), nvcdi.WithVendor(automaticDeviceVendor), nvcdi.WithClass(cdiModeIdentifiers.deviceClassByMode[mode]), nvcdi.WithMode(mode), - nvcdi.WithFeatureFlags(nvcdiFeatureFlags...), + nvcdi.WithFeatureFlags(f.cfg.NVIDIAContainerRuntimeConfig.Modes.JitCDI.NVCDIFeatureFlags...), nvcdi.WithCSVCompatContainerRoot(f.cfg.NVIDIAContainerRuntimeConfig.Modes.CSV.CompatContainerRoot), nvcdi.WithCSVFiles(csvFiles), ) diff --git a/internal/modifier/csv_test.go b/internal/modifier/csv_test.go index d91c6d019..994ceb65f 100644 --- a/internal/modifier/csv_test.go +++ b/internal/modifier/csv_test.go @@ -124,7 +124,6 @@ func TestNewCSVModifier(t *testing.T) { f := createFactory( WithLogger(logger), WithDriver(driver), - WithLogger(logger), WithConfig(&tc.cfg), WithImage(&image), ) diff --git a/internal/modifier/discover_test.go b/internal/modifier/discover_test.go index 76de5fc75..136716da4 100644 --- a/internal/modifier/discover_test.go +++ b/internal/modifier/discover_test.go @@ -24,6 +24,7 @@ import ( testlog "github.com/sirupsen/logrus/hooks/test" "github.com/stretchr/testify/require" + "github.com/NVIDIA/nvidia-container-toolkit/api/config/v1" "github.com/NVIDIA/nvidia-container-toolkit/internal/discover" ) @@ -132,6 +133,7 @@ func TestDiscoverModifier(t *testing.T) { factory := createFactory( WithLogger(logger), + WithConfig(&config.Config{}), ) for _, tc := range testCases { t.Run(tc.description, func(t *testing.T) { diff --git a/internal/modifier/factory.go b/internal/modifier/factory.go index 6b58b56d7..3ec104fa2 100644 --- a/internal/modifier/factory.go +++ b/internal/modifier/factory.go @@ -32,14 +32,20 @@ import ( "github.com/NVIDIA/nvidia-container-toolkit/internal/oci" ) -type Factory struct { +// factoryOptions define the set of options that must be set when constructing +// a modifier factory. +type factoryOptions struct { logger logger.Interface cfg *config.Config driver *root.Driver hookCreator discover.HookCreator image *image.CUDA runtimeMode info.RuntimeMode +} +type Factory struct { + factoryOptions + // An editsFactory is created at construction. editsFactory edits.Factory } @@ -60,12 +66,14 @@ func New(opts ...Option) (oci.SpecModifier, error) { func createFactory(opts ...Option) *Factory { f := &Factory{} for _, opt := range opts { - opt(f) - } - if f.editsFactory == nil { - f.editsFactory = edits.NewFactory(edits.WithLogger(f.logger)) + opt(&f.factoryOptions) } + f.editsFactory = edits.NewFactory( + edits.WithLogger(f.logger), + edits.WithNoAdditionalGIDsForDeviceNodes(f.cfg.Features.NoAdditionalGIDsForDeviceNodes.IsEnabled()), + ) + return f } @@ -125,39 +133,39 @@ func (f *Factory) create() (oci.SpecModifier, error) { return modifiers, nil } -type Option func(*Factory) +type Option func(*factoryOptions) func WithConfig(cfg *config.Config) Option { - return func(f *Factory) { + return func(f *factoryOptions) { f.cfg = cfg } } func WithDriver(driver *root.Driver) Option { - return func(f *Factory) { + return func(f *factoryOptions) { f.driver = driver } } func WithHookCreator(hookCreator discover.HookCreator) Option { - return func(f *Factory) { + return func(f *factoryOptions) { f.hookCreator = hookCreator } } func WithImage(image *image.CUDA) Option { - return func(f *Factory) { + return func(f *factoryOptions) { f.image = image } } func WithLogger(logger logger.Interface) Option { - return func(f *Factory) { + return func(f *factoryOptions) { f.logger = logger } } func WithRuntimeMode(runtimeMode info.RuntimeMode) Option { - return func(f *Factory) { + return func(f *factoryOptions) { f.runtimeMode = runtimeMode } } diff --git a/pkg/nvcdi/lib.go b/pkg/nvcdi/lib.go index 91f116cb9..fa84e19c5 100644 --- a/pkg/nvcdi/lib.go +++ b/pkg/nvcdi/lib.go @@ -71,10 +71,7 @@ func New(opts ...Option) (Interface, error) { discover.WithLdconfigPath(o.ldconfigPath), discover.WithDisabledHooks(o.disabledHooks...), ), - editsFactory: edits.NewFactory( - edits.WithLogger(o.logger), - edits.WithNoAdditionalGIDsForDeviceNodes(o.featureFlags[FeatureNoAdditionalGIDsForDeviceNodes]), - ), + editsFactory: o.editsFactory, } var factory deviceSpecGeneratorFactory diff --git a/pkg/nvcdi/options.go b/pkg/nvcdi/options.go index 907af6c99..a5e2b1c40 100644 --- a/pkg/nvcdi/options.go +++ b/pkg/nvcdi/options.go @@ -22,6 +22,7 @@ import ( "github.com/NVIDIA/go-nvml/pkg/nvml" "github.com/NVIDIA/nvidia-container-toolkit/internal/discover" + "github.com/NVIDIA/nvidia-container-toolkit/internal/edits" "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" "github.com/NVIDIA/nvidia-container-toolkit/internal/nvsandboxutils" "github.com/NVIDIA/nvidia-container-toolkit/internal/platform-support/tegra/csv" @@ -52,6 +53,8 @@ type options struct { disabledHooks []discover.HookName enabledHooks []discover.HookName + + editsFactory edits.Factory } type platformlibs struct { @@ -116,6 +119,13 @@ func populateOptions(opts ...Option) *options { o.disabledHooks = append(o.disabledHooks, HookEnableCudaCompat, DisableDeviceNodeModificationHook) } + if o.editsFactory == nil { + o.editsFactory = edits.NewFactory( + edits.WithLogger(o.logger), + edits.WithNoAdditionalGIDsForDeviceNodes(o.featureFlags[FeatureNoAdditionalGIDsForDeviceNodes]), + ) + } + return o } @@ -191,6 +201,12 @@ func WithDevRoot(root string) Option { } } +func WithEditsFactory(editsFactory edits.Factory) Option { + return func(l *options) { + l.editsFactory = editsFactory + } +} + // WithLogger sets the logger for the library func WithLogger(logger logger.Interface) Option { return func(l *options) {