diff options
-rw-r--r-- | api/websocket.go | 4 | ||||
-rw-r--r-- | api4/websocket.go | 4 | ||||
-rw-r--r-- | app/server.go | 11 | ||||
-rw-r--r-- | utils/api.go | 16 |
4 files changed, 17 insertions, 18 deletions
diff --git a/api/websocket.go b/api/websocket.go index e5e2390c7..518f81d58 100644 --- a/api/websocket.go +++ b/api/websocket.go @@ -18,12 +18,10 @@ func (api *API) InitWebSocket() { } func connect(c *Context, w http.ResponseWriter, r *http.Request) { - originChecker := utils.GetOriginChecker(r) - upgrader := websocket.Upgrader{ ReadBufferSize: model.SOCKET_MAX_MESSAGE_SIZE_KB, WriteBufferSize: model.SOCKET_MAX_MESSAGE_SIZE_KB, - CheckOrigin: originChecker, + CheckOrigin: c.App.OriginChecker(), } ws, err := upgrader.Upgrade(w, r, nil) diff --git a/api4/websocket.go b/api4/websocket.go index c148ec3bf..6465e49e9 100644 --- a/api4/websocket.go +++ b/api4/websocket.go @@ -19,12 +19,10 @@ func (api *API) InitWebSocket() { } func connectWebSocket(c *Context, w http.ResponseWriter, r *http.Request) { - originChecker := utils.GetOriginChecker(r) - upgrader := websocket.Upgrader{ ReadBufferSize: model.SOCKET_MAX_MESSAGE_SIZE_KB, WriteBufferSize: model.SOCKET_MAX_MESSAGE_SIZE_KB, - CheckOrigin: originChecker, + CheckOrigin: c.App.OriginChecker(), } ws, err := upgrader.Upgrade(w, r, nil) diff --git a/app/server.go b/app/server.go index 0739d4989..eb2fa9b32 100644 --- a/app/server.go +++ b/app/server.go @@ -58,8 +58,8 @@ type CorsWrapper struct { } func (cw *CorsWrapper) ServeHTTP(w http.ResponseWriter, r *http.Request) { - if len(*cw.config().ServiceSettings.AllowCorsFrom) > 0 { - if utils.OriginChecker(r) { + if allowed := *cw.config().ServiceSettings.AllowCorsFrom; allowed != "" { + if utils.CheckOrigin(r, allowed) { w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin")) if r.Method == "OPTIONS" { @@ -252,6 +252,13 @@ func (a *App) StopServer() { } } +func (a *App) OriginChecker() func(*http.Request) bool { + if allowed := *a.Config().ServiceSettings.AllowCorsFrom; allowed != "" { + return utils.OriginChecker(allowed) + } + return nil +} + // This is required to re-use the underlying connection and not take up file descriptors func consumeAndClose(r *http.Response) { if r.Body != nil { diff --git a/utils/api.go b/utils/api.go index a96e98254..48382d1fe 100644 --- a/utils/api.go +++ b/utils/api.go @@ -11,14 +11,12 @@ import ( "github.com/mattermost/mattermost-server/model" ) -type OriginCheckerProc func(*http.Request) bool - -func OriginChecker(r *http.Request) bool { +func CheckOrigin(r *http.Request, allowedOrigins string) bool { origin := r.Header.Get("Origin") - if *Cfg.ServiceSettings.AllowCorsFrom == "*" { + if allowedOrigins == "*" { return true } - for _, allowed := range strings.Split(*Cfg.ServiceSettings.AllowCorsFrom, " ") { + for _, allowed := range strings.Split(allowedOrigins, " ") { if allowed == origin { return true } @@ -26,12 +24,10 @@ func OriginChecker(r *http.Request) bool { return false } -func GetOriginChecker(r *http.Request) OriginCheckerProc { - if len(*Cfg.ServiceSettings.AllowCorsFrom) > 0 { - return OriginChecker +func OriginChecker(allowedOrigins string) func(*http.Request) bool { + return func(r *http.Request) bool { + return CheckOrigin(r, allowedOrigins) } - - return nil } func RenderWebError(err *model.AppError, w http.ResponseWriter, r *http.Request) { |