From 1262d254736229618582f0963c9c30c4e66efb98 Mon Sep 17 00:00:00 2001 From: Christopher Speller Date: Wed, 31 Jan 2018 09:49:15 -0800 Subject: User based rate limiting (#8152) --- app/server.go | 51 +++++++++++++++++++++------------------------------ 1 file changed, 21 insertions(+), 30 deletions(-) (limited to 'app/server.go') diff --git a/app/server.go b/app/server.go index c008da3a1..2a94bf2c7 100644 --- a/app/server.go +++ b/app/server.go @@ -18,8 +18,6 @@ import ( "github.com/gorilla/handlers" "github.com/gorilla/mux" "golang.org/x/crypto/acme/autocert" - "gopkg.in/throttled/throttled.v2" - "gopkg.in/throttled/throttled.v2/store/memstore" "github.com/mattermost/mattermost-server/model" "github.com/mattermost/mattermost-server/store" @@ -32,6 +30,7 @@ type Server struct { Router *mux.Router Server *http.Server ListenAddr *net.TCPAddr + RateLimiter *RateLimiter didFinishListen chan struct{} } @@ -84,10 +83,26 @@ func (cw *CorsWrapper) ServeHTTP(w http.ResponseWriter, r *http.Request) { const TIME_TO_WAIT_FOR_CONNECTIONS_TO_CLOSE_ON_SERVER_SHUTDOWN = time.Second -type VaryBy struct{} +type VaryBy struct { + useIP bool + useAuth bool +} func (m *VaryBy) Key(r *http.Request) string { - return utils.GetIpAddress(r) + key := "" + + if m.useAuth { + token, tokenLocation := ParseAuthTokenFromRequest(r) + if tokenLocation != TokenLocationNotFound { + key += token + } else if m.useIP { // If we don't find an authentication token and IP based is enabled, fall back to IP + key += utils.GetIpAddress(r) + } + } else if m.useIP { // Only if Auth based is not enabed do we use a plain IP based + key = utils.GetIpAddress(r) + } + + return key } func redirectHTTPToHTTPS(w http.ResponseWriter, r *http.Request) { @@ -109,33 +124,9 @@ func (a *App) StartServer() { if *a.Config().RateLimitSettings.Enable { l4g.Info(utils.T("api.server.start_server.rate.info")) - store, err := memstore.New(*a.Config().RateLimitSettings.MemoryStoreSize) - if err != nil { - l4g.Critical(utils.T("api.server.start_server.rate_limiting_memory_store")) - return - } - - quota := throttled.RateQuota{ - MaxRate: throttled.PerSec(*a.Config().RateLimitSettings.PerSec), - MaxBurst: *a.Config().RateLimitSettings.MaxBurst, - } - - rateLimiter, err := throttled.NewGCRARateLimiter(store, quota) - if err != nil { - l4g.Critical(utils.T("api.server.start_server.rate_limiting_rate_limiter")) - return - } - - httpRateLimiter := throttled.HTTPRateLimiter{ - RateLimiter: rateLimiter, - VaryBy: &VaryBy{}, - DeniedHandler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - l4g.Error("%v: Denied due to throttling settings code=429 ip=%v", r.URL.Path, utils.GetIpAddress(r)) - throttled.DefaultDeniedHandler.ServeHTTP(w, r) - }), - } + a.Srv.RateLimiter = NewRateLimiter(&a.Config().RateLimitSettings) - handler = httpRateLimiter.RateLimit(handler) + handler = a.Srv.RateLimiter.RateLimitHandler(handler) } a.Srv.Server = &http.Server{ -- cgit v1.2.3-1-g7c22