diff options
Diffstat (limited to 'Godeps/_workspace/src/github.com/garyburd/redigo/redis/conn.go')
-rw-r--r-- | Godeps/_workspace/src/github.com/garyburd/redigo/redis/conn.go | 182 |
1 files changed, 147 insertions, 35 deletions
diff --git a/Godeps/_workspace/src/github.com/garyburd/redigo/redis/conn.go b/Godeps/_workspace/src/github.com/garyburd/redigo/redis/conn.go index ac0e971c4..6a3819f1d 100644 --- a/Godeps/_workspace/src/github.com/garyburd/redigo/redis/conn.go +++ b/Godeps/_workspace/src/github.com/garyburd/redigo/redis/conn.go @@ -21,6 +21,8 @@ import ( "fmt" "io" "net" + "net/url" + "regexp" "strconv" "sync" "time" @@ -51,56 +53,164 @@ type conn struct { numScratch [40]byte } -// Dial connects to the Redis server at the given network and address. -func Dial(network, address string) (Conn, error) { - dialer := xDialer{} - return dialer.Dial(network, address) -} - // DialTimeout acts like Dial but takes timeouts for establishing the // connection to the server, writing a command and reading a reply. +// +// Deprecated: Use Dial with options instead. func DialTimeout(network, address string, connectTimeout, readTimeout, writeTimeout time.Duration) (Conn, error) { - netDialer := net.Dialer{Timeout: connectTimeout} - dialer := xDialer{ - NetDial: netDialer.Dial, - ReadTimeout: readTimeout, - WriteTimeout: writeTimeout, - } - return dialer.Dial(network, address) + return Dial(network, address, + DialConnectTimeout(connectTimeout), + DialReadTimeout(readTimeout), + DialWriteTimeout(writeTimeout)) +} + +// DialOption specifies an option for dialing a Redis server. +type DialOption struct { + f func(*dialOptions) +} + +type dialOptions struct { + readTimeout time.Duration + writeTimeout time.Duration + dial func(network, addr string) (net.Conn, error) + db int + password string +} + +// DialReadTimeout specifies the timeout for reading a single command reply. +func DialReadTimeout(d time.Duration) DialOption { + return DialOption{func(do *dialOptions) { + do.readTimeout = d + }} +} + +// DialWriteTimeout specifies the timeout for writing a single command. +func DialWriteTimeout(d time.Duration) DialOption { + return DialOption{func(do *dialOptions) { + do.writeTimeout = d + }} } -// A Dialer specifies options for connecting to a Redis server. -type xDialer struct { - // NetDial specifies the dial function for creating TCP connections. If - // NetDial is nil, then net.Dial is used. - NetDial func(network, addr string) (net.Conn, error) +// DialConnectTimeout specifies the timeout for connecting to the Redis server. +func DialConnectTimeout(d time.Duration) DialOption { + return DialOption{func(do *dialOptions) { + dialer := net.Dialer{Timeout: d} + do.dial = dialer.Dial + }} +} + +// DialNetDial specifies a custom dial function for creating TCP +// connections. If this option is left out, then net.Dial is +// used. DialNetDial overrides DialConnectTimeout. +func DialNetDial(dial func(network, addr string) (net.Conn, error)) DialOption { + return DialOption{func(do *dialOptions) { + do.dial = dial + }} +} - // ReadTimeout specifies the timeout for reading a single command - // reply. If ReadTimeout is zero, then no timeout is used. - ReadTimeout time.Duration +// DialDatabase specifies the database to select when dialing a connection. +func DialDatabase(db int) DialOption { + return DialOption{func(do *dialOptions) { + do.db = db + }} +} - // WriteTimeout specifies the timeout for writing a single command. If - // WriteTimeout is zero, then no timeout is used. - WriteTimeout time.Duration +// DialPassword specifies the password to use when connecting to +// the Redis server. +func DialPassword(password string) DialOption { + return DialOption{func(do *dialOptions) { + do.password = password + }} } -// Dial connects to the Redis server at address on the named network. -func (d *xDialer) Dial(network, address string) (Conn, error) { - dial := d.NetDial - if dial == nil { - dial = net.Dial +// Dial connects to the Redis server at the given network and +// address using the specified options. +func Dial(network, address string, options ...DialOption) (Conn, error) { + do := dialOptions{ + dial: net.Dial, } - netConn, err := dial(network, address) + for _, option := range options { + option.f(&do) + } + + netConn, err := do.dial(network, address) if err != nil { return nil, err } - return &conn{ + c := &conn{ conn: netConn, bw: bufio.NewWriter(netConn), br: bufio.NewReader(netConn), - readTimeout: d.ReadTimeout, - writeTimeout: d.WriteTimeout, - }, nil + readTimeout: do.readTimeout, + writeTimeout: do.writeTimeout, + } + + if do.password != "" { + if _, err := c.Do("AUTH", do.password); err != nil { + netConn.Close() + return nil, err + } + } + + if do.db != 0 { + if _, err := c.Do("SELECT", do.db); err != nil { + netConn.Close() + return nil, err + } + } + + return c, nil +} + +var pathDBRegexp = regexp.MustCompile(`/(\d+)\z`) + +// DialURL connects to a Redis server at the given URL using the Redis +// URI scheme. URLs should follow the draft IANA specification for the +// scheme (https://www.iana.org/assignments/uri-schemes/prov/redis). +func DialURL(rawurl string, options ...DialOption) (Conn, error) { + u, err := url.Parse(rawurl) + if err != nil { + return nil, err + } + + if u.Scheme != "redis" { + return nil, fmt.Errorf("invalid redis URL scheme: %s", u.Scheme) + } + + // As per the IANA draft spec, the host defaults to localhost and + // the port defaults to 6379. + host, port, err := net.SplitHostPort(u.Host) + if err != nil { + // assume port is missing + host = u.Host + port = "6379" + } + if host == "" { + host = "localhost" + } + address := net.JoinHostPort(host, port) + + if u.User != nil { + password, isSet := u.User.Password() + if isSet { + options = append(options, DialPassword(password)) + } + } + + match := pathDBRegexp.FindStringSubmatch(u.Path) + if len(match) == 2 { + db, err := strconv.Atoi(match[1]) + if err != nil { + return nil, fmt.Errorf("invalid database: %s", u.Path[1:]) + } + if db != 0 { + options = append(options, DialDatabase(db)) + } + } else if u.Path != "" { + return nil, fmt.Errorf("invalid database: %s", u.Path[1:]) + } + + return Dial("tcp", address, options...) } // NewConn returns a new Redigo connection for the given net connection. @@ -417,7 +527,9 @@ func (c *conn) Do(cmd string, args ...interface{}) (interface{}, error) { } if cmd != "" { - c.writeCommand(cmd, args) + if err := c.writeCommand(cmd, args); err != nil { + return nil, c.fatal(err) + } } if err := c.bw.Flush(); err != nil { |