diff --git a/rueidiscompat/adapter.go b/rueidiscompat/adapter.go index 3b3cff55..c4be320d 100644 --- a/rueidiscompat/adapter.go +++ b/rueidiscompat/adapter.go @@ -53,7 +53,7 @@ var Nil = rueidis.Nil type Cmdable interface { CoreCmdable - Cache(ttl time.Duration) CacheCompat + Cache(ttl time.Duration) CacheCmdable Subscribe(ctx context.Context, channels ...string) PubSub PSubscribe(ctx context.Context, patterns ...string) PubSub @@ -625,6 +625,93 @@ type JSONCmdable interface { JSONType(ctx context.Context, key, path string) *JSONSliceCmd } +type CacheCmdable interface { + BFExists(ctx context.Context, key string, element interface{}) *BoolCmd + BFInfo(ctx context.Context, key string) *BFInfoCmd + BFInfoArg(ctx context.Context, key string, option string) *BFInfoCmd + BFInfoCapacity(ctx context.Context, key string) *BFInfoCmd + BFInfoExpansion(ctx context.Context, key string) *BFInfoCmd + BFInfoFilters(ctx context.Context, key string) *BFInfoCmd + BFInfoItems(ctx context.Context, key string) *BFInfoCmd + BFInfoSize(ctx context.Context, key string) *BFInfoCmd + BitCount(ctx context.Context, key string, bitCount *BitCount) *IntCmd + BitFieldRO(ctx context.Context, key string, args ...any) *IntSliceCmd + BitPos(ctx context.Context, key string, bit int64, pos ...int64) *IntCmd + BitPosSpan(ctx context.Context, key string, bit int64, start int64, end int64, span string) *IntCmd + CFCount(ctx context.Context, key string, element interface{}) *IntCmd + CFExists(ctx context.Context, key string, element interface{}) *BoolCmd + CFInfo(ctx context.Context, key string) *CFInfoCmd + CMSInfo(ctx context.Context, key string) *CMSInfoCmd + CMSQuery(ctx context.Context, key string, elements ...interface{}) *IntSliceCmd + EvalRO(ctx context.Context, script string, keys []string, args ...any) *Cmd + EvalShaRO(ctx context.Context, sha1 string, keys []string, args ...any) *Cmd + FCallRO(ctx context.Context, function string, keys []string, args ...any) *Cmd + GeoDist(ctx context.Context, key string, member1 string, member2 string, unit string) *FloatCmd + GeoHash(ctx context.Context, key string, members ...string) *StringSliceCmd + GeoPos(ctx context.Context, key string, members ...string) *GeoPosCmd + GeoRadius(ctx context.Context, key string, longitude float64, latitude float64, query GeoRadiusQuery) *GeoLocationCmd + GeoRadiusByMember(ctx context.Context, key string, member string, query GeoRadiusQuery) *GeoLocationCmd + GeoSearch(ctx context.Context, key string, q GeoSearchQuery) *StringSliceCmd + Get(ctx context.Context, key string) *StringCmd + GetBit(ctx context.Context, key string, offset int64) *IntCmd + GetRange(ctx context.Context, key string, start int64, end int64) *StringCmd + HExists(ctx context.Context, key string, field string) *BoolCmd + HGet(ctx context.Context, key string, field string) *StringCmd + HGetAll(ctx context.Context, key string) *StringStringMapCmd + HKeys(ctx context.Context, key string) *StringSliceCmd + HLen(ctx context.Context, key string) *IntCmd + HMGet(ctx context.Context, key string, fields ...string) *SliceCmd + HStrLen(ctx context.Context, key string, field string) *IntCmd + HVals(ctx context.Context, key string) *StringSliceCmd + JSONArrIndex(ctx context.Context, key string, path string, value ...interface{}) *IntSliceCmd + JSONArrLen(ctx context.Context, key string, path string) *IntSliceCmd + JSONGet(ctx context.Context, key string, paths ...string) *JSONCmd + JSONMGet(ctx context.Context, path string, keys ...string) *JSONSliceCmd + JSONObjKeys(ctx context.Context, key string, path string) *SliceCmd + JSONObjLen(ctx context.Context, key string, path string) *IntPointerSliceCmd + JSONStrLen(ctx context.Context, key string, path string) *IntPointerSliceCmd + JSONType(ctx context.Context, key string, path string) *JSONSliceCmd + LIndex(ctx context.Context, key string, index int64) *StringCmd + LLen(ctx context.Context, key string) *IntCmd + LPos(ctx context.Context, key string, element string, a LPosArgs) *IntCmd + LRange(ctx context.Context, key string, start int64, stop int64) *StringSliceCmd + MGet(ctx context.Context, keys ...string) *SliceCmd + MGetCache(ctx context.Context, keys ...string) *SliceCmd + PTTL(ctx context.Context, key string) *DurationCmd + SCard(ctx context.Context, key string) *IntCmd + SIsMember(ctx context.Context, key string, member any) *BoolCmd + SMIsMember(ctx context.Context, key string, members ...any) *BoolSliceCmd + SMembers(ctx context.Context, key string) *StringSliceCmd + SortRO(ctx context.Context, key string, sort Sort) *StringSliceCmd + StrLen(ctx context.Context, key string) *IntCmd + TTL(ctx context.Context, key string) *DurationCmd + TopKInfo(ctx context.Context, key string) *TopKInfoCmd + TopKList(ctx context.Context, key string) *StringSliceCmd + TopKQuery(ctx context.Context, key string, elements ...interface{}) *BoolSliceCmd + Type(ctx context.Context, key string) *StatusCmd + ZCard(ctx context.Context, key string) *IntCmd + ZCount(ctx context.Context, key string, min string, max string) *IntCmd + ZLexCount(ctx context.Context, key string, min string, max string) *IntCmd + ZMScore(ctx context.Context, key string, members ...string) *FloatSliceCmd + ZRangeArgs(ctx context.Context, z ZRangeArgs) *StringSliceCmd + ZRangeArgsWithScores(ctx context.Context, z ZRangeArgs) *ZSliceCmd + ZRangeByLex(ctx context.Context, key string, opt ZRangeBy) *StringSliceCmd + ZRangeByScore(ctx context.Context, key string, opt ZRangeBy) *StringSliceCmd + ZRangeByScoreWithScores(ctx context.Context, key string, opt ZRangeBy) *ZSliceCmd + ZRangeWithScores(ctx context.Context, key string, start int64, stop int64) *ZSliceCmd + ZRank(ctx context.Context, key string, member string) *IntCmd + ZRankWithScore(ctx context.Context, key string, member string) *RankWithScoreCmd + ZRevRange(ctx context.Context, key string, start int64, stop int64) *StringSliceCmd + ZRevRangeByLex(ctx context.Context, key string, opt ZRangeBy) *StringSliceCmd + ZRevRangeByScore(ctx context.Context, key string, opt ZRangeBy) *StringSliceCmd + ZRevRangeByScoreWithScores(ctx context.Context, key string, opt ZRangeBy) *ZSliceCmd + ZRevRangeWithScores(ctx context.Context, key string, start int64, stop int64) *ZSliceCmd + ZRevRank(ctx context.Context, key string, member string) *IntCmd + ZRevRankWithScore(ctx context.Context, key string, member string) *RankWithScoreCmd + ZScore(ctx context.Context, key string, member string) *FloatCmd + zRangeArgs(withScores bool, z ZRangeArgs) rueidis.Cacheable +} + var _ Cmdable = (*Compat)(nil) type Compat struct { @@ -662,8 +749,8 @@ func NewAdapter(client rueidis.Client, options ...AdapterOption) Cmdable { return c } -func (c *Compat) Cache(ttl time.Duration) CacheCompat { - return CacheCompat{client: c.client, ttl: ttl} +func (c *Compat) Cache(ttl time.Duration) CacheCmdable { + return &CacheCompat{client: c.client, ttl: ttl} } func (c *Compat) Command(ctx context.Context) *CommandsInfoCmd { @@ -1604,15 +1691,16 @@ func (c *Compat) HGetEXWithArgs(ctx context.Context, key string, options *HGetEX } var cmd rueidis.Completed - if options.ExpirationType == HGetEXExpirationEX { + switch options.ExpirationType { + case HGetEXExpirationEX: cmd = c.client.B().Hgetex().Key(key).Ex(options.ExpirationVal).Fields().Numfields(int64(len(fields))).Field(fields...).Build() - } else if options.ExpirationType == HGetEXExpirationPX { + case HGetEXExpirationPX: cmd = c.client.B().Hgetex().Key(key).Px(options.ExpirationVal).Fields().Numfields(int64(len(fields))).Field(fields...).Build() - } else if options.ExpirationType == HGetEXExpirationEXAT { + case HGetEXExpirationEXAT: cmd = c.client.B().Hgetex().Key(key).Exat(options.ExpirationVal).Fields().Numfields(int64(len(fields))).Field(fields...).Build() - } else if options.ExpirationType == HGetEXExpirationPXAT { + case HGetEXExpirationPXAT: cmd = c.client.B().Hgetex().Key(key).Pxat(options.ExpirationVal).Fields().Numfields(int64(len(fields))).Field(fields...).Build() - } else if options.ExpirationType == HGetEXExpirationPERSIST { + case HGetEXExpirationPERSIST: cmd = c.client.B().Hgetex().Key(key).Persist().Fields().Numfields(int64(len(fields))).Field(fields...).Build() } resp := c.client.Do(ctx, cmd) @@ -1637,28 +1725,31 @@ func (c *Compat) HSetEXWithArgs(ctx context.Context, key string, options *HSetEX } var partial cmds.HsetexFieldValue - if options.Condition == HSetEXFNX { - if options.ExpirationType == HSetEXExpirationEX { + switch options.Condition { + case HSetEXFNX: + switch options.ExpirationType { + case HSetEXExpirationEX: partial = c.client.B().Hsetex().Key(key).Fnx().Ex(options.ExpirationVal).Fields().Numfields(int64(len(fieldsAndValues) / 2)).FieldValue() - } else if options.ExpirationType == HSetEXExpirationPX { + case HSetEXExpirationPX: partial = c.client.B().Hsetex().Key(key).Fnx().Px(options.ExpirationVal).Fields().Numfields(int64(len(fieldsAndValues) / 2)).FieldValue() - } else if options.ExpirationType == HSetEXExpirationEXAT { + case HSetEXExpirationEXAT: partial = c.client.B().Hsetex().Key(key).Fnx().Exat(options.ExpirationVal).Fields().Numfields(int64(len(fieldsAndValues) / 2)).FieldValue() - } else if options.ExpirationType == HSetEXExpirationPXAT { + case HSetEXExpirationPXAT: partial = c.client.B().Hsetex().Key(key).Fnx().Pxat(options.ExpirationVal).Fields().Numfields(int64(len(fieldsAndValues) / 2)).FieldValue() - } else if options.ExpirationType == HSetEXExpirationKEEPTTL { + case HSetEXExpirationKEEPTTL: partial = c.client.B().Hsetex().Key(key).Fnx().Keepttl().Fields().Numfields(int64(len(fieldsAndValues) / 2)).FieldValue() } - } else if options.Condition == HSetEXFXX { - if options.ExpirationType == HSetEXExpirationEX { + case HSetEXFXX: + switch options.ExpirationType { + case HSetEXExpirationEX: partial = c.client.B().Hsetex().Key(key).Fxx().Ex(options.ExpirationVal).Fields().Numfields(int64(len(fieldsAndValues) / 2)).FieldValue() - } else if options.ExpirationType == HSetEXExpirationPX { + case HSetEXExpirationPX: partial = c.client.B().Hsetex().Key(key).Fxx().Px(options.ExpirationVal).Fields().Numfields(int64(len(fieldsAndValues) / 2)).FieldValue() - } else if options.ExpirationType == HSetEXExpirationEXAT { + case HSetEXExpirationEXAT: partial = c.client.B().Hsetex().Key(key).Fxx().Exat(options.ExpirationVal).Fields().Numfields(int64(len(fieldsAndValues) / 2)).FieldValue() - } else if options.ExpirationType == HSetEXExpirationPXAT { + case HSetEXExpirationPXAT: partial = c.client.B().Hsetex().Key(key).Fxx().Pxat(options.ExpirationVal).Fields().Numfields(int64(len(fieldsAndValues) / 2)).FieldValue() - } else if options.ExpirationType == HSetEXExpirationKEEPTTL { + case HSetEXExpirationKEEPTTL: partial = c.client.B().Hsetex().Key(key).Fxx().Keepttl().Fields().Numfields(int64(len(fieldsAndValues) / 2)).FieldValue() } } @@ -6044,6 +6135,28 @@ func (c CacheCompat) MGet(ctx context.Context, keys ...string) *SliceCmd { return newSliceCmd(resp, false, keys...) } +// MGetCache will internally call rueidis.MGetCache which only sends the keys to server which are not in cache. +func (c CacheCompat) MGetCache(ctx context.Context, keys ...string) *SliceCmd { + mgetResult, err := rueidis.MGetCache(c.client, ctx, c.ttl, keys) + cmd := &SliceCmd{keys: keys} + vals := make([]any, len(keys)) + for i, key := range keys { + msg := mgetResult[key] + if msg.IsNil() { + vals[i] = nil + } else { + if s, err := msg.ToString(); err == nil { + vals[i] = s + } + } + } + cmd.SetVal(vals) + if err != nil { + cmd.SetErr(err) + } + return cmd +} + func (c CacheCompat) GetBit(ctx context.Context, key string, offset int64) *IntCmd { cmd := c.client.B().Getbit().Key(key).Offset(offset).Cache() resp := c.client.DoCache(ctx, cmd, c.ttl) diff --git a/rueidiscompat/adapter_test.go b/rueidiscompat/adapter_test.go index 74d3d7a5..ab65f162 100644 --- a/rueidiscompat/adapter_test.go +++ b/rueidiscompat/adapter_test.go @@ -7928,6 +7928,28 @@ func testAdapterCache(resp3 bool) { Expect(mGet.Val()).To(Equal([]any{"hello1", nil, "hello2"})) }) + It("should MGetCache", func() { + mGetCache := adapter.Cache(time.Hour).MGetCache(ctx, "_", "key2") + Expect(mGetCache.Err()).NotTo(HaveOccurred()) + Expect(mGetCache.Val()).To(Equal([]any{nil, nil})) + + set := adapter.Set(ctx, "key1", "hello1", 0) + Expect(set.Err()).NotTo(HaveOccurred()) + Expect(set.Val()).To(Equal("OK")) + + set = adapter.Set(ctx, "key2", "hello2", 0) + Expect(set.Err()).NotTo(HaveOccurred()) + Expect(set.Val()).To(Equal("OK")) + + mGetCache = adapter.Cache(time.Hour).MGetCache(ctx, "key1", "key2", "_") + Expect(mGetCache.Err()).NotTo(HaveOccurred()) + Expect(mGetCache.Val()).To(Equal([]any{"hello1", "hello2", nil})) + + mGetCache = adapter.Cache(time.Hour).MGetCache(ctx, "key1", "_", "key2") + Expect(mGetCache.Err()).NotTo(HaveOccurred()) + Expect(mGetCache.Val()).To(Equal([]any{"hello1", nil, "hello2"})) + }) + It("should GetBit", func() { setBit := adapter.SetBit(ctx, "key", 7, 1) Expect(setBit.Err()).NotTo(HaveOccurred())