1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
|
package throttled
import (
"net/http"
"net/http/httptest"
"strconv"
"testing"
"time"
)
const deniedStatus = 429
// Simple memory store for tests, unsafe for concurrent access
type mapStore struct {
cnt map[string]int
ts map[string]time.Time
}
func newMapStore() *mapStore {
return &mapStore{
make(map[string]int),
make(map[string]time.Time),
}
}
func (ms *mapStore) Incr(key string, window time.Duration) (int, int, error) {
if _, ok := ms.cnt[key]; !ok {
return 0, 0, ErrNoSuchKey
}
ms.cnt[key]++
ts := ms.ts[key]
return ms.cnt[key], RemainingSeconds(ts, window), nil
}
func (ms *mapStore) Reset(key string, win time.Duration) error {
ms.cnt[key] = 1
ms.ts[key] = time.Now().UTC()
return nil
}
func TestRateLimit(t *testing.T) {
quota := Q{5, 5 * time.Second}
cases := []struct {
limit, remain, reset, status int
}{
0: {5, 4, 5, 200},
1: {5, 3, 4, 200},
2: {5, 2, 4, 200},
3: {5, 1, 3, 200},
4: {5, 0, 3, 200},
5: {5, 0, 2, deniedStatus},
}
// Limit the requests to 2 per second
th := Interval(PerSec(2), 0, nil, 0)
// Rate limit
rl := RateLimit(quota, nil, newMapStore())
// Create the stats
st := &stats{}
// Create the handler
h := th.Throttle(rl.Throttle(st))
// Start the server
srv := httptest.NewServer(h)
defer srv.Close()
for i, c := range cases {
callRateLimited(t, i, c.limit, c.remain, c.reset, c.status, srv.URL)
}
// Wait 3 seconds and call again, should start a new window
time.Sleep(3 * time.Second)
callRateLimited(t, len(cases), 5, 4, 5, 200, srv.URL)
}
func callRateLimited(t *testing.T, i, limit, remain, reset, status int, url string) {
res, err := http.Get(url)
if err != nil {
t.Fatal(err)
}
defer res.Body.Close()
// Assert status code
if status != res.StatusCode {
t.Errorf("%d: expected status %d, got %d", i, status, res.StatusCode)
}
// Assert headers
if v := res.Header.Get("X-RateLimit-Limit"); v != strconv.Itoa(limit) {
t.Errorf("%d: expected limit header to be %d, got %s", i, limit, v)
}
if v := res.Header.Get("X-RateLimit-Remaining"); v != strconv.Itoa(remain) {
t.Errorf("%d: expected remain header to be %d, got %s", i, remain, v)
}
// Allow 1 second wiggle room
v := res.Header.Get("X-RateLimit-Reset")
vi, _ := strconv.Atoi(v)
if vi < reset-1 || vi > reset+1 {
t.Errorf("%d: expected reset header to be close to %d, got %d", i, reset, vi)
}
if status == deniedStatus {
v := res.Header.Get("Retry-After")
vi, _ := strconv.Atoi(v)
if vi < reset-1 || vi > reset+1 {
t.Errorf("%d: expected retry after header to be close to %d, got %d", i, reset, vi)
}
}
}
|