1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
|
package store
import (
"time"
"github.com/garyburd/redigo/redis"
"gopkg.in/throttled/throttled.v1"
)
// redisStore implements a Redis-based store.
type redisStore struct {
pool *redis.Pool
prefix string
db int
}
// NewRedisStore creates a new Redis-based store, using the provided pool to get its
// connections. The keys will have the specified keyPrefix, which may be an empty string,
// and the database index specified by db will be selected to store the keys.
//
func NewRedisStore(pool *redis.Pool, keyPrefix string, db int) throttled.Store {
return &redisStore{
pool: pool,
prefix: keyPrefix,
db: db,
}
}
// Incr increments the specified key. If the key did not exist, it sets it to 1
// and sets it to expire after the number of seconds specified by window.
//
// It returns the new count value and the number of remaining seconds, or an error
// if the operation fails.
func (r *redisStore) Incr(key string, window time.Duration) (int, int, error) {
conn := r.pool.Get()
defer conn.Close()
if err := selectDB(r.db, conn); err != nil {
return 0, 0, err
}
// Atomically increment and read the TTL.
conn.Send("MULTI")
conn.Send("INCR", r.prefix+key)
conn.Send("TTL", r.prefix+key)
vals, err := redis.Values(conn.Do("EXEC"))
if err != nil {
conn.Do("DISCARD")
return 0, 0, err
}
var cnt, ttl int
if _, err = redis.Scan(vals, &cnt, &ttl); err != nil {
return 0, 0, err
}
// If there was no TTL set, then this is a newly created key (INCR creates the key
// if it didn't exist), so set it to expire.
if ttl == -1 {
ttl = int(window.Seconds())
_, err = conn.Do("EXPIRE", r.prefix+key, ttl)
if err != nil {
return 0, 0, err
}
}
return cnt, ttl, nil
}
// Reset sets the value of the key to 1, and resets its time window.
func (r *redisStore) Reset(key string, window time.Duration) error {
conn := r.pool.Get()
defer conn.Close()
if err := selectDB(r.db, conn); err != nil {
return err
}
_, err := redis.String(conn.Do("SET", r.prefix+key, "1", "EX", int(window.Seconds()), "NX"))
return err
}
// Select the specified database index.
func selectDB(db int, conn redis.Conn) error {
// Select the specified database
if db > 0 {
if _, err := redis.String(conn.Do("SELECT", db)); err != nil {
return err
}
}
return nil
}
|