-
Notifications
You must be signed in to change notification settings - Fork 25
Add pkg/nvpassthrough for binding GPUs to the vfio-pci driver #83
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,103 @@ | ||
| /* | ||
| * Copyright (c) NVIDIA CORPORATION. All rights reserved. | ||
| * | ||
| * Licensed under the Apache License, Version 2.0 (the "License"); | ||
| * you may not use this file except in compliance with the License. | ||
| * You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package nvpassthrough | ||
|
|
||
| import ( | ||
| "bufio" | ||
| "fmt" | ||
| "os" | ||
| "os/exec" | ||
| "path/filepath" | ||
| "strconv" | ||
| "strings" | ||
| ) | ||
|
|
||
| const ( | ||
| procModules = "/proc/modules" | ||
| ) | ||
|
|
||
| type kernelModules struct { | ||
| log basicLogger | ||
|
|
||
| root string | ||
| } | ||
|
|
||
| func newKernelModules(log basicLogger, root string) *kernelModules { | ||
| km := &kernelModules{ | ||
| log: log, | ||
| root: root, | ||
| } | ||
| return km | ||
| } | ||
|
|
||
| func (km *kernelModules) list(searchKey string) error { | ||
| modsFilePath := filepath.Join(km.root, procModules) | ||
| file, err := os.Open(modsFilePath) | ||
| if err != nil { | ||
| return fmt.Errorf("error opening file %s: %w", modsFilePath, err) | ||
| } | ||
| defer func(file *os.File) { | ||
| err := file.Close() | ||
| if err != nil { | ||
| km.log.Warnf("error closing file %s: %v", modsFilePath, err) | ||
| } | ||
| }(file) | ||
|
|
||
| scanner := bufio.NewScanner(file) | ||
| km.log.Infof("%-20s %-10s %-15s %s\n", "Module", "Size", "Ref Count", "Used by") // Header | ||
|
|
||
| for scanner.Scan() { | ||
| line := scanner.Text() | ||
|
|
||
| if len(searchKey) > 0 && !strings.Contains(line, searchKey) { | ||
| continue | ||
| } | ||
|
|
||
| fields := strings.Fields(line) | ||
|
|
||
| if len(fields) >= 4 { | ||
| name := fields[0] | ||
|
|
||
| size, err := strconv.Atoi(fields[1]) | ||
| if err != nil { | ||
| km.log.Warnf("error parsing module size %s: %v", fields[1], err) | ||
| continue | ||
| } | ||
|
|
||
| refCnt, err := strconv.Atoi(fields[2]) | ||
| if err != nil { | ||
| km.log.Warnf("error parsing module ref count %s: %v", fields[2], err) | ||
| continue | ||
| } | ||
|
|
||
| usedBy := fields[3] | ||
|
|
||
| km.log.Infof("%-20s %-10d %-15d %s\n", name, size, refCnt, usedBy) | ||
| } | ||
| } | ||
|
|
||
| if err := scanner.Err(); err != nil { | ||
| km.log.Warnf("error reading /proc/modules: %v\n", err) | ||
| return err | ||
| } | ||
| return nil | ||
| } | ||
|
|
||
| func (km *kernelModules) load(module string) error { | ||
| cmd := exec.Command("chroot", km.root, "modprobe", module) | ||
| return cmd.Run() | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,31 @@ | ||
| /** | ||
| # Copyright (c) NVIDIA CORPORATION. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| **/ | ||
|
|
||
| package nvpassthrough | ||
|
|
||
| type basicLogger interface { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: should this be in an |
||
| Debugf(string, ...interface{}) | ||
| Warnf(string, ...interface{}) | ||
| Infof(string, ...interface{}) | ||
| } | ||
|
|
||
| type nullLogger struct{} | ||
|
|
||
| func (n *nullLogger) Debugf(string, ...interface{}) {} | ||
|
|
||
| func (n *nullLogger) Warnf(string, ...interface{}) {} | ||
|
|
||
| func (n *nullLogger) Infof(string, ...interface{}) {} | ||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,213 @@ | ||||||||||||||||||||
| /* | ||||||||||||||||||||
| * Copyright (c) NVIDIA CORPORATION. All rights reserved. | ||||||||||||||||||||
| * | ||||||||||||||||||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | ||||||||||||||||||||
| * you may not use this file except in compliance with the License. | ||||||||||||||||||||
| * You may obtain a copy of the License at | ||||||||||||||||||||
| * | ||||||||||||||||||||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||||||||||||||||||||
| * | ||||||||||||||||||||
| * Unless required by applicable law or agreed to in writing, software | ||||||||||||||||||||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||||||||||||||||||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||||||||||||||||
| * See the License for the specific language governing permissions and | ||||||||||||||||||||
| * limitations under the License. | ||||||||||||||||||||
| */ | ||||||||||||||||||||
|
|
||||||||||||||||||||
| package nvpassthrough | ||||||||||||||||||||
|
|
||||||||||||||||||||
| import ( | ||||||||||||||||||||
| "fmt" | ||||||||||||||||||||
| "math" | ||||||||||||||||||||
| "reflect" | ||||||||||||||||||||
| "strings" | ||||||||||||||||||||
|
|
||||||||||||||||||||
| "golang.org/x/sys/unix" | ||||||||||||||||||||
| ) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| const ( | ||||||||||||||||||||
| vfioPciAliasPrefix string = "alias vfio_pci:" | ||||||||||||||||||||
| ) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| // modAlias is a decomposed version of string like this | ||||||||||||||||||||
| // | ||||||||||||||||||||
| // vNNNNNNNNdNNNNNNNNsvNNNNNNNNsdNNNNNNNNbcNNscNNiNN | ||||||||||||||||||||
| // | ||||||||||||||||||||
| // The "NNNN" are always of the length in the example | ||||||||||||||||||||
| // unless replaced with a wildcard ("*"). | ||||||||||||||||||||
| type modAlias struct { | ||||||||||||||||||||
| vendor string // v | ||||||||||||||||||||
| device string // d | ||||||||||||||||||||
| subvendor string // sv | ||||||||||||||||||||
| subdevice string // sd | ||||||||||||||||||||
| baseClass string // bc | ||||||||||||||||||||
| subClass string // sc | ||||||||||||||||||||
| programmingInterface string // i | ||||||||||||||||||||
| } | ||||||||||||||||||||
|
|
||||||||||||||||||||
| // vfioAlias represents an entry from the modules.alias file for a vfio driver. | ||||||||||||||||||||
| type vfioAlias struct { | ||||||||||||||||||||
| modAlias *modAlias // The modalias pattern | ||||||||||||||||||||
| driver string // The vfio driver name | ||||||||||||||||||||
| } | ||||||||||||||||||||
|
|
||||||||||||||||||||
| func parseModAliasString(input string) (*modAlias, error) { | ||||||||||||||||||||
| if input == "" { | ||||||||||||||||||||
| return nil, fmt.Errorf("modalias string is empty") | ||||||||||||||||||||
| } | ||||||||||||||||||||
|
|
||||||||||||||||||||
| input = strings.TrimSpace(input) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| // Trim the leading "pci:" prefix in the modalias file | ||||||||||||||||||||
| split := strings.SplitN(input, ":", 2) | ||||||||||||||||||||
| if len(split) != 2 { | ||||||||||||||||||||
| return nil, fmt.Errorf("unexpected number of parts in modalias after trimming 'pci:' prefix: %s", input) | ||||||||||||||||||||
| } | ||||||||||||||||||||
| input = split[1] | ||||||||||||||||||||
|
Comment on lines
+62
to
+66
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Alternatively, sicne we're not checking for the |
||||||||||||||||||||
|
|
||||||||||||||||||||
| if !strings.HasPrefix(input, "v") { | ||||||||||||||||||||
| return nil, fmt.Errorf("modalias must start with 'v', got: %s", input) | ||||||||||||||||||||
| } | ||||||||||||||||||||
|
|
||||||||||||||||||||
| ma := &modAlias{} | ||||||||||||||||||||
| var before, after string | ||||||||||||||||||||
| var found bool | ||||||||||||||||||||
| after = input[1:] // cut leading 'v' | ||||||||||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Question: Why not |
||||||||||||||||||||
|
|
||||||||||||||||||||
| before, after, found = strings.Cut(after, "d") | ||||||||||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. From the spec for the string we're processing, it seems we're dealing with fixed-length segments. Should we use these lengths when parsing? |
||||||||||||||||||||
| if !found { | ||||||||||||||||||||
| return nil, fmt.Errorf("failed to find delimiter 'd' in %q", input) | ||||||||||||||||||||
| } | ||||||||||||||||||||
| ma.vendor = before | ||||||||||||||||||||
|
|
||||||||||||||||||||
| before, after, found = strings.Cut(after, "sv") | ||||||||||||||||||||
| if !found { | ||||||||||||||||||||
| return nil, fmt.Errorf("failed to find delimiter 'sv' in %q", input) | ||||||||||||||||||||
| } | ||||||||||||||||||||
| ma.device = before | ||||||||||||||||||||
|
|
||||||||||||||||||||
| before, after, found = strings.Cut(after, "sd") | ||||||||||||||||||||
| if !found { | ||||||||||||||||||||
| return nil, fmt.Errorf("failed to find delimiter 'sd' in %q", input) | ||||||||||||||||||||
| } | ||||||||||||||||||||
| ma.subvendor = before | ||||||||||||||||||||
|
|
||||||||||||||||||||
| before, after, found = strings.Cut(after, "bc") | ||||||||||||||||||||
| if !found { | ||||||||||||||||||||
| return nil, fmt.Errorf("failed to find delimiter 'bc' in %q", input) | ||||||||||||||||||||
| } | ||||||||||||||||||||
| ma.subdevice = before | ||||||||||||||||||||
|
|
||||||||||||||||||||
| before, after, found = strings.Cut(after, "sc") | ||||||||||||||||||||
| if !found { | ||||||||||||||||||||
| return nil, fmt.Errorf("failed to find delimiter 'sc' in input %q", input) | ||||||||||||||||||||
| } | ||||||||||||||||||||
| ma.baseClass = before | ||||||||||||||||||||
|
|
||||||||||||||||||||
| before, after, found = strings.Cut(after, "i") | ||||||||||||||||||||
| if !found { | ||||||||||||||||||||
| return nil, fmt.Errorf("failed to find delimiter 'i' in %q", input) | ||||||||||||||||||||
| } | ||||||||||||||||||||
| ma.subClass = before | ||||||||||||||||||||
| ma.programmingInterface = after | ||||||||||||||||||||
|
|
||||||||||||||||||||
| return ma, nil | ||||||||||||||||||||
| } | ||||||||||||||||||||
|
|
||||||||||||||||||||
| func getKernelVersion() (string, error) { | ||||||||||||||||||||
| var uname unix.Utsname | ||||||||||||||||||||
| if err := unix.Uname(&uname); err != nil { | ||||||||||||||||||||
| return "", err | ||||||||||||||||||||
| } | ||||||||||||||||||||
|
|
||||||||||||||||||||
| // Convert C-style byte array to Go string | ||||||||||||||||||||
| release := make([]byte, 0, len(uname.Release)) | ||||||||||||||||||||
| for _, c := range uname.Release { | ||||||||||||||||||||
| if c == 0 { | ||||||||||||||||||||
| break | ||||||||||||||||||||
| } | ||||||||||||||||||||
| release = append(release, c) | ||||||||||||||||||||
| } | ||||||||||||||||||||
|
|
||||||||||||||||||||
| return string(release), nil | ||||||||||||||||||||
| } | ||||||||||||||||||||
|
|
||||||||||||||||||||
| // getVFIOAliases returns the vfio driver aliases from the input string. | ||||||||||||||||||||
| // The input string is expected to be the content of a modules.alias file. | ||||||||||||||||||||
| // Only lines that begin with 'alias vfio_pci:' are parsed, with the | ||||||||||||||||||||
| // format being: alias vfio_pci:<modalias string> <driver_name>. | ||||||||||||||||||||
| func getVFIOAliases(input string) []vfioAlias { | ||||||||||||||||||||
| var aliases []vfioAlias | ||||||||||||||||||||
|
|
||||||||||||||||||||
| lines := strings.Split(input, "\n") | ||||||||||||||||||||
| for _, line := range lines { | ||||||||||||||||||||
| line = strings.TrimSpace(line) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| if !strings.HasPrefix(line, vfioPciAliasPrefix) { | ||||||||||||||||||||
| continue | ||||||||||||||||||||
| } | ||||||||||||||||||||
|
|
||||||||||||||||||||
| split := strings.SplitN(line, " ", 3) | ||||||||||||||||||||
| if len(split) != 3 { | ||||||||||||||||||||
| continue | ||||||||||||||||||||
| } | ||||||||||||||||||||
| modAliasStr := split[1] | ||||||||||||||||||||
| modAlias, err := parseModAliasString(modAliasStr) | ||||||||||||||||||||
| if err != nil { | ||||||||||||||||||||
| continue | ||||||||||||||||||||
| } | ||||||||||||||||||||
|
|
||||||||||||||||||||
| driver := split[2] | ||||||||||||||||||||
| aliases = append(aliases, vfioAlias{ | ||||||||||||||||||||
| modAlias: modAlias, | ||||||||||||||||||||
| driver: driver, | ||||||||||||||||||||
| }) | ||||||||||||||||||||
| } | ||||||||||||||||||||
|
|
||||||||||||||||||||
| return aliases | ||||||||||||||||||||
| } | ||||||||||||||||||||
|
|
||||||||||||||||||||
| // findBestMatch finds the best matching VFIO driver for the given modalias | ||||||||||||||||||||
| // by comparing against all available vfio alias patterns. The best match | ||||||||||||||||||||
| // is the one with the fewest wildcard characters. | ||||||||||||||||||||
| func findBestMatch(deviceModAlias *modAlias, aliases []vfioAlias) string { | ||||||||||||||||||||
| var bestDriver string | ||||||||||||||||||||
| bestWildcardCount := math.MaxInt | ||||||||||||||||||||
|
|
||||||||||||||||||||
| for _, alias := range aliases { | ||||||||||||||||||||
| if matches, wildcardCount := matchModalias(deviceModAlias, alias.modAlias); matches { | ||||||||||||||||||||
| if wildcardCount < bestWildcardCount { | ||||||||||||||||||||
| bestDriver = alias.driver | ||||||||||||||||||||
| bestWildcardCount = wildcardCount | ||||||||||||||||||||
| } | ||||||||||||||||||||
| } | ||||||||||||||||||||
| } | ||||||||||||||||||||
|
|
||||||||||||||||||||
| return bestDriver | ||||||||||||||||||||
| } | ||||||||||||||||||||
|
|
||||||||||||||||||||
| // matchModalias checks if a device modalias matches a pattern from modules.alias | ||||||||||||||||||||
| // Returns true if it matches and the number of wildcards. | ||||||||||||||||||||
| func matchModalias(deviceModAlias, patternModAlias *modAlias) (bool, int) { | ||||||||||||||||||||
| wildcardCount := 0 | ||||||||||||||||||||
|
|
||||||||||||||||||||
| modAliasType := reflect.TypeOf(*deviceModAlias) | ||||||||||||||||||||
| deviceModAliasValue := reflect.ValueOf(*deviceModAlias) | ||||||||||||||||||||
| patternModAliasValue := reflect.ValueOf(*patternModAlias) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| // iterate over both modAlias structs, comparing each field | ||||||||||||||||||||
| for i := 0; i < modAliasType.NumField(); i++ { | ||||||||||||||||||||
| deviceValue := deviceModAliasValue.Field(i).String() | ||||||||||||||||||||
| patternValue := patternModAliasValue.Field(i).String() | ||||||||||||||||||||
|
|
||||||||||||||||||||
| if patternValue == "*" { | ||||||||||||||||||||
| wildcardCount++ | ||||||||||||||||||||
| continue | ||||||||||||||||||||
| } | ||||||||||||||||||||
|
|
||||||||||||||||||||
| if deviceValue != patternValue { | ||||||||||||||||||||
| return false, wildcardCount | ||||||||||||||||||||
| } | ||||||||||||||||||||
| } | ||||||||||||||||||||
| return true, wildcardCount | ||||||||||||||||||||
| } | ||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Question: Why is the
listfunction not used?