summaryrefslogtreecommitdiffstats
path: root/app
diff options
context:
space:
mode:
authorChristopher Speller <crspeller@gmail.com>2018-01-31 09:49:15 -0800
committerGitHub <noreply@github.com>2018-01-31 09:49:15 -0800
commit1262d254736229618582f0963c9c30c4e66efb98 (patch)
treec2375b6c6b143dc59c24d590eb59c5d49d17247e /app
parente0ee73ef9963ab398bcc6011795ad23e8e003147 (diff)
downloadchat-1262d254736229618582f0963c9c30c4e66efb98.tar.gz
chat-1262d254736229618582f0963c9c30c4e66efb98.tar.bz2
chat-1262d254736229618582f0963c9c30c4e66efb98.zip
User based rate limiting (#8152)
Diffstat (limited to 'app')
-rw-r--r--app/authentication.go47
-rw-r--r--app/authentication_test.go52
-rw-r--r--app/ratelimit.go131
-rw-r--r--app/ratelimit_test.go67
-rw-r--r--app/server.go51
5 files changed, 318 insertions, 30 deletions
diff --git a/app/authentication.go b/app/authentication.go
index 91e3bf564..140bffd5a 100644
--- a/app/authentication.go
+++ b/app/authentication.go
@@ -11,6 +11,30 @@ import (
"github.com/mattermost/mattermost-server/utils"
)
+type TokenLocation int
+
+const (
+ TokenLocationNotFound = iota
+ TokenLocationHeader
+ TokenLocationCookie
+ TokenLocationQueryString
+)
+
+func (tl TokenLocation) String() string {
+ switch tl {
+ case TokenLocationNotFound:
+ return "Not Found"
+ case TokenLocationHeader:
+ return "Header"
+ case TokenLocationCookie:
+ return "Cookie"
+ case TokenLocationQueryString:
+ return "QueryString"
+ default:
+ return "Unknown"
+ }
+}
+
func (a *App) IsPasswordValid(password string) *model.AppError {
if utils.IsLicensed() && *utils.License().Features.PasswordRequirements {
return utils.IsPasswordValidWithSettings(password, &a.Config().PasswordSettings)
@@ -168,3 +192,26 @@ func (a *App) authenticateUser(user *model.User, password, mfaToken string) (*mo
}
}
}
+
+func ParseAuthTokenFromRequest(r *http.Request) (string, TokenLocation) {
+ authHeader := r.Header.Get(model.HEADER_AUTH)
+ if len(authHeader) > 6 && strings.ToUpper(authHeader[0:6]) == model.HEADER_BEARER {
+ // Default session token
+ return authHeader[7:], TokenLocationHeader
+ } else if len(authHeader) > 5 && strings.ToLower(authHeader[0:5]) == model.HEADER_TOKEN {
+ // OAuth token
+ return authHeader[6:], TokenLocationHeader
+ }
+
+ // Attempt to parse the token from the cookie
+ if cookie, err := r.Cookie(model.SESSION_COOKIE_TOKEN); err == nil {
+ return cookie.Value, TokenLocationCookie
+ }
+
+ // Attempt to parse token out of the query string
+ if token := r.URL.Query().Get("access_token"); token != "" {
+ return token, TokenLocationQueryString
+ }
+
+ return "", TokenLocationNotFound
+}
diff --git a/app/authentication_test.go b/app/authentication_test.go
new file mode 100644
index 000000000..f3014b1b8
--- /dev/null
+++ b/app/authentication_test.go
@@ -0,0 +1,52 @@
+// Copyright (c) 2017-present Mattermost, Inc. All Rights Reserved.
+// See License.txt for license information.
+
+package app
+
+import (
+ "net/http"
+ "net/http/httptest"
+ "strconv"
+ "testing"
+
+ "github.com/mattermost/mattermost-server/model"
+ "github.com/stretchr/testify/require"
+)
+
+func TestParseAuthTokenFromRequest(t *testing.T) {
+ cases := []struct {
+ header string
+ cookie string
+ query string
+ expectedToken string
+ expectedLocation TokenLocation
+ }{
+ {"", "", "", "", TokenLocationNotFound},
+ {"token mytoken", "", "", "mytoken", TokenLocationHeader},
+ {"BEARER mytoken", "", "", "mytoken", TokenLocationHeader},
+ {"", "mytoken", "", "mytoken", TokenLocationCookie},
+ {"", "", "mytoken", "mytoken", TokenLocationQueryString},
+ }
+
+ for testnum, tc := range cases {
+ pathname := "/test/here"
+ if tc.query != "" {
+ pathname += "?access_token=" + tc.query
+ }
+ req := httptest.NewRequest("GET", pathname, nil)
+ if tc.header != "" {
+ req.Header.Add(model.HEADER_AUTH, tc.header)
+ }
+ if tc.cookie != "" {
+ req.AddCookie(&http.Cookie{
+ Name: model.SESSION_COOKIE_TOKEN,
+ Value: tc.cookie,
+ })
+ }
+
+ token, location := ParseAuthTokenFromRequest(req)
+
+ require.Equal(t, tc.expectedToken, token, "Wrong token on test "+strconv.Itoa(testnum))
+ require.Equal(t, tc.expectedLocation, location, "Wrong location on test "+strconv.Itoa(testnum))
+ }
+}
diff --git a/app/ratelimit.go b/app/ratelimit.go
new file mode 100644
index 000000000..460088598
--- /dev/null
+++ b/app/ratelimit.go
@@ -0,0 +1,131 @@
+// Copyright (c) 2018-present Mattermost, Inc. All Rights Reserved.
+// See License.txt for license information.
+
+package app
+
+import (
+ "math"
+ "net/http"
+ "strconv"
+ "strings"
+
+ l4g "github.com/alecthomas/log4go"
+ "github.com/mattermost/mattermost-server/model"
+ "github.com/mattermost/mattermost-server/utils"
+ throttled "gopkg.in/throttled/throttled.v2"
+ "gopkg.in/throttled/throttled.v2/store/memstore"
+)
+
+type RateLimiter struct {
+ throttledRateLimiter *throttled.GCRARateLimiter
+ useAuth bool
+ useIP bool
+ header string
+}
+
+func NewRateLimiter(settings *model.RateLimitSettings) *RateLimiter {
+ store, err := memstore.New(*settings.MemoryStoreSize)
+ if err != nil {
+ l4g.Critical(utils.T("api.server.start_server.rate_limiting_memory_store"))
+ return nil
+ }
+
+ quota := throttled.RateQuota{
+ MaxRate: throttled.PerSec(*settings.PerSec),
+ MaxBurst: *settings.MaxBurst,
+ }
+
+ throttledRateLimiter, err := throttled.NewGCRARateLimiter(store, quota)
+ if err != nil {
+ l4g.Critical(utils.T("api.server.start_server.rate_limiting_rate_limiter"))
+ return nil
+ }
+
+ return &RateLimiter{
+ throttledRateLimiter: throttledRateLimiter,
+ useAuth: *settings.VaryByUser,
+ useIP: *settings.VaryByRemoteAddr,
+ header: settings.VaryByHeader,
+ }
+}
+
+func (rl *RateLimiter) GenerateKey(r *http.Request) string {
+ key := ""
+
+ if rl.useAuth {
+ token, tokenLocation := ParseAuthTokenFromRequest(r)
+ if tokenLocation != TokenLocationNotFound {
+ key += token
+ } else if rl.useIP { // If we don't find an authentication token and IP based is enabled, fall back to IP
+ key += utils.GetIpAddress(r)
+ }
+ } else if rl.useIP { // Only if Auth based is not enabed do we use a plain IP based
+ key += utils.GetIpAddress(r)
+ }
+
+ // Note that most of the time the user won't have to set this because the utils.GetIpAddress above tries the
+ // most common headers anyway.
+ if rl.header != "" {
+ key += strings.ToLower(r.Header.Get(rl.header))
+ }
+
+ return key
+}
+
+func (rl *RateLimiter) RateLimitWriter(key string, w http.ResponseWriter) bool {
+ limited, context, err := rl.throttledRateLimiter.RateLimit(key, 1)
+ if err != nil {
+ l4g.Critical("Internal server error when rate limiting. Rate Limiting broken. Error:" + err.Error())
+ return false
+ }
+
+ setRateLimitHeaders(w, context)
+
+ if limited {
+ l4g.Error("Denied due to throttling settings code=429 key=%v", key)
+ http.Error(w, "limit exceeded", 429)
+ }
+
+ return limited
+}
+
+func (rl *RateLimiter) UserIdRateLimit(userId string, w http.ResponseWriter) bool {
+ if rl.useAuth {
+ if rl.RateLimitWriter(userId, w) {
+ return true
+ }
+ }
+ return false
+}
+
+func (rl *RateLimiter) RateLimitHandler(wrappedHandler http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ key := rl.GenerateKey(r)
+ limited := rl.RateLimitWriter(key, w)
+
+ if !limited {
+ wrappedHandler.ServeHTTP(w, r)
+ }
+ })
+}
+
+// Copied from https://github.com/throttled/throttled http.go
+func setRateLimitHeaders(w http.ResponseWriter, context throttled.RateLimitResult) {
+ if v := context.Limit; v >= 0 {
+ w.Header().Add("X-RateLimit-Limit", strconv.Itoa(v))
+ }
+
+ if v := context.Remaining; v >= 0 {
+ w.Header().Add("X-RateLimit-Remaining", strconv.Itoa(v))
+ }
+
+ if v := context.ResetAfter; v >= 0 {
+ vi := int(math.Ceil(v.Seconds()))
+ w.Header().Add("X-RateLimit-Reset", strconv.Itoa(vi))
+ }
+
+ if v := context.RetryAfter; v >= 0 {
+ vi := int(math.Ceil(v.Seconds()))
+ w.Header().Add("Retry-After", strconv.Itoa(vi))
+ }
+}
diff --git a/app/ratelimit_test.go b/app/ratelimit_test.go
new file mode 100644
index 000000000..ddaa25710
--- /dev/null
+++ b/app/ratelimit_test.go
@@ -0,0 +1,67 @@
+// Copyright (c) 2018-present Mattermost, Inc. All Rights Reserved.
+// See License.txt for license information.
+
+package app
+
+import (
+ "net/http"
+ "net/http/httptest"
+ "strconv"
+ "testing"
+
+ "github.com/mattermost/mattermost-server/model"
+ "github.com/stretchr/testify/require"
+)
+
+func genRateLimitSettings(useAuth, useIP bool, header string) *model.RateLimitSettings {
+ return &model.RateLimitSettings{
+ Enable: model.NewBool(true),
+ PerSec: model.NewInt(10),
+ MaxBurst: model.NewInt(100),
+ MemoryStoreSize: model.NewInt(10000),
+ VaryByRemoteAddr: model.NewBool(useIP),
+ VaryByUser: model.NewBool(useAuth),
+ VaryByHeader: header,
+ }
+}
+
+func TestGenerateKey(t *testing.T) {
+ cases := []struct {
+ useAuth bool
+ useIP bool
+ header string
+ authTokenResult string
+ ipResult string
+ headerResult string
+ expectedKey string
+ }{
+ {false, false, "", "", "", "", ""},
+ {true, false, "", "resultkey", "notme", "notme", "resultkey"},
+ {false, true, "", "notme", "resultkey", "notme", "resultkey"},
+ {false, false, "myheader", "notme", "notme", "resultkey", "resultkey"},
+ {true, true, "", "resultkey", "ipaddr", "notme", "resultkey"},
+ {true, true, "", "", "ipaddr", "notme", "ipaddr"},
+ {true, true, "myheader", "resultkey", "ipaddr", "hadd", "resultkeyhadd"},
+ {true, true, "myheader", "", "ipaddr", "hadd", "ipaddrhadd"},
+ }
+
+ for testnum, tc := range cases {
+ req := httptest.NewRequest("GET", "/", nil)
+ if tc.authTokenResult != "" {
+ req.AddCookie(&http.Cookie{
+ Name: model.SESSION_COOKIE_TOKEN,
+ Value: tc.authTokenResult,
+ })
+ }
+ req.RemoteAddr = tc.ipResult + ":80"
+ if tc.headerResult != "" {
+ req.Header.Set(tc.header, tc.headerResult)
+ }
+
+ rateLimiter := NewRateLimiter(genRateLimitSettings(tc.useAuth, tc.useIP, tc.header))
+
+ key := rateLimiter.GenerateKey(req)
+
+ require.Equal(t, tc.expectedKey, key, "Wrong key on test "+strconv.Itoa(testnum))
+ }
+}
diff --git a/app/server.go b/app/server.go
index c008da3a1..2a94bf2c7 100644
--- a/app/server.go
+++ b/app/server.go
@@ -18,8 +18,6 @@ import (
"github.com/gorilla/handlers"
"github.com/gorilla/mux"
"golang.org/x/crypto/acme/autocert"
- "gopkg.in/throttled/throttled.v2"
- "gopkg.in/throttled/throttled.v2/store/memstore"
"github.com/mattermost/mattermost-server/model"
"github.com/mattermost/mattermost-server/store"
@@ -32,6 +30,7 @@ type Server struct {
Router *mux.Router
Server *http.Server
ListenAddr *net.TCPAddr
+ RateLimiter *RateLimiter
didFinishListen chan struct{}
}
@@ -84,10 +83,26 @@ func (cw *CorsWrapper) ServeHTTP(w http.ResponseWriter, r *http.Request) {
const TIME_TO_WAIT_FOR_CONNECTIONS_TO_CLOSE_ON_SERVER_SHUTDOWN = time.Second
-type VaryBy struct{}
+type VaryBy struct {
+ useIP bool
+ useAuth bool
+}
func (m *VaryBy) Key(r *http.Request) string {
- return utils.GetIpAddress(r)
+ key := ""
+
+ if m.useAuth {
+ token, tokenLocation := ParseAuthTokenFromRequest(r)
+ if tokenLocation != TokenLocationNotFound {
+ key += token
+ } else if m.useIP { // If we don't find an authentication token and IP based is enabled, fall back to IP
+ key += utils.GetIpAddress(r)
+ }
+ } else if m.useIP { // Only if Auth based is not enabed do we use a plain IP based
+ key = utils.GetIpAddress(r)
+ }
+
+ return key
}
func redirectHTTPToHTTPS(w http.ResponseWriter, r *http.Request) {
@@ -109,33 +124,9 @@ func (a *App) StartServer() {
if *a.Config().RateLimitSettings.Enable {
l4g.Info(utils.T("api.server.start_server.rate.info"))
- store, err := memstore.New(*a.Config().RateLimitSettings.MemoryStoreSize)
- if err != nil {
- l4g.Critical(utils.T("api.server.start_server.rate_limiting_memory_store"))
- return
- }
-
- quota := throttled.RateQuota{
- MaxRate: throttled.PerSec(*a.Config().RateLimitSettings.PerSec),
- MaxBurst: *a.Config().RateLimitSettings.MaxBurst,
- }
-
- rateLimiter, err := throttled.NewGCRARateLimiter(store, quota)
- if err != nil {
- l4g.Critical(utils.T("api.server.start_server.rate_limiting_rate_limiter"))
- return
- }
-
- httpRateLimiter := throttled.HTTPRateLimiter{
- RateLimiter: rateLimiter,
- VaryBy: &VaryBy{},
- DeniedHandler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- l4g.Error("%v: Denied due to throttling settings code=429 ip=%v", r.URL.Path, utils.GetIpAddress(r))
- throttled.DefaultDeniedHandler.ServeHTTP(w, r)
- }),
- }
+ a.Srv.RateLimiter = NewRateLimiter(&a.Config().RateLimitSettings)
- handler = httpRateLimiter.RateLimit(handler)
+ handler = a.Srv.RateLimiter.RateLimitHandler(handler)
}
a.Srv.Server = &http.Server{