diff options
author | Christopher Speller <crspeller@gmail.com> | 2016-11-16 19:28:52 -0500 |
---|---|---|
committer | GitHub <noreply@github.com> | 2016-11-16 19:28:52 -0500 |
commit | 0135904f7d3e1c0e763adaefe267c736616e3d26 (patch) | |
tree | c27be7588f98eaea62e0bd0c0087f2b348da9738 /vendor/github.com | |
parent | 0b296dd8c2aefefe89787be5cc627d44cf431150 (diff) | |
download | chat-0135904f7d3e1c0e763adaefe267c736616e3d26.tar.gz chat-0135904f7d3e1c0e763adaefe267c736616e3d26.tar.bz2 chat-0135904f7d3e1c0e763adaefe267c736616e3d26.zip |
Upgrading server dependancies (#4566)
Diffstat (limited to 'vendor/github.com')
118 files changed, 5593 insertions, 1192 deletions
diff --git a/vendor/github.com/NYTimes/gziphandler/gzip.go b/vendor/github.com/NYTimes/gziphandler/gzip.go index 39e8c5e75..dad0eb747 100644 --- a/vendor/github.com/NYTimes/gziphandler/gzip.go +++ b/vendor/github.com/NYTimes/gziphandler/gzip.go @@ -181,7 +181,7 @@ func GzipHandler(h http.Handler) http.Handler { } // acceptsGzip returns true if the given HTTP request indicates that it will -// accept a gzippped response. +// accept a gzipped response. func acceptsGzip(r *http.Request) bool { acceptedEncodings, _ := parseEncodings(r.Header.Get(acceptEncoding)) return acceptedEncodings["gzip"] > 0.0 diff --git a/vendor/github.com/gorilla/handlers/.travis.yml b/vendor/github.com/gorilla/handlers/.travis.yml index 783020996..4ea1e7a1f 100644 --- a/vendor/github.com/gorilla/handlers/.travis.yml +++ b/vendor/github.com/gorilla/handlers/.travis.yml @@ -6,13 +6,13 @@ matrix: - go: 1.4 - go: 1.5 - go: 1.6 + - go: 1.7 + - go: tip + allow_failures: - go: tip - -install: - - go get golang.org/x/tools/cmd/vet script: - go get -t -v ./... - diff -u <(echo -n) <(gofmt -d .) - - go tool vet . + - go vet $(go list ./... | grep -v /vendor/) - go test -v -race ./... diff --git a/vendor/github.com/gorilla/handlers/compress.go b/vendor/github.com/gorilla/handlers/compress.go index 5e140c503..e8345d792 100644 --- a/vendor/github.com/gorilla/handlers/compress.go +++ b/vendor/github.com/gorilla/handlers/compress.go @@ -56,6 +56,9 @@ func (w *compressResponseWriter) Flush() { // CompressHandler gzip compresses HTTP responses for clients that support it // via the 'Accept-Encoding' header. +// +// Compressing TLS traffic may leak the page contents to an attacker if the +// page contains user input: http://security.stackexchange.com/a/102015/12208 func CompressHandler(h http.Handler) http.Handler { return CompressHandlerLevel(h, gzip.DefaultCompression) } diff --git a/vendor/github.com/gorilla/handlers/cors.go b/vendor/github.com/gorilla/handlers/cors.go index d4229a5d9..1f92d1ad4 100644 --- a/vendor/github.com/gorilla/handlers/cors.go +++ b/vendor/github.com/gorilla/handlers/cors.go @@ -112,6 +112,9 @@ func (ch *cors) ServeHTTP(w http.ResponseWriter, r *http.Request) { w.Header().Set(corsAllowOriginHeader, origin) + if r.Method == corsOptionMethod { + return + } ch.h.ServeHTTP(w, r) } diff --git a/vendor/github.com/gorilla/handlers/cors_test.go b/vendor/github.com/gorilla/handlers/cors_test.go index ff7eebf48..c63913eee 100644 --- a/vendor/github.com/gorilla/handlers/cors_test.go +++ b/vendor/github.com/gorilla/handlers/cors_test.go @@ -104,6 +104,24 @@ func TestCORSHandlerInvalidRequestMethodForPreflightMethodNotAllowed(t *testing. } } +func TestCORSHandlerOptionsRequestMustNotBePassedToNextHandler(t *testing.T) { + r := newRequest("OPTIONS", "http://www.example.com/") + r.Header.Set("Origin", r.URL.String()) + r.Header.Set(corsRequestMethodHeader, "GET") + + rr := httptest.NewRecorder() + + testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Fatal("Options request must not be passed to next handler") + }) + + CORS()(testHandler).ServeHTTP(rr, r) + + if status := rr.Code; status != http.StatusOK { + t.Fatalf("bad status: got %v want %v", status, http.StatusOK) + } +} + func TestCORSHandlerAllowedMethodForPreflight(t *testing.T) { r := newRequest("OPTIONS", "http://www.example.com/") r.Header.Set("Origin", r.URL.String()) diff --git a/vendor/github.com/gorilla/handlers/proxy_headers.go b/vendor/github.com/gorilla/handlers/proxy_headers.go index 268de9c6a..0be750fd7 100644 --- a/vendor/github.com/gorilla/handlers/proxy_headers.go +++ b/vendor/github.com/gorilla/handlers/proxy_headers.go @@ -8,9 +8,11 @@ import ( var ( // De-facto standard header keys. - xForwardedFor = http.CanonicalHeaderKey("X-Forwarded-For") - xRealIP = http.CanonicalHeaderKey("X-Real-IP") - xForwardedProto = http.CanonicalHeaderKey("X-Forwarded-Scheme") + xForwardedFor = http.CanonicalHeaderKey("X-Forwarded-For") + xForwardedHost = http.CanonicalHeaderKey("X-Forwarded-Host") + xForwardedProto = http.CanonicalHeaderKey("X-Forwarded-Proto") + xForwardedScheme = http.CanonicalHeaderKey("X-Forwarded-Scheme") + xRealIP = http.CanonicalHeaderKey("X-Real-IP") ) var ( @@ -28,9 +30,9 @@ var ( // ProxyHeaders inspects common reverse proxy headers and sets the corresponding // fields in the HTTP request struct. These are X-Forwarded-For and X-Real-IP -// for the remote (client) IP address, X-Forwarded-Proto for the scheme -// (http|https) and the RFC7239 Forwarded header, which may include both client -// IPs and schemes. +// for the remote (client) IP address, X-Forwarded-Proto or X-Forwarded-Scheme +// for the scheme (http|https) and the RFC7239 Forwarded header, which may +// include both client IPs and schemes. // // NOTE: This middleware should only be used when behind a reverse // proxy like nginx, HAProxy or Apache. Reverse proxies that don't (or are @@ -49,7 +51,10 @@ func ProxyHeaders(h http.Handler) http.Handler { if scheme := getScheme(r); scheme != "" { r.URL.Scheme = scheme } - + // Set the host with the value passed by the proxy + if r.Header.Get(xForwardedHost) != "" { + r.Host = r.Header.Get(xForwardedHost) + } // Call the next handler in the chain. h.ServeHTTP(w, r) } @@ -99,7 +104,9 @@ func getScheme(r *http.Request) string { // Retrieve the scheme from X-Forwarded-Proto. if proto := r.Header.Get(xForwardedProto); proto != "" { scheme = strings.ToLower(proto) - } else if proto := r.Header.Get(forwarded); proto != "" { + } else if proto = r.Header.Get(xForwardedScheme); proto != "" { + scheme = strings.ToLower(proto) + } else if proto = r.Header.Get(forwarded); proto != "" { // match should contain at least two elements if the protocol was // specified in the Forwarded header. The first element will always be // the 'proto=' capture, which we ignore. In the case of multiple proto diff --git a/vendor/github.com/gorilla/handlers/proxy_headers_test.go b/vendor/github.com/gorilla/handlers/proxy_headers_test.go index 85282ef7d..1bd78052d 100644 --- a/vendor/github.com/gorilla/handlers/proxy_headers_test.go +++ b/vendor/github.com/gorilla/handlers/proxy_headers_test.go @@ -47,6 +47,9 @@ func TestGetScheme(t *testing.T) { {xForwardedProto, "https", "https"}, {xForwardedProto, "http", "http"}, {xForwardedProto, "HTTP", "http"}, + {xForwardedScheme, "https", "https"}, + {xForwardedScheme, "http", "http"}, + {xForwardedScheme, "HTTP", "http"}, {forwarded, `For="[2001:db8:cafe::17]:4711`, ""}, // No proto {forwarded, `for=192.0.2.43, for=198.51.100.17;proto=https`, "https"}, // Multiple params before proto {forwarded, `for=172.32.10.15; proto=https;by=127.0.0.1`, "https"}, // Space before proto @@ -74,13 +77,17 @@ func TestProxyHeaders(t *testing.T) { r.Header.Set(xForwardedFor, "8.8.8.8") r.Header.Set(xForwardedProto, "https") - - var addr string - var proto string + r.Header.Set(xForwardedHost, "google.com") + var ( + addr string + proto string + host string + ) ProxyHeaders(http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { addr = r.RemoteAddr proto = r.URL.Scheme + host = r.Host })).ServeHTTP(rr, r) if rr.Code != http.StatusOK { @@ -96,5 +103,9 @@ func TestProxyHeaders(t *testing.T) { t.Fatalf("wrong address: got %s want %s", proto, r.Header.Get(xForwardedProto)) } + if host != r.Header.Get(xForwardedHost) { + t.Fatalf("wrong address: got %s want %s", host, + r.Header.Get(xForwardedHost)) + } } diff --git a/vendor/github.com/gorilla/mux/.travis.yml b/vendor/github.com/gorilla/mux/.travis.yml index 4dcdacb65..f93ce56d1 100644 --- a/vendor/github.com/gorilla/mux/.travis.yml +++ b/vendor/github.com/gorilla/mux/.travis.yml @@ -8,10 +8,11 @@ matrix: - go: 1.4 - go: 1.5 - go: 1.6 + - go: 1.7 - go: tip install: - - go get golang.org/x/tools/cmd/vet + - # Skip script: - go get -t -v ./... diff --git a/vendor/github.com/gorilla/mux/README.md b/vendor/github.com/gorilla/mux/README.md index 9516c5191..fa79a6bc3 100644 --- a/vendor/github.com/gorilla/mux/README.md +++ b/vendor/github.com/gorilla/mux/README.md @@ -1,19 +1,43 @@ -mux +gorilla/mux === [![GoDoc](https://godoc.org/github.com/gorilla/mux?status.svg)](https://godoc.org/github.com/gorilla/mux) [![Build Status](https://travis-ci.org/gorilla/mux.svg?branch=master)](https://travis-ci.org/gorilla/mux) +![Gorilla Logo](http://www.gorillatoolkit.org/static/images/gorilla-icon-64.png) + http://www.gorillatoolkit.org/pkg/mux -Package `gorilla/mux` implements a request router and dispatcher. +Package `gorilla/mux` implements a request router and dispatcher for matching incoming requests to +their respective handler. The name mux stands for "HTTP request multiplexer". Like the standard `http.ServeMux`, `mux.Router` matches incoming requests against a list of registered routes and calls a handler for the route that matches the URL or other conditions. The main features are: +* It implements the `http.Handler` interface so it is compatible with the standard `http.ServeMux`. * Requests can be matched based on URL host, path, path prefix, schemes, header and query values, HTTP methods or using custom matchers. * URL hosts and paths can have variables with an optional regular expression. * Registered URLs can be built, or "reversed", which helps maintaining references to resources. * Routes can be used as subrouters: nested routes are only tested if the parent route matches. This is useful to define groups of routes that share common conditions like a host, a path prefix or other repeated attributes. As a bonus, this optimizes request matching. -* It implements the `http.Handler` interface so it is compatible with the standard `http.ServeMux`. + +--- + +* [Install](#install) +* [Examples](#examples) +* [Matching Routes](#matching-routes) +* [Static Files](#static-files) +* [Registered URLs](#registered-urls) +* [Full Example](#full-example) + +--- + +## Install + +With a [correctly configured](https://golang.org/doc/install#testing) Go toolchain: + +```sh +go get -u github.com/gorilla/mux +``` + +## Examples Let's start registering a couple of URL paths and handlers: @@ -47,6 +71,8 @@ category := vars["category"] And this is all you need to know about the basic usage. More advanced options are explained below. +### Matching Routes + Routes can also be restricted to a domain or subdomain. Just define a host pattern to be matched. They can also have variables: ```go @@ -118,7 +144,7 @@ Then register routes in the subrouter: ```go s.HandleFunc("/products/", ProductsHandler) s.HandleFunc("/products/{key}", ProductHandler) -s.HandleFunc("/articles/{category}/{id:[0-9]+}"), ArticleHandler) +s.HandleFunc("/articles/{category}/{id:[0-9]+}", ArticleHandler) ``` The three URL paths we registered above will only be tested if the domain is `www.example.com`, because the subrouter is tested first. This is not only convenient, but also optimizes request matching. You can create subrouters combining any attribute matchers accepted by a route. @@ -138,6 +164,37 @@ s.HandleFunc("/{key}/", ProductHandler) s.HandleFunc("/{key}/details", ProductDetailsHandler) ``` +### Static Files + +Note that the path provided to `PathPrefix()` represents a "wildcard": calling +`PathPrefix("/static/").Handler(...)` means that the handler will be passed any +request that matches "/static/*". This makes it easy to serve static files with mux: + +```go +func main() { + var dir string + + flag.StringVar(&dir, "dir", ".", "the directory to serve files from. Defaults to the current dir") + flag.Parse() + r := mux.NewRouter() + + // This will serve files under http://localhost:8000/static/<filename> + r.PathPrefix("/static/").Handler(http.StripPrefix("/static/", http.FileServer(http.Dir(dir)))) + + srv := &http.Server{ + Handler: r, + Addr: "127.0.0.1:8000", + // Good practice: enforce timeouts for servers you create! + WriteTimeout: 15 * time.Second, + ReadTimeout: 15 * time.Second, + } + + log.Fatal(srv.ListenAndServe()) +} +``` + +### Registered URLs + Now let's see how to build registered URLs. Routes can be named. All routes that define a name can have their URLs built, or "reversed". We define a name calling `Name()` on a route. For example: @@ -219,7 +276,7 @@ package main import ( "net/http" - + "log" "github.com/gorilla/mux" ) @@ -233,7 +290,7 @@ func main() { r.HandleFunc("/", YourHandler) // Bind to a port and pass our router in - http.ListenAndServe(":8000", r) + log.Fatal(http.ListenAndServe(":8000", r)) } ``` diff --git a/vendor/github.com/gorilla/mux/context_gorilla.go b/vendor/github.com/gorilla/mux/context_gorilla.go new file mode 100644 index 000000000..d7adaa8fa --- /dev/null +++ b/vendor/github.com/gorilla/mux/context_gorilla.go @@ -0,0 +1,26 @@ +// +build !go1.7 + +package mux + +import ( + "net/http" + + "github.com/gorilla/context" +) + +func contextGet(r *http.Request, key interface{}) interface{} { + return context.Get(r, key) +} + +func contextSet(r *http.Request, key, val interface{}) *http.Request { + if val == nil { + return r + } + + context.Set(r, key, val) + return r +} + +func contextClear(r *http.Request) { + context.Clear(r) +} diff --git a/vendor/github.com/gorilla/mux/context_gorilla_test.go b/vendor/github.com/gorilla/mux/context_gorilla_test.go new file mode 100644 index 000000000..ffaf384c0 --- /dev/null +++ b/vendor/github.com/gorilla/mux/context_gorilla_test.go @@ -0,0 +1,40 @@ +// +build !go1.7 + +package mux + +import ( + "net/http" + "testing" + + "github.com/gorilla/context" +) + +// Tests that the context is cleared or not cleared properly depending on +// the configuration of the router +func TestKeepContext(t *testing.T) { + func1 := func(w http.ResponseWriter, r *http.Request) {} + + r := NewRouter() + r.HandleFunc("/", func1).Name("func1") + + req, _ := http.NewRequest("GET", "http://localhost/", nil) + context.Set(req, "t", 1) + + res := new(http.ResponseWriter) + r.ServeHTTP(*res, req) + + if _, ok := context.GetOk(req, "t"); ok { + t.Error("Context should have been cleared at end of request") + } + + r.KeepContext = true + + req, _ = http.NewRequest("GET", "http://localhost/", nil) + context.Set(req, "t", 1) + + r.ServeHTTP(*res, req) + if _, ok := context.GetOk(req, "t"); !ok { + t.Error("Context should NOT have been cleared at end of request") + } + +} diff --git a/vendor/github.com/gorilla/mux/context_native.go b/vendor/github.com/gorilla/mux/context_native.go new file mode 100644 index 000000000..209cbea7d --- /dev/null +++ b/vendor/github.com/gorilla/mux/context_native.go @@ -0,0 +1,24 @@ +// +build go1.7 + +package mux + +import ( + "context" + "net/http" +) + +func contextGet(r *http.Request, key interface{}) interface{} { + return r.Context().Value(key) +} + +func contextSet(r *http.Request, key, val interface{}) *http.Request { + if val == nil { + return r + } + + return r.WithContext(context.WithValue(r.Context(), key, val)) +} + +func contextClear(r *http.Request) { + return +} diff --git a/vendor/github.com/gorilla/mux/context_native_test.go b/vendor/github.com/gorilla/mux/context_native_test.go new file mode 100644 index 000000000..c150edf01 --- /dev/null +++ b/vendor/github.com/gorilla/mux/context_native_test.go @@ -0,0 +1,32 @@ +// +build go1.7 + +package mux + +import ( + "context" + "net/http" + "testing" + "time" +) + +func TestNativeContextMiddleware(t *testing.T) { + withTimeout := func(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx, cancel := context.WithTimeout(r.Context(), time.Minute) + defer cancel() + h.ServeHTTP(w, r.WithContext(ctx)) + }) + } + + r := NewRouter() + r.Handle("/path/{foo}", withTimeout(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + vars := Vars(r) + if vars["foo"] != "bar" { + t.Fatal("Expected foo var to be set") + } + }))) + + rec := NewRecorder() + req := newRequest("GET", "/path/bar") + r.ServeHTTP(rec, req) +} diff --git a/vendor/github.com/gorilla/mux/doc.go b/vendor/github.com/gorilla/mux/doc.go index 835f5342e..e9573dd8a 100644 --- a/vendor/github.com/gorilla/mux/doc.go +++ b/vendor/github.com/gorilla/mux/doc.go @@ -47,6 +47,10 @@ variable will be anything until the next slash. For example: r.HandleFunc("/articles/{category}/", ArticlesCategoryHandler) r.HandleFunc("/articles/{category}/{id:[0-9]+}", ArticleHandler) +Groups can be used inside patterns, as long as they are non-capturing (?:re). For example: + + r.HandleFunc("/articles/{category}/{sort:(?:asc|desc|new)}", ArticlesCategoryHandler) + The names are used to create a map of route variables which can be retrieved calling mux.Vars(): @@ -136,6 +140,31 @@ the inner routes use it as base for their paths: // "/products/{key}/details" s.HandleFunc("/{key}/details", ProductDetailsHandler) +Note that the path provided to PathPrefix() represents a "wildcard": calling +PathPrefix("/static/").Handler(...) means that the handler will be passed any +request that matches "/static/*". This makes it easy to serve static files with mux: + + func main() { + var dir string + + flag.StringVar(&dir, "dir", ".", "the directory to serve files from. Defaults to the current dir") + flag.Parse() + r := mux.NewRouter() + + // This will serve files under http://localhost:8000/static/<filename> + r.PathPrefix("/static/").Handler(http.StripPrefix("/static/", http.FileServer(http.Dir(dir)))) + + srv := &http.Server{ + Handler: r, + Addr: "127.0.0.1:8000", + // Good practice: enforce timeouts for servers you create! + WriteTimeout: 15 * time.Second, + ReadTimeout: 15 * time.Second, + } + + log.Fatal(srv.ListenAndServe()) + } + Now let's see how to build registered URLs. Routes can be named. All routes that define a name can have their URLs built, diff --git a/vendor/github.com/gorilla/mux/mux.go b/vendor/github.com/gorilla/mux/mux.go index fbb7f19ad..d66ec3841 100644 --- a/vendor/github.com/gorilla/mux/mux.go +++ b/vendor/github.com/gorilla/mux/mux.go @@ -10,8 +10,7 @@ import ( "net/http" "path" "regexp" - - "github.com/gorilla/context" + "strings" ) // NewRouter returns a new router instance. @@ -48,8 +47,14 @@ type Router struct { namedRoutes map[string]*Route // See Router.StrictSlash(). This defines the flag for new routes. strictSlash bool - // If true, do not clear the request context after handling the request + // See Router.SkipClean(). This defines the flag for new routes. + skipClean bool + // If true, do not clear the request context after handling the request. + // This has no effect when go1.7+ is used, since the context is stored + // on the request itself. KeepContext bool + // see Router.UseEncodedPath(). This defines a flag for all routes. + useEncodedPath bool } // Match matches registered routes against the request. @@ -73,32 +78,38 @@ func (r *Router) Match(req *http.Request, match *RouteMatch) bool { // When there is a match, the route variables can be retrieved calling // mux.Vars(request). func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) { - // Clean path to canonical form and redirect. - if p := cleanPath(req.URL.Path); p != req.URL.Path { - - // Added 3 lines (Philip Schlump) - It was dropping the query string and #whatever from query. - // This matches with fix in go 1.2 r.c. 4 for same problem. Go Issue: - // http://code.google.com/p/go/issues/detail?id=5252 - url := *req.URL - url.Path = p - p = url.String() - - w.Header().Set("Location", p) - w.WriteHeader(http.StatusMovedPermanently) - return + if !r.skipClean { + path := req.URL.Path + if r.useEncodedPath { + path = getPath(req) + } + // Clean path to canonical form and redirect. + if p := cleanPath(path); p != path { + + // Added 3 lines (Philip Schlump) - It was dropping the query string and #whatever from query. + // This matches with fix in go 1.2 r.c. 4 for same problem. Go Issue: + // http://code.google.com/p/go/issues/detail?id=5252 + url := *req.URL + url.Path = p + p = url.String() + + w.Header().Set("Location", p) + w.WriteHeader(http.StatusMovedPermanently) + return + } } var match RouteMatch var handler http.Handler if r.Match(req, &match) { handler = match.Handler - setVars(req, match.Vars) - setCurrentRoute(req, match.Route) + req = setVars(req, match.Vars) + req = setCurrentRoute(req, match.Route) } if handler == nil { handler = http.NotFoundHandler() } if !r.KeepContext { - defer context.Clear(req) + defer contextClear(req) } handler.ServeHTTP(w, req) } @@ -133,6 +144,34 @@ func (r *Router) StrictSlash(value bool) *Router { return r } +// SkipClean defines the path cleaning behaviour for new routes. The initial +// value is false. Users should be careful about which routes are not cleaned +// +// When true, if the route path is "/path//to", it will remain with the double +// slash. This is helpful if you have a route like: /fetch/http://xkcd.com/534/ +// +// When false, the path will be cleaned, so /fetch/http://xkcd.com/534/ will +// become /fetch/http/xkcd.com/534 +func (r *Router) SkipClean(value bool) *Router { + r.skipClean = value + return r +} + +// UseEncodedPath tells the router to match the encoded original path +// to the routes. +// For eg. "/path/foo%2Fbar/to" will match the path "/path/{var}/to". +// This behavior has the drawback of needing to match routes against +// r.RequestURI instead of r.URL.Path. Any modifications (such as http.StripPrefix) +// to r.URL.Path will not affect routing when this flag is on and thus may +// induce unintended behavior. +// +// If not called, the router will match the unencoded path to the routes. +// For eg. "/path/foo%2Fbar/to" will match the path "/path/foo/bar/to" +func (r *Router) UseEncodedPath() *Router { + r.useEncodedPath = true + return r +} + // ---------------------------------------------------------------------------- // parentRoute // ---------------------------------------------------------------------------- @@ -170,7 +209,7 @@ func (r *Router) buildVars(m map[string]string) map[string]string { // NewRoute registers an empty route. func (r *Router) NewRoute() *Route { - route := &Route{parent: r, strictSlash: r.strictSlash} + route := &Route{parent: r, strictSlash: r.strictSlash, skipClean: r.skipClean, useEncodedPath: r.useEncodedPath} r.routes = append(r.routes, route) return route } @@ -268,6 +307,9 @@ func (r *Router) walk(walkFn WalkFunc, ancestors []*Route) error { if err == SkipRouter { continue } + if err != nil { + return err + } for _, sr := range t.matchers { if h, ok := sr.(*Router); ok { err := h.walk(walkFn, ancestors) @@ -308,7 +350,7 @@ const ( // Vars returns the route variables for the current request, if any. func Vars(r *http.Request) map[string]string { - if rv := context.Get(r, varsKey); rv != nil { + if rv := contextGet(r, varsKey); rv != nil { return rv.(map[string]string) } return nil @@ -320,28 +362,46 @@ func Vars(r *http.Request) map[string]string { // after the handler returns, unless the KeepContext option is set on the // Router. func CurrentRoute(r *http.Request) *Route { - if rv := context.Get(r, routeKey); rv != nil { + if rv := contextGet(r, routeKey); rv != nil { return rv.(*Route) } return nil } -func setVars(r *http.Request, val interface{}) { - if val != nil { - context.Set(r, varsKey, val) - } +func setVars(r *http.Request, val interface{}) *http.Request { + return contextSet(r, varsKey, val) } -func setCurrentRoute(r *http.Request, val interface{}) { - if val != nil { - context.Set(r, routeKey, val) - } +func setCurrentRoute(r *http.Request, val interface{}) *http.Request { + return contextSet(r, routeKey, val) } // ---------------------------------------------------------------------------- // Helpers // ---------------------------------------------------------------------------- +// getPath returns the escaped path if possible; doing what URL.EscapedPath() +// which was added in go1.5 does +func getPath(req *http.Request) string { + if req.RequestURI != "" { + // Extract the path from RequestURI (which is escaped unlike URL.Path) + // as detailed here as detailed in https://golang.org/pkg/net/url/#URL + // for < 1.5 server side workaround + // http://localhost/path/here?v=1 -> /path/here + path := req.RequestURI + path = strings.TrimPrefix(path, req.URL.Scheme+`://`) + path = strings.TrimPrefix(path, req.URL.Host) + if i := strings.LastIndex(path, "?"); i > -1 { + path = path[:i] + } + if i := strings.LastIndex(path, "#"); i > -1 { + path = path[:i] + } + return path + } + return req.URL.Path +} + // cleanPath returns the canonical path for p, eliminating . and .. elements. // Borrowed from the net/http package. func cleanPath(p string) string { @@ -357,6 +417,7 @@ func cleanPath(p string) string { if p[len(p)-1] == '/' && np != "/" { np += "/" } + return np } diff --git a/vendor/github.com/gorilla/mux/mux_test.go b/vendor/github.com/gorilla/mux/mux_test.go index a44d03f80..39a099c1e 100644 --- a/vendor/github.com/gorilla/mux/mux_test.go +++ b/vendor/github.com/gorilla/mux/mux_test.go @@ -5,12 +5,13 @@ package mux import ( + "bufio" + "bytes" + "errors" "fmt" "net/http" "strings" "testing" - - "github.com/gorilla/context" ) func (r *Route) GoString() string { @@ -32,8 +33,8 @@ type routeTest struct { vars map[string]string // the expected vars of the match host string // the expected host of the match path string // the expected path of the match - path_template string // the expected path template to match - host_template string // the expected host template to match + pathTemplate string // the expected path template to match + hostTemplate string // the expected host template to match shouldMatch bool // whether the request is expected to match the route at all shouldRedirect bool // whether the request should result in a redirect } @@ -115,124 +116,124 @@ func TestHost(t *testing.T) { shouldMatch: false, }, { - title: "Host route with pattern, match", - route: new(Route).Host("aaa.{v1:[a-z]{3}}.ccc"), - request: newRequest("GET", "http://aaa.bbb.ccc/111/222/333"), - vars: map[string]string{"v1": "bbb"}, - host: "aaa.bbb.ccc", - path: "", - host_template: `aaa.{v1:[a-z]{3}}.ccc`, - shouldMatch: true, - }, - { - title: "Host route with pattern, additional capturing group, match", - route: new(Route).Host("aaa.{v1:[a-z]{2}(b|c)}.ccc"), - request: newRequest("GET", "http://aaa.bbb.ccc/111/222/333"), - vars: map[string]string{"v1": "bbb"}, - host: "aaa.bbb.ccc", - path: "", - host_template: `aaa.{v1:[a-z]{2}(b|c)}.ccc`, - shouldMatch: true, - }, - { - title: "Host route with pattern, wrong host in request URL", - route: new(Route).Host("aaa.{v1:[a-z]{3}}.ccc"), - request: newRequest("GET", "http://aaa.222.ccc/111/222/333"), - vars: map[string]string{"v1": "bbb"}, - host: "aaa.bbb.ccc", - path: "", - host_template: `aaa.{v1:[a-z]{3}}.ccc`, - shouldMatch: false, - }, - { - title: "Host route with multiple patterns, match", - route: new(Route).Host("{v1:[a-z]{3}}.{v2:[a-z]{3}}.{v3:[a-z]{3}}"), - request: newRequest("GET", "http://aaa.bbb.ccc/111/222/333"), - vars: map[string]string{"v1": "aaa", "v2": "bbb", "v3": "ccc"}, - host: "aaa.bbb.ccc", - path: "", - host_template: `{v1:[a-z]{3}}.{v2:[a-z]{3}}.{v3:[a-z]{3}}`, - shouldMatch: true, - }, - { - title: "Host route with multiple patterns, wrong host in request URL", - route: new(Route).Host("{v1:[a-z]{3}}.{v2:[a-z]{3}}.{v3:[a-z]{3}}"), - request: newRequest("GET", "http://aaa.222.ccc/111/222/333"), - vars: map[string]string{"v1": "aaa", "v2": "bbb", "v3": "ccc"}, - host: "aaa.bbb.ccc", - path: "", - host_template: `{v1:[a-z]{3}}.{v2:[a-z]{3}}.{v3:[a-z]{3}}`, - shouldMatch: false, - }, - { - title: "Host route with hyphenated name and pattern, match", - route: new(Route).Host("aaa.{v-1:[a-z]{3}}.ccc"), - request: newRequest("GET", "http://aaa.bbb.ccc/111/222/333"), - vars: map[string]string{"v-1": "bbb"}, - host: "aaa.bbb.ccc", - path: "", - host_template: `aaa.{v-1:[a-z]{3}}.ccc`, - shouldMatch: true, - }, - { - title: "Host route with hyphenated name and pattern, additional capturing group, match", - route: new(Route).Host("aaa.{v-1:[a-z]{2}(b|c)}.ccc"), - request: newRequest("GET", "http://aaa.bbb.ccc/111/222/333"), - vars: map[string]string{"v-1": "bbb"}, - host: "aaa.bbb.ccc", - path: "", - host_template: `aaa.{v-1:[a-z]{2}(b|c)}.ccc`, - shouldMatch: true, - }, - { - title: "Host route with multiple hyphenated names and patterns, match", - route: new(Route).Host("{v-1:[a-z]{3}}.{v-2:[a-z]{3}}.{v-3:[a-z]{3}}"), - request: newRequest("GET", "http://aaa.bbb.ccc/111/222/333"), - vars: map[string]string{"v-1": "aaa", "v-2": "bbb", "v-3": "ccc"}, - host: "aaa.bbb.ccc", - path: "", - host_template: `{v-1:[a-z]{3}}.{v-2:[a-z]{3}}.{v-3:[a-z]{3}}`, - shouldMatch: true, - }, - { - title: "Path route with single pattern with pipe, match", - route: new(Route).Path("/{category:a|b/c}"), - request: newRequest("GET", "http://localhost/a"), - vars: map[string]string{"category": "a"}, - host: "", - path: "/a", - path_template: `/{category:a|b/c}`, - shouldMatch: true, - }, - { - title: "Path route with single pattern with pipe, match", - route: new(Route).Path("/{category:a|b/c}"), - request: newRequest("GET", "http://localhost/b/c"), - vars: map[string]string{"category": "b/c"}, - host: "", - path: "/b/c", - path_template: `/{category:a|b/c}`, - shouldMatch: true, - }, - { - title: "Path route with multiple patterns with pipe, match", - route: new(Route).Path("/{category:a|b/c}/{product}/{id:[0-9]+}"), - request: newRequest("GET", "http://localhost/a/product_name/1"), - vars: map[string]string{"category": "a", "product": "product_name", "id": "1"}, - host: "", - path: "/a/product_name/1", - path_template: `/{category:a|b/c}/{product}/{id:[0-9]+}`, - shouldMatch: true, - }, - { - title: "Path route with multiple patterns with pipe, match", - route: new(Route).Path("/{category:a|b/c}/{product}/{id:[0-9]+}"), - request: newRequest("GET", "http://localhost/b/c/product_name/1"), - vars: map[string]string{"category": "b/c", "product": "product_name", "id": "1"}, - host: "", - path: "/b/c/product_name/1", - path_template: `/{category:a|b/c}/{product}/{id:[0-9]+}`, - shouldMatch: true, + title: "Host route with pattern, match", + route: new(Route).Host("aaa.{v1:[a-z]{3}}.ccc"), + request: newRequest("GET", "http://aaa.bbb.ccc/111/222/333"), + vars: map[string]string{"v1": "bbb"}, + host: "aaa.bbb.ccc", + path: "", + hostTemplate: `aaa.{v1:[a-z]{3}}.ccc`, + shouldMatch: true, + }, + { + title: "Host route with pattern, additional capturing group, match", + route: new(Route).Host("aaa.{v1:[a-z]{2}(?:b|c)}.ccc"), + request: newRequest("GET", "http://aaa.bbb.ccc/111/222/333"), + vars: map[string]string{"v1": "bbb"}, + host: "aaa.bbb.ccc", + path: "", + hostTemplate: `aaa.{v1:[a-z]{2}(?:b|c)}.ccc`, + shouldMatch: true, + }, + { + title: "Host route with pattern, wrong host in request URL", + route: new(Route).Host("aaa.{v1:[a-z]{3}}.ccc"), + request: newRequest("GET", "http://aaa.222.ccc/111/222/333"), + vars: map[string]string{"v1": "bbb"}, + host: "aaa.bbb.ccc", + path: "", + hostTemplate: `aaa.{v1:[a-z]{3}}.ccc`, + shouldMatch: false, + }, + { + title: "Host route with multiple patterns, match", + route: new(Route).Host("{v1:[a-z]{3}}.{v2:[a-z]{3}}.{v3:[a-z]{3}}"), + request: newRequest("GET", "http://aaa.bbb.ccc/111/222/333"), + vars: map[string]string{"v1": "aaa", "v2": "bbb", "v3": "ccc"}, + host: "aaa.bbb.ccc", + path: "", + hostTemplate: `{v1:[a-z]{3}}.{v2:[a-z]{3}}.{v3:[a-z]{3}}`, + shouldMatch: true, + }, + { + title: "Host route with multiple patterns, wrong host in request URL", + route: new(Route).Host("{v1:[a-z]{3}}.{v2:[a-z]{3}}.{v3:[a-z]{3}}"), + request: newRequest("GET", "http://aaa.222.ccc/111/222/333"), + vars: map[string]string{"v1": "aaa", "v2": "bbb", "v3": "ccc"}, + host: "aaa.bbb.ccc", + path: "", + hostTemplate: `{v1:[a-z]{3}}.{v2:[a-z]{3}}.{v3:[a-z]{3}}`, + shouldMatch: false, + }, + { + title: "Host route with hyphenated name and pattern, match", + route: new(Route).Host("aaa.{v-1:[a-z]{3}}.ccc"), + request: newRequest("GET", "http://aaa.bbb.ccc/111/222/333"), + vars: map[string]string{"v-1": "bbb"}, + host: "aaa.bbb.ccc", + path: "", + hostTemplate: `aaa.{v-1:[a-z]{3}}.ccc`, + shouldMatch: true, + }, + { + title: "Host route with hyphenated name and pattern, additional capturing group, match", + route: new(Route).Host("aaa.{v-1:[a-z]{2}(?:b|c)}.ccc"), + request: newRequest("GET", "http://aaa.bbb.ccc/111/222/333"), + vars: map[string]string{"v-1": "bbb"}, + host: "aaa.bbb.ccc", + path: "", + hostTemplate: `aaa.{v-1:[a-z]{2}(?:b|c)}.ccc`, + shouldMatch: true, + }, + { + title: "Host route with multiple hyphenated names and patterns, match", + route: new(Route).Host("{v-1:[a-z]{3}}.{v-2:[a-z]{3}}.{v-3:[a-z]{3}}"), + request: newRequest("GET", "http://aaa.bbb.ccc/111/222/333"), + vars: map[string]string{"v-1": "aaa", "v-2": "bbb", "v-3": "ccc"}, + host: "aaa.bbb.ccc", + path: "", + hostTemplate: `{v-1:[a-z]{3}}.{v-2:[a-z]{3}}.{v-3:[a-z]{3}}`, + shouldMatch: true, + }, + { + title: "Path route with single pattern with pipe, match", + route: new(Route).Path("/{category:a|b/c}"), + request: newRequest("GET", "http://localhost/a"), + vars: map[string]string{"category": "a"}, + host: "", + path: "/a", + pathTemplate: `/{category:a|b/c}`, + shouldMatch: true, + }, + { + title: "Path route with single pattern with pipe, match", + route: new(Route).Path("/{category:a|b/c}"), + request: newRequest("GET", "http://localhost/b/c"), + vars: map[string]string{"category": "b/c"}, + host: "", + path: "/b/c", + pathTemplate: `/{category:a|b/c}`, + shouldMatch: true, + }, + { + title: "Path route with multiple patterns with pipe, match", + route: new(Route).Path("/{category:a|b/c}/{product}/{id:[0-9]+}"), + request: newRequest("GET", "http://localhost/a/product_name/1"), + vars: map[string]string{"category": "a", "product": "product_name", "id": "1"}, + host: "", + path: "/a/product_name/1", + pathTemplate: `/{category:a|b/c}/{product}/{id:[0-9]+}`, + shouldMatch: true, + }, + { + title: "Path route with multiple patterns with pipe, match", + route: new(Route).Path("/{category:a|b/c}/{product}/{id:[0-9]+}"), + request: newRequest("GET", "http://localhost/b/c/product_name/1"), + vars: map[string]string{"category": "b/c", "product": "product_name", "id": "1"}, + host: "", + path: "/b/c/product_name/1", + pathTemplate: `/{category:a|b/c}/{product}/{id:[0-9]+}`, + shouldMatch: true, }, } for _, test := range tests { @@ -262,24 +263,48 @@ func TestPath(t *testing.T) { shouldMatch: true, }, { - title: "Path route, do not match with trailing slash in path", - route: new(Route).Path("/111/"), - request: newRequest("GET", "http://localhost/111"), - vars: map[string]string{}, - host: "", - path: "/111", - path_template: `/111/`, - shouldMatch: false, - }, - { - title: "Path route, do not match with trailing slash in request", - route: new(Route).Path("/111"), - request: newRequest("GET", "http://localhost/111/"), - vars: map[string]string{}, - host: "", - path: "/111/", - path_template: `/111`, - shouldMatch: false, + title: "Path route, do not match with trailing slash in path", + route: new(Route).Path("/111/"), + request: newRequest("GET", "http://localhost/111"), + vars: map[string]string{}, + host: "", + path: "/111", + pathTemplate: `/111/`, + shouldMatch: false, + }, + { + title: "Path route, do not match with trailing slash in request", + route: new(Route).Path("/111"), + request: newRequest("GET", "http://localhost/111/"), + vars: map[string]string{}, + host: "", + path: "/111/", + pathTemplate: `/111`, + shouldMatch: false, + }, + { + title: "Path route, match root with no host", + route: new(Route).Path("/"), + request: newRequest("GET", "/"), + vars: map[string]string{}, + host: "", + path: "/", + pathTemplate: `/`, + shouldMatch: true, + }, + { + title: "Path route, match root with no host, App Engine format", + route: new(Route).Path("/"), + request: func() *http.Request { + r := newRequest("GET", "http://localhost/") + r.RequestURI = "/" + return r + }(), + vars: map[string]string{}, + host: "", + path: "/", + pathTemplate: `/`, + shouldMatch: true, }, { title: "Path route, wrong path in request in request URL", @@ -291,100 +316,111 @@ func TestPath(t *testing.T) { shouldMatch: false, }, { - title: "Path route with pattern, match", - route: new(Route).Path("/111/{v1:[0-9]{3}}/333"), - request: newRequest("GET", "http://localhost/111/222/333"), - vars: map[string]string{"v1": "222"}, - host: "", - path: "/111/222/333", - path_template: `/111/{v1:[0-9]{3}}/333`, - shouldMatch: true, - }, - { - title: "Path route with pattern, URL in request does not match", - route: new(Route).Path("/111/{v1:[0-9]{3}}/333"), - request: newRequest("GET", "http://localhost/111/aaa/333"), - vars: map[string]string{"v1": "222"}, - host: "", - path: "/111/222/333", - path_template: `/111/{v1:[0-9]{3}}/333`, - shouldMatch: false, - }, - { - title: "Path route with multiple patterns, match", - route: new(Route).Path("/{v1:[0-9]{3}}/{v2:[0-9]{3}}/{v3:[0-9]{3}}"), - request: newRequest("GET", "http://localhost/111/222/333"), - vars: map[string]string{"v1": "111", "v2": "222", "v3": "333"}, - host: "", - path: "/111/222/333", - path_template: `/{v1:[0-9]{3}}/{v2:[0-9]{3}}/{v3:[0-9]{3}}`, - shouldMatch: true, - }, - { - title: "Path route with multiple patterns, URL in request does not match", - route: new(Route).Path("/{v1:[0-9]{3}}/{v2:[0-9]{3}}/{v3:[0-9]{3}}"), - request: newRequest("GET", "http://localhost/111/aaa/333"), - vars: map[string]string{"v1": "111", "v2": "222", "v3": "333"}, - host: "", - path: "/111/222/333", - path_template: `/{v1:[0-9]{3}}/{v2:[0-9]{3}}/{v3:[0-9]{3}}`, - shouldMatch: false, - }, - { - title: "Path route with multiple patterns with pipe, match", - route: new(Route).Path("/{category:a|(b/c)}/{product}/{id:[0-9]+}"), - request: newRequest("GET", "http://localhost/a/product_name/1"), - vars: map[string]string{"category": "a", "product": "product_name", "id": "1"}, - host: "", - path: "/a/product_name/1", - path_template: `/{category:a|(b/c)}/{product}/{id:[0-9]+}`, - shouldMatch: true, - }, - { - title: "Path route with hyphenated name and pattern, match", - route: new(Route).Path("/111/{v-1:[0-9]{3}}/333"), - request: newRequest("GET", "http://localhost/111/222/333"), - vars: map[string]string{"v-1": "222"}, - host: "", - path: "/111/222/333", - path_template: `/111/{v-1:[0-9]{3}}/333`, - shouldMatch: true, - }, - { - title: "Path route with multiple hyphenated names and patterns, match", - route: new(Route).Path("/{v-1:[0-9]{3}}/{v-2:[0-9]{3}}/{v-3:[0-9]{3}}"), - request: newRequest("GET", "http://localhost/111/222/333"), - vars: map[string]string{"v-1": "111", "v-2": "222", "v-3": "333"}, - host: "", - path: "/111/222/333", - path_template: `/{v-1:[0-9]{3}}/{v-2:[0-9]{3}}/{v-3:[0-9]{3}}`, - shouldMatch: true, - }, - { - title: "Path route with multiple hyphenated names and patterns with pipe, match", - route: new(Route).Path("/{product-category:a|(b/c)}/{product-name}/{product-id:[0-9]+}"), - request: newRequest("GET", "http://localhost/a/product_name/1"), - vars: map[string]string{"product-category": "a", "product-name": "product_name", "product-id": "1"}, - host: "", - path: "/a/product_name/1", - path_template: `/{product-category:a|(b/c)}/{product-name}/{product-id:[0-9]+}`, - shouldMatch: true, - }, - { - title: "Path route with multiple hyphenated names and patterns with pipe and case insensitive, match", - route: new(Route).Path("/{type:(?i:daily|mini|variety)}-{date:\\d{4,4}-\\d{2,2}-\\d{2,2}}"), - request: newRequest("GET", "http://localhost/daily-2016-01-01"), - vars: map[string]string{"type": "daily", "date": "2016-01-01"}, - host: "", - path: "/daily-2016-01-01", - path_template: `/{type:(?i:daily|mini|variety)}-{date:\d{4,4}-\d{2,2}-\d{2,2}}`, - shouldMatch: true, + title: "Path route with pattern, match", + route: new(Route).Path("/111/{v1:[0-9]{3}}/333"), + request: newRequest("GET", "http://localhost/111/222/333"), + vars: map[string]string{"v1": "222"}, + host: "", + path: "/111/222/333", + pathTemplate: `/111/{v1:[0-9]{3}}/333`, + shouldMatch: true, + }, + { + title: "Path route with pattern, URL in request does not match", + route: new(Route).Path("/111/{v1:[0-9]{3}}/333"), + request: newRequest("GET", "http://localhost/111/aaa/333"), + vars: map[string]string{"v1": "222"}, + host: "", + path: "/111/222/333", + pathTemplate: `/111/{v1:[0-9]{3}}/333`, + shouldMatch: false, + }, + { + title: "Path route with multiple patterns, match", + route: new(Route).Path("/{v1:[0-9]{3}}/{v2:[0-9]{3}}/{v3:[0-9]{3}}"), + request: newRequest("GET", "http://localhost/111/222/333"), + vars: map[string]string{"v1": "111", "v2": "222", "v3": "333"}, + host: "", + path: "/111/222/333", + pathTemplate: `/{v1:[0-9]{3}}/{v2:[0-9]{3}}/{v3:[0-9]{3}}`, + shouldMatch: true, + }, + { + title: "Path route with multiple patterns, URL in request does not match", + route: new(Route).Path("/{v1:[0-9]{3}}/{v2:[0-9]{3}}/{v3:[0-9]{3}}"), + request: newRequest("GET", "http://localhost/111/aaa/333"), + vars: map[string]string{"v1": "111", "v2": "222", "v3": "333"}, + host: "", + path: "/111/222/333", + pathTemplate: `/{v1:[0-9]{3}}/{v2:[0-9]{3}}/{v3:[0-9]{3}}`, + shouldMatch: false, + }, + { + title: "Path route with multiple patterns with pipe, match", + route: new(Route).Path("/{category:a|(?:b/c)}/{product}/{id:[0-9]+}"), + request: newRequest("GET", "http://localhost/a/product_name/1"), + vars: map[string]string{"category": "a", "product": "product_name", "id": "1"}, + host: "", + path: "/a/product_name/1", + pathTemplate: `/{category:a|(?:b/c)}/{product}/{id:[0-9]+}`, + shouldMatch: true, + }, + { + title: "Path route with hyphenated name and pattern, match", + route: new(Route).Path("/111/{v-1:[0-9]{3}}/333"), + request: newRequest("GET", "http://localhost/111/222/333"), + vars: map[string]string{"v-1": "222"}, + host: "", + path: "/111/222/333", + pathTemplate: `/111/{v-1:[0-9]{3}}/333`, + shouldMatch: true, + }, + { + title: "Path route with multiple hyphenated names and patterns, match", + route: new(Route).Path("/{v-1:[0-9]{3}}/{v-2:[0-9]{3}}/{v-3:[0-9]{3}}"), + request: newRequest("GET", "http://localhost/111/222/333"), + vars: map[string]string{"v-1": "111", "v-2": "222", "v-3": "333"}, + host: "", + path: "/111/222/333", + pathTemplate: `/{v-1:[0-9]{3}}/{v-2:[0-9]{3}}/{v-3:[0-9]{3}}`, + shouldMatch: true, + }, + { + title: "Path route with multiple hyphenated names and patterns with pipe, match", + route: new(Route).Path("/{product-category:a|(?:b/c)}/{product-name}/{product-id:[0-9]+}"), + request: newRequest("GET", "http://localhost/a/product_name/1"), + vars: map[string]string{"product-category": "a", "product-name": "product_name", "product-id": "1"}, + host: "", + path: "/a/product_name/1", + pathTemplate: `/{product-category:a|(?:b/c)}/{product-name}/{product-id:[0-9]+}`, + shouldMatch: true, + }, + { + title: "Path route with multiple hyphenated names and patterns with pipe and case insensitive, match", + route: new(Route).Path("/{type:(?i:daily|mini|variety)}-{date:\\d{4,4}-\\d{2,2}-\\d{2,2}}"), + request: newRequest("GET", "http://localhost/daily-2016-01-01"), + vars: map[string]string{"type": "daily", "date": "2016-01-01"}, + host: "", + path: "/daily-2016-01-01", + pathTemplate: `/{type:(?i:daily|mini|variety)}-{date:\d{4,4}-\d{2,2}-\d{2,2}}`, + shouldMatch: true, + }, + { + title: "Path route with empty match right after other match", + route: new(Route).Path(`/{v1:[0-9]*}{v2:[a-z]*}/{v3:[0-9]*}`), + request: newRequest("GET", "http://localhost/111/222"), + vars: map[string]string{"v1": "111", "v2": "", "v3": "222"}, + host: "", + path: "/111/222", + pathTemplate: `/{v1:[0-9]*}{v2:[a-z]*}/{v3:[0-9]*}`, + shouldMatch: true, }, } for _, test := range tests { testRoute(t, test) testTemplate(t, test) + testUseEscapedRoute(t, test) } } @@ -418,126 +454,128 @@ func TestPathPrefix(t *testing.T) { shouldMatch: false, }, { - title: "PathPrefix route with pattern, match", - route: new(Route).PathPrefix("/111/{v1:[0-9]{3}}"), - request: newRequest("GET", "http://localhost/111/222/333"), - vars: map[string]string{"v1": "222"}, - host: "", - path: "/111/222", - path_template: `/111/{v1:[0-9]{3}}`, - shouldMatch: true, - }, - { - title: "PathPrefix route with pattern, URL prefix in request does not match", - route: new(Route).PathPrefix("/111/{v1:[0-9]{3}}"), - request: newRequest("GET", "http://localhost/111/aaa/333"), - vars: map[string]string{"v1": "222"}, - host: "", - path: "/111/222", - path_template: `/111/{v1:[0-9]{3}}`, - shouldMatch: false, - }, - { - title: "PathPrefix route with multiple patterns, match", - route: new(Route).PathPrefix("/{v1:[0-9]{3}}/{v2:[0-9]{3}}"), - request: newRequest("GET", "http://localhost/111/222/333"), - vars: map[string]string{"v1": "111", "v2": "222"}, - host: "", - path: "/111/222", - path_template: `/{v1:[0-9]{3}}/{v2:[0-9]{3}}`, - shouldMatch: true, - }, - { - title: "PathPrefix route with multiple patterns, URL prefix in request does not match", - route: new(Route).PathPrefix("/{v1:[0-9]{3}}/{v2:[0-9]{3}}"), - request: newRequest("GET", "http://localhost/111/aaa/333"), - vars: map[string]string{"v1": "111", "v2": "222"}, - host: "", - path: "/111/222", - path_template: `/{v1:[0-9]{3}}/{v2:[0-9]{3}}`, - shouldMatch: false, + title: "PathPrefix route with pattern, match", + route: new(Route).PathPrefix("/111/{v1:[0-9]{3}}"), + request: newRequest("GET", "http://localhost/111/222/333"), + vars: map[string]string{"v1": "222"}, + host: "", + path: "/111/222", + pathTemplate: `/111/{v1:[0-9]{3}}`, + shouldMatch: true, + }, + { + title: "PathPrefix route with pattern, URL prefix in request does not match", + route: new(Route).PathPrefix("/111/{v1:[0-9]{3}}"), + request: newRequest("GET", "http://localhost/111/aaa/333"), + vars: map[string]string{"v1": "222"}, + host: "", + path: "/111/222", + pathTemplate: `/111/{v1:[0-9]{3}}`, + shouldMatch: false, + }, + { + title: "PathPrefix route with multiple patterns, match", + route: new(Route).PathPrefix("/{v1:[0-9]{3}}/{v2:[0-9]{3}}"), + request: newRequest("GET", "http://localhost/111/222/333"), + vars: map[string]string{"v1": "111", "v2": "222"}, + host: "", + path: "/111/222", + pathTemplate: `/{v1:[0-9]{3}}/{v2:[0-9]{3}}`, + shouldMatch: true, + }, + { + title: "PathPrefix route with multiple patterns, URL prefix in request does not match", + route: new(Route).PathPrefix("/{v1:[0-9]{3}}/{v2:[0-9]{3}}"), + request: newRequest("GET", "http://localhost/111/aaa/333"), + vars: map[string]string{"v1": "111", "v2": "222"}, + host: "", + path: "/111/222", + pathTemplate: `/{v1:[0-9]{3}}/{v2:[0-9]{3}}`, + shouldMatch: false, }, } for _, test := range tests { testRoute(t, test) testTemplate(t, test) + testUseEscapedRoute(t, test) } } func TestHostPath(t *testing.T) { tests := []routeTest{ { - title: "Host and Path route, match", - route: new(Route).Host("aaa.bbb.ccc").Path("/111/222/333"), - request: newRequest("GET", "http://aaa.bbb.ccc/111/222/333"), - vars: map[string]string{}, - host: "", - path: "", - path_template: `/111/222/333`, - host_template: `aaa.bbb.ccc`, - shouldMatch: true, - }, - { - title: "Host and Path route, wrong host in request URL", - route: new(Route).Host("aaa.bbb.ccc").Path("/111/222/333"), - request: newRequest("GET", "http://aaa.222.ccc/111/222/333"), - vars: map[string]string{}, - host: "", - path: "", - path_template: `/111/222/333`, - host_template: `aaa.bbb.ccc`, - shouldMatch: false, - }, - { - title: "Host and Path route with pattern, match", - route: new(Route).Host("aaa.{v1:[a-z]{3}}.ccc").Path("/111/{v2:[0-9]{3}}/333"), - request: newRequest("GET", "http://aaa.bbb.ccc/111/222/333"), - vars: map[string]string{"v1": "bbb", "v2": "222"}, - host: "aaa.bbb.ccc", - path: "/111/222/333", - path_template: `/111/{v2:[0-9]{3}}/333`, - host_template: `aaa.{v1:[a-z]{3}}.ccc`, - shouldMatch: true, - }, - { - title: "Host and Path route with pattern, URL in request does not match", - route: new(Route).Host("aaa.{v1:[a-z]{3}}.ccc").Path("/111/{v2:[0-9]{3}}/333"), - request: newRequest("GET", "http://aaa.222.ccc/111/222/333"), - vars: map[string]string{"v1": "bbb", "v2": "222"}, - host: "aaa.bbb.ccc", - path: "/111/222/333", - path_template: `/111/{v2:[0-9]{3}}/333`, - host_template: `aaa.{v1:[a-z]{3}}.ccc`, - shouldMatch: false, - }, - { - title: "Host and Path route with multiple patterns, match", - route: new(Route).Host("{v1:[a-z]{3}}.{v2:[a-z]{3}}.{v3:[a-z]{3}}").Path("/{v4:[0-9]{3}}/{v5:[0-9]{3}}/{v6:[0-9]{3}}"), - request: newRequest("GET", "http://aaa.bbb.ccc/111/222/333"), - vars: map[string]string{"v1": "aaa", "v2": "bbb", "v3": "ccc", "v4": "111", "v5": "222", "v6": "333"}, - host: "aaa.bbb.ccc", - path: "/111/222/333", - path_template: `/{v4:[0-9]{3}}/{v5:[0-9]{3}}/{v6:[0-9]{3}}`, - host_template: `{v1:[a-z]{3}}.{v2:[a-z]{3}}.{v3:[a-z]{3}}`, - shouldMatch: true, - }, - { - title: "Host and Path route with multiple patterns, URL in request does not match", - route: new(Route).Host("{v1:[a-z]{3}}.{v2:[a-z]{3}}.{v3:[a-z]{3}}").Path("/{v4:[0-9]{3}}/{v5:[0-9]{3}}/{v6:[0-9]{3}}"), - request: newRequest("GET", "http://aaa.222.ccc/111/222/333"), - vars: map[string]string{"v1": "aaa", "v2": "bbb", "v3": "ccc", "v4": "111", "v5": "222", "v6": "333"}, - host: "aaa.bbb.ccc", - path: "/111/222/333", - path_template: `/{v4:[0-9]{3}}/{v5:[0-9]{3}}/{v6:[0-9]{3}}`, - host_template: `{v1:[a-z]{3}}.{v2:[a-z]{3}}.{v3:[a-z]{3}}`, - shouldMatch: false, + title: "Host and Path route, match", + route: new(Route).Host("aaa.bbb.ccc").Path("/111/222/333"), + request: newRequest("GET", "http://aaa.bbb.ccc/111/222/333"), + vars: map[string]string{}, + host: "", + path: "", + pathTemplate: `/111/222/333`, + hostTemplate: `aaa.bbb.ccc`, + shouldMatch: true, + }, + { + title: "Host and Path route, wrong host in request URL", + route: new(Route).Host("aaa.bbb.ccc").Path("/111/222/333"), + request: newRequest("GET", "http://aaa.222.ccc/111/222/333"), + vars: map[string]string{}, + host: "", + path: "", + pathTemplate: `/111/222/333`, + hostTemplate: `aaa.bbb.ccc`, + shouldMatch: false, + }, + { + title: "Host and Path route with pattern, match", + route: new(Route).Host("aaa.{v1:[a-z]{3}}.ccc").Path("/111/{v2:[0-9]{3}}/333"), + request: newRequest("GET", "http://aaa.bbb.ccc/111/222/333"), + vars: map[string]string{"v1": "bbb", "v2": "222"}, + host: "aaa.bbb.ccc", + path: "/111/222/333", + pathTemplate: `/111/{v2:[0-9]{3}}/333`, + hostTemplate: `aaa.{v1:[a-z]{3}}.ccc`, + shouldMatch: true, + }, + { + title: "Host and Path route with pattern, URL in request does not match", + route: new(Route).Host("aaa.{v1:[a-z]{3}}.ccc").Path("/111/{v2:[0-9]{3}}/333"), + request: newRequest("GET", "http://aaa.222.ccc/111/222/333"), + vars: map[string]string{"v1": "bbb", "v2": "222"}, + host: "aaa.bbb.ccc", + path: "/111/222/333", + pathTemplate: `/111/{v2:[0-9]{3}}/333`, + hostTemplate: `aaa.{v1:[a-z]{3}}.ccc`, + shouldMatch: false, + }, + { + title: "Host and Path route with multiple patterns, match", + route: new(Route).Host("{v1:[a-z]{3}}.{v2:[a-z]{3}}.{v3:[a-z]{3}}").Path("/{v4:[0-9]{3}}/{v5:[0-9]{3}}/{v6:[0-9]{3}}"), + request: newRequest("GET", "http://aaa.bbb.ccc/111/222/333"), + vars: map[string]string{"v1": "aaa", "v2": "bbb", "v3": "ccc", "v4": "111", "v5": "222", "v6": "333"}, + host: "aaa.bbb.ccc", + path: "/111/222/333", + pathTemplate: `/{v4:[0-9]{3}}/{v5:[0-9]{3}}/{v6:[0-9]{3}}`, + hostTemplate: `{v1:[a-z]{3}}.{v2:[a-z]{3}}.{v3:[a-z]{3}}`, + shouldMatch: true, + }, + { + title: "Host and Path route with multiple patterns, URL in request does not match", + route: new(Route).Host("{v1:[a-z]{3}}.{v2:[a-z]{3}}.{v3:[a-z]{3}}").Path("/{v4:[0-9]{3}}/{v5:[0-9]{3}}/{v6:[0-9]{3}}"), + request: newRequest("GET", "http://aaa.222.ccc/111/222/333"), + vars: map[string]string{"v1": "aaa", "v2": "bbb", "v3": "ccc", "v4": "111", "v5": "222", "v6": "333"}, + host: "aaa.bbb.ccc", + path: "/111/222/333", + pathTemplate: `/{v4:[0-9]{3}}/{v5:[0-9]{3}}/{v6:[0-9]{3}}`, + hostTemplate: `{v1:[a-z]{3}}.{v2:[a-z]{3}}.{v3:[a-z]{3}}`, + shouldMatch: false, }, } for _, test := range tests { testRoute(t, test) testTemplate(t, test) + testUseEscapedRoute(t, test) } } @@ -649,26 +687,26 @@ func TestQueries(t *testing.T) { shouldMatch: true, }, { - title: "Queries route, match with a query string", - route: new(Route).Host("www.example.com").Path("/api").Queries("foo", "bar", "baz", "ding"), - request: newRequest("GET", "http://www.example.com/api?foo=bar&baz=ding"), - vars: map[string]string{}, - host: "", - path: "", - path_template: `/api`, - host_template: `www.example.com`, - shouldMatch: true, + title: "Queries route, match with a query string", + route: new(Route).Host("www.example.com").Path("/api").Queries("foo", "bar", "baz", "ding"), + request: newRequest("GET", "http://www.example.com/api?foo=bar&baz=ding"), + vars: map[string]string{}, + host: "", + path: "", + pathTemplate: `/api`, + hostTemplate: `www.example.com`, + shouldMatch: true, }, { - title: "Queries route, match with a query string out of order", - route: new(Route).Host("www.example.com").Path("/api").Queries("foo", "bar", "baz", "ding"), - request: newRequest("GET", "http://www.example.com/api?baz=ding&foo=bar"), - vars: map[string]string{}, - host: "", - path: "", - path_template: `/api`, - host_template: `www.example.com`, - shouldMatch: true, + title: "Queries route, match with a query string out of order", + route: new(Route).Host("www.example.com").Path("/api").Queries("foo", "bar", "baz", "ding"), + request: newRequest("GET", "http://www.example.com/api?baz=ding&foo=bar"), + vars: map[string]string{}, + host: "", + path: "", + pathTemplate: `/api`, + hostTemplate: `www.example.com`, + shouldMatch: true, }, { title: "Queries route, bad query", @@ -744,7 +782,7 @@ func TestQueries(t *testing.T) { }, { title: "Queries route with regexp pattern with quantifier, additional capturing group", - route: new(Route).Queries("foo", "{v1:[0-9]{1}(a|b)}"), + route: new(Route).Queries("foo", "{v1:[0-9]{1}(?:a|b)}"), request: newRequest("GET", "http://localhost?foo=1a"), vars: map[string]string{"v1": "1a"}, host: "", @@ -789,7 +827,7 @@ func TestQueries(t *testing.T) { }, { title: "Queries route with hyphenated name and pattern with quantifier, additional capturing group", - route: new(Route).Queries("foo", "{v-1:[0-9]{1}(a|b)}"), + route: new(Route).Queries("foo", "{v-1:[0-9]{1}(?:a|b)}"), request: newRequest("GET", "http://localhost?foo=1a"), vars: map[string]string{"v-1": "1a"}, host: "", @@ -864,6 +902,7 @@ func TestQueries(t *testing.T) { for _, test := range tests { testRoute(t, test) testTemplate(t, test) + testUseEscapedRoute(t, test) } } @@ -948,10 +987,10 @@ func TestBuildVarsFunc(t *testing.T) { vars["v2"] = "a" return vars }), - request: newRequest("GET", "http://localhost/111/2"), - path: "/111/3a", - path_template: `/111/{v1:\d}{v2:.*}`, - shouldMatch: true, + request: newRequest("GET", "http://localhost/111/2"), + path: "/111/3a", + pathTemplate: `/111/{v1:\d}{v2:.*}`, + shouldMatch: true, }, { title: "BuildVarsFunc set on route and parent route", @@ -962,10 +1001,10 @@ func TestBuildVarsFunc(t *testing.T) { vars["v2"] = "b" return vars }), - request: newRequest("GET", "http://localhost/1/a"), - path: "/2/b", - path_template: `/{v1:\d}/{v2:\w}`, - shouldMatch: true, + request: newRequest("GET", "http://localhost/1/a"), + path: "/2/b", + pathTemplate: `/{v1:\d}/{v2:\w}`, + shouldMatch: true, }, } @@ -981,48 +1020,49 @@ func TestSubRouter(t *testing.T) { tests := []routeTest{ { - route: subrouter1.Path("/{v2:[a-z]+}"), - request: newRequest("GET", "http://aaa.google.com/bbb"), - vars: map[string]string{"v1": "aaa", "v2": "bbb"}, - host: "aaa.google.com", - path: "/bbb", - path_template: `/{v2:[a-z]+}`, - host_template: `{v1:[a-z]+}.google.com`, - shouldMatch: true, - }, - { - route: subrouter1.Path("/{v2:[a-z]+}"), - request: newRequest("GET", "http://111.google.com/111"), - vars: map[string]string{"v1": "aaa", "v2": "bbb"}, - host: "aaa.google.com", - path: "/bbb", - path_template: `/{v2:[a-z]+}`, - host_template: `{v1:[a-z]+}.google.com`, - shouldMatch: false, - }, - { - route: subrouter2.Path("/baz/{v2}"), - request: newRequest("GET", "http://localhost/foo/bar/baz/ding"), - vars: map[string]string{"v1": "bar", "v2": "ding"}, - host: "", - path: "/foo/bar/baz/ding", - path_template: `/foo/{v1}/baz/{v2}`, - shouldMatch: true, - }, - { - route: subrouter2.Path("/baz/{v2}"), - request: newRequest("GET", "http://localhost/foo/bar"), - vars: map[string]string{"v1": "bar", "v2": "ding"}, - host: "", - path: "/foo/bar/baz/ding", - path_template: `/foo/{v1}/baz/{v2}`, - shouldMatch: false, + route: subrouter1.Path("/{v2:[a-z]+}"), + request: newRequest("GET", "http://aaa.google.com/bbb"), + vars: map[string]string{"v1": "aaa", "v2": "bbb"}, + host: "aaa.google.com", + path: "/bbb", + pathTemplate: `/{v2:[a-z]+}`, + hostTemplate: `{v1:[a-z]+}.google.com`, + shouldMatch: true, + }, + { + route: subrouter1.Path("/{v2:[a-z]+}"), + request: newRequest("GET", "http://111.google.com/111"), + vars: map[string]string{"v1": "aaa", "v2": "bbb"}, + host: "aaa.google.com", + path: "/bbb", + pathTemplate: `/{v2:[a-z]+}`, + hostTemplate: `{v1:[a-z]+}.google.com`, + shouldMatch: false, + }, + { + route: subrouter2.Path("/baz/{v2}"), + request: newRequest("GET", "http://localhost/foo/bar/baz/ding"), + vars: map[string]string{"v1": "bar", "v2": "ding"}, + host: "", + path: "/foo/bar/baz/ding", + pathTemplate: `/foo/{v1}/baz/{v2}`, + shouldMatch: true, + }, + { + route: subrouter2.Path("/baz/{v2}"), + request: newRequest("GET", "http://localhost/foo/bar"), + vars: map[string]string{"v1": "bar", "v2": "ding"}, + host: "", + path: "/foo/bar/baz/ding", + pathTemplate: `/foo/{v1}/baz/{v2}`, + shouldMatch: false, }, } for _, test := range tests { testRoute(t, test) testTemplate(t, test) + testUseEscapedRoute(t, test) } } @@ -1119,6 +1159,40 @@ func TestStrictSlash(t *testing.T) { for _, test := range tests { testRoute(t, test) testTemplate(t, test) + testUseEscapedRoute(t, test) + } +} + +func TestUseEncodedPath(t *testing.T) { + r := NewRouter() + r.UseEncodedPath() + + tests := []routeTest{ + { + title: "Router with useEncodedPath, URL with encoded slash does match", + route: r.NewRoute().Path("/v1/{v1}/v2"), + request: newRequest("GET", "http://localhost/v1/1%2F2/v2"), + vars: map[string]string{"v1": "1%2F2"}, + host: "", + path: "/v1/1%2F2/v2", + pathTemplate: `/v1/{v1}/v2`, + shouldMatch: true, + }, + { + title: "Router with useEncodedPath, URL with encoded slash doesn't match", + route: r.NewRoute().Path("/v1/1/2/v2"), + request: newRequest("GET", "http://localhost/v1/1%2F2/v2"), + vars: map[string]string{"v1": "1%2F2"}, + host: "", + path: "/v1/1%2F2/v2", + pathTemplate: `/v1/1/2/v2`, + shouldMatch: false, + }, + } + + for _, test := range tests { + testRoute(t, test) + testTemplate(t, test) } } @@ -1197,6 +1271,42 @@ func TestWalkNested(t *testing.T) { } } +func TestWalkErrorRoute(t *testing.T) { + router := NewRouter() + router.Path("/g") + expectedError := errors.New("error") + err := router.Walk(func(route *Route, router *Router, ancestors []*Route) error { + return expectedError + }) + if err != expectedError { + t.Errorf("Expected %v routes, found %v", expectedError, err) + } +} + +func TestWalkErrorMatcher(t *testing.T) { + router := NewRouter() + expectedError := router.Path("/g").Subrouter().Path("").GetError() + err := router.Walk(func(route *Route, router *Router, ancestors []*Route) error { + return route.GetError() + }) + if err != expectedError { + t.Errorf("Expected %v routes, found %v", expectedError, err) + } +} + +func TestWalkErrorHandler(t *testing.T) { + handler := NewRouter() + expectedError := handler.Path("/path").Subrouter().Path("").GetError() + router := NewRouter() + router.Path("/g").Handler(handler) + err := router.Walk(func(route *Route, router *Router, ancestors []*Route) error { + return route.GetError() + }) + if err != expectedError { + t.Errorf("Expected %v routes, found %v", expectedError, err) + } +} + func TestSubrouterErrorHandling(t *testing.T) { superRouterCalled := false subRouterCalled := false @@ -1294,56 +1404,31 @@ func testRoute(t *testing.T, test routeTest) { } } +func testUseEscapedRoute(t *testing.T, test routeTest) { + test.route.useEncodedPath = true + testRoute(t, test) +} + func testTemplate(t *testing.T, test routeTest) { route := test.route - path_template := test.path_template - if len(path_template) == 0 { - path_template = test.path - } - host_template := test.host_template - if len(host_template) == 0 { - host_template = test.host + pathTemplate := test.pathTemplate + if len(pathTemplate) == 0 { + pathTemplate = test.path } - - path_tmpl, path_err := route.GetPathTemplate() - if path_err == nil && path_tmpl != path_template { - t.Errorf("(%v) GetPathTemplate not equal: expected %v, got %v", test.title, path_template, path_tmpl) + hostTemplate := test.hostTemplate + if len(hostTemplate) == 0 { + hostTemplate = test.host } - host_tmpl, host_err := route.GetHostTemplate() - if host_err == nil && host_tmpl != host_template { - t.Errorf("(%v) GetHostTemplate not equal: expected %v, got %v", test.title, host_template, host_tmpl) + routePathTemplate, pathErr := route.GetPathTemplate() + if pathErr == nil && routePathTemplate != pathTemplate { + t.Errorf("(%v) GetPathTemplate not equal: expected %v, got %v", test.title, pathTemplate, routePathTemplate) } -} - -// Tests that the context is cleared or not cleared properly depending on -// the configuration of the router -func TestKeepContext(t *testing.T) { - func1 := func(w http.ResponseWriter, r *http.Request) {} - - r := NewRouter() - r.HandleFunc("/", func1).Name("func1") - - req, _ := http.NewRequest("GET", "http://localhost/", nil) - context.Set(req, "t", 1) - - res := new(http.ResponseWriter) - r.ServeHTTP(*res, req) - if _, ok := context.GetOk(req, "t"); ok { - t.Error("Context should have been cleared at end of request") + routeHostTemplate, hostErr := route.GetHostTemplate() + if hostErr == nil && routeHostTemplate != hostTemplate { + t.Errorf("(%v) GetHostTemplate not equal: expected %v, got %v", test.title, hostTemplate, routeHostTemplate) } - - r.KeepContext = true - - req, _ = http.NewRequest("GET", "http://localhost/", nil) - context.Set(req, "t", 1) - - r.ServeHTTP(*res, req) - if _, ok := context.GetOk(req, "t"); !ok { - t.Error("Context should NOT have been cleared at end of request") - } - } type TestA301ResponseWriter struct { @@ -1386,6 +1471,24 @@ func Test301Redirect(t *testing.T) { } } +func TestSkipClean(t *testing.T) { + func1 := func(w http.ResponseWriter, r *http.Request) {} + func2 := func(w http.ResponseWriter, r *http.Request) {} + + r := NewRouter() + r.SkipClean(true) + r.HandleFunc("/api/", func2).Name("func2") + r.HandleFunc("/", func1).Name("func1") + + req, _ := http.NewRequest("GET", "http://localhost//api/?abc=def", nil) + res := NewRecorder() + r.ServeHTTP(res, req) + + if len(res.HeaderMap["Location"]) != 0 { + t.Errorf("Shouldn't redirect since skip clean is disabled") + } +} + // https://plus.google.com/101022900381697718949/posts/eWy6DjFJ6uW func TestSubrouterHeader(t *testing.T) { expected := "func1 response" @@ -1443,11 +1546,42 @@ func stringMapEqual(m1, m2 map[string]string) bool { return true } -// newRequest is a helper function to create a new request with a method and url +// newRequest is a helper function to create a new request with a method and url. +// The request returned is a 'server' request as opposed to a 'client' one through +// simulated write onto the wire and read off of the wire. +// The differences between requests are detailed in the net/http package. func newRequest(method, url string) *http.Request { req, err := http.NewRequest(method, url, nil) if err != nil { panic(err) } + // extract the escaped original host+path from url + // http://localhost/path/here?v=1#frag -> //localhost/path/here + opaque := "" + if i := len(req.URL.Scheme); i > 0 { + opaque = url[i+1:] + } + + if i := strings.LastIndex(opaque, "?"); i > -1 { + opaque = opaque[:i] + } + if i := strings.LastIndex(opaque, "#"); i > -1 { + opaque = opaque[:i] + } + + // Escaped host+path workaround as detailed in https://golang.org/pkg/net/url/#URL + // for < 1.5 client side workaround + req.URL.Opaque = opaque + + // Simulate writing to wire + var buff bytes.Buffer + req.Write(&buff) + ioreader := bufio.NewReader(&buff) + + // Parse request off of 'wire' + req, err = http.ReadRequest(ioreader) + if err != nil { + panic(err) + } return req } diff --git a/vendor/github.com/gorilla/mux/old_test.go b/vendor/github.com/gorilla/mux/old_test.go index c385a2519..9bdc5e5d1 100644 --- a/vendor/github.com/gorilla/mux/old_test.go +++ b/vendor/github.com/gorilla/mux/old_test.go @@ -687,7 +687,7 @@ func TestNewRegexp(t *testing.T) { } for pattern, paths := range tests { - p, _ = newRouteRegexp(pattern, false, false, false, false) + p, _ = newRouteRegexp(pattern, false, false, false, false, false) for path, result := range paths { matches = p.regexp.FindStringSubmatch(path) if result == nil { diff --git a/vendor/github.com/gorilla/mux/regexp.go b/vendor/github.com/gorilla/mux/regexp.go index 08710bc98..fd8fe3956 100644 --- a/vendor/github.com/gorilla/mux/regexp.go +++ b/vendor/github.com/gorilla/mux/regexp.go @@ -24,7 +24,7 @@ import ( // Previously we accepted only Python-like identifiers for variable // names ([a-zA-Z_][a-zA-Z0-9_]*), but currently the only restriction is that // name and pattern can't be empty, and names can't contain a colon. -func newRouteRegexp(tpl string, matchHost, matchPrefix, matchQuery, strictSlash bool) (*routeRegexp, error) { +func newRouteRegexp(tpl string, matchHost, matchPrefix, matchQuery, strictSlash, useEncodedPath bool) (*routeRegexp, error) { // Check if it is well-formed. idxs, errBraces := braceIndices(tpl) if errBraces != nil { @@ -111,14 +111,15 @@ func newRouteRegexp(tpl string, matchHost, matchPrefix, matchQuery, strictSlash } // Done! return &routeRegexp{ - template: template, - matchHost: matchHost, - matchQuery: matchQuery, - strictSlash: strictSlash, - regexp: reg, - reverse: reverse.String(), - varsN: varsN, - varsR: varsR, + template: template, + matchHost: matchHost, + matchQuery: matchQuery, + strictSlash: strictSlash, + useEncodedPath: useEncodedPath, + regexp: reg, + reverse: reverse.String(), + varsN: varsN, + varsR: varsR, }, nil } @@ -133,6 +134,9 @@ type routeRegexp struct { matchQuery bool // The strictSlash value defined on the route, but disabled if PathPrefix was used. strictSlash bool + // Determines whether to use encoded path from getPath function or unencoded + // req.URL.Path for path matching + useEncodedPath bool // Expanded regexp. regexp *regexp.Regexp // Reverse template. @@ -149,8 +153,11 @@ func (r *routeRegexp) Match(req *http.Request, match *RouteMatch) bool { if r.matchQuery { return r.matchQueryString(req) } - - return r.regexp.MatchString(req.URL.Path) + path := req.URL.Path + if r.useEncodedPath { + path = getPath(req) + } + return r.regexp.MatchString(path) } return r.regexp.MatchString(getHost(req)) @@ -253,14 +260,18 @@ func (v *routeRegexpGroup) setMatch(req *http.Request, m *RouteMatch, r *Route) extractVars(host, matches, v.host.varsN, m.Vars) } } + path := req.URL.Path + if r.useEncodedPath { + path = getPath(req) + } // Store path variables. if v.path != nil { - matches := v.path.regexp.FindStringSubmatchIndex(req.URL.Path) + matches := v.path.regexp.FindStringSubmatchIndex(path) if len(matches) > 0 { - extractVars(req.URL.Path, matches, v.path.varsN, m.Vars) + extractVars(path, matches, v.path.varsN, m.Vars) // Check if we should redirect. if v.path.strictSlash { - p1 := strings.HasSuffix(req.URL.Path, "/") + p1 := strings.HasSuffix(path, "/") p2 := strings.HasSuffix(v.path.template, "/") if p1 != p2 { u, _ := url.Parse(req.URL.String()) @@ -299,14 +310,7 @@ func getHost(r *http.Request) string { } func extractVars(input string, matches []int, names []string, output map[string]string) { - matchesCount := 0 - prevEnd := -1 - for i := 2; i < len(matches) && matchesCount < len(names); i += 2 { - if prevEnd < matches[i+1] { - value := input[matches[i]:matches[i+1]] - output[names[matchesCount]] = value - prevEnd = matches[i+1] - matchesCount++ - } + for i, name := range names { + output[name] = input[matches[2*i+2]:matches[2*i+3]] } } diff --git a/vendor/github.com/gorilla/mux/route.go b/vendor/github.com/gorilla/mux/route.go index bf92af261..293b6d493 100644 --- a/vendor/github.com/gorilla/mux/route.go +++ b/vendor/github.com/gorilla/mux/route.go @@ -26,6 +26,11 @@ type Route struct { // If true, when the path pattern is "/path/", accessing "/path" will // redirect to the former and vice versa. strictSlash bool + // If true, when the path pattern is "/path//to", accessing "/path//to" + // will not redirect + skipClean bool + // If true, "/path/foo%2Fbar/to" will match the path "/path/{var}/to" + useEncodedPath bool // If true, this route never matches: it is only used to build URLs. buildOnly bool // The name used to build URLs. @@ -36,6 +41,10 @@ type Route struct { buildVarsFunc BuildVarsFunc } +func (r *Route) SkipClean() bool { + return r.skipClean +} + // Match matches the route against the request. func (r *Route) Match(req *http.Request, match *RouteMatch) bool { if r.buildOnly || r.err != nil { @@ -151,7 +160,7 @@ func (r *Route) addRegexpMatcher(tpl string, matchHost, matchPrefix, matchQuery tpl = strings.TrimRight(r.regexp.path.template, "/") + tpl } } - rr, err := newRouteRegexp(tpl, matchHost, matchPrefix, matchQuery, r.strictSlash) + rr, err := newRouteRegexp(tpl, matchHost, matchPrefix, matchQuery, r.strictSlash, r.useEncodedPath) if err != nil { return err } diff --git a/vendor/github.com/gorilla/websocket/.gitignore b/vendor/github.com/gorilla/websocket/.gitignore index 00268614f..ac710204f 100644 --- a/vendor/github.com/gorilla/websocket/.gitignore +++ b/vendor/github.com/gorilla/websocket/.gitignore @@ -20,3 +20,6 @@ _cgo_export.* _testmain.go *.exe + +.idea/ +*.iml
\ No newline at end of file diff --git a/vendor/github.com/gorilla/websocket/.travis.yml b/vendor/github.com/gorilla/websocket/.travis.yml index 66435ac0b..4ea1e7a1f 100644 --- a/vendor/github.com/gorilla/websocket/.travis.yml +++ b/vendor/github.com/gorilla/websocket/.travis.yml @@ -6,6 +6,7 @@ matrix: - go: 1.4 - go: 1.5 - go: 1.6 + - go: 1.7 - go: tip allow_failures: - go: tip diff --git a/vendor/github.com/gorilla/websocket/README.md b/vendor/github.com/gorilla/websocket/README.md index 9d71959ea..33c3d2be3 100644 --- a/vendor/github.com/gorilla/websocket/README.md +++ b/vendor/github.com/gorilla/websocket/README.md @@ -3,6 +3,9 @@ Gorilla WebSocket is a [Go](http://golang.org/) implementation of the [WebSocket](http://www.rfc-editor.org/rfc/rfc6455.txt) protocol. +[![Build Status](https://travis-ci.org/gorilla/websocket.svg?branch=master)](https://travis-ci.org/gorilla/websocket) +[![GoDoc](https://godoc.org/github.com/gorilla/websocket?status.svg)](https://godoc.org/github.com/gorilla/websocket) + ### Documentation * [API Reference](http://godoc.org/github.com/gorilla/websocket) @@ -43,7 +46,7 @@ subdirectory](https://github.com/gorilla/websocket/tree/master/examples/autobahn <tr><td>Send <a href="https://tools.ietf.org/html/rfc6455#section-5.5.2">pings</a> and receive <a href="https://tools.ietf.org/html/rfc6455#section-5.5.3">pongs</a></td><td><a href="http://godoc.org/github.com/gorilla/websocket#hdr-Control_Messages">Yes</a></td><td>No</td></tr> <tr><td>Get the <a href="https://tools.ietf.org/html/rfc6455#section-5.6">type</a> of a received data message</td><td>Yes</td><td>Yes, see note 2</td></tr> <tr><td colspan="3">Other Features</tr></td> -<tr><td>Limit size of received message</td><td><a href="http://godoc.org/github.com/gorilla/websocket#Conn.SetReadLimit">Yes</a></td><td><a href="https://code.google.com/p/go/issues/detail?id=5082">No</a></td></tr> +<tr><td><a href="https://tools.ietf.org/html/rfc7692">Compression Extensions</a></td><td>Experimental</td><td>No</td></tr> <tr><td>Read message using io.Reader</td><td><a href="http://godoc.org/github.com/gorilla/websocket#Conn.NextReader">Yes</a></td><td>No, see note 3</td></tr> <tr><td>Write message using io.WriteCloser</td><td><a href="http://godoc.org/github.com/gorilla/websocket#Conn.NextWriter">Yes</a></td><td>No, see note 3</td></tr> </table> diff --git a/vendor/github.com/gorilla/websocket/bench_test.go b/vendor/github.com/gorilla/websocket/bench_test.go deleted file mode 100644 index f66fc36bc..000000000 --- a/vendor/github.com/gorilla/websocket/bench_test.go +++ /dev/null @@ -1,19 +0,0 @@ -// Copyright 2014 The Gorilla WebSocket Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package websocket - -import ( - "testing" -) - -func BenchmarkMaskBytes(b *testing.B) { - var key [4]byte - data := make([]byte, 1024) - pos := 0 - for i := 0; i < b.N; i++ { - pos = maskBytes(key, pos, data) - } - b.SetBytes(int64(len(data))) -} diff --git a/vendor/github.com/gorilla/websocket/client.go b/vendor/github.com/gorilla/websocket/client.go index 879d33ed3..78d932877 100644 --- a/vendor/github.com/gorilla/websocket/client.go +++ b/vendor/github.com/gorilla/websocket/client.go @@ -23,6 +23,8 @@ import ( // invalid. var ErrBadHandshake = errors.New("websocket: bad handshake") +var errInvalidCompression = errors.New("websocket: invalid compression negotiation") + // NewClient creates a new client connection using the given net connection. // The URL u specifies the host and request URI. Use requestHeader to specify // the origin (Origin), subprotocols (Sec-WebSocket-Protocol) and cookies @@ -70,6 +72,17 @@ type Dialer struct { // Subprotocols specifies the client's requested subprotocols. Subprotocols []string + + // EnableCompression specifies if the client should attempt to negotiate + // per message compression (RFC 7692). Setting this value to true does not + // guarantee that compression will be supported. Currently only "no context + // takeover" modes are supported. + EnableCompression bool + + // Jar specifies the cookie jar. + // If Jar is nil, cookies are not sent in requests and ignored + // in responses. + Jar http.CookieJar } var errMalformedURL = errors.New("malformed ws or wss URL") @@ -83,7 +96,6 @@ func parseURL(s string) (*url.URL, error) { // // ws-URI = "ws:" "//" host [ ":" port ] path [ "?" query ] // wss-URI = "wss:" "//" host [ ":" port ] path [ "?" query ] - var u url.URL switch { case strings.HasPrefix(s, "ws://"): @@ -193,6 +205,13 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re Host: u.Host, } + // Set the cookies present in the cookie jar of the dialer + if d.Jar != nil { + for _, cookie := range d.Jar.Cookies(u) { + req.AddCookie(cookie) + } + } + // Set the request headers using the capitalization for names and values in // RFC examples. Although the capitalization shouldn't matter, there are // servers that depend on it. The Header.Set method is not used because the @@ -214,6 +233,7 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re k == "Connection" || k == "Sec-Websocket-Key" || k == "Sec-Websocket-Version" || + k == "Sec-Websocket-Extensions" || (k == "Sec-Websocket-Protocol" && len(d.Subprotocols) > 0): return nil, nil, errors.New("websocket: duplicate header not allowed: " + k) default: @@ -221,6 +241,10 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re } } + if d.EnableCompression { + req.Header.Set("Sec-Websocket-Extensions", "permessage-deflate; server_no_context_takeover; client_no_context_takeover") + } + hostPort, hostNoPort := hostPortNoPort(u) var proxyURL *url.URL @@ -324,6 +348,13 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re if err != nil { return nil, nil, err } + + if d.Jar != nil { + if rc := resp.Cookies(); len(rc) > 0 { + d.Jar.SetCookies(u, rc) + } + } + if resp.StatusCode != 101 || !strings.EqualFold(resp.Header.Get("Upgrade"), "websocket") || !strings.EqualFold(resp.Header.Get("Connection"), "upgrade") || @@ -337,6 +368,20 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re return nil, resp, ErrBadHandshake } + for _, ext := range parseExtensions(req.Header) { + if ext[""] != "permessage-deflate" { + continue + } + _, snct := ext["server_no_context_takeover"] + _, cnct := ext["client_no_context_takeover"] + if !snct || !cnct { + return nil, resp, errInvalidCompression + } + conn.newCompressionWriter = compressNoContextTakeover + conn.newDecompressionReader = decompressNoContextTakeover + break + } + resp.Body = ioutil.NopCloser(bytes.NewReader([]byte{})) conn.subprotocol = resp.Header.Get("Sec-Websocket-Protocol") diff --git a/vendor/github.com/gorilla/websocket/client_server_test.go b/vendor/github.com/gorilla/websocket/client_server_test.go index 3f7345dde..7d39da681 100644 --- a/vendor/github.com/gorilla/websocket/client_server_test.go +++ b/vendor/github.com/gorilla/websocket/client_server_test.go @@ -10,8 +10,8 @@ import ( "encoding/base64" "io" "io/ioutil" - "net" "net/http" + "net/http/cookiejar" "net/http/httptest" "net/url" "reflect" @@ -21,9 +21,10 @@ import ( ) var cstUpgrader = Upgrader{ - Subprotocols: []string{"p0", "p1"}, - ReadBufferSize: 1024, - WriteBufferSize: 1024, + Subprotocols: []string{"p0", "p1"}, + ReadBufferSize: 1024, + WriteBufferSize: 1024, + EnableCompression: true, Error: func(w http.ResponseWriter, r *http.Request, status int, reason error) { http.Error(w, reason.Error(), status) }, @@ -228,6 +229,54 @@ func TestDial(t *testing.T) { sendRecv(t, ws) } +func TestDialCookieJar(t *testing.T) { + s := newServer(t) + defer s.Close() + + jar, _ := cookiejar.New(nil) + d := cstDialer + d.Jar = jar + + u, _ := parseURL(s.URL) + + switch u.Scheme { + case "ws": + u.Scheme = "http" + case "wss": + u.Scheme = "https" + } + + cookies := []*http.Cookie{&http.Cookie{Name: "gorilla", Value: "ws", Path: "/"}} + d.Jar.SetCookies(u, cookies) + + ws, _, err := d.Dial(s.URL, nil) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer ws.Close() + + var gorilla string + var sessionID string + for _, c := range d.Jar.Cookies(u) { + if c.Name == "gorilla" { + gorilla = c.Value + } + + if c.Name == "sessionID" { + sessionID = c.Value + } + } + if gorilla != "ws" { + t.Error("Cookie not present in jar.") + } + + if sessionID != "1234" { + t.Error("Set-Cookie not received from the server.") + } + + sendRecv(t, ws) +} + func TestDialTLS(t *testing.T) { s := newTLSServer(t) defer s.Close() @@ -243,11 +292,9 @@ func TestDialTLS(t *testing.T) { } } - u, _ := url.Parse(s.URL) d := cstDialer - d.NetDial = func(network, addr string) (net.Conn, error) { return net.Dial(network, u.Host) } d.TLSClientConfig = &tls.Config{RootCAs: certs} - ws, _, err := d.Dial("wss://example.com"+cstRequestURI, nil) + ws, _, err := d.Dial(s.URL, nil) if err != nil { t.Fatalf("Dial: %v", err) } @@ -267,7 +314,7 @@ func xTestDialTLSBadCert(t *testing.T) { } } -func xTestDialTLSNoVerify(t *testing.T) { +func TestDialTLSNoVerify(t *testing.T) { s := newTLSServer(t) defer s.Close() @@ -449,3 +496,17 @@ func TestHostHeader(t *testing.T) { sendRecv(t, ws) } + +func TestDialCompression(t *testing.T) { + s := newServer(t) + defer s.Close() + + dialer := cstDialer + dialer.EnableCompression = true + ws, _, err := dialer.Dial(s.URL, nil) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer ws.Close() + sendRecv(t, ws) +} diff --git a/vendor/github.com/gorilla/websocket/compression.go b/vendor/github.com/gorilla/websocket/compression.go new file mode 100644 index 000000000..e2ac7617b --- /dev/null +++ b/vendor/github.com/gorilla/websocket/compression.go @@ -0,0 +1,85 @@ +// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package websocket + +import ( + "compress/flate" + "errors" + "io" + "strings" +) + +func decompressNoContextTakeover(r io.Reader) io.Reader { + const tail = + // Add four bytes as specified in RFC + "\x00\x00\xff\xff" + + // Add final block to squelch unexpected EOF error from flate reader. + "\x01\x00\x00\xff\xff" + + return flate.NewReader(io.MultiReader(r, strings.NewReader(tail))) +} + +func compressNoContextTakeover(w io.WriteCloser) (io.WriteCloser, error) { + tw := &truncWriter{w: w} + fw, err := flate.NewWriter(tw, 3) + return &flateWrapper{fw: fw, tw: tw}, err +} + +// truncWriter is an io.Writer that writes all but the last four bytes of the +// stream to another io.Writer. +type truncWriter struct { + w io.WriteCloser + n int + p [4]byte +} + +func (w *truncWriter) Write(p []byte) (int, error) { + n := 0 + + // fill buffer first for simplicity. + if w.n < len(w.p) { + n = copy(w.p[w.n:], p) + p = p[n:] + w.n += n + if len(p) == 0 { + return n, nil + } + } + + m := len(p) + if m > len(w.p) { + m = len(w.p) + } + + if nn, err := w.w.Write(w.p[:m]); err != nil { + return n + nn, err + } + + copy(w.p[:], w.p[m:]) + copy(w.p[len(w.p)-m:], p[len(p)-m:]) + nn, err := w.w.Write(p[:len(p)-m]) + return n + nn, err +} + +type flateWrapper struct { + fw *flate.Writer + tw *truncWriter +} + +func (w *flateWrapper) Write(p []byte) (int, error) { + return w.fw.Write(p) +} + +func (w *flateWrapper) Close() error { + err1 := w.fw.Flush() + if w.tw.p != [4]byte{0, 0, 0xff, 0xff} { + return errors.New("websocket: internal error, unexpected bytes at end of flate stream") + } + err2 := w.tw.w.Close() + if err1 != nil { + return err1 + } + return err2 +} diff --git a/vendor/github.com/gorilla/websocket/compression_test.go b/vendor/github.com/gorilla/websocket/compression_test.go new file mode 100644 index 000000000..cad70fb51 --- /dev/null +++ b/vendor/github.com/gorilla/websocket/compression_test.go @@ -0,0 +1,31 @@ +package websocket + +import ( + "bytes" + "io" + "testing" +) + +type nopCloser struct{ io.Writer } + +func (nopCloser) Close() error { return nil } + +func TestTruncWriter(t *testing.T) { + const data = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijlkmnopqrstuvwxyz987654321" + for n := 1; n <= 10; n++ { + var b bytes.Buffer + w := &truncWriter{w: nopCloser{&b}} + p := []byte(data) + for len(p) > 0 { + m := len(p) + if m > n { + m = n + } + w.Write(p[:m]) + p = p[m:] + } + if b.String() != data[:len(data)-len(w.p)] { + t.Errorf("%d: %q", n, b.String()) + } + } +} diff --git a/vendor/github.com/gorilla/websocket/conn.go b/vendor/github.com/gorilla/websocket/conn.go index ed7736c49..b7a97bae9 100644 --- a/vendor/github.com/gorilla/websocket/conn.go +++ b/vendor/github.com/gorilla/websocket/conn.go @@ -10,19 +10,27 @@ import ( "errors" "io" "io/ioutil" - "math/rand" "net" "strconv" + "sync" "time" "unicode/utf8" ) const ( + // Frame header byte 0 bits from Section 5.2 of RFC 6455 + finalBit = 1 << 7 + rsv1Bit = 1 << 6 + rsv2Bit = 1 << 5 + rsv3Bit = 1 << 4 + + // Frame header byte 1 bits from Section 5.2 of RFC 6455 + maskBit = 1 << 7 + maxFrameHeaderSize = 2 + 8 + 4 // Fixed header + length + mask maxControlFramePayloadSize = 125 - finalBit = 1 << 7 - maskBit = 1 << 7 - writeWait = time.Second + + writeWait = time.Second defaultReadBufferSize = 4096 defaultWriteBufferSize = 4096 @@ -210,51 +218,41 @@ func isValidReceivedCloseCode(code int) bool { return validReceivedCloseCodes[code] || (code >= 3000 && code <= 4999) } -func maskBytes(key [4]byte, pos int, b []byte) int { - for i := range b { - b[i] ^= key[pos&3] - pos++ - } - return pos & 3 -} - -func newMaskKey() [4]byte { - n := rand.Uint32() - return [4]byte{byte(n), byte(n >> 8), byte(n >> 16), byte(n >> 24)} -} - -// Conn represents a WebSocket connection. type Conn struct { conn net.Conn isServer bool subprotocol string // Write fields - mu chan bool // used as mutex to protect write to conn and closeSent - closeSent bool // true if close message was sent - - // Message writer fields. - writeErr error - writeBuf []byte // frame is constructed in this buffer. - writePos int // end of data in writeBuf. - writeFrameType int // type of the current frame. - writeSeq int // incremented to invalidate message writers. - writeDeadline time.Time - isWriting bool // for best-effort concurrent write detection + mu chan bool // used as mutex to protect write to conn + writeBuf []byte // frame is constructed in this buffer. + writeDeadline time.Time + writer io.WriteCloser // the current writer returned to the application + isWriting bool // for best-effort concurrent write detection + + writeErrMu sync.Mutex + writeErr error + + enableWriteCompression bool + newCompressionWriter func(io.WriteCloser) (io.WriteCloser, error) // Read fields readErr error br *bufio.Reader readRemaining int64 // bytes remaining in current frame. readFinal bool // true the current message has more frames. - readSeq int // incremented to invalidate message readers. readLength int64 // Message size. readLimit int64 // Maximum message size. readMaskPos int readMaskKey [4]byte handlePong func(string) error handlePing func(string) error + handleClose func(int, string) error readErrCount int + messageReader *messageReader // the current low-level reader + + readDecompress bool // whether last read frame had RSV1 set + newDecompressionReader func(io.Reader) io.Reader } func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) *Conn { @@ -264,20 +262,23 @@ func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) if readBufferSize == 0 { readBufferSize = defaultReadBufferSize } + if readBufferSize < maxControlFramePayloadSize { + readBufferSize = maxControlFramePayloadSize + } if writeBufferSize == 0 { writeBufferSize = defaultWriteBufferSize } c := &Conn{ - isServer: isServer, - br: bufio.NewReaderSize(conn, readBufferSize), - conn: conn, - mu: mu, - readFinal: true, - writeBuf: make([]byte, writeBufferSize+maxFrameHeaderSize), - writeFrameType: noFrame, - writePos: maxFrameHeaderSize, - } + isServer: isServer, + br: bufio.NewReaderSize(conn, readBufferSize), + conn: conn, + mu: mu, + readFinal: true, + writeBuf: make([]byte, writeBufferSize+maxFrameHeaderSize), + enableWriteCompression: true, + } + c.SetCloseHandler(nil) c.SetPingHandler(nil) c.SetPongHandler(nil) return c @@ -305,29 +306,40 @@ func (c *Conn) RemoteAddr() net.Addr { // Write methods +func (c *Conn) writeFatal(err error) error { + err = hideTempErr(err) + c.writeErrMu.Lock() + if c.writeErr == nil { + c.writeErr = err + } + c.writeErrMu.Unlock() + return err +} + func (c *Conn) write(frameType int, deadline time.Time, bufs ...[]byte) error { <-c.mu defer func() { c.mu <- true }() - if c.closeSent { - return ErrCloseSent - } else if frameType == CloseMessage { - c.closeSent = true + c.writeErrMu.Lock() + err := c.writeErr + c.writeErrMu.Unlock() + if err != nil { + return err } c.conn.SetWriteDeadline(deadline) for _, buf := range bufs { if len(buf) > 0 { - n, err := c.conn.Write(buf) - if n != len(buf) { - // Close on partial write. - c.conn.Close() - } + _, err := c.conn.Write(buf) if err != nil { - return err + return c.writeFatal(err) } } } + + if frameType == CloseMessage { + c.writeFatal(ErrCloseSent) + } return nil } @@ -376,60 +388,104 @@ func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) er } defer func() { c.mu <- true }() - if c.closeSent { - return ErrCloseSent - } else if messageType == CloseMessage { - c.closeSent = true + c.writeErrMu.Lock() + err := c.writeErr + c.writeErrMu.Unlock() + if err != nil { + return err } c.conn.SetWriteDeadline(deadline) - n, err := c.conn.Write(buf) - if n != 0 && n != len(buf) { - c.conn.Close() + _, err = c.conn.Write(buf) + if err != nil { + return c.writeFatal(err) + } + if messageType == CloseMessage { + c.writeFatal(ErrCloseSent) } - return hideTempErr(err) + return err } -// NextWriter returns a writer for the next message to send. The writer's -// Close method flushes the complete message to the network. +// NextWriter returns a writer for the next message to send. The writer's Close +// method flushes the complete message to the network. // // There can be at most one open writer on a connection. NextWriter closes the // previous writer if the application has not already done so. func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) { - if c.writeErr != nil { - return nil, c.writeErr + // Close previous writer if not already closed by the application. It's + // probably better to return an error in this situation, but we cannot + // change this without breaking existing applications. + if c.writer != nil { + c.writer.Close() + c.writer = nil } - if c.writeFrameType != noFrame { - if err := c.flushFrame(true, nil); err != nil { + if !isControl(messageType) && !isData(messageType) { + return nil, errBadWriteOpCode + } + + c.writeErrMu.Lock() + err := c.writeErr + c.writeErrMu.Unlock() + if err != nil { + return nil, err + } + + mw := &messageWriter{ + c: c, + frameType: messageType, + pos: maxFrameHeaderSize, + } + c.writer = mw + if c.newCompressionWriter != nil && c.enableWriteCompression && isData(messageType) { + w, err := c.newCompressionWriter(c.writer) + if err != nil { + c.writer = nil return nil, err } + mw.compress = true + c.writer = w } + return c.writer, nil +} - if !isControl(messageType) && !isData(messageType) { - return nil, errBadWriteOpCode - } +type messageWriter struct { + c *Conn + compress bool // whether next call to flushFrame should set RSV1 + pos int // end of data in writeBuf. + frameType int // type of the current frame. + err error +} - c.writeFrameType = messageType - return messageWriter{c, c.writeSeq}, nil +func (w *messageWriter) fatal(err error) error { + if w.err != nil { + w.err = err + w.c.writer = nil + } + return err } -func (c *Conn) flushFrame(final bool, extra []byte) error { - length := c.writePos - maxFrameHeaderSize + len(extra) +// flushFrame writes buffered data and extra as a frame to the network. The +// final argument indicates that this is the last frame in the message. +func (w *messageWriter) flushFrame(final bool, extra []byte) error { + c := w.c + length := w.pos - maxFrameHeaderSize + len(extra) // Check for invalid control frames. - if isControl(c.writeFrameType) && + if isControl(w.frameType) && (!final || length > maxControlFramePayloadSize) { - c.writeSeq++ - c.writeFrameType = noFrame - c.writePos = maxFrameHeaderSize - return errInvalidControlFrame + return w.fatal(errInvalidControlFrame) } - b0 := byte(c.writeFrameType) + b0 := byte(w.frameType) if final { b0 |= finalBit } + if w.compress { + b0 |= rsv1Bit + } + w.compress = false + b1 := byte(0) if !c.isServer { b1 |= maskBit @@ -461,10 +517,9 @@ func (c *Conn) flushFrame(final bool, extra []byte) error { if !c.isServer { key := newMaskKey() copy(c.writeBuf[maxFrameHeaderSize-4:], key[:]) - maskBytes(key, 0, c.writeBuf[maxFrameHeaderSize:c.writePos]) + maskBytes(key, 0, c.writeBuf[maxFrameHeaderSize:w.pos]) if len(extra) > 0 { - c.writeErr = errors.New("websocket: internal error, extra used in client mode") - return c.writeErr + return c.writeFatal(errors.New("websocket: internal error, extra used in client mode")) } } @@ -477,46 +532,35 @@ func (c *Conn) flushFrame(final bool, extra []byte) error { } c.isWriting = true - c.writeErr = c.write(c.writeFrameType, c.writeDeadline, c.writeBuf[framePos:c.writePos], extra) + err := c.write(w.frameType, c.writeDeadline, c.writeBuf[framePos:w.pos], extra) if !c.isWriting { panic("concurrent write to websocket connection") } c.isWriting = false - // Setup for next frame. - c.writePos = maxFrameHeaderSize - c.writeFrameType = continuationFrame - if final { - c.writeSeq++ - c.writeFrameType = noFrame + if err != nil { + return w.fatal(err) } - return c.writeErr -} -type messageWriter struct { - c *Conn - seq int -} - -func (w messageWriter) err() error { - c := w.c - if c.writeSeq != w.seq { - return errWriteClosed - } - if c.writeErr != nil { - return c.writeErr + if final { + c.writer = nil + return nil } + + // Setup for next frame. + w.pos = maxFrameHeaderSize + w.frameType = continuationFrame return nil } -func (w messageWriter) ncopy(max int) (int, error) { - n := len(w.c.writeBuf) - w.c.writePos +func (w *messageWriter) ncopy(max int) (int, error) { + n := len(w.c.writeBuf) - w.pos if n <= 0 { - if err := w.c.flushFrame(false, nil); err != nil { + if err := w.flushFrame(false, nil); err != nil { return 0, err } - n = len(w.c.writeBuf) - w.c.writePos + n = len(w.c.writeBuf) - w.pos } if n > max { n = max @@ -524,14 +568,14 @@ func (w messageWriter) ncopy(max int) (int, error) { return n, nil } -func (w messageWriter) write(final bool, p []byte) (int, error) { - if err := w.err(); err != nil { - return 0, err +func (w *messageWriter) Write(p []byte) (int, error) { + if w.err != nil { + return 0, w.err } if len(p) > 2*len(w.c.writeBuf) && w.c.isServer { // Don't buffer large messages. - err := w.c.flushFrame(final, p) + err := w.flushFrame(false, p) if err != nil { return 0, err } @@ -544,20 +588,16 @@ func (w messageWriter) write(final bool, p []byte) (int, error) { if err != nil { return 0, err } - copy(w.c.writeBuf[w.c.writePos:], p[:n]) - w.c.writePos += n + copy(w.c.writeBuf[w.pos:], p[:n]) + w.pos += n p = p[n:] } return nn, nil } -func (w messageWriter) Write(p []byte) (int, error) { - return w.write(false, p) -} - -func (w messageWriter) WriteString(p string) (int, error) { - if err := w.err(); err != nil { - return 0, err +func (w *messageWriter) WriteString(p string) (int, error) { + if w.err != nil { + return 0, w.err } nn := len(p) @@ -566,27 +606,27 @@ func (w messageWriter) WriteString(p string) (int, error) { if err != nil { return 0, err } - copy(w.c.writeBuf[w.c.writePos:], p[:n]) - w.c.writePos += n + copy(w.c.writeBuf[w.pos:], p[:n]) + w.pos += n p = p[n:] } return nn, nil } -func (w messageWriter) ReadFrom(r io.Reader) (nn int64, err error) { - if err := w.err(); err != nil { - return 0, err +func (w *messageWriter) ReadFrom(r io.Reader) (nn int64, err error) { + if w.err != nil { + return 0, w.err } for { - if w.c.writePos == len(w.c.writeBuf) { - err = w.c.flushFrame(false, nil) + if w.pos == len(w.c.writeBuf) { + err = w.flushFrame(false, nil) if err != nil { break } } var n int - n, err = r.Read(w.c.writeBuf[w.c.writePos:]) - w.c.writePos += n + n, err = r.Read(w.c.writeBuf[w.pos:]) + w.pos += n nn += int64(n) if err != nil { if err == io.EOF { @@ -598,30 +638,36 @@ func (w messageWriter) ReadFrom(r io.Reader) (nn int64, err error) { return nn, err } -func (w messageWriter) Close() error { - if err := w.err(); err != nil { +func (w *messageWriter) Close() error { + if w.err != nil { + return w.err + } + if err := w.flushFrame(true, nil); err != nil { return err } - return w.c.flushFrame(true, nil) + w.err = errWriteClosed + return nil } // WriteMessage is a helper method for getting a writer using NextWriter, // writing the message and closing the writer. func (c *Conn) WriteMessage(messageType int, data []byte) error { - wr, err := c.NextWriter(messageType) + w, err := c.NextWriter(messageType) if err != nil { return err } - w := wr.(messageWriter) - if _, err := w.write(true, data); err != nil { + if mw, ok := w.(*messageWriter); ok && c.isServer { + // Optimize write as a single frame. + n := copy(c.writeBuf[mw.pos:], data) + mw.pos += n + data = data[n:] + err = mw.flushFrame(true, data) return err } - if c.writeSeq == w.seq { - if err := c.flushFrame(true, nil); err != nil { - return err - } + if _, err = w.Write(data); err != nil { + return err } - return nil + return w.Close() } // SetWriteDeadline sets the write deadline on the underlying network @@ -635,22 +681,6 @@ func (c *Conn) SetWriteDeadline(t time.Time) error { // Read methods -// readFull is like io.ReadFull except that io.EOF is never returned. -func (c *Conn) readFull(p []byte) (err error) { - var n int - for n < len(p) && err == nil { - var nn int - nn, err = c.br.Read(p[n:]) - n += nn - } - if n == len(p) { - err = nil - } else if err == io.EOF { - err = errUnexpectedEOF - } - return -} - func (c *Conn) advanceFrame() (int, error) { // 1. Skip remainder of previous frame. @@ -663,19 +693,24 @@ func (c *Conn) advanceFrame() (int, error) { // 2. Read and parse first two bytes of frame header. - var b [8]byte - if err := c.readFull(b[:2]); err != nil { + p, err := c.read(2) + if err != nil { return noFrame, err } - final := b[0]&finalBit != 0 - frameType := int(b[0] & 0xf) - reserved := int((b[0] >> 4) & 0x7) - mask := b[1]&maskBit != 0 - c.readRemaining = int64(b[1] & 0x7f) + final := p[0]&finalBit != 0 + frameType := int(p[0] & 0xf) + mask := p[1]&maskBit != 0 + c.readRemaining = int64(p[1] & 0x7f) + + c.readDecompress = false + if c.newDecompressionReader != nil && (p[0]&rsv1Bit) != 0 { + c.readDecompress = true + p[0] &^= rsv1Bit + } - if reserved != 0 { - return noFrame, c.handleProtocolError("unexpected reserved bits " + strconv.Itoa(reserved)) + if rsv := p[0] & (rsv1Bit | rsv2Bit | rsv3Bit); rsv != 0 { + return noFrame, c.handleProtocolError("unexpected reserved bits 0x" + strconv.FormatInt(int64(rsv), 16)) } switch frameType { @@ -704,15 +739,17 @@ func (c *Conn) advanceFrame() (int, error) { switch c.readRemaining { case 126: - if err := c.readFull(b[:2]); err != nil { + p, err := c.read(2) + if err != nil { return noFrame, err } - c.readRemaining = int64(binary.BigEndian.Uint16(b[:2])) + c.readRemaining = int64(binary.BigEndian.Uint16(p)) case 127: - if err := c.readFull(b[:8]); err != nil { + p, err := c.read(8) + if err != nil { return noFrame, err } - c.readRemaining = int64(binary.BigEndian.Uint64(b[:8])) + c.readRemaining = int64(binary.BigEndian.Uint64(p)) } // 4. Handle frame masking. @@ -723,9 +760,11 @@ func (c *Conn) advanceFrame() (int, error) { if mask { c.readMaskPos = 0 - if err := c.readFull(c.readMaskKey[:]); err != nil { + p, err := c.read(len(c.readMaskKey)) + if err != nil { return noFrame, err } + copy(c.readMaskKey[:], p) } // 5. For text and binary messages, enforce read limit and return. @@ -745,9 +784,9 @@ func (c *Conn) advanceFrame() (int, error) { var payload []byte if c.readRemaining > 0 { - payload = make([]byte, c.readRemaining) + payload, err = c.read(int(c.readRemaining)) c.readRemaining = 0 - if err := c.readFull(payload); err != nil { + if err != nil { return noFrame, err } if c.isServer { @@ -767,11 +806,9 @@ func (c *Conn) advanceFrame() (int, error) { return noFrame, err } case CloseMessage: - echoMessage := []byte{} closeCode := CloseNoStatusReceived closeText := "" if len(payload) >= 2 { - echoMessage = payload[:2] closeCode = int(binary.BigEndian.Uint16(payload)) if !isValidReceivedCloseCode(closeCode) { return noFrame, c.handleProtocolError("invalid close code") @@ -781,7 +818,9 @@ func (c *Conn) advanceFrame() (int, error) { return noFrame, c.handleProtocolError("invalid utf8 payload in close frame") } } - c.WriteControl(CloseMessage, echoMessage, time.Now().Add(writeWait)) + if err := c.handleClose(closeCode, closeText); err != nil { + return noFrame, err + } return noFrame, &CloseError{Code: closeCode, Text: closeText} } @@ -805,7 +844,7 @@ func (c *Conn) handleProtocolError(message string) error { // this method return the same error. func (c *Conn) NextReader() (messageType int, r io.Reader, err error) { - c.readSeq++ + c.messageReader = nil c.readLength = 0 for c.readErr == nil { @@ -815,7 +854,12 @@ func (c *Conn) NextReader() (messageType int, r io.Reader, err error) { break } if frameType == TextMessage || frameType == BinaryMessage { - return frameType, messageReader{c, c.readSeq}, nil + c.messageReader = &messageReader{c} + var r io.Reader = c.messageReader + if c.readDecompress { + r = c.newDecompressionReader(r) + } + return frameType, r, nil } } @@ -830,51 +874,48 @@ func (c *Conn) NextReader() (messageType int, r io.Reader, err error) { return noFrame, nil, c.readErr } -type messageReader struct { - c *Conn - seq int -} - -func (r messageReader) Read(b []byte) (int, error) { +type messageReader struct{ c *Conn } - if r.seq != r.c.readSeq { +func (r *messageReader) Read(b []byte) (int, error) { + c := r.c + if c.messageReader != r { return 0, io.EOF } - for r.c.readErr == nil { + for c.readErr == nil { - if r.c.readRemaining > 0 { - if int64(len(b)) > r.c.readRemaining { - b = b[:r.c.readRemaining] + if c.readRemaining > 0 { + if int64(len(b)) > c.readRemaining { + b = b[:c.readRemaining] } - n, err := r.c.br.Read(b) - r.c.readErr = hideTempErr(err) - if r.c.isServer { - r.c.readMaskPos = maskBytes(r.c.readMaskKey, r.c.readMaskPos, b[:n]) + n, err := c.br.Read(b) + c.readErr = hideTempErr(err) + if c.isServer { + c.readMaskPos = maskBytes(c.readMaskKey, c.readMaskPos, b[:n]) } - r.c.readRemaining -= int64(n) - if r.c.readRemaining > 0 && r.c.readErr == io.EOF { - r.c.readErr = errUnexpectedEOF + c.readRemaining -= int64(n) + if c.readRemaining > 0 && c.readErr == io.EOF { + c.readErr = errUnexpectedEOF } - return n, r.c.readErr + return n, c.readErr } - if r.c.readFinal { - r.c.readSeq++ + if c.readFinal { + c.messageReader = nil return 0, io.EOF } - frameType, err := r.c.advanceFrame() + frameType, err := c.advanceFrame() switch { case err != nil: - r.c.readErr = hideTempErr(err) + c.readErr = hideTempErr(err) case frameType == TextMessage || frameType == BinaryMessage: - r.c.readErr = errors.New("websocket: internal error, unexpected text or binary in Reader") + c.readErr = errors.New("websocket: internal error, unexpected text or binary in Reader") } } - err := r.c.readErr - if err == io.EOF && r.seq == r.c.readSeq { + err := c.readErr + if err == io.EOF && c.messageReader == r { err = errUnexpectedEOF } return 0, err @@ -907,6 +948,34 @@ func (c *Conn) SetReadLimit(limit int64) { c.readLimit = limit } +// CloseHandler returns the current close handler +func (c *Conn) CloseHandler() func(code int, text string) error { + return c.handleClose +} + +// SetCloseHandler sets the handler for close messages received from the peer. +// The code argument to h is the received close code or CloseNoStatusReceived +// if the close message is empty. The default close handler sends a close frame +// back to the peer. +func (c *Conn) SetCloseHandler(h func(code int, text string) error) { + if h == nil { + h = func(code int, text string) error { + message := []byte{} + if code != CloseNoStatusReceived { + message = FormatCloseMessage(code, "") + } + c.WriteControl(CloseMessage, message, time.Now().Add(writeWait)) + return nil + } + } + c.handleClose = h +} + +// PingHandler returns the current ping handler +func (c *Conn) PingHandler() func(appData string) error { + return c.handlePing +} + // SetPingHandler sets the handler for ping messages received from the peer. // The appData argument to h is the PING frame application data. The default // ping handler sends a pong to the peer. @@ -925,6 +994,11 @@ func (c *Conn) SetPingHandler(h func(appData string) error) { c.handlePing = h } +// PongHandler returns the current pong handler +func (c *Conn) PongHandler() func(appData string) error { + return c.handlePong +} + // SetPongHandler sets the handler for pong messages received from the peer. // The appData argument to h is the PONG frame application data. The default // pong handler does nothing. @@ -941,6 +1015,13 @@ func (c *Conn) UnderlyingConn() net.Conn { return c.conn } +// EnableWriteCompression enables and disables write compression of +// subsequent text and binary messages. This function is a noop if +// compression was not negotiated with the peer. +func (c *Conn) EnableWriteCompression(enable bool) { + c.enableWriteCompression = enable +} + // FormatCloseMessage formats closeCode and text as a WebSocket close message. func FormatCloseMessage(closeCode int, text string) []byte { buf := make([]byte, 2+len(text)) diff --git a/vendor/github.com/gorilla/websocket/conn_read.go b/vendor/github.com/gorilla/websocket/conn_read.go new file mode 100644 index 000000000..1ea15059e --- /dev/null +++ b/vendor/github.com/gorilla/websocket/conn_read.go @@ -0,0 +1,18 @@ +// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build go1.5 + +package websocket + +import "io" + +func (c *Conn) read(n int) ([]byte, error) { + p, err := c.br.Peek(n) + if err == io.EOF { + err = errUnexpectedEOF + } + c.br.Discard(len(p)) + return p, err +} diff --git a/vendor/github.com/gorilla/websocket/conn_read_legacy.go b/vendor/github.com/gorilla/websocket/conn_read_legacy.go new file mode 100644 index 000000000..018541cf6 --- /dev/null +++ b/vendor/github.com/gorilla/websocket/conn_read_legacy.go @@ -0,0 +1,21 @@ +// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !go1.5 + +package websocket + +import "io" + +func (c *Conn) read(n int) ([]byte, error) { + p, err := c.br.Peek(n) + if err == io.EOF { + err = errUnexpectedEOF + } + if len(p) > 0 { + // advance over the bytes just read + io.ReadFull(c.br, p) + } + return p, err +} diff --git a/vendor/github.com/gorilla/websocket/conn_test.go b/vendor/github.com/gorilla/websocket/conn_test.go index 0243c1154..7431383b1 100644 --- a/vendor/github.com/gorilla/websocket/conn_test.go +++ b/vendor/github.com/gorilla/websocket/conn_test.go @@ -26,12 +26,27 @@ type fakeNetConn struct { } func (c fakeNetConn) Close() error { return nil } -func (c fakeNetConn) LocalAddr() net.Addr { return nil } -func (c fakeNetConn) RemoteAddr() net.Addr { return nil } +func (c fakeNetConn) LocalAddr() net.Addr { return localAddr } +func (c fakeNetConn) RemoteAddr() net.Addr { return remoteAddr } func (c fakeNetConn) SetDeadline(t time.Time) error { return nil } func (c fakeNetConn) SetReadDeadline(t time.Time) error { return nil } func (c fakeNetConn) SetWriteDeadline(t time.Time) error { return nil } +type fakeAddr int + +var ( + localAddr = fakeAddr(1) + remoteAddr = fakeAddr(2) +) + +func (a fakeAddr) Network() string { + return "net" +} + +func (a fakeAddr) String() string { + return "str" +} + func TestFraming(t *testing.T) { frameSizes := []int{0, 1, 2, 124, 125, 126, 127, 128, 129, 65534, 65535, 65536, 65537} var readChunkers = []struct { @@ -42,66 +57,78 @@ func TestFraming(t *testing.T) { {"one", iotest.OneByteReader}, {"asis", func(r io.Reader) io.Reader { return r }}, } - writeBuf := make([]byte, 65537) for i := range writeBuf { writeBuf[i] = byte(i) } + var writers = []struct { + name string + f func(w io.Writer, n int) (int, error) + }{ + {"iocopy", func(w io.Writer, n int) (int, error) { + nn, err := io.Copy(w, bytes.NewReader(writeBuf[:n])) + return int(nn), err + }}, + {"write", func(w io.Writer, n int) (int, error) { + return w.Write(writeBuf[:n]) + }}, + {"string", func(w io.Writer, n int) (int, error) { + return io.WriteString(w, string(writeBuf[:n])) + }}, + } - for _, isServer := range []bool{true, false} { - for _, chunker := range readChunkers { - - var connBuf bytes.Buffer - wc := newConn(fakeNetConn{Reader: nil, Writer: &connBuf}, isServer, 1024, 1024) - rc := newConn(fakeNetConn{Reader: chunker.f(&connBuf), Writer: nil}, !isServer, 1024, 1024) - - for _, n := range frameSizes { - for _, iocopy := range []bool{true, false} { - name := fmt.Sprintf("s:%v, r:%s, n:%d c:%v", isServer, chunker.name, n, iocopy) + for _, compress := range []bool{false, true} { + for _, isServer := range []bool{true, false} { + for _, chunker := range readChunkers { - w, err := wc.NextWriter(TextMessage) - if err != nil { - t.Errorf("%s: wc.NextWriter() returned %v", name, err) - continue - } - var nn int - if iocopy { - var n64 int64 - n64, err = io.Copy(w, bytes.NewReader(writeBuf[:n])) - nn = int(n64) - } else { - nn, err = w.Write(writeBuf[:n]) - } - if err != nil || nn != n { - t.Errorf("%s: w.Write(writeBuf[:n]) returned %d, %v", name, nn, err) - continue - } - err = w.Close() - if err != nil { - t.Errorf("%s: w.Close() returned %v", name, err) - continue - } + var connBuf bytes.Buffer + wc := newConn(fakeNetConn{Reader: nil, Writer: &connBuf}, isServer, 1024, 1024) + rc := newConn(fakeNetConn{Reader: chunker.f(&connBuf), Writer: nil}, !isServer, 1024, 1024) + if compress { + wc.newCompressionWriter = compressNoContextTakeover + rc.newDecompressionReader = decompressNoContextTakeover + } + for _, n := range frameSizes { + for _, writer := range writers { + name := fmt.Sprintf("z:%v, s:%v, r:%s, n:%d w:%s", compress, isServer, chunker.name, n, writer.name) + + w, err := wc.NextWriter(TextMessage) + if err != nil { + t.Errorf("%s: wc.NextWriter() returned %v", name, err) + continue + } + nn, err := writer.f(w, n) + if err != nil || nn != n { + t.Errorf("%s: w.Write(writeBuf[:n]) returned %d, %v", name, nn, err) + continue + } + err = w.Close() + if err != nil { + t.Errorf("%s: w.Close() returned %v", name, err) + continue + } - opCode, r, err := rc.NextReader() - if err != nil || opCode != TextMessage { - t.Errorf("%s: NextReader() returned %d, r, %v", name, opCode, err) - continue - } - rbuf, err := ioutil.ReadAll(r) - if err != nil { - t.Errorf("%s: ReadFull() returned rbuf, %v", name, err) - continue - } + opCode, r, err := rc.NextReader() + if err != nil || opCode != TextMessage { + t.Errorf("%s: NextReader() returned %d, r, %v", name, opCode, err) + continue + } + rbuf, err := ioutil.ReadAll(r) + if err != nil { + t.Errorf("%s: ReadFull() returned rbuf, %v", name, err) + continue + } - if len(rbuf) != n { - t.Errorf("%s: len(rbuf) is %d, want %d", name, len(rbuf), n) - continue - } + if len(rbuf) != n { + t.Errorf("%s: len(rbuf) is %d, want %d", name, len(rbuf), n) + continue + } - for i, b := range rbuf { - if byte(i) != b { - t.Errorf("%s: bad byte at offset %d", name, i) - break + for i, b := range rbuf { + if byte(i) != b { + t.Errorf("%s: bad byte at offset %d", name, i) + break + } } } } @@ -146,7 +173,7 @@ func TestControl(t *testing.T) { } } -func TestCloseBeforeFinalFrame(t *testing.T) { +func TestCloseFrameBeforeFinalMessageFrame(t *testing.T) { const bufSize = 512 expectedErr := &CloseError{Code: CloseNormalClosure, Text: "hello"} @@ -233,6 +260,32 @@ func TestEOFBeforeFinalFrame(t *testing.T) { } } +func TestWriteAfterMessageWriterClose(t *testing.T) { + wc := newConn(fakeNetConn{Reader: nil, Writer: &bytes.Buffer{}}, false, 1024, 1024) + w, _ := wc.NextWriter(BinaryMessage) + io.WriteString(w, "hello") + if err := w.Close(); err != nil { + t.Fatalf("unxpected error closing message writer, %v", err) + } + + if _, err := io.WriteString(w, "world"); err == nil { + t.Fatalf("no error writing after close") + } + + w, _ = wc.NextWriter(BinaryMessage) + io.WriteString(w, "hello") + + // close w by getting next writer + _, err := wc.NextWriter(BinaryMessage) + if err != nil { + t.Fatalf("unexpected error getting next writer, %v", err) + } + + if _, err := io.WriteString(w, "world"); err == nil { + t.Fatalf("no error writing after close") + } +} + func TestReadLimit(t *testing.T) { const readLimit = 512 @@ -267,6 +320,16 @@ func TestReadLimit(t *testing.T) { } } +func TestAddrs(t *testing.T) { + c := newConn(&fakeNetConn{}, true, 1024, 1024) + if c.LocalAddr() != localAddr { + t.Errorf("LocalAddr = %v, want %v", c.LocalAddr(), localAddr) + } + if c.RemoteAddr() != remoteAddr { + t.Errorf("RemoteAddr = %v, want %v", c.RemoteAddr(), remoteAddr) + } +} + func TestUnderlyingConn(t *testing.T) { var b1, b2 bytes.Buffer fc := fakeNetConn{Reader: &b1, Writer: &b2} diff --git a/vendor/github.com/gorilla/websocket/doc.go b/vendor/github.com/gorilla/websocket/doc.go index c901a7a94..610acf712 100644 --- a/vendor/github.com/gorilla/websocket/doc.go +++ b/vendor/github.com/gorilla/websocket/doc.go @@ -149,4 +149,25 @@ // The deprecated Upgrade function does not enforce an origin policy. It's the // application's responsibility to check the Origin header before calling // Upgrade. +// +// Compression [Experimental] +// +// Per message compression extensions (RFC 7692) are experimentally supported +// by this package in a limited capacity. Setting the EnableCompression option +// to true in Dialer or Upgrader will attempt to negotiate per message deflate +// support. If compression was successfully negotiated with the connection's +// peer, any message received in compressed form will be automatically +// decompressed. All Read methods will return uncompressed bytes. +// +// Per message compression of messages written to a connection can be enabled +// or disabled by calling the corresponding Conn method: +// +// conn.EnableWriteCompression(true) +// +// Currently this package does not support compression with "context takeover". +// This means that messages must be compressed and decompressed in isolation, +// without retaining sliding window or dictionary state across messages. For +// more details refer to RFC 7692. +// +// Use of compression is experimental and may result in decreased performance. package websocket diff --git a/vendor/github.com/gorilla/websocket/examples/autobahn/server.go b/vendor/github.com/gorilla/websocket/examples/autobahn/server.go index d96ac84db..e98563be9 100644 --- a/vendor/github.com/gorilla/websocket/examples/autobahn/server.go +++ b/vendor/github.com/gorilla/websocket/examples/autobahn/server.go @@ -8,17 +8,19 @@ package main import ( "errors" "flag" - "github.com/gorilla/websocket" "io" "log" "net/http" "time" "unicode/utf8" + + "github.com/gorilla/websocket" ) var upgrader = websocket.Upgrader{ - ReadBufferSize: 4096, - WriteBufferSize: 4096, + ReadBufferSize: 4096, + WriteBufferSize: 4096, + EnableCompression: true, CheckOrigin: func(r *http.Request) bool { return true }, diff --git a/vendor/github.com/gorilla/websocket/examples/chat/README.md b/vendor/github.com/gorilla/websocket/examples/chat/README.md index 5df3cf1a3..47c82f908 100644 --- a/vendor/github.com/gorilla/websocket/examples/chat/README.md +++ b/vendor/github.com/gorilla/websocket/examples/chat/README.md @@ -1,8 +1,8 @@ # Chat Example This application shows how to use use the -[websocket](https://github.com/gorilla/websocket) package and -[jQuery](http://jquery.com) to implement a simple web chat application. +[websocket](https://github.com/gorilla/websocket) package to implement a simple +web chat application. ## Running the example @@ -18,3 +18,85 @@ using the following commands. $ go run *.go To use the chat example, open http://localhost:8080/ in your browser. + +## Server + +The server application defines two types, `Client` and `Hub`. The server +creates an instance of the `Client` type for each websocket connection. A +`Client` acts as an intermediary between the websocket connection and a single +instance of the `Hub` type. The `Hub` maintains a set of registered clients and +broadcasts messages to the clients. + +The application runs one goroutine for the `Hub` and two goroutines for each +`Client`. The goroutines communicate with each other using channels. The `Hub` +has channels for registering clients, unregistering clients and broadcasting +messages. A `Client` has a buffered channel of outbound messages. One of the +client's goroutines reads messages from this channel and writes the messages to +the websocket. The other client goroutine reads messages from the websocket and +sends them to the hub. + +### Hub + +The code for the `Hub` type is in +[hub.go](https://github.com/gorilla/websocket/blob/master/examples/chat/hub.go). +The application's `main` function starts the hub's `run` method as a goroutine. +Clients send requests to the hub using the `register`, `unregister` and +`broadcast` channels. + +The hub registers clients by adding the client pointer as a key in the +`clients` map. The map value is always true. + +The unregister code is a little more complicated. In addition to deleting the +client pointer from the `clients` map, the hub closes the clients's `send` +channel to signal the client that no more messages will be sent to the client. + +The hub handles messages by looping over the registered clients and sending the +message to the client's `send` channel. If the client's `send` buffer is full, +then the hub assumes that the client is dead or stuck. In this case, the hub +unregisters the client and closes the websocket. + +### Client + +The code for the `Client` type is in [client.go](https://github.com/gorilla/websocket/blob/master/examples/chat/client.go). + +The `serveWs` function is registered by the application's `main` function as +an HTTP handler. The handler upgrades the HTTP connection to the WebSocket +protocol, creates a client, registers the client with the hub and schedules the +client to be unregistered using a defer statement. + +Next, the HTTP handler starts the client's `writePump` method as a goroutine. +This method transfers messages from the client's send channel to the websocket +connection. The writer method exits when the channel is closed by the hub or +there's an error writing to the websocket connection. + +Finally, the HTTP handler calls the client's `readPump` method. This method +transfers inbound messages from the websocket to the hub. + +WebSocket connections [support one concurrent reader and one concurrent +writer](https://godoc.org/github.com/gorilla/websocket#hdr-Concurrency). The +application ensures that these concurrency requirements are met by executing +all reads from the `readPump` goroutine and all writes from the `writePump` +goroutine. + +To improve efficiency under high load, the `writePump` function coalesces +pending chat messages in the `send` channel to a single WebSocket message. This +reduces the number of system calls and the amount of data sent over the +network. + +## Frontend + +The frontend code is in [home.html](https://github.com/gorilla/websocket/blob/master/examples/chat/home.html). + +On document load, the script checks for websocket functionality in the browser. +If websocket functionality is available, then the script opens a connection to +the server and registers a callback to handle messages from the server. The +callback appends the message to the chat log using the appendLog function. + +To allow the user to manually scroll through the chat log without interruption +from new messages, the `appendLog` function checks the scroll position before +adding new content. If the chat log is scrolled to the bottom, then the +function scrolls new content into view after adding the content. Otherwise, the +scroll position is not changed. + +The form handler writes the user input to the websocket and clears the input +field. diff --git a/vendor/github.com/gorilla/websocket/examples/chat/client.go b/vendor/github.com/gorilla/websocket/examples/chat/client.go new file mode 100644 index 000000000..26468477c --- /dev/null +++ b/vendor/github.com/gorilla/websocket/examples/chat/client.go @@ -0,0 +1,134 @@ +// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import ( + "bytes" + "log" + "net/http" + "time" + + "github.com/gorilla/websocket" +) + +const ( + // Time allowed to write a message to the peer. + writeWait = 10 * time.Second + + // Time allowed to read the next pong message from the peer. + pongWait = 60 * time.Second + + // Send pings to peer with this period. Must be less than pongWait. + pingPeriod = (pongWait * 9) / 10 + + // Maximum message size allowed from peer. + maxMessageSize = 512 +) + +var ( + newline = []byte{'\n'} + space = []byte{' '} +) + +var upgrader = websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, +} + +// Client is a middleman between the websocket connection and the hub. +type Client struct { + hub *Hub + + // The websocket connection. + conn *websocket.Conn + + // Buffered channel of outbound messages. + send chan []byte +} + +// readPump pumps messages from the websocket connection to the hub. +// +// The application runs readPump in a per-connection goroutine. The application +// ensures that there is at most one reader on a connection by executing all +// reads from this goroutine. +func (c *Client) readPump() { + defer func() { + c.hub.unregister <- c + c.conn.Close() + }() + c.conn.SetReadLimit(maxMessageSize) + c.conn.SetReadDeadline(time.Now().Add(pongWait)) + c.conn.SetPongHandler(func(string) error { c.conn.SetReadDeadline(time.Now().Add(pongWait)); return nil }) + for { + _, message, err := c.conn.ReadMessage() + if err != nil { + if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway) { + log.Printf("error: %v", err) + } + break + } + message = bytes.TrimSpace(bytes.Replace(message, newline, space, -1)) + c.hub.broadcast <- message + } +} + +// writePump pumps messages from the hub to the websocket connection. +// +// A goroutine running writePump is started for each connection. The +// application ensures that there is at most one writer to a connection by +// executing all writes from this goroutine. +func (c *Client) writePump() { + ticker := time.NewTicker(pingPeriod) + defer func() { + ticker.Stop() + c.conn.Close() + }() + for { + select { + case message, ok := <-c.send: + c.conn.SetWriteDeadline(time.Now().Add(writeWait)) + if !ok { + // The hub closed the channel. + c.conn.WriteMessage(websocket.CloseMessage, []byte{}) + return + } + + w, err := c.conn.NextWriter(websocket.TextMessage) + if err != nil { + return + } + w.Write(message) + + // Add queued chat messages to the current websocket message. + n := len(c.send) + for i := 0; i < n; i++ { + w.Write(newline) + w.Write(<-c.send) + } + + if err := w.Close(); err != nil { + return + } + case <-ticker.C: + c.conn.SetWriteDeadline(time.Now().Add(writeWait)) + if err := c.conn.WriteMessage(websocket.PingMessage, []byte{}); err != nil { + return + } + } + } +} + +// serveWs handles websocket requests from the peer. +func serveWs(hub *Hub, w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + log.Println(err) + return + } + client := &Client{hub: hub, conn: conn, send: make(chan []byte, 256)} + client.hub.register <- client + go client.writePump() + client.readPump() +} diff --git a/vendor/github.com/gorilla/websocket/examples/chat/conn.go b/vendor/github.com/gorilla/websocket/examples/chat/conn.go deleted file mode 100644 index 40fd38c2c..000000000 --- a/vendor/github.com/gorilla/websocket/examples/chat/conn.go +++ /dev/null @@ -1,105 +0,0 @@ -// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package main - -import ( - "github.com/gorilla/websocket" - "log" - "net/http" - "time" -) - -const ( - // Time allowed to write a message to the peer. - writeWait = 10 * time.Second - - // Time allowed to read the next pong message from the peer. - pongWait = 60 * time.Second - - // Send pings to peer with this period. Must be less than pongWait. - pingPeriod = (pongWait * 9) / 10 - - // Maximum message size allowed from peer. - maxMessageSize = 512 -) - -var upgrader = websocket.Upgrader{ - ReadBufferSize: 1024, - WriteBufferSize: 1024, -} - -// connection is an middleman between the websocket connection and the hub. -type connection struct { - // The websocket connection. - ws *websocket.Conn - - // Buffered channel of outbound messages. - send chan []byte -} - -// readPump pumps messages from the websocket connection to the hub. -func (c *connection) readPump() { - defer func() { - h.unregister <- c - c.ws.Close() - }() - c.ws.SetReadLimit(maxMessageSize) - c.ws.SetReadDeadline(time.Now().Add(pongWait)) - c.ws.SetPongHandler(func(string) error { c.ws.SetReadDeadline(time.Now().Add(pongWait)); return nil }) - for { - _, message, err := c.ws.ReadMessage() - if err != nil { - if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway) { - log.Printf("error: %v", err) - } - break - } - h.broadcast <- message - } -} - -// write writes a message with the given message type and payload. -func (c *connection) write(mt int, payload []byte) error { - c.ws.SetWriteDeadline(time.Now().Add(writeWait)) - return c.ws.WriteMessage(mt, payload) -} - -// writePump pumps messages from the hub to the websocket connection. -func (c *connection) writePump() { - ticker := time.NewTicker(pingPeriod) - defer func() { - ticker.Stop() - c.ws.Close() - }() - for { - select { - case message, ok := <-c.send: - if !ok { - c.write(websocket.CloseMessage, []byte{}) - return - } - if err := c.write(websocket.TextMessage, message); err != nil { - return - } - case <-ticker.C: - if err := c.write(websocket.PingMessage, []byte{}); err != nil { - return - } - } - } -} - -// serveWs handles websocket requests from the peer. -func serveWs(w http.ResponseWriter, r *http.Request) { - ws, err := upgrader.Upgrade(w, r, nil) - if err != nil { - log.Println(err) - return - } - c := &connection{send: make(chan []byte, 256), ws: ws} - h.register <- c - go c.writePump() - c.readPump() -} diff --git a/vendor/github.com/gorilla/websocket/examples/chat/home.html b/vendor/github.com/gorilla/websocket/examples/chat/home.html index 29599225c..7262918ec 100644 --- a/vendor/github.com/gorilla/websocket/examples/chat/home.html +++ b/vendor/github.com/gorilla/websocket/examples/chat/home.html @@ -2,47 +2,53 @@ <html lang="en"> <head> <title>Chat Example</title> -<script src="//ajax.googleapis.com/ajax/libs/jquery/2.0.3/jquery.min.js"></script> <script type="text/javascript"> - $(function() { - +window.onload = function () { var conn; - var msg = $("#msg"); - var log = $("#log"); + var msg = document.getElementById("msg"); + var log = document.getElementById("log"); - function appendLog(msg) { - var d = log[0] - var doScroll = d.scrollTop == d.scrollHeight - d.clientHeight; - msg.appendTo(log) + function appendLog(item) { + var doScroll = log.scrollTop === log.scrollHeight - log.clientHeight; + log.appendChild(item); if (doScroll) { - d.scrollTop = d.scrollHeight - d.clientHeight; + log.scrollTop = log.scrollHeight - log.clientHeight; } } - $("#form").submit(function() { + document.getElementById("form").onsubmit = function () { if (!conn) { return false; } - if (!msg.val()) { + if (!msg.value) { return false; } - conn.send(msg.val()); - msg.val(""); - return false - }); + conn.send(msg.value); + msg.value = ""; + return false; + }; if (window["WebSocket"]) { conn = new WebSocket("ws://{{$}}/ws"); - conn.onclose = function(evt) { - appendLog($("<div><b>Connection closed.</b></div>")) - } - conn.onmessage = function(evt) { - appendLog($("<div/>").text(evt.data)) - } + conn.onclose = function (evt) { + var item = document.createElement("div"); + item.innerHTML = "<b>Connection closed.</b>"; + appendLog(item); + }; + conn.onmessage = function (evt) { + var messages = evt.data.split('\n'); + for (var i = 0; i < messages.length; i++) { + var item = document.createElement("div"); + item.innerText = messages[i]; + appendLog(item); + } + }; } else { - appendLog($("<div><b>Your browser does not support WebSockets.</b></div>")) + var item = document.createElement("div"); + item.innerHTML = "<b>Your browser does not support WebSockets.</b>"; + appendLog(item); } - }); +}; </script> <style type="text/css"> html { diff --git a/vendor/github.com/gorilla/websocket/examples/chat/hub.go b/vendor/github.com/gorilla/websocket/examples/chat/hub.go index 449ba753d..7f07ea079 100644 --- a/vendor/github.com/gorilla/websocket/examples/chat/hub.go +++ b/vendor/github.com/gorilla/websocket/examples/chat/hub.go @@ -4,46 +4,48 @@ package main -// hub maintains the set of active connections and broadcasts messages to the -// connections. -type hub struct { - // Registered connections. - connections map[*connection]bool +// hub maintains the set of active clients and broadcasts messages to the +// clients. +type Hub struct { + // Registered clients. + clients map[*Client]bool - // Inbound messages from the connections. + // Inbound messages from the clients. broadcast chan []byte - // Register requests from the connections. - register chan *connection + // Register requests from the clients. + register chan *Client - // Unregister requests from connections. - unregister chan *connection + // Unregister requests from clients. + unregister chan *Client } -var h = hub{ - broadcast: make(chan []byte), - register: make(chan *connection), - unregister: make(chan *connection), - connections: make(map[*connection]bool), +func newHub() *Hub { + return &Hub{ + broadcast: make(chan []byte), + register: make(chan *Client), + unregister: make(chan *Client), + clients: make(map[*Client]bool), + } } -func (h *hub) run() { +func (h *Hub) run() { for { select { - case c := <-h.register: - h.connections[c] = true - case c := <-h.unregister: - if _, ok := h.connections[c]; ok { - delete(h.connections, c) - close(c.send) + case client := <-h.register: + h.clients[client] = true + case client := <-h.unregister: + if _, ok := h.clients[client]; ok { + delete(h.clients, client) + close(client.send) } - case m := <-h.broadcast: - for c := range h.connections { + case message := <-h.broadcast: + for client := range h.clients { select { - case c.send <- m: + case client.send <- message: default: - close(c.send) - delete(h.connections, c) + close(client.send) + delete(h.clients, client) } } } diff --git a/vendor/github.com/gorilla/websocket/examples/chat/main.go b/vendor/github.com/gorilla/websocket/examples/chat/main.go index 3c4448d72..a865ffec5 100644 --- a/vendor/github.com/gorilla/websocket/examples/chat/main.go +++ b/vendor/github.com/gorilla/websocket/examples/chat/main.go @@ -12,9 +12,10 @@ import ( ) var addr = flag.String("addr", ":8080", "http service address") -var homeTempl = template.Must(template.ParseFiles("home.html")) +var homeTemplate = template.Must(template.ParseFiles("home.html")) func serveHome(w http.ResponseWriter, r *http.Request) { + log.Println(r.URL) if r.URL.Path != "/" { http.Error(w, "Not found", 404) return @@ -24,14 +25,17 @@ func serveHome(w http.ResponseWriter, r *http.Request) { return } w.Header().Set("Content-Type", "text/html; charset=utf-8") - homeTempl.Execute(w, r.Host) + homeTemplate.Execute(w, r.Host) } func main() { flag.Parse() - go h.run() + hub := newHub() + go hub.run() http.HandleFunc("/", serveHome) - http.HandleFunc("/ws", serveWs) + http.HandleFunc("/ws", func(w http.ResponseWriter, r *http.Request) { + serveWs(hub, w, r) + }) err := http.ListenAndServe(*addr, nil) if err != nil { log.Fatal("ListenAndServe: ", err) diff --git a/vendor/github.com/gorilla/websocket/examples/command/main.go b/vendor/github.com/gorilla/websocket/examples/command/main.go index f3f022edb..438fb8328 100644 --- a/vendor/github.com/gorilla/websocket/examples/command/main.go +++ b/vendor/github.com/gorilla/websocket/examples/command/main.go @@ -36,6 +36,9 @@ const ( // Send pings to peer with this period. Must be less than pongWait. pingPeriod = (pongWait * 9) / 10 + + // Time to wait before force close on connection. + closeGracePeriod = 10 * time.Second ) func pumpStdin(ws *websocket.Conn, w io.Writer) { @@ -57,19 +60,24 @@ func pumpStdin(ws *websocket.Conn, w io.Writer) { func pumpStdout(ws *websocket.Conn, r io.Reader, done chan struct{}) { defer func() { - ws.Close() - close(done) }() s := bufio.NewScanner(r) for s.Scan() { ws.SetWriteDeadline(time.Now().Add(writeWait)) if err := ws.WriteMessage(websocket.TextMessage, s.Bytes()); err != nil { + ws.Close() break } } if s.Err() != nil { log.Println("scan:", s.Err()) } + close(done) + + ws.SetWriteDeadline(time.Now().Add(writeWait)) + ws.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) + time.Sleep(closeGracePeriod) + ws.Close() } func ping(ws *websocket.Conn, done chan struct{}) { diff --git a/vendor/github.com/gorilla/websocket/mask.go b/vendor/github.com/gorilla/websocket/mask.go new file mode 100644 index 000000000..6758a2cb7 --- /dev/null +++ b/vendor/github.com/gorilla/websocket/mask.go @@ -0,0 +1,61 @@ +// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved. Use of +// this source code is governed by a BSD-style license that can be found in the +// LICENSE file. + +package websocket + +import ( + "math/rand" + "unsafe" +) + +const wordSize = int(unsafe.Sizeof(uintptr(0))) + +func newMaskKey() [4]byte { + n := rand.Uint32() + return [4]byte{byte(n), byte(n >> 8), byte(n >> 16), byte(n >> 24)} +} + +func maskBytes(key [4]byte, pos int, b []byte) int { + + // Mask one byte at a time for small buffers. + if len(b) < 2*wordSize { + for i := range b { + b[i] ^= key[pos&3] + pos++ + } + return pos & 3 + } + + // Mask one byte at a time to word boundary. + if n := int(uintptr(unsafe.Pointer(&b[0]))) % wordSize; n != 0 { + n = wordSize - n + for i := range b[:n] { + b[i] ^= key[pos&3] + pos++ + } + b = b[n:] + } + + // Create aligned word size key. + var k [wordSize]byte + for i := range k { + k[i] = key[(pos+i)&3] + } + kw := *(*uintptr)(unsafe.Pointer(&k)) + + // Mask one word at a time. + n := (len(b) / wordSize) * wordSize + for i := 0; i < n; i += wordSize { + *(*uintptr)(unsafe.Pointer(uintptr(unsafe.Pointer(&b[0])) + uintptr(i))) ^= kw + } + + // Mask one byte at a time for remaining bytes. + b = b[n:] + for i := range b { + b[i] ^= key[pos&3] + pos++ + } + + return pos & 3 +} diff --git a/vendor/github.com/gorilla/websocket/mask_test.go b/vendor/github.com/gorilla/websocket/mask_test.go new file mode 100644 index 000000000..de0602993 --- /dev/null +++ b/vendor/github.com/gorilla/websocket/mask_test.go @@ -0,0 +1,73 @@ +// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved. Use of +// this source code is governed by a BSD-style license that can be found in the +// LICENSE file. + +// Require 1.7 for sub-bencmarks +// +build go1.7 + +package websocket + +import ( + "fmt" + "testing" +) + +func maskBytesByByte(key [4]byte, pos int, b []byte) int { + for i := range b { + b[i] ^= key[pos&3] + pos++ + } + return pos & 3 +} + +func notzero(b []byte) int { + for i := range b { + if b[i] != 0 { + return i + } + } + return -1 +} + +func TestMaskBytes(t *testing.T) { + key := [4]byte{1, 2, 3, 4} + for size := 1; size <= 1024; size++ { + for align := 0; align < wordSize; align++ { + for pos := 0; pos < 4; pos++ { + b := make([]byte, size+align)[align:] + maskBytes(key, pos, b) + maskBytesByByte(key, pos, b) + if i := notzero(b); i >= 0 { + t.Errorf("size:%d, align:%d, pos:%d, offset:%d", size, align, pos, i) + } + } + } + } +} + +func BenchmarkMaskBytes(b *testing.B) { + for _, size := range []int{2, 4, 8, 16, 32, 512, 1024} { + b.Run(fmt.Sprintf("size-%d", size), func(b *testing.B) { + for _, align := range []int{wordSize / 2} { + b.Run(fmt.Sprintf("align-%d", align), func(b *testing.B) { + for _, fn := range []struct { + name string + fn func(key [4]byte, pos int, b []byte) int + }{ + {"byte", maskBytesByByte}, + {"word", maskBytes}, + } { + b.Run(fn.name, func(b *testing.B) { + key := newMaskKey() + data := make([]byte, size+align)[align:] + for i := 0; i < b.N; i++ { + fn.fn(key, 0, data) + } + b.SetBytes(int64(len(data))) + }) + } + }) + } + }) + } +} diff --git a/vendor/github.com/gorilla/websocket/server.go b/vendor/github.com/gorilla/websocket/server.go index 8d7137de9..aaedebdbe 100644 --- a/vendor/github.com/gorilla/websocket/server.go +++ b/vendor/github.com/gorilla/websocket/server.go @@ -46,6 +46,12 @@ type Upgrader struct { // CheckOrigin is nil, the host in the Origin header must not be set or // must match the host of the request. CheckOrigin func(r *http.Request) bool + + // EnableCompression specify if the server should attempt to negotiate per + // message compression (RFC 7692). Setting this value to true does not + // guarantee that compression will be supported. Currently only "no context + // takeover" modes are supported. + EnableCompression bool } func (u *Upgrader) returnError(w http.ResponseWriter, r *http.Request, status int, reason string) (*Conn, error) { @@ -53,6 +59,7 @@ func (u *Upgrader) returnError(w http.ResponseWriter, r *http.Request, status in if u.Error != nil { u.Error(w, r, status, err) } else { + w.Header().Set("Sec-Websocket-Version", "13") http.Error(w, http.StatusText(status), status) } return nil, err @@ -99,7 +106,12 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade if r.Method != "GET" { return u.returnError(w, r, http.StatusMethodNotAllowed, "websocket: method not GET") } - if values := r.Header["Sec-Websocket-Version"]; len(values) == 0 || values[0] != "13" { + + if _, ok := responseHeader["Sec-Websocket-Extensions"]; ok { + return u.returnError(w, r, http.StatusInternalServerError, "websocket: application specific Sec-Websocket-Extensions headers are unsupported") + } + + if !tokenListContainsValue(r.Header, "Sec-Websocket-Version", "13") { return u.returnError(w, r, http.StatusBadRequest, "websocket: version != 13") } @@ -126,6 +138,18 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade subprotocol := u.selectSubprotocol(r, responseHeader) + // Negotiate PMCE + var compress bool + if u.EnableCompression { + for _, ext := range parseExtensions(r.Header) { + if ext[""] != "permessage-deflate" { + continue + } + compress = true + break + } + } + var ( netConn net.Conn br *bufio.Reader @@ -151,6 +175,11 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade c := newConn(netConn, true, u.ReadBufferSize, u.WriteBufferSize) c.subprotocol = subprotocol + if compress { + c.newCompressionWriter = compressNoContextTakeover + c.newDecompressionReader = decompressNoContextTakeover + } + p := c.writeBuf[:0] p = append(p, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: "...) p = append(p, computeAcceptKey(challengeKey)...) @@ -160,6 +189,9 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade p = append(p, c.subprotocol...) p = append(p, "\r\n"...) } + if compress { + p = append(p, "Sec-Websocket-Extensions: permessage-deflate; server_no_context_takeover; client_no_context_takeover\r\n"...) + } for k, vs := range responseHeader { if k == "Sec-Websocket-Protocol" { continue diff --git a/vendor/github.com/gorilla/websocket/util.go b/vendor/github.com/gorilla/websocket/util.go index ffdc265ed..9a4908df2 100644 --- a/vendor/github.com/gorilla/websocket/util.go +++ b/vendor/github.com/gorilla/websocket/util.go @@ -13,19 +13,6 @@ import ( "strings" ) -// tokenListContainsValue returns true if the 1#token header with the given -// name contains token. -func tokenListContainsValue(header http.Header, name string, value string) bool { - for _, v := range header[name] { - for _, s := range strings.Split(v, ",") { - if strings.EqualFold(value, strings.TrimSpace(s)) { - return true - } - } - } - return false -} - var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11") func computeAcceptKey(challengeKey string) string { @@ -42,3 +29,186 @@ func generateChallengeKey() (string, error) { } return base64.StdEncoding.EncodeToString(p), nil } + +// Octet types from RFC 2616. +var octetTypes [256]byte + +const ( + isTokenOctet = 1 << iota + isSpaceOctet +) + +func init() { + // From RFC 2616 + // + // OCTET = <any 8-bit sequence of data> + // CHAR = <any US-ASCII character (octets 0 - 127)> + // CTL = <any US-ASCII control character (octets 0 - 31) and DEL (127)> + // CR = <US-ASCII CR, carriage return (13)> + // LF = <US-ASCII LF, linefeed (10)> + // SP = <US-ASCII SP, space (32)> + // HT = <US-ASCII HT, horizontal-tab (9)> + // <"> = <US-ASCII double-quote mark (34)> + // CRLF = CR LF + // LWS = [CRLF] 1*( SP | HT ) + // TEXT = <any OCTET except CTLs, but including LWS> + // separators = "(" | ")" | "<" | ">" | "@" | "," | ";" | ":" | "\" | <"> + // | "/" | "[" | "]" | "?" | "=" | "{" | "}" | SP | HT + // token = 1*<any CHAR except CTLs or separators> + // qdtext = <any TEXT except <">> + + for c := 0; c < 256; c++ { + var t byte + isCtl := c <= 31 || c == 127 + isChar := 0 <= c && c <= 127 + isSeparator := strings.IndexRune(" \t\"(),/:;<=>?@[]\\{}", rune(c)) >= 0 + if strings.IndexRune(" \t\r\n", rune(c)) >= 0 { + t |= isSpaceOctet + } + if isChar && !isCtl && !isSeparator { + t |= isTokenOctet + } + octetTypes[c] = t + } +} + +func skipSpace(s string) (rest string) { + i := 0 + for ; i < len(s); i++ { + if octetTypes[s[i]]&isSpaceOctet == 0 { + break + } + } + return s[i:] +} + +func nextToken(s string) (token, rest string) { + i := 0 + for ; i < len(s); i++ { + if octetTypes[s[i]]&isTokenOctet == 0 { + break + } + } + return s[:i], s[i:] +} + +func nextTokenOrQuoted(s string) (value string, rest string) { + if !strings.HasPrefix(s, "\"") { + return nextToken(s) + } + s = s[1:] + for i := 0; i < len(s); i++ { + switch s[i] { + case '"': + return s[:i], s[i+1:] + case '\\': + p := make([]byte, len(s)-1) + j := copy(p, s[:i]) + escape := true + for i = i + 1; i < len(s); i++ { + b := s[i] + switch { + case escape: + escape = false + p[j] = b + j += 1 + case b == '\\': + escape = true + case b == '"': + return string(p[:j]), s[i+1:] + default: + p[j] = b + j += 1 + } + } + return "", "" + } + } + return "", "" +} + +// tokenListContainsValue returns true if the 1#token header with the given +// name contains token. +func tokenListContainsValue(header http.Header, name string, value string) bool { +headers: + for _, s := range header[name] { + for { + var t string + t, s = nextToken(skipSpace(s)) + if t == "" { + continue headers + } + s = skipSpace(s) + if s != "" && s[0] != ',' { + continue headers + } + if strings.EqualFold(t, value) { + return true + } + if s == "" { + continue headers + } + s = s[1:] + } + } + return false +} + +// parseExtensiosn parses WebSocket extensions from a header. +func parseExtensions(header http.Header) []map[string]string { + + // From RFC 6455: + // + // Sec-WebSocket-Extensions = extension-list + // extension-list = 1#extension + // extension = extension-token *( ";" extension-param ) + // extension-token = registered-token + // registered-token = token + // extension-param = token [ "=" (token | quoted-string) ] + // ;When using the quoted-string syntax variant, the value + // ;after quoted-string unescaping MUST conform to the + // ;'token' ABNF. + + var result []map[string]string +headers: + for _, s := range header["Sec-Websocket-Extensions"] { + for { + var t string + t, s = nextToken(skipSpace(s)) + if t == "" { + continue headers + } + ext := map[string]string{"": t} + for { + s = skipSpace(s) + if !strings.HasPrefix(s, ";") { + break + } + var k string + k, s = nextToken(skipSpace(s[1:])) + if k == "" { + continue headers + } + s = skipSpace(s) + var v string + if strings.HasPrefix(s, "=") { + v, s = nextTokenOrQuoted(skipSpace(s[1:])) + s = skipSpace(s) + } + if s != "" && s[0] != ',' && s[0] != ';' { + continue headers + } + ext[k] = v + } + if s != "" && s[0] != ',' { + continue headers + } + result = append(result, ext) + if s == "" { + continue headers + } + s = s[1:] + } + } + return result +} diff --git a/vendor/github.com/gorilla/websocket/util_test.go b/vendor/github.com/gorilla/websocket/util_test.go index 91f70ceb0..610e613c0 100644 --- a/vendor/github.com/gorilla/websocket/util_test.go +++ b/vendor/github.com/gorilla/websocket/util_test.go @@ -6,6 +6,7 @@ package websocket import ( "net/http" + "reflect" "testing" ) @@ -32,3 +33,42 @@ func TestTokenListContainsValue(t *testing.T) { } } } + +var parseExtensionTests = []struct { + value string + extensions []map[string]string +}{ + {`foo`, []map[string]string{map[string]string{"": "foo"}}}, + {`foo, bar; baz=2`, []map[string]string{ + map[string]string{"": "foo"}, + map[string]string{"": "bar", "baz": "2"}}}, + {`foo; bar="b,a;z"`, []map[string]string{ + map[string]string{"": "foo", "bar": "b,a;z"}}}, + {`foo , bar; baz = 2`, []map[string]string{ + map[string]string{"": "foo"}, + map[string]string{"": "bar", "baz": "2"}}}, + {`foo, bar; baz=2 junk`, []map[string]string{ + map[string]string{"": "foo"}}}, + {`foo junk, bar; baz=2 junk`, nil}, + {`mux; max-channels=4; flow-control, deflate-stream`, []map[string]string{ + map[string]string{"": "mux", "max-channels": "4", "flow-control": ""}, + map[string]string{"": "deflate-stream"}}}, + {`permessage-foo; x="10"`, []map[string]string{ + map[string]string{"": "permessage-foo", "x": "10"}}}, + {`permessage-foo; use_y, permessage-foo`, []map[string]string{ + map[string]string{"": "permessage-foo", "use_y": ""}, + map[string]string{"": "permessage-foo"}}}, + {`permessage-deflate; client_max_window_bits; server_max_window_bits=10 , permessage-deflate; client_max_window_bits`, []map[string]string{ + map[string]string{"": "permessage-deflate", "client_max_window_bits": "", "server_max_window_bits": "10"}, + map[string]string{"": "permessage-deflate", "client_max_window_bits": ""}}}, +} + +func TestParseExtensions(t *testing.T) { + for _, tt := range parseExtensionTests { + h := http.Header{http.CanonicalHeaderKey("Sec-WebSocket-Extensions"): {tt.value}} + extensions := parseExtensions(h) + if !reflect.DeepEqual(extensions, tt.extensions) { + t.Errorf("parseExtensions(%q)\n = %v,\nwant %v", tt.value, extensions, tt.extensions) + } + } +} diff --git a/vendor/github.com/lib/pq/conn.go b/vendor/github.com/lib/pq/conn.go index 8e1aee9f0..ca88dc8c6 100644 --- a/vendor/github.com/lib/pq/conn.go +++ b/vendor/github.com/lib/pq/conn.go @@ -32,6 +32,10 @@ var ( ErrSSLNotSupported = errors.New("pq: SSL is not enabled on the server") ErrSSLKeyHasWorldPermissions = errors.New("pq: Private key file has group or world access. Permissions should be u=rw (0600) or less.") ErrCouldNotDetectUsername = errors.New("pq: Could not detect default username. Please provide one explicitly.") + + errUnexpectedReady = errors.New("unexpected ReadyForQuery") + errNoRowsAffected = errors.New("no RowsAffected available after the empty statement") + errNoLastInsertId = errors.New("no LastInsertId available after the empty statement") ) type drv struct{} @@ -115,6 +119,9 @@ type conn struct { // Whether to always send []byte parameters over as binary. Enables single // round-trip mode for non-prepared Query calls. binaryParameters bool + + // If true this connection is in the middle of a COPY + inCopy bool } // Handle driver-side settings in parsed connection string. @@ -598,11 +605,16 @@ func (cn *conn) simpleExec(q string) (res driver.Result, commandTag string, err res, commandTag = cn.parseComplete(r.string()) case 'Z': cn.processReadyForQuery(r) + if res == nil && err == nil { + err = errUnexpectedReady + } // done return case 'E': err = parseError(r) - case 'T', 'D', 'I': + case 'I': + res = emptyRows + case 'T', 'D': // ignore any results default: cn.bad = true @@ -666,6 +678,20 @@ func (cn *conn) simpleQuery(q string) (res *rows, err error) { } } +type noRows struct{} + +var emptyRows noRows + +var _ driver.Result = noRows{} + +func (noRows) LastInsertId() (int64, error) { + return 0, errNoLastInsertId +} + +func (noRows) RowsAffected() (int64, error) { + return 0, errNoRowsAffected +} + // Decides which column formats to use for a prepared statement. The input is // an array of type oids, one element per result column. func decideColumnFormats(colTyps []oid.Oid, forceText bool) (colFmts []format, colFmtData []byte) { @@ -743,25 +769,31 @@ func (cn *conn) Prepare(q string) (_ driver.Stmt, err error) { defer cn.errRecover(&err) if len(q) >= 4 && strings.EqualFold(q[:4], "COPY") { - return cn.prepareCopyIn(q) + s, err := cn.prepareCopyIn(q) + if err == nil { + cn.inCopy = true + } + return s, err } return cn.prepareTo(q, cn.gname()), nil } func (cn *conn) Close() (err error) { - if cn.bad { - return driver.ErrBadConn - } + // Skip cn.bad return here because we always want to close a connection. defer cn.errRecover(&err) + // Ensure that cn.c.Close is always run. Since error handling is done with + // panics and cn.errRecover, the Close must be in a defer. + defer func() { + cerr := cn.c.Close() + if err == nil { + err = cerr + } + }() + // Don't go through send(); ListenerConn relies on us not scribbling on the // scratch buffer of this connection. - err = cn.sendSimpleMessage('X') - if err != nil { - return err - } - - return cn.c.Close() + return cn.sendSimpleMessage('X') } // Implement the "Queryer" interface @@ -769,6 +801,9 @@ func (cn *conn) Query(query string, args []driver.Value) (_ driver.Rows, err err if cn.bad { return nil, driver.ErrBadConn } + if cn.inCopy { + return nil, errCopyInProgress + } defer cn.errRecover(&err) // Check to see if we can use the "simpleQuery" interface, which is @@ -1472,12 +1507,23 @@ func (rs *rows) Next(dest []driver.Value) (err error) { dest[i] = decode(&conn.parameterStatus, rs.rb.next(l), rs.colTyps[i], rs.colFmts[i]) } return + case 'T': + rs.colNames, rs.colFmts, rs.colTyps = parsePortalRowDescribe(&rs.rb) + return io.EOF default: errorf("unexpected message after execute: %q", t) } } } +func (rs *rows) HasNextResultSet() bool { + return !rs.done +} + +func (rs *rows) NextResultSet() error { + return nil +} + // QuoteIdentifier quotes an "identifier" (e.g. a table or a column name) to be // used as part of an SQL statement. For example: // @@ -1720,6 +1766,9 @@ func (cn *conn) readExecuteResponse(protocolState string) (res driver.Result, co res, commandTag = cn.parseComplete(r.string()) case 'Z': cn.processReadyForQuery(r) + if res == nil && err == nil { + err = errUnexpectedReady + } return res, commandTag, err case 'E': err = parseError(r) @@ -1728,6 +1777,9 @@ func (cn *conn) readExecuteResponse(protocolState string) (res driver.Result, co cn.bad = true errorf("unexpected %q after error %s", t, err) } + if t == 'I' { + res = emptyRows + } // ignore any results default: cn.bad = true diff --git a/vendor/github.com/lib/pq/conn_test.go b/vendor/github.com/lib/pq/conn_test.go index 592860f8a..183e6dcd6 100644 --- a/vendor/github.com/lib/pq/conn_test.go +++ b/vendor/github.com/lib/pq/conn_test.go @@ -5,6 +5,7 @@ import ( "database/sql/driver" "fmt" "io" + "net" "os" "reflect" "strings" @@ -385,10 +386,16 @@ func TestEmptyQuery(t *testing.T) { db := openTestConn(t) defer db.Close() - _, err := db.Exec("") + res, err := db.Exec("") if err != nil { t.Fatal(err) } + if _, err := res.RowsAffected(); err != errNoRowsAffected { + t.Fatalf("expected %s, got %v", errNoRowsAffected, err) + } + if _, err := res.LastInsertId(); err != errNoLastInsertId { + t.Fatalf("expected %s, got %v", errNoLastInsertId, err) + } rows, err := db.Query("") if err != nil { t.Fatal(err) @@ -411,10 +418,16 @@ func TestEmptyQuery(t *testing.T) { if err != nil { t.Fatal(err) } - _, err = stmt.Exec() + res, err = stmt.Exec() if err != nil { t.Fatal(err) } + if _, err := res.RowsAffected(); err != errNoRowsAffected { + t.Fatalf("expected %s, got %v", errNoRowsAffected, err) + } + if _, err := res.LastInsertId(); err != errNoLastInsertId { + t.Fatalf("expected %s, got %v", errNoLastInsertId, err) + } rows, err = stmt.Query() if err != nil { t.Fatal(err) @@ -653,6 +666,40 @@ func TestBadConn(t *testing.T) { } } +// TestCloseBadConn tests that the underlying connection can be closed with +// Close after an error. +func TestCloseBadConn(t *testing.T) { + nc, err := net.Dial("tcp", "localhost:5432") + if err != nil { + t.Fatal(err) + } + cn := conn{c: nc} + func() { + defer cn.errRecover(&err) + panic(io.EOF) + }() + // Verify we can write before closing. + if _, err := nc.Write(nil); err != nil { + t.Fatal(err) + } + // First close should close the connection. + if err := cn.Close(); err != nil { + t.Fatal(err) + } + // Verify write after closing fails. + if _, err := nc.Write(nil); err == nil { + t.Fatal("expected error") + } else if !strings.Contains(err.Error(), "use of closed network connection") { + t.Fatalf("expected use of closed network connection error, got %s", err) + } + // Verify second close fails. + if err := cn.Close(); err == nil { + t.Fatal("expected error") + } else if !strings.Contains(err.Error(), "use of closed network connection") { + t.Fatalf("expected use of closed network connection error, got %s", err) + } +} + func TestErrorOnExec(t *testing.T) { db := openTestConn(t) defer db.Close() diff --git a/vendor/github.com/lib/pq/copy.go b/vendor/github.com/lib/pq/copy.go index 101f11133..86a7127e1 100644 --- a/vendor/github.com/lib/pq/copy.go +++ b/vendor/github.com/lib/pq/copy.go @@ -13,6 +13,7 @@ var ( errBinaryCopyNotSupported = errors.New("pq: only text format supported for COPY") errCopyToNotSupported = errors.New("pq: COPY TO is not supported") errCopyNotSupportedOutsideTxn = errors.New("pq: COPY is only allowed inside a transaction") + errCopyInProgress = errors.New("pq: COPY in progress") ) // CopyIn creates a COPY FROM statement which can be prepared with @@ -258,6 +259,7 @@ func (ci *copyin) Close() (err error) { } <-ci.done + ci.cn.inCopy = false if ci.isErrorSet() { err = ci.err diff --git a/vendor/github.com/lib/pq/doc.go b/vendor/github.com/lib/pq/doc.go index 19798dfc9..6d252ecee 100644 --- a/vendor/github.com/lib/pq/doc.go +++ b/vendor/github.com/lib/pq/doc.go @@ -89,8 +89,10 @@ provided connection parameters. The pgpass mechanism as described in http://www.postgresql.org/docs/current/static/libpq-pgpass.html is supported, but on Windows PGPASSFILE must be specified explicitly. + Queries + database/sql does not dictate any specific format for parameter markers in query strings, and pq uses the Postgres-native ordinal markers, as shown above. The same marker can be reused for the same parameter: @@ -114,8 +116,29 @@ For more details on RETURNING, see the Postgres documentation: For additional instructions on querying see the documentation for the database/sql package. + +Data Types + + +Parameters pass through driver.DefaultParameterConverter before they are handled +by this package. When the binary_parameters connection option is enabled, +[]byte values are sent directly to the backend as data in binary format. + +This package returns the following types for values from the PostgreSQL backend: + + - integer types smallint, integer, and bigint are returned as int64 + - floating-point types real and double precision are returned as float64 + - character types char, varchar, and text are returned as string + - temporal types date, time, timetz, timestamp, and timestamptz are returned as time.Time + - the boolean type is returned as bool + - the bytea type is returned as []byte + +All other types are returned directly from the backend as []byte values in text format. + + Errors + pq may return errors of type *pq.Error which can be interrogated for error details: if err, ok := err.(*pq.Error); ok { diff --git a/vendor/github.com/lib/pq/go18_test.go b/vendor/github.com/lib/pq/go18_test.go new file mode 100644 index 000000000..df3e496b5 --- /dev/null +++ b/vendor/github.com/lib/pq/go18_test.go @@ -0,0 +1,68 @@ +// +build go1.8 + +package pq + +import "testing" + +func TestMultipleSimpleQuery(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + rows, err := db.Query("select 1; set time zone default; select 2; select 3") + if err != nil { + t.Fatal(err) + } + defer rows.Close() + + var i int + for rows.Next() { + if err := rows.Scan(&i); err != nil { + t.Fatal(err) + } + if i != 1 { + t.Fatalf("expected 1, got %d", i) + } + } + if !rows.NextResultSet() { + t.Fatal("expected more result sets", rows.Err()) + } + for rows.Next() { + if err := rows.Scan(&i); err != nil { + t.Fatal(err) + } + if i != 2 { + t.Fatalf("expected 2, got %d", i) + } + } + + // Make sure that if we ignore a result we can still query. + + rows, err = db.Query("select 4; select 5") + if err != nil { + t.Fatal(err) + } + defer rows.Close() + + for rows.Next() { + if err := rows.Scan(&i); err != nil { + t.Fatal(err) + } + if i != 4 { + t.Fatalf("expected 4, got %d", i) + } + } + if !rows.NextResultSet() { + t.Fatal("expected more result sets", rows.Err()) + } + for rows.Next() { + if err := rows.Scan(&i); err != nil { + t.Fatal(err) + } + if i != 5 { + t.Fatalf("expected 5, got %d", i) + } + } + if rows.NextResultSet() { + t.Fatal("unexpected result set") + } +} diff --git a/vendor/github.com/lib/pq/issues_test.go b/vendor/github.com/lib/pq/issues_test.go new file mode 100644 index 000000000..3a330a0a9 --- /dev/null +++ b/vendor/github.com/lib/pq/issues_test.go @@ -0,0 +1,26 @@ +package pq + +import "testing" + +func TestIssue494(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + query := `CREATE TEMP TABLE t (i INT PRIMARY KEY)` + if _, err := db.Exec(query); err != nil { + t.Fatal(err) + } + + txn, err := db.Begin() + if err != nil { + t.Fatal(err) + } + + if _, err := txn.Prepare(CopyIn("t", "i")); err != nil { + t.Fatal(err) + } + + if _, err := txn.Query("SELECT 1"); err == nil { + t.Fatal("expected error") + } +} diff --git a/vendor/github.com/miekg/dns/README.md b/vendor/github.com/miekg/dns/README.md index 83b4183eb..0e3356cb9 100644 --- a/vendor/github.com/miekg/dns/README.md +++ b/vendor/github.com/miekg/dns/README.md @@ -1,4 +1,4 @@ -[![Build Status](https://travis-ci.org/miekg/dns.svg?branch=master)](https://travis-ci.org/miekg/dns) +[![Build Status](https://travis-ci.org/miekg/dns.svg?branch=master)](https://travis-ci.org/miekg/dns) [![](https://godoc.org/github.com/miekg/dns?status.svg)](https://godoc.org/github.com/miekg/dns) # Alternative (more granular) approach to a DNS library @@ -50,6 +50,9 @@ A not-so-up-to-date-list-that-may-be-actually-current: * https://dnslookup.org * https://github.com/looterz/grimd * https://github.com/phamhongviet/serf-dns +* https://github.com/mehrdadrad/mylg +* https://github.com/bamarni/dockness +* https://github.com/fffaraz/microdns Send pull request if you want to be listed here. diff --git a/vendor/github.com/miekg/dns/client.go b/vendor/github.com/miekg/dns/client.go index 1302e4e04..0db7f7bf6 100644 --- a/vendor/github.com/miekg/dns/client.go +++ b/vendor/github.com/miekg/dns/client.go @@ -39,7 +39,7 @@ type Client struct { } // Exchange performs a synchronous UDP query. It sends the message m to the address -// contained in a and waits for an reply. Exchange does not retry a failed query, nor +// contained in a and waits for a reply. Exchange does not retry a failed query, nor // will it fall back to TCP in case of truncation. // See client.Exchange for more information on setting larger buffer sizes. func Exchange(m *Msg, a string) (r *Msg, err error) { @@ -93,8 +93,8 @@ func ExchangeConn(c net.Conn, m *Msg) (r *Msg, err error) { return r, err } -// Exchange performs an synchronous query. It sends the message m to the address -// contained in a and waits for an reply. Basic use pattern with a *dns.Client: +// Exchange performs a synchronous query. It sends the message m to the address +// contained in a and waits for a reply. Basic use pattern with a *dns.Client: // // c := new(dns.Client) // in, rtt, err := c.Exchange(message, "127.0.0.1:53") diff --git a/vendor/github.com/miekg/dns/dane.go b/vendor/github.com/miekg/dns/dane.go new file mode 100644 index 000000000..cdaa833ff --- /dev/null +++ b/vendor/github.com/miekg/dns/dane.go @@ -0,0 +1,44 @@ +package dns + +import ( + "crypto/sha256" + "crypto/sha512" + "crypto/x509" + "encoding/hex" + "errors" + "io" +) + +// CertificateToDANE converts a certificate to a hex string as used in the TLSA or SMIMEA records. +func CertificateToDANE(selector, matchingType uint8, cert *x509.Certificate) (string, error) { + switch matchingType { + case 0: + switch selector { + case 0: + return hex.EncodeToString(cert.Raw), nil + case 1: + return hex.EncodeToString(cert.RawSubjectPublicKeyInfo), nil + } + case 1: + h := sha256.New() + switch selector { + case 0: + io.WriteString(h, string(cert.Raw)) + return hex.EncodeToString(h.Sum(nil)), nil + case 1: + io.WriteString(h, string(cert.RawSubjectPublicKeyInfo)) + return hex.EncodeToString(h.Sum(nil)), nil + } + case 2: + h := sha512.New() + switch selector { + case 0: + io.WriteString(h, string(cert.Raw)) + return hex.EncodeToString(h.Sum(nil)), nil + case 1: + io.WriteString(h, string(cert.RawSubjectPublicKeyInfo)) + return hex.EncodeToString(h.Sum(nil)), nil + } + } + return "", errors.New("dns: bad MatchingType or Selector") +} diff --git a/vendor/github.com/miekg/dns/dnssec_keyscan.go b/vendor/github.com/miekg/dns/dnssec_keyscan.go index c0b54dc76..9ff3a617e 100644 --- a/vendor/github.com/miekg/dns/dnssec_keyscan.go +++ b/vendor/github.com/miekg/dns/dnssec_keyscan.go @@ -14,7 +14,7 @@ import ( // NewPrivateKey returns a PrivateKey by parsing the string s. // s should be in the same form of the BIND private key files. func (k *DNSKEY) NewPrivateKey(s string) (crypto.PrivateKey, error) { - if s[len(s)-1] != '\n' { // We need a closing newline + if s == "" || s[len(s)-1] != '\n' { // We need a closing newline return k.ReadPrivateKey(strings.NewReader(s+"\n"), "") } return k.ReadPrivateKey(strings.NewReader(s), "") diff --git a/vendor/github.com/miekg/dns/doc.go b/vendor/github.com/miekg/dns/doc.go index f3555e433..e38753d7d 100644 --- a/vendor/github.com/miekg/dns/doc.go +++ b/vendor/github.com/miekg/dns/doc.go @@ -203,7 +203,7 @@ RFC 6895 sets aside a range of type codes for private use. This range is 65,280 - 65,534 (0xFF00 - 0xFFFE). When experimenting with new Resource Records these can be used, before requesting an official type code from IANA. -see http://miek.nl/posts/2014/Sep/21/Private%20RRs%20and%20IDN%20in%20Go%20DNS/ for more +see http://miek.nl/2014/September/21/idn-and-private-rr-in-go-dns/ for more information. EDNS0 diff --git a/vendor/github.com/miekg/dns/edns.go b/vendor/github.com/miekg/dns/edns.go index 7a58aa9b1..0550aaa39 100644 --- a/vendor/github.com/miekg/dns/edns.go +++ b/vendor/github.com/miekg/dns/edns.go @@ -128,8 +128,18 @@ func (rr *OPT) Do() bool { } // SetDo sets the DO (DNSSEC OK) bit. -func (rr *OPT) SetDo() { - rr.Hdr.Ttl |= _DO +// If we pass an argument, set the DO bit to that value. +// It is possible to pass 2 or more arguments. Any arguments after the 1st is silently ignored. +func (rr *OPT) SetDo(do ...bool) { + if len(do) == 1 { + if do[0] { + rr.Hdr.Ttl |= _DO + } else { + rr.Hdr.Ttl &^= _DO + } + } else { + rr.Hdr.Ttl |= _DO + } } // EDNS0 defines an EDNS0 Option. An OPT RR can have multiple options appended to it. diff --git a/vendor/github.com/miekg/dns/edns_test.go b/vendor/github.com/miekg/dns/edns_test.go index 5fd75abb4..c290b0c8a 100644 --- a/vendor/github.com/miekg/dns/edns_test.go +++ b/vendor/github.com/miekg/dns/edns_test.go @@ -7,10 +7,46 @@ func TestOPTTtl(t *testing.T) { e.Hdr.Name = "." e.Hdr.Rrtype = TypeOPT + // verify the default setting of DO=0 if e.Do() { t.Errorf("DO bit should be zero") } + // There are 6 possible invocations of SetDo(): + // + // 1. Starting with DO=0, using SetDo() + // 2. Starting with DO=0, using SetDo(true) + // 3. Starting with DO=0, using SetDo(false) + // 4. Starting with DO=1, using SetDo() + // 5. Starting with DO=1, using SetDo(true) + // 6. Starting with DO=1, using SetDo(false) + + // verify that invoking SetDo() sets DO=1 (TEST #1) + e.SetDo() + if !e.Do() { + t.Errorf("DO bit should be non-zero") + } + // verify that using SetDo(true) works when DO=1 (TEST #5) + e.SetDo(true) + if !e.Do() { + t.Errorf("DO bit should still be non-zero") + } + // verify that we can use SetDo(false) to set DO=0 (TEST #6) + e.SetDo(false) + if e.Do() { + t.Errorf("DO bit should be zero") + } + // verify that if we call SetDo(false) when DO=0 that it is unchanged (TEST #3) + e.SetDo(false) + if e.Do() { + t.Errorf("DO bit should still be zero") + } + // verify that using SetDo(true) works for DO=0 (TEST #2) + e.SetDo(true) + if !e.Do() { + t.Errorf("DO bit should be non-zero") + } + // verify that using SetDo() works for DO=1 (TEST #4) e.SetDo() if !e.Do() { t.Errorf("DO bit should be non-zero") diff --git a/vendor/github.com/miekg/dns/msg.go b/vendor/github.com/miekg/dns/msg.go index ec2f7ab7b..a9acd1e9f 100644 --- a/vendor/github.com/miekg/dns/msg.go +++ b/vendor/github.com/miekg/dns/msg.go @@ -58,7 +58,7 @@ var ( ErrTruncated error = &Error{err: "failed to unpack truncated message"} // ErrTruncated indicates that we failed to unpack a truncated message. We unpacked as much as we had so Msg can still be used, if desired. ) -// Id, by default, returns a 16 bits random number to be used as a +// Id by default, returns a 16 bits random number to be used as a // message id. The random provided should be good enough. This being a // variable the function can be reassigned to a custom function. // For instance, to make it return a static value: diff --git a/vendor/github.com/miekg/dns/parse_test.go b/vendor/github.com/miekg/dns/parse_test.go index 3b38dba65..ca467a227 100644 --- a/vendor/github.com/miekg/dns/parse_test.go +++ b/vendor/github.com/miekg/dns/parse_test.go @@ -1375,6 +1375,27 @@ func TestParseTLSA(t *testing.T) { } } +func TestParseSMIMEA(t *testing.T) { + lt := map[string]string{ + "2e85e1db3e62be6ea._smimecert.example.com.\t3600\tIN\tSMIMEA\t1 1 2 bd80f334566928fc18f58df7e4928c1886f48f71ca3fd41cd9b1854aca7c2180aaacad2819612ed68e7bd3701cc39be7f2529b017c0bc6a53e8fb3f0c7d48070": "2e85e1db3e62be6ea._smimecert.example.com.\t3600\tIN\tSMIMEA\t1 1 2 bd80f334566928fc18f58df7e4928c1886f48f71ca3fd41cd9b1854aca7c2180aaacad2819612ed68e7bd3701cc39be7f2529b017c0bc6a53e8fb3f0c7d48070", + "2e85e1db3e62be6ea._smimecert.example.com.\t3600\tIN\tSMIMEA\t0 0 1 cdcf0fc66b182928c5217ddd42c826983f5a4b94160ee6c1c9be62d38199f710": "2e85e1db3e62be6ea._smimecert.example.com.\t3600\tIN\tSMIMEA\t0 0 1 cdcf0fc66b182928c5217ddd42c826983f5a4b94160ee6c1c9be62d38199f710", + "2e85e1db3e62be6ea._smimecert.example.com.\t3600\tIN\tSMIMEA\t3 0 2 499a1eda2af8828b552cdb9d80c3744a25872fddd73f3898d8e4afa3549595d2dd4340126e759566fe8c26b251fa0c887ba4869f011a65f7e79967c2eb729f5b": "2e85e1db3e62be6ea._smimecert.example.com.\t3600\tIN\tSMIMEA\t3 0 2 499a1eda2af8828b552cdb9d80c3744a25872fddd73f3898d8e4afa3549595d2dd4340126e759566fe8c26b251fa0c887ba4869f011a65f7e79967c2eb729f5b", + "2e85e1db3e62be6eb._smimecert.example.com.\t3600\tIN\tSMIMEA\t3 0 2 499a1eda2af8828b552cdb9d80c3744a25872fddd73f3898d8e4afa3549595d2dd4340126e759566fe8 c26b251fa0c887ba4869f01 1a65f7e79967c2eb729f5b": "2e85e1db3e62be6eb._smimecert.example.com.\t3600\tIN\tSMIMEA\t3 0 2 499a1eda2af8828b552cdb9d80c3744a25872fddd73f3898d8e4afa3549595d2dd4340126e759566fe8c26b251fa0c887ba4869f011a65f7e79967c2eb729f5b", + } + for i, o := range lt { + rr, err := NewRR(i) + if err != nil { + t.Error("failed to parse RR: ", err) + continue + } + if rr.String() != o { + t.Errorf("`%s' should be equal to\n`%s', but is `%s'", o, o, rr.String()) + } else { + t.Logf("RR is OK: `%s'", rr.String()) + } + } +} + func TestParseSSHFP(t *testing.T) { lt := []string{ "test.example.org.\t300\tSSHFP\t1 2 (\n" + diff --git a/vendor/github.com/miekg/dns/privaterr_test.go b/vendor/github.com/miekg/dns/privaterr_test.go index 5f177aa47..72ec8f5c0 100644 --- a/vendor/github.com/miekg/dns/privaterr_test.go +++ b/vendor/github.com/miekg/dns/privaterr_test.go @@ -7,7 +7,7 @@ import ( "github.com/miekg/dns" ) -const TypeISBN uint16 = 0x0F01 +const TypeISBN uint16 = 0xFF00 // A crazy new RR type :) type ISBN struct { @@ -101,7 +101,7 @@ func TestPrivateByteSlice(t *testing.T) { } } -const TypeVERSION uint16 = 0x0F02 +const TypeVERSION uint16 = 0xFF01 type VERSION struct { x string diff --git a/vendor/github.com/miekg/dns/scan.go b/vendor/github.com/miekg/dns/scan.go index 0e83797fb..d34597ba3 100644 --- a/vendor/github.com/miekg/dns/scan.go +++ b/vendor/github.com/miekg/dns/scan.go @@ -627,6 +627,7 @@ func zlexer(s *scan, c chan lex) { if stri > 0 { l.value = zString l.token = string(str[:stri]) + l.tokenUpper = strings.ToUpper(l.token) l.length = stri debug.Printf("[4 %+v]", l.token) c <- l @@ -663,6 +664,7 @@ func zlexer(s *scan, c chan lex) { owner = true l.value = zNewline l.token = "\n" + l.tokenUpper = l.token l.length = 1 l.comment = string(com[:comi]) debug.Printf("[3 %+v %+v]", l.token, l.comment) @@ -696,6 +698,7 @@ func zlexer(s *scan, c chan lex) { } l.value = zNewline l.token = "\n" + l.tokenUpper = l.token l.length = 1 debug.Printf("[1 %+v]", l.token) c <- l @@ -740,6 +743,7 @@ func zlexer(s *scan, c chan lex) { if stri != 0 { l.value = zString l.token = string(str[:stri]) + l.tokenUpper = strings.ToUpper(l.token) l.length = stri debug.Printf("[%+v]", l.token) @@ -750,6 +754,7 @@ func zlexer(s *scan, c chan lex) { // send quote itself as separate token l.value = zQuote l.token = "\"" + l.tokenUpper = l.token l.length = 1 c <- l quote = !quote @@ -775,6 +780,7 @@ func zlexer(s *scan, c chan lex) { brace-- if brace < 0 { l.token = "extra closing brace" + l.tokenUpper = l.token l.err = true debug.Printf("[%+v]", l.token) c <- l @@ -799,6 +805,7 @@ func zlexer(s *scan, c chan lex) { if stri > 0 { // Send remainder l.token = string(str[:stri]) + l.tokenUpper = strings.ToUpper(l.token) l.length = stri l.value = zString debug.Printf("[%+v]", l.token) diff --git a/vendor/github.com/miekg/dns/scan_rr.go b/vendor/github.com/miekg/dns/scan_rr.go index e521dc063..675fc80d8 100644 --- a/vendor/github.com/miekg/dns/scan_rr.go +++ b/vendor/github.com/miekg/dns/scan_rr.go @@ -1746,6 +1746,41 @@ func setTLSA(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) { return rr, nil, c1 } +func setSMIMEA(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) { + rr := new(SMIMEA) + rr.Hdr = h + l := <-c + if l.length == 0 { + return rr, nil, l.comment + } + i, e := strconv.Atoi(l.token) + if e != nil || l.err { + return nil, &ParseError{f, "bad SMIMEA Usage", l}, "" + } + rr.Usage = uint8(i) + <-c // zBlank + l = <-c + i, e = strconv.Atoi(l.token) + if e != nil || l.err { + return nil, &ParseError{f, "bad SMIMEA Selector", l}, "" + } + rr.Selector = uint8(i) + <-c // zBlank + l = <-c + i, e = strconv.Atoi(l.token) + if e != nil || l.err { + return nil, &ParseError{f, "bad SMIMEA MatchingType", l}, "" + } + rr.MatchingType = uint8(i) + // So this needs be e2 (i.e. different than e), because...??t + s, e2, c1 := endingToString(c, "bad SMIMEA Certificate", f) + if e2 != nil { + return nil, e2, c1 + } + rr.Certificate = s + return rr, nil, c1 +} + func setRFC3597(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) { rr := new(RFC3597) rr.Hdr = h @@ -2128,6 +2163,7 @@ var typeToparserFunc = map[uint16]parserFunc{ TypeRP: {setRP, false}, TypeRRSIG: {setRRSIG, true}, TypeRT: {setRT, false}, + TypeSMIMEA: {setSMIMEA, true}, TypeSOA: {setSOA, false}, TypeSPF: {setSPF, true}, TypeSRV: {setSRV, false}, diff --git a/vendor/github.com/miekg/dns/server.go b/vendor/github.com/miekg/dns/server.go index 2b4bff49f..c34801aef 100644 --- a/vendor/github.com/miekg/dns/server.go +++ b/vendor/github.com/miekg/dns/server.go @@ -147,7 +147,7 @@ func (mux *ServeMux) match(q string, t uint16) Handler { b[i] |= ('a' - 'A') } } - if h, ok := mux.z[string(b[:l])]; ok { // 'causes garbage, might want to change the map key + if h, ok := mux.z[string(b[:l])]; ok { // causes garbage, might want to change the map key if t != TypeDS { return h } diff --git a/vendor/github.com/miekg/dns/smimea.go b/vendor/github.com/miekg/dns/smimea.go new file mode 100644 index 000000000..3a4bb5700 --- /dev/null +++ b/vendor/github.com/miekg/dns/smimea.go @@ -0,0 +1,47 @@ +package dns + +import ( + "crypto/sha256" + "crypto/x509" + "encoding/hex" +) + +// Sign creates a SMIMEA record from an SSL certificate. +func (r *SMIMEA) Sign(usage, selector, matchingType int, cert *x509.Certificate) (err error) { + r.Hdr.Rrtype = TypeSMIMEA + r.Usage = uint8(usage) + r.Selector = uint8(selector) + r.MatchingType = uint8(matchingType) + + r.Certificate, err = CertificateToDANE(r.Selector, r.MatchingType, cert) + if err != nil { + return err + } + return nil +} + +// Verify verifies a SMIMEA record against an SSL certificate. If it is OK +// a nil error is returned. +func (r *SMIMEA) Verify(cert *x509.Certificate) error { + c, err := CertificateToDANE(r.Selector, r.MatchingType, cert) + if err != nil { + return err // Not also ErrSig? + } + if r.Certificate == c { + return nil + } + return ErrSig // ErrSig, really? +} + +// SIMEAName returns the ownername of a SMIMEA resource record as per the +// format specified in RFC 'draft-ietf-dane-smime-12' Section 2 and 3 +func SMIMEAName(email_address string, domain_name string) (string, error) { + hasher := sha256.New() + hasher.Write([]byte(email_address)) + + // RFC Section 3: "The local-part is hashed using the SHA2-256 + // algorithm with the hash truncated to 28 octets and + // represented in its hexadecimal representation to become the + // left-most label in the prepared domain name" + return hex.EncodeToString(hasher.Sum(nil)[:28]) + "." + "_smimecert." + domain_name, nil +} diff --git a/vendor/github.com/miekg/dns/tlsa.go b/vendor/github.com/miekg/dns/tlsa.go index 34fe6615a..431e2fb5a 100644 --- a/vendor/github.com/miekg/dns/tlsa.go +++ b/vendor/github.com/miekg/dns/tlsa.go @@ -1,50 +1,11 @@ package dns import ( - "crypto/sha256" - "crypto/sha512" "crypto/x509" - "encoding/hex" - "errors" - "io" "net" "strconv" ) -// CertificateToDANE converts a certificate to a hex string as used in the TLSA record. -func CertificateToDANE(selector, matchingType uint8, cert *x509.Certificate) (string, error) { - switch matchingType { - case 0: - switch selector { - case 0: - return hex.EncodeToString(cert.Raw), nil - case 1: - return hex.EncodeToString(cert.RawSubjectPublicKeyInfo), nil - } - case 1: - h := sha256.New() - switch selector { - case 0: - io.WriteString(h, string(cert.Raw)) - return hex.EncodeToString(h.Sum(nil)), nil - case 1: - io.WriteString(h, string(cert.RawSubjectPublicKeyInfo)) - return hex.EncodeToString(h.Sum(nil)), nil - } - case 2: - h := sha512.New() - switch selector { - case 0: - io.WriteString(h, string(cert.Raw)) - return hex.EncodeToString(h.Sum(nil)), nil - case 1: - io.WriteString(h, string(cert.RawSubjectPublicKeyInfo)) - return hex.EncodeToString(h.Sum(nil)), nil - } - } - return "", errors.New("dns: bad TLSA MatchingType or TLSA Selector") -} - // Sign creates a TLSA record from an SSL certificate. func (r *TLSA) Sign(usage, selector, matchingType int, cert *x509.Certificate) (err error) { r.Hdr.Rrtype = TypeTLSA diff --git a/vendor/github.com/miekg/dns/types.go b/vendor/github.com/miekg/dns/types.go index 5059d1a79..f63a18b33 100644 --- a/vendor/github.com/miekg/dns/types.go +++ b/vendor/github.com/miekg/dns/types.go @@ -70,6 +70,7 @@ const ( TypeNSEC3 uint16 = 50 TypeNSEC3PARAM uint16 = 51 TypeTLSA uint16 = 52 + TypeSMIMEA uint16 = 53 TypeHIP uint16 = 55 TypeNINFO uint16 = 56 TypeRKEY uint16 = 57 @@ -1047,6 +1048,28 @@ func (rr *TLSA) String() string { " " + rr.Certificate } +type SMIMEA struct { + Hdr RR_Header + Usage uint8 + Selector uint8 + MatchingType uint8 + Certificate string `dns:"hex"` +} + +func (rr *SMIMEA) String() string { + s := rr.Hdr.String() + + strconv.Itoa(int(rr.Usage)) + + " " + strconv.Itoa(int(rr.Selector)) + + " " + strconv.Itoa(int(rr.MatchingType)) + + // Every Nth char needs a space on this output. If we output + // this as one giant line, we can't read it can in because in some cases + // the cert length overflows scan.maxTok (2048). + sx := splitN(rr.Certificate, 1024) // conservative value here + s += " " + strings.Join(sx, " ") + return s +} + type HIP struct { Hdr RR_Header HitLength uint8 @@ -1247,3 +1270,25 @@ func copyIP(ip net.IP) net.IP { copy(p, ip) return p } + +// SplitN splits a string into N sized string chunks. +// This might become an exported function once. +func splitN(s string, n int) []string { + if len(s) < n { + return []string{s} + } + sx := []string{} + p, i := 0, n + for { + if i <= len(s) { + sx = append(sx, s[p:i]) + } else { + sx = append(sx, s[p:]) + break + + } + p, i = p+n, i+n + } + + return sx +} diff --git a/vendor/github.com/miekg/dns/types_test.go b/vendor/github.com/miekg/dns/types_test.go index 118612946..c117cfbc7 100644 --- a/vendor/github.com/miekg/dns/types_test.go +++ b/vendor/github.com/miekg/dns/types_test.go @@ -40,3 +40,35 @@ func TestCmToM(t *testing.T) { t.Error("9, 9") } } + +func TestSplitN(t *testing.T) { + xs := splitN("abc", 5) + if len(xs) != 1 && xs[0] != "abc" { + t.Errorf("Failure to split abc") + } + + s := "" + for i := 0; i < 255; i++ { + s += "a" + } + + xs = splitN(s, 255) + if len(xs) != 1 && xs[0] != s { + t.Errorf("failure to split 255 char long string") + } + + s += "b" + xs = splitN(s, 255) + if len(xs) != 2 || xs[1] != "b" { + t.Errorf("failure to split 256 char long string: %d", len(xs)) + } + + // Make s longer + for i := 0; i < 255; i++ { + s += "a" + } + xs = splitN(s, 255) + if len(xs) != 3 || xs[2] != "a" { + t.Errorf("failure to split 510 char long string: %d", len(xs)) + } +} diff --git a/vendor/github.com/miekg/dns/zmsg.go b/vendor/github.com/miekg/dns/zmsg.go index 346d3102d..c561370e7 100644 --- a/vendor/github.com/miekg/dns/zmsg.go +++ b/vendor/github.com/miekg/dns/zmsg.go @@ -1085,6 +1085,32 @@ func (rr *SIG) pack(msg []byte, off int, compression map[string]int, compress bo return off, nil } +func (rr *SMIMEA) pack(msg []byte, off int, compression map[string]int, compress bool) (int, error) { + off, err := rr.Hdr.pack(msg, off, compression, compress) + if err != nil { + return off, err + } + headerEnd := off + off, err = packUint8(rr.Usage, msg, off) + if err != nil { + return off, err + } + off, err = packUint8(rr.Selector, msg, off) + if err != nil { + return off, err + } + off, err = packUint8(rr.MatchingType, msg, off) + if err != nil { + return off, err + } + off, err = packStringHex(rr.Certificate, msg, off) + if err != nil { + return off, err + } + rr.Header().Rdlength = uint16(off - headerEnd) + return off, nil +} + func (rr *SOA) pack(msg []byte, off int, compression map[string]int, compress bool) (int, error) { off, err := rr.Hdr.pack(msg, off, compression, compress) if err != nil { @@ -2907,6 +2933,44 @@ func unpackSIG(h RR_Header, msg []byte, off int) (RR, int, error) { return rr, off, err } +func unpackSMIMEA(h RR_Header, msg []byte, off int) (RR, int, error) { + rr := new(SMIMEA) + rr.Hdr = h + if noRdata(h) { + return rr, off, nil + } + var err error + rdStart := off + _ = rdStart + + rr.Usage, off, err = unpackUint8(msg, off) + if err != nil { + return rr, off, err + } + if off == len(msg) { + return rr, off, nil + } + rr.Selector, off, err = unpackUint8(msg, off) + if err != nil { + return rr, off, err + } + if off == len(msg) { + return rr, off, nil + } + rr.MatchingType, off, err = unpackUint8(msg, off) + if err != nil { + return rr, off, err + } + if off == len(msg) { + return rr, off, nil + } + rr.Certificate, off, err = unpackStringHex(msg, off, rdStart+int(rr.Hdr.Rdlength)) + if err != nil { + return rr, off, err + } + return rr, off, err +} + func unpackSOA(h RR_Header, msg []byte, off int) (RR, int, error) { rr := new(SOA) rr.Hdr = h @@ -3447,6 +3511,7 @@ var typeToUnpack = map[uint16]func(RR_Header, []byte, int) (RR, int, error){ TypeRRSIG: unpackRRSIG, TypeRT: unpackRT, TypeSIG: unpackSIG, + TypeSMIMEA: unpackSMIMEA, TypeSOA: unpackSOA, TypeSPF: unpackSPF, TypeSRV: unpackSRV, diff --git a/vendor/github.com/miekg/dns/ztypes.go b/vendor/github.com/miekg/dns/ztypes.go index a4ecbb0cc..3c052773e 100644 --- a/vendor/github.com/miekg/dns/ztypes.go +++ b/vendor/github.com/miekg/dns/ztypes.go @@ -62,6 +62,7 @@ var TypeToRR = map[uint16]func() RR{ TypeRRSIG: func() RR { return new(RRSIG) }, TypeRT: func() RR { return new(RT) }, TypeSIG: func() RR { return new(SIG) }, + TypeSMIMEA: func() RR { return new(SMIMEA) }, TypeSOA: func() RR { return new(SOA) }, TypeSPF: func() RR { return new(SPF) }, TypeSRV: func() RR { return new(SRV) }, @@ -141,6 +142,7 @@ var TypeToString = map[uint16]string{ TypeRT: "RT", TypeReserved: "Reserved", TypeSIG: "SIG", + TypeSMIMEA: "SMIMEA", TypeSOA: "SOA", TypeSPF: "SPF", TypeSRV: "SRV", @@ -213,6 +215,7 @@ func (rr *RP) Header() *RR_Header { return &rr.Hdr } func (rr *RRSIG) Header() *RR_Header { return &rr.Hdr } func (rr *RT) Header() *RR_Header { return &rr.Hdr } func (rr *SIG) Header() *RR_Header { return &rr.Hdr } +func (rr *SMIMEA) Header() *RR_Header { return &rr.Hdr } func (rr *SOA) Header() *RR_Header { return &rr.Hdr } func (rr *SPF) Header() *RR_Header { return &rr.Hdr } func (rr *SRV) Header() *RR_Header { return &rr.Hdr } @@ -514,6 +517,14 @@ func (rr *RT) len() int { l += len(rr.Host) + 1 return l } +func (rr *SMIMEA) len() int { + l := rr.Hdr.len() + l += 1 // Usage + l += 1 // Selector + l += 1 // MatchingType + l += len(rr.Certificate)/2 + 1 + return l +} func (rr *SOA) len() int { l := rr.Hdr.len() l += len(rr.Ns) + 1 @@ -780,6 +791,9 @@ func (rr *RRSIG) copy() RR { func (rr *RT) copy() RR { return &RT{*rr.Hdr.copyHeader(), rr.Preference, rr.Host} } +func (rr *SMIMEA) copy() RR { + return &SMIMEA{*rr.Hdr.copyHeader(), rr.Usage, rr.Selector, rr.MatchingType, rr.Certificate} +} func (rr *SOA) copy() RR { return &SOA{*rr.Hdr.copyHeader(), rr.Ns, rr.Mbox, rr.Serial, rr.Refresh, rr.Retry, rr.Expire, rr.Minttl} } diff --git a/vendor/github.com/nicksnyder/go-i18n/.travis.yml b/vendor/github.com/nicksnyder/go-i18n/.travis.yml index 8558bb44f..1202d43d6 100644 --- a/vendor/github.com/nicksnyder/go-i18n/.travis.yml +++ b/vendor/github.com/nicksnyder/go-i18n/.travis.yml @@ -5,4 +5,6 @@ go: - 1.3 - 1.4 - 1.5 + - 1.6 + - 1.7 - tip diff --git a/vendor/github.com/nicksnyder/go-i18n/goi18n/constants_command.go b/vendor/github.com/nicksnyder/go-i18n/goi18n/constants_command.go new file mode 100644 index 000000000..85b1ac18e --- /dev/null +++ b/vendor/github.com/nicksnyder/go-i18n/goi18n/constants_command.go @@ -0,0 +1,230 @@ +package main + +import ( + "flag" + "fmt" + "os" + "path/filepath" + "reflect" + "regexp" + "sort" + "strconv" + "strings" + "text/template" + "unicode" + + "github.com/nicksnyder/go-i18n/i18n/bundle" + "github.com/nicksnyder/go-i18n/i18n/language" + "github.com/nicksnyder/go-i18n/i18n/translation" +) + +type constantsCommand struct { + translationFiles []string + packageName string + outdir string +} + +type templateConstants struct { + ID string + Name string + Comments []string +} + +type templateHeader struct { + PackageName string + Constants []templateConstants +} + +var constTemplate = template.Must(template.New("").Parse(`// DON'T CHANGE THIS FILE MANUALLY +// This file was generated using the command: +// $ goi18n constants + +package {{.PackageName}} +{{range .Constants}} +// {{.Name}} is the identifier for the following localizable string template(s):{{range .Comments}} +// {{.}}{{end}} +const {{.Name}} = "{{.ID}}" +{{end}}`)) + +func (cc *constantsCommand) execute() error { + if len(cc.translationFiles) != 1 { + return fmt.Errorf("need one translation file") + } + + bundle := bundle.New() + + if err := bundle.LoadTranslationFile(cc.translationFiles[0]); err != nil { + return fmt.Errorf("failed to load translation file %s because %s\n", cc.translationFiles[0], err) + } + + translations := bundle.Translations() + lang := translations[bundle.LanguageTags()[0]] + + // create an array of id to organize + keys := make([]string, len(lang)) + i := 0 + + for id := range lang { + keys[i] = id + i++ + } + sort.Strings(keys) + + tmpl := &templateHeader{ + PackageName: cc.packageName, + Constants: make([]templateConstants, len(keys)), + } + + for i, id := range keys { + tmpl.Constants[i].ID = id + tmpl.Constants[i].Name = toCamelCase(id) + tmpl.Constants[i].Comments = toComments(lang[id]) + } + + filename := filepath.Join(cc.outdir, cc.packageName+".go") + f, err := os.Create(filename) + if err != nil { + return fmt.Errorf("failed to create file %s because %s", filename, err) + } + + defer f.Close() + + if err = constTemplate.Execute(f, tmpl); err != nil { + return fmt.Errorf("failed to write file %s because %s", filename, err) + } + + return nil +} + +func (cc *constantsCommand) parse(arguments []string) { + flags := flag.NewFlagSet("constants", flag.ExitOnError) + flags.Usage = usageConstants + + packageName := flags.String("package", "R", "") + outdir := flags.String("outdir", ".", "") + + flags.Parse(arguments) + + cc.translationFiles = flags.Args() + cc.packageName = *packageName + cc.outdir = *outdir +} + +func (cc *constantsCommand) SetArgs(args []string) { + cc.translationFiles = args +} + +func usageConstants() { + fmt.Printf(`Generate constant file from translation file. + +Usage: + + goi18n constants [options] [file] + +Translation files: + + A translation file contains the strings and translations for a single language. + + Translation file names must have a suffix of a supported format (e.g. .json) and + contain a valid language tag as defined by RFC 5646 (e.g. en-us, fr, zh-hant, etc.). + +Options: + + -package name + goi18n generates the constant file under the package name. + Default: R + + -outdir directory + goi18n writes the constant file to this directory. + Default: . + +`) + os.Exit(1) +} + +// commonInitialisms is a set of common initialisms. +// Only add entries that are highly unlikely to be non-initialisms. +// For instance, "ID" is fine (Freudian code is rare), but "AND" is not. +// https://github.com/golang/lint/blob/master/lint.go +var commonInitialisms = map[string]bool{ + "API": true, + "ASCII": true, + "CPU": true, + "CSS": true, + "DNS": true, + "EOF": true, + "GUID": true, + "HTML": true, + "HTTP": true, + "HTTPS": true, + "ID": true, + "IP": true, + "JSON": true, + "LHS": true, + "QPS": true, + "RAM": true, + "RHS": true, + "RPC": true, + "SLA": true, + "SMTP": true, + "SQL": true, + "SSH": true, + "TCP": true, + "TLS": true, + "TTL": true, + "UDP": true, + "UI": true, + "UID": true, + "UUID": true, + "URI": true, + "URL": true, + "UTF8": true, + "VM": true, + "XML": true, + "XSRF": true, + "XSS": true, +} + +func toCamelCase(id string) string { + var result string + + r := regexp.MustCompile(`[\-\.\_\s]`) + words := r.Split(id, -1) + + for _, w := range words { + upper := strings.ToUpper(w) + if commonInitialisms[upper] { + result += upper + continue + } + + if len(w) > 0 { + u := []rune(w) + u[0] = unicode.ToUpper(u[0]) + result += string(u) + } + } + return result +} + +func toComments(trans translation.Translation) []string { + var result []string + data := trans.MarshalInterface().(map[string]interface{}) + + t := data["translation"] + + switch v := reflect.ValueOf(t); v.Kind() { + case reflect.Map: + for _, k := range []language.Plural{"zero", "one", "two", "few", "many", "other"} { + vt := v.MapIndex(reflect.ValueOf(k)) + if !vt.IsValid() { + continue + } + result = append(result, string(k)+": "+strconv.Quote(fmt.Sprint(vt.Interface()))) + } + default: + result = append(result, strconv.Quote(fmt.Sprint(t))) + } + + return result +} diff --git a/vendor/github.com/nicksnyder/go-i18n/goi18n/constants_command_test.go b/vendor/github.com/nicksnyder/go-i18n/goi18n/constants_command_test.go new file mode 100644 index 000000000..43dea3f38 --- /dev/null +++ b/vendor/github.com/nicksnyder/go-i18n/goi18n/constants_command_test.go @@ -0,0 +1,42 @@ +package main + +import "testing" + +func TestConstantsExecute(t *testing.T) { + resetDir(t, "testdata/output") + + cc := &constantsCommand{ + translationFiles: []string{"testdata/input/en-us.constants.json"}, + packageName: "R", + outdir: "testdata/output", + } + + if err := cc.execute(); err != nil { + t.Fatal(err) + } + + expectEqualFiles(t, "testdata/output/R.go", "testdata/expected/R.go") +} + +func TestToCamelCase(t *testing.T) { + expectEqual := func(test, expected string) { + result := toCamelCase(test) + if result != expected { + t.Fatalf("failed toCamelCase the test %s was expected %s but the result was %s", test, expected, result) + } + } + + expectEqual("", "") + expectEqual("a", "A") + expectEqual("_", "") + expectEqual("__code__", "Code") + expectEqual("test", "Test") + expectEqual("test_one", "TestOne") + expectEqual("test.two", "TestTwo") + expectEqual("test_alpha_beta", "TestAlphaBeta") + expectEqual("word word", "WordWord") + expectEqual("test_id", "TestID") + expectEqual("tcp_name", "TCPName") + expectEqual("こんにちは", "こんにちは") + expectEqual("test_a", "TestA") +} diff --git a/vendor/github.com/nicksnyder/go-i18n/goi18n/doc.go b/vendor/github.com/nicksnyder/go-i18n/goi18n/doc.go index 10d244217..97c7a7fb6 100644 --- a/vendor/github.com/nicksnyder/go-i18n/goi18n/doc.go +++ b/vendor/github.com/nicksnyder/go-i18n/goi18n/doc.go @@ -5,11 +5,22 @@ // // Help documentation: // -// goi18n formats and merges translation files. +// goi18n manages translation files. // // Usage: // -// goi18n [options] [files...] +// goi18n merge Merge translation files +// goi18n constants Generate constant file from translation file +// +// For more details execute: +// +// goi18n [command] -help +// +// Merge translation files. +// +// Usage: +// +// goi18n merge [options] [files...] // // Translation files: // @@ -56,4 +67,27 @@ // Supported formats: json, yaml // Default: json // +// Generate constant file from translation file. +// +// Usage: +// +// goi18n constants [options] [file] +// +// Translation files: +// +// A translation file contains the strings and translations for a single language. +// +// Translation file names must have a suffix of a supported format (e.g. .json) and +// contain a valid language tag as defined by RFC 5646 (e.g. en-us, fr, zh-hant, etc.). +// +// Options: +// +// -package name +// goi18n generates the constant file under the package name. +// Default: R +// +// -outdir directory +// goi18n writes the constant file to this directory. +// Default: . +// package main diff --git a/vendor/github.com/nicksnyder/go-i18n/goi18n/gendoc.sh b/vendor/github.com/nicksnyder/go-i18n/goi18n/gendoc.sh index 094f479a5..f30df34e6 100644 --- a/vendor/github.com/nicksnyder/go-i18n/goi18n/gendoc.sh +++ b/vendor/github.com/nicksnyder/go-i18n/goi18n/gendoc.sh @@ -6,5 +6,7 @@ echo "// goi18n -help" >> doc.go echo "//" >> doc.go echo "// Help documentation:" >> doc.go echo "//" >> doc.go -goi18n -help | sed -e 's/^/\/\/ /' >> doc.go +goi18n | sed -e 's/^/\/\/ /' >> doc.go +goi18n merge -help | sed -e 's/^/\/\/ /' >> doc.go +goi18n constants -help | sed -e 's/^/\/\/ /' >> doc.go echo "package main" >> doc.go diff --git a/vendor/github.com/nicksnyder/go-i18n/goi18n/goi18n.go b/vendor/github.com/nicksnyder/go-i18n/goi18n/goi18n.go index f57ea8175..3bd763f47 100644 --- a/vendor/github.com/nicksnyder/go-i18n/goi18n/goi18n.go +++ b/vendor/github.com/nicksnyder/go-i18n/goi18n/goi18n.go @@ -6,77 +6,50 @@ import ( "os" ) -func usage() { - fmt.Printf(`goi18n formats and merges translation files. - -Usage: - - goi18n [options] [files...] - -Translation files: - - A translation file contains the strings and translations for a single language. - - Translation file names must have a suffix of a supported format (e.g. .json) and - contain a valid language tag as defined by RFC 5646 (e.g. en-us, fr, zh-hant, etc.). - - For each language represented by at least one input translation file, goi18n will produce 2 output files: - - xx-yy.all.format - This file contains all strings for the language (translated and untranslated). - Use this file when loading strings at runtime. +type command interface { + execute() error + parse(arguments []string) +} - xx-yy.untranslated.format - This file contains the strings that have not been translated for this language. - The translations for the strings in this file will be extracted from the source language. - After they are translated, merge them back into xx-yy.all.format using goi18n. +func main() { + flag.Usage = usage -Merging: + if len(os.Args) == 1 { + usage() + } - goi18n will merge multiple translation files for the same language. - Duplicate translations will be merged into the existing translation. - Non-empty fields in the duplicate translation will overwrite those fields in the existing translation. - Empty fields in the duplicate translation are ignored. + var cmd command + + switch os.Args[1] { + case "merge": + cmd = &mergeCommand{} + cmd.parse(os.Args[2:]) + case "constants": + cmd = &constantsCommand{} + cmd.parse(os.Args[2:]) + default: + cmd = &mergeCommand{} + cmd.parse(os.Args[1:]) + } -Adding a new language: + if err := cmd.execute(); err != nil { + fmt.Println(err.Error()) + os.Exit(1) + } +} - To produce translation files for a new language, create an empty translation file with the - appropriate name and pass it in to goi18n. +func usage() { + fmt.Printf(`goi18n manages translation files. -Options: +Usage: - -sourceLanguage tag - goi18n uses the strings from this language to seed the translations for other languages. - Default: en-us + goi18n merge Merge translation files + goi18n constants Generate constant file from translation file - -outdir directory - goi18n writes the output translation files to this directory. - Default: . +For more details execute: - -format format - goi18n encodes the output translation files in this format. - Supported formats: json, yaml - Default: json + goi18n [command] -help `) os.Exit(1) } - -func main() { - flag.Usage = usage - sourceLanguage := flag.String("sourceLanguage", "en-us", "") - outdir := flag.String("outdir", ".", "") - format := flag.String("format", "json", "") - flag.Parse() - - mc := &mergeCommand{ - translationFiles: flag.Args(), - sourceLanguageTag: *sourceLanguage, - outdir: *outdir, - format: *format, - } - if err := mc.execute(); err != nil { - fmt.Println(err.Error()) - os.Exit(1) - } -} diff --git a/vendor/github.com/nicksnyder/go-i18n/goi18n/merge.go b/vendor/github.com/nicksnyder/go-i18n/goi18n/merge_command.go index 1317fe958..1d34ac438 100644 --- a/vendor/github.com/nicksnyder/go-i18n/goi18n/merge.go +++ b/vendor/github.com/nicksnyder/go-i18n/goi18n/merge_command.go @@ -2,23 +2,26 @@ package main import ( "encoding/json" + "flag" "fmt" - "gopkg.in/yaml.v2" "io/ioutil" + "os" "path/filepath" "reflect" "sort" + "gopkg.in/yaml.v2" + "github.com/nicksnyder/go-i18n/i18n/bundle" "github.com/nicksnyder/go-i18n/i18n/language" "github.com/nicksnyder/go-i18n/i18n/translation" ) type mergeCommand struct { - translationFiles []string - sourceLanguageTag string - outdir string - format string + translationFiles []string + sourceLanguage string + outdir string + format string } func (mc *mergeCommand) execute() error { @@ -26,8 +29,8 @@ func (mc *mergeCommand) execute() error { return fmt.Errorf("need at least one translation file to parse") } - if lang := language.Parse(mc.sourceLanguageTag); lang == nil { - return fmt.Errorf("invalid source locale: %s", mc.sourceLanguageTag) + if lang := language.Parse(mc.sourceLanguage); lang == nil { + return fmt.Errorf("invalid source locale: %s", mc.sourceLanguage) } marshal, err := newMarshalFunc(mc.format) @@ -43,7 +46,7 @@ func (mc *mergeCommand) execute() error { } translations := bundle.Translations() - sourceLanguageTag := language.NormalizeTag(mc.sourceLanguageTag) + sourceLanguageTag := language.NormalizeTag(mc.sourceLanguage) sourceTranslations := translations[sourceLanguageTag] if sourceTranslations == nil { return fmt.Errorf("no translations found for source locale %s", sourceLanguageTag) @@ -78,6 +81,26 @@ func (mc *mergeCommand) execute() error { return nil } +func (mc *mergeCommand) parse(arguments []string) { + flags := flag.NewFlagSet("merge", flag.ExitOnError) + flags.Usage = usageMerge + + sourceLanguage := flags.String("sourceLanguage", "en-us", "") + outdir := flags.String("outdir", ".", "") + format := flags.String("format", "json", "") + + flags.Parse(arguments) + + mc.translationFiles = flags.Args() + mc.sourceLanguage = *sourceLanguage + mc.outdir = *outdir + mc.format = *format +} + +func (mc *mergeCommand) SetArgs(args []string) { + mc.translationFiles = args +} + type marshalFunc func(interface{}) ([]byte, error) func (mc *mergeCommand) writeFile(label string, translations []translation.Translation, localeID string, marshal marshalFunc) error { @@ -125,3 +148,59 @@ func marshalInterface(translations []translation.Translation) []interface{} { } return mi } + +func usageMerge() { + fmt.Printf(`Merge translation files. + +Usage: + + goi18n merge [options] [files...] + +Translation files: + + A translation file contains the strings and translations for a single language. + + Translation file names must have a suffix of a supported format (e.g. .json) and + contain a valid language tag as defined by RFC 5646 (e.g. en-us, fr, zh-hant, etc.). + + For each language represented by at least one input translation file, goi18n will produce 2 output files: + + xx-yy.all.format + This file contains all strings for the language (translated and untranslated). + Use this file when loading strings at runtime. + + xx-yy.untranslated.format + This file contains the strings that have not been translated for this language. + The translations for the strings in this file will be extracted from the source language. + After they are translated, merge them back into xx-yy.all.format using goi18n. + +Merging: + + goi18n will merge multiple translation files for the same language. + Duplicate translations will be merged into the existing translation. + Non-empty fields in the duplicate translation will overwrite those fields in the existing translation. + Empty fields in the duplicate translation are ignored. + +Adding a new language: + + To produce translation files for a new language, create an empty translation file with the + appropriate name and pass it in to goi18n. + +Options: + + -sourceLanguage tag + goi18n uses the strings from this language to seed the translations for other languages. + Default: en-us + + -outdir directory + goi18n writes the output translation files to this directory. + Default: . + + -format format + goi18n encodes the output translation files in this format. + Supported formats: json, yaml + Default: json + +`) + os.Exit(1) +} diff --git a/vendor/github.com/nicksnyder/go-i18n/goi18n/merge_test.go b/vendor/github.com/nicksnyder/go-i18n/goi18n/merge_command_test.go index f0d0d47a1..37e46518b 100644 --- a/vendor/github.com/nicksnyder/go-i18n/goi18n/merge_test.go +++ b/vendor/github.com/nicksnyder/go-i18n/goi18n/merge_command_test.go @@ -33,10 +33,10 @@ func testMergeExecute(t *testing.T, files []string) { resetDir(t, "testdata/output") mc := &mergeCommand{ - translationFiles: files, - sourceLanguageTag: "en-us", - outdir: "testdata/output", - format: "json", + translationFiles: files, + sourceLanguage: "en-us", + outdir: "testdata/output", + format: "json", } if err := mc.execute(); err != nil { t.Fatal(err) diff --git a/vendor/github.com/nicksnyder/go-i18n/goi18n/testdata/expected/R.go b/vendor/github.com/nicksnyder/go-i18n/goi18n/testdata/expected/R.go new file mode 100644 index 000000000..9b5334a7e --- /dev/null +++ b/vendor/github.com/nicksnyder/go-i18n/goi18n/testdata/expected/R.go @@ -0,0 +1,38 @@ +// DON'T CHANGE THIS FILE MANUALLY +// This file was generated using the command: +// $ goi18n constants + +package R + +// DDays is the identifier for the following localizable string template(s): +// one: "{{.Count}} day" +// other: "{{.Count}} days" +const DDays = "d_days" + +// MyHeightInMeters is the identifier for the following localizable string template(s): +// one: "I am {{.Count}} meter tall." +// other: "I am {{.Count}} meters tall." +const MyHeightInMeters = "my_height_in_meters" + +// PersonGreeting is the identifier for the following localizable string template(s): +// "Hello {{.Person}}" +const PersonGreeting = "person_greeting" + +// PersonUnreadEmailCount is the identifier for the following localizable string template(s): +// one: "{{.Person}} has {{.Count}} unread email." +// other: "{{.Person}} has {{.Count}} unread emails." +const PersonUnreadEmailCount = "person_unread_email_count" + +// PersonUnreadEmailCountTimeframe is the identifier for the following localizable string template(s): +// one: "{{.Person}} has {{.Count}} unread email in the past {{.Timeframe}}." +// other: "{{.Person}} has {{.Count}} unread emails in the past {{.Timeframe}}." +const PersonUnreadEmailCountTimeframe = "person_unread_email_count_timeframe" + +// ProgramGreeting is the identifier for the following localizable string template(s): +// "Hello world" +const ProgramGreeting = "program_greeting" + +// YourUnreadEmailCount is the identifier for the following localizable string template(s): +// one: "You have {{.Count}} unread email." +// other: "You have {{.Count}} unread emails." +const YourUnreadEmailCount = "your_unread_email_count" diff --git a/vendor/github.com/nicksnyder/go-i18n/goi18n/testdata/input/en-us.constants.json b/vendor/github.com/nicksnyder/go-i18n/goi18n/testdata/input/en-us.constants.json new file mode 100644 index 000000000..5aedc235a --- /dev/null +++ b/vendor/github.com/nicksnyder/go-i18n/goi18n/testdata/input/en-us.constants.json @@ -0,0 +1,45 @@ +[ + { + "id": "d_days", + "translation": { + "one": "{{.Count}} day", + "other": "{{.Count}} days" + } + }, + { + "id": "my_height_in_meters", + "translation": { + "one": "I am {{.Count}} meter tall.", + "other": "I am {{.Count}} meters tall." + } + }, + { + "id": "person_greeting", + "translation": "Hello {{.Person}}" + }, + { + "id": "person_unread_email_count", + "translation": { + "one": "{{.Person}} has {{.Count}} unread email.", + "other": "{{.Person}} has {{.Count}} unread emails." + } + }, + { + "id": "person_unread_email_count_timeframe", + "translation": { + "one": "{{.Person}} has {{.Count}} unread email in the past {{.Timeframe}}.", + "other": "{{.Person}} has {{.Count}} unread emails in the past {{.Timeframe}}." + } + }, + { + "id": "program_greeting", + "translation": "Hello world" + }, + { + "id": "your_unread_email_count", + "translation": { + "one": "You have {{.Count}} unread email.", + "other": "You have {{.Count}} unread emails." + } + } +]
\ No newline at end of file diff --git a/vendor/github.com/nicksnyder/go-i18n/i18n/bundle/bundle.go b/vendor/github.com/nicksnyder/go-i18n/i18n/bundle/bundle.go index e93db95d7..8e46fa296 100644 --- a/vendor/github.com/nicksnyder/go-i18n/i18n/bundle/bundle.go +++ b/vendor/github.com/nicksnyder/go-i18n/i18n/bundle/bundle.go @@ -260,6 +260,11 @@ func (b *Bundle) translate(lang *language.Language, translationID string, args . dataMap["Count"] = count data = dataMap } + } else { + dataMap := toMap(data) + if c, ok := dataMap["Count"]; ok { + count = c + } } p, _ := lang.Plural(count) diff --git a/vendor/github.com/nicksnyder/go-i18n/i18n/bundle/bundle_test.go b/vendor/github.com/nicksnyder/go-i18n/i18n/bundle/bundle_test.go index b9c0a0593..b241ad1d4 100644 --- a/vendor/github.com/nicksnyder/go-i18n/i18n/bundle/bundle_test.go +++ b/vendor/github.com/nicksnyder/go-i18n/i18n/bundle/bundle_test.go @@ -270,6 +270,26 @@ func BenchmarkTranslatePluralWithMap(b *testing.B) { } } +func BenchmarkTranslatePluralWithMapAndCountField(b *testing.B) { + data := map[string]interface{}{ + "Person": "Bob", + "Count": 26, + } + + translationTemplate := map[string]interface{}{ + "one": "{{.Person}} is {{.Count}} year old.", + "other": "{{.Person}} is {{.Count}} years old.", + } + expected := "Bob is 26 years old." + + tf := createBenchmarkTranslateFunc(b, translationTemplate, nil, expected) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + tf(data) + } +} + func BenchmarkTranslatePluralWithStruct(b *testing.B) { data := struct{ Person string }{Person: "Bob"} tf := createBenchmarkPluralTranslateFunc(b) diff --git a/vendor/github.com/nicksnyder/go-i18n/i18n/example_test.go b/vendor/github.com/nicksnyder/go-i18n/i18n/example_test.go index d2d9706a7..305c5b3df 100644 --- a/vendor/github.com/nicksnyder/go-i18n/i18n/example_test.go +++ b/vendor/github.com/nicksnyder/go-i18n/i18n/example_test.go @@ -30,6 +30,15 @@ func Example() { fmt.Println(T("person_unread_email_count", 1, bobStruct)) fmt.Println(T("person_unread_email_count", 2, bobStruct)) + type Count struct{ Count int } + fmt.Println(T("your_unread_email_count", Count{0})) + fmt.Println(T("your_unread_email_count", Count{1})) + fmt.Println(T("your_unread_email_count", Count{2})) + + fmt.Println(T("your_unread_email_count", map[string]interface{}{"Count": 0})) + fmt.Println(T("your_unread_email_count", map[string]interface{}{"Count": "1"})) + fmt.Println(T("your_unread_email_count", map[string]interface{}{"Count": "3.14"})) + fmt.Println(T("person_unread_email_count_timeframe", 3, map[string]interface{}{ "Person": "Bob", "Timeframe": T("d_days", 0), @@ -43,6 +52,22 @@ func Example() { "Timeframe": T("d_days", 2), })) + fmt.Println(T("person_unread_email_count_timeframe", 1, map[string]interface{}{ + "Count": 30, + "Person": "Bob", + "Timeframe": T("d_days", 0), + })) + fmt.Println(T("person_unread_email_count_timeframe", 2, map[string]interface{}{ + "Count": 20, + "Person": "Bob", + "Timeframe": T("d_days", 1), + })) + fmt.Println(T("person_unread_email_count_timeframe", 3, map[string]interface{}{ + "Count": 10, + "Person": "Bob", + "Timeframe": T("d_days", 2), + })) + // Output: // Hello world // Hello Bob @@ -57,7 +82,16 @@ func Example() { // Bob has 0 unread emails. // Bob has 1 unread email. // Bob has 2 unread emails. + // You have 0 unread emails. + // You have 1 unread email. + // You have 2 unread emails. + // You have 0 unread emails. + // You have 1 unread email. + // You have 3.14 unread emails. // Bob has 3 unread emails in the past 0 days. // Bob has 3 unread emails in the past 1 day. // Bob has 3 unread emails in the past 2 days. + // Bob has 1 unread email in the past 0 days. + // Bob has 2 unread emails in the past 1 day. + // Bob has 3 unread emails in the past 2 days. } diff --git a/vendor/github.com/nicksnyder/go-i18n/i18n/i18n.go b/vendor/github.com/nicksnyder/go-i18n/i18n/i18n.go index f96842966..c478ff6ea 100644 --- a/vendor/github.com/nicksnyder/go-i18n/i18n/i18n.go +++ b/vendor/github.com/nicksnyder/go-i18n/i18n/i18n.go @@ -69,9 +69,15 @@ import ( // If translationID is a non-plural form, then the first variadic argument may be a map[string]interface{} // or struct that contains template data. // -// If translationID is a plural form, then the first variadic argument must be an integer type +// If translationID is a plural form, the function accepts two parameter signatures +// 1. T(count int, data struct{}) +// The first variadic argument must be an integer type // (int, int8, int16, int32, int64) or a float formatted as a string (e.g. "123.45"). -// The second variadic argument may be a map[string]interface{} or struct that contains template data. +// The second variadic argument may be a map[string]interface{} or struct{} that contains template data. +// 2. T(data struct{}) +// data must be a struct{} or map[string]interface{} that contains a Count field and the template data, +// Count field must be an integer type (int, int8, int16, int32, int64) +// or a float formatted as a string (e.g. "123.45"). type TranslateFunc func(translationID string, args ...interface{}) string // IdentityTfunc returns a TranslateFunc that always returns the translationID passed to it. diff --git a/vendor/github.com/rsc/letsencrypt/README b/vendor/github.com/rsc/letsencrypt/README index 98a875f37..575ae16a8 100644 --- a/vendor/github.com/rsc/letsencrypt/README +++ b/vendor/github.com/rsc/letsencrypt/README @@ -7,6 +7,31 @@ servers that can prove control over the given domain's DNS records or the servers pointed at by those records. +Warning + +Like any other random code you find on the internet, this package should not +be relied upon in important, production systems without thorough testing to +ensure that it meets your needs. + +In the long term you should be using +https://golang.org/x/crypto/acme/autocert instead of this package. Send +improvements there, not here. + +This is a package that I wrote for my own personal web sites (swtch.com, +rsc.io) in a hurry when my paid-for SSL certificate was expiring. It has no +tests, has barely been used, and there is some anecdotal evidence that it +does not properly renew certificates in a timely fashion, so servers that +run for more than 3 months may run into trouble. I don't run this code +anymore: to simplify maintenance, I moved the sites off of Ubuntu VMs and +onto Google App Engine, configured with inexpensive long-term certificates +purchased from cheapsslsecurity.com. + +This package was interesting primarily as an example of how simple the API +for using LetsEncrypt.org could be made, in contrast to the low-level +implementations that existed at the time. In that respect, it helped inform +the design of the golang.org/x/crypto/acme/autocert package. + + Quick Start A complete HTTP/HTTPS web server using TLS certificates from diff --git a/vendor/github.com/rsc/letsencrypt/lets.go b/vendor/github.com/rsc/letsencrypt/lets.go index c0168b56a..f112af31c 100644 --- a/vendor/github.com/rsc/letsencrypt/lets.go +++ b/vendor/github.com/rsc/letsencrypt/lets.go @@ -8,6 +8,30 @@ // that can prove control over the given domain's DNS records or // the servers pointed at by those records. // +// Warning +// +// Like any other random code you find on the internet, this package should +// not be relied upon in important, production systems without thorough testing +// to ensure that it meets your needs. +// +// In the long term you should be using +// https://golang.org/x/crypto/acme/autocert instead of this package. +// Send improvements there, not here. +// +// This is a package that I wrote for my own personal web sites (swtch.com, rsc.io) +// in a hurry when my paid-for SSL certificate was expiring. It has no tests, +// has barely been used, and there is some anecdotal evidence that it does +// not properly renew certificates in a timely fashion, so servers that run for +// more than 3 months may run into trouble. +// I don't run this code anymore: to simplify maintenance, I moved the sites +// off of Ubuntu VMs and onto Google App Engine, configured with inexpensive +// long-term certificates purchased from cheapsslsecurity.com. +// +// This package was interesting primarily as an example of how simple the API +// for using LetsEncrypt.org could be made, in contrast to the low-level +// implementations that existed at the time. In that respect, it helped inform +// the design of the golang.org/x/crypto/acme/autocert package. +// // Quick Start // // A complete HTTP/HTTPS web server using TLS certificates from LetsEncrypt.org, diff --git a/vendor/github.com/xenolf/lego/.travis.yml b/vendor/github.com/xenolf/lego/.travis.yml index f1af03bd6..e37f07962 100644 --- a/vendor/github.com/xenolf/lego/.travis.yml +++ b/vendor/github.com/xenolf/lego/.travis.yml @@ -3,6 +3,10 @@ go: - 1.6.3 - 1.7 - tip +services: + - memcached +env: + - MEMCACHED_HOSTS=localhost:11211 install: - go get -t ./... script: diff --git a/vendor/github.com/xenolf/lego/Dockerfile b/vendor/github.com/xenolf/lego/Dockerfile index 3749dfcee..c03964076 100644 --- a/vendor/github.com/xenolf/lego/Dockerfile +++ b/vendor/github.com/xenolf/lego/Dockerfile @@ -7,7 +7,7 @@ RUN apk update && apk add ca-certificates go git && \ go get -u github.com/xenolf/lego && \ cd /go/src/github.com/xenolf/lego && \ go build -o /usr/bin/lego . && \ - apk del ca-certificates go git && \ + apk del go git && \ rm -rf /var/cache/apk/* && \ rm -rf /go diff --git a/vendor/github.com/xenolf/lego/acme/client.go b/vendor/github.com/xenolf/lego/acme/client.go index 5eae8d26a..9f837af36 100644 --- a/vendor/github.com/xenolf/lego/acme/client.go +++ b/vendor/github.com/xenolf/lego/acme/client.go @@ -97,7 +97,7 @@ func NewClient(caDirURL string, user User, keyType KeyType) (*Client, error) { return &Client{directory: dir, user: user, jws: jws, keyType: keyType, solvers: solvers}, nil } -// SetChallengeProvider specifies a custom provider that will make the solution available +// SetChallengeProvider specifies a custom provider p that can solve the given challenge type. func (c *Client) SetChallengeProvider(challenge Challenge, p ChallengeProvider) error { switch challenge { case HTTP01: @@ -115,6 +115,9 @@ func (c *Client) SetChallengeProvider(challenge Challenge, p ChallengeProvider) // SetHTTPAddress specifies a custom interface:port to be used for HTTP based challenges. // If this option is not used, the default port 80 and all interfaces will be used. // To only specify a port and no interface use the ":port" notation. +// +// NOTE: This REPLACES any custom HTTP provider previously set by calling +// c.SetChallengeProvider with the default HTTP challenge provider. func (c *Client) SetHTTPAddress(iface string) error { host, port, err := net.SplitHostPort(iface) if err != nil { @@ -131,6 +134,9 @@ func (c *Client) SetHTTPAddress(iface string) error { // SetTLSAddress specifies a custom interface:port to be used for TLS based challenges. // If this option is not used, the default port 443 and all interfaces will be used. // To only specify a port and no interface use the ":port" notation. +// +// NOTE: This REPLACES any custom TLS-SNI provider previously set by calling +// c.SetChallengeProvider with the default TLS-SNI challenge provider. func (c *Client) SetTLSAddress(iface string) error { host, port, err := net.SplitHostPort(iface) if err != nil { @@ -347,7 +353,7 @@ DNSNames: // your issued certificate as a bundle. // This function will never return a partial certificate. If one domain in the list fails, // the whole certificate will fail. -func (c *Client) ObtainCertificate(domains []string, bundle bool, privKey crypto.PrivateKey) (CertificateResource, map[string]error) { +func (c *Client) ObtainCertificate(domains []string, bundle bool, privKey crypto.PrivateKey, mustStaple bool) (CertificateResource, map[string]error) { if bundle { logf("[INFO][%s] acme: Obtaining bundled SAN certificate", strings.Join(domains, ", ")) } else { @@ -368,7 +374,7 @@ func (c *Client) ObtainCertificate(domains []string, bundle bool, privKey crypto logf("[INFO][%s] acme: Validations succeeded; requesting certificates", strings.Join(domains, ", ")) - cert, err := c.requestCertificate(challenges, bundle, privKey) + cert, err := c.requestCertificate(challenges, bundle, privKey, mustStaple) if err != nil { for _, chln := range challenges { failures[chln.Domain] = err @@ -404,7 +410,7 @@ func (c *Client) RevokeCertificate(certificate []byte) error { // If bundle is true, the []byte contains both the issuer certificate and // your issued certificate as a bundle. // For private key reuse the PrivateKey property of the passed in CertificateResource should be non-nil. -func (c *Client) RenewCertificate(cert CertificateResource, bundle bool) (CertificateResource, error) { +func (c *Client) RenewCertificate(cert CertificateResource, bundle, mustStaple bool) (CertificateResource, error) { // Input certificate is PEM encoded. Decode it here as we may need the decoded // cert later on in the renewal process. The input may be a bundle or a single certificate. certificates, err := parsePEMBundle(cert.Certificate) @@ -421,50 +427,7 @@ func (c *Client) RenewCertificate(cert CertificateResource, bundle bool) (Certif timeLeft := x509Cert.NotAfter.Sub(time.Now().UTC()) logf("[INFO][%s] acme: Trying renewal with %d hours remaining", cert.Domain, int(timeLeft.Hours())) - // The first step of renewal is to check if we get a renewed cert - // directly from the cert URL. - resp, err := httpGet(cert.CertURL) - if err != nil { - return CertificateResource{}, err - } - defer resp.Body.Close() - serverCertBytes, err := ioutil.ReadAll(resp.Body) - if err != nil { - return CertificateResource{}, err - } - - serverCert, err := x509.ParseCertificate(serverCertBytes) - if err != nil { - return CertificateResource{}, err - } - - // If the server responds with a different certificate we are effectively renewed. - // TODO: Further test if we can actually use the new certificate (Our private key works) - if !x509Cert.Equal(serverCert) { - logf("[INFO][%s] acme: Server responded with renewed certificate", cert.Domain) - issuedCert := pemEncode(derCertificateBytes(serverCertBytes)) - // If bundle is true, we want to return a certificate bundle. - // To do this, we need the issuer certificate. - if bundle { - // The issuer certificate link is always supplied via an "up" link - // in the response headers of a new certificate. - links := parseLinks(resp.Header["Link"]) - issuerCert, err := c.getIssuerCertificate(links["up"]) - if err != nil { - // If we fail to acquire the issuer cert, return the issued certificate - do not fail. - logf("[ERROR][%s] acme: Could not bundle issuer certificate: %v", cert.Domain, err) - } else { - // Success - append the issuer cert to the issued cert. - issuerCert = pemEncode(derCertificateBytes(issuerCert)) - issuedCert = append(issuedCert, issuerCert...) - } - } - - cert.Certificate = issuedCert - return cert, nil - } - - // If the certificate is the same, then we need to request a new certificate. + // We always need to request a new certificate to renew. // Start by checking to see if the certificate was based off a CSR, and // use that if it's defined. if len(cert.CSR) > 0 { @@ -499,7 +462,7 @@ func (c *Client) RenewCertificate(cert CertificateResource, bundle bool) (Certif domains = append(domains, x509Cert.Subject.CommonName) } - newCert, failures := c.ObtainCertificate(domains, bundle, privKey) + newCert, failures := c.ObtainCertificate(domains, bundle, privKey, mustStaple) return newCert, failures[cert.Domain] } @@ -600,7 +563,7 @@ func (c *Client) getChallenges(domains []string) ([]authorizationResource, map[s return challenges, failures } -func (c *Client) requestCertificate(authz []authorizationResource, bundle bool, privKey crypto.PrivateKey) (CertificateResource, error) { +func (c *Client) requestCertificate(authz []authorizationResource, bundle bool, privKey crypto.PrivateKey, mustStaple bool) (CertificateResource, error) { if len(authz) == 0 { return CertificateResource{}, errors.New("Passed no authorizations to requestCertificate!") } @@ -621,7 +584,7 @@ func (c *Client) requestCertificate(authz []authorizationResource, bundle bool, } // TODO: should the CSR be customizable? - csr, err := generateCsr(privKey, commonName.Domain, san) + csr, err := generateCsr(privKey, commonName.Domain, san, mustStaple) if err != nil { return CertificateResource{}, err } diff --git a/vendor/github.com/xenolf/lego/acme/crypto.go b/vendor/github.com/xenolf/lego/acme/crypto.go index af97f5d1e..c63b23b99 100644 --- a/vendor/github.com/xenolf/lego/acme/crypto.go +++ b/vendor/github.com/xenolf/lego/acme/crypto.go @@ -20,6 +20,8 @@ import ( "strings" "time" + "encoding/asn1" + "golang.org/x/crypto/ocsp" ) @@ -47,6 +49,12 @@ const ( OCSPServerFailed = ocsp.ServerFailed ) +// Constants for OCSP must staple +var ( + tlsFeatureExtensionOID = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 1, 24} + ocspMustStapleFeature = []byte{0x30, 0x03, 0x02, 0x01, 0x05} +) + // GetOCSPForCert takes a PEM encoded cert or cert bundle returning the raw OCSP response, // the parsed response, and an error, if any. The returned []byte can be passed directly // into the OCSPStaple property of a tls.Certificate. If the bundle only contains the @@ -206,7 +214,7 @@ func generatePrivateKey(keyType KeyType) (crypto.PrivateKey, error) { return nil, fmt.Errorf("Invalid KeyType: %s", keyType) } -func generateCsr(privateKey crypto.PrivateKey, domain string, san []string) ([]byte, error) { +func generateCsr(privateKey crypto.PrivateKey, domain string, san []string, mustStaple bool) ([]byte, error) { template := x509.CertificateRequest{ Subject: pkix.Name{ CommonName: domain, @@ -217,6 +225,13 @@ func generateCsr(privateKey crypto.PrivateKey, domain string, san []string) ([]b template.DNSNames = san } + if mustStaple { + template.Extensions = append(template.Extensions, pkix.Extension{ + Id: tlsFeatureExtensionOID, + Value: ocspMustStapleFeature, + }) + } + return x509.CreateCertificateRequest(rand.Reader, &template, privateKey) } diff --git a/vendor/github.com/xenolf/lego/acme/crypto_test.go b/vendor/github.com/xenolf/lego/acme/crypto_test.go index d2fc5088b..6f43835fb 100644 --- a/vendor/github.com/xenolf/lego/acme/crypto_test.go +++ b/vendor/github.com/xenolf/lego/acme/crypto_test.go @@ -24,7 +24,7 @@ func TestGenerateCSR(t *testing.T) { t.Fatal("Error generating private key:", err) } - csr, err := generateCsr(key, "fizz.buzz", nil) + csr, err := generateCsr(key, "fizz.buzz", nil, true) if err != nil { t.Error("Error generating CSR:", err) } diff --git a/vendor/github.com/xenolf/lego/acme/dns_challenge.go b/vendor/github.com/xenolf/lego/acme/dns_challenge.go index c5fd354a1..30f2170ff 100644 --- a/vendor/github.com/xenolf/lego/acme/dns_challenge.go +++ b/vendor/github.com/xenolf/lego/acme/dns_challenge.go @@ -23,14 +23,37 @@ var ( fqdnToZone = map[string]string{} ) -var RecursiveNameservers = []string{ +const defaultResolvConf = "/etc/resolv.conf" + +var defaultNameservers = []string{ "google-public-dns-a.google.com:53", "google-public-dns-b.google.com:53", } +var RecursiveNameservers = getNameservers(defaultResolvConf, defaultNameservers) + // DNSTimeout is used to override the default DNS timeout of 10 seconds. var DNSTimeout = 10 * time.Second +// getNameservers attempts to get systems nameservers before falling back to the defaults +func getNameservers(path string, defaults []string) []string { + config, err := dns.ClientConfigFromFile(path) + if err != nil || len(config.Servers) == 0 { + return defaults + } + + systemNameservers := []string{} + for _, server := range config.Servers { + // ensure all servers have a port number + if _, _, err := net.SplitHostPort(server); err != nil { + systemNameservers = append(systemNameservers, net.JoinHostPort(server, "53")) + } else { + systemNameservers = append(systemNameservers, server) + } + } + return systemNameservers +} + // DNS01Record returns a DNS record which will fulfill the `dns-01` challenge func DNS01Record(domain, keyAuth string) (fqdn string, value string, ttl int) { keyAuthShaBytes := sha256.Sum256([]byte(keyAuth)) @@ -75,7 +98,7 @@ func (s *dnsChallenge) Solve(chlng challenge, domain string) error { fqdn, value, _ := DNS01Record(domain, keyAuth) - logf("[INFO][%s] Checking DNS record propagation...", domain) + logf("[INFO][%s] Checking DNS record propagation using %+v", domain, RecursiveNameservers) var timeout, interval time.Duration switch provider := s.provider.(type) { diff --git a/vendor/github.com/xenolf/lego/acme/dns_challenge_test.go b/vendor/github.com/xenolf/lego/acme/dns_challenge_test.go index 6e448854b..597aaac17 100644 --- a/vendor/github.com/xenolf/lego/acme/dns_challenge_test.go +++ b/vendor/github.com/xenolf/lego/acme/dns_challenge_test.go @@ -85,6 +85,15 @@ var checkAuthoritativeNssTestsErr = []struct { }, } +var checkResolvConfServersTests = []struct { + fixture string + expected []string + defaults []string +}{ + {"testdata/resolv.conf.1", []string{"10.200.3.249:53", "10.200.3.250:5353", "[2001:4860:4860::8844]:53", "[10.0.0.1]:5353"}, []string{"127.0.0.1:53"}}, + {"testdata/resolv.conf.nonexistant", []string{"127.0.0.1:53"}, []string{"127.0.0.1:53"}}, +} + func TestDNSValidServerResponse(t *testing.T) { PreCheckDNS = func(fqdn, value string) (bool, error) { return true, nil @@ -183,3 +192,15 @@ func TestCheckAuthoritativeNssErr(t *testing.T) { } } } + +func TestResolveConfServers(t *testing.T) { + for _, tt := range checkResolvConfServersTests { + result := getNameservers(tt.fixture, tt.defaults) + + sort.Strings(result) + sort.Strings(tt.expected) + if !reflect.DeepEqual(result, tt.expected) { + t.Errorf("#%s: expected %q; got %q", tt.fixture, tt.expected, result) + } + } +} diff --git a/vendor/github.com/xenolf/lego/acme/http_challenge_server.go b/vendor/github.com/xenolf/lego/acme/http_challenge_server.go index 42541380c..64c6a8280 100644 --- a/vendor/github.com/xenolf/lego/acme/http_challenge_server.go +++ b/vendor/github.com/xenolf/lego/acme/http_challenge_server.go @@ -63,7 +63,7 @@ func (s *HTTPProviderServer) serve(domain, token, keyAuth string) { w.Write([]byte(keyAuth)) logf("[INFO][%s] Served key authentication", domain) } else { - logf("[INFO] Received request for domain %s with method %s", r.Host, r.Method) + logf("[WARN] Received request for domain %s with method %s but the domain did not match any challenge. Please ensure your are passing the HOST header properly.", r.Host, r.Method) w.Write([]byte("TEST")) } }) diff --git a/vendor/github.com/xenolf/lego/acme/http_challenge_test.go b/vendor/github.com/xenolf/lego/acme/http_challenge_test.go index fdd8f4d27..7400f56d4 100644 --- a/vendor/github.com/xenolf/lego/acme/http_challenge_test.go +++ b/vendor/github.com/xenolf/lego/acme/http_challenge_test.go @@ -51,7 +51,7 @@ func TestHTTPChallengeInvalidPort(t *testing.T) { if err := solver.Solve(clientChallenge, "localhost:123456"); err == nil { t.Errorf("Solve error: got %v, want error", err) - } else if want := "invalid port 123456"; !strings.HasSuffix(err.Error(), want) { + } else if want, want18 := "invalid port 123456", "123456: invalid port"; !strings.HasSuffix(err.Error(), want) && !strings.HasSuffix(err.Error(), want18) { t.Errorf("Solve error: got %q, want suffix %q", err.Error(), want) } } diff --git a/vendor/github.com/xenolf/lego/acme/messages.go b/vendor/github.com/xenolf/lego/acme/messages.go index 0efeae674..0f6514c3f 100644 --- a/vendor/github.com/xenolf/lego/acme/messages.go +++ b/vendor/github.com/xenolf/lego/acme/messages.go @@ -13,17 +13,10 @@ type directory struct { RevokeCertURL string `json:"revoke-cert"` } -type recoveryKeyMessage struct { - Length int `json:"length,omitempty"` - Client jose.JsonWebKey `json:"client,omitempty"` - Server jose.JsonWebKey `json:"client,omitempty"` -} - type registrationMessage struct { Resource string `json:"resource"` Contact []string `json:"contact"` Delete bool `json:"delete,omitempty"` - // RecoveryKey recoveryKeyMessage `json:"recoveryKey,omitempty"` } // Registration is returned by the ACME server after the registration @@ -36,7 +29,6 @@ type Registration struct { Agreement string `json:"agreement,omitempty"` Authorizations string `json:"authorizations,omitempty"` Certificates string `json:"certificates,omitempty"` - // RecoveryKey recoveryKeyMessage `json:"recoveryKey,omitempty"` } // RegistrationResource represents all important informations about a registration diff --git a/vendor/github.com/xenolf/lego/acme/testdata/resolv.conf.1 b/vendor/github.com/xenolf/lego/acme/testdata/resolv.conf.1 new file mode 100644 index 000000000..3098f99b5 --- /dev/null +++ b/vendor/github.com/xenolf/lego/acme/testdata/resolv.conf.1 @@ -0,0 +1,5 @@ +domain company.com +nameserver 10.200.3.249 +nameserver 10.200.3.250:5353 +nameserver 2001:4860:4860::8844 +nameserver [10.0.0.1]:5353 diff --git a/vendor/github.com/xenolf/lego/acme/tls_sni_challenge_test.go b/vendor/github.com/xenolf/lego/acme/tls_sni_challenge_test.go index 3aec74565..83b2833a9 100644 --- a/vendor/github.com/xenolf/lego/acme/tls_sni_challenge_test.go +++ b/vendor/github.com/xenolf/lego/acme/tls_sni_challenge_test.go @@ -59,7 +59,7 @@ func TestTLSSNIChallengeInvalidPort(t *testing.T) { if err := solver.Solve(clientChallenge, "localhost:123456"); err == nil { t.Errorf("Solve error: got %v, want error", err) - } else if want := "invalid port 123456"; !strings.HasSuffix(err.Error(), want) { + } else if want, want18 := "invalid port 123456", "123456: invalid port"; !strings.HasSuffix(err.Error(), want) && !strings.HasSuffix(err.Error(), want18) { t.Errorf("Solve error: got %q, want suffix %q", err.Error(), want) } } diff --git a/vendor/github.com/xenolf/lego/cli.go b/vendor/github.com/xenolf/lego/cli.go index abdcf47de..9fac2dd59 100644 --- a/vendor/github.com/xenolf/lego/cli.go +++ b/vendor/github.com/xenolf/lego/cli.go @@ -64,6 +64,10 @@ func main() { Name: "no-bundle", Usage: "Do not create a certificate bundle by adding the issuers certificate to the new certificate.", }, + cli.BoolFlag{ + Name: "must-staple", + Usage: "Include the OCSP must staple TLS extension in the CSR and generated certificate. Only works if the CSR is generated by lego.", + }, }, }, { @@ -89,6 +93,10 @@ func main() { Name: "no-bundle", Usage: "Do not create a certificate bundle by adding the issuers certificate to the new certificate.", }, + cli.BoolFlag{ + Name: "must-staple", + Usage: "Include the OCSP must staple TLS extension in the CSR and generated certificate. Only works if the CSR is generated by lego.", + }, }, }, { @@ -138,6 +146,10 @@ func main() { Name: "webroot", Usage: "Set the webroot folder to use for HTTP based challenges to write directly in a file in .well-known/acme-challenge", }, + cli.StringSliceFlag{ + Name: "memcached-host", + Usage: "Set the memcached host(s) to use for HTTP based challenges. Challenges will be written to all specified hosts.", + }, cli.StringFlag{ Name: "http", Usage: "Set the port and interface to use for HTTP based challenges to listen on. Supported: interface:port or :port", @@ -189,21 +201,26 @@ Here is an example bash command using the CloudFlare DNS provider: w := tabwriter.NewWriter(os.Stdout, 0, 8, 1, '\t', 0) fmt.Fprintln(w, "Valid providers and their associated credential environment variables:") fmt.Fprintln(w) + fmt.Fprintln(w, "\tazure:\tAZURE_CLIENT_ID, AZURE_CLIENT_SECRET, AZURE_SUBSCRIPTION_ID, AZURE_TENANT_ID, AZURE_RESROUCE_GROUP") + fmt.Fprintln(w, "\tauroradns:\tAURORA_USER_ID, AURORA_KEY, AURORA_ENDPOINT") fmt.Fprintln(w, "\tcloudflare:\tCLOUDFLARE_EMAIL, CLOUDFLARE_API_KEY") fmt.Fprintln(w, "\tdigitalocean:\tDO_AUTH_TOKEN") fmt.Fprintln(w, "\tdnsimple:\tDNSIMPLE_EMAIL, DNSIMPLE_API_KEY") fmt.Fprintln(w, "\tdnsmadeeasy:\tDNSMADEEASY_API_KEY, DNSMADEEASY_API_SECRET") + fmt.Fprintln(w, "\texoscale:\tEXOSCALE_API_KEY, EXOSCALE_API_SECRET, EXOSCALE_ENDPOINT") fmt.Fprintln(w, "\tgandi:\tGANDI_API_KEY") fmt.Fprintln(w, "\tgcloud:\tGCE_PROJECT") fmt.Fprintln(w, "\tlinode:\tLINODE_API_KEY") fmt.Fprintln(w, "\tmanual:\tnone") fmt.Fprintln(w, "\tnamecheap:\tNAMECHEAP_API_USER, NAMECHEAP_API_KEY") + fmt.Fprintln(w, "\trackspace:\tRACKSPACE_USER, RACKSPACE_API_KEY") fmt.Fprintln(w, "\trfc2136:\tRFC2136_TSIG_KEY, RFC2136_TSIG_SECRET,\n\t\tRFC2136_TSIG_ALGORITHM, RFC2136_NAMESERVER") fmt.Fprintln(w, "\troute53:\tAWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, AWS_REGION") fmt.Fprintln(w, "\tdyn:\tDYN_CUSTOMER_NAME, DYN_USER_NAME, DYN_PASSWORD") fmt.Fprintln(w, "\tvultr:\tVULTR_API_KEY") fmt.Fprintln(w, "\tovh:\tOVH_ENDPOINT, OVH_APPLICATION_KEY, OVH_APPLICATION_SECRET, OVH_CONSUMER_KEY") fmt.Fprintln(w, "\tpdns:\tPDNS_API_KEY, PDNS_API_URL") + fmt.Fprintln(w, "\tdnspod:\tDNSPOD_API_KEY") w.Flush() fmt.Println(` diff --git a/vendor/github.com/xenolf/lego/cli_handlers.go b/vendor/github.com/xenolf/lego/cli_handlers.go index 29a1166d8..45e781246 100644 --- a/vendor/github.com/xenolf/lego/cli_handlers.go +++ b/vendor/github.com/xenolf/lego/cli_handlers.go @@ -15,20 +15,27 @@ import ( "github.com/urfave/cli" "github.com/xenolf/lego/acme" + "github.com/xenolf/lego/providers/dns/auroradns" + "github.com/xenolf/lego/providers/dns/azure" "github.com/xenolf/lego/providers/dns/cloudflare" "github.com/xenolf/lego/providers/dns/digitalocean" "github.com/xenolf/lego/providers/dns/dnsimple" "github.com/xenolf/lego/providers/dns/dnsmadeeasy" + "github.com/xenolf/lego/providers/dns/dnspod" "github.com/xenolf/lego/providers/dns/dyn" + "github.com/xenolf/lego/providers/dns/exoscale" "github.com/xenolf/lego/providers/dns/gandi" "github.com/xenolf/lego/providers/dns/googlecloud" "github.com/xenolf/lego/providers/dns/linode" "github.com/xenolf/lego/providers/dns/namecheap" + "github.com/xenolf/lego/providers/dns/ns1" "github.com/xenolf/lego/providers/dns/ovh" "github.com/xenolf/lego/providers/dns/pdns" + "github.com/xenolf/lego/providers/dns/rackspace" "github.com/xenolf/lego/providers/dns/rfc2136" "github.com/xenolf/lego/providers/dns/route53" "github.com/xenolf/lego/providers/dns/vultr" + "github.com/xenolf/lego/providers/http/memcached" "github.com/xenolf/lego/providers/http/webroot" ) @@ -99,6 +106,18 @@ func setup(c *cli.Context) (*Configuration, *Account, *acme.Client) { // infer that the user also wants to exclude all other challenges client.ExcludeChallenges([]acme.Challenge{acme.DNS01, acme.TLSSNI01}) } + if c.GlobalIsSet("memcached-host") { + provider, err := memcached.NewMemcachedProvider(c.GlobalStringSlice("memcached-host")) + if err != nil { + logger().Fatal(err) + } + + client.SetChallengeProvider(acme.HTTP01, provider) + + // --memcached-host=foo:11211 indicates that the user specifically want to do a HTTP challenge + // infer that the user also wants to exclude all other challenges + client.ExcludeChallenges([]acme.Challenge{acme.DNS01, acme.TLSSNI01}) + } if c.GlobalIsSet("http") { if strings.Index(c.GlobalString("http"), ":") == -1 { logger().Fatalf("The --http switch only accepts interface:port or :port for its argument.") @@ -117,6 +136,10 @@ func setup(c *cli.Context) (*Configuration, *Account, *acme.Client) { var err error var provider acme.ChallengeProvider switch c.GlobalString("dns") { + case "azure": + provider, err = azure.NewDNSProvider() + case "auroradns": + provider, err = auroradns.NewDNSProvider() case "cloudflare": provider, err = cloudflare.NewDNSProvider() case "digitalocean": @@ -125,6 +148,8 @@ func setup(c *cli.Context) (*Configuration, *Account, *acme.Client) { provider, err = dnsimple.NewDNSProvider() case "dnsmadeeasy": provider, err = dnsmadeeasy.NewDNSProvider() + case "exoscale": + provider, err = exoscale.NewDNSProvider() case "dyn": provider, err = dyn.NewDNSProvider() case "gandi": @@ -137,6 +162,8 @@ func setup(c *cli.Context) (*Configuration, *Account, *acme.Client) { provider, err = acme.NewDNSProviderManual() case "namecheap": provider, err = namecheap.NewDNSProvider() + case "rackspace": + provider, err = rackspace.NewDNSProvider() case "route53": provider, err = route53.NewDNSProvider() case "rfc2136": @@ -147,6 +174,10 @@ func setup(c *cli.Context) (*Configuration, *Account, *acme.Client) { provider, err = ovh.NewDNSProvider() case "pdns": provider, err = pdns.NewDNSProvider() + case "ns1": + provider, err = ns1.NewDNSProvider() + case "dnspod": + provider, err = dnspod.NewDNSProvider() } if err != nil { @@ -320,7 +351,7 @@ func run(c *cli.Context) error { if hasDomains { // obtain a certificate, generating a new private key - cert, failures = client.ObtainCertificate(c.GlobalStringSlice("domains"), !c.Bool("no-bundle"), nil) + cert, failures = client.ObtainCertificate(c.GlobalStringSlice("domains"), !c.Bool("no-bundle"), nil, c.Bool("must-staple")) } else { // read the CSR csr, err := readCSRFile(c.GlobalString("csr")) @@ -433,7 +464,7 @@ func renew(c *cli.Context) error { certRes.Certificate = certBytes - newCert, err := client.RenewCertificate(certRes, !c.Bool("no-bundle")) + newCert, err := client.RenewCertificate(certRes, !c.Bool("no-bundle"), c.Bool("must-staple")) if err != nil { logger().Fatalf("%s", err.Error()) } diff --git a/vendor/github.com/xenolf/lego/providers/dns/auroradns/auroradns.go b/vendor/github.com/xenolf/lego/providers/dns/auroradns/auroradns.go new file mode 100644 index 000000000..55b48f9b4 --- /dev/null +++ b/vendor/github.com/xenolf/lego/providers/dns/auroradns/auroradns.go @@ -0,0 +1,141 @@ +package auroradns + +import ( + "fmt" + "github.com/edeckers/auroradnsclient" + "github.com/edeckers/auroradnsclient/records" + "github.com/edeckers/auroradnsclient/zones" + "github.com/xenolf/lego/acme" + "os" + "sync" +) + +// DNSProvider describes a provider for AuroraDNS +type DNSProvider struct { + recordIDs map[string]string + recordIDsMu sync.Mutex + client *auroradnsclient.AuroraDNSClient +} + +// NewDNSProvider returns a DNSProvider instance configured for AuroraDNS. +// Credentials must be passed in the environment variables: AURORA_USER_ID +// and AURORA_KEY. +func NewDNSProvider() (*DNSProvider, error) { + userID := os.Getenv("AURORA_USER_ID") + key := os.Getenv("AURORA_KEY") + + endpoint := os.Getenv("AURORA_ENDPOINT") + if endpoint == "" { + endpoint = "https://api.auroradns.eu" + } + + return NewDNSProviderCredentials(endpoint, userID, key) +} + +// NewDNSProviderCredentials uses the supplied credentials to return a +// DNSProvider instance configured for AuroraDNS. +func NewDNSProviderCredentials(baseURL string, userID string, key string) (*DNSProvider, error) { + client, err := auroradnsclient.NewAuroraDNSClient(baseURL, userID, key) + if err != nil { + return nil, err + } + + return &DNSProvider{ + client: client, + recordIDs: make(map[string]string), + }, nil +} + +func (provider *DNSProvider) getZoneInformationByName(name string) (zones.ZoneRecord, error) { + zs, err := provider.client.GetZones() + + if err != nil { + return zones.ZoneRecord{}, err + } + + for _, element := range zs { + if element.Name == name { + return element, nil + } + } + + return zones.ZoneRecord{}, fmt.Errorf("Could not find Zone record") +} + +// Present creates a record with a secret +func (provider *DNSProvider) Present(domain, token, keyAuth string) error { + fqdn, value, _ := acme.DNS01Record(domain, keyAuth) + + authZone, err := acme.FindZoneByFqdn(acme.ToFqdn(domain), acme.RecursiveNameservers) + if err != nil { + return fmt.Errorf("Could not determine zone for domain: '%s'. %s", domain, err) + } + + // 1. Aurora will happily create the TXT record when it is provided a fqdn, + // but it will only appear in the control panel and will not be + // propagated to DNS servers. Extract and use subdomain instead. + // 2. A trailing dot in the fqdn will cause Aurora to add a trailing dot to + // the subdomain, resulting in _acme-challenge..<domain> rather + // than _acme-challenge.<domain> + + subdomain := fqdn[0 : len(fqdn)-len(authZone)-1] + + authZone = acme.UnFqdn(authZone) + + zoneRecord, err := provider.getZoneInformationByName(authZone) + + reqData := + records.CreateRecordRequest{ + RecordType: "TXT", + Name: subdomain, + Content: value, + TTL: 300, + } + + respData, err := provider.client.CreateRecord(zoneRecord.ID, reqData) + if err != nil { + return fmt.Errorf("Could not create record: '%s'.", err) + } + + provider.recordIDsMu.Lock() + provider.recordIDs[fqdn] = respData.ID + provider.recordIDsMu.Unlock() + + return nil +} + +// CleanUp removes a given record that was generated by Present +func (provider *DNSProvider) CleanUp(domain, token, keyAuth string) error { + fqdn, _, _ := acme.DNS01Record(domain, keyAuth) + + provider.recordIDsMu.Lock() + recordID, ok := provider.recordIDs[fqdn] + provider.recordIDsMu.Unlock() + + if !ok { + return fmt.Errorf("Unknown recordID for '%s'", fqdn) + } + + authZone, err := acme.FindZoneByFqdn(acme.ToFqdn(domain), acme.RecursiveNameservers) + if err != nil { + return fmt.Errorf("Could not determine zone for domain: '%s'. %s", domain, err) + } + + authZone = acme.UnFqdn(authZone) + + zoneRecord, err := provider.getZoneInformationByName(authZone) + if err != nil { + return err + } + + _, err = provider.client.RemoveRecord(zoneRecord.ID, recordID) + if err != nil { + return err + } + + provider.recordIDsMu.Lock() + delete(provider.recordIDs, fqdn) + provider.recordIDsMu.Unlock() + + return nil +} diff --git a/vendor/github.com/xenolf/lego/providers/dns/auroradns/auroradns_test.go b/vendor/github.com/xenolf/lego/providers/dns/auroradns/auroradns_test.go new file mode 100644 index 000000000..f4df7fa61 --- /dev/null +++ b/vendor/github.com/xenolf/lego/providers/dns/auroradns/auroradns_test.go @@ -0,0 +1,148 @@ +package auroradns + +import ( + "fmt" + "io/ioutil" + "net/http" + "net/http/httptest" + "testing" +) + +var fakeAuroraDNSUserId = "asdf1234" +var fakeAuroraDNSKey = "key" + +func TestAuroraDNSPresent(t *testing.T) { + var requestReceived bool + + mock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == "GET" && r.URL.Path == "/zones" { + w.WriteHeader(http.StatusCreated) + fmt.Fprintf(w, `[{ + "id": "c56a4180-65aa-42ec-a945-5fd21dec0538", + "name": "example.com" + }]`) + return + } + + requestReceived = true + + if got, want := r.Method, "POST"; got != want { + t.Errorf("Expected method to be '%s' but got '%s'", want, got) + } + + if got, want := r.URL.Path, "/zones/c56a4180-65aa-42ec-a945-5fd21dec0538/records"; got != want { + t.Errorf("Expected path to be '%s' but got '%s'", want, got) + } + + if got, want := r.Header.Get("Content-Type"), "application/json"; got != want { + t.Errorf("Expected Content-Type to be '%s' but got '%s'", want, got) + } + + reqBody, err := ioutil.ReadAll(r.Body) + if err != nil { + t.Fatalf("Error reading request body: %v", err) + } + + if got, want := string(reqBody), + `{"type":"TXT","name":"_acme-challenge","content":"w6uP8Tcg6K2QR905Rms8iXTlksL6OD1KOWBxTK7wxPI","ttl":300}`; got != want { + + t.Errorf("Expected body data to be: `%s` but got `%s`", want, got) + } + + w.WriteHeader(http.StatusCreated) + fmt.Fprintf(w, `{ + "id": "c56a4180-65aa-42ec-a945-5fd21dec0538", + "type": "TXT", + "name": "_acme-challenge", + "ttl": 300 + }`) + })) + + defer mock.Close() + + auroraProvider, err := NewDNSProviderCredentials(mock.URL, fakeAuroraDNSUserId, fakeAuroraDNSKey) + if auroraProvider == nil { + t.Fatal("Expected non-nil AuroraDNS provider, but was nil") + } + + if err != nil { + t.Fatalf("Expected no error creating provider, but got: %v", err) + } + + err = auroraProvider.Present("example.com", "", "foobar") + if err != nil { + t.Fatalf("Expected no error creating TXT record, but got: %v", err) + } + + if !requestReceived { + t.Error("Expected request to be received by mock backend, but it wasn't") + } +} + +func TestAuroraDNSCleanUp(t *testing.T) { + var requestReceived bool + + mock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == "GET" && r.URL.Path == "/zones" { + w.WriteHeader(http.StatusCreated) + fmt.Fprintf(w, `[{ + "id": "c56a4180-65aa-42ec-a945-5fd21dec0538", + "name": "example.com" + }]`) + return + } + + if r.Method == "POST" && r.URL.Path == "/zones/c56a4180-65aa-42ec-a945-5fd21dec0538/records" { + w.WriteHeader(http.StatusCreated) + fmt.Fprintf(w, `{ + "id": "ec56a4180-65aa-42ec-a945-5fd21dec0538", + "type": "TXT", + "name": "_acme-challenge", + "ttl": 300 + }`) + return + } + + requestReceived = true + + if got, want := r.Method, "DELETE"; got != want { + t.Errorf("Expected method to be '%s' but got '%s'", want, got) + } + + if got, want := r.URL.Path, + "/zones/c56a4180-65aa-42ec-a945-5fd21dec0538/records/ec56a4180-65aa-42ec-a945-5fd21dec0538"; got != want { + t.Errorf("Expected path to be '%s' but got '%s'", want, got) + } + + if got, want := r.Header.Get("Content-Type"), "application/json"; got != want { + t.Errorf("Expected Content-Type to be '%s' but got '%s'", want, got) + } + + w.WriteHeader(http.StatusCreated) + fmt.Fprintf(w, `{}`) + })) + defer mock.Close() + + auroraProvider, err := NewDNSProviderCredentials(mock.URL, fakeAuroraDNSUserId, fakeAuroraDNSKey) + if auroraProvider == nil { + t.Fatal("Expected non-nil AuroraDNS provider, but was nil") + } + + if err != nil { + t.Fatalf("Expected no error creating provider, but got: %v", err) + } + + err = auroraProvider.Present("example.com", "", "foobar") + if err != nil { + t.Fatalf("Expected no error creating TXT record, but got: %v", err) + } + + err = auroraProvider.CleanUp("example.com", "", "foobar") + if err != nil { + t.Fatalf("Expected no error removing TXT record, but got: %v", err) + } + + if !requestReceived { + t.Error("Expected request to be received by mock backend, but it wasn't") + } +} diff --git a/vendor/github.com/xenolf/lego/providers/dns/azure/azure.go b/vendor/github.com/xenolf/lego/providers/dns/azure/azure.go new file mode 100644 index 000000000..6742e4f56 --- /dev/null +++ b/vendor/github.com/xenolf/lego/providers/dns/azure/azure.go @@ -0,0 +1,142 @@ +// Package azure implements a DNS provider for solving the DNS-01 +// challenge using azure DNS. +// Azure doesn't like trailing dots on domain names, most of the acme code does. +package azure + +import ( + "fmt" + "os" + "time" + + "github.com/Azure/azure-sdk-for-go/arm/dns" + + "github.com/Azure/go-autorest/autorest/azure" + "github.com/Azure/go-autorest/autorest/to" + "github.com/xenolf/lego/acme" + "strings" +) + +// DNSProvider is an implementation of the acme.ChallengeProvider interface +type DNSProvider struct { + clientId string + clientSecret string + subscriptionId string + tenantId string + resourceGroup string +} + +// NewDNSProvider returns a DNSProvider instance configured for azure. +// Credentials must be passed in the environment variables: AZURE_CLIENT_ID, +// AZURE_CLIENT_SECRET, AZURE_SUBSCRIPTION_ID, AZURE_TENANT_ID +func NewDNSProvider() (*DNSProvider, error) { + clientId := os.Getenv("AZURE_CLIENT_ID") + clientSecret := os.Getenv("AZURE_CLIENT_SECRET") + subscriptionId := os.Getenv("AZURE_SUBSCRIPTION_ID") + tenantId := os.Getenv("AZURE_TENANT_ID") + resourceGroup := os.Getenv("AZURE_RESOURCE_GROUP") + return NewDNSProviderCredentials(clientId, clientSecret, subscriptionId, tenantId, resourceGroup) +} + +// NewDNSProviderCredentials uses the supplied credentials to return a +// DNSProvider instance configured for azure. +func NewDNSProviderCredentials(clientId, clientSecret, subscriptionId, tenantId, resourceGroup string) (*DNSProvider, error) { + if clientId == "" || clientSecret == "" || subscriptionId == "" || tenantId == "" || resourceGroup == "" { + return nil, fmt.Errorf("Azure configuration missing") + } + + return &DNSProvider{ + clientId: clientId, + clientSecret: clientSecret, + subscriptionId: subscriptionId, + tenantId: tenantId, + resourceGroup: resourceGroup, + }, nil +} + +// Timeout returns the timeout and interval to use when checking for DNS +// propagation. Adjusting here to cope with spikes in propagation times. +func (c *DNSProvider) Timeout() (timeout, interval time.Duration) { + return 120 * time.Second, 2 * time.Second +} + +// Present creates a TXT record to fulfil the dns-01 challenge +func (c *DNSProvider) Present(domain, token, keyAuth string) error { + fqdn, value, _ := acme.DNS01Record(domain, keyAuth) + zone, err := c.getHostedZoneID(fqdn) + if err != nil { + return err + } + + rsc := dns.NewRecordSetsClient(c.subscriptionId) + rsc.Authorizer, err = c.newServicePrincipalTokenFromCredentials(azure.PublicCloud.ResourceManagerEndpoint) + relative := toRelativeRecord(fqdn, acme.ToFqdn(zone)) + rec := dns.RecordSet{ + Name: &relative, + Properties: &dns.RecordSetProperties{ + TTL: to.Int64Ptr(60), + TXTRecords: &[]dns.TxtRecord{dns.TxtRecord{Value: &[]string{value}}}, + }, + } + _, err = rsc.CreateOrUpdate(c.resourceGroup, zone, relative, dns.TXT, rec, "", "") + + if err != nil { + return err + } + + return nil +} + +// Returns the relative record to the domain +func toRelativeRecord(domain, zone string) string { + return acme.UnFqdn(strings.TrimSuffix(domain, zone)) +} + +// CleanUp removes the TXT record matching the specified parameters +func (c *DNSProvider) CleanUp(domain, token, keyAuth string) error { + fqdn, _, _ := acme.DNS01Record(domain, keyAuth) + + zone, err := c.getHostedZoneID(fqdn) + if err != nil { + return err + } + + relative := toRelativeRecord(fqdn, acme.ToFqdn(zone)) + rsc := dns.NewRecordSetsClient(c.subscriptionId) + rsc.Authorizer, err = c.newServicePrincipalTokenFromCredentials(azure.PublicCloud.ResourceManagerEndpoint) + _, err = rsc.Delete(c.resourceGroup, zone, relative, dns.TXT, "", "") + if err != nil { + return err + } + + return nil +} + +// Checks that azure has a zone for this domain name. +func (c *DNSProvider) getHostedZoneID(fqdn string) (string, error) { + authZone, err := acme.FindZoneByFqdn(fqdn, acme.RecursiveNameservers) + if err != nil { + return "", err + } + + // Now we want to to Azure and get the zone. + dc := dns.NewZonesClient(c.subscriptionId) + dc.Authorizer, err = c.newServicePrincipalTokenFromCredentials(azure.PublicCloud.ResourceManagerEndpoint) + zone, err := dc.Get(c.resourceGroup, acme.UnFqdn(authZone)) + + if err != nil { + return "", err + } + + // zone.Name shouldn't have a trailing dot(.) + return to.String(zone.Name), nil +} + +// NewServicePrincipalTokenFromCredentials creates a new ServicePrincipalToken using values of the +// passed credentials map. +func (c *DNSProvider) newServicePrincipalTokenFromCredentials(scope string) (*azure.ServicePrincipalToken, error) { + oauthConfig, err := azure.PublicCloud.OAuthConfigForTenant(c.tenantId) + if err != nil { + panic(err) + } + return azure.NewServicePrincipalToken(*oauthConfig, c.clientId, c.clientSecret, scope) +} diff --git a/vendor/github.com/xenolf/lego/providers/dns/azure/azure_test.go b/vendor/github.com/xenolf/lego/providers/dns/azure/azure_test.go new file mode 100644 index 000000000..db55f578a --- /dev/null +++ b/vendor/github.com/xenolf/lego/providers/dns/azure/azure_test.go @@ -0,0 +1,89 @@ +package azure + +import ( + "os" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +var ( + azureLiveTest bool + azureClientID string + azureClientSecret string + azureSubscriptionID string + azureTenantID string + azureResourceGroup string + azureDomain string +) + +func init() { + azureClientID = os.Getenv("AZURE_CLIENT_ID") + azureClientSecret = os.Getenv("AZURE_CLIENT_SECRET") + azureSubscriptionID = os.Getenv("AZURE_SUBSCRIPTION_ID") + azureTenantID = os.Getenv("AZURE_TENANT_ID") + azureResourceGroup = os.Getenv("AZURE_RESOURCE_GROUP") + azureDomain = os.Getenv("AZURE_DOMAIN") + if len(azureClientID) > 0 && len(azureClientSecret) > 0 { + azureLiveTest = true + } +} + +func restoreAzureEnv() { + os.Setenv("AZURE_CLIENT_ID", azureClientID) + os.Setenv("AZURE_SUBSCRIPTION_ID", azureSubscriptionID) +} + +func TestNewDNSProviderValid(t *testing.T) { + if !azureLiveTest { + t.Skip("skipping live test (requires credentials)") + } + os.Setenv("AZURE_CLIENT_ID", "") + _, err := NewDNSProviderCredentials(azureClientID, azureClientSecret, azureSubscriptionID, azureTenantID, azureResourceGroup) + assert.NoError(t, err) + restoreAzureEnv() +} + +func TestNewDNSProviderValidEnv(t *testing.T) { + if !azureLiveTest { + t.Skip("skipping live test (requires credentials)") + } + os.Setenv("AZURE_CLIENT_ID", "other") + _, err := NewDNSProvider() + assert.NoError(t, err) + restoreAzureEnv() +} + +func TestNewDNSProviderMissingCredErr(t *testing.T) { + os.Setenv("AZURE_SUBSCRIPTION_ID", "") + _, err := NewDNSProvider() + assert.EqualError(t, err, "Azure configuration missing") + restoreAzureEnv() +} + +func TestLiveAzurePresent(t *testing.T) { + if !azureLiveTest { + t.Skip("skipping live test") + } + + provider, err := NewDNSProviderCredentials(azureClientID, azureClientSecret, azureSubscriptionID, azureTenantID, azureResourceGroup) + assert.NoError(t, err) + + err = provider.Present(azureDomain, "", "123d==") + assert.NoError(t, err) +} + +func TestLiveAzureCleanUp(t *testing.T) { + if !azureLiveTest { + t.Skip("skipping live test") + } + + provider, err := NewDNSProviderCredentials(azureClientID, azureClientSecret, azureSubscriptionID, azureTenantID, azureResourceGroup) + time.Sleep(time.Second * 1) + + assert.NoError(t, err) + + err = provider.CleanUp(azureDomain, "", "123d==") + assert.NoError(t, err) +} diff --git a/vendor/github.com/xenolf/lego/providers/dns/dnspod/dnspod.go b/vendor/github.com/xenolf/lego/providers/dns/dnspod/dnspod.go new file mode 100644 index 000000000..0ce08a8bb --- /dev/null +++ b/vendor/github.com/xenolf/lego/providers/dns/dnspod/dnspod.go @@ -0,0 +1,146 @@ +// Package dnspod implements a DNS provider for solving the DNS-01 challenge +// using dnspod DNS. +package dnspod + +import ( + "fmt" + "os" + "strings" + + "github.com/decker502/dnspod-go" + "github.com/xenolf/lego/acme" +) + +// DNSProvider is an implementation of the acme.ChallengeProvider interface. +type DNSProvider struct { + client *dnspod.Client +} + +// NewDNSProvider returns a DNSProvider instance configured for dnspod. +// Credentials must be passed in the environment variables: DNSPOD_API_KEY. +func NewDNSProvider() (*DNSProvider, error) { + key := os.Getenv("DNSPOD_API_KEY") + return NewDNSProviderCredentials(key) +} + +// NewDNSProviderCredentials uses the supplied credentials to return a +// DNSProvider instance configured for dnspod. +func NewDNSProviderCredentials(key string) (*DNSProvider, error) { + if key == "" { + return nil, fmt.Errorf("dnspod credentials missing") + } + + params := dnspod.CommonParams{LoginToken: key, Format: "json"} + return &DNSProvider{ + client: dnspod.NewClient(params), + }, nil +} + +// Present creates a TXT record to fulfil the dns-01 challenge. +func (c *DNSProvider) Present(domain, token, keyAuth string) error { + fqdn, value, ttl := acme.DNS01Record(domain, keyAuth) + zoneID, zoneName, err := c.getHostedZone(domain) + if err != nil { + return err + } + + recordAttributes := c.newTxtRecord(zoneName, fqdn, value, ttl) + _, _, err = c.client.Domains.CreateRecord(zoneID, *recordAttributes) + if err != nil { + return fmt.Errorf("dnspod API call failed: %v", err) + } + + return nil +} + +// CleanUp removes the TXT record matching the specified parameters. +func (c *DNSProvider) CleanUp(domain, token, keyAuth string) error { + fqdn, _, _ := acme.DNS01Record(domain, keyAuth) + + records, err := c.findTxtRecords(domain, fqdn) + if err != nil { + return err + } + + zoneID, _, err := c.getHostedZone(domain) + if err != nil { + return err + } + + for _, rec := range records { + _, err := c.client.Domains.DeleteRecord(zoneID, rec.ID) + if err != nil { + return err + } + } + return nil +} + +func (c *DNSProvider) getHostedZone(domain string) (string, string, error) { + zones, _, err := c.client.Domains.List() + if err != nil { + return "", "", fmt.Errorf("dnspod API call failed: %v", err) + } + + authZone, err := acme.FindZoneByFqdn(acme.ToFqdn(domain), acme.RecursiveNameservers) + if err != nil { + return "", "", err + } + + var hostedZone dnspod.Domain + for _, zone := range zones { + if zone.Name == acme.UnFqdn(authZone) { + hostedZone = zone + } + } + + if hostedZone.ID == 0 { + return "", "", fmt.Errorf("Zone %s not found in dnspod for domain %s", authZone, domain) + + } + + return fmt.Sprintf("%v", hostedZone.ID), hostedZone.Name, nil +} + +func (c *DNSProvider) newTxtRecord(zone, fqdn, value string, ttl int) *dnspod.Record { + name := c.extractRecordName(fqdn, zone) + + return &dnspod.Record{ + Type: "TXT", + Name: name, + Value: value, + Line: "默认", + TTL: "600", + } +} + +func (c *DNSProvider) findTxtRecords(domain, fqdn string) ([]dnspod.Record, error) { + zoneID, zoneName, err := c.getHostedZone(domain) + if err != nil { + return nil, err + } + + var records []dnspod.Record + result, _, err := c.client.Domains.ListRecords(zoneID, "") + if err != nil { + return records, fmt.Errorf("dnspod API call has failed: %v", err) + } + + recordName := c.extractRecordName(fqdn, zoneName) + + for _, record := range result { + if record.Name == recordName { + records = append(records, record) + } + } + + return records, nil +} + +func (c *DNSProvider) extractRecordName(fqdn, domain string) string { + name := acme.UnFqdn(fqdn) + if idx := strings.Index(name, "."+domain); idx != -1 { + return name[:idx] + } + return name +} diff --git a/vendor/github.com/xenolf/lego/providers/dns/dnspod/dnspod_test.go b/vendor/github.com/xenolf/lego/providers/dns/dnspod/dnspod_test.go new file mode 100644 index 000000000..3311eb0a6 --- /dev/null +++ b/vendor/github.com/xenolf/lego/providers/dns/dnspod/dnspod_test.go @@ -0,0 +1,72 @@ +package dnspod + +import ( + "github.com/stretchr/testify/assert" + "os" + "testing" + "time" +) + +var ( + dnspodLiveTest bool + dnspodAPIKey string + dnspodDomain string +) + +func init() { + dnspodAPIKey = os.Getenv("DNSPOD_API_KEY") + dnspodDomain = os.Getenv("DNSPOD_DOMAIN") + if len(dnspodAPIKey) > 0 && len(dnspodDomain) > 0 { + dnspodLiveTest = true + } +} + +func restorednspodEnv() { + os.Setenv("DNSPOD_API_KEY", dnspodAPIKey) +} + +func TestNewDNSProviderValid(t *testing.T) { + os.Setenv("DNSPOD_API_KEY", "") + _, err := NewDNSProviderCredentials("123") + assert.NoError(t, err) + restorednspodEnv() +} +func TestNewDNSProviderValidEnv(t *testing.T) { + os.Setenv("DNSPOD_API_KEY", "123") + _, err := NewDNSProvider() + assert.NoError(t, err) + restorednspodEnv() +} + +func TestNewDNSProviderMissingCredErr(t *testing.T) { + os.Setenv("DNSPOD_API_KEY", "") + _, err := NewDNSProvider() + assert.EqualError(t, err, "dnspod credentials missing") + restorednspodEnv() +} + +func TestLivednspodPresent(t *testing.T) { + if !dnspodLiveTest { + t.Skip("skipping live test") + } + + provider, err := NewDNSProviderCredentials(dnspodAPIKey) + assert.NoError(t, err) + + err = provider.Present(dnspodDomain, "", "123d==") + assert.NoError(t, err) +} + +func TestLivednspodCleanUp(t *testing.T) { + if !dnspodLiveTest { + t.Skip("skipping live test") + } + + time.Sleep(time.Second * 1) + + provider, err := NewDNSProviderCredentials(dnspodAPIKey) + assert.NoError(t, err) + + err = provider.CleanUp(dnspodDomain, "", "123d==") + assert.NoError(t, err) +} diff --git a/vendor/github.com/xenolf/lego/providers/dns/exoscale/exoscale.go b/vendor/github.com/xenolf/lego/providers/dns/exoscale/exoscale.go new file mode 100644 index 000000000..3b6b58d08 --- /dev/null +++ b/vendor/github.com/xenolf/lego/providers/dns/exoscale/exoscale.go @@ -0,0 +1,132 @@ +// Package exoscale implements a DNS provider for solving the DNS-01 challenge +// using exoscale DNS. +package exoscale + +import ( + "errors" + "fmt" + "os" + + "github.com/pyr/egoscale/src/egoscale" + "github.com/xenolf/lego/acme" +) + +// DNSProvider is an implementation of the acme.ChallengeProvider interface. +type DNSProvider struct { + client *egoscale.Client +} + +// Credentials must be passed in the environment variables: +// EXOSCALE_API_KEY, EXOSCALE_API_SECRET, EXOSCALE_ENDPOINT. +func NewDNSProvider() (*DNSProvider, error) { + key := os.Getenv("EXOSCALE_API_KEY") + secret := os.Getenv("EXOSCALE_API_SECRET") + endpoint := os.Getenv("EXOSCALE_ENDPOINT") + return NewDNSProviderClient(key, secret, endpoint) +} + +// Uses the supplied parameters to return a DNSProvider instance +// configured for Exoscale. +func NewDNSProviderClient(key, secret, endpoint string) (*DNSProvider, error) { + if key == "" || secret == "" { + return nil, fmt.Errorf("Exoscale credentials missing") + } + if endpoint == "" { + endpoint = "https://api.exoscale.ch/dns" + } + + return &DNSProvider{ + client: egoscale.NewClient(endpoint, key, secret), + }, nil +} + +// Present creates a TXT record to fulfil the dns-01 challenge. +func (c *DNSProvider) Present(domain, token, keyAuth string) error { + fqdn, value, ttl := acme.DNS01Record(domain, keyAuth) + zone, recordName, err := c.FindZoneAndRecordName(fqdn, domain) + if err != nil { + return err + } + + recordId, err := c.FindExistingRecordId(zone, recordName) + if err != nil { + return err + } + + record := egoscale.DNSRecord{ + Name: recordName, + Ttl: ttl, + Content: value, + RecordType: "TXT", + } + + if recordId == 0 { + _, err := c.client.CreateRecord(zone, record) + if err != nil { + return errors.New("Error while creating DNS record: " + err.Error()) + } + } else { + record.Id = recordId + _, err := c.client.UpdateRecord(zone, record) + if err != nil { + return errors.New("Error while updating DNS record: " + err.Error()) + } + } + + return nil +} + +// CleanUp removes the record matching the specified parameters. +func (c *DNSProvider) CleanUp(domain, token, keyAuth string) error { + fqdn, _, _ := acme.DNS01Record(domain, keyAuth) + zone, recordName, err := c.FindZoneAndRecordName(fqdn, domain) + if err != nil { + return err + } + + recordId, err := c.FindExistingRecordId(zone, recordName) + if err != nil { + return err + } + + if recordId != 0 { + record := egoscale.DNSRecord{ + Id: recordId, + } + + err = c.client.DeleteRecord(zone, record) + if err != nil { + return errors.New("Error while deleting DNS record: " + err.Error()) + } + } + + return nil +} + +// Query Exoscale to find an existing record for this name. +// Returns nil if no record could be found +func (c *DNSProvider) FindExistingRecordId(zone, recordName string) (int64, error) { + responses, err := c.client.GetRecords(zone) + if err != nil { + return -1, errors.New("Error while retrievening DNS records: " + err.Error()) + } + for _, response := range responses { + if response.Record.Name == recordName { + return response.Record.Id, nil + } + } + return 0, nil +} + +// Extract DNS zone and DNS entry name +func (c *DNSProvider) FindZoneAndRecordName(fqdn, domain string) (string, string, error) { + zone, err := acme.FindZoneByFqdn(acme.ToFqdn(domain), acme.RecursiveNameservers) + if err != nil { + return "", "", err + } + zone = acme.UnFqdn(zone) + name := acme.UnFqdn(fqdn) + name = name[:len(name)-len("."+zone)] + + return zone, name, nil +} diff --git a/vendor/github.com/xenolf/lego/providers/dns/exoscale/exoscale_test.go b/vendor/github.com/xenolf/lego/providers/dns/exoscale/exoscale_test.go new file mode 100644 index 000000000..343dd56f8 --- /dev/null +++ b/vendor/github.com/xenolf/lego/providers/dns/exoscale/exoscale_test.go @@ -0,0 +1,103 @@ +package exoscale + +import ( + "os" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +var ( + exoscaleLiveTest bool + exoscaleAPIKey string + exoscaleAPISecret string + exoscaleDomain string +) + +func init() { + exoscaleAPISecret = os.Getenv("EXOSCALE_API_SECRET") + exoscaleAPIKey = os.Getenv("EXOSCALE_API_KEY") + exoscaleDomain = os.Getenv("EXOSCALE_DOMAIN") + if len(exoscaleAPIKey) > 0 && len(exoscaleAPISecret) > 0 && len(exoscaleDomain) > 0 { + exoscaleLiveTest = true + } +} + +func restoreExoscaleEnv() { + os.Setenv("EXOSCALE_API_KEY", exoscaleAPIKey) + os.Setenv("EXOSCALE_API_SECRET", exoscaleAPISecret) +} + +func TestNewDNSProviderValid(t *testing.T) { + os.Setenv("EXOSCALE_API_KEY", "") + os.Setenv("EXOSCALE_API_SECRET", "") + _, err := NewDNSProviderClient("example@example.com", "123", "") + assert.NoError(t, err) + restoreExoscaleEnv() +} +func TestNewDNSProviderValidEnv(t *testing.T) { + os.Setenv("EXOSCALE_API_KEY", "example@example.com") + os.Setenv("EXOSCALE_API_SECRET", "123") + _, err := NewDNSProvider() + assert.NoError(t, err) + restoreExoscaleEnv() +} + +func TestNewDNSProviderMissingCredErr(t *testing.T) { + os.Setenv("EXOSCALE_API_KEY", "") + os.Setenv("EXOSCALE_API_SECRET", "") + _, err := NewDNSProvider() + assert.EqualError(t, err, "Exoscale credentials missing") + restoreExoscaleEnv() +} + +func TestExtractRootRecordName(t *testing.T) { + provider, err := NewDNSProviderClient("example@example.com", "123", "") + assert.NoError(t, err) + + zone, recordName, err := provider.FindZoneAndRecordName("_acme-challenge.bar.com.", "bar.com") + assert.NoError(t, err) + assert.Equal(t, "bar.com", zone) + assert.Equal(t, "_acme-challenge", recordName) +} + +func TestExtractSubRecordName(t *testing.T) { + provider, err := NewDNSProviderClient("example@example.com", "123", "") + assert.NoError(t, err) + + zone, recordName, err := provider.FindZoneAndRecordName("_acme-challenge.foo.bar.com.", "foo.bar.com") + assert.NoError(t, err) + assert.Equal(t, "bar.com", zone) + assert.Equal(t, "_acme-challenge.foo", recordName) +} + +func TestLiveExoscalePresent(t *testing.T) { + if !exoscaleLiveTest { + t.Skip("skipping live test") + } + + provider, err := NewDNSProviderClient(exoscaleAPIKey, exoscaleAPISecret, "") + assert.NoError(t, err) + + err = provider.Present(exoscaleDomain, "", "123d==") + assert.NoError(t, err) + + // Present Twice to handle create / update + err = provider.Present(exoscaleDomain, "", "123d==") + assert.NoError(t, err) +} + +func TestLiveExoscaleCleanUp(t *testing.T) { + if !exoscaleLiveTest { + t.Skip("skipping live test") + } + + time.Sleep(time.Second * 1) + + provider, err := NewDNSProviderClient(exoscaleAPIKey, exoscaleAPISecret, "") + assert.NoError(t, err) + + err = provider.CleanUp(exoscaleDomain, "", "123d==") + assert.NoError(t, err) +} diff --git a/vendor/github.com/xenolf/lego/providers/dns/gandi/gandi_test.go b/vendor/github.com/xenolf/lego/providers/dns/gandi/gandi_test.go index 15919e2eb..451333ca1 100644 --- a/vendor/github.com/xenolf/lego/providers/dns/gandi/gandi_test.go +++ b/vendor/github.com/xenolf/lego/providers/dns/gandi/gandi_test.go @@ -141,7 +141,7 @@ func TestDNSProviderLive(t *testing.T) { } // complete the challenge bundle := false - _, failures := client.ObtainCertificate([]string{domain}, bundle, nil) + _, failures := client.ObtainCertificate([]string{domain}, bundle, nil, false) if len(failures) > 0 { t.Fatal(failures) } @@ -496,7 +496,7 @@ var serverResponses = map[string]string{ </member> <member> <name>id</name> -<value><int>3333333333</int></value> +<value><int>333333333</int></value> </member> <member> <name>value</name> diff --git a/vendor/github.com/xenolf/lego/providers/dns/googlecloud/googlecloud.go b/vendor/github.com/xenolf/lego/providers/dns/googlecloud/googlecloud.go index b8d9951c9..ea6c0875c 100644 --- a/vendor/github.com/xenolf/lego/providers/dns/googlecloud/googlecloud.go +++ b/vendor/github.com/xenolf/lego/providers/dns/googlecloud/googlecloud.go @@ -68,6 +68,16 @@ func (c *DNSProvider) Present(domain, token, keyAuth string) error { Additions: []*dns.ResourceRecordSet{rec}, } + // Look for existing records. + list, err := c.client.ResourceRecordSets.List(c.project, zone).Name(fqdn).Type("TXT").Do() + if err != nil { + return err + } + if len(list.Rrsets) > 0 { + // Attempt to delete the existing records when adding our new one. + change.Deletions = list.Rrsets + } + chg, err := c.client.Changes.Create(c.project, zone, change).Do() if err != nil { return err diff --git a/vendor/github.com/xenolf/lego/providers/dns/googlecloud/googlecloud_test.go b/vendor/github.com/xenolf/lego/providers/dns/googlecloud/googlecloud_test.go index d73788163..75a10d9a4 100644 --- a/vendor/github.com/xenolf/lego/providers/dns/googlecloud/googlecloud_test.go +++ b/vendor/github.com/xenolf/lego/providers/dns/googlecloud/googlecloud_test.go @@ -70,6 +70,20 @@ func TestLiveGoogleCloudPresent(t *testing.T) { assert.NoError(t, err) } +func TestLiveGoogleCloudPresentMultiple(t *testing.T) { + if !gcloudLiveTest { + t.Skip("skipping live test") + } + + provider, err := NewDNSProviderCredentials(gcloudProject) + assert.NoError(t, err) + + // Check that we're able to create multiple entries + err = provider.Present(gcloudDomain, "1", "123d==") + err = provider.Present(gcloudDomain, "2", "123d==") + assert.NoError(t, err) +} + func TestLiveGoogleCloudCleanUp(t *testing.T) { if !gcloudLiveTest { t.Skip("skipping live test") diff --git a/vendor/github.com/xenolf/lego/providers/dns/ns1/ns1.go b/vendor/github.com/xenolf/lego/providers/dns/ns1/ns1.go new file mode 100644 index 000000000..105d73f89 --- /dev/null +++ b/vendor/github.com/xenolf/lego/providers/dns/ns1/ns1.go @@ -0,0 +1,97 @@ +// Package ns1 implements a DNS provider for solving the DNS-01 challenge +// using NS1 DNS. +package ns1 + +import ( + "fmt" + "net/http" + "os" + "time" + + "github.com/xenolf/lego/acme" + "gopkg.in/ns1/ns1-go.v2/rest" + "gopkg.in/ns1/ns1-go.v2/rest/model/dns" +) + +// DNSProvider is an implementation of the acme.ChallengeProvider interface. +type DNSProvider struct { + client *rest.Client +} + +// NewDNSProvider returns a DNSProvider instance configured for NS1. +// Credentials must be passed in the environment variables: NS1_API_KEY. +func NewDNSProvider() (*DNSProvider, error) { + key := os.Getenv("NS1_API_KEY") + if key == "" { + return nil, fmt.Errorf("NS1 credentials missing") + } + return NewDNSProviderCredentials(key) +} + +// NewDNSProviderCredentials uses the supplied credentials to return a +// DNSProvider instance configured for NS1. +func NewDNSProviderCredentials(key string) (*DNSProvider, error) { + if key == "" { + return nil, fmt.Errorf("NS1 credentials missing") + } + + httpClient := &http.Client{Timeout: time.Second * 10} + client := rest.NewClient(httpClient, rest.SetAPIKey(key)) + + return &DNSProvider{client}, nil +} + +// Present creates a TXT record to fulfil the dns-01 challenge. +func (c *DNSProvider) Present(domain, token, keyAuth string) error { + fqdn, value, ttl := acme.DNS01Record(domain, keyAuth) + + zone, err := c.getHostedZone(domain) + if err != nil { + return err + } + + record := c.newTxtRecord(zone, fqdn, value, ttl) + _, err = c.client.Records.Create(record) + if err != nil && err != rest.ErrRecordExists { + return err + } + + return nil +} + +// CleanUp removes the TXT record matching the specified parameters. +func (c *DNSProvider) CleanUp(domain, token, keyAuth string) error { + fqdn, _, _ := acme.DNS01Record(domain, keyAuth) + + zone, err := c.getHostedZone(domain) + if err != nil { + return err + } + + name := acme.UnFqdn(fqdn) + _, err = c.client.Records.Delete(zone.Zone, name, "TXT") + return err +} + +func (c *DNSProvider) getHostedZone(domain string) (*dns.Zone, error) { + zone, _, err := c.client.Zones.Get(domain) + if err != nil { + return nil, err + } + + return zone, nil +} + +func (c *DNSProvider) newTxtRecord(zone *dns.Zone, fqdn, value string, ttl int) *dns.Record { + name := acme.UnFqdn(fqdn) + + return &dns.Record{ + Type: "TXT", + Zone: zone.Zone, + Domain: name, + TTL: ttl, + Answers: []*dns.Answer{ + {Rdata: []string{value}}, + }, + } +} diff --git a/vendor/github.com/xenolf/lego/providers/dns/ns1/ns1_test.go b/vendor/github.com/xenolf/lego/providers/dns/ns1/ns1_test.go new file mode 100644 index 000000000..eb9150dde --- /dev/null +++ b/vendor/github.com/xenolf/lego/providers/dns/ns1/ns1_test.go @@ -0,0 +1,67 @@ +package ns1 + +import ( + "os" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +var ( + liveTest bool + apiKey string + domain string +) + +func init() { + apiKey = os.Getenv("NS1_API_KEY") + domain = os.Getenv("NS1_DOMAIN") + if len(apiKey) > 0 && len(domain) > 0 { + liveTest = true + } +} + +func restoreNS1Env() { + os.Setenv("NS1_API_KEY", apiKey) +} + +func TestNewDNSProviderValid(t *testing.T) { + os.Setenv("NS1_API_KEY", "") + _, err := NewDNSProviderCredentials("123") + assert.NoError(t, err) + restoreNS1Env() +} + +func TestNewDNSProviderMissingCredErr(t *testing.T) { + os.Setenv("NS1_API_KEY", "") + _, err := NewDNSProvider() + assert.EqualError(t, err, "NS1 credentials missing") + restoreNS1Env() +} + +func TestLivePresent(t *testing.T) { + if !liveTest { + t.Skip("skipping live test") + } + + provider, err := NewDNSProviderCredentials(apiKey) + assert.NoError(t, err) + + err = provider.Present(domain, "", "123d==") + assert.NoError(t, err) +} + +func TestLiveCleanUp(t *testing.T) { + if !liveTest { + t.Skip("skipping live test") + } + + time.Sleep(time.Second * 1) + + provider, err := NewDNSProviderCredentials(apiKey) + assert.NoError(t, err) + + err = provider.CleanUp(domain, "", "123d==") + assert.NoError(t, err) +} diff --git a/vendor/github.com/xenolf/lego/providers/dns/rackspace/rackspace.go b/vendor/github.com/xenolf/lego/providers/dns/rackspace/rackspace.go new file mode 100644 index 000000000..2b106a27e --- /dev/null +++ b/vendor/github.com/xenolf/lego/providers/dns/rackspace/rackspace.go @@ -0,0 +1,284 @@ +// Package rackspace implements a DNS provider for solving the DNS-01 +// challenge using rackspace DNS. +package rackspace + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "time" + + "github.com/xenolf/lego/acme" +) + +// rackspaceAPIURL represents the Identity API endpoint to call +var rackspaceAPIURL = "https://identity.api.rackspacecloud.com/v2.0/tokens" + +// DNSProvider is an implementation of the acme.ChallengeProvider interface +// used to store the reusable token and DNS API endpoint +type DNSProvider struct { + token string + cloudDNSEndpoint string +} + +// NewDNSProvider returns a DNSProvider instance configured for Rackspace. +// Credentials must be passed in the environment variables: RACKSPACE_USER +// and RACKSPACE_API_KEY. +func NewDNSProvider() (*DNSProvider, error) { + user := os.Getenv("RACKSPACE_USER") + key := os.Getenv("RACKSPACE_API_KEY") + return NewDNSProviderCredentials(user, key) +} + +// NewDNSProviderCredentials uses the supplied credentials to return a +// DNSProvider instance configured for Rackspace. It authenticates against +// the API, also grabbing the DNS Endpoint. +func NewDNSProviderCredentials(user, key string) (*DNSProvider, error) { + if user == "" || key == "" { + return nil, fmt.Errorf("Rackspace credentials missing") + } + + type APIKeyCredentials struct { + Username string `json:"username"` + APIKey string `json:"apiKey"` + } + + type Auth struct { + APIKeyCredentials `json:"RAX-KSKEY:apiKeyCredentials"` + } + + type RackspaceAuthData struct { + Auth `json:"auth"` + } + + type RackspaceIdentity struct { + Access struct { + ServiceCatalog []struct { + Endpoints []struct { + PublicURL string `json:"publicURL"` + TenantID string `json:"tenantId"` + } `json:"endpoints"` + Name string `json:"name"` + } `json:"serviceCatalog"` + Token struct { + ID string `json:"id"` + } `json:"token"` + } `json:"access"` + } + + authData := RackspaceAuthData{ + Auth: Auth{ + APIKeyCredentials: APIKeyCredentials{ + Username: user, + APIKey: key, + }, + }, + } + + body, err := json.Marshal(authData) + if err != nil { + return nil, err + } + + req, err := http.NewRequest("POST", rackspaceAPIURL, bytes.NewReader(body)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + + client := http.Client{Timeout: 30 * time.Second} + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("Error querying Rackspace Identity API: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("Rackspace Authentication failed. Response code: %d", resp.StatusCode) + } + + var rackspaceIdentity RackspaceIdentity + err = json.NewDecoder(resp.Body).Decode(&rackspaceIdentity) + if err != nil { + return nil, err + } + + // Iterate through the Service Catalog to get the DNS Endpoint + var dnsEndpoint string + for _, service := range rackspaceIdentity.Access.ServiceCatalog { + if service.Name == "cloudDNS" { + dnsEndpoint = service.Endpoints[0].PublicURL + break + } + } + if dnsEndpoint == "" { + return nil, fmt.Errorf("Failed to populate DNS endpoint, check Rackspace API for changes.") + } + + return &DNSProvider{ + token: rackspaceIdentity.Access.Token.ID, + cloudDNSEndpoint: dnsEndpoint, + }, nil +} + +// Present creates a TXT record to fulfil the dns-01 challenge +func (c *DNSProvider) Present(domain, token, keyAuth string) error { + fqdn, value, _ := acme.DNS01Record(domain, keyAuth) + zoneID, err := c.getHostedZoneID(fqdn) + if err != nil { + return err + } + + rec := RackspaceRecords{ + RackspaceRecord: []RackspaceRecord{{ + Name: acme.UnFqdn(fqdn), + Type: "TXT", + Data: value, + TTL: 300, + }}, + } + + body, err := json.Marshal(rec) + if err != nil { + return err + } + + _, err = c.makeRequest("POST", fmt.Sprintf("/domains/%d/records", zoneID), bytes.NewReader(body)) + if err != nil { + return err + } + + return nil +} + +// CleanUp removes the TXT record matching the specified parameters +func (c *DNSProvider) CleanUp(domain, token, keyAuth string) error { + fqdn, _, _ := acme.DNS01Record(domain, keyAuth) + zoneID, err := c.getHostedZoneID(fqdn) + if err != nil { + return err + } + + record, err := c.findTxtRecord(fqdn, zoneID) + if err != nil { + return err + } + + _, err = c.makeRequest("DELETE", fmt.Sprintf("/domains/%d/records?id=%s", zoneID, record.ID), nil) + if err != nil { + return err + } + + return nil +} + +// getHostedZoneID performs a lookup to get the DNS zone which needs +// modifying for a given FQDN +func (c *DNSProvider) getHostedZoneID(fqdn string) (int, error) { + // HostedZones represents the response when querying Rackspace DNS zones + type ZoneSearchResponse struct { + TotalEntries int `json:"totalEntries"` + HostedZones []struct { + ID int `json:"id"` + Name string `json:"name"` + } `json:"domains"` + } + + authZone, err := acme.FindZoneByFqdn(fqdn, acme.RecursiveNameservers) + if err != nil { + return 0, err + } + + result, err := c.makeRequest("GET", fmt.Sprintf("/domains?name=%s", acme.UnFqdn(authZone)), nil) + if err != nil { + return 0, err + } + + var zoneSearchResponse ZoneSearchResponse + err = json.Unmarshal(result, &zoneSearchResponse) + if err != nil { + return 0, err + } + + // If nothing was returned, or for whatever reason more than 1 was returned (the search uses exact match, so should not occur) + if zoneSearchResponse.TotalEntries != 1 { + return 0, fmt.Errorf("Found %d zones for %s in Rackspace for domain %s", zoneSearchResponse.TotalEntries, authZone, fqdn) + } + + return zoneSearchResponse.HostedZones[0].ID, nil +} + +// findTxtRecord searches a DNS zone for a TXT record with a specific name +func (c *DNSProvider) findTxtRecord(fqdn string, zoneID int) (*RackspaceRecord, error) { + result, err := c.makeRequest("GET", fmt.Sprintf("/domains/%d/records?type=TXT&name=%s", zoneID, acme.UnFqdn(fqdn)), nil) + if err != nil { + return nil, err + } + + var records RackspaceRecords + err = json.Unmarshal(result, &records) + if err != nil { + return nil, err + } + + recordsLength := len(records.RackspaceRecord) + switch recordsLength { + case 1: + break + case 0: + return nil, fmt.Errorf("No TXT record found for %s", fqdn) + default: + return nil, fmt.Errorf("More than 1 TXT record found for %s", fqdn) + } + + return &records.RackspaceRecord[0], nil +} + +// makeRequest is a wrapper function used for making DNS API requests +func (c *DNSProvider) makeRequest(method, uri string, body io.Reader) (json.RawMessage, error) { + url := c.cloudDNSEndpoint + uri + req, err := http.NewRequest(method, url, body) + if err != nil { + return nil, err + } + + req.Header.Set("X-Auth-Token", c.token) + req.Header.Set("Content-Type", "application/json") + + client := http.Client{Timeout: 30 * time.Second} + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("Error querying DNS API: %v", err) + } + + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted { + return nil, fmt.Errorf("Request failed for %s %s. Response code: %d", method, url, resp.StatusCode) + } + + var r json.RawMessage + err = json.NewDecoder(resp.Body).Decode(&r) + if err != nil { + return nil, fmt.Errorf("JSON decode failed for %s %s. Response code: %d", method, url, resp.StatusCode) + } + + return r, nil +} + +// RackspaceRecords is the list of records sent/recieved from the DNS API +type RackspaceRecords struct { + RackspaceRecord []RackspaceRecord `json:"records"` +} + +// RackspaceRecord represents a Rackspace DNS record +type RackspaceRecord struct { + Name string `json:"name"` + Type string `json:"type"` + Data string `json:"data"` + TTL int `json:"ttl,omitempty"` + ID string `json:"id,omitempty"` +} diff --git a/vendor/github.com/xenolf/lego/providers/dns/rackspace/rackspace_test.go b/vendor/github.com/xenolf/lego/providers/dns/rackspace/rackspace_test.go new file mode 100644 index 000000000..22c979cad --- /dev/null +++ b/vendor/github.com/xenolf/lego/providers/dns/rackspace/rackspace_test.go @@ -0,0 +1,220 @@ +package rackspace + +import ( + "fmt" + "io/ioutil" + "net/http" + "net/http/httptest" + "os" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +var ( + rackspaceLiveTest bool + rackspaceUser string + rackspaceAPIKey string + rackspaceDomain string + testAPIURL string +) + +func init() { + rackspaceUser = os.Getenv("RACKSPACE_USER") + rackspaceAPIKey = os.Getenv("RACKSPACE_API_KEY") + rackspaceDomain = os.Getenv("RACKSPACE_DOMAIN") + if len(rackspaceUser) > 0 && len(rackspaceAPIKey) > 0 && len(rackspaceDomain) > 0 { + rackspaceLiveTest = true + } +} + +func testRackspaceEnv() { + rackspaceAPIURL = testAPIURL + os.Setenv("RACKSPACE_USER", "testUser") + os.Setenv("RACKSPACE_API_KEY", "testKey") +} + +func liveRackspaceEnv() { + rackspaceAPIURL = "https://identity.api.rackspacecloud.com/v2.0/tokens" + os.Setenv("RACKSPACE_USER", rackspaceUser) + os.Setenv("RACKSPACE_API_KEY", rackspaceAPIKey) +} + +func startTestServers() (identityAPI, dnsAPI *httptest.Server) { + dnsAPI = httptest.NewServer(dnsMux()) + dnsEndpoint := dnsAPI.URL + "/123456" + + identityAPI = httptest.NewServer(identityHandler(dnsEndpoint)) + testAPIURL = identityAPI.URL + "/" + return +} + +func closeTestServers(identityAPI, dnsAPI *httptest.Server) { + identityAPI.Close() + dnsAPI.Close() +} + +func identityHandler(dnsEndpoint string) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + reqBody, err := ioutil.ReadAll(r.Body) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + resp, found := jsonMap[string(reqBody)] + if !found { + w.WriteHeader(http.StatusBadRequest) + return + } + resp = strings.Replace(resp, "https://dns.api.rackspacecloud.com/v1.0/123456", dnsEndpoint, 1) + w.WriteHeader(http.StatusOK) + fmt.Fprintf(w, resp) + }) +} + +func dnsMux() *http.ServeMux { + mux := http.NewServeMux() + + // Used by `getHostedZoneID()` finding `zoneID` "?name=example.com" + mux.HandleFunc("/123456/domains", func(w http.ResponseWriter, r *http.Request) { + if r.URL.Query().Get("name") == "example.com" { + w.WriteHeader(http.StatusOK) + fmt.Fprintf(w, jsonMap["zoneDetails"]) + return + } + w.WriteHeader(http.StatusBadRequest) + return + }) + + mux.HandleFunc("/123456/domains/112233/records", func(w http.ResponseWriter, r *http.Request) { + switch r.Method { + // Used by `Present()` creating the TXT record + case http.MethodPost: + reqBody, err := ioutil.ReadAll(r.Body) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + resp, found := jsonMap[string(reqBody)] + if !found { + w.WriteHeader(http.StatusBadRequest) + return + } + w.WriteHeader(http.StatusAccepted) + fmt.Fprintf(w, resp) + // Used by `findTxtRecord()` finding `record.ID` "?type=TXT&name=_acme-challenge.example.com" + case http.MethodGet: + if r.URL.Query().Get("type") == "TXT" && r.URL.Query().Get("name") == "_acme-challenge.example.com" { + w.WriteHeader(http.StatusOK) + fmt.Fprintf(w, jsonMap["recordDetails"]) + return + } + w.WriteHeader(http.StatusBadRequest) + return + // Used by `CleanUp()` deleting the TXT record "?id=445566" + case http.MethodDelete: + if r.URL.Query().Get("id") == "TXT-654321" { + w.WriteHeader(http.StatusOK) + fmt.Fprintf(w, jsonMap["recordDelete"]) + return + } + w.WriteHeader(http.StatusBadRequest) + } + }) + + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + fmt.Printf("Not Found for Request: (%+v)\n\n", r) + }) + + return mux +} + +func TestNewDNSProviderMissingCredErr(t *testing.T) { + testRackspaceEnv() + _, err := NewDNSProviderCredentials("", "") + assert.EqualError(t, err, "Rackspace credentials missing") +} + +func TestOfflineRackspaceValid(t *testing.T) { + testRackspaceEnv() + provider, err := NewDNSProviderCredentials(os.Getenv("RACKSPACE_USER"), os.Getenv("RACKSPACE_API_KEY")) + + assert.NoError(t, err) + assert.Equal(t, provider.token, "testToken", "The token should match") +} + +func TestOfflineRackspacePresent(t *testing.T) { + testRackspaceEnv() + provider, err := NewDNSProvider() + + if assert.NoError(t, err) { + err = provider.Present("example.com", "token", "keyAuth") + assert.NoError(t, err) + } +} + +func TestOfflineRackspaceCleanUp(t *testing.T) { + testRackspaceEnv() + provider, err := NewDNSProvider() + + if assert.NoError(t, err) { + err = provider.CleanUp("example.com", "token", "keyAuth") + assert.NoError(t, err) + } +} + +func TestNewDNSProviderValidEnv(t *testing.T) { + if !rackspaceLiveTest { + t.Skip("skipping live test") + } + + liveRackspaceEnv() + provider, err := NewDNSProvider() + assert.NoError(t, err) + assert.Contains(t, provider.cloudDNSEndpoint, "https://dns.api.rackspacecloud.com/v1.0/", "The endpoint URL should contain the base") +} + +func TestRackspacePresent(t *testing.T) { + if !rackspaceLiveTest { + t.Skip("skipping live test") + } + + liveRackspaceEnv() + provider, err := NewDNSProvider() + assert.NoError(t, err) + + err = provider.Present(rackspaceDomain, "", "112233445566==") + assert.NoError(t, err) +} + +func TestRackspaceCleanUp(t *testing.T) { + if !rackspaceLiveTest { + t.Skip("skipping live test") + } + + time.Sleep(time.Second * 15) + + liveRackspaceEnv() + provider, err := NewDNSProvider() + assert.NoError(t, err) + + err = provider.CleanUp(rackspaceDomain, "", "112233445566==") + assert.NoError(t, err) +} + +func TestMain(m *testing.M) { + identityAPI, dnsAPI := startTestServers() + defer closeTestServers(identityAPI, dnsAPI) + os.Exit(m.Run()) +} + +var jsonMap = map[string]string{ + `{"auth":{"RAX-KSKEY:apiKeyCredentials":{"username":"testUser","apiKey":"testKey"}}}`: `{"access":{"token":{"id":"testToken","expires":"1970-01-01T00:00:00.000Z","tenant":{"id":"123456","name":"123456"},"RAX-AUTH:authenticatedBy":["APIKEY"]},"serviceCatalog":[{"type":"rax:dns","endpoints":[{"publicURL":"https://dns.api.rackspacecloud.com/v1.0/123456","tenantId":"123456"}],"name":"cloudDNS"}],"user":{"id":"fakeUseID","name":"testUser"}}}`, + "zoneDetails": `{"domains":[{"name":"example.com","id":112233,"emailAddress":"hostmaster@example.com","updated":"1970-01-01T00:00:00.000+0000","created":"1970-01-01T00:00:00.000+0000"}],"totalEntries":1}`, + `{"records":[{"name":"_acme-challenge.example.com","type":"TXT","data":"pW9ZKG0xz_PCriK-nCMOjADy9eJcgGWIzkkj2fN4uZM","ttl":300}]}`: `{"request":"{\"records\":[{\"name\":\"_acme-challenge.example.com\",\"type\":\"TXT\",\"data\":\"pW9ZKG0xz_PCriK-nCMOjADy9eJcgGWIzkkj2fN4uZM\",\"ttl\":300}]}","status":"RUNNING","verb":"POST","jobId":"00000000-0000-0000-0000-0000000000","callbackUrl":"https://dns.api.rackspacecloud.com/v1.0/123456/status/00000000-0000-0000-0000-0000000000","requestUrl":"https://dns.api.rackspacecloud.com/v1.0/123456/domains/112233/records"}`, + "recordDetails": `{"records":[{"name":"_acme-challenge.example.com","id":"TXT-654321","type":"TXT","data":"pW9ZKG0xz_PCriK-nCMOjADy9eJcgGWIzkkj2fN4uZM","ttl":300,"updated":"1970-01-01T00:00:00.000+0000","created":"1970-01-01T00:00:00.000+0000"}]}`, + "recordDelete": `{"status":"RUNNING","verb":"DELETE","jobId":"00000000-0000-0000-0000-0000000000","callbackUrl":"https://dns.api.rackspacecloud.com/v1.0/123456/status/00000000-0000-0000-0000-0000000000","requestUrl":"https://dns.api.rackspacecloud.com/v1.0/123456/domains/112233/recordsid=TXT-654321"}`, +} diff --git a/vendor/github.com/xenolf/lego/providers/http/memcached/README.md b/vendor/github.com/xenolf/lego/providers/http/memcached/README.md new file mode 100644 index 000000000..f14d216df --- /dev/null +++ b/vendor/github.com/xenolf/lego/providers/http/memcached/README.md @@ -0,0 +1,15 @@ +# Memcached http provider + +Publishes challenges into memcached where they can be retrieved by nginx. Allows +specifying multiple memcached servers and the responses will be published to all +of them, making it easier to verify when your domain is hosted on a cluster of +servers. + +Example nginx config: + +``` + location /.well-known/acme-challenge/ { + set $memcached_key "$uri"; + memcached_pass 127.0.0.1:11211; + } +``` diff --git a/vendor/github.com/xenolf/lego/providers/http/memcached/memcached.go b/vendor/github.com/xenolf/lego/providers/http/memcached/memcached.go new file mode 100644 index 000000000..9c5f6c0b4 --- /dev/null +++ b/vendor/github.com/xenolf/lego/providers/http/memcached/memcached.go @@ -0,0 +1,59 @@ +// Package webroot implements a HTTP provider for solving the HTTP-01 challenge using web server's root path. +package memcached + +import ( + "fmt" + "path" + + "github.com/rainycape/memcache" + "github.com/xenolf/lego/acme" +) + +// HTTPProvider implements ChallengeProvider for `http-01` challenge +type MemcachedProvider struct { + hosts []string +} + +// NewHTTPProvider returns a HTTPProvider instance with a configured webroot path +func NewMemcachedProvider(hosts []string) (*MemcachedProvider, error) { + if len(hosts) == 0 { + return nil, fmt.Errorf("No memcached hosts provided") + } + + c := &MemcachedProvider{ + hosts: hosts, + } + + return c, nil +} + +// Present makes the token available at `HTTP01ChallengePath(token)` by creating a file in the given webroot path +func (w *MemcachedProvider) Present(domain, token, keyAuth string) error { + var errs []error + + challengePath := path.Join("/", acme.HTTP01ChallengePath(token)) + for _, host := range w.hosts { + mc, err := memcache.New(host) + if err != nil { + errs = append(errs, err) + continue + } + mc.Add(&memcache.Item{ + Key: challengePath, + Value: []byte(keyAuth), + Expiration: 60, + }) + } + + if len(errs) == len(w.hosts) { + return fmt.Errorf("Unable to store key in any of the memcache hosts -> %v", errs) + } + + return nil +} + +// CleanUp removes the file created for the challenge +func (w *MemcachedProvider) CleanUp(domain, token, keyAuth string) error { + // Memcached will clean up itself, that's what expiration is for. + return nil +} diff --git a/vendor/github.com/xenolf/lego/providers/http/memcached/memcached_test.go b/vendor/github.com/xenolf/lego/providers/http/memcached/memcached_test.go new file mode 100644 index 000000000..287a33304 --- /dev/null +++ b/vendor/github.com/xenolf/lego/providers/http/memcached/memcached_test.go @@ -0,0 +1,111 @@ +package memcached + +import ( + "os" + "path" + "strings" + "testing" + + "github.com/rainycape/memcache" + "github.com/stretchr/testify/assert" + "github.com/xenolf/lego/acme" +) + +var ( + memcachedHosts []string +) + +const ( + domain = "lego.test" + token = "foo" + keyAuth = "bar" +) + +func init() { + memcachedHostsStr := os.Getenv("MEMCACHED_HOSTS") + if len(memcachedHostsStr) > 0 { + memcachedHosts = strings.Split(memcachedHostsStr, ",") + } +} + +func TestNewMemcachedProviderEmpty(t *testing.T) { + emptyHosts := make([]string, 0) + _, err := NewMemcachedProvider(emptyHosts) + assert.EqualError(t, err, "No memcached hosts provided") +} + +func TestNewMemcachedProviderValid(t *testing.T) { + if len(memcachedHosts) == 0 { + t.Skip("Skipping memcached tests") + } + _, err := NewMemcachedProvider(memcachedHosts) + assert.NoError(t, err) +} + +func TestMemcachedPresentSingleHost(t *testing.T) { + if len(memcachedHosts) == 0 { + t.Skip("Skipping memcached tests") + } + p, err := NewMemcachedProvider(memcachedHosts[0:1]) + assert.NoError(t, err) + + challengePath := path.Join("/", acme.HTTP01ChallengePath(token)) + + err = p.Present(domain, token, keyAuth) + assert.NoError(t, err) + mc, err := memcache.New(memcachedHosts[0]) + assert.NoError(t, err) + i, err := mc.Get(challengePath) + assert.NoError(t, err) + assert.Equal(t, i.Value, []byte(keyAuth)) +} + +func TestMemcachedPresentMultiHost(t *testing.T) { + if len(memcachedHosts) <= 1 { + t.Skip("Skipping memcached multi-host tests") + } + p, err := NewMemcachedProvider(memcachedHosts) + assert.NoError(t, err) + + challengePath := path.Join("/", acme.HTTP01ChallengePath(token)) + + err = p.Present(domain, token, keyAuth) + assert.NoError(t, err) + for _, host := range memcachedHosts { + mc, err := memcache.New(host) + assert.NoError(t, err) + i, err := mc.Get(challengePath) + assert.NoError(t, err) + assert.Equal(t, i.Value, []byte(keyAuth)) + } +} + +func TestMemcachedPresentPartialFailureMultiHost(t *testing.T) { + if len(memcachedHosts) == 0 { + t.Skip("Skipping memcached tests") + } + hosts := append(memcachedHosts, "5.5.5.5:11211") + p, err := NewMemcachedProvider(hosts) + assert.NoError(t, err) + + challengePath := path.Join("/", acme.HTTP01ChallengePath(token)) + + err = p.Present(domain, token, keyAuth) + assert.NoError(t, err) + for _, host := range memcachedHosts { + mc, err := memcache.New(host) + assert.NoError(t, err) + i, err := mc.Get(challengePath) + assert.NoError(t, err) + assert.Equal(t, i.Value, []byte(keyAuth)) + } +} + +func TestMemcachedCleanup(t *testing.T) { + if len(memcachedHosts) == 0 { + t.Skip("Skipping memcached tests") + } + p, err := NewMemcachedProvider(memcachedHosts) + assert.NoError(t, err) + assert.NoError(t, p.CleanUp(domain, token, keyAuth)) +} |