From 5c3c909c8541f26ae09577338d2302bed2a2f3a9 Mon Sep 17 00:00:00 2001 From: Chris Date: Thu, 13 Jul 2017 14:02:33 -0700 Subject: Tweak WebSocket header-processing (#6929) * fix * consolidate code --- api/websocket_test.go | 9 +++++++++ app/server.go | 5 ++--- utils/api.go | 10 +++++++++- 3 files changed, 20 insertions(+), 4 deletions(-) diff --git a/api/websocket_test.go b/api/websocket_test.go index a65ebc02e..18e1a6426 100644 --- a/api/websocket_test.go +++ b/api/websocket_test.go @@ -362,6 +362,15 @@ func TestWebsocketOriginSecurity(t *testing.T) { t.Fatal("Should have errored because Origin contain AllowCorsFrom") } + // Should fail because non-matching CORS + *utils.Cfg.ServiceSettings.AllowCorsFrom = "http://www.good.com" + _, _, err = websocket.DefaultDialer.Dial(url+model.API_URL_SUFFIX_V3+"/users/websocket", http.Header{ + "Origin": []string{"http://www.good.co"}, + }) + if err == nil { + t.Fatal("Should have errored because Origin does not match host! SECURITY ISSUE!") + } + *utils.Cfg.ServiceSettings.AllowCorsFrom = "" } diff --git a/app/server.go b/app/server.go index a5090a597..a5b2dbda9 100644 --- a/app/server.go +++ b/app/server.go @@ -53,9 +53,8 @@ type CorsWrapper struct { func (cw *CorsWrapper) ServeHTTP(w http.ResponseWriter, r *http.Request) { if len(*utils.Cfg.ServiceSettings.AllowCorsFrom) > 0 { - origin := r.Header.Get("Origin") - if *utils.Cfg.ServiceSettings.AllowCorsFrom == "*" || strings.Contains(*utils.Cfg.ServiceSettings.AllowCorsFrom, origin) { - w.Header().Set("Access-Control-Allow-Origin", origin) + if utils.OriginChecker(r) { + w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin")) if r.Method == "OPTIONS" { w.Header().Set( diff --git a/utils/api.go b/utils/api.go index 663f53c16..d175e0c13 100644 --- a/utils/api.go +++ b/utils/api.go @@ -15,7 +15,15 @@ type OriginCheckerProc func(*http.Request) bool func OriginChecker(r *http.Request) bool { origin := r.Header.Get("Origin") - return *Cfg.ServiceSettings.AllowCorsFrom == "*" || strings.Contains(*Cfg.ServiceSettings.AllowCorsFrom, origin) + if *Cfg.ServiceSettings.AllowCorsFrom == "*" { + return true + } + for _, allowed := range strings.Split(*Cfg.ServiceSettings.AllowCorsFrom, " ") { + if allowed == origin { + return true + } + } + return false } func GetOriginChecker(r *http.Request) OriginCheckerProc { -- cgit v1.2.3-1-g7c22