summaryrefslogtreecommitdiffstats
path: root/vendor/gopkg.in/throttled/throttled.v1/store/redis.go
blob: b089f9f4e4812bbc03d7069060c1fbdb90de24c9 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
package store

import (
	"time"

	"github.com/garyburd/redigo/redis"
	"gopkg.in/throttled/throttled.v1"
)

// redisStore implements a Redis-based store.
type redisStore struct {
	pool   *redis.Pool
	prefix string
	db     int
}

// NewRedisStore creates a new Redis-based store, using the provided pool to get its
// connections. The keys will have the specified keyPrefix, which may be an empty string,
// and the database index specified by db will be selected to store the keys.
//
func NewRedisStore(pool *redis.Pool, keyPrefix string, db int) throttled.Store {
	return &redisStore{
		pool:   pool,
		prefix: keyPrefix,
		db:     db,
	}
}

// Incr increments the specified key. If the key did not exist, it sets it to 1
// and sets it to expire after the number of seconds specified by window.
//
// It returns the new count value and the number of remaining seconds, or an error
// if the operation fails.
func (r *redisStore) Incr(key string, window time.Duration) (int, int, error) {
	conn := r.pool.Get()
	defer conn.Close()
	if err := selectDB(r.db, conn); err != nil {
		return 0, 0, err
	}
	// Atomically increment and read the TTL.
	conn.Send("MULTI")
	conn.Send("INCR", r.prefix+key)
	conn.Send("TTL", r.prefix+key)
	vals, err := redis.Values(conn.Do("EXEC"))
	if err != nil {
		conn.Do("DISCARD")
		return 0, 0, err
	}
	var cnt, ttl int
	if _, err = redis.Scan(vals, &cnt, &ttl); err != nil {
		return 0, 0, err
	}
	// If there was no TTL set, then this is a newly created key (INCR creates the key
	// if it didn't exist), so set it to expire.
	if ttl == -1 {
		ttl = int(window.Seconds())
		_, err = conn.Do("EXPIRE", r.prefix+key, ttl)
		if err != nil {
			return 0, 0, err
		}
	}
	return cnt, ttl, nil
}

// Reset sets the value of the key to 1, and resets its time window.
func (r *redisStore) Reset(key string, window time.Duration) error {
	conn := r.pool.Get()
	defer conn.Close()
	if err := selectDB(r.db, conn); err != nil {
		return err
	}
	_, err := redis.String(conn.Do("SET", r.prefix+key, "1", "EX", int(window.Seconds()), "NX"))
	return err
}

// Select the specified database index.
func selectDB(db int, conn redis.Conn) error {
	// Select the specified database
	if db > 0 {
		if _, err := redis.String(conn.Do("SELECT", db)); err != nil {
			return err
		}
	}
	return nil
}