diff --git a/frac/active_lids.go b/frac/active_lids.go index 47abe92a..4875deb8 100644 --- a/frac/active_lids.go +++ b/frac/active_lids.go @@ -1,29 +1,12 @@ package frac import ( + "cmp" "math" - "sort" + "slices" "sync" ) -type queueIDs struct { - lids []uint32 - mids []uint64 - rids []uint64 -} - -func (p *queueIDs) Len() int { return len(p.lids) } -func (p *queueIDs) Less(i, j int) bool { - if p.mids[p.lids[i]] == p.mids[p.lids[j]] { - if p.rids[p.lids[i]] == p.rids[p.lids[j]] { - return p.lids[i] > p.lids[j] - } - return p.rids[p.lids[i]] > p.rids[p.lids[j]] - } - return p.mids[p.lids[i]] > p.mids[p.lids[j]] -} -func (p *queueIDs) Swap(i, j int) { p.lids[i], p.lids[j] = p.lids[j], p.lids[i] } - type TokenLIDs struct { sortedMu sync.Mutex // global merge mutex, making the merge process strictly sequential sorted []uint32 // slice of actual sorted and merged LIDs of token @@ -39,16 +22,20 @@ func (tl *TokenLIDs) GetLIDs(mids, rids *UInt64s) []uint32 { lids := tl.getQueuedLIDs() if len(lids) != 0 { - midsVals := mids.GetVals() - ridsVals := rids.GetVals() + mids := mids.GetVals() + rids := rids.GetVals() - sort.Sort(&queueIDs{ - lids: lids, - mids: midsVals, - rids: ridsVals, + slices.SortFunc(lids, func(i, j uint32) int { + if mids[i] == mids[j] { + if rids[i] == rids[j] { + return -cmp.Compare(i, j) + } + return -cmp.Compare(rids[i], rids[j]) + } + return -cmp.Compare(mids[i], mids[j]) }) - tl.sorted = mergeSorted(tl.sorted, lids, midsVals, ridsVals) + tl.sorted = mergeSorted(tl.sorted, lids, mids, rids) } return tl.sorted @@ -98,7 +85,7 @@ func mergeSorted(right, left []uint32, mids, rids []uint64) []uint32 { val := uint32(0) prev := uint32(math.MaxUint32) - cmp := SeqIDCmp{ + c := SeqIDCmp{ mid: mids, rid: rids, } @@ -108,7 +95,7 @@ func mergeSorted(right, left []uint32, mids, rids []uint64) []uint32 { for l != len(left) && r != len(right) { ri, li := right[r], left[l] - switch cmp.compare(ri, li) { + switch c.compare(ri, li) { case 0: val = ri r++