From bae26ec268aef4e85d5055f1b83c6b3992bf178f Mon Sep 17 00:00:00 2001 From: Christopher Speller Date: Thu, 26 Jul 2018 08:31:22 -0700 Subject: MM-11160 Adding proper CORS support. (#9152) * Adding proper CORS support. * Better CORS tests. --- Gopkg.lock | 8 +- api4/apitestlib.go | 13 +- api4/cors_test.go | 150 +++++++++++++ app/server.go | 55 ++--- config/default.json | 3 + model/config.go | 15 ++ vendor/github.com/rs/cors/.travis.yml | 8 + vendor/github.com/rs/cors/LICENSE | 19 ++ vendor/github.com/rs/cors/README.md | 105 +++++++++ vendor/github.com/rs/cors/cors.go | 407 ++++++++++++++++++++++++++++++++++ vendor/github.com/rs/cors/go.mod | 1 + vendor/github.com/rs/cors/utils.go | 70 ++++++ 12 files changed, 819 insertions(+), 35 deletions(-) create mode 100644 api4/cors_test.go create mode 100644 vendor/github.com/rs/cors/.travis.yml create mode 100644 vendor/github.com/rs/cors/LICENSE create mode 100644 vendor/github.com/rs/cors/README.md create mode 100644 vendor/github.com/rs/cors/cors.go create mode 100644 vendor/github.com/rs/cors/go.mod create mode 100644 vendor/github.com/rs/cors/utils.go diff --git a/Gopkg.lock b/Gopkg.lock index 11b29ae3d..cf6036903 100644 --- a/Gopkg.lock +++ b/Gopkg.lock @@ -456,6 +456,12 @@ ] revision = "ae68e2d4c00fed4943b5f6698d504a5fe083da8a" +[[projects]] + name = "github.com/rs/cors" + packages = ["."] + revision = "ca016a06a5753f8ba03029c0aa5e54afb1bf713f" + version = "v1.4.0" + [[projects]] branch = "go1" name = "github.com/rwcarlsen/goexif" @@ -773,6 +779,6 @@ [solve-meta] analyzer-name = "dep" analyzer-version = 1 - inputs-digest = "70c0a23e6c8edc03d3c11a5227b2a43caea33f6170bfde2cf7698dcddf529036" + inputs-digest = "254723dc4684977f7472e728e16f347059481f6b0ef917375f31e23d1b6c6ad0" solver-name = "gps-cdcl" solver-version = 1 diff --git a/api4/apitestlib.go b/api4/apitestlib.go index ff7d47b26..fce44cfa1 100644 --- a/api4/apitestlib.go +++ b/api4/apitestlib.go @@ -75,7 +75,7 @@ func StopTestStore() { } } -func setupTestHelper(enterprise bool) *TestHelper { +func setupTestHelper(enterprise bool, updateConfig func(*model.Config)) *TestHelper { permConfig, err := os.Open(utils.FindConfigFile("config.json")) if err != nil { panic(err) @@ -115,6 +115,9 @@ func setupTestHelper(enterprise bool) *TestHelper { if testStore != nil { th.App.UpdateConfig(func(cfg *model.Config) { *cfg.ServiceSettings.ListenAddress = ":0" }) } + if updateConfig != nil { + th.App.UpdateConfig(updateConfig) + } serverErr := th.App.StartServer() if serverErr != nil { panic(serverErr) @@ -161,11 +164,15 @@ func setupTestHelper(enterprise bool) *TestHelper { } func SetupEnterprise() *TestHelper { - return setupTestHelper(true) + return setupTestHelper(true, nil) } func Setup() *TestHelper { - return setupTestHelper(false) + return setupTestHelper(false, nil) +} + +func SetupConfig(updateConfig func(cfg *model.Config)) *TestHelper { + return setupTestHelper(false, updateConfig) } func (me *TestHelper) TearDown() { diff --git a/api4/cors_test.go b/api4/cors_test.go new file mode 100644 index 000000000..74702b284 --- /dev/null +++ b/api4/cors_test.go @@ -0,0 +1,150 @@ +package api4 + +import ( + "fmt" + "net/http" + "testing" + + "github.com/mattermost/mattermost-server/model" + "github.com/stretchr/testify/assert" +) + +const ( + acAllowOrigin = "Access-Control-Allow-Origin" + acExposeHeaders = "Access-Control-Expose-Headers" + acMaxAge = "Access-Control-Max-Age" + acAllowCredentials = "Access-Control-Allow-Credentials" + acAllowMethods = "Access-Control-Allow-Methods" + acAllowHeaders = "Access-Control-Allow-Headers" +) + +func TestCORSRequestHandling(t *testing.T) { + for name, testcase := range map[string]struct { + AllowCorsFrom string + CorsExposedHeaders string + CorsAllowCredentials bool + ModifyRequest func(req *http.Request) + ExpectedAllowOrigin string + ExpectedExposeHeaders string + ExpectedAllowCredentials string + }{ + "NoCORS": { + "", + "", + false, + func(req *http.Request) { + }, + "", + "", + "", + }, + "CORSEnabled": { + "http://somewhere.com", + "", + false, + func(req *http.Request) { + }, + "", + "", + "", + }, + "CORSEnabledStarOrigin": { + "*", + "", + false, + func(req *http.Request) { + req.Header.Set("Origin", "http://pre-release.mattermost.com") + }, + "*", + "", + "", + }, + "CORSEnabledStarNoOrigin": { // CORS spec requires this, not a bug. + "*", + "", + false, + func(req *http.Request) { + }, + "", + "", + "", + }, + "CORSEnabledMatching": { + "http://mattermost.com", + "", + false, + func(req *http.Request) { + req.Header.Set("Origin", "http://mattermost.com") + }, + "http://mattermost.com", + "", + "", + }, + "CORSEnabledMultiple": { + "http://spinmint.com http://mattermost.com", + "", + false, + func(req *http.Request) { + req.Header.Set("Origin", "http://mattermost.com") + }, + "http://mattermost.com", + "", + "", + }, + "CORSEnabledWithCredentials": { + "http://mattermost.com", + "", + true, + func(req *http.Request) { + req.Header.Set("Origin", "http://mattermost.com") + }, + "http://mattermost.com", + "", + "true", + }, + "CORSEnabledWithHeaders": { + "http://mattermost.com", + "x-my-special-header x-blueberry", + true, + func(req *http.Request) { + req.Header.Set("Origin", "http://mattermost.com") + }, + "http://mattermost.com", + "X-My-Special-Header, X-Blueberry", + "true", + }, + } { + t.Run(name, func(t *testing.T) { + th := SetupConfig(func(cfg *model.Config) { + *cfg.ServiceSettings.AllowCorsFrom = testcase.AllowCorsFrom + *cfg.ServiceSettings.CorsExposedHeaders = testcase.CorsExposedHeaders + *cfg.ServiceSettings.CorsAllowCredentials = testcase.CorsAllowCredentials + }) + defer th.TearDown() + + port := th.App.Srv.ListenAddr.Port + host := fmt.Sprintf("http://localhost:%v", port) + url := fmt.Sprintf("%v/api/v4/system/ping", host) + + req, err := http.NewRequest("GET", url, nil) + if err != nil { + t.Fatal(err) + } + testcase.ModifyRequest(req) + + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + t.Fatal(err) + } + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, testcase.ExpectedAllowOrigin, resp.Header.Get(acAllowOrigin)) + assert.Equal(t, testcase.ExpectedExposeHeaders, resp.Header.Get(acExposeHeaders)) + assert.Equal(t, "", resp.Header.Get(acMaxAge)) + assert.Equal(t, testcase.ExpectedAllowCredentials, resp.Header.Get(acAllowCredentials)) + assert.Equal(t, "", resp.Header.Get(acAllowMethods)) + assert.Equal(t, "", resp.Header.Get(acAllowHeaders)) + }) + } + +} diff --git a/app/server.go b/app/server.go index 769690295..6b2e244d8 100644 --- a/app/server.go +++ b/app/server.go @@ -18,6 +18,7 @@ import ( "github.com/gorilla/handlers" "github.com/gorilla/mux" "github.com/pkg/errors" + "github.com/rs/cors" "golang.org/x/crypto/acme/autocert" "github.com/mattermost/mattermost-server/mlog" @@ -44,7 +45,7 @@ type Server struct { didFinishListen chan struct{} } -var allowedMethods []string = []string{ +var corsAllowedMethods []string = []string{ "POST", "GET", "OPTIONS", @@ -61,35 +62,6 @@ func (rl *RecoveryLogger) Println(i ...interface{}) { mlog.Error(fmt.Sprint(i)) } -type CorsWrapper struct { - config model.ConfigFunc - router *mux.Router -} - -func (cw *CorsWrapper) ServeHTTP(w http.ResponseWriter, r *http.Request) { - if allowed := *cw.config().ServiceSettings.AllowCorsFrom; allowed != "" { - if utils.CheckOrigin(r, allowed) { - w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin")) - - if r.Method == "OPTIONS" { - w.Header().Set( - "Access-Control-Allow-Methods", - strings.Join(allowedMethods, ", ")) - - w.Header().Set( - "Access-Control-Allow-Headers", - r.Header.Get("Access-Control-Request-Headers")) - } - } - } - - if r.Method == "OPTIONS" { - return - } - - cw.router.ServeHTTP(w, r) -} - const TIME_TO_WAIT_FOR_CONNECTIONS_TO_CLOSE_ON_SERVER_SHUTDOWN = time.Second // golang.org/x/crypto/acme/autocert/autocert.go @@ -114,7 +86,28 @@ func stripPort(hostport string) string { func (a *App) StartServer() error { mlog.Info("Starting Server...") - var handler http.Handler = &CorsWrapper{a.Config, a.Srv.RootRouter} + var handler http.Handler = a.Srv.RootRouter + if allowedOrigins := *a.Config().ServiceSettings.AllowCorsFrom; allowedOrigins != "" { + exposedCorsHeaders := *a.Config().ServiceSettings.CorsExposedHeaders + allowCredentials := *a.Config().ServiceSettings.CorsAllowCredentials + debug := *a.Config().ServiceSettings.CorsDebug + corsWrapper := cors.New(cors.Options{ + AllowedOrigins: strings.Fields(allowedOrigins), + AllowedMethods: corsAllowedMethods, + AllowedHeaders: []string{"*"}, + ExposedHeaders: strings.Fields(exposedCorsHeaders), + MaxAge: 86400, + AllowCredentials: allowCredentials, + Debug: debug, + }) + + // If we have debugging of CORS turned on then forward messages to logs + if debug { + corsWrapper.Log = a.Log.StdLog(mlog.String("source", "cors")) + } + + handler = corsWrapper.Handler(handler) + } if *a.Config().RateLimitSettings.Enable { mlog.Info("RateLimiter is enabled") diff --git a/config/default.json b/config/default.json index add9a71d6..d0c18e60c 100644 --- a/config/default.json +++ b/config/default.json @@ -33,6 +33,9 @@ "EnforceMultifactorAuthentication": false, "EnableUserAccessTokens": false, "AllowCorsFrom": "", + "CorsExposedHeaders": "", + "CorsAllowCredentials": false, + "CorsDebug": false, "AllowCookiesForSubdomains": false, "SessionLengthWebInDays": 30, "SessionLengthMobileInDays": 30, diff --git a/model/config.go b/model/config.go index f36b321ec..1d0f40901 100644 --- a/model/config.go +++ b/model/config.go @@ -203,6 +203,9 @@ type ServiceSettings struct { EnforceMultifactorAuthentication *bool EnableUserAccessTokens *bool AllowCorsFrom *string + CorsExposedHeaders *string + CorsAllowCredentials *bool + CorsDebug *bool AllowCookiesForSubdomains *bool SessionLengthWebInDays *int SessionLengthMobileInDays *int @@ -413,6 +416,18 @@ func (s *ServiceSettings) SetDefaults() { s.AllowCorsFrom = NewString(SERVICE_SETTINGS_DEFAULT_ALLOW_CORS_FROM) } + if s.CorsExposedHeaders == nil { + s.CorsExposedHeaders = NewString("") + } + + if s.CorsAllowCredentials == nil { + s.CorsAllowCredentials = NewBool(false) + } + + if s.CorsDebug == nil { + s.CorsDebug = NewBool(false) + } + if s.AllowCookiesForSubdomains == nil { s.AllowCookiesForSubdomains = NewBool(false) } diff --git a/vendor/github.com/rs/cors/.travis.yml b/vendor/github.com/rs/cors/.travis.yml new file mode 100644 index 000000000..17e5e50d5 --- /dev/null +++ b/vendor/github.com/rs/cors/.travis.yml @@ -0,0 +1,8 @@ +language: go +go: +- 1.9 +- "1.10" +- tip +matrix: + allow_failures: + - go: tip diff --git a/vendor/github.com/rs/cors/LICENSE b/vendor/github.com/rs/cors/LICENSE new file mode 100644 index 000000000..d8e2df5a4 --- /dev/null +++ b/vendor/github.com/rs/cors/LICENSE @@ -0,0 +1,19 @@ +Copyright (c) 2014 Olivier Poitrey + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is furnished +to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/vendor/github.com/rs/cors/README.md b/vendor/github.com/rs/cors/README.md new file mode 100644 index 000000000..425ed9624 --- /dev/null +++ b/vendor/github.com/rs/cors/README.md @@ -0,0 +1,105 @@ +# Go CORS handler [![godoc](http://img.shields.io/badge/godoc-reference-blue.svg?style=flat)](https://godoc.org/github.com/rs/cors) [![license](http://img.shields.io/badge/license-MIT-red.svg?style=flat)](https://raw.githubusercontent.com/rs/cors/master/LICENSE) [![build](https://img.shields.io/travis/rs/cors.svg?style=flat)](https://travis-ci.org/rs/cors) [![Coverage](http://gocover.io/_badge/github.com/rs/cors)](http://gocover.io/github.com/rs/cors) + +CORS is a `net/http` handler implementing [Cross Origin Resource Sharing W3 specification](http://www.w3.org/TR/cors/) in Golang. + +## Getting Started + +After installing Go and setting up your [GOPATH](http://golang.org/doc/code.html#GOPATH), create your first `.go` file. We'll call it `server.go`. + +```go +package main + +import ( + "net/http" + + "github.com/rs/cors" +) + +func main() { + mux := http.NewServeMux() + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte("{\"hello\": \"world\"}")) + }) + + // cors.Default() setup the middleware with default options being + // all origins accepted with simple methods (GET, POST). See + // documentation below for more options. + handler := cors.Default().Handler(mux) + http.ListenAndServe(":8080", handler) +} +``` + +Install `cors`: + + go get github.com/rs/cors + +Then run your server: + + go run server.go + +The server now runs on `localhost:8080`: + + $ curl -D - -H 'Origin: http://foo.com' http://localhost:8080/ + HTTP/1.1 200 OK + Access-Control-Allow-Origin: foo.com + Content-Type: application/json + Date: Sat, 25 Oct 2014 03:43:57 GMT + Content-Length: 18 + + {"hello": "world"} + +### More Examples + +* `net/http`: [examples/nethttp/server.go](https://github.com/rs/cors/blob/master/examples/nethttp/server.go) +* [Goji](https://goji.io): [examples/goji/server.go](https://github.com/rs/cors/blob/master/examples/goji/server.go) +* [Martini](http://martini.codegangsta.io): [examples/martini/server.go](https://github.com/rs/cors/blob/master/examples/martini/server.go) +* [Negroni](https://github.com/codegangsta/negroni): [examples/negroni/server.go](https://github.com/rs/cors/blob/master/examples/negroni/server.go) +* [Alice](https://github.com/justinas/alice): [examples/alice/server.go](https://github.com/rs/cors/blob/master/examples/alice/server.go) +* [HttpRouter](https://github.com/julienschmidt/httprouter): [examples/httprouter/server.go](https://github.com/rs/cors/blob/master/examples/httprouter/server.go) +* [Gorilla](http://www.gorillatoolkit.org/pkg/mux): [examples/gorilla/server.go](https://github.com/rs/cors/blob/master/examples/gorilla/server.go) +* [Buffalo](https://gobuffalo.io): [examples/buffalo/server.go](https://github.com/rs/cors/blob/master/examples/buffalo/server.go) +* [Gin](https://gin-gonic.github.io/gin): [examples/gin/server.go](https://github.com/rs/cors/blob/master/examples/gin/server.go) + +## Parameters + +Parameters are passed to the middleware thru the `cors.New` method as follow: + +```go +c := cors.New(cors.Options{ + AllowedOrigins: []string{"http://foo.com", "http://foo.com:8080"}, + AllowCredentials: true, + // Enable Debugging for testing, consider disabling in production + Debug: true, +}) + +// Insert the middleware +handler = c.Handler(handler) +``` + +* **AllowedOrigins** `[]string`: A list of origins a cross-domain request can be executed from. If the special `*` value is present in the list, all origins will be allowed. An origin may contain a wildcard (`*`) to replace 0 or more characters (i.e.: `http://*.domain.com`). Usage of wildcards implies a small performance penality. Only one wildcard can be used per origin. The default value is `*`. +* **AllowOriginFunc** `func (origin string) bool`: A custom function to validate the origin. It take the origin as argument and returns true if allowed or false otherwise. If this option is set, the content of `AllowedOrigins` is ignored +* **AllowedMethods** `[]string`: A list of methods the client is allowed to use with cross-domain requests. Default value is simple methods (`GET` and `POST`). +* **AllowedHeaders** `[]string`: A list of non simple headers the client is allowed to use with cross-domain requests. +* **ExposedHeaders** `[]string`: Indicates which headers are safe to expose to the API of a CORS API specification +* **AllowCredentials** `bool`: Indicates whether the request can include user credentials like cookies, HTTP authentication or client side SSL certificates. The default is `false`. +* **MaxAge** `int`: Indicates how long (in seconds) the results of a preflight request can be cached. The default is `0` which stands for no max age. +* **OptionsPassthrough** `bool`: Instructs preflight to let other potential next handlers to process the `OPTIONS` method. Turn this on if your application handles `OPTIONS`. +* **Debug** `bool`: Debugging flag adds additional output to debug server side CORS issues. + +See [API documentation](http://godoc.org/github.com/rs/cors) for more info. + +## Benchmarks + + BenchmarkWithout 20000000 64.6 ns/op 8 B/op 1 allocs/op + BenchmarkDefault 3000000 469 ns/op 114 B/op 2 allocs/op + BenchmarkAllowedOrigin 3000000 608 ns/op 114 B/op 2 allocs/op + BenchmarkPreflight 20000000 73.2 ns/op 0 B/op 0 allocs/op + BenchmarkPreflightHeader 20000000 73.6 ns/op 0 B/op 0 allocs/op + BenchmarkParseHeaderList 2000000 847 ns/op 184 B/op 6 allocs/op + BenchmarkParse…Single 5000000 290 ns/op 32 B/op 3 allocs/op + BenchmarkParse…Normalized 2000000 776 ns/op 160 B/op 6 allocs/op + +## Licenses + +All source code is licensed under the [MIT License](https://raw.github.com/rs/cors/master/LICENSE). diff --git a/vendor/github.com/rs/cors/cors.go b/vendor/github.com/rs/cors/cors.go new file mode 100644 index 000000000..0aa4f51f1 --- /dev/null +++ b/vendor/github.com/rs/cors/cors.go @@ -0,0 +1,407 @@ +/* +Package cors is net/http handler to handle CORS related requests +as defined by http://www.w3.org/TR/cors/ + +You can configure it by passing an option struct to cors.New: + + c := cors.New(cors.Options{ + AllowedOrigins: []string{"foo.com"}, + AllowedMethods: []string{"GET", "POST", "DELETE"}, + AllowCredentials: true, + }) + +Then insert the handler in the chain: + + handler = c.Handler(handler) + +See Options documentation for more options. + +The resulting handler is a standard net/http handler. +*/ +package cors + +import ( + "log" + "net/http" + "os" + "strconv" + "strings" +) + +// Options is a configuration container to setup the CORS middleware. +type Options struct { + // AllowedOrigins is a list of origins a cross-domain request can be executed from. + // If the special "*" value is present in the list, all origins will be allowed. + // An origin may contain a wildcard (*) to replace 0 or more characters + // (i.e.: http://*.domain.com). Usage of wildcards implies a small performance penalty. + // Only one wildcard can be used per origin. + // Default value is ["*"] + AllowedOrigins []string + // AllowOriginFunc is a custom function to validate the origin. It take the origin + // as argument and returns true if allowed or false otherwise. If this option is + // set, the content of AllowedOrigins is ignored. + AllowOriginFunc func(origin string) bool + // AllowedMethods is a list of methods the client is allowed to use with + // cross-domain requests. Default value is simple methods (HEAD, GET and POST). + AllowedMethods []string + // AllowedHeaders is list of non simple headers the client is allowed to use with + // cross-domain requests. + // If the special "*" value is present in the list, all headers will be allowed. + // Default value is [] but "Origin" is always appended to the list. + AllowedHeaders []string + // ExposedHeaders indicates which headers are safe to expose to the API of a CORS + // API specification + ExposedHeaders []string + // MaxAge indicates how long (in seconds) the results of a preflight request + // can be cached + MaxAge int + // AllowCredentials indicates whether the request can include user credentials like + // cookies, HTTP authentication or client side SSL certificates. + AllowCredentials bool + // OptionsPassthrough instructs preflight to let other potential next handlers to + // process the OPTIONS method. Turn this on if your application handles OPTIONS. + OptionsPassthrough bool + // Debugging flag adds additional output to debug server side CORS issues + Debug bool +} + +// Cors http handler +type Cors struct { + // Debug logger + Log *log.Logger + // Normalized list of plain allowed origins + allowedOrigins []string + // List of allowed origins containing wildcards + allowedWOrigins []wildcard + // Optional origin validator function + allowOriginFunc func(origin string) bool + // Normalized list of allowed headers + allowedHeaders []string + // Normalized list of allowed methods + allowedMethods []string + // Normalized list of exposed headers + exposedHeaders []string + maxAge int + // Set to true when allowed origins contains a "*" + allowedOriginsAll bool + // Set to true when allowed headers contains a "*" + allowedHeadersAll bool + allowCredentials bool + optionPassthrough bool +} + +// New creates a new Cors handler with the provided options. +func New(options Options) *Cors { + c := &Cors{ + exposedHeaders: convert(options.ExposedHeaders, http.CanonicalHeaderKey), + allowOriginFunc: options.AllowOriginFunc, + allowCredentials: options.AllowCredentials, + maxAge: options.MaxAge, + optionPassthrough: options.OptionsPassthrough, + } + if options.Debug { + c.Log = log.New(os.Stdout, "[cors] ", log.LstdFlags) + } + + // Normalize options + // Note: for origins and methods matching, the spec requires a case-sensitive matching. + // As it may error prone, we chose to ignore the spec here. + + // Allowed Origins + if len(options.AllowedOrigins) == 0 { + if options.AllowOriginFunc == nil { + // Default is all origins + c.allowedOriginsAll = true + } + } else { + c.allowedOrigins = []string{} + c.allowedWOrigins = []wildcard{} + for _, origin := range options.AllowedOrigins { + // Normalize + origin = strings.ToLower(origin) + if origin == "*" { + // If "*" is present in the list, turn the whole list into a match all + c.allowedOriginsAll = true + c.allowedOrigins = nil + c.allowedWOrigins = nil + break + } else if i := strings.IndexByte(origin, '*'); i >= 0 { + // Split the origin in two: start and end string without the * + w := wildcard{origin[0:i], origin[i+1:]} + c.allowedWOrigins = append(c.allowedWOrigins, w) + } else { + c.allowedOrigins = append(c.allowedOrigins, origin) + } + } + } + + // Allowed Headers + if len(options.AllowedHeaders) == 0 { + // Use sensible defaults + c.allowedHeaders = []string{"Origin", "Accept", "Content-Type", "X-Requested-With"} + } else { + // Origin is always appended as some browsers will always request for this header at preflight + c.allowedHeaders = convert(append(options.AllowedHeaders, "Origin"), http.CanonicalHeaderKey) + for _, h := range options.AllowedHeaders { + if h == "*" { + c.allowedHeadersAll = true + c.allowedHeaders = nil + break + } + } + } + + // Allowed Methods + if len(options.AllowedMethods) == 0 { + // Default is spec's "simple" methods + c.allowedMethods = []string{"GET", "POST", "HEAD"} + } else { + c.allowedMethods = convert(options.AllowedMethods, strings.ToUpper) + } + + return c +} + +// Default creates a new Cors handler with default options. +func Default() *Cors { + return New(Options{}) +} + +// AllowAll create a new Cors handler with permissive configuration allowing all +// origins with all standard methods with any header and credentials. +func AllowAll() *Cors { + return New(Options{ + AllowedOrigins: []string{"*"}, + AllowedMethods: []string{"HEAD", "GET", "POST", "PUT", "PATCH", "DELETE"}, + AllowedHeaders: []string{"*"}, + AllowCredentials: true, + }) +} + +// Handler apply the CORS specification on the request, and add relevant CORS headers +// as necessary. +func (c *Cors) Handler(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodOptions && r.Header.Get("Access-Control-Request-Method") != "" { + c.logf("Handler: Preflight request") + c.handlePreflight(w, r) + // Preflight requests are standalone and should stop the chain as some other + // middleware may not handle OPTIONS requests correctly. One typical example + // is authentication middleware ; OPTIONS requests won't carry authentication + // headers (see #1) + if c.optionPassthrough { + h.ServeHTTP(w, r) + } else { + w.WriteHeader(http.StatusOK) + } + } else { + c.logf("Handler: Actual request") + c.handleActualRequest(w, r) + h.ServeHTTP(w, r) + } + }) +} + +// HandlerFunc provides Martini compatible handler +func (c *Cors) HandlerFunc(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodOptions && r.Header.Get("Access-Control-Request-Method") != "" { + c.logf("HandlerFunc: Preflight request") + c.handlePreflight(w, r) + } else { + c.logf("HandlerFunc: Actual request") + c.handleActualRequest(w, r) + } +} + +// Negroni compatible interface +func (c *Cors) ServeHTTP(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) { + if r.Method == http.MethodOptions && r.Header.Get("Access-Control-Request-Method") != "" { + c.logf("ServeHTTP: Preflight request") + c.handlePreflight(w, r) + // Preflight requests are standalone and should stop the chain as some other + // middleware may not handle OPTIONS requests correctly. One typical example + // is authentication middleware ; OPTIONS requests won't carry authentication + // headers (see #1) + if c.optionPassthrough { + next(w, r) + } else { + w.WriteHeader(http.StatusOK) + } + } else { + c.logf("ServeHTTP: Actual request") + c.handleActualRequest(w, r) + next(w, r) + } +} + +// handlePreflight handles pre-flight CORS requests +func (c *Cors) handlePreflight(w http.ResponseWriter, r *http.Request) { + headers := w.Header() + origin := r.Header.Get("Origin") + + if r.Method != http.MethodOptions { + c.logf(" Preflight aborted: %s!=OPTIONS", r.Method) + return + } + // Always set Vary headers + // see https://github.com/rs/cors/issues/10, + // https://github.com/rs/cors/commit/dbdca4d95feaa7511a46e6f1efb3b3aa505bc43f#commitcomment-12352001 + headers.Add("Vary", "Origin") + headers.Add("Vary", "Access-Control-Request-Method") + headers.Add("Vary", "Access-Control-Request-Headers") + + if origin == "" { + c.logf(" Preflight aborted: empty origin") + return + } + if !c.isOriginAllowed(origin) { + c.logf(" Preflight aborted: origin '%s' not allowed", origin) + return + } + + reqMethod := r.Header.Get("Access-Control-Request-Method") + if !c.isMethodAllowed(reqMethod) { + c.logf(" Preflight aborted: method '%s' not allowed", reqMethod) + return + } + reqHeaders := parseHeaderList(r.Header.Get("Access-Control-Request-Headers")) + if !c.areHeadersAllowed(reqHeaders) { + c.logf(" Preflight aborted: headers '%v' not allowed", reqHeaders) + return + } + if c.allowedOriginsAll && !c.allowCredentials { + headers.Set("Access-Control-Allow-Origin", "*") + } else { + headers.Set("Access-Control-Allow-Origin", origin) + } + // Spec says: Since the list of methods can be unbounded, simply returning the method indicated + // by Access-Control-Request-Method (if supported) can be enough + headers.Set("Access-Control-Allow-Methods", strings.ToUpper(reqMethod)) + if len(reqHeaders) > 0 { + + // Spec says: Since the list of headers can be unbounded, simply returning supported headers + // from Access-Control-Request-Headers can be enough + headers.Set("Access-Control-Allow-Headers", strings.Join(reqHeaders, ", ")) + } + if c.allowCredentials { + headers.Set("Access-Control-Allow-Credentials", "true") + } + if c.maxAge > 0 { + headers.Set("Access-Control-Max-Age", strconv.Itoa(c.maxAge)) + } + c.logf(" Preflight response headers: %v", headers) +} + +// handleActualRequest handles simple cross-origin requests, actual request or redirects +func (c *Cors) handleActualRequest(w http.ResponseWriter, r *http.Request) { + headers := w.Header() + origin := r.Header.Get("Origin") + + if r.Method == http.MethodOptions { + c.logf(" Actual request no headers added: method == %s", r.Method) + return + } + // Always set Vary, see https://github.com/rs/cors/issues/10 + headers.Add("Vary", "Origin") + if origin == "" { + c.logf(" Actual request no headers added: missing origin") + return + } + if !c.isOriginAllowed(origin) { + c.logf(" Actual request no headers added: origin '%s' not allowed", origin) + return + } + + // Note that spec does define a way to specifically disallow a simple method like GET or + // POST. Access-Control-Allow-Methods is only used for pre-flight requests and the + // spec doesn't instruct to check the allowed methods for simple cross-origin requests. + // We think it's a nice feature to be able to have control on those methods though. + if !c.isMethodAllowed(r.Method) { + c.logf(" Actual request no headers added: method '%s' not allowed", r.Method) + + return + } + if c.allowedOriginsAll && !c.allowCredentials { + headers.Set("Access-Control-Allow-Origin", "*") + } else { + headers.Set("Access-Control-Allow-Origin", origin) + } + if len(c.exposedHeaders) > 0 { + headers.Set("Access-Control-Expose-Headers", strings.Join(c.exposedHeaders, ", ")) + } + if c.allowCredentials { + headers.Set("Access-Control-Allow-Credentials", "true") + } + c.logf(" Actual response added headers: %v", headers) +} + +// convenience method. checks if debugging is turned on before printing +func (c *Cors) logf(format string, a ...interface{}) { + if c.Log != nil { + c.Log.Printf(format, a...) + } +} + +// isOriginAllowed checks if a given origin is allowed to perform cross-domain requests +// on the endpoint +func (c *Cors) isOriginAllowed(origin string) bool { + if c.allowOriginFunc != nil { + return c.allowOriginFunc(origin) + } + if c.allowedOriginsAll { + return true + } + origin = strings.ToLower(origin) + for _, o := range c.allowedOrigins { + if o == origin { + return true + } + } + for _, w := range c.allowedWOrigins { + if w.match(origin) { + return true + } + } + return false +} + +// isMethodAllowed checks if a given method can be used as part of a cross-domain request +// on the endpoing +func (c *Cors) isMethodAllowed(method string) bool { + if len(c.allowedMethods) == 0 { + // If no method allowed, always return false, even for preflight request + return false + } + method = strings.ToUpper(method) + if method == http.MethodOptions { + // Always allow preflight requests + return true + } + for _, m := range c.allowedMethods { + if m == method { + return true + } + } + return false +} + +// areHeadersAllowed checks if a given list of headers are allowed to used within +// a cross-domain request. +func (c *Cors) areHeadersAllowed(requestedHeaders []string) bool { + if c.allowedHeadersAll || len(requestedHeaders) == 0 { + return true + } + for _, header := range requestedHeaders { + header = http.CanonicalHeaderKey(header) + found := false + for _, h := range c.allowedHeaders { + if h == header { + found = true + } + } + if !found { + return false + } + } + return true +} diff --git a/vendor/github.com/rs/cors/go.mod b/vendor/github.com/rs/cors/go.mod new file mode 100644 index 000000000..0a4c65210 --- /dev/null +++ b/vendor/github.com/rs/cors/go.mod @@ -0,0 +1 @@ +module github.com/rs/cors diff --git a/vendor/github.com/rs/cors/utils.go b/vendor/github.com/rs/cors/utils.go new file mode 100644 index 000000000..c7a0aa060 --- /dev/null +++ b/vendor/github.com/rs/cors/utils.go @@ -0,0 +1,70 @@ +package cors + +import "strings" + +const toLower = 'a' - 'A' + +type converter func(string) string + +type wildcard struct { + prefix string + suffix string +} + +func (w wildcard) match(s string) bool { + return len(s) >= len(w.prefix+w.suffix) && strings.HasPrefix(s, w.prefix) && strings.HasSuffix(s, w.suffix) +} + +// convert converts a list of string using the passed converter function +func convert(s []string, c converter) []string { + out := []string{} + for _, i := range s { + out = append(out, c(i)) + } + return out +} + +// parseHeaderList tokenize + normalize a string containing a list of headers +func parseHeaderList(headerList string) []string { + l := len(headerList) + h := make([]byte, 0, l) + upper := true + // Estimate the number headers in order to allocate the right splice size + t := 0 + for i := 0; i < l; i++ { + if headerList[i] == ',' { + t++ + } + } + headers := make([]string, 0, t) + for i := 0; i < l; i++ { + b := headerList[i] + if b >= 'a' && b <= 'z' { + if upper { + h = append(h, b-toLower) + } else { + h = append(h, b) + } + } else if b >= 'A' && b <= 'Z' { + if !upper { + h = append(h, b+toLower) + } else { + h = append(h, b) + } + } else if b == '-' || b == '_' || (b >= '0' && b <= '9') { + h = append(h, b) + } + + if b == ' ' || b == ',' || i == l-1 { + if len(h) > 0 { + // Flush the found header + headers = append(headers, string(h)) + h = h[:0] + upper = true + } + } else { + upper = b == '-' || b == '_' + } + } + return headers +} -- cgit v1.2.3-1-g7c22