Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions group.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ type Group interface {
Remove(context.Context, string) error
UsedBytes() (int64, int64)
Name() string
RemoveKeys(ctx context.Context, keys ...string) error
GroupStats() GroupStats
}

// A Getter loads data for a key.
Expand Down Expand Up @@ -108,6 +110,11 @@ func (g *group) Name() string {
return g.name
}

// GroupStats returns the stats for this group.
func (g *group) GroupStats() GroupStats {
return g.Stats
}

// UsedBytes returns the total number of bytes used by the main and hot caches
func (g *group) UsedBytes() (mainCache int64, hotCache int64) {
return g.mainCache.Bytes(), g.hotCache.Bytes()
Expand Down Expand Up @@ -443,6 +450,79 @@ func (g *group) LocalRemove(key string) {
})
}

func (g *group) RemoveKeys(ctx context.Context, keys ...string) error {
if len(keys) == 0 {
return nil
}

g.Stats.RemoveKeysRequests.Add(1)
g.Stats.RemovedKeys.Add(int64(len(keys)))

keysByOwner := make(map[peer.Client][]string)
var localKeys []string

for _, key := range keys {
owner, isRemote := g.instance.PickPeer(key)
if isRemote {
keysByOwner[owner] = append(keysByOwner[owner], key)
} else {
localKeys = append(localKeys, key)
}
}

for _, key := range localKeys {
g.LocalRemove(key)
}

multiErr := &MultiError{}
errCh := make(chan error)

// Send removeKeys requests to owners (parallel)
var wg sync.WaitGroup
for owner, ownerKeys := range keysByOwner {
wg.Add(1)
go func(p peer.Client, k []string) {
errCh <- p.RemoveKeys(ctx, &pb.RemoveKeysRequest{
Group: &g.name,
Keys: k,
})
wg.Done()
}(owner, ownerKeys)
}

allPeers := g.instance.getAllPeers()
for _, p := range allPeers {
if p.PeerInfo().IsSelf {
continue
}
if _, isOwner := keysByOwner[p]; isOwner {
continue
}

wg.Add(1)
go func(peer peer.Client) {
errCh <- peer.RemoveKeys(ctx, &pb.RemoveKeysRequest{
Group: &g.name,
Keys: keys,
})
wg.Done()
}(p)
}

go func() {
wg.Wait()
close(errCh)
}()

for err := range errCh {
if err != nil {
multiErr.Add(err)
}
}

return multiErr.NilOrError()
}

func (g *group) populateCache(key string, value transport.ByteView, cache Cache) {
if g.maxCacheBytes <= 0 {
return
Expand Down Expand Up @@ -524,6 +604,8 @@ func (g *group) registerInstruments(meter otelmetric.Meter) error {
o.ObserveInt64(instruments.LocalLoadsCounter(), g.Stats.LocalLoads.Get(), observeOptions...)
o.ObserveInt64(instruments.LocalLoadErrsCounter(), g.Stats.LocalLoadErrs.Get(), observeOptions...)
o.ObserveInt64(instruments.GetFromPeersLatencyMaxGauge(), g.Stats.GetFromPeersLatencyLower.Get(), observeOptions...)
o.ObserveInt64(instruments.RemoveKeysRequestsCounter(), g.Stats.RemoveKeysRequests.Get(), observeOptions...)
o.ObserveInt64(instruments.RemovedKeysCounter(), g.Stats.RemovedKeys.Get(), observeOptions...)

return nil
},
Expand All @@ -536,6 +618,8 @@ func (g *group) registerInstruments(meter otelmetric.Meter) error {
instruments.LocalLoadsCounter(),
instruments.LocalLoadErrsCounter(),
instruments.GetFromPeersLatencyMaxGauge(),
instruments.RemoveKeysRequestsCounter(),
instruments.RemovedKeysCounter(),
)

return err
Expand Down
4 changes: 3 additions & 1 deletion instance_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -523,11 +523,13 @@ func TestNewGroupRegistersMetricsWithMeterProvider(t *testing.T) {
"groupcache.group.loads.deduped",
"groupcache.group.local.loads",
"groupcache.group.local.load_errors",
"groupcache.group.remove_keys.requests",
"groupcache.group.removed_keys",
}
assert.Equal(t, expectedCounters, recMeter.counterNames)
assert.Equal(t, []string{"groupcache.group.peer.latency_max_ms"}, recMeter.updownNames)
assert.True(t, recMeter.callbackRegistered, "expected callback registration for metrics")
assert.Equal(t, 9, recMeter.instrumentCount)
assert.Equal(t, 11, recMeter.instrumentCount)
}

func TestNewGroupFailsWhenMetricRegistrationFails(t *testing.T) {
Expand Down
260 changes: 260 additions & 0 deletions remove_keys_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,260 @@
/*
Copyright 2024 Groupcache Authors

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package groupcache_test

import (
"context"
"fmt"
"testing"
"time"

"github.com/groupcache/groupcache-go/v3"
"github.com/groupcache/groupcache-go/v3/cluster"
"github.com/groupcache/groupcache-go/v3/transport"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestRemoveKeys(t *testing.T) {
ctx := context.Background()

err := cluster.Start(ctx, 3, groupcache.Options{})
require.NoError(t, err)
defer func() { _ = cluster.Shutdown(ctx) }()

callCount := make(map[string]int)
getter := groupcache.GetterFunc(func(ctx context.Context, key string, dest transport.Sink) error {
callCount[key]++
return dest.SetString(fmt.Sprintf("value-%s", key), time.Now().Add(time.Minute*5))
})

// Register the group on ALL daemons (required for broadcast)
group, err := cluster.DaemonAt(0).NewGroup("test-remove-keys", 3000000, getter)
require.NoError(t, err)
for i := 1; i < 3; i++ {
_, err := cluster.DaemonAt(i).NewGroup("test-remove-keys", 3000000, getter)
require.NoError(t, err)
}

keys := []string{"key1", "key2", "key3"}

// First, populate the cache by getting each key
for _, key := range keys {
var value string
err := group.Get(ctx, key, transport.StringSink(&value))
require.NoError(t, err)
assert.Equal(t, fmt.Sprintf("value-%s", key), value)
}

// Verify getter was called for each key
for _, key := range keys {
assert.Equal(t, 1, callCount[key], "getter should be called once for %s", key)
}

// Now remove all keys using variadic signature
err = group.RemoveKeys(ctx, "key1", "key2", "key3")
require.NoError(t, err)

// Fetch again - getter should be called again since keys were removed
for _, key := range keys {
var value string
err := group.Get(ctx, key, transport.StringSink(&value))
require.NoError(t, err)
}

// Verify getter was called again for each key
for _, key := range keys {
assert.Equal(t, 2, callCount[key], "getter should be called twice for %s after removal", key)
}
}

func TestRemoveKeysEmpty(t *testing.T) {
ctx := context.Background()

err := cluster.Start(ctx, 2, groupcache.Options{})
require.NoError(t, err)
defer func() { _ = cluster.Shutdown(ctx) }()

getter := groupcache.GetterFunc(func(ctx context.Context, key string, dest transport.Sink) error {
return dest.SetString("value", time.Now().Add(time.Minute))
})

// Register the group on ALL daemons
group, err := cluster.DaemonAt(0).NewGroup("test-remove-empty", 3000000, getter)
require.NoError(t, err)
_, err = cluster.DaemonAt(1).NewGroup("test-remove-empty", 3000000, getter)
require.NoError(t, err)

// Test RemoveKeys with no keys - should not error
err = group.RemoveKeys(ctx)
require.NoError(t, err)
}

func TestRemoveKeysWithSlice(t *testing.T) {
ctx := context.Background()

err := cluster.Start(ctx, 2, groupcache.Options{})
require.NoError(t, err)
defer func() { _ = cluster.Shutdown(ctx) }()

getter := groupcache.GetterFunc(func(ctx context.Context, key string, dest transport.Sink) error {
return dest.SetString(fmt.Sprintf("value-%s", key), time.Now().Add(time.Minute*5))
})

// Register the group on ALL daemons
group, err := cluster.DaemonAt(0).NewGroup("test-remove-slice", 3000000, getter)
require.NoError(t, err)
_, err = cluster.DaemonAt(1).NewGroup("test-remove-slice", 3000000, getter)
require.NoError(t, err)

keys := []string{"key1", "key2", "key3"}

// Populate cache
for _, key := range keys {
var value string
err := group.Get(ctx, key, transport.StringSink(&value))
require.NoError(t, err)
}

// Test RemoveKeys with slice expansion
err = group.RemoveKeys(ctx, keys...)
require.NoError(t, err)
}

func TestRemoveKeysStats(t *testing.T) {
ctx := context.Background()

err := cluster.Start(ctx, 2, groupcache.Options{})
require.NoError(t, err)
defer func() { _ = cluster.Shutdown(ctx) }()

getter := groupcache.GetterFunc(func(ctx context.Context, key string, dest transport.Sink) error {
return dest.SetString(fmt.Sprintf("value-%s", key), time.Now().Add(time.Minute*5))
})

// Register the group on ALL daemons
transportGroup, err := cluster.DaemonAt(0).NewGroup("test-remove-stats", 3000000, getter)
require.NoError(t, err)
_, err = cluster.DaemonAt(1).NewGroup("test-remove-stats", 3000000, getter)
require.NoError(t, err)

// Cast to groupcache.Group to access GroupStats()
group, ok := transportGroup.(groupcache.Group)
require.True(t, ok, "expected transportGroup to implement groupcache.Group")

// Capture stats before RemoveKeys
statsBefore := group.GroupStats()
removeKeysRequestsBefore := statsBefore.RemoveKeysRequests.Get()
removedKeysBefore := statsBefore.RemovedKeys.Get()

err = group.RemoveKeys(ctx, "key1", "key2", "key3")
require.NoError(t, err)

// Verify stats were incremented correctly
statsAfter := group.GroupStats()
assert.Equal(t, removeKeysRequestsBefore+1, statsAfter.RemoveKeysRequests.Get(), "RemoveKeysRequests should be incremented by 1")
assert.Equal(t, removedKeysBefore+3, statsAfter.RemovedKeys.Get(), "RemovedKeys should be incremented by 3")
}

func BenchmarkRemoveKeys(b *testing.B) {
ctx := context.Background()

err := cluster.Start(ctx, 3, groupcache.Options{})
if err != nil {
b.Fatal(err)
}
defer func() { _ = cluster.Shutdown(ctx) }()

getter := groupcache.GetterFunc(func(ctx context.Context, key string, dest transport.Sink) error {
return dest.SetString(fmt.Sprintf("value-%s", key), time.Now().Add(time.Minute*5))
})

// Register the group on ALL daemons
group, err := cluster.DaemonAt(0).NewGroup("bench-remove", 3000000, getter)
if err != nil {
b.Fatal(err)
}
for i := 1; i < 3; i++ {
_, err := cluster.DaemonAt(i).NewGroup("bench-remove", 3000000, getter)
if err != nil {
b.Fatal(err)
}
}

// Prepare keys
keys := make([]string, 100)
for i := 0; i < 100; i++ {
keys[i] = fmt.Sprintf("key-%d", i)
}

// Populate cache first
for _, key := range keys {
var value string
_ = group.Get(ctx, key, transport.StringSink(&value))
}

b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = group.RemoveKeys(ctx, keys...)
}
}

func BenchmarkRemoveKeysVsLoop(b *testing.B) {
ctx := context.Background()

err := cluster.Start(ctx, 3, groupcache.Options{})
if err != nil {
b.Fatal(err)
}
defer func() { _ = cluster.Shutdown(ctx) }()

getter := groupcache.GetterFunc(func(ctx context.Context, key string, dest transport.Sink) error {
return dest.SetString(fmt.Sprintf("value-%s", key), time.Now().Add(time.Minute*5))
})

// Register the group on ALL daemons
group, err := cluster.DaemonAt(0).NewGroup("bench-compare", 3000000, getter)
if err != nil {
b.Fatal(err)
}
for i := 1; i < 3; i++ {
_, err := cluster.DaemonAt(i).NewGroup("bench-compare", 3000000, getter)
if err != nil {
b.Fatal(err)
}
}

// Prepare keys
keys := make([]string, 50)
for i := 0; i < 50; i++ {
keys[i] = fmt.Sprintf("key-%d", i)
}

b.Run("RemoveKeys", func(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = group.RemoveKeys(ctx, keys...)
}
})

b.Run("LoopRemove", func(b *testing.B) {
for i := 0; i < b.N; i++ {
for _, key := range keys {
_ = group.Remove(ctx, key)
}
}
})
}
Loading
Loading