diff options
Diffstat (limited to 'Godeps/_workspace/src/github.com/gorilla/websocket/client.go')
-rw-r--r-- | Godeps/_workspace/src/github.com/gorilla/websocket/client.go | 246 |
1 files changed, 158 insertions, 88 deletions
diff --git a/Godeps/_workspace/src/github.com/gorilla/websocket/client.go b/Godeps/_workspace/src/github.com/gorilla/websocket/client.go index 5bc27e193..3bf9b2e84 100644 --- a/Godeps/_workspace/src/github.com/gorilla/websocket/client.go +++ b/Godeps/_workspace/src/github.com/gorilla/websocket/client.go @@ -5,6 +5,7 @@ package websocket import ( + "bufio" "bytes" "crypto/tls" "errors" @@ -30,50 +31,17 @@ var ErrBadHandshake = errors.New("websocket: bad handshake") // If the WebSocket handshake fails, ErrBadHandshake is returned along with a // non-nil *http.Response so that callers can handle redirects, authentication, // etc. +// +// Deprecated: Use Dialer instead. func NewClient(netConn net.Conn, u *url.URL, requestHeader http.Header, readBufSize, writeBufSize int) (c *Conn, response *http.Response, err error) { - challengeKey, err := generateChallengeKey() - if err != nil { - return nil, nil, err + d := Dialer{ + ReadBufferSize: readBufSize, + WriteBufferSize: writeBufSize, + NetDial: func(net, addr string) (net.Conn, error) { + return netConn, nil + }, } - acceptKey := computeAcceptKey(challengeKey) - - c = newConn(netConn, false, readBufSize, writeBufSize) - p := c.writeBuf[:0] - p = append(p, "GET "...) - p = append(p, u.RequestURI()...) - p = append(p, " HTTP/1.1\r\nHost: "...) - p = append(p, u.Host...) - // "Upgrade" is capitalized for servers that do not use case insensitive - // comparisons on header tokens. - p = append(p, "\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Version: 13\r\nSec-WebSocket-Key: "...) - p = append(p, challengeKey...) - p = append(p, "\r\n"...) - for k, vs := range requestHeader { - for _, v := range vs { - p = append(p, k...) - p = append(p, ": "...) - p = append(p, v...) - p = append(p, "\r\n"...) - } - } - p = append(p, "\r\n"...) - - if _, err := netConn.Write(p); err != nil { - return nil, nil, err - } - - resp, err := http.ReadResponse(c.br, &http.Request{Method: "GET", URL: u}) - if err != nil { - return nil, nil, err - } - if resp.StatusCode != 101 || - !strings.EqualFold(resp.Header.Get("Upgrade"), "websocket") || - !strings.EqualFold(resp.Header.Get("Connection"), "upgrade") || - resp.Header.Get("Sec-Websocket-Accept") != acceptKey { - return nil, resp, ErrBadHandshake - } - c.subprotocol = resp.Header.Get("Sec-Websocket-Protocol") - return c, resp, nil + return d.Dial(u.String(), requestHeader) } // A Dialer contains options for connecting to WebSocket server. @@ -82,6 +50,12 @@ type Dialer struct { // NetDial is nil, net.Dial is used. NetDial func(network, addr string) (net.Conn, error) + // Proxy specifies a function to return a proxy for a given + // Request. If the function returns a non-nil error, the + // request is aborted with the provided error. + // If Proxy is nil or returns a nil *URL, no proxy is used. + Proxy func(*http.Request) (*url.URL, error) + // TLSClientConfig specifies the TLS configuration to use with tls.Client. // If nil, the default configuration is used. TLSClientConfig *tls.Config @@ -99,17 +73,15 @@ type Dialer struct { var errMalformedURL = errors.New("malformed ws or wss URL") -// parseURL parses the URL. The url.Parse function is not used here because -// url.Parse mangles the path. +// parseURL parses the URL. +// +// This function is a replacement for the standard library url.Parse function. +// In Go 1.4 and earlier, url.Parse loses information from the path. func parseURL(s string) (*url.URL, error) { // From the RFC: // // ws-URI = "ws:" "//" host [ ":" port ] path [ "?" query ] // wss-URI = "wss:" "//" host [ ":" port ] path [ "?" query ] - // - // We don't use the net/url parser here because the dialer interface does - // not provide a way for applications to work around percent deocding in - // the net/url parser. var u url.URL switch { @@ -130,6 +102,12 @@ func parseURL(s string) (*url.URL, error) { u.Opaque = s[i:] } + if strings.Contains(u.Host, "@") { + // Don't bother parsing user information because user information is + // not allowed in websocket URIs. + return nil, errMalformedURL + } + return &u, nil } @@ -139,9 +117,12 @@ func hostPortNoPort(u *url.URL) (hostPort, hostNoPort string) { if i := strings.LastIndex(u.Host, ":"); i > strings.LastIndex(u.Host, "]") { hostNoPort = hostNoPort[:i] } else { - if u.Scheme == "wss" { + switch u.Scheme { + case "wss": + hostPort += ":443" + case "https": hostPort += ":443" - } else { + default: hostPort += ":80" } } @@ -149,7 +130,9 @@ func hostPortNoPort(u *url.URL) (hostPort, hostNoPort string) { } // DefaultDialer is a dialer with all fields set to the default zero values. -var DefaultDialer *Dialer +var DefaultDialer = &Dialer{ + Proxy: http.ProxyFromEnvironment, +} // Dial creates a new client connection. Use requestHeader to specify the // origin (Origin), subprotocols (Sec-WebSocket-Protocol) and cookies (Cookie). @@ -161,15 +144,91 @@ var DefaultDialer *Dialer // etcetera. The response body may not contain the entire response and does not // need to be closed by the application. func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) { + + if d == nil { + d = &Dialer{ + Proxy: http.ProxyFromEnvironment, + } + } + + challengeKey, err := generateChallengeKey() + if err != nil { + return nil, nil, err + } + u, err := parseURL(urlStr) if err != nil { return nil, nil, err } + switch u.Scheme { + case "ws": + u.Scheme = "http" + case "wss": + u.Scheme = "https" + default: + return nil, nil, errMalformedURL + } + + if u.User != nil { + // User name and password are not allowed in websocket URIs. + return nil, nil, errMalformedURL + } + + req := &http.Request{ + Method: "GET", + URL: u, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: make(http.Header), + Host: u.Host, + } + + // Set the request headers using the capitalization for names and values in + // RFC examples. Although the capitalization shouldn't matter, there are + // servers that depend on it. The Header.Set method is not used because the + // method canonicalizes the header names. + req.Header["Upgrade"] = []string{"websocket"} + req.Header["Connection"] = []string{"Upgrade"} + req.Header["Sec-WebSocket-Key"] = []string{challengeKey} + req.Header["Sec-WebSocket-Version"] = []string{"13"} + if len(d.Subprotocols) > 0 { + req.Header["Sec-WebSocket-Protocol"] = []string{strings.Join(d.Subprotocols, ", ")} + } + for k, vs := range requestHeader { + switch { + case k == "Host": + if len(vs) > 0 { + req.Host = vs[0] + } + case k == "Upgrade" || + k == "Connection" || + k == "Sec-Websocket-Key" || + k == "Sec-Websocket-Version" || + (k == "Sec-Websocket-Protocol" && len(d.Subprotocols) > 0): + return nil, nil, errors.New("websocket: duplicate header not allowed: " + k) + default: + req.Header[k] = vs + } + } + hostPort, hostNoPort := hostPortNoPort(u) - if d == nil { - d = &Dialer{} + var proxyURL *url.URL + // Check wether the proxy method has been configured + if d.Proxy != nil { + proxyURL, err = d.Proxy(req) + } + if err != nil { + return nil, nil, err + } + + var targetHostPort string + if proxyURL != nil { + targetHostPort, _ = hostPortNoPort(proxyURL) + } else { + targetHostPort = hostPort } var deadline time.Time @@ -183,7 +242,7 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re netDial = netDialer.Dial } - netConn, err := netDial("tcp", hostPort) + netConn, err := netDial("tcp", targetHostPort) if err != nil { return nil, nil, err } @@ -198,7 +257,31 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re return nil, nil, err } - if u.Scheme == "wss" { + if proxyURL != nil { + connectReq := &http.Request{ + Method: "CONNECT", + URL: &url.URL{Opaque: hostPort}, + Host: hostPort, + Header: make(http.Header), + } + + connectReq.Write(netConn) + + // Read response. + // Okay to use and discard buffered reader here, because + // TLS server will not speak until spoken to. + br := bufio.NewReader(netConn) + resp, err := http.ReadResponse(br, connectReq) + if err != nil { + return nil, nil, err + } + if resp.StatusCode != 200 { + f := strings.SplitN(resp.Status, " ", 2) + return nil, nil, errors.New(f[1]) + } + } + + if u.Scheme == "https" { cfg := d.TLSClientConfig if cfg == nil { cfg = &tls.Config{ServerName: hostNoPort} @@ -219,45 +302,32 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re } } - if len(d.Subprotocols) > 0 { - h := http.Header{} - for k, v := range requestHeader { - h[k] = v - } - h.Set("Sec-Websocket-Protocol", strings.Join(d.Subprotocols, ", ")) - requestHeader = h - } - - if len(requestHeader["Host"]) > 0 { - // This can be used to supply a Host: header which is different from - // the dial address. - u.Host = requestHeader.Get("Host") + conn := newConn(netConn, false, d.ReadBufferSize, d.WriteBufferSize) - // Drop "Host" header - h := http.Header{} - for k, v := range requestHeader { - if k == "Host" { - continue - } - h[k] = v - } - requestHeader = h + if err := req.Write(netConn); err != nil { + return nil, nil, err } - conn, resp, err := NewClient(netConn, u, requestHeader, d.ReadBufferSize, d.WriteBufferSize) - + resp, err := http.ReadResponse(conn.br, req) if err != nil { - if err == ErrBadHandshake { - // Before closing the network connection on return from this - // function, slurp up some of the response to aid application - // debugging. - buf := make([]byte, 1024) - n, _ := io.ReadFull(resp.Body, buf) - resp.Body = ioutil.NopCloser(bytes.NewReader(buf[:n])) - } - return nil, resp, err + return nil, nil, err + } + if resp.StatusCode != 101 || + !strings.EqualFold(resp.Header.Get("Upgrade"), "websocket") || + !strings.EqualFold(resp.Header.Get("Connection"), "upgrade") || + resp.Header.Get("Sec-Websocket-Accept") != computeAcceptKey(challengeKey) { + // Before closing the network connection on return from this + // function, slurp up some of the response to aid application + // debugging. + buf := make([]byte, 1024) + n, _ := io.ReadFull(resp.Body, buf) + resp.Body = ioutil.NopCloser(bytes.NewReader(buf[:n])) + return nil, resp, ErrBadHandshake } + resp.Body = ioutil.NopCloser(bytes.NewReader([]byte{})) + conn.subprotocol = resp.Header.Get("Sec-Websocket-Protocol") + netConn.SetDeadline(time.Time{}) netConn = nil // to avoid close in defer. return conn, resp, nil |