From 120f5a6f8a5f4ab05aace89ae710698cf68d0564 Mon Sep 17 00:00:00 2001 From: Brad Howes Date: Thu, 23 Mar 2017 14:10:52 +0100 Subject: Websocket CORS Support (#5667) * Second attept at patching api/websocket.go for CORS support. * Missing include * Fixed whitespace formatting so that gofmt passes. * Added tests for CORS filtering --- api/websocket.go | 17 ++++++++++++++++- api/websocket_test.go | 30 ++++++++++++++++++++++++++++++ 2 files changed, 46 insertions(+), 1 deletion(-) (limited to 'api') diff --git a/api/websocket.go b/api/websocket.go index 5c0858910..2de9abb0a 100644 --- a/api/websocket.go +++ b/api/websocket.go @@ -5,6 +5,7 @@ package api import ( "net/http" + "strings" l4g "github.com/alecthomas/log4go" "github.com/gorilla/websocket" @@ -19,11 +20,25 @@ func InitWebSocket() { app.HubStart() } +type OriginCheckerProc func(*http.Request) bool + +func OriginChecker(r *http.Request) bool { + origin := r.Header.Get("Origin") + return *utils.Cfg.ServiceSettings.AllowCorsFrom == "*" || strings.Contains(origin, *utils.Cfg.ServiceSettings.AllowCorsFrom) +} + func connect(c *Context, w http.ResponseWriter, r *http.Request) { + + var originChecker OriginCheckerProc = nil + + if len(*utils.Cfg.ServiceSettings.AllowCorsFrom) > 0 { + originChecker = OriginChecker + } + upgrader := websocket.Upgrader{ ReadBufferSize: model.SOCKET_MAX_MESSAGE_SIZE_KB, WriteBufferSize: model.SOCKET_MAX_MESSAGE_SIZE_KB, - CheckOrigin: nil, + CheckOrigin: originChecker, } ws, err := upgrader.Upgrade(w, r, nil) diff --git a/api/websocket_test.go b/api/websocket_test.go index ab2959b03..d3d8fc4b2 100644 --- a/api/websocket_test.go +++ b/api/websocket_test.go @@ -316,6 +316,7 @@ func TestCreateDirectChannelWithSocket(t *testing.T) { func TestWebsocketOriginSecurity(t *testing.T) { Setup().InitBasic() + url := "ws://localhost" + utils.Cfg.ServiceSettings.ListenAddress // Should fail because origin doesn't match @@ -333,6 +334,35 @@ func TestWebsocketOriginSecurity(t *testing.T) { if err != nil { t.Fatal(err) } + + // Should succeed now because open CORS + *utils.Cfg.ServiceSettings.AllowCorsFrom = "*" + _, _, err = websocket.DefaultDialer.Dial(url+model.API_URL_SUFFIX_V3+"/users/websocket", http.Header{ + "Origin": []string{"http://www.evil.com"}, + }) + if err != nil { + t.Fatal(err) + } + + // Should succeed now because matching CORS + *utils.Cfg.ServiceSettings.AllowCorsFrom = "www.evil.com" + _, _, err = websocket.DefaultDialer.Dial(url+model.API_URL_SUFFIX_V3+"/users/websocket", http.Header{ + "Origin": []string{"http://www.evil.com"}, + }) + if err != nil { + t.Fatal(err) + } + + // Should fail because non-matching CORS + *utils.Cfg.ServiceSettings.AllowCorsFrom = "www.good.com" + _, _, err = websocket.DefaultDialer.Dial(url+model.API_URL_SUFFIX_V3+"/users/websocket", http.Header{ + "Origin": []string{"http://www.evil.com"}, + }) + if err == nil { + t.Fatal("Should have errored because Origin contain AllowCorsFrom") + } + + *utils.Cfg.ServiceSettings.AllowCorsFrom = "" } func TestZZWebSocketTearDown(t *testing.T) { -- cgit v1.2.3-1-g7c22