summaryrefslogtreecommitdiffstats
path: root/api/server.go
diff options
context:
space:
mode:
Diffstat (limited to 'api/server.go')
-rw-r--r--api/server.go67
1 files changed, 48 insertions, 19 deletions
diff --git a/api/server.go b/api/server.go
index 6e0ca49f0..b16ad6e8e 100644
--- a/api/server.go
+++ b/api/server.go
@@ -10,8 +10,8 @@ import (
"github.com/gorilla/mux"
"github.com/mattermost/platform/store"
"github.com/mattermost/platform/utils"
- "gopkg.in/throttled/throttled.v1"
- throttledStore "gopkg.in/throttled/throttled.v1/store"
+ "gopkg.in/throttled/throttled.v2"
+ "gopkg.in/throttled/throttled.v2/store/memstore"
"net/http"
"strings"
"time"
@@ -39,6 +39,31 @@ func NewServer() {
Srv.Router.NotFoundHandler = http.HandlerFunc(Handle404)
}
+type VaryBy struct{}
+
+func (m *VaryBy) Key(r *http.Request) string {
+ return GetIpAddress(r)
+}
+
+func initalizeThrottledVaryBy() *throttled.VaryBy {
+ vary := throttled.VaryBy{}
+
+ if utils.Cfg.RateLimitSettings.VaryByRemoteAddr {
+ vary.RemoteAddr = true
+ }
+
+ if len(utils.Cfg.RateLimitSettings.VaryByHeader) > 0 {
+ vary.Headers = strings.Fields(utils.Cfg.RateLimitSettings.VaryByHeader)
+
+ if utils.Cfg.RateLimitSettings.VaryByRemoteAddr {
+ l4g.Warn(utils.T("api.server.start_server.rate.warn"))
+ vary.RemoteAddr = false
+ }
+ }
+
+ return &vary
+}
+
func StartServer() {
l4g.Info(utils.T("api.server.start_server.starting.info"))
l4g.Info(utils.T("api.server.start_server.listening.info"), utils.Cfg.ServiceSettings.ListenAddress)
@@ -48,29 +73,33 @@ func StartServer() {
if utils.Cfg.RateLimitSettings.EnableRateLimiter {
l4g.Info(utils.T("api.server.start_server.rate.info"))
- vary := throttled.VaryBy{}
-
- if utils.Cfg.RateLimitSettings.VaryByRemoteAddr {
- vary.RemoteAddr = true
+ store, err := memstore.New(utils.Cfg.RateLimitSettings.MemoryStoreSize)
+ if err != nil {
+ l4g.Critical(utils.T("api.server.start_server.rate_limiting_memory_store"))
+ return
}
- if len(utils.Cfg.RateLimitSettings.VaryByHeader) > 0 {
- vary.Headers = strings.Fields(utils.Cfg.RateLimitSettings.VaryByHeader)
-
- if utils.Cfg.RateLimitSettings.VaryByRemoteAddr {
- l4g.Warn(utils.T("api.server.start_server.rate.warn"))
- vary.RemoteAddr = false
- }
+ quota := throttled.RateQuota{
+ MaxRate: throttled.PerSec(utils.Cfg.RateLimitSettings.PerSec),
+ MaxBurst: 100,
}
- th := throttled.RateLimit(throttled.PerSec(utils.Cfg.RateLimitSettings.PerSec), &vary, throttledStore.NewMemStore(utils.Cfg.RateLimitSettings.MemoryStoreSize))
+ rateLimiter, err := throttled.NewGCRARateLimiter(store, quota)
+ if err != nil {
+ l4g.Critical(utils.T("api.server.start_server.rate_limiting_rate_limiter"))
+ return
+ }
- th.DeniedHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- l4g.Error("%v: code=429 ip=%v", r.URL.Path, GetIpAddress(r))
- throttled.DefaultDeniedHandler.ServeHTTP(w, r)
- })
+ httpRateLimiter := throttled.HTTPRateLimiter{
+ RateLimiter: rateLimiter,
+ VaryBy: &VaryBy{},
+ DeniedHandler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ l4g.Error("%v: Denied due to throttling settings code=429 ip=%v", r.URL.Path, GetIpAddress(r))
+ throttled.DefaultDeniedHandler.ServeHTTP(w, r)
+ }),
+ }
- handler = th.Throttle(&CorsWrapper{Srv.Router})
+ handler = httpRateLimiter.RateLimit(handler)
}
go func() {