summaryrefslogtreecommitdiffstats
path: root/app/ratelimit.go
blob: d7b96dae343781806eb2c7487f4d91e270b63f7d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
// Copyright (c) 2018-present Mattermost, Inc. All Rights Reserved.
// See License.txt for license information.

package app

import (
	"fmt"
	"math"
	"net/http"
	"strconv"
	"strings"

	"github.com/mattermost/mattermost-server/mlog"
	"github.com/mattermost/mattermost-server/model"
	"github.com/mattermost/mattermost-server/utils"
	"github.com/pkg/errors"
	"github.com/throttled/throttled"
	"github.com/throttled/throttled/store/memstore"
)

type RateLimiter struct {
	throttledRateLimiter *throttled.GCRARateLimiter
	useAuth              bool
	useIP                bool
	header               string
}

func NewRateLimiter(settings *model.RateLimitSettings) (*RateLimiter, error) {
	store, err := memstore.New(*settings.MemoryStoreSize)
	if err != nil {
		return nil, errors.Wrap(err, utils.T("api.server.start_server.rate_limiting_memory_store"))
	}

	quota := throttled.RateQuota{
		MaxRate:  throttled.PerSec(*settings.PerSec),
		MaxBurst: *settings.MaxBurst,
	}

	throttledRateLimiter, err := throttled.NewGCRARateLimiter(store, quota)
	if err != nil {
		return nil, errors.Wrap(err, utils.T("api.server.start_server.rate_limiting_rate_limiter"))
	}

	return &RateLimiter{
		throttledRateLimiter: throttledRateLimiter,
		useAuth:              *settings.VaryByUser,
		useIP:                *settings.VaryByRemoteAddr,
		header:               settings.VaryByHeader,
	}, nil
}

func (rl *RateLimiter) GenerateKey(r *http.Request) string {
	key := ""

	if rl.useAuth {
		token, tokenLocation := ParseAuthTokenFromRequest(r)
		if tokenLocation != TokenLocationNotFound {
			key += token
		} else if rl.useIP { // If we don't find an authentication token and IP based is enabled, fall back to IP
			key += utils.GetIpAddress(r)
		}
	} else if rl.useIP { // Only if Auth based is not enabed do we use a plain IP based
		key += utils.GetIpAddress(r)
	}

	// Note that most of the time the user won't have to set this because the utils.GetIpAddress above tries the
	// most common headers anyway.
	if rl.header != "" {
		key += strings.ToLower(r.Header.Get(rl.header))
	}

	return key
}

func (rl *RateLimiter) RateLimitWriter(key string, w http.ResponseWriter) bool {
	limited, context, err := rl.throttledRateLimiter.RateLimit(key, 1)
	if err != nil {
		mlog.Critical("Internal server error when rate limiting. Rate Limiting broken. Error:" + err.Error())
		return false
	}

	setRateLimitHeaders(w, context)

	if limited {
		mlog.Error(fmt.Sprintf("Denied due to throttling settings code=429 key=%v", key))
		http.Error(w, "limit exceeded", 429)
	}

	return limited
}

func (rl *RateLimiter) UserIdRateLimit(userId string, w http.ResponseWriter) bool {
	if rl.useAuth {
		if rl.RateLimitWriter(userId, w) {
			return true
		}
	}
	return false
}

func (rl *RateLimiter) RateLimitHandler(wrappedHandler http.Handler) http.Handler {
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		key := rl.GenerateKey(r)
		limited := rl.RateLimitWriter(key, w)

		if !limited {
			wrappedHandler.ServeHTTP(w, r)
		}
	})
}

// Copied from https://github.com/throttled/throttled http.go
func setRateLimitHeaders(w http.ResponseWriter, context throttled.RateLimitResult) {
	if v := context.Limit; v >= 0 {
		w.Header().Add("X-RateLimit-Limit", strconv.Itoa(v))
	}

	if v := context.Remaining; v >= 0 {
		w.Header().Add("X-RateLimit-Remaining", strconv.Itoa(v))
	}

	if v := context.ResetAfter; v >= 0 {
		vi := int(math.Ceil(v.Seconds()))
		w.Header().Add("X-RateLimit-Reset", strconv.Itoa(vi))
	}

	if v := context.RetryAfter; v >= 0 {
		vi := int(math.Ceil(v.Seconds()))
		w.Header().Add("Retry-After", strconv.Itoa(vi))
	}
}