diff --git a/contrib/github.com/go-redis/redis.v5/redis.go b/contrib/github.com/go-redis/redis.v5/redis.go new file mode 100644 index 0000000..cf3ae4b --- /dev/null +++ b/contrib/github.com/go-redis/redis.v5/redis.go @@ -0,0 +1,44 @@ +package redis + +import ( + "time" + + "github.com/ntindall/speedbump/internal" + redis "gopkg.in/redis.v5" +) + +// Wrapper is a wrapper around *redis.Client that implements the +// internal.RedisClient interface. +type Wrapper struct { + *redis.Client +} + +var _ internal.RedisClient = &Wrapper{} + +func (w *Wrapper) Exists(key string) (exists bool, err error) { + return w.Client.Exists(key).Result() +} + +func (w *Wrapper) Get(key string) (value string, err error) { + return w.Client.Get(key).Result() +} + +func (w *Wrapper) IncrAndExpire(key string, duration time.Duration) error { + return w.Client.Watch(func(rx *redis.Tx) error { + _, err := rx.Pipelined(func(pipe *redis.Pipeline) error { + if err := pipe.Incr(key).Err(); err != nil { + return err + } + + return pipe.Expire(key, duration).Err() + }) + + return err + }) +} + +// NewRedisClient constructs a speedbump.RedisClient from a "gopkg.in/redis.v5" +// redis.Client. +func NewRedisClient(redisClient *redis.Client) internal.RedisClient { + return &Wrapper{redisClient} +} diff --git a/contrib/github.com/gomodule/redigo/redis/redis.go b/contrib/github.com/gomodule/redigo/redis/redis.go new file mode 100644 index 0000000..869fed5 --- /dev/null +++ b/contrib/github.com/gomodule/redigo/redis/redis.go @@ -0,0 +1,41 @@ +package redis + +import ( + "time" + + redis "github.com/gomodule/redigo/redis" + "github.com/ntindall/speedbump/internal" +) + +type redisWrapper struct { + conn redis.Conn +} + +var _ internal.RedisClient = &redisWrapper{} + +func (w *redisWrapper) Exists(key string) (exists bool, err error) { + return redis.Bool(w.conn.Do("EXISTS", key)) +} + +func (w *redisWrapper) Get(key string) (value string, err error) { + return redis.String(w.conn.Do("GET", key)) +} + +func (w *redisWrapper) IncrAndExpire(key string, duration time.Duration) error { + if err := w.conn.Send("MULTI"); err != nil { + return err + } + if err := w.conn.Send("INCR", key); err != nil { + return err + } + if err := w.conn.Send("EXPIRE", key, duration/time.Second); err != nil { + return err + } + _, err := w.conn.Do("EXEC") + return err +} + +// NewRedisClient constructs a internal.RedisClient from a redigo connection. +func NewRedisClient(redisConn redis.Conn) internal.RedisClient { + return &redisWrapper{conn: redisConn} +} diff --git a/internal/internal.go b/internal/internal.go new file mode 100644 index 0000000..5f44546 --- /dev/null +++ b/internal/internal.go @@ -0,0 +1,14 @@ +package internal + +import ( + "time" +) + +// RedisClient is an abstraction over speedbump connection to redis. +// It is exported from internal so that it can only be instructed from +// within the package. +type RedisClient interface { + Get(key string) (value string, err error) + Exists(key string) (exists bool, err error) + IncrAndExpire(key string, duration time.Duration) error +} diff --git a/speedbump.go b/speedbump.go index 5809d3f..578881d 100644 --- a/speedbump.go +++ b/speedbump.go @@ -5,13 +5,17 @@ import ( "strconv" "time" - "gopkg.in/redis.v5" + "github.com/ntindall/speedbump/internal" +) + +var ( + redisNil string = "redis: nil" ) // RateLimiter is a Redis-backed rate limiter. type RateLimiter struct { // redisClient is the client that will be used to talk to the Redis server. - redisClient *redis.Client + redisClient internal.RedisClient // hasher is used to generate keys for each counter and to set their // expiration time. hasher RateHasher @@ -36,7 +40,7 @@ type RateHasher interface { // NewLimiter creates a new instance of a rate limiter. func NewLimiter( - client *redis.Client, + client internal.RedisClient, hasher RateHasher, max int64, ) *RateLimiter { @@ -51,7 +55,7 @@ func NewLimiter( // during the current period. func (r *RateLimiter) Has(id string) (bool, error) { hash := r.hasher.Hash(id) - return r.redisClient.Exists(hash).Result() + return r.redisClient.Exists(hash) } // Attempted returns the number of attempted requests for an id in the current @@ -59,10 +63,10 @@ func (r *RateLimiter) Has(id string) (bool, error) { // interval and only returns the max count after this is reached. func (r *RateLimiter) Attempted(id string) (int64, error) { hash := r.hasher.Hash(id) - val, err := r.redisClient.Get(hash).Result() + val, err := r.redisClient.Get(hash) if err != nil { - if err == redis.Nil { + if err.Error() == redisNil { // Key does not exist. See: http://redis.io/commands/GET return 0, nil } @@ -104,9 +108,9 @@ func (r *RateLimiter) Attempt(id string) (bool, error) { // exist. exists := true - val, err := r.redisClient.Get(hash).Result() + val, err := r.redisClient.Get(hash) if err != nil { - if err == redis.Nil { + if err.Error() == redisNil { // Key does not exist. See: http://redis.io/commands/GET exists = false } else { @@ -132,17 +136,7 @@ func (r *RateLimiter) Attempt(id string) (bool, error) { // // See: http://redis.io/commands/INCR // See: http://redis.io/commands/INCR#pattern-rate-limiter-1 - err = r.redisClient.Watch(func(rx *redis.Tx) error { - _, err := rx.Pipelined(func(pipe *redis.Pipeline) error { - if err := pipe.Incr(hash).Err(); err != nil { - return err - } - - return pipe.Expire(hash, r.hasher.Duration()).Err() - }) - - return err - }) + err = r.redisClient.IncrAndExpire(hash, r.hasher.Duration()) if err != nil { return false, err diff --git a/speedbump_test.go b/speedbump_test.go index de1708e..c597d5d 100644 --- a/speedbump_test.go +++ b/speedbump_test.go @@ -1,4 +1,4 @@ -package speedbump +package speedbump_test import ( "fmt" @@ -7,44 +7,33 @@ import ( "time" "github.com/facebookgo/clock" + "github.com/ntindall/speedbump" + contribredis "github.com/ntindall/speedbump/contrib/github.com/go-redis/redis.v5" + "github.com/ntindall/speedbump/internal" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - - "gopkg.in/redis.v5" + redis "gopkg.in/redis.v5" ) -func createClient() *redis.Client { +func createClient() internal.RedisClient { + addr := "localhost:6379" if os.Getenv("WERCKER_REDIS_HOST") != "" { - return redis.NewClient(&redis.Options{ - Addr: os.Getenv("WERCKER_REDIS_HOST") + ":6379", - Password: "", - DB: 0, - }) + addr = os.Getenv("WERCKER_REDIS_HOST") + ":6379" } - return redis.NewClient(&redis.Options{ - Addr: "localhost:6379", - Password: "", - DB: 0, - }) -} - -func teardown(t *testing.T, client *redis.Client) { - // Flush Redis. - require.NoError(t, client.FlushAll().Err()) + return contribredis.NewRedisClient( + redis.NewClient(&redis.Options{ + Addr: addr, + Password: "", + DB: 0, + }), + ) } -func TestNewLimiter(t *testing.T) { - client := createClient() - hasher := PerSecondHasher{} - max := int64(10) - actual := NewLimiter(client, hasher, max) +func teardown(t *testing.T, client internal.RedisClient) { - assert.Exactly(t, RateLimiter{ - redisClient: client, - hasher: hasher, - max: max, - }, *actual) + // Flush Redis. + require.NoError(t, client.(*contribredis.Wrapper).FlushAll().Err()) } func ExampleNewLimiter() { @@ -52,10 +41,10 @@ func ExampleNewLimiter() { client := createClient() // Create a new hasher. - hasher := PerSecondHasher{} + hasher := speedbump.PerSecondHasher{} // Create a new limiter that will only allow 10 requests per second. - limiter := NewLimiter(client, hasher, 10) + limiter := speedbump.NewLimiter(client, hasher, 10) fmt.Println(limiter.Attempt("127.0.0.1")) // Output: true @@ -66,7 +55,7 @@ func TestHas(t *testing.T) { client := createClient() defer teardown(t, client) // Create limiter of 5 requests/min. - limiter := NewLimiter(client, PerMinuteHasher{}, 5) + limiter := speedbump.NewLimiter(client, speedbump.PerMinuteHasher{}, 5) // Choose an arbitrary id. testID := "test_id" @@ -106,11 +95,11 @@ func TestAttempt(t *testing.T) { defer teardown(t, client) // Create PerMinuteHasher with mock clock. mock := clock.NewMock() - hasher := PerMinuteHasher{ + hasher := speedbump.PerMinuteHasher{ Clock: mock, } // Create limiter of 5 requests/min. - limiter := NewLimiter(client, hasher, 5) + limiter := speedbump.NewLimiter(client, hasher, 5) // Choose an arbitrary id. testID := "test_id" // Ensure no key exists before first request for testID. @@ -217,7 +206,7 @@ func TestAttempt(t *testing.T) { assert.True(t, ok, "Attempts returned false after waiting for interval") } -func makeNAttempts(t *testing.T, limiter *RateLimiter, id string, n int64) { +func makeNAttempts(t *testing.T, limiter *speedbump.RateLimiter, id string, n int64) { var i int64 for i = 0; i < n; i++ { _, err := limiter.Attempt(id) @@ -231,12 +220,12 @@ func TestAttemptedLeft(t *testing.T) { defer teardown(t, client) // Create PerMinuteHasher with mock clock. mock := clock.NewMock() - hasher := PerMinuteHasher{ + hasher := speedbump.PerMinuteHasher{ Clock: mock, } max := int64(5) // Create limiter of 5 requests/min. - limiter := NewLimiter(client, hasher, max) + limiter := speedbump.NewLimiter(client, hasher, max) // Choose an arbitrary id. testID := "test_id"