summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--app/ratelimit.go11
-rw-r--r--app/ratelimit_test.go17
-rw-r--r--app/server.go9
3 files changed, 28 insertions, 9 deletions
diff --git a/app/ratelimit.go b/app/ratelimit.go
index 460088598..13508f36f 100644
--- a/app/ratelimit.go
+++ b/app/ratelimit.go
@@ -12,6 +12,7 @@ import (
l4g "github.com/alecthomas/log4go"
"github.com/mattermost/mattermost-server/model"
"github.com/mattermost/mattermost-server/utils"
+ "github.com/pkg/errors"
throttled "gopkg.in/throttled/throttled.v2"
"gopkg.in/throttled/throttled.v2/store/memstore"
)
@@ -23,11 +24,10 @@ type RateLimiter struct {
header string
}
-func NewRateLimiter(settings *model.RateLimitSettings) *RateLimiter {
+func NewRateLimiter(settings *model.RateLimitSettings) (*RateLimiter, error) {
store, err := memstore.New(*settings.MemoryStoreSize)
if err != nil {
- l4g.Critical(utils.T("api.server.start_server.rate_limiting_memory_store"))
- return nil
+ return nil, errors.Wrap(err, utils.T("api.server.start_server.rate_limiting_memory_store"))
}
quota := throttled.RateQuota{
@@ -37,8 +37,7 @@ func NewRateLimiter(settings *model.RateLimitSettings) *RateLimiter {
throttledRateLimiter, err := throttled.NewGCRARateLimiter(store, quota)
if err != nil {
- l4g.Critical(utils.T("api.server.start_server.rate_limiting_rate_limiter"))
- return nil
+ return nil, errors.Wrap(err, utils.T("api.server.start_server.rate_limiting_rate_limiter"))
}
return &RateLimiter{
@@ -46,7 +45,7 @@ func NewRateLimiter(settings *model.RateLimitSettings) *RateLimiter {
useAuth: *settings.VaryByUser,
useIP: *settings.VaryByRemoteAddr,
header: settings.VaryByHeader,
- }
+ }, nil
}
func (rl *RateLimiter) GenerateKey(r *http.Request) string {
diff --git a/app/ratelimit_test.go b/app/ratelimit_test.go
index ddaa25710..fb157b2b0 100644
--- a/app/ratelimit_test.go
+++ b/app/ratelimit_test.go
@@ -25,6 +25,21 @@ func genRateLimitSettings(useAuth, useIP bool, header string) *model.RateLimitSe
}
}
+func TestNewRateLimiterSuccess(t *testing.T) {
+ settings := genRateLimitSettings(false, false, "")
+ rateLimiter, err := NewRateLimiter(settings)
+ require.NotNil(t, rateLimiter)
+ require.NoError(t, err)
+}
+
+func TestNewRateLimiterFailure(t *testing.T) {
+ invalidSettings := genRateLimitSettings(false, false, "")
+ invalidSettings.MaxBurst = model.NewInt(-100)
+ rateLimiter, err := NewRateLimiter(invalidSettings)
+ require.Nil(t, rateLimiter)
+ require.Error(t, err)
+}
+
func TestGenerateKey(t *testing.T) {
cases := []struct {
useAuth bool
@@ -58,7 +73,7 @@ func TestGenerateKey(t *testing.T) {
req.Header.Set(tc.header, tc.headerResult)
}
- rateLimiter := NewRateLimiter(genRateLimitSettings(tc.useAuth, tc.useIP, tc.header))
+ rateLimiter, _ := NewRateLimiter(genRateLimitSettings(tc.useAuth, tc.useIP, tc.header))
key := rateLimiter.GenerateKey(req)
diff --git a/app/server.go b/app/server.go
index 2a94bf2c7..1659908b6 100644
--- a/app/server.go
+++ b/app/server.go
@@ -124,9 +124,14 @@ func (a *App) StartServer() {
if *a.Config().RateLimitSettings.Enable {
l4g.Info(utils.T("api.server.start_server.rate.info"))
- a.Srv.RateLimiter = NewRateLimiter(&a.Config().RateLimitSettings)
+ rateLimiter, err := NewRateLimiter(&a.Config().RateLimitSettings)
+ if err != nil {
+ l4g.Critical(err.Error())
+ return
+ }
- handler = a.Srv.RateLimiter.RateLimitHandler(handler)
+ a.Srv.RateLimiter = rateLimiter
+ handler = rateLimiter.RateLimitHandler(handler)
}
a.Srv.Server = &http.Server{