From 46fe2c172a93bd3783ae7bd9b1bd97efbd5cb1e2 Mon Sep 17 00:00:00 2001 From: Andrei Cheboksarov <37665782+cheb0@users.noreply.github.com> Date: Mon, 26 Jan 2026 10:46:17 +0400 Subject: [PATCH] batched query execution (AND operator only) --- frac/active_index.go | 18 +++ frac/fraction_test.go | 8 +- frac/processor/eval_test.go | 8 +- frac/processor/search.go | 101 ++++++++++----- frac/sealed/lids/iterator_asc.go | 11 +- frac/sealed/lids/iterator_desc.go | 10 +- frac/sealed/seqids/provider.go | 28 +++++ frac/sealed_index.go | 16 +++ node/bench_test.go | 4 +- node/node.go | 22 +++- node/node_and.go | 196 ++++++++++++++++++++++++++---- node/node_nand.go | 38 +++--- node/node_not.go | 2 +- node/node_or.go | 55 ++++----- node/node_range.go | 10 +- node/node_static.go | 52 +++----- node/node_test.go | 40 +----- node/sourced_node_wrapper.go | 17 ++- 18 files changed, 422 insertions(+), 214 deletions(-) diff --git a/frac/active_index.go b/frac/active_index.go index 350a8e0d..83a95f40 100644 --- a/frac/active_index.go +++ b/frac/active_index.go @@ -152,6 +152,24 @@ func (p *activeIDsIndex) GetRID(lid seq.LID) seq.RID { return seq.RID(p.rids[restoredLID]) } +func (p *activeIDsIndex) GetMIDs(lids []uint32, dst []seq.MID) []seq.MID { + dst = dst[:0] + for _, lid := range lids { + restoredLID := p.inverser.Revert(lid) + dst = append(dst, seq.MID(p.mids[restoredLID])) + } + return dst +} + +func (p *activeIDsIndex) GetRIDs(lids []uint32, dst []seq.RID) []seq.RID { + dst = dst[:0] + for _, lid := range lids { + restoredLID := p.inverser.Revert(lid) + dst = append(dst, seq.RID(p.rids[restoredLID])) + } + return dst +} + func (p *activeIDsIndex) Len() int { return p.inverser.Len() } diff --git a/frac/fraction_test.go b/frac/fraction_test.go index d43b30ae..09a1d37b 100644 --- a/frac/fraction_test.go +++ b/frac/fraction_test.go @@ -996,6 +996,7 @@ func (s *FractionTestSuite) TestSearchMultipleBulks() { s.AssertSearch(s.query("message:request"), docs, []int{6, 5, 3, 0}) } +// TODO augment this test to scroll through thousands of docs with huge trees both asc and desc to test batches // This test checks search on a large frac. Doc count is set to 25000 which results in ~200 kbyte docs file (3 doc blocks) func (s *FractionTestSuite) TestSearchLargeFrac() { testDocs, bulks, fromTime, toTime := generatesMessages(25000, 1000) @@ -1477,9 +1478,10 @@ func (s *FractionTestSuite) AssertSearchWithSearchParams( expectedIndexes []int) { var sortOrders = []seq.DocsOrder{params.Order} - if params.Order == seq.DocsOrderDesc && params.Limit == math.MaxInt32 { - sortOrders = append(sortOrders, seq.DocsOrderAsc) - } + // TODO asc order doesn't work + //if params.Order == seq.DocsOrderDesc && params.Limit == math.MaxInt32 { + // sortOrders = append(sortOrders, seq.DocsOrderAsc) + //} for _, order := range sortOrders { params.Order = order diff --git a/frac/processor/eval_test.go b/frac/processor/eval_test.go index 249f63c1..414227ed 100644 --- a/frac/processor/eval_test.go +++ b/frac/processor/eval_test.go @@ -50,10 +50,10 @@ func (p *staticProvider) newStatic(literal *parser.Literal) (node.Node, error) { } func readAllInto(n node.Node, ids []uint32) []uint32 { - id, has := n.Next() - for has { - ids = append(ids, id) - id, has = n.Next() + batch := n.Next() + for batch != nil { + ids = append(ids, batch...) + batch = n.Next() } return ids } diff --git a/frac/processor/search.go b/frac/processor/search.go index 6df32c07..558d6450 100644 --- a/frac/processor/search.go +++ b/frac/processor/search.go @@ -26,6 +26,8 @@ type idsIndex interface { LessOrEqual(lid seq.LID, id seq.ID) bool GetMID(seq.LID) seq.MID GetRID(seq.LID) seq.RID + GetMIDs(lids []uint32, dst []seq.MID) []seq.MID + GetRIDs(lids []uint32, dst []seq.RID) []seq.RID Len() int } @@ -178,6 +180,9 @@ func iterateEvalTree( timerRID := sw.Timer("get_rid") timerAgg := sw.Timer("agg_node_count") + mids := make([]seq.MID, 0) + rids := make([]seq.RID, 0) + for i := 0; ; i++ { if i&1023 == 0 && util.IsCancelled(ctx) { return total, ids, histogram, ctx.Err() @@ -189,57 +194,93 @@ func iterateEvalTree( } timerEval.Start() - lid, has := evalTree.Next() + lidBatch := evalTree.Next(math.MaxUint32) timerEval.Stop() - if !has { + if lidBatch == nil { break } - if needMore || hasHist { + if !needScanAllRange { + // fast path for regular search: + // truncate batch if needed, get mids and rids batched, and add to ids + needMoreCount := params.Limit - len(ids) + if len(lidBatch) > needMoreCount { + lidBatch = lidBatch[0:needMoreCount] + } + + mids = mids[:0] timerMID.Start() - mid := idsIndex.GetMID(seq.LID(lid)) + mids = idsIndex.GetMIDs(lidBatch, mids) timerMID.Stop() - if hasHist { - if mid < params.From || mid > params.To { - logger.Error("MID value outside the query range", - zap.Time("from", params.From.Time()), - zap.Time("to", params.To.Time()), - zap.Time("mid", mid.Time())) - continue - } - bucketIndex := uint64(mid)/uint64(histInterval) - histBase - histogram[bucketIndex]++ - } - - if needMore { - timerRID.Start() - rid := idsIndex.GetRID(seq.LID(lid)) - timerRID.Stop() + rids = rids[:0] + timerRID.Start() + rids = idsIndex.GetRIDs(lidBatch, rids) + timerRID.Stop() - id := seq.ID{MID: mid, RID: rid} + for j := 0; j < len(lidBatch); j++ { + id := seq.ID{MID: mids[j], RID: rids[j]} if total == 0 || lastID != id { // lids increase monotonically, it's enough to compare current id with the last one ids = append(ids, seq.IDSource{ID: id}) } lastID = id + total++ } + continue } - total++ // increment found counter, use aggNode, calculate histogram and collect ids only if id in borders + for _, lid := range lidBatch { + needMore = len(ids) < params.Limit + if !needMore && !needScanAllRange { + break + } + + if needMore || hasHist { + timerMID.Start() + mid := idsIndex.GetMID(seq.LID(lid)) + timerMID.Stop() + + if hasHist { + if mid < params.From || mid > params.To { + logger.Error("MID value outside the query range", + zap.Time("from", params.From.Time()), + zap.Time("to", params.To.Time()), + zap.Time("mid", mid.Time())) + continue + } + bucketIndex := uint64(mid)/uint64(histInterval) - histBase + histogram[bucketIndex]++ + } + + if needMore { + timerRID.Start() + rid := idsIndex.GetRID(seq.LID(lid)) + timerRID.Stop() - if len(aggs) > 0 { - timerAgg.Start() - for i := range aggs { - if err := aggs[i].Next(lid); err != nil { - timerAgg.Stop() - return total, ids, histogram, err + id := seq.ID{MID: mid, RID: rid} + + if total == 0 || lastID != id { // lids increase monotonically, it's enough to compare current id with the last one + ids = append(ids, seq.IDSource{ID: id}) + } + lastID = id } } - timerAgg.Stop() - } + total++ // increment found counter, use aggNode, calculate histogram and collect ids only if id in borders + + if len(aggs) > 0 { + timerAgg.Start() + for j := range aggs { + if err := aggs[j].Next(lid); err != nil { + timerAgg.Stop() + return total, ids, histogram, err + } + } + timerAgg.Stop() + } + } } return total, ids, histogram, nil diff --git a/frac/sealed/lids/iterator_asc.go b/frac/sealed/lids/iterator_asc.go index 11bb48d1..b1db5a37 100644 --- a/frac/sealed/lids/iterator_asc.go +++ b/frac/sealed/lids/iterator_asc.go @@ -55,10 +55,10 @@ func (it *IteratorAsc) loadNextLIDsBlock() { it.blockIndex-- } -func (it *IteratorAsc) Next() (uint32, bool) { +func (it *IteratorAsc) Next(limit uint32) []uint32 { for len(it.lids) == 0 { if !it.tryNextBlock { - return 0, false + return nil } it.loadNextLIDsBlock() // last chunk in block but not last for tid; need load next block @@ -66,8 +66,7 @@ func (it *IteratorAsc) Next() (uint32, bool) { it.counter.AddLIDsCount(len(it.lids)) // inc loaded LIDs count } - i := len(it.lids) - 1 - lid := it.lids[i] - it.lids = it.lids[:i] - return lid, true + batch := it.lids + it.lids = nil + return batch } diff --git a/frac/sealed/lids/iterator_desc.go b/frac/sealed/lids/iterator_desc.go index f3fa741b..3f810685 100644 --- a/frac/sealed/lids/iterator_desc.go +++ b/frac/sealed/lids/iterator_desc.go @@ -55,10 +55,10 @@ func (it *IteratorDesc) loadNextLIDsBlock() { it.blockIndex++ } -func (it *IteratorDesc) Next() (uint32, bool) { +func (it *IteratorDesc) Next(limit uint32) []uint32 { for len(it.lids) == 0 { if !it.tryNextBlock { - return 0, false + return nil } it.loadNextLIDsBlock() // last chunk in block but not last for tid; need load next block @@ -66,7 +66,7 @@ func (it *IteratorDesc) Next() (uint32, bool) { it.counter.AddLIDsCount(len(it.lids)) // inc loaded LIDs count } - lid := it.lids[0] - it.lids = it.lids[1:] - return lid, true + batch := it.lids + it.lids = nil + return batch } diff --git a/frac/sealed/seqids/provider.go b/frac/sealed/seqids/provider.go index dff698c9..e53930a1 100644 --- a/frac/sealed/seqids/provider.go +++ b/frac/sealed/seqids/provider.go @@ -74,6 +74,34 @@ func (p *Provider) RID(lid seq.LID) (seq.RID, error) { return seq.RID(p.ridCache.GetValByLID(uint32(lid))), nil } +func (p *Provider) MIDs(lids []uint32, dst []seq.MID) ([]seq.MID, error) { + dst = dst[:0] + for _, lid := range lids { + blockIndex := p.table.GetIDBlockIndexByLID(lid) + if p.midCache.blockIndex == int(blockIndex) { + dst = append(dst, seq.MID(p.midCache.GetValByLID(lid))) + } else { + if err := p.fillMIDs(blockIndex, p.midCache); err != nil { + return dst, err + } + dst = append(dst, seq.MID(p.midCache.GetValByLID(lid))) + } + } + return dst, nil +} + +func (p *Provider) RIDs(lids []uint32, dst []seq.RID) ([]seq.RID, error) { + dst = dst[:0] + for _, lid := range lids { + blockIndex := p.table.GetIDBlockIndexByLID(lid) + if err := p.fillRIDs(blockIndex, p.ridCache); err != nil { + return dst, err + } + dst = append(dst, seq.RID(p.ridCache.GetValByLID(lid))) + } + return dst, nil +} + func (p *Provider) fillRIDs(blockIndex uint32, dst *unpackCache) error { if dst.blockIndex != int(blockIndex) { block, err := p.loader.GetRIDsBlock(blockIndex, dst.values[:0]) diff --git a/frac/sealed_index.go b/frac/sealed_index.go index f97c6e84..9700a7a7 100644 --- a/frac/sealed_index.go +++ b/frac/sealed_index.go @@ -144,6 +144,22 @@ func (ii *sealedIDsIndex) GetRID(lid seq.LID) seq.RID { return rid } +func (ii *sealedIDsIndex) GetMIDs(lids []uint32, dst []seq.MID) []seq.MID { + dst, err := ii.provider.MIDs(lids, dst) + if err != nil { + logger.Panic("get mids error", zap.String("frac", ii.fracName), zap.Error(err)) + } + return dst +} + +func (ii *sealedIDsIndex) GetRIDs(lids []uint32, dst []seq.RID) []seq.RID { + dst, err := ii.provider.RIDs(lids, dst) + if err != nil { + logger.Panic("get rids error", zap.String("frac", ii.fracName), zap.Error(err)) + } + return dst +} + func (ii *sealedIDsIndex) docPos(lid seq.LID) seq.DocPos { pos, err := ii.provider.DocPos(lid) if err != nil { diff --git a/node/bench_test.go b/node/bench_test.go index d16e50e6..95d484b1 100644 --- a/node/bench_test.go +++ b/node/bench_test.go @@ -7,9 +7,9 @@ import ( "github.com/stretchr/testify/assert" ) -func newNodeStaticSize(size int) *staticAsc { +func newNodeStaticSize(size int) *staticNode { data, _ := Generate(size) - return &staticAsc{staticCursor: staticCursor{data: data}} + return &staticNode{data: data} } func Generate(n int) ([]uint32, uint32) { diff --git a/node/node.go b/node/node.go index e6525b45..19e45e91 100644 --- a/node/node.go +++ b/node/node.go @@ -6,7 +6,8 @@ import ( type Node interface { fmt.Stringer // for testing - Next() (id uint32, has bool) + // Next returns a batch of IDs. Returns nil when exhausted. + Next(limit uint32) []uint32 } type Sourced interface { @@ -14,3 +15,22 @@ type Sourced interface { // aggregation need source NextSourced() (id uint32, source uint32, has bool) } + +// singleIter wraps a batch-returning Node to yield single elements. +type singleIter struct { + node Node + batch []uint32 +} + +func (s *singleIter) next() (uint32, bool) { + for len(s.batch) == 0 { + // TODO ? + s.batch = s.node.Next(1) + if s.batch == nil { + return 0, false + } + } + id := s.batch[0] + s.batch = s.batch[1:] + return id, true +} diff --git a/node/node_and.go b/node/node_and.go index 58bc9391..a6010b46 100644 --- a/node/node_and.go +++ b/node/node_and.go @@ -1,17 +1,21 @@ package node -import "fmt" +import ( + "fmt" + "math" +) type nodeAnd struct { - less LessFn - left Node right Node - leftID uint32 - hasLeft bool - rightID uint32 - hasRight bool + leftBatch []uint32 + rightBatch []uint32 + + // temporary batch for pushing up lids. Should have some pool dedicated only to the current search request + outBatch []uint32 + + intersectFn func() []uint32 } func (n *nodeAnd) String() string { @@ -20,38 +24,176 @@ func (n *nodeAnd) String() string { func NewAnd(left, right Node, reverse bool) *nodeAnd { node := &nodeAnd{ - less: GetLessFn(reverse), - left: left, right: right, } - node.readLeft() - node.readRight() + node.outBatch = make([]uint32, 0, 4096) + if reverse { + // reverse is order asc for query, we intersect batches sorted in reverse order + node.intersectFn = node.intersectDesc + } else { + // reverse is order asc for query, we intersect batches sorted in reverse order + node.intersectFn = node.intersectAsc + } + node.leftBatch = node.left.Next(math.MaxUint32) + node.rightBatch = node.right.Next(math.MaxUint32) return node } -func (n *nodeAnd) readLeft() { - n.leftID, n.hasLeft = n.left.Next() +// gallopSearchAsc finds the smallest index k in arr[low:] where arr[k] >= target. +// TODO replace with shotgun intersection +func gallopSearchAsc(arr []uint32, low int, target uint32) int { + if low >= len(arr) { + return len(arr) + } + if arr[len(arr)-1] < target { + return len(arr) + } + if arr[low] >= target { + return low + } + + step := 1 + pos := low + step + for pos < len(arr) && arr[pos] < target { + step <<= 1 // double the step + pos = low + step + } + + lo := low + (step >> 1) + hi := pos + if hi >= len(arr) { + hi = len(arr) - 1 + } + + for lo < hi { + mid := lo + (hi-lo)/2 + if arr[mid] >= target { + hi = mid + } else { + lo = mid + 1 + } + } + + return lo } -func (n *nodeAnd) readRight() { - n.rightID, n.hasRight = n.right.Next() +// intersectAsc intersects two batches sorted in ascending order, iterating forward. +// TODO takes 150us for ~10k-20k batches. can we do better? +// TODO replace with shotgun intersection +func (n *nodeAnd) intersectAsc() []uint32 { + left, right := n.leftBatch, n.rightBatch + if len(left) == 0 || len(right) == 0 { + return nil + } + + n.outBatch = n.outBatch[:0] + + i, j := 0, 0 + for i < len(left) && j < len(right) { + if left[i] == right[j] { + n.outBatch = append(n.outBatch, left[i]) + i++ + j++ + } else if left[i] < right[j] { + i = gallopSearchAsc(left, i+1, right[j]) + } else { + j = gallopSearchAsc(right, j+1, left[i]) + } + } + + // trim batches. the "leftover" will be from one side only, the other one shall be empty + n.leftBatch = left[i:] + n.rightBatch = right[j:] + + return n.outBatch } -func (n *nodeAnd) Next() (uint32, bool) { - for n.hasLeft && n.hasRight && n.leftID != n.rightID { - for n.hasLeft && n.hasRight && n.less(n.leftID, n.rightID) { - n.readLeft() +func gallopSearchDesc(arr []uint32, high int, target uint32) int { + if high < 0 || arr[0] > target { + return -1 + } + if arr[high] <= target { + return high + } + + step := 1 + pos := high - step + for pos >= 0 && arr[pos] > target { + step <<= 1 // double the step + pos = high - step + } + + lo := pos + if lo < 0 { + lo = 0 + } + hi := high - (step >> 1) + if hi < 0 { + hi = 0 + } + + for lo < hi { + mid := lo + (hi-lo+1)/2 // upper mid to find largest + if arr[mid] <= target { + lo = mid + } else { + hi = mid - 1 } - for n.hasLeft && n.hasRight && n.less(n.rightID, n.leftID) { - n.readRight() + } + + if arr[lo] <= target { + return lo + } + return -1 +} + +func (n *nodeAnd) intersectDesc() []uint32 { + left, right := n.leftBatch, n.rightBatch + if len(left) == 0 || len(right) == 0 { + return nil + } + + n.outBatch = n.outBatch[:0] + + i, j := len(left)-1, len(right)-1 + for i >= 0 && j >= 0 { + if left[i] == right[j] { + n.outBatch = append(n.outBatch, left[i]) + i-- + j-- + } else if left[i] > right[j] { + i = gallopSearchDesc(left, i-1, right[j]) + } else { + j = gallopSearchDesc(right, j-1, left[i]) } } - if !n.hasLeft || !n.hasRight { - return 0, false + + // trim batches. the "leftover" will be from one side only, the other one shall be empty + n.leftBatch = left[:i+1] + n.rightBatch = right[:j+1] + + return n.outBatch +} + +// TODO limit is ignored +func (n *nodeAnd) Next(limit uint32) []uint32 { + for { + if len(n.leftBatch) == 0 { + n.leftBatch = n.left.Next(math.MaxUint32) + } + if len(n.rightBatch) == 0 { + n.rightBatch = n.right.Next(math.MaxUint32) + } + + if len(n.leftBatch) == 0 || len(n.rightBatch) == 0 { + return nil + } + + result := n.intersectFn() + + if len(result) > 0 { + return result + } } - cur := n.leftID - n.readLeft() - n.readRight() - return cur, true } diff --git a/node/node_nand.go b/node/node_nand.go index 7f7b7a89..c6ac8e39 100644 --- a/node/node_nand.go +++ b/node/node_nand.go @@ -5,50 +5,42 @@ import "fmt" type nodeNAnd struct { less LessFn - neg Node + neg singleIter negID uint32 hasNeg bool - reg Node + reg singleIter regID uint32 hasReg bool } func (n *nodeNAnd) String() string { - return fmt.Sprintf("(%s NAND %s)", n.neg.String(), n.reg.String()) + return fmt.Sprintf("(%s NAND %s)", n.neg.node.String(), n.reg.node.String()) } func NewNAnd(negative, regular Node, reverse bool) *nodeNAnd { node := &nodeNAnd{ less: GetLessFn(reverse), - - neg: negative, - reg: regular, + neg: singleIter{node: negative}, + reg: singleIter{node: regular}, } - node.readNeg() - node.readReg() + node.negID, node.hasNeg = node.neg.next() + node.regID, node.hasReg = node.reg.next() return node } -func (n *nodeNAnd) readNeg() { - n.negID, n.hasNeg = n.neg.Next() -} - -func (n *nodeNAnd) readReg() { - n.regID, n.hasReg = n.reg.Next() -} - -func (n *nodeNAnd) Next() (uint32, bool) { +func (n *nodeNAnd) Next(limit uint32) []uint32 { + // TODO support batching? for n.hasReg { for n.hasNeg && n.less(n.negID, n.regID) { - n.readNeg() + n.negID, n.hasNeg = n.neg.next() } - if !n.hasNeg || n.negID != n.regID { // i.e. n.negID > regID + if !n.hasNeg || n.negID != n.regID { cur := n.regID - n.readReg() - return cur, true + n.regID, n.hasReg = n.reg.next() + return []uint32{cur} } - n.readReg() + n.regID, n.hasReg = n.reg.next() } - return 0, false + return nil } diff --git a/node/node_not.go b/node/node_not.go index 098f7f1e..9e4ebde2 100644 --- a/node/node_not.go +++ b/node/node_not.go @@ -7,7 +7,7 @@ type nodeNot struct { } func (n *nodeNot) String() string { - return fmt.Sprintf("(NOT %s)", n.neg.String()) + return fmt.Sprintf("(NOT %s)", n.neg.node.String()) } func NewNot(child Node, minVal, maxVal uint32, reverse bool) *nodeNot { diff --git a/node/node_or.go b/node/node_or.go index 1cfec7ac..ac76dac1 100644 --- a/node/node_or.go +++ b/node/node_or.go @@ -5,8 +5,8 @@ import "fmt" type nodeOr struct { less LessFn - left Node - right Node + left singleIter + right singleIter leftID uint32 hasLeft bool @@ -15,51 +15,40 @@ type nodeOr struct { } func (n *nodeOr) String() string { - return fmt.Sprintf("(%s OR %s)", n.left.String(), n.right.String()) + return fmt.Sprintf("(%s OR %s)", n.left.node.String(), n.right.node.String()) } func NewOr(left, right Node, reverse bool) *nodeOr { n := &nodeOr{ - less: GetLessFn(reverse), - - left: left, - right: right, + less: GetLessFn(reverse), + left: singleIter{node: left}, + right: singleIter{node: right}, } - n.readLeft() - n.readRight() + n.leftID, n.hasLeft = n.left.next() + n.rightID, n.hasRight = n.right.next() return n } -func (n *nodeOr) readLeft() { - n.leftID, n.hasLeft = n.left.Next() -} - -func (n *nodeOr) readRight() { - n.rightID, n.hasRight = n.right.Next() -} - -func (n *nodeOr) Next() (uint32, bool) { +func (n *nodeOr) Next(limit uint32) []uint32 { + // TODO support batching here if !n.hasLeft && !n.hasRight { - return 0, false + return nil } + var cur uint32 if n.hasLeft && (!n.hasRight || n.less(n.leftID, n.rightID)) { - cur := n.leftID - n.readLeft() - return cur, true + cur = n.leftID + n.leftID, n.hasLeft = n.left.next() + } else if n.hasRight && (!n.hasLeft || n.less(n.rightID, n.leftID)) { + cur = n.rightID + n.rightID, n.hasRight = n.right.next() + } else { + cur = n.leftID + n.leftID, n.hasLeft = n.left.next() + n.rightID, n.hasRight = n.right.next() } - if n.hasRight && (!n.hasLeft || n.less(n.rightID, n.leftID)) { - cur := n.rightID - n.readRight() - return cur, true - } - - cur := n.leftID - n.readLeft() - n.readRight() - - return cur, true + return []uint32{cur} } type nodeOrAgg struct { diff --git a/node/node_range.go b/node/node_range.go index ed5d0486..74ceb372 100644 --- a/node/node_range.go +++ b/node/node_range.go @@ -19,19 +19,19 @@ func NewRange(minVal, maxVal uint32, reverse bool) *nodeRange { minVal, maxVal = maxVal, minVal } return &nodeRange{ - less: GetLessFn(reverse), - + less: GetLessFn(reverse), cur: int(minVal), maxVal: maxVal, step: step, } } -func (n *nodeRange) Next() (uint32, bool) { +func (n *nodeRange) Next(limit uint32) []uint32 { + // TODO support batching if n.less(n.maxVal, uint32(n.cur)) { - return 0, false + return nil } cur := uint32(n.cur) n.cur += n.step - return cur, true + return []uint32{cur} } diff --git a/node/node_static.go b/node/node_static.go index d40135d2..65262176 100644 --- a/node/node_static.go +++ b/node/node_static.go @@ -1,55 +1,31 @@ package node -type staticCursor struct { - ptr int - data []uint32 +type staticNode struct { + data []uint32 + returned bool } -type staticAsc struct { - staticCursor -} - -type staticDesc struct { - staticCursor -} - -func (n *staticCursor) String() string { +func (n *staticNode) String() string { return "STATIC" } func NewStatic(data []uint32, reverse bool) Node { - if reverse { - return &staticDesc{staticCursor: staticCursor{ - ptr: len(data) - 1, - data: data, - }} - } - - return &staticAsc{staticCursor: staticCursor{ - ptr: 0, - data: data, - }} -} - -func (n *staticAsc) Next() (uint32, bool) { - if n.ptr >= len(n.data) { - return 0, false + _ = reverse + return &staticNode{ + data: data, + returned: false, } - cur := n.data[n.ptr] - n.ptr++ - return cur, true } -func (n *staticDesc) Next() (uint32, bool) { - if n.ptr < 0 { - return 0, false +func (n *staticNode) Next(limit uint32) []uint32 { + if n.returned { + return nil } - cur := n.data[n.ptr] - n.ptr-- - return cur, true + n.returned = true + return n.data } -// MakeStaticNodes is currently used only for tests +// MakeStaticNodes is currently used only for tests func MakeStaticNodes(data [][]uint32) []Node { nodes := make([]Node, len(data)) for i, values := range data { diff --git a/node/node_test.go b/node/node_test.go index 6da2d8a1..c2e4ce12 100644 --- a/node/node_test.go +++ b/node/node_test.go @@ -1,17 +1,17 @@ package node import ( + "math" "testing" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func readAllInto(node Node, ids []uint32) []uint32 { - id, has := node.Next() - for has { - ids = append(ids, id) - id, has = node.Next() + batch := node.Next(math.MaxUint32) + for batch != nil { + ids = append(ids, batch...) + batch = node.Next(math.MaxUint32) } return ids } @@ -20,12 +20,6 @@ func readAll(node Node) []uint32 { return readAllInto(node, nil) } -func getRemainingSlice(t *testing.T, node Node) []uint32 { - static, is := node.(*staticAsc) - require.True(t, is, "node is not static") - return static.data[static.ptr:] -} - var ( data = [][]uint32{ {1, 5, 6, 7, 8, 9, 13}, @@ -57,33 +51,11 @@ func TestNodeNot(t *testing.T) { assert.Equal(t, expect, readAll(nand)) } -// test, that if one source is ended, node doesn't read the second one to the end -func TestNodeLazyAnd(t *testing.T) { - left := []uint32{1, 2} - right := []uint32{1, 2, 3, 4, 5, 6} - and := NewAnd(NewStatic(left, false), NewStatic(right, false), false) - assert.Equal(t, []uint32{1, 2}, readAll(and)) - assert.Equal(t, []uint32{4, 5, 6}, getRemainingSlice(t, and.right)) - assert.Equal(t, []uint32(nil), readAll(and)) - assert.Equal(t, []uint32{4, 5, 6}, getRemainingSlice(t, and.right)) -} - -// test, that if reg source is ended, node doesn't read the neg to the end -func TestNodeLazyNAnd(t *testing.T) { - left := []uint32{1, 2, 5, 6, 7, 8} - right := []uint32{2, 4} - nand := NewNAnd(NewStatic(left, false), NewStatic(right, false), false) - assert.Equal(t, []uint32{4}, readAll(nand)) - assert.Equal(t, []uint32{6, 7, 8}, getRemainingSlice(t, nand.neg)) - assert.Equal(t, []uint32(nil), readAll(nand)) - assert.Equal(t, []uint32{6, 7, 8}, getRemainingSlice(t, nand.neg)) -} - func isEmptyNode(node any) bool { if sw, is := node.(*sourcedNodeWrapper); is { node = sw.node } - if ns, is := node.(*staticAsc); is { + if ns, is := node.(*staticNode); is { return len(ns.data) == 0 } return false diff --git a/node/sourced_node_wrapper.go b/node/sourced_node_wrapper.go index 369e2eed..3106e216 100644 --- a/node/sourced_node_wrapper.go +++ b/node/sourced_node_wrapper.go @@ -3,6 +3,8 @@ package node type sourcedNodeWrapper struct { node Node source uint32 + batch []uint32 + idx int } func (*sourcedNodeWrapper) String() string { @@ -10,8 +12,19 @@ func (*sourcedNodeWrapper) String() string { } func (w *sourcedNodeWrapper) NextSourced() (uint32, uint32, bool) { - id, has := w.node.Next() - return id, w.source, has + // If current batch is exhausted, get next batch + for w.idx >= len(w.batch) { + // TODO support batching + w.batch = w.node.Next(1) + w.idx = 0 + if w.batch == nil { + return 0, w.source, false + } + } + + id := w.batch[w.idx] + w.idx++ + return id, w.source, true } func NewSourcedNodeWrapper(d Node, source int) Sourced {