Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions internal/plugin/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,16 @@ const (
deviceListEnvVar = "NVIDIA_VISIBLE_DEVICES"
deviceListAsVolumeMountsHostPath = "/dev/null"
deviceListAsVolumeMountsContainerPathRoot = "/var/run/nvidia-container-devices"

// healthChannelBufferSize defines the buffer capacity for the health
// channel. This is sized to handle bursts of unhealthy device reports
// without blocking the health check goroutine. The value of 64 is
// chosen assuming a single device plugin instance runs per node with
// up to 8 GPUs per node and multiple in-flight events per GPU (XID
// errors, ECC errors, etc.), while keeping a power-of-2 size for
// cache-friendly alignment. Operators running nodes with significantly
// more GPUs should review this assumption.
healthChannelBufferSize = 64
)

// nvidiaDevicePlugin implements the Kubernetes device plugin API
Expand Down Expand Up @@ -108,7 +118,7 @@ func getPluginSocketPath(resource spec.ResourceName) string {

func (plugin *nvidiaDevicePlugin) initialize() {
plugin.server = grpc.NewServer([]grpc.ServerOption{}...)
plugin.health = make(chan *rm.Device)
plugin.health = make(chan *rm.Device, healthChannelBufferSize)
plugin.stop = make(chan interface{})
}

Expand Down Expand Up @@ -150,7 +160,7 @@ func (plugin *nvidiaDevicePlugin) Start(kubeletSocket string) error {

go func() {
// TODO: add MPS health check
err := plugin.rm.CheckHealth(plugin.stop, plugin.health)
err := plugin.rm.CheckHealth(plugin.ctx, plugin.stop, plugin.health)
if err != nil {
klog.Errorf("Failed to start health check: %v; continuing with health checks disabled", err)
}
Expand Down Expand Up @@ -263,7 +273,8 @@ func (plugin *nvidiaDevicePlugin) GetDevicePluginOptions(context.Context, *plugi
return options, nil
}

// ListAndWatch lists devices and update that list according to the health status
// ListAndWatch lists devices and update that list according to the health
// status.
func (plugin *nvidiaDevicePlugin) ListAndWatch(e *pluginapi.Empty, s pluginapi.DevicePlugin_ListAndWatchServer) error {
if err := s.Send(&pluginapi.ListAndWatchResponse{Devices: plugin.apiDevices()}); err != nil {
return err
Expand All @@ -274,9 +285,9 @@ func (plugin *nvidiaDevicePlugin) ListAndWatch(e *pluginapi.Empty, s pluginapi.D
case <-plugin.stop:
return nil
case d := <-plugin.health:
// FIXME: there is no way to recover from the Unhealthy state.
d.Health = pluginapi.Unhealthy
klog.Infof("'%s' device marked unhealthy: %s", plugin.rm.Resource(), d.ID)
klog.Infof("'%s' device marked unhealthy: %s (reason: %s)",
plugin.rm.Resource(), d.ID, d.GetUnhealthyReason())
if err := s.Send(&pluginapi.ListAndWatchResponse{Devices: plugin.apiDevices()}); err != nil {
return nil
}
Expand Down
55 changes: 55 additions & 0 deletions internal/rm/devices.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ import (
"fmt"
"strconv"
"strings"
"sync"
"time"

"k8s.io/klog/v2"
pluginapi "k8s.io/kubelet/pkg/apis/deviceplugin/v1beta1"
Expand All @@ -35,6 +37,11 @@ type Device struct {
// Replicas stores the total number of times this device is replicated.
// If this is 0 or 1 then the device is not shared.
Replicas int

// Health tracking fields (protected by healthMu)
healthMu sync.RWMutex
lastUnhealthyTime time.Time // When device became unhealthy
unhealthyReason string // Human-readable reason (e.g., "XID-79")
}

// deviceInfo defines the information the required to construct a Device
Expand Down Expand Up @@ -239,6 +246,54 @@ func (d *Device) GetUUID() string {
return AnnotatedID(d.ID).GetID()
}

// MarkUnhealthy marks the device as unhealthy and records the reason and
// timestamp. This should be called when a health check detects a device
// failure (e.g., XID error). Once marked unhealthy, devices remain in this
// state until external intervention (e.g., node drain, GPU reset, reboot).
// This method is thread-safe.
func (d *Device) MarkUnhealthy(reason string) {
d.healthMu.Lock()
defer d.healthMu.Unlock()
d.Health = pluginapi.Unhealthy
d.lastUnhealthyTime = time.Now()
d.unhealthyReason = reason
}

// IsUnhealthy returns true if the device is currently marked as unhealthy.
// This method is thread-safe.
func (d *Device) IsUnhealthy() bool {
d.healthMu.RLock()
defer d.healthMu.RUnlock()
return d.Health == pluginapi.Unhealthy
}

// GetUnhealthyReason returns the reason the device was marked unhealthy.
// This method is thread-safe.
func (d *Device) GetUnhealthyReason() string {
d.healthMu.RLock()
defer d.healthMu.RUnlock()
return d.unhealthyReason
}

// GetLastUnhealthyTime returns when the device was marked unhealthy.
// This method is thread-safe.
func (d *Device) GetLastUnhealthyTime() time.Time {
d.healthMu.RLock()
defer d.healthMu.RUnlock()
return d.lastUnhealthyTime
}

// UnhealthyDuration returns how long the device has been unhealthy. Returns
// zero duration if the device is healthy. This method is thread-safe.
func (d *Device) UnhealthyDuration() time.Duration {
d.healthMu.RLock()
defer d.healthMu.RUnlock()
if d.Health != pluginapi.Unhealthy {
return 0
}
return time.Since(d.lastUnhealthyTime)
}

// NewAnnotatedID creates a new AnnotatedID from an ID and a replica number.
func NewAnnotatedID(id string, replica int) AnnotatedID {
return AnnotatedID(fmt.Sprintf("%s::%d", id, replica))
Expand Down
Loading
Loading