From 1262d254736229618582f0963c9c30c4e66efb98 Mon Sep 17 00:00:00 2001 From: Christopher Speller Date: Wed, 31 Jan 2018 09:49:15 -0800 Subject: User based rate limiting (#8152) --- api/context.go | 45 +++++----------- api4/context.go | 45 +++++----------- app/authentication.go | 47 ++++++++++++++++ app/authentication_test.go | 52 ++++++++++++++++++ app/ratelimit.go | 131 +++++++++++++++++++++++++++++++++++++++++++++ app/ratelimit_test.go | 67 +++++++++++++++++++++++ app/server.go | 51 ++++++++---------- config/default.json | 1 + model/config.go | 11 +++- 9 files changed, 355 insertions(+), 95 deletions(-) create mode 100644 app/authentication_test.go create mode 100644 app/ratelimit.go create mode 100644 app/ratelimit_test.go diff --git a/api/context.go b/api/context.go index 34a87e633..84967659d 100644 --- a/api/context.go +++ b/api/context.go @@ -114,38 +114,14 @@ func (h handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { metrics.IncrementHttpRequest() } - token := "" - isTokenFromQueryString := false - - // Attempt to parse token out of the header - authHeader := r.Header.Get(model.HEADER_AUTH) - if len(authHeader) > 6 && strings.ToUpper(authHeader[0:6]) == model.HEADER_BEARER { - // Default session token - token = authHeader[7:] - - } else if len(authHeader) > 5 && strings.ToLower(authHeader[0:5]) == model.HEADER_TOKEN { - // OAuth token - token = authHeader[6:] - } - - // Attempt to parse the token from the cookie - if len(token) == 0 { - if cookie, err := r.Cookie(model.SESSION_COOKIE_TOKEN); err == nil { - token = cookie.Value - - if (h.requireSystemAdmin || h.requireUser) && !h.trustRequester { - if r.Header.Get(model.HEADER_REQUESTED_WITH) != model.HEADER_REQUESTED_WITH_XML { - c.Err = model.NewAppError("ServeHTTP", "api.context.session_expired.app_error", nil, "token="+token+" Appears to be a CSRF attempt", http.StatusUnauthorized) - token = "" - } - } - } - } + token, tokenLocation := app.ParseAuthTokenFromRequest(r) - // Attempt to parse token out of the query string - if len(token) == 0 { - token = r.URL.Query().Get("access_token") - isTokenFromQueryString = true + // CSRF Check + if tokenLocation == app.TokenLocationCookie && (h.requireSystemAdmin || h.requireUser) && !h.trustRequester { + if r.Header.Get(model.HEADER_REQUESTED_WITH) != model.HEADER_REQUESTED_WITH_XML { + c.Err = model.NewAppError("ServeHTTP", "api.context.session_expired.app_error", nil, "token="+token+" Appears to be a CSRF attempt", http.StatusUnauthorized) + token = "" + } } c.SetSiteURLHeader(app.GetProtocol(r) + "://" + r.Host) @@ -175,11 +151,16 @@ func (h handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { if h.requireUser || h.requireSystemAdmin { c.Err = model.NewAppError("ServeHTTP", "api.context.session_expired.app_error", nil, "token="+token, http.StatusUnauthorized) } - } else if !session.IsOAuth && isTokenFromQueryString { + } else if !session.IsOAuth && tokenLocation == app.TokenLocationQueryString { c.Err = model.NewAppError("ServeHTTP", "api.context.token_provided.app_error", nil, "token="+token, http.StatusUnauthorized) } else { c.Session = *session } + + // Rate limit by UserID + if c.App.Srv.RateLimiter != nil && c.App.Srv.RateLimiter.UserIdRateLimit(c.Session.UserId, w) { + return + } } if h.isApi || h.isTeamIndependent { diff --git a/api4/context.go b/api4/context.go index b10ea7a9b..980897062 100644 --- a/api4/context.go +++ b/api4/context.go @@ -99,38 +99,14 @@ func (h handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { c.IpAddress = utils.GetIpAddress(r) c.Params = ApiParamsFromRequest(r) - token := "" - isTokenFromQueryString := false - - // Attempt to parse token out of the header - authHeader := r.Header.Get(model.HEADER_AUTH) - if len(authHeader) > 6 && strings.ToUpper(authHeader[0:6]) == model.HEADER_BEARER { - // Default session token - token = authHeader[7:] - - } else if len(authHeader) > 5 && strings.ToLower(authHeader[0:5]) == model.HEADER_TOKEN { - // OAuth token - token = authHeader[6:] - } - - // Attempt to parse the token from the cookie - if len(token) == 0 { - if cookie, err := r.Cookie(model.SESSION_COOKIE_TOKEN); err == nil { - token = cookie.Value - - if h.requireSession && !h.trustRequester { - if r.Header.Get(model.HEADER_REQUESTED_WITH) != model.HEADER_REQUESTED_WITH_XML { - c.Err = model.NewAppError("ServeHTTP", "api.context.session_expired.app_error", nil, "token="+token+" Appears to be a CSRF attempt", http.StatusUnauthorized) - token = "" - } - } - } - } + token, tokenLocation := app.ParseAuthTokenFromRequest(r) - // Attempt to parse token out of the query string - if len(token) == 0 { - token = r.URL.Query().Get("access_token") - isTokenFromQueryString = true + // CSRF Check + if tokenLocation == app.TokenLocationCookie && h.requireSession && !h.trustRequester { + if r.Header.Get(model.HEADER_REQUESTED_WITH) != model.HEADER_REQUESTED_WITH_XML { + c.Err = model.NewAppError("ServeHTTP", "api.context.session_expired.app_error", nil, "token="+token+" Appears to be a CSRF attempt", http.StatusUnauthorized) + token = "" + } } c.SetSiteURLHeader(app.GetProtocol(r) + "://" + r.Host) @@ -153,11 +129,16 @@ func (h handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { if h.requireSession { c.Err = model.NewAppError("ServeHTTP", "api.context.session_expired.app_error", nil, "token="+token, http.StatusUnauthorized) } - } else if !session.IsOAuth && isTokenFromQueryString { + } else if !session.IsOAuth && tokenLocation == app.TokenLocationQueryString { c.Err = model.NewAppError("ServeHTTP", "api.context.token_provided.app_error", nil, "token="+token, http.StatusUnauthorized) } else { c.Session = *session } + + // Rate limit by UserID + if c.App.Srv.RateLimiter != nil && c.App.Srv.RateLimiter.UserIdRateLimit(c.Session.UserId, w) { + return + } } c.Path = r.URL.Path diff --git a/app/authentication.go b/app/authentication.go index 91e3bf564..140bffd5a 100644 --- a/app/authentication.go +++ b/app/authentication.go @@ -11,6 +11,30 @@ import ( "github.com/mattermost/mattermost-server/utils" ) +type TokenLocation int + +const ( + TokenLocationNotFound = iota + TokenLocationHeader + TokenLocationCookie + TokenLocationQueryString +) + +func (tl TokenLocation) String() string { + switch tl { + case TokenLocationNotFound: + return "Not Found" + case TokenLocationHeader: + return "Header" + case TokenLocationCookie: + return "Cookie" + case TokenLocationQueryString: + return "QueryString" + default: + return "Unknown" + } +} + func (a *App) IsPasswordValid(password string) *model.AppError { if utils.IsLicensed() && *utils.License().Features.PasswordRequirements { return utils.IsPasswordValidWithSettings(password, &a.Config().PasswordSettings) @@ -168,3 +192,26 @@ func (a *App) authenticateUser(user *model.User, password, mfaToken string) (*mo } } } + +func ParseAuthTokenFromRequest(r *http.Request) (string, TokenLocation) { + authHeader := r.Header.Get(model.HEADER_AUTH) + if len(authHeader) > 6 && strings.ToUpper(authHeader[0:6]) == model.HEADER_BEARER { + // Default session token + return authHeader[7:], TokenLocationHeader + } else if len(authHeader) > 5 && strings.ToLower(authHeader[0:5]) == model.HEADER_TOKEN { + // OAuth token + return authHeader[6:], TokenLocationHeader + } + + // Attempt to parse the token from the cookie + if cookie, err := r.Cookie(model.SESSION_COOKIE_TOKEN); err == nil { + return cookie.Value, TokenLocationCookie + } + + // Attempt to parse token out of the query string + if token := r.URL.Query().Get("access_token"); token != "" { + return token, TokenLocationQueryString + } + + return "", TokenLocationNotFound +} diff --git a/app/authentication_test.go b/app/authentication_test.go new file mode 100644 index 000000000..f3014b1b8 --- /dev/null +++ b/app/authentication_test.go @@ -0,0 +1,52 @@ +// Copyright (c) 2017-present Mattermost, Inc. All Rights Reserved. +// See License.txt for license information. + +package app + +import ( + "net/http" + "net/http/httptest" + "strconv" + "testing" + + "github.com/mattermost/mattermost-server/model" + "github.com/stretchr/testify/require" +) + +func TestParseAuthTokenFromRequest(t *testing.T) { + cases := []struct { + header string + cookie string + query string + expectedToken string + expectedLocation TokenLocation + }{ + {"", "", "", "", TokenLocationNotFound}, + {"token mytoken", "", "", "mytoken", TokenLocationHeader}, + {"BEARER mytoken", "", "", "mytoken", TokenLocationHeader}, + {"", "mytoken", "", "mytoken", TokenLocationCookie}, + {"", "", "mytoken", "mytoken", TokenLocationQueryString}, + } + + for testnum, tc := range cases { + pathname := "/test/here" + if tc.query != "" { + pathname += "?access_token=" + tc.query + } + req := httptest.NewRequest("GET", pathname, nil) + if tc.header != "" { + req.Header.Add(model.HEADER_AUTH, tc.header) + } + if tc.cookie != "" { + req.AddCookie(&http.Cookie{ + Name: model.SESSION_COOKIE_TOKEN, + Value: tc.cookie, + }) + } + + token, location := ParseAuthTokenFromRequest(req) + + require.Equal(t, tc.expectedToken, token, "Wrong token on test "+strconv.Itoa(testnum)) + require.Equal(t, tc.expectedLocation, location, "Wrong location on test "+strconv.Itoa(testnum)) + } +} diff --git a/app/ratelimit.go b/app/ratelimit.go new file mode 100644 index 000000000..460088598 --- /dev/null +++ b/app/ratelimit.go @@ -0,0 +1,131 @@ +// Copyright (c) 2018-present Mattermost, Inc. All Rights Reserved. +// See License.txt for license information. + +package app + +import ( + "math" + "net/http" + "strconv" + "strings" + + l4g "github.com/alecthomas/log4go" + "github.com/mattermost/mattermost-server/model" + "github.com/mattermost/mattermost-server/utils" + throttled "gopkg.in/throttled/throttled.v2" + "gopkg.in/throttled/throttled.v2/store/memstore" +) + +type RateLimiter struct { + throttledRateLimiter *throttled.GCRARateLimiter + useAuth bool + useIP bool + header string +} + +func NewRateLimiter(settings *model.RateLimitSettings) *RateLimiter { + store, err := memstore.New(*settings.MemoryStoreSize) + if err != nil { + l4g.Critical(utils.T("api.server.start_server.rate_limiting_memory_store")) + return nil + } + + quota := throttled.RateQuota{ + MaxRate: throttled.PerSec(*settings.PerSec), + MaxBurst: *settings.MaxBurst, + } + + throttledRateLimiter, err := throttled.NewGCRARateLimiter(store, quota) + if err != nil { + l4g.Critical(utils.T("api.server.start_server.rate_limiting_rate_limiter")) + return nil + } + + return &RateLimiter{ + throttledRateLimiter: throttledRateLimiter, + useAuth: *settings.VaryByUser, + useIP: *settings.VaryByRemoteAddr, + header: settings.VaryByHeader, + } +} + +func (rl *RateLimiter) GenerateKey(r *http.Request) string { + key := "" + + if rl.useAuth { + token, tokenLocation := ParseAuthTokenFromRequest(r) + if tokenLocation != TokenLocationNotFound { + key += token + } else if rl.useIP { // If we don't find an authentication token and IP based is enabled, fall back to IP + key += utils.GetIpAddress(r) + } + } else if rl.useIP { // Only if Auth based is not enabed do we use a plain IP based + key += utils.GetIpAddress(r) + } + + // Note that most of the time the user won't have to set this because the utils.GetIpAddress above tries the + // most common headers anyway. + if rl.header != "" { + key += strings.ToLower(r.Header.Get(rl.header)) + } + + return key +} + +func (rl *RateLimiter) RateLimitWriter(key string, w http.ResponseWriter) bool { + limited, context, err := rl.throttledRateLimiter.RateLimit(key, 1) + if err != nil { + l4g.Critical("Internal server error when rate limiting. Rate Limiting broken. Error:" + err.Error()) + return false + } + + setRateLimitHeaders(w, context) + + if limited { + l4g.Error("Denied due to throttling settings code=429 key=%v", key) + http.Error(w, "limit exceeded", 429) + } + + return limited +} + +func (rl *RateLimiter) UserIdRateLimit(userId string, w http.ResponseWriter) bool { + if rl.useAuth { + if rl.RateLimitWriter(userId, w) { + return true + } + } + return false +} + +func (rl *RateLimiter) RateLimitHandler(wrappedHandler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + key := rl.GenerateKey(r) + limited := rl.RateLimitWriter(key, w) + + if !limited { + wrappedHandler.ServeHTTP(w, r) + } + }) +} + +// Copied from https://github.com/throttled/throttled http.go +func setRateLimitHeaders(w http.ResponseWriter, context throttled.RateLimitResult) { + if v := context.Limit; v >= 0 { + w.Header().Add("X-RateLimit-Limit", strconv.Itoa(v)) + } + + if v := context.Remaining; v >= 0 { + w.Header().Add("X-RateLimit-Remaining", strconv.Itoa(v)) + } + + if v := context.ResetAfter; v >= 0 { + vi := int(math.Ceil(v.Seconds())) + w.Header().Add("X-RateLimit-Reset", strconv.Itoa(vi)) + } + + if v := context.RetryAfter; v >= 0 { + vi := int(math.Ceil(v.Seconds())) + w.Header().Add("Retry-After", strconv.Itoa(vi)) + } +} diff --git a/app/ratelimit_test.go b/app/ratelimit_test.go new file mode 100644 index 000000000..ddaa25710 --- /dev/null +++ b/app/ratelimit_test.go @@ -0,0 +1,67 @@ +// Copyright (c) 2018-present Mattermost, Inc. All Rights Reserved. +// See License.txt for license information. + +package app + +import ( + "net/http" + "net/http/httptest" + "strconv" + "testing" + + "github.com/mattermost/mattermost-server/model" + "github.com/stretchr/testify/require" +) + +func genRateLimitSettings(useAuth, useIP bool, header string) *model.RateLimitSettings { + return &model.RateLimitSettings{ + Enable: model.NewBool(true), + PerSec: model.NewInt(10), + MaxBurst: model.NewInt(100), + MemoryStoreSize: model.NewInt(10000), + VaryByRemoteAddr: model.NewBool(useIP), + VaryByUser: model.NewBool(useAuth), + VaryByHeader: header, + } +} + +func TestGenerateKey(t *testing.T) { + cases := []struct { + useAuth bool + useIP bool + header string + authTokenResult string + ipResult string + headerResult string + expectedKey string + }{ + {false, false, "", "", "", "", ""}, + {true, false, "", "resultkey", "notme", "notme", "resultkey"}, + {false, true, "", "notme", "resultkey", "notme", "resultkey"}, + {false, false, "myheader", "notme", "notme", "resultkey", "resultkey"}, + {true, true, "", "resultkey", "ipaddr", "notme", "resultkey"}, + {true, true, "", "", "ipaddr", "notme", "ipaddr"}, + {true, true, "myheader", "resultkey", "ipaddr", "hadd", "resultkeyhadd"}, + {true, true, "myheader", "", "ipaddr", "hadd", "ipaddrhadd"}, + } + + for testnum, tc := range cases { + req := httptest.NewRequest("GET", "/", nil) + if tc.authTokenResult != "" { + req.AddCookie(&http.Cookie{ + Name: model.SESSION_COOKIE_TOKEN, + Value: tc.authTokenResult, + }) + } + req.RemoteAddr = tc.ipResult + ":80" + if tc.headerResult != "" { + req.Header.Set(tc.header, tc.headerResult) + } + + rateLimiter := NewRateLimiter(genRateLimitSettings(tc.useAuth, tc.useIP, tc.header)) + + key := rateLimiter.GenerateKey(req) + + require.Equal(t, tc.expectedKey, key, "Wrong key on test "+strconv.Itoa(testnum)) + } +} diff --git a/app/server.go b/app/server.go index c008da3a1..2a94bf2c7 100644 --- a/app/server.go +++ b/app/server.go @@ -18,8 +18,6 @@ import ( "github.com/gorilla/handlers" "github.com/gorilla/mux" "golang.org/x/crypto/acme/autocert" - "gopkg.in/throttled/throttled.v2" - "gopkg.in/throttled/throttled.v2/store/memstore" "github.com/mattermost/mattermost-server/model" "github.com/mattermost/mattermost-server/store" @@ -32,6 +30,7 @@ type Server struct { Router *mux.Router Server *http.Server ListenAddr *net.TCPAddr + RateLimiter *RateLimiter didFinishListen chan struct{} } @@ -84,10 +83,26 @@ func (cw *CorsWrapper) ServeHTTP(w http.ResponseWriter, r *http.Request) { const TIME_TO_WAIT_FOR_CONNECTIONS_TO_CLOSE_ON_SERVER_SHUTDOWN = time.Second -type VaryBy struct{} +type VaryBy struct { + useIP bool + useAuth bool +} func (m *VaryBy) Key(r *http.Request) string { - return utils.GetIpAddress(r) + key := "" + + if m.useAuth { + token, tokenLocation := ParseAuthTokenFromRequest(r) + if tokenLocation != TokenLocationNotFound { + key += token + } else if m.useIP { // If we don't find an authentication token and IP based is enabled, fall back to IP + key += utils.GetIpAddress(r) + } + } else if m.useIP { // Only if Auth based is not enabed do we use a plain IP based + key = utils.GetIpAddress(r) + } + + return key } func redirectHTTPToHTTPS(w http.ResponseWriter, r *http.Request) { @@ -109,33 +124,9 @@ func (a *App) StartServer() { if *a.Config().RateLimitSettings.Enable { l4g.Info(utils.T("api.server.start_server.rate.info")) - store, err := memstore.New(*a.Config().RateLimitSettings.MemoryStoreSize) - if err != nil { - l4g.Critical(utils.T("api.server.start_server.rate_limiting_memory_store")) - return - } - - quota := throttled.RateQuota{ - MaxRate: throttled.PerSec(*a.Config().RateLimitSettings.PerSec), - MaxBurst: *a.Config().RateLimitSettings.MaxBurst, - } - - rateLimiter, err := throttled.NewGCRARateLimiter(store, quota) - if err != nil { - l4g.Critical(utils.T("api.server.start_server.rate_limiting_rate_limiter")) - return - } - - httpRateLimiter := throttled.HTTPRateLimiter{ - RateLimiter: rateLimiter, - VaryBy: &VaryBy{}, - DeniedHandler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - l4g.Error("%v: Denied due to throttling settings code=429 ip=%v", r.URL.Path, utils.GetIpAddress(r)) - throttled.DefaultDeniedHandler.ServeHTTP(w, r) - }), - } + a.Srv.RateLimiter = NewRateLimiter(&a.Config().RateLimitSettings) - handler = httpRateLimiter.RateLimit(handler) + handler = a.Srv.RateLimiter.RateLimitHandler(handler) } a.Srv.Server = &http.Server{ diff --git a/config/default.json b/config/default.json index 4bcccdf19..e0fda74cd 100644 --- a/config/default.json +++ b/config/default.json @@ -180,6 +180,7 @@ "MaxBurst": 100, "MemoryStoreSize": 10000, "VaryByRemoteAddr": true, + "VaryByUser": false, "VaryByHeader": "" }, "PrivacySettings": { diff --git a/model/config.go b/model/config.go index 525fc71ed..b7888ab13 100644 --- a/model/config.go +++ b/model/config.go @@ -802,7 +802,8 @@ type RateLimitSettings struct { PerSec *int MaxBurst *int MemoryStoreSize *int - VaryByRemoteAddr bool + VaryByRemoteAddr *bool + VaryByUser *bool VaryByHeader string } @@ -822,6 +823,14 @@ func (s *RateLimitSettings) SetDefaults() { if s.MemoryStoreSize == nil { s.MemoryStoreSize = NewInt(10000) } + + if s.VaryByRemoteAddr == nil { + s.VaryByRemoteAddr = NewBool(true) + } + + if s.VaryByUser == nil { + s.VaryByUser = NewBool(false) + } } type PrivacySettings struct { -- cgit v1.2.3-1-g7c22