66// Package replay implements an efficient anti-replay algorithm as specified in RFC 6479.
77package replay
88
9- import (
10- "sync/atomic"
11- )
9+ import "sync"
1210
1311const (
1412 // RekeyAfterMessages is the maximum number of messages that can be sent before rekeying.
@@ -17,7 +15,7 @@ const (
1715 RejectAfterMessages = (1 << 64 ) - (1 << 13 ) - 1
1816)
1917
20- type block = atomic. Uint64
18+ type block uint64
2119
2220const (
2321 blockBitLog = 6 // 1<<6 == 64 bits
@@ -28,52 +26,57 @@ const (
2826 bitMask = blockBits - 1
2927)
3028
29+ // Filter rejects replayed messages by checking if message counter value is
30+ // within a sliding window of previously received messages.
31+ // The zero value for Filter is an empty, ready-to-use, thread-safe filter.
3132type Filter struct {
32- last atomic.Uint64
33+ mu sync.Mutex
34+ last uint64
3335 ring [ringBlocks ]block
3436}
3537
38+ // Reset resets the filter to empty state.
3639func (f * Filter ) Reset () {
37- f .last .Store (0 )
38- f .ring [0 ].Store (0 )
40+ f .mu .Lock ()
41+ f .last = 0
42+ f .ring [0 ] = 0
43+ // Optionally clear the rest to be thorough:
44+ for i := 1 ; i < ringBlocks ; i ++ {
45+ f .ring [i ] = 0
46+ }
47+ f .mu .Unlock ()
3948}
4049
4150// ValidateCounter checks if the counter should be accepted.
51+ // Overlimit counters (>= limit) are always rejected.
4252func (f * Filter ) ValidateCounter (counter , limit uint64 ) bool {
4353 if counter >= limit {
4454 return false
4555 }
4656
47- indexBlock := counter >> blockBitLog
48- last := f . last . Load ()
57+ f . mu . Lock ()
58+ defer f . mu . Unlock ()
4959
50- if counter > last {
51- current := last >> blockBitLog
60+ indexBlock := counter >> blockBitLog
61+ if counter > f .last { // move window forward
62+ current := f .last >> blockBitLog
5263 diff := indexBlock - current
5364 if diff > ringBlocks {
54- diff = ringBlocks
65+ diff = ringBlocks // cap diff to clear the whole ring
5566 }
5667 for i := current + 1 ; i <= current + diff ; i ++ {
57- f .ring [i & blockMask ]. Store ( 0 )
68+ f .ring [i & blockMask ] = 0
5869 }
59- f .last . Store ( counter )
60- } else if last - counter > windowSize {
70+ f .last = counter
71+ } else if f . last - counter > windowSize { // behind current window
6172 return false
6273 }
6374
64- indexBlock &= blockMask
75+ // check and set bit
76+ idx := indexBlock & blockMask
6577 indexBit := counter & bitMask
66-
67- ptr := & f .ring [indexBlock ]
68- mask := uint64 (1 ) << indexBit
69-
70- for {
71- old := ptr .Load ()
72- if old & mask != 0 {
73- return false
74- }
75- if ptr .CompareAndSwap (old , old | mask ) {
76- return true
77- }
78- }
78+ old := f .ring [idx ]
79+ new := old | 1 << indexBit
80+ f .ring [idx ] = new
81+ return old != new
7982}
0 commit comments