summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--api/websocket_test.go9
-rw-r--r--app/server.go5
-rw-r--r--utils/api.go10
3 files changed, 20 insertions, 4 deletions
diff --git a/api/websocket_test.go b/api/websocket_test.go
index a65ebc02e..18e1a6426 100644
--- a/api/websocket_test.go
+++ b/api/websocket_test.go
@@ -362,6 +362,15 @@ func TestWebsocketOriginSecurity(t *testing.T) {
t.Fatal("Should have errored because Origin contain AllowCorsFrom")
}
+ // Should fail because non-matching CORS
+ *utils.Cfg.ServiceSettings.AllowCorsFrom = "http://www.good.com"
+ _, _, err = websocket.DefaultDialer.Dial(url+model.API_URL_SUFFIX_V3+"/users/websocket", http.Header{
+ "Origin": []string{"http://www.good.co"},
+ })
+ if err == nil {
+ t.Fatal("Should have errored because Origin does not match host! SECURITY ISSUE!")
+ }
+
*utils.Cfg.ServiceSettings.AllowCorsFrom = ""
}
diff --git a/app/server.go b/app/server.go
index a5090a597..a5b2dbda9 100644
--- a/app/server.go
+++ b/app/server.go
@@ -53,9 +53,8 @@ type CorsWrapper struct {
func (cw *CorsWrapper) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if len(*utils.Cfg.ServiceSettings.AllowCorsFrom) > 0 {
- origin := r.Header.Get("Origin")
- if *utils.Cfg.ServiceSettings.AllowCorsFrom == "*" || strings.Contains(*utils.Cfg.ServiceSettings.AllowCorsFrom, origin) {
- w.Header().Set("Access-Control-Allow-Origin", origin)
+ if utils.OriginChecker(r) {
+ w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin"))
if r.Method == "OPTIONS" {
w.Header().Set(
diff --git a/utils/api.go b/utils/api.go
index 663f53c16..d175e0c13 100644
--- a/utils/api.go
+++ b/utils/api.go
@@ -15,7 +15,15 @@ type OriginCheckerProc func(*http.Request) bool
func OriginChecker(r *http.Request) bool {
origin := r.Header.Get("Origin")
- return *Cfg.ServiceSettings.AllowCorsFrom == "*" || strings.Contains(*Cfg.ServiceSettings.AllowCorsFrom, origin)
+ if *Cfg.ServiceSettings.AllowCorsFrom == "*" {
+ return true
+ }
+ for _, allowed := range strings.Split(*Cfg.ServiceSettings.AllowCorsFrom, " ") {
+ if allowed == origin {
+ return true
+ }
+ }
+ return false
}
func GetOriginChecker(r *http.Request) OriginCheckerProc {