diff --git a/go.mod b/go.mod index 1dd775257a0..54c141ff657 100644 --- a/go.mod +++ b/go.mod @@ -215,4 +215,6 @@ replace google.golang.org/protobuf => github.com/lightninglabs/protobuf-go-hex-d // well). go 1.23.6 +replace github.com/lightningnetwork/lnd/queue => ./queue + retract v0.0.2 diff --git a/peer/brontide.go b/peer/brontide.go index da4aa610a44..c54164717de 100644 --- a/peer/brontide.go +++ b/peer/brontide.go @@ -1730,12 +1730,9 @@ type msgStream struct { startMsg string stopMsg string - msgCond *sync.Cond - msgs []lnwire.Message - - mtx sync.Mutex - - producerSema chan struct{} + // queue is the underlying backpressure-aware queue that manages + // messages. + queue *queue.BackpressureQueue[lnwire.Message] wg sync.WaitGroup quit chan struct{} @@ -1744,28 +1741,25 @@ type msgStream struct { // newMsgStream creates a new instance of a chanMsgStream for a particular // channel identified by its channel ID. bufSize is the max number of messages // that should be buffered in the internal queue. Callers should set this to a -// sane value that avoids blocking unnecessarily, but doesn't allow an -// unbounded amount of memory to be allocated to buffer incoming messages. -func newMsgStream(p *Brontide, startMsg, stopMsg string, bufSize uint32, +// sane value that avoids blocking unnecessarily, but doesn't allow an unbounded +// amount of memory to be allocated to buffer incoming messages. +func newMsgStream(p *Brontide, startMsg, stopMsg string, bufSize int, apply func(lnwire.Message)) *msgStream { stream := &msgStream{ - peer: p, - apply: apply, - startMsg: startMsg, - stopMsg: stopMsg, - producerSema: make(chan struct{}, bufSize), - quit: make(chan struct{}), - } - stream.msgCond = sync.NewCond(&stream.mtx) - - // Before we return the active stream, we'll populate the producer's - // semaphore channel. We'll use this to ensure that the producer won't - // attempt to allocate memory in the queue for an item until it has - // sufficient extra space. - for i := uint32(0); i < bufSize; i++ { - stream.producerSema <- struct{}{} - } + peer: p, + apply: apply, + startMsg: startMsg, + stopMsg: stopMsg, + quit: make(chan struct{}), + } + + // Initialize the backpressure queue with a predicate determined by + // build tags. + dropPredicate := getMsgStreamDropPredicate() + stream.queue = queue.NewBackpressureQueue[lnwire.Message]( + bufSize, dropPredicate, + ) return stream } @@ -1778,17 +1772,8 @@ func (ms *msgStream) Start() { // Stop stops the chanMsgStream. func (ms *msgStream) Stop() { - // TODO(roasbeef): signal too? - close(ms.quit) - // Now that we've closed the channel, we'll repeatedly signal the msg - // consumer until we've detected that it has exited. - for atomic.LoadInt32(&ms.streamShutdown) == 0 { - ms.msgCond.Signal() - time.Sleep(time.Millisecond * 100) - } - ms.wg.Wait() } @@ -1796,82 +1781,49 @@ func (ms *msgStream) Stop() { // readHandler directly to the target channel. func (ms *msgStream) msgConsumer() { defer ms.wg.Done() - defer peerLog.Tracef(ms.stopMsg) + defer ms.peer.log.Tracef(ms.stopMsg) defer atomic.StoreInt32(&ms.streamShutdown, 1) - peerLog.Tracef(ms.startMsg) + ms.peer.log.Tracef(ms.startMsg) + + ctx, _ := ms.peer.cg.Create(context.Background()) for { - // First, we'll check our condition. If the queue of messages - // is empty, then we'll wait until a new item is added. - ms.msgCond.L.Lock() - for len(ms.msgs) == 0 { - ms.msgCond.Wait() - - // If we woke up in order to exit, then we'll do so. - // Otherwise, we'll check the message queue for any new - // items. - select { - case <-ms.peer.cg.Done(): - ms.msgCond.L.Unlock() - return - case <-ms.quit: - ms.msgCond.L.Unlock() - return - default: - } + // Dequeue the next message. This will block until a message is + // available or the context is canceled. + msg, err := ms.queue.Dequeue(ctx) + if err != nil { + ms.peer.log.Warnf("unable to dequeue message: %v", err) + return } - // Grab the message off the front of the queue, shifting the - // slice's reference down one in order to remove the message - // from the queue. - msg := ms.msgs[0] - ms.msgs[0] = nil // Set to nil to prevent GC leak. - ms.msgs = ms.msgs[1:] - - ms.msgCond.L.Unlock() - + // Apply the dequeued message. ms.apply(msg) - // We've just successfully processed an item, so we'll signal - // to the producer that a new slot in the buffer. We'll use - // this to bound the size of the buffer to avoid allowing it to - // grow indefinitely. + // As a precaution, we'll check to see if we're already shutting + // down before adding a new message to the queue. select { - case ms.producerSema <- struct{}{}: case <-ms.peer.cg.Done(): return case <-ms.quit: return + default: } } } // AddMsg adds a new message to the msgStream. This function is safe for // concurrent access. -func (ms *msgStream) AddMsg(msg lnwire.Message) { - // First, we'll attempt to receive from the producerSema struct. This - // acts as a semaphore to prevent us from indefinitely buffering - // incoming items from the wire. Either the msg queue isn't full, and - // we'll not block, or the queue is full, and we'll block until either - // we're signalled to quit, or a slot is freed up. - select { - case <-ms.producerSema: - case <-ms.peer.cg.Done(): - return - case <-ms.quit: +func (ms *msgStream) AddMsg(ctx context.Context, msg lnwire.Message) { + dropped, err := ms.queue.Enqueue(ctx, msg).Unpack() + if err != nil { + ms.peer.log.Warnf("unable to enqueue message: %v", err) return } - // Next, we'll lock the condition, and add the message to the end of - // the message queue. - ms.msgCond.L.Lock() - ms.msgs = append(ms.msgs, msg) - ms.msgCond.L.Unlock() - - // With the message added, we signal to the msgConsumer that there are - // additional messages to consume. - ms.msgCond.Signal() + if dropped { + ms.peer.log.Debugf("message %T dropped by predicate", msg) + } } // waitUntilLinkActive waits until the target link is active and returns a @@ -2026,6 +1978,8 @@ func (p *Brontide) readHandler() { // gossiper? p.initGossipSync() + ctx, _ := p.cg.Create(context.Background()) + discStream := newDiscMsgStream(p) discStream.Start() defer discStream.Stop() @@ -2141,11 +2095,15 @@ out: case *lnwire.Warning: targetChan = msg.ChanID - isLinkUpdate = p.handleWarningOrError(targetChan, msg) + isLinkUpdate = p.handleWarningOrError( + ctx, targetChan, msg, + ) case *lnwire.Error: targetChan = msg.ChanID - isLinkUpdate = p.handleWarningOrError(targetChan, msg) + isLinkUpdate = p.handleWarningOrError( + ctx, targetChan, msg, + ) case *lnwire.ChannelReestablish: targetChan = msg.ChanID @@ -2193,7 +2151,7 @@ out: *lnwire.ReplyChannelRange, *lnwire.ReplyShortChanIDsEnd: - discStream.AddMsg(msg) + discStream.AddMsg(ctx, msg) case *lnwire.Custom: err := p.handleCustomMessage(msg) @@ -2215,7 +2173,7 @@ out: if isLinkUpdate { // If this is a channel update, then we need to feed it // into the channel's in-order message stream. - p.sendLinkUpdateMsg(targetChan, nextMsg) + p.sendLinkUpdateMsg(ctx, targetChan, nextMsg) } idleTimer.Reset(idleTimeout) @@ -2330,8 +2288,8 @@ func (p *Brontide) storeError(err error) { // an error from a peer with an active channel, we'll store it in memory. // // NOTE: This method should only be called from within the readHandler. -func (p *Brontide) handleWarningOrError(chanID lnwire.ChannelID, - msg lnwire.Message) bool { +func (p *Brontide) handleWarningOrError(ctx context.Context, + chanID lnwire.ChannelID, msg lnwire.Message) bool { if errMsg, ok := msg.(*lnwire.Error); ok { p.storeError(errMsg) @@ -2342,7 +2300,7 @@ func (p *Brontide) handleWarningOrError(chanID lnwire.ChannelID, // with this peer. case chanID == lnwire.ConnectionWideID: for _, chanStream := range p.activeMsgStreams { - chanStream.AddMsg(msg) + chanStream.AddMsg(ctx, msg) } return false @@ -5297,7 +5255,9 @@ func (p *Brontide) handleRemovePendingChannel(req *newChannelMsg) { // sendLinkUpdateMsg sends a message that updates the channel to the // channel's message stream. -func (p *Brontide) sendLinkUpdateMsg(cid lnwire.ChannelID, msg lnwire.Message) { +func (p *Brontide) sendLinkUpdateMsg(ctx context.Context, + cid lnwire.ChannelID, msg lnwire.Message) { + p.log.Tracef("Sending link update msg=%v", msg.MsgType()) chanStream, ok := p.activeMsgStreams[cid] @@ -5317,7 +5277,7 @@ func (p *Brontide) sendLinkUpdateMsg(cid lnwire.ChannelID, msg lnwire.Message) { // With the stream obtained, add the message to the stream so we can // continue processing message. - chanStream.AddMsg(msg) + chanStream.AddMsg(ctx, msg) } // scaleTimeout multiplies the argument duration by a constant factor depending diff --git a/peer/drop_predicate.go b/peer/drop_predicate.go new file mode 100644 index 00000000000..beb67129614 --- /dev/null +++ b/peer/drop_predicate.go @@ -0,0 +1,57 @@ +//go:build !integration + +package peer + +import ( + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/queue" +) + +const ( + // redMinThreshold is the minimum queue length before RED starts dropping + // messages. + redMinThreshold = 10 + + // redMaxThreshold is the queue length at or above which RED drops all + // messages (that are not protected by type). + redMaxThreshold = 40 +) + +// isProtectedMsgType checks if a message is of a type that should not be +// dropped by the predicate. +func isProtectedMsgType(msg lnwire.Message) bool { + switch msg.(type) { + // Never drop any messages that are heading to an active channel. + case lnwire.LinkUpdater: + return true + + // Make sure to never drop an incoming announcement signatures + // message, as we need this to be able to advertise channels. + // + // TODO(roasbeef): don't drop any gossip if doing IGD? + case *lnwire.AnnounceSignatures1: + return true + + default: + return false + } +} + +// getMsgStreamDropPredicate returns the drop predicate for the msgStream's +// BackpressureQueue. For non-integration builds, this combines a type-based +// check for critical messages with Random Early Detection (RED). +func getMsgStreamDropPredicate() queue.DropPredicate[lnwire.Message] { + redPred := queue.RandomEarlyDrop[lnwire.Message]( + redMinThreshold, redMaxThreshold, + ) + + // We'll never dropped protected messages, for the rest we'll use the + // RED predicate. + return func(queueLen int, item lnwire.Message) bool { + if isProtectedMsgType(item) { + return false + } + + return redPred(queueLen, item) + } +} diff --git a/peer/drop_predicate_integration.go b/peer/drop_predicate_integration.go new file mode 100644 index 00000000000..eb252a0f050 --- /dev/null +++ b/peer/drop_predicate_integration.go @@ -0,0 +1,17 @@ +//go:build integration + +package peer + +import ( + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/queue" +) + +// getMsgStreamDropPredicate returns the drop predicate for the msgStream's +// BackpressureQueue. For integration builds, this predicate never drops +// messages. +func getMsgStreamDropPredicate() queue.DropPredicate[lnwire.Message] { + return func(queueLen int, item lnwire.Message) bool { + return false + } +} diff --git a/queue/back_pressure.go b/queue/back_pressure.go new file mode 100644 index 00000000000..de3c8d08644 --- /dev/null +++ b/queue/back_pressure.go @@ -0,0 +1,165 @@ +package queue + +import ( + "context" + "math/rand" + + "github.com/lightningnetwork/lnd/fn/v2" +) + +// DropPredicate decides whether to drop an item when the queue is full. +// It receives the current queue length and the item, and returns true to drop, +// false to enqueue. +type DropPredicate[T any] func(queueLen int, item T) bool + +// BackpressureQueue is a generic, fixed-capacity queue with predicate-based +// drop behavior. When full, it uses the DropPredicate to perform early drops +// (e.g., RED-style). +type BackpressureQueue[T any] struct { + ch chan T + dropPredicate DropPredicate[T] +} + +// NewBackpressureQueue creates a new BackpressureQueue with the given capacity +// and drop predicate. +func NewBackpressureQueue[T any](capacity int, + predicate DropPredicate[T]) *BackpressureQueue[T] { + + return &BackpressureQueue[T]{ + ch: make(chan T, capacity), + dropPredicate: predicate, + } +} + +// Enqueue attempts to add an item to the queue, respecting context +// cancellation. Returns true if the item was dropped by the predicate, false if +// enqueued successfully, or an error if ctx is done before enqueue. +func (q *BackpressureQueue[T]) Enqueue(ctx context.Context, + item T) fn.Result[bool] { + + // First, consult the drop predicate based on the current queue length. + // If the predicate decides to drop the item, return true (dropped). + if q.dropPredicate(len(q.ch), item) { + return fn.Ok(true) + } + + // If the predicate decides not to drop, attempt to enqueue the item. + select { + case q.ch <- item: + return fn.Ok(false) + + default: + // Channel is full, and the predicate decided not to drop. We + // must block until space is available or context is cancelled. + select { + case q.ch <- item: + return fn.Ok(false) + + case <-ctx.Done(): + return fn.Err[bool](ctx.Err()) + } + } +} + +// Dequeue retrieves the next item from the queue, blocking until available or +// context done. Returns the item or an error if ctx is done before an item is +// available. +func (q *BackpressureQueue[T]) Dequeue(ctx context.Context) (T, error) { + select { + + case item := <-q.ch: + return item, nil + + case <-ctx.Done(): + var zero T + return zero, ctx.Err() + } +} + +// redConfig holds configuration for RandomEarlyDrop. +type redConfig struct { + randSrc func() float64 +} + +// REDOption is a functional option for configuring RandomEarlyDrop. +type REDOption func(*redConfig) + +// WithRandSource provides a custom random number source (a function that +// returns a float64 between 0.0 and 1.0). +func WithRandSource(src func() float64) REDOption { + return func(cfg *redConfig) { + cfg.randSrc = src + } +} + +// RandomEarlyDrop returns a DropPredicate that implements Random Early +// Detection (RED), inspired by TCP-RED queue management. +// +// RED prevents sudden buffer overflows by proactively dropping packets before +// the queue is full. It establishes two thresholds: +// +// 1. minThreshold: queue length below which no drops occur. +// 2. maxThreshold: queue length at or above which all items are dropped. +// +// Between these points, the drop probability p increases linearly: +// +// p = (queueLen - minThreshold) / (maxThreshold - minThreshold) +// +// For example, with minThreshold=15 and maxThreshold=35: +// - At queueLen=15, p=0.0 (0% drop chance) +// - At queueLen=25, p=0.5 (50% drop chance) +// - At queueLen=35, p=1.0 (100% drop chance) +// +// This smooth ramp helps avoid tail-drop spikes, smooths queue occupancy, +// and gives early back-pressure signals to senders. +func RandomEarlyDrop[T any](minThreshold, maxThreshold int, opts ...REDOption) DropPredicate[T] { + cfg := redConfig{ + randSrc: rand.Float64, + } + + for _, opt := range opts { + opt(&cfg) + } + if cfg.randSrc == nil { + cfg.randSrc = rand.Float64 + } + + return func(queueLen int, _ T) bool { + // If the queue is below the minimum threshold, then we never + // drop. + if queueLen < minThreshold { + return false + } + + // If the queue is at or above the maximum threshold, then we + // always drop. + if queueLen >= maxThreshold { + return true + } + + // If we're in the middle, then we implement linear scaling of + // the drop probability based on our thresholds. At this point, + // minThreshold <= queueLen < maxThreshold. This also implies + // minThreshold < maxThreshold, so denominator won't be zero. + denominator := float64(maxThreshold - minThreshold) + + // The previous guards (queueLen < minThreshold and queueLen >= + // maxThreshold) ensure that if minThreshold == maxThreshold, + // this part is not reached. + // + // If minThreshold == maxThreshold: + // - if queueLen < minThreshold, returns false. + // - if queueLen >= minThreshold (i.e. queueLen >= + // maxThreshold), returns true. + // + // So, if we reach here, maxThreshold > minThreshold is + // guaranteed. + if denominator <= 0 { + return true + } + + p := float64(queueLen-minThreshold) / denominator + + return cfg.randSrc() < p + } +} diff --git a/queue/back_pressure_test.go b/queue/back_pressure_test.go new file mode 100644 index 00000000000..c583603c1b8 --- /dev/null +++ b/queue/back_pressure_test.go @@ -0,0 +1,375 @@ +package queue + +import ( + "context" + "math/rand" + "testing" + "time" + + "github.com/stretchr/testify/require" + "pgregory.net/rapid" +) + +// queueMachine is the generic state machine logic for testing +// BackpressureQueue. T must be comparable for use in assertions. +type queueMachine[T comparable] struct { + tb rapid.TB + + capacity int + + queue *BackpressureQueue[T] + + modelQueue []T + + dropPredicate DropPredicate[T] + + itemGenerator *rapid.Generator[T] +} + +// Enqueue is a state machine action. It enqueues an item and updates the model. +func (m *queueMachine[T]) Enqueue(t *rapid.T) { + item := m.itemGenerator.Draw(t, "item") + + result := m.queue.Enqueue(context.Background(), item) + require.True( + m.tb, result.IsOk(), "Enqueue returned an error with "+ + "background context: %v", result.Err(), + ) + + // Unpack the boolean value indicating if the predicate actually dropped + // the item. + actualDrop, _ := result.Unpack() + + if !actualDrop { + // If the item was not dropped, it must have been enqueued. Add + // it to the model. The modelQueue should not exceed capacity. + // This is also checked in Check(). + m.modelQueue = append(m.modelQueue, item) + } +} + +// Dequeue is a state machine action. It dequeues an item and updates the model. +func (m *queueMachine[T]) Dequeue(t *rapid.T) { + if len(m.modelQueue) == 0 { + // If the model is empty, the actual queue channel should also + // be empty. + require.Zero( + m.tb, len(m.queue.ch), "actual queue channel not "+ + "empty when model is empty", + ) + + // Attempting to dequeue from an empty queue should block. We + // verify this by trying to dequeue with a very short timeout. + ctx, cancel := context.WithTimeout( + context.Background(), 5*time.Millisecond, + ) + defer cancel() + + _, err := m.queue.Dequeue(ctx) + require.ErrorIs( + m.tb, err, context.DeadlineExceeded, "Dequeue should "+ + "block on empty queue", + ) + return + } + + // The model is not empty, so we expect to dequeue an item. + expectedItem := m.modelQueue[0] + m.modelQueue = m.modelQueue[1:] + + // Perform the dequeue operation. + actualItem, err := m.queue.Dequeue(context.Background()) + require.NoError(m.tb, err, "Dequeue failed when model was not empty") + require.Equal( + m.tb, expectedItem, actualItem, "dequeued item does not "+ + "match model (FIFO violation or model error)", + ) +} + +// Check is called by rapid after each action to verify invariants. +func (m *queueMachine[T]) Check(t *rapid.T) { + // Invariant 1: The length of the internal channel must not exceed + // capacity. + require.LessOrEqual( + m.tb, len(m.queue.ch), m.capacity, + "queue channel length exceeds capacity", + ) + + // Invariant 2: The length of our model queue must match the length of + // the actual queue's channel. + require.Equal( + m.tb, len(m.modelQueue), len(m.queue.ch), + "model queue length mismatch with actual queue channel length", + ) + + // Invariant 3: The model queue itself should not exceed capacity. + require.LessOrEqual( + m.tb, len(m.modelQueue), m.capacity, + "model queue length exceeds capacity", + ) +} + +// intQueueMachine is a concrete wrapper for queueMachine[int] for rapid. +type intQueueMachine struct { + *queueMachine[int] +} + +// NewIntqueueMachine creates a new queueMachine specialized for int items. +func NewIntqueueMachine(rt *rapid.T) *intQueueMachine { + // Draw from the rapid distribution for the made params of our queue. + capacity := rapid.IntRange(1, 50).Draw(rt, "capacity") + minThreshold := rapid.IntRange(0, capacity).Draw(rt, "minThreshold") + maxThreshold := rapid.IntRange( + minThreshold, capacity, + ).Draw(rt, "maxThreshold") + + // Draw a seed for this machine's local RNG using rapid. This makes the + // predicate's randomness part of rapid's generated test case. + machineSeed := rapid.Int64().Draw(rt, "machine_rng_seed") + localRngFixed := rand.New(rand.NewSource(machineSeed)) + + rt.Logf("NewIntqueueMachine: capacity=%d, minT=%d, maxT=%d, "+ + "machineSeed=%d", capacity, minThreshold, maxThreshold, + machineSeed) + + predicate := RandomEarlyDrop[int]( + minThreshold, maxThreshold, + WithRandSource(localRngFixed.Float64), + ) + + q := NewBackpressureQueue[int](capacity, predicate) + + return &intQueueMachine{ + queueMachine: &queueMachine[int]{ + tb: rt, + capacity: capacity, + queue: q, + modelQueue: make([]int, 0, capacity), + dropPredicate: predicate, + itemGenerator: rapid.IntRange(-1000, 1000), + }, + } +} + +// Enqueue forwards the call to the generic queueMachine. +func (m *intQueueMachine) Enqueue(t *rapid.T) { m.queueMachine.Enqueue(t) } + +// Dequeue forwards the call to the generic queueMachine. +func (m *intQueueMachine) Dequeue(t *rapid.T) { m.queueMachine.Dequeue(t) } + +// Check forwards the call to the generic queueMachine. +func (m *intQueueMachine) Check(t *rapid.T) { m.queueMachine.Check(t) } + +// TestBackpressureQueueRapidInt is the main property-based test for +// BackpressureQueue using the IntqueueMachine state machine. +func TestBackpressureQueueRapidInt(t *testing.T) { + rapid.Check(t, func(rt *rapid.T) { + // Initialize the state machine instance within the property + // function. NewIntqueueMachine expects *rapid.T, which rt is. + machine := NewIntqueueMachine(rt) + + // Generate the actions map from the machine's methods. Rapid + // will randomly call the methods, and then use the `Check` + // method to verify invariants. + rt.Repeat(rapid.StateMachineActions(machine)) + }) +} + +// TestBackpressureQueueEnqueueCancellation tests that Enqueue respects context +// cancellation when it would otherwise block. +func TestBackpressureQueueEnqueueCancellation(t *testing.T) { + rapid.Check(t, func(rt *rapid.T) { + capacity := rapid.IntRange(1, 20).Draw(rt, "capacity") + + // Use a predicate that never drops when full, to force blocking + // behavior. + q := NewBackpressureQueue[int](capacity, + func(_ int, _ int) bool { return false }, + ) + + // Fill the queue to its capacity. + for i := 0; i < capacity; i++ { + res := q.Enqueue(context.Background(), i) + require.True( + rt, res.IsOk(), "Enqueue failed during "+ + "setup: %v", res.Err(), + ) + + dropped, _ := res.Unpack() + require.False(rt, dropped, "Item dropped during setup") + } + require.Equal(rt, capacity, len(q.ch), "Queue should be full after setup") + + // Attempt to enqueue one more item with an immediately + // cancelled context. + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + enqueueResult := q.Enqueue(ctx, 999) + require.True( + rt, enqueueResult.IsErr(), "Enqueue should have "+ + "returned an error for cancelled context", + ) + require.ErrorIs( + rt, enqueueResult.Err(), context.Canceled, "Error "+ + "should be context.Canceled", + ) + + // Ensure the queue state (length) is unchanged. + require.Equal( + rt, capacity, len(q.ch), "queue length changed "+ + "after cancelled enqueue attempt", + ) + }) +} + +// TestBackpressureQueueDequeueCancellation tests that Dequeue respects context +// cancellation when the queue is empty and it would otherwise block. +func TestBackpressureQueueDequeueCancellation(t *testing.T) { + rapid.Check(t, func(rt *rapid.T) { + capacity := rapid.IntRange(0, 20).Draw(rt, "capacity") + + // The predicate doesn't matter much here as the queue will be + // empty. + q := NewBackpressureQueue[int]( + capacity, RandomEarlyDrop[int](0, capacity), + ) + + require.Zero( + rt, len(q.ch), "queue should be empty initially for "+ + "Dequeue cancellation test", + ) + + // Attempt to dequeue from the empty queue with an immediately + // cancelled context. + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + _, err := q.Dequeue(ctx) + require.Error( + rt, err, "dequeue should have returned an error "+ + "for cancelled context", + ) + require.ErrorIs( + rt, err, context.Canceled, "Error should be "+ + "context.Canceled", + ) + }) +} + +// TestBackpressureQueueComposedPredicate demonstrates testing with a composed +// predicate. This is a scenario-based test rather than a full property-based +// state machine. +func TestBackpressureQueueComposedPredicate(t *testing.T) { + capacity := 10 + minThresh, maxThresh := 3, 7 + + // Use a deterministic random source for this specific test case to + // ensure predictable behavior of RandomEarlyDrop. + const testSeed = int64(12345) + localRng := rand.New(rand.NewSource(testSeed)) + + redPred := RandomEarlyDrop[int]( + minThresh, maxThresh, WithRandSource(localRng.Float64), + ) + + // Next, we'll define a custom predicate: drop items with value 42. + customValuePredicate := func(queueLen int, item int) bool { + return item == 42 + } + + // We'll also make a composed predicate: drop if RED says so OR if item + // is 42. + composedPredicate := func(queueLen int, item int) bool { + isRedDrop := redPred(queueLen, item) + isCustomDrop := customValuePredicate(queueLen, item) + return isRedDrop || isCustomDrop + } + + q := NewBackpressureQueue[int](capacity, composedPredicate) + + // Scenario 1: Enqueue item 42 when queue length is between min/max + // thresholds. As we're below the max threshold, we shouldn't drop + // anything. + for i := 0; i < minThresh; i++ { + // All items aren't 42, so they shouldn't be dropped. + res := q.Enqueue(context.Background(), i) + require.True( + t, res.IsOk(), "enqueue S1 setup item %d Ok: %v", i, + res.Err(), + ) + + droppedVal, _ := res.Unpack() + require.False( + t, droppedVal, "enqueue S1 setup item %d (qLen "+ + "before: %d) should not be dropped. Predicate "+ + "was redPred(%d,%d) || customPred(%d,%d)", i, + len(q.ch)-1, len(q.ch)-1, i, len(q.ch)-1, i, + ) + } + + currentLen := len(q.ch) + require.Equal(t, minThresh, currentLen, "queue length after S1 setup") + + // Enqueue item 42. customValuePredicate is true, so composedPredicate + // is true. Item 42 should be dropped regardless of what redPred + // decides. + res := q.Enqueue(context.Background(), 42) + require.True( + t, res.IsOk(), "Enqueue of 42 should not error: %v", res.Err(), + ) + require.True(t, + res.UnwrapOrFail(t), "Item 42 should have been dropped by "+ + "composed predicate", + ) + require.Equal( + t, currentLen, len(q.ch), "queue length should not change "+ + "after dropping 42", + ) + + // Scenario 2: Queue is full. Item 100 (not 42). Reset the localRng + // state by re-seeding for the setup of Scenario 2 to ensure its + // behavior is independent of S1. + localRng.Seed(testSeed + 1) + + // The goal is to get it to capacity for the next step. First, create a + // temporary queue with only redPred to observe its behavior or fill. + // This part is mostly for potential debugging or complex fill logic. + // For this test, we'll simplify by directly filling the SUT queue. + localRng.Seed(testSeed + 2) + + // Re-create the main SUT queue with the composedPredicate. We will + // manually fill its channel to capacity to bypass Enqueue logic for + // setup. + q = NewBackpressureQueue[int](capacity, composedPredicate) + for i := 0; i < capacity; i++ { + // Directly add items to the channel, bypassing Enqueue logic + // for this setup. Ensure items are not 42, so + // customValuePredicate is false for them. Behavior of redPred + // part is controlled by localRng. + q.ch <- i + } + require.Equal( + t, capacity, len(q.ch), "queue manually filled to capacity "+ + "for S2 test", + ) + + localRng.Seed(testSeed + 3) + + res = q.Enqueue(context.Background(), 100) + require.True( + t, res.IsOk(), "Enqueue of 100 should not error: %v", res.Err(), + ) + + // Expect drop because queue is full (len=capacity), so + // redPred(capacity, 100) is true. customValuePredicate(capacity, 100) + // is false. Thus, composedPredicate should be true. + require.True( + t, res.UnwrapOrFail(t), "Item 100 should be dropped (due to "+ + "RED part of composed predicate) when queue full", + ) + require.Equal( + t, capacity, len(q.ch), "Queue length should not change "+ + "after dropping 100", + ) +} diff --git a/queue/go.mod b/queue/go.mod index 58267e27606..7dd9930d31b 100644 --- a/queue/go.mod +++ b/queue/go.mod @@ -1,6 +1,19 @@ module github.com/lightningnetwork/lnd/queue -require github.com/lightningnetwork/lnd/ticker v1.0.0 +require ( + github.com/lightningnetwork/lnd/fn/v2 v2.0.8 + github.com/lightningnetwork/lnd/ticker v1.0.0 + github.com/stretchr/testify v1.8.1 + pgregory.net/rapid v1.2.0 +) + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + golang.org/x/exp v0.0.0-20231226003508-02704c960a9b // indirect + golang.org/x/sync v0.7.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) replace github.com/lightningnetwork/lnd/ticker v1.0.0 => ../ticker diff --git a/queue/go.sum b/queue/go.sum index e69de29bb2d..575b2bc6777 100644 --- a/queue/go.sum +++ b/queue/go.sum @@ -0,0 +1,25 @@ +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/lightningnetwork/lnd/fn/v2 v2.0.8 h1:r2SLz7gZYQPVc3IZhU82M66guz3Zk2oY+Rlj9QN5S3g= +github.com/lightningnetwork/lnd/fn/v2 v2.0.8/go.mod h1:TOzwrhjB/Azw1V7aa8t21ufcQmdsQOQMDtxVOQWNl8s= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +golang.org/x/exp v0.0.0-20231226003508-02704c960a9b h1:kLiC65FbiHWFAOu+lxwNPujcsl8VYyTYYEZnsOO1WK4= +golang.org/x/exp v0.0.0-20231226003508-02704c960a9b/go.mod h1:iRJReGqOEeBhDZGkGbynYwcHlctCvnjTYIamk7uXpHI= +golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= +golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +pgregory.net/rapid v1.2.0 h1:keKAYRcjm+e1F0oAuU5F5+YPAWcyxNNRK2wud503Gnk= +pgregory.net/rapid v1.2.0/go.mod h1:PY5XlDGj0+V1FCq0o192FdRhpKHGTRIWBgqjDBTrq04=