summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--api/context.go45
-rw-r--r--api4/context.go45
-rw-r--r--app/authentication.go47
-rw-r--r--app/authentication_test.go52
-rw-r--r--app/ratelimit.go131
-rw-r--r--app/ratelimit_test.go67
-rw-r--r--app/server.go51
-rw-r--r--config/default.json1
-rw-r--r--model/config.go11
9 files changed, 355 insertions, 95 deletions
diff --git a/api/context.go b/api/context.go
index 34a87e633..84967659d 100644
--- a/api/context.go
+++ b/api/context.go
@@ -114,38 +114,14 @@ func (h handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
metrics.IncrementHttpRequest()
}
- token := ""
- isTokenFromQueryString := false
-
- // Attempt to parse token out of the header
- authHeader := r.Header.Get(model.HEADER_AUTH)
- if len(authHeader) > 6 && strings.ToUpper(authHeader[0:6]) == model.HEADER_BEARER {
- // Default session token
- token = authHeader[7:]
-
- } else if len(authHeader) > 5 && strings.ToLower(authHeader[0:5]) == model.HEADER_TOKEN {
- // OAuth token
- token = authHeader[6:]
- }
-
- // Attempt to parse the token from the cookie
- if len(token) == 0 {
- if cookie, err := r.Cookie(model.SESSION_COOKIE_TOKEN); err == nil {
- token = cookie.Value
-
- if (h.requireSystemAdmin || h.requireUser) && !h.trustRequester {
- if r.Header.Get(model.HEADER_REQUESTED_WITH) != model.HEADER_REQUESTED_WITH_XML {
- c.Err = model.NewAppError("ServeHTTP", "api.context.session_expired.app_error", nil, "token="+token+" Appears to be a CSRF attempt", http.StatusUnauthorized)
- token = ""
- }
- }
- }
- }
+ token, tokenLocation := app.ParseAuthTokenFromRequest(r)
- // Attempt to parse token out of the query string
- if len(token) == 0 {
- token = r.URL.Query().Get("access_token")
- isTokenFromQueryString = true
+ // CSRF Check
+ if tokenLocation == app.TokenLocationCookie && (h.requireSystemAdmin || h.requireUser) && !h.trustRequester {
+ if r.Header.Get(model.HEADER_REQUESTED_WITH) != model.HEADER_REQUESTED_WITH_XML {
+ c.Err = model.NewAppError("ServeHTTP", "api.context.session_expired.app_error", nil, "token="+token+" Appears to be a CSRF attempt", http.StatusUnauthorized)
+ token = ""
+ }
}
c.SetSiteURLHeader(app.GetProtocol(r) + "://" + r.Host)
@@ -175,11 +151,16 @@ func (h handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if h.requireUser || h.requireSystemAdmin {
c.Err = model.NewAppError("ServeHTTP", "api.context.session_expired.app_error", nil, "token="+token, http.StatusUnauthorized)
}
- } else if !session.IsOAuth && isTokenFromQueryString {
+ } else if !session.IsOAuth && tokenLocation == app.TokenLocationQueryString {
c.Err = model.NewAppError("ServeHTTP", "api.context.token_provided.app_error", nil, "token="+token, http.StatusUnauthorized)
} else {
c.Session = *session
}
+
+ // Rate limit by UserID
+ if c.App.Srv.RateLimiter != nil && c.App.Srv.RateLimiter.UserIdRateLimit(c.Session.UserId, w) {
+ return
+ }
}
if h.isApi || h.isTeamIndependent {
diff --git a/api4/context.go b/api4/context.go
index b10ea7a9b..980897062 100644
--- a/api4/context.go
+++ b/api4/context.go
@@ -99,38 +99,14 @@ func (h handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
c.IpAddress = utils.GetIpAddress(r)
c.Params = ApiParamsFromRequest(r)
- token := ""
- isTokenFromQueryString := false
-
- // Attempt to parse token out of the header
- authHeader := r.Header.Get(model.HEADER_AUTH)
- if len(authHeader) > 6 && strings.ToUpper(authHeader[0:6]) == model.HEADER_BEARER {
- // Default session token
- token = authHeader[7:]
-
- } else if len(authHeader) > 5 && strings.ToLower(authHeader[0:5]) == model.HEADER_TOKEN {
- // OAuth token
- token = authHeader[6:]
- }
-
- // Attempt to parse the token from the cookie
- if len(token) == 0 {
- if cookie, err := r.Cookie(model.SESSION_COOKIE_TOKEN); err == nil {
- token = cookie.Value
-
- if h.requireSession && !h.trustRequester {
- if r.Header.Get(model.HEADER_REQUESTED_WITH) != model.HEADER_REQUESTED_WITH_XML {
- c.Err = model.NewAppError("ServeHTTP", "api.context.session_expired.app_error", nil, "token="+token+" Appears to be a CSRF attempt", http.StatusUnauthorized)
- token = ""
- }
- }
- }
- }
+ token, tokenLocation := app.ParseAuthTokenFromRequest(r)
- // Attempt to parse token out of the query string
- if len(token) == 0 {
- token = r.URL.Query().Get("access_token")
- isTokenFromQueryString = true
+ // CSRF Check
+ if tokenLocation == app.TokenLocationCookie && h.requireSession && !h.trustRequester {
+ if r.Header.Get(model.HEADER_REQUESTED_WITH) != model.HEADER_REQUESTED_WITH_XML {
+ c.Err = model.NewAppError("ServeHTTP", "api.context.session_expired.app_error", nil, "token="+token+" Appears to be a CSRF attempt", http.StatusUnauthorized)
+ token = ""
+ }
}
c.SetSiteURLHeader(app.GetProtocol(r) + "://" + r.Host)
@@ -153,11 +129,16 @@ func (h handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if h.requireSession {
c.Err = model.NewAppError("ServeHTTP", "api.context.session_expired.app_error", nil, "token="+token, http.StatusUnauthorized)
}
- } else if !session.IsOAuth && isTokenFromQueryString {
+ } else if !session.IsOAuth && tokenLocation == app.TokenLocationQueryString {
c.Err = model.NewAppError("ServeHTTP", "api.context.token_provided.app_error", nil, "token="+token, http.StatusUnauthorized)
} else {
c.Session = *session
}
+
+ // Rate limit by UserID
+ if c.App.Srv.RateLimiter != nil && c.App.Srv.RateLimiter.UserIdRateLimit(c.Session.UserId, w) {
+ return
+ }
}
c.Path = r.URL.Path
diff --git a/app/authentication.go b/app/authentication.go
index 91e3bf564..140bffd5a 100644
--- a/app/authentication.go
+++ b/app/authentication.go
@@ -11,6 +11,30 @@ import (
"github.com/mattermost/mattermost-server/utils"
)
+type TokenLocation int
+
+const (
+ TokenLocationNotFound = iota
+ TokenLocationHeader
+ TokenLocationCookie
+ TokenLocationQueryString
+)
+
+func (tl TokenLocation) String() string {
+ switch tl {
+ case TokenLocationNotFound:
+ return "Not Found"
+ case TokenLocationHeader:
+ return "Header"
+ case TokenLocationCookie:
+ return "Cookie"
+ case TokenLocationQueryString:
+ return "QueryString"
+ default:
+ return "Unknown"
+ }
+}
+
func (a *App) IsPasswordValid(password string) *model.AppError {
if utils.IsLicensed() && *utils.License().Features.PasswordRequirements {
return utils.IsPasswordValidWithSettings(password, &a.Config().PasswordSettings)
@@ -168,3 +192,26 @@ func (a *App) authenticateUser(user *model.User, password, mfaToken string) (*mo
}
}
}
+
+func ParseAuthTokenFromRequest(r *http.Request) (string, TokenLocation) {
+ authHeader := r.Header.Get(model.HEADER_AUTH)
+ if len(authHeader) > 6 && strings.ToUpper(authHeader[0:6]) == model.HEADER_BEARER {
+ // Default session token
+ return authHeader[7:], TokenLocationHeader
+ } else if len(authHeader) > 5 && strings.ToLower(authHeader[0:5]) == model.HEADER_TOKEN {
+ // OAuth token
+ return authHeader[6:], TokenLocationHeader
+ }
+
+ // Attempt to parse the token from the cookie
+ if cookie, err := r.Cookie(model.SESSION_COOKIE_TOKEN); err == nil {
+ return cookie.Value, TokenLocationCookie
+ }
+
+ // Attempt to parse token out of the query string
+ if token := r.URL.Query().Get("access_token"); token != "" {
+ return token, TokenLocationQueryString
+ }
+
+ return "", TokenLocationNotFound
+}
diff --git a/app/authentication_test.go b/app/authentication_test.go
new file mode 100644
index 000000000..f3014b1b8
--- /dev/null
+++ b/app/authentication_test.go
@@ -0,0 +1,52 @@
+// Copyright (c) 2017-present Mattermost, Inc. All Rights Reserved.
+// See License.txt for license information.
+
+package app
+
+import (
+ "net/http"
+ "net/http/httptest"
+ "strconv"
+ "testing"
+
+ "github.com/mattermost/mattermost-server/model"
+ "github.com/stretchr/testify/require"
+)
+
+func TestParseAuthTokenFromRequest(t *testing.T) {
+ cases := []struct {
+ header string
+ cookie string
+ query string
+ expectedToken string
+ expectedLocation TokenLocation
+ }{
+ {"", "", "", "", TokenLocationNotFound},
+ {"token mytoken", "", "", "mytoken", TokenLocationHeader},
+ {"BEARER mytoken", "", "", "mytoken", TokenLocationHeader},
+ {"", "mytoken", "", "mytoken", TokenLocationCookie},
+ {"", "", "mytoken", "mytoken", TokenLocationQueryString},
+ }
+
+ for testnum, tc := range cases {
+ pathname := "/test/here"
+ if tc.query != "" {
+ pathname += "?access_token=" + tc.query
+ }
+ req := httptest.NewRequest("GET", pathname, nil)
+ if tc.header != "" {
+ req.Header.Add(model.HEADER_AUTH, tc.header)
+ }
+ if tc.cookie != "" {
+ req.AddCookie(&http.Cookie{
+ Name: model.SESSION_COOKIE_TOKEN,
+ Value: tc.cookie,
+ })
+ }
+
+ token, location := ParseAuthTokenFromRequest(req)
+
+ require.Equal(t, tc.expectedToken, token, "Wrong token on test "+strconv.Itoa(testnum))
+ require.Equal(t, tc.expectedLocation, location, "Wrong location on test "+strconv.Itoa(testnum))
+ }
+}
diff --git a/app/ratelimit.go b/app/ratelimit.go
new file mode 100644
index 000000000..460088598
--- /dev/null
+++ b/app/ratelimit.go
@@ -0,0 +1,131 @@
+// Copyright (c) 2018-present Mattermost, Inc. All Rights Reserved.
+// See License.txt for license information.
+
+package app
+
+import (
+ "math"
+ "net/http"
+ "strconv"
+ "strings"
+
+ l4g "github.com/alecthomas/log4go"
+ "github.com/mattermost/mattermost-server/model"
+ "github.com/mattermost/mattermost-server/utils"
+ throttled "gopkg.in/throttled/throttled.v2"
+ "gopkg.in/throttled/throttled.v2/store/memstore"
+)
+
+type RateLimiter struct {
+ throttledRateLimiter *throttled.GCRARateLimiter
+ useAuth bool
+ useIP bool
+ header string
+}
+
+func NewRateLimiter(settings *model.RateLimitSettings) *RateLimiter {
+ store, err := memstore.New(*settings.MemoryStoreSize)
+ if err != nil {
+ l4g.Critical(utils.T("api.server.start_server.rate_limiting_memory_store"))
+ return nil
+ }
+
+ quota := throttled.RateQuota{
+ MaxRate: throttled.PerSec(*settings.PerSec),
+ MaxBurst: *settings.MaxBurst,
+ }
+
+ throttledRateLimiter, err := throttled.NewGCRARateLimiter(store, quota)
+ if err != nil {
+ l4g.Critical(utils.T("api.server.start_server.rate_limiting_rate_limiter"))
+ return nil
+ }
+
+ return &RateLimiter{
+ throttledRateLimiter: throttledRateLimiter,
+ useAuth: *settings.VaryByUser,
+ useIP: *settings.VaryByRemoteAddr,
+ header: settings.VaryByHeader,
+ }
+}
+
+func (rl *RateLimiter) GenerateKey(r *http.Request) string {
+ key := ""
+
+ if rl.useAuth {
+ token, tokenLocation := ParseAuthTokenFromRequest(r)
+ if tokenLocation != TokenLocationNotFound {
+ key += token
+ } else if rl.useIP { // If we don't find an authentication token and IP based is enabled, fall back to IP
+ key += utils.GetIpAddress(r)
+ }
+ } else if rl.useIP { // Only if Auth based is not enabed do we use a plain IP based
+ key += utils.GetIpAddress(r)
+ }
+
+ // Note that most of the time the user won't have to set this because the utils.GetIpAddress above tries the
+ // most common headers anyway.
+ if rl.header != "" {
+ key += strings.ToLower(r.Header.Get(rl.header))
+ }
+
+ return key
+}
+
+func (rl *RateLimiter) RateLimitWriter(key string, w http.ResponseWriter) bool {
+ limited, context, err := rl.throttledRateLimiter.RateLimit(key, 1)
+ if err != nil {
+ l4g.Critical("Internal server error when rate limiting. Rate Limiting broken. Error:" + err.Error())
+ return false
+ }
+
+ setRateLimitHeaders(w, context)
+
+ if limited {
+ l4g.Error("Denied due to throttling settings code=429 key=%v", key)
+ http.Error(w, "limit exceeded", 429)
+ }
+
+ return limited
+}
+
+func (rl *RateLimiter) UserIdRateLimit(userId string, w http.ResponseWriter) bool {
+ if rl.useAuth {
+ if rl.RateLimitWriter(userId, w) {
+ return true
+ }
+ }
+ return false
+}
+
+func (rl *RateLimiter) RateLimitHandler(wrappedHandler http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ key := rl.GenerateKey(r)
+ limited := rl.RateLimitWriter(key, w)
+
+ if !limited {
+ wrappedHandler.ServeHTTP(w, r)
+ }
+ })
+}
+
+// Copied from https://github.com/throttled/throttled http.go
+func setRateLimitHeaders(w http.ResponseWriter, context throttled.RateLimitResult) {
+ if v := context.Limit; v >= 0 {
+ w.Header().Add("X-RateLimit-Limit", strconv.Itoa(v))
+ }
+
+ if v := context.Remaining; v >= 0 {
+ w.Header().Add("X-RateLimit-Remaining", strconv.Itoa(v))
+ }
+
+ if v := context.ResetAfter; v >= 0 {
+ vi := int(math.Ceil(v.Seconds()))
+ w.Header().Add("X-RateLimit-Reset", strconv.Itoa(vi))
+ }
+
+ if v := context.RetryAfter; v >= 0 {
+ vi := int(math.Ceil(v.Seconds()))
+ w.Header().Add("Retry-After", strconv.Itoa(vi))
+ }
+}
diff --git a/app/ratelimit_test.go b/app/ratelimit_test.go
new file mode 100644
index 000000000..ddaa25710
--- /dev/null
+++ b/app/ratelimit_test.go
@@ -0,0 +1,67 @@
+// Copyright (c) 2018-present Mattermost, Inc. All Rights Reserved.
+// See License.txt for license information.
+
+package app
+
+import (
+ "net/http"
+ "net/http/httptest"
+ "strconv"
+ "testing"
+
+ "github.com/mattermost/mattermost-server/model"
+ "github.com/stretchr/testify/require"
+)
+
+func genRateLimitSettings(useAuth, useIP bool, header string) *model.RateLimitSettings {
+ return &model.RateLimitSettings{
+ Enable: model.NewBool(true),
+ PerSec: model.NewInt(10),
+ MaxBurst: model.NewInt(100),
+ MemoryStoreSize: model.NewInt(10000),
+ VaryByRemoteAddr: model.NewBool(useIP),
+ VaryByUser: model.NewBool(useAuth),
+ VaryByHeader: header,
+ }
+}
+
+func TestGenerateKey(t *testing.T) {
+ cases := []struct {
+ useAuth bool
+ useIP bool
+ header string
+ authTokenResult string
+ ipResult string
+ headerResult string
+ expectedKey string
+ }{
+ {false, false, "", "", "", "", ""},
+ {true, false, "", "resultkey", "notme", "notme", "resultkey"},
+ {false, true, "", "notme", "resultkey", "notme", "resultkey"},
+ {false, false, "myheader", "notme", "notme", "resultkey", "resultkey"},
+ {true, true, "", "resultkey", "ipaddr", "notme", "resultkey"},
+ {true, true, "", "", "ipaddr", "notme", "ipaddr"},
+ {true, true, "myheader", "resultkey", "ipaddr", "hadd", "resultkeyhadd"},
+ {true, true, "myheader", "", "ipaddr", "hadd", "ipaddrhadd"},
+ }
+
+ for testnum, tc := range cases {
+ req := httptest.NewRequest("GET", "/", nil)
+ if tc.authTokenResult != "" {
+ req.AddCookie(&http.Cookie{
+ Name: model.SESSION_COOKIE_TOKEN,
+ Value: tc.authTokenResult,
+ })
+ }
+ req.RemoteAddr = tc.ipResult + ":80"
+ if tc.headerResult != "" {
+ req.Header.Set(tc.header, tc.headerResult)
+ }
+
+ rateLimiter := NewRateLimiter(genRateLimitSettings(tc.useAuth, tc.useIP, tc.header))
+
+ key := rateLimiter.GenerateKey(req)
+
+ require.Equal(t, tc.expectedKey, key, "Wrong key on test "+strconv.Itoa(testnum))
+ }
+}
diff --git a/app/server.go b/app/server.go
index c008da3a1..2a94bf2c7 100644
--- a/app/server.go
+++ b/app/server.go
@@ -18,8 +18,6 @@ import (
"github.com/gorilla/handlers"
"github.com/gorilla/mux"
"golang.org/x/crypto/acme/autocert"
- "gopkg.in/throttled/throttled.v2"
- "gopkg.in/throttled/throttled.v2/store/memstore"
"github.com/mattermost/mattermost-server/model"
"github.com/mattermost/mattermost-server/store"
@@ -32,6 +30,7 @@ type Server struct {
Router *mux.Router
Server *http.Server
ListenAddr *net.TCPAddr
+ RateLimiter *RateLimiter
didFinishListen chan struct{}
}
@@ -84,10 +83,26 @@ func (cw *CorsWrapper) ServeHTTP(w http.ResponseWriter, r *http.Request) {
const TIME_TO_WAIT_FOR_CONNECTIONS_TO_CLOSE_ON_SERVER_SHUTDOWN = time.Second
-type VaryBy struct{}
+type VaryBy struct {
+ useIP bool
+ useAuth bool
+}
func (m *VaryBy) Key(r *http.Request) string {
- return utils.GetIpAddress(r)
+ key := ""
+
+ if m.useAuth {
+ token, tokenLocation := ParseAuthTokenFromRequest(r)
+ if tokenLocation != TokenLocationNotFound {
+ key += token
+ } else if m.useIP { // If we don't find an authentication token and IP based is enabled, fall back to IP
+ key += utils.GetIpAddress(r)
+ }
+ } else if m.useIP { // Only if Auth based is not enabed do we use a plain IP based
+ key = utils.GetIpAddress(r)
+ }
+
+ return key
}
func redirectHTTPToHTTPS(w http.ResponseWriter, r *http.Request) {
@@ -109,33 +124,9 @@ func (a *App) StartServer() {
if *a.Config().RateLimitSettings.Enable {
l4g.Info(utils.T("api.server.start_server.rate.info"))
- store, err := memstore.New(*a.Config().RateLimitSettings.MemoryStoreSize)
- if err != nil {
- l4g.Critical(utils.T("api.server.start_server.rate_limiting_memory_store"))
- return
- }
-
- quota := throttled.RateQuota{
- MaxRate: throttled.PerSec(*a.Config().RateLimitSettings.PerSec),
- MaxBurst: *a.Config().RateLimitSettings.MaxBurst,
- }
-
- rateLimiter, err := throttled.NewGCRARateLimiter(store, quota)
- if err != nil {
- l4g.Critical(utils.T("api.server.start_server.rate_limiting_rate_limiter"))
- return
- }
-
- httpRateLimiter := throttled.HTTPRateLimiter{
- RateLimiter: rateLimiter,
- VaryBy: &VaryBy{},
- DeniedHandler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- l4g.Error("%v: Denied due to throttling settings code=429 ip=%v", r.URL.Path, utils.GetIpAddress(r))
- throttled.DefaultDeniedHandler.ServeHTTP(w, r)
- }),
- }
+ a.Srv.RateLimiter = NewRateLimiter(&a.Config().RateLimitSettings)
- handler = httpRateLimiter.RateLimit(handler)
+ handler = a.Srv.RateLimiter.RateLimitHandler(handler)
}
a.Srv.Server = &http.Server{
diff --git a/config/default.json b/config/default.json
index 4bcccdf19..e0fda74cd 100644
--- a/config/default.json
+++ b/config/default.json
@@ -180,6 +180,7 @@
"MaxBurst": 100,
"MemoryStoreSize": 10000,
"VaryByRemoteAddr": true,
+ "VaryByUser": false,
"VaryByHeader": ""
},
"PrivacySettings": {
diff --git a/model/config.go b/model/config.go
index 525fc71ed..b7888ab13 100644
--- a/model/config.go
+++ b/model/config.go
@@ -802,7 +802,8 @@ type RateLimitSettings struct {
PerSec *int
MaxBurst *int
MemoryStoreSize *int
- VaryByRemoteAddr bool
+ VaryByRemoteAddr *bool
+ VaryByUser *bool
VaryByHeader string
}
@@ -822,6 +823,14 @@ func (s *RateLimitSettings) SetDefaults() {
if s.MemoryStoreSize == nil {
s.MemoryStoreSize = NewInt(10000)
}
+
+ if s.VaryByRemoteAddr == nil {
+ s.VaryByRemoteAddr = NewBool(true)
+ }
+
+ if s.VaryByUser == nil {
+ s.VaryByUser = NewBool(false)
+ }
}
type PrivacySettings struct {