diff options
author | Brad Howes <bradhowes@mac.com> | 2017-03-23 14:10:52 +0100 |
---|---|---|
committer | Christopher Speller <crspeller@gmail.com> | 2017-03-23 09:10:52 -0400 |
commit | 120f5a6f8a5f4ab05aace89ae710698cf68d0564 (patch) | |
tree | 1a787c7306c5f03f5d873e0ddb129d906e5487b9 | |
parent | 34cb70d005ba5ebea7398646db6e242baa81b701 (diff) | |
download | chat-120f5a6f8a5f4ab05aace89ae710698cf68d0564.tar.gz chat-120f5a6f8a5f4ab05aace89ae710698cf68d0564.tar.bz2 chat-120f5a6f8a5f4ab05aace89ae710698cf68d0564.zip |
Websocket CORS Support (#5667)
* Second attept at patching api/websocket.go for CORS support.
* Missing include
* Fixed whitespace formatting so that gofmt passes.
* Added tests for CORS filtering
-rw-r--r-- | api/websocket.go | 17 | ||||
-rw-r--r-- | api/websocket_test.go | 30 |
2 files changed, 46 insertions, 1 deletions
diff --git a/api/websocket.go b/api/websocket.go index 5c0858910..2de9abb0a 100644 --- a/api/websocket.go +++ b/api/websocket.go @@ -5,6 +5,7 @@ package api import ( "net/http" + "strings" l4g "github.com/alecthomas/log4go" "github.com/gorilla/websocket" @@ -19,11 +20,25 @@ func InitWebSocket() { app.HubStart() } +type OriginCheckerProc func(*http.Request) bool + +func OriginChecker(r *http.Request) bool { + origin := r.Header.Get("Origin") + return *utils.Cfg.ServiceSettings.AllowCorsFrom == "*" || strings.Contains(origin, *utils.Cfg.ServiceSettings.AllowCorsFrom) +} + func connect(c *Context, w http.ResponseWriter, r *http.Request) { + + var originChecker OriginCheckerProc = nil + + if len(*utils.Cfg.ServiceSettings.AllowCorsFrom) > 0 { + originChecker = OriginChecker + } + upgrader := websocket.Upgrader{ ReadBufferSize: model.SOCKET_MAX_MESSAGE_SIZE_KB, WriteBufferSize: model.SOCKET_MAX_MESSAGE_SIZE_KB, - CheckOrigin: nil, + CheckOrigin: originChecker, } ws, err := upgrader.Upgrade(w, r, nil) diff --git a/api/websocket_test.go b/api/websocket_test.go index ab2959b03..d3d8fc4b2 100644 --- a/api/websocket_test.go +++ b/api/websocket_test.go @@ -316,6 +316,7 @@ func TestCreateDirectChannelWithSocket(t *testing.T) { func TestWebsocketOriginSecurity(t *testing.T) { Setup().InitBasic() + url := "ws://localhost" + utils.Cfg.ServiceSettings.ListenAddress // Should fail because origin doesn't match @@ -333,6 +334,35 @@ func TestWebsocketOriginSecurity(t *testing.T) { if err != nil { t.Fatal(err) } + + // Should succeed now because open CORS + *utils.Cfg.ServiceSettings.AllowCorsFrom = "*" + _, _, err = websocket.DefaultDialer.Dial(url+model.API_URL_SUFFIX_V3+"/users/websocket", http.Header{ + "Origin": []string{"http://www.evil.com"}, + }) + if err != nil { + t.Fatal(err) + } + + // Should succeed now because matching CORS + *utils.Cfg.ServiceSettings.AllowCorsFrom = "www.evil.com" + _, _, err = websocket.DefaultDialer.Dial(url+model.API_URL_SUFFIX_V3+"/users/websocket", http.Header{ + "Origin": []string{"http://www.evil.com"}, + }) + if err != nil { + t.Fatal(err) + } + + // Should fail because non-matching CORS + *utils.Cfg.ServiceSettings.AllowCorsFrom = "www.good.com" + _, _, err = websocket.DefaultDialer.Dial(url+model.API_URL_SUFFIX_V3+"/users/websocket", http.Header{ + "Origin": []string{"http://www.evil.com"}, + }) + if err == nil { + t.Fatal("Should have errored because Origin contain AllowCorsFrom") + } + + *utils.Cfg.ServiceSettings.AllowCorsFrom = "" } func TestZZWebSocketTearDown(t *testing.T) { |