diff options
Diffstat (limited to 'vendor/github.com/go-redis/redis')
24 files changed, 2238 insertions, 1129 deletions
diff --git a/vendor/github.com/go-redis/redis/CHANGELOG.md b/vendor/github.com/go-redis/redis/CHANGELOG.md new file mode 100644 index 000000000..7c40d5e38 --- /dev/null +++ b/vendor/github.com/go-redis/redis/CHANGELOG.md @@ -0,0 +1,21 @@ +# Changelog + +## 6.14 + +- Added Options.MinIdleConns. +- Added Options.MaxConnAge. +- PoolStats.FreeConns is renamed to PoolStats.IdleConns. +- Add Client.Do to simplify creating custom commands. +- Add Cmd.String, Cmd.Int, Cmd.Int64, Cmd.Uint64, Cmd.Float64, and Cmd.Bool helpers. +- Lower memory usage. + +## v6.13 + +- Ring got new options called `HashReplicas` and `Hash`. It is recommended to set `HashReplicas = 1000` for better keys distribution between shards. +- Cluster client was optimized to use much less memory when reloading cluster state. +- PubSub.ReceiveMessage is re-worked to not use ReceiveTimeout so it does not lose data when timeout occurres. In most cases it is recommended to use PubSub.Channel instead. +- Dialer.KeepAlive is set to 5 minutes by default. + +## v6.12 + +- ClusterClient got new option called `ClusterSlots` which allows to build cluster of normal Redis Servers that don't have cluster mode enabled. See https://godoc.org/github.com/go-redis/redis#example-NewClusterClient--ManualSetup diff --git a/vendor/github.com/go-redis/redis/README.md b/vendor/github.com/go-redis/redis/README.md index 9f349764a..7d05b4466 100644 --- a/vendor/github.com/go-redis/redis/README.md +++ b/vendor/github.com/go-redis/redis/README.md @@ -15,6 +15,7 @@ Supports: - [Timeouts](https://godoc.org/github.com/go-redis/redis#Options). - [Redis Sentinel](https://godoc.org/github.com/go-redis/redis#NewFailoverClient). - [Redis Cluster](https://godoc.org/github.com/go-redis/redis#NewClusterClient). +- [Cluster of Redis Servers](https://godoc.org/github.com/go-redis/redis#example-NewClusterClient--ManualSetup) without using cluster mode and Redis Sentinel. - [Ring](https://godoc.org/github.com/go-redis/redis#NewRing). - [Instrumentation](https://godoc.org/github.com/go-redis/redis#ex-package--Instrumentation). - [Cache friendly](https://github.com/go-redis/cache). @@ -86,25 +87,27 @@ Please go through [examples](https://godoc.org/github.com/go-redis/redis#pkg-exa Some corner cases: - SET key value EX 10 NX - set, err := client.SetNX("key", "value", 10*time.Second).Result() +```go +// SET key value EX 10 NX +set, err := client.SetNX("key", "value", 10*time.Second).Result() - SORT list LIMIT 0 2 ASC - vals, err := client.Sort("list", redis.Sort{Offset: 0, Count: 2, Order: "ASC"}).Result() +// SORT list LIMIT 0 2 ASC +vals, err := client.Sort("list", redis.Sort{Offset: 0, Count: 2, Order: "ASC"}).Result() - ZRANGEBYSCORE zset -inf +inf WITHSCORES LIMIT 0 2 - vals, err := client.ZRangeByScoreWithScores("zset", redis.ZRangeBy{ - Min: "-inf", - Max: "+inf", - Offset: 0, - Count: 2, - }).Result() +// ZRANGEBYSCORE zset -inf +inf WITHSCORES LIMIT 0 2 +vals, err := client.ZRangeByScoreWithScores("zset", redis.ZRangeBy{ + Min: "-inf", + Max: "+inf", + Offset: 0, + Count: 2, +}).Result() - ZINTERSTORE out 2 zset1 zset2 WEIGHTS 2 3 AGGREGATE SUM - vals, err := client.ZInterStore("out", redis.ZStore{Weights: []int64{2, 3}}, "zset1", "zset2").Result() +// ZINTERSTORE out 2 zset1 zset2 WEIGHTS 2 3 AGGREGATE SUM +vals, err := client.ZInterStore("out", redis.ZStore{Weights: []int64{2, 3}}, "zset1", "zset2").Result() - EVAL "return {KEYS[1],ARGV[1]}" 1 "key" "hello" - vals, err := client.Eval("return {KEYS[1],ARGV[1]}", []string{"key"}, "hello").Result() +// EVAL "return {KEYS[1],ARGV[1]}" 1 "key" "hello" +vals, err := client.Eval("return {KEYS[1],ARGV[1]}", []string{"key"}, "hello").Result() +``` ## Benchmark diff --git a/vendor/github.com/go-redis/redis/cluster.go b/vendor/github.com/go-redis/redis/cluster.go index 0c58c8532..6b05f1a56 100644 --- a/vendor/github.com/go-redis/redis/cluster.go +++ b/vendor/github.com/go-redis/redis/cluster.go @@ -8,7 +8,8 @@ import ( "math" "math/rand" "net" - "strings" + "runtime" + "sort" "sync" "sync/atomic" "time" @@ -30,7 +31,7 @@ type ClusterOptions struct { // The maximum number of retries before giving up. Command is retried // on network errors and MOVED/ASK redirects. - // Default is 8. + // Default is 8 retries. MaxRedirects int // Enables read-only commands on slave nodes. @@ -39,16 +40,25 @@ type ClusterOptions struct { // It automatically enables ReadOnly. RouteByLatency bool // Allows routing read-only commands to the random master or slave node. + // It automatically enables ReadOnly. RouteRandomly bool + // Optional function that returns cluster slots information. + // It is useful to manually create cluster of standalone Redis servers + // and load-balance read/write operations between master and slaves. + // It can use service like ZooKeeper to maintain configuration information + // and Cluster.ReloadState to manually trigger state reloading. + ClusterSlots func() ([]ClusterSlot, error) + // Following options are copied from Options struct. OnConnect func(*Conn) error + Password string + MaxRetries int MinRetryBackoff time.Duration MaxRetryBackoff time.Duration - Password string DialTimeout time.Duration ReadTimeout time.Duration @@ -56,6 +66,8 @@ type ClusterOptions struct { // PoolSize applies per cluster node and not for the whole cluster. PoolSize int + MinIdleConns int + MaxConnAge time.Duration PoolTimeout time.Duration IdleTimeout time.Duration IdleCheckFrequency time.Duration @@ -70,10 +82,14 @@ func (opt *ClusterOptions) init() { opt.MaxRedirects = 8 } - if opt.RouteByLatency { + if opt.RouteByLatency || opt.RouteRandomly { opt.ReadOnly = true } + if opt.PoolSize == 0 { + opt.PoolSize = 5 * runtime.NumCPU() + } + switch opt.ReadTimeout { case -1: opt.ReadTimeout = 0 @@ -117,10 +133,11 @@ func (opt *ClusterOptions) clientOptions() *Options { ReadTimeout: opt.ReadTimeout, WriteTimeout: opt.WriteTimeout, - PoolSize: opt.PoolSize, - PoolTimeout: opt.PoolTimeout, - IdleTimeout: opt.IdleTimeout, - + PoolSize: opt.PoolSize, + MinIdleConns: opt.MinIdleConns, + MaxConnAge: opt.MaxConnAge, + PoolTimeout: opt.PoolTimeout, + IdleTimeout: opt.IdleTimeout, IdleCheckFrequency: disableIdleCheck, TLSConfig: opt.TLSConfig, @@ -160,10 +177,6 @@ func (n *clusterNode) Close() error { return n.Client.Close() } -func (n *clusterNode) Test() error { - return n.Client.ClusterInfo().Err() -} - func (n *clusterNode) updateLatency() { const probes = 10 @@ -330,7 +343,7 @@ func (c *clusterNodes) GetOrCreate(addr string) (*clusterNode, error) { v, err := c.nodeCreateGroup.Do(addr, func() (interface{}, error) { node := newClusterNode(c.opt, addr) - return node, node.Test() + return node, nil }) c.mu.Lock() @@ -383,12 +396,31 @@ func (c *clusterNodes) Random() (*clusterNode, error) { //------------------------------------------------------------------------------ +type clusterSlot struct { + start, end int + nodes []*clusterNode +} + +type clusterSlotSlice []*clusterSlot + +func (p clusterSlotSlice) Len() int { + return len(p) +} + +func (p clusterSlotSlice) Less(i, j int) bool { + return p[i].start < p[j].start +} + +func (p clusterSlotSlice) Swap(i, j int) { + p[i], p[j] = p[j], p[i] +} + type clusterState struct { nodes *clusterNodes Masters []*clusterNode Slaves []*clusterNode - slots [][]*clusterNode + slots []*clusterSlot generation uint32 createdAt time.Time @@ -400,19 +432,21 @@ func newClusterState( c := clusterState{ nodes: nodes, - slots: make([][]*clusterNode, hashtag.SlotNumber), + slots: make([]*clusterSlot, 0, len(slots)), generation: nodes.NextGeneration(), createdAt: time.Now(), } - isLoopbackOrigin := isLoopbackAddr(origin) + originHost, _, _ := net.SplitHostPort(origin) + isLoopbackOrigin := isLoopback(originHost) + for _, slot := range slots { var nodes []*clusterNode for i, slotNode := range slot.Nodes { addr := slotNode.Addr - if !isLoopbackOrigin && useOriginAddr(origin, addr) { - addr = origin + if !isLoopbackOrigin { + addr = replaceLoopbackHost(addr, originHost) } node, err := c.nodes.GetOrCreate(addr) @@ -430,11 +464,15 @@ func newClusterState( } } - for i := slot.Start; i <= slot.End; i++ { - c.slots[i] = nodes - } + c.slots = append(c.slots, &clusterSlot{ + start: slot.Start, + end: slot.End, + nodes: nodes, + }) } + sort.Sort(clusterSlotSlice(c.slots)) + time.AfterFunc(time.Minute, func() { nodes.GC(c.generation) }) @@ -442,6 +480,33 @@ func newClusterState( return &c, nil } +func replaceLoopbackHost(nodeAddr, originHost string) string { + nodeHost, nodePort, err := net.SplitHostPort(nodeAddr) + if err != nil { + return nodeAddr + } + + nodeIP := net.ParseIP(nodeHost) + if nodeIP == nil { + return nodeAddr + } + + if !nodeIP.IsLoopback() { + return nodeAddr + } + + // Use origin host which is not loopback and node port. + return net.JoinHostPort(originHost, nodePort) +} + +func isLoopback(host string) bool { + ip := net.ParseIP(host) + if ip == nil { + return true + } + return ip.IsLoopback() +} + func (c *clusterState) slotMasterNode(slot int) (*clusterNode, error) { nodes := c.slotNodes(slot) if len(nodes) > 0 { @@ -502,32 +567,24 @@ func (c *clusterState) slotRandomNode(slot int) *clusterNode { } func (c *clusterState) slotNodes(slot int) []*clusterNode { - if slot >= 0 && slot < len(c.slots) { - return c.slots[slot] + i := sort.Search(len(c.slots), func(i int) bool { + return c.slots[i].end >= slot + }) + if i >= len(c.slots) { + return nil + } + x := c.slots[i] + if slot >= x.start && slot <= x.end { + return x.nodes } return nil } func (c *clusterState) IsConsistent() bool { - if len(c.Masters) > len(c.Slaves) { - return false - } - - for _, master := range c.Masters { - s := master.Client.Info("replication").Val() - if !strings.Contains(s, "role:master") { - return false - } - } - - for _, slave := range c.Slaves { - s := slave.Client.Info("replication").Val() - if !strings.Contains(s, "role:slave") { - return false - } + if c.nodes.opt.ClusterSlots != nil { + return true } - - return true + return len(c.Masters) <= len(c.Slaves) } //------------------------------------------------------------------------------ @@ -555,7 +612,7 @@ func (c *clusterStateHolder) Reload() (*clusterState, error) { return nil, err } if !state.IsConsistent() { - c.LazyReload() + time.AfterFunc(time.Second, c.LazyReload) } return state, nil } @@ -614,6 +671,14 @@ func (c *clusterStateHolder) Get() (*clusterState, error) { return nil, errors.New("redis: cluster has no state") } +func (c *clusterStateHolder) ReloadOrGet() (*clusterState, error) { + state, err := c.Reload() + if err == nil { + return state, nil + } + return c.Get() +} + //------------------------------------------------------------------------------ // ClusterClient is a Redis Cluster client representing a pool of zero @@ -653,6 +718,8 @@ func NewClusterClient(opt *ClusterOptions) *ClusterClient { c.init() _, _ = c.state.Reload() + _, _ = c.cmdsInfoCache.Get() + if opt.IdleCheckFrequency > 0 { go c.reaper(opt.IdleCheckFrequency) } @@ -660,6 +727,13 @@ func NewClusterClient(opt *ClusterOptions) *ClusterClient { return c } +// ReloadState reloads cluster state. It calls ClusterSlots func +// to get cluster slots information. +func (c *ClusterClient) ReloadState() error { + _, err := c.state.Reload() + return err +} + func (c *ClusterClient) init() { c.cmdable.setProcessor(c.Process) } @@ -818,6 +892,7 @@ func (c *ClusterClient) Watch(fn func(*Tx) error, keys ...string) error { } if internal.IsRetryableError(err, true) { + c.state.LazyReload() continue } @@ -853,6 +928,13 @@ func (c *ClusterClient) Close() error { return c.nodes.Close() } +// Do creates a Cmd from the args and processes the cmd. +func (c *ClusterClient) Do(args ...interface{}) *Cmd { + cmd := NewCmd(args...) + c.Process(cmd) + return cmd +} + func (c *ClusterClient) WrapProcess( fn func(oldProcess func(Cmder) error) func(Cmder) error, ) { @@ -904,12 +986,14 @@ func (c *ClusterClient) defaultProcess(cmd Cmder) error { } if internal.IsRetryableError(err, true) { - // Firstly retry the same node. + c.state.LazyReload() + + // First retry the same node. if attempt == 0 { continue } - // Secondly try random node. + // Second try random node. node, err = c.nodes.Random() if err != nil { break @@ -944,12 +1028,9 @@ func (c *ClusterClient) defaultProcess(cmd Cmder) error { // ForEachMaster concurrently calls the fn on each master node in the cluster. // It returns the first error if any. func (c *ClusterClient) ForEachMaster(fn func(client *Client) error) error { - state, err := c.state.Reload() + state, err := c.state.ReloadOrGet() if err != nil { - state, err = c.state.Get() - if err != nil { - return err - } + return err } var wg sync.WaitGroup @@ -980,12 +1061,9 @@ func (c *ClusterClient) ForEachMaster(fn func(client *Client) error) error { // ForEachSlave concurrently calls the fn on each slave node in the cluster. // It returns the first error if any. func (c *ClusterClient) ForEachSlave(fn func(client *Client) error) error { - state, err := c.state.Reload() + state, err := c.state.ReloadOrGet() if err != nil { - state, err = c.state.Get() - if err != nil { - return err - } + return err } var wg sync.WaitGroup @@ -1016,12 +1094,9 @@ func (c *ClusterClient) ForEachSlave(fn func(client *Client) error) error { // ForEachNode concurrently calls the fn on each known node in the cluster. // It returns the first error if any. func (c *ClusterClient) ForEachNode(fn func(client *Client) error) error { - state, err := c.state.Reload() + state, err := c.state.ReloadOrGet() if err != nil { - state, err = c.state.Get() - if err != nil { - return err - } + return err } var wg sync.WaitGroup @@ -1071,7 +1146,7 @@ func (c *ClusterClient) PoolStats() *PoolStats { acc.Timeouts += s.Timeouts acc.TotalConns += s.TotalConns - acc.FreeConns += s.FreeConns + acc.IdleConns += s.IdleConns acc.StaleConns += s.StaleConns } @@ -1082,7 +1157,7 @@ func (c *ClusterClient) PoolStats() *PoolStats { acc.Timeouts += s.Timeouts acc.TotalConns += s.TotalConns - acc.FreeConns += s.FreeConns + acc.IdleConns += s.IdleConns acc.StaleConns += s.StaleConns } @@ -1090,6 +1165,14 @@ func (c *ClusterClient) PoolStats() *PoolStats { } func (c *ClusterClient) loadState() (*clusterState, error) { + if c.opt.ClusterSlots != nil { + slots, err := c.opt.ClusterSlots() + if err != nil { + return nil, err + } + return newClusterState(c.nodes, slots, "") + } + addrs, err := c.nodes.Addrs() if err != nil { return nil, err @@ -1196,7 +1279,7 @@ func (c *ClusterClient) defaultProcessPipeline(cmds []Cmder) error { cmdsMap = failedCmds } - return firstCmdsErr(cmds) + return cmdsFirstErr(cmds) } func (c *ClusterClient) mapCmdsByNode(cmds []Cmder) (map[*clusterNode][]Cmder, error) { @@ -1207,9 +1290,16 @@ func (c *ClusterClient) mapCmdsByNode(cmds []Cmder) (map[*clusterNode][]Cmder, e } cmdsMap := make(map[*clusterNode][]Cmder) + cmdsAreReadOnly := c.cmdsAreReadOnly(cmds) for _, cmd := range cmds { - slot := c.cmdSlot(cmd) - node, err := state.slotMasterNode(slot) + var node *clusterNode + var err error + if cmdsAreReadOnly { + _, node, err = c.cmdSlotAndNode(cmd) + } else { + slot := c.cmdSlot(cmd) + node, err = state.slotMasterNode(slot) + } if err != nil { return nil, err } @@ -1218,6 +1308,16 @@ func (c *ClusterClient) mapCmdsByNode(cmds []Cmder) (map[*clusterNode][]Cmder, e return cmdsMap, nil } +func (c *ClusterClient) cmdsAreReadOnly(cmds []Cmder) bool { + for _, cmd := range cmds { + cmdInfo := c.cmdInfo(cmd.Name()) + if cmdInfo == nil || !cmdInfo.ReadOnly { + return false + } + } + return true +} + func (c *ClusterClient) remapCmds(cmds []Cmder, failedCmds map[*clusterNode][]Cmder) { remappedCmds, err := c.mapCmdsByNode(cmds) if err != nil { @@ -1233,26 +1333,26 @@ func (c *ClusterClient) remapCmds(cmds []Cmder, failedCmds map[*clusterNode][]Cm func (c *ClusterClient) pipelineProcessCmds( node *clusterNode, cn *pool.Conn, cmds []Cmder, failedCmds map[*clusterNode][]Cmder, ) error { - _ = cn.SetWriteTimeout(c.opt.WriteTimeout) - - err := writeCmd(cn, cmds...) + err := cn.WithWriter(c.opt.WriteTimeout, func(wr *proto.Writer) error { + return writeCmd(wr, cmds...) + }) if err != nil { setCmdsErr(cmds, err) failedCmds[node] = cmds return err } - // Set read timeout for all commands. - _ = cn.SetReadTimeout(c.opt.ReadTimeout) - - return c.pipelineReadCmds(cn, cmds, failedCmds) + err = cn.WithReader(c.opt.ReadTimeout, func(rd *proto.Reader) error { + return c.pipelineReadCmds(rd, cmds, failedCmds) + }) + return err } func (c *ClusterClient) pipelineReadCmds( - cn *pool.Conn, cmds []Cmder, failedCmds map[*clusterNode][]Cmder, + rd *proto.Reader, cmds []Cmder, failedCmds map[*clusterNode][]Cmder, ) error { for _, cmd := range cmds { - err := cmd.readReply(cn) + err := cmd.readReply(rd) if err == nil { continue } @@ -1361,7 +1461,7 @@ func (c *ClusterClient) defaultProcessTxPipeline(cmds []Cmder) error { } } - return firstCmdsErr(cmds) + return cmdsFirstErr(cmds) } func (c *ClusterClient) mapCmdsBySlot(cmds []Cmder) map[int][]Cmder { @@ -1376,35 +1476,37 @@ func (c *ClusterClient) mapCmdsBySlot(cmds []Cmder) map[int][]Cmder { func (c *ClusterClient) txPipelineProcessCmds( node *clusterNode, cn *pool.Conn, cmds []Cmder, failedCmds map[*clusterNode][]Cmder, ) error { - cn.SetWriteTimeout(c.opt.WriteTimeout) - if err := txPipelineWriteMulti(cn, cmds); err != nil { + err := cn.WithWriter(c.opt.WriteTimeout, func(wr *proto.Writer) error { + return txPipelineWriteMulti(wr, cmds) + }) + if err != nil { setCmdsErr(cmds, err) failedCmds[node] = cmds return err } - // Set read timeout for all commands. - cn.SetReadTimeout(c.opt.ReadTimeout) - - if err := c.txPipelineReadQueued(cn, cmds, failedCmds); err != nil { - setCmdsErr(cmds, err) - return err - } - - return pipelineReadCmds(cn, cmds) + err = cn.WithReader(c.opt.ReadTimeout, func(rd *proto.Reader) error { + err := c.txPipelineReadQueued(rd, cmds, failedCmds) + if err != nil { + setCmdsErr(cmds, err) + return err + } + return pipelineReadCmds(rd, cmds) + }) + return err } func (c *ClusterClient) txPipelineReadQueued( - cn *pool.Conn, cmds []Cmder, failedCmds map[*clusterNode][]Cmder, + rd *proto.Reader, cmds []Cmder, failedCmds map[*clusterNode][]Cmder, ) error { // Parse queued replies. var statusCmd StatusCmd - if err := statusCmd.readReply(cn); err != nil { + if err := statusCmd.readReply(rd); err != nil { return err } for _, cmd := range cmds { - err := statusCmd.readReply(cn) + err := statusCmd.readReply(rd) if err == nil { continue } @@ -1417,7 +1519,7 @@ func (c *ClusterClient) txPipelineReadQueued( } // Parse number of replies. - line, err := cn.Rd.ReadLine() + line, err := rd.ReadLine() if err != nil { if err == Nil { err = TxFailedErr @@ -1445,11 +1547,9 @@ func (c *ClusterClient) txPipelineReadQueued( } func (c *ClusterClient) pubSub(channels []string) *PubSub { - opt := c.opt.clientOptions() - var node *clusterNode - return &PubSub{ - opt: opt, + pubsub := &PubSub{ + opt: c.opt.clientOptions(), newConn: func(channels []string) (*pool.Conn, error) { if node == nil { @@ -1472,6 +1572,8 @@ func (c *ClusterClient) pubSub(channels []string) *PubSub { return node.Client.connPool.CloseConn(cn) }, } + pubsub.init() + return pubsub } // Subscribe subscribes the client to the specified channels. @@ -1494,43 +1596,6 @@ func (c *ClusterClient) PSubscribe(channels ...string) *PubSub { return pubsub } -func useOriginAddr(originAddr, nodeAddr string) bool { - nodeHost, nodePort, err := net.SplitHostPort(nodeAddr) - if err != nil { - return false - } - - nodeIP := net.ParseIP(nodeHost) - if nodeIP == nil { - return false - } - - if !nodeIP.IsLoopback() { - return false - } - - _, originPort, err := net.SplitHostPort(originAddr) - if err != nil { - return false - } - - return nodePort == originPort -} - -func isLoopbackAddr(addr string) bool { - host, _, err := net.SplitHostPort(addr) - if err != nil { - return false - } - - ip := net.ParseIP(host) - if ip == nil { - return false - } - - return ip.IsLoopback() -} - func appendUniqueNode(nodes []*clusterNode, node *clusterNode) []*clusterNode { for _, n := range nodes { if n == node { diff --git a/vendor/github.com/go-redis/redis/command.go b/vendor/github.com/go-redis/redis/command.go index 552c897bb..ca44d7c8b 100644 --- a/vendor/github.com/go-redis/redis/command.go +++ b/vendor/github.com/go-redis/redis/command.go @@ -1,16 +1,14 @@ package redis import ( - "bytes" "fmt" + "net" "strconv" "strings" "time" "github.com/go-redis/redis/internal" - "github.com/go-redis/redis/internal/pool" "github.com/go-redis/redis/internal/proto" - "github.com/go-redis/redis/internal/util" ) type Cmder interface { @@ -18,13 +16,12 @@ type Cmder interface { Args() []interface{} stringArg(int) string - readReply(*pool.Conn) error + readReply(rd *proto.Reader) error setErr(error) readTimeout() *time.Duration Err() error - fmt.Stringer } func setCmdsErr(cmds []Cmder, e error) { @@ -35,7 +32,7 @@ func setCmdsErr(cmds []Cmder, e error) { } } -func firstCmdsErr(cmds []Cmder) error { +func cmdsFirstErr(cmds []Cmder) error { for _, cmd := range cmds { if err := cmd.Err(); err != nil { return err @@ -44,16 +41,14 @@ func firstCmdsErr(cmds []Cmder) error { return nil } -func writeCmd(cn *pool.Conn, cmds ...Cmder) error { - cn.Wb.Reset() +func writeCmd(wr *proto.Writer, cmds ...Cmder) error { for _, cmd := range cmds { - if err := cn.Wb.Append(cmd.Args()); err != nil { + err := wr.WriteArgs(cmd.Args()) + if err != nil { return err } } - - _, err := cn.Write(cn.Wb.Bytes()) - return err + return nil } func cmdString(cmd Cmder, val interface{}) string { @@ -165,20 +160,124 @@ func (cmd *Cmd) Result() (interface{}, error) { return cmd.val, cmd.err } -func (cmd *Cmd) String() string { - return cmdString(cmd, cmd.val) +func (cmd *Cmd) String() (string, error) { + if cmd.err != nil { + return "", cmd.err + } + switch val := cmd.val.(type) { + case string: + return val, nil + default: + err := fmt.Errorf("redis: unexpected type=%T for String", val) + return "", err + } } -func (cmd *Cmd) readReply(cn *pool.Conn) error { - cmd.val, cmd.err = cn.Rd.ReadReply(sliceParser) +func (cmd *Cmd) Int() (int, error) { if cmd.err != nil { - return cmd.err + return 0, cmd.err + } + switch val := cmd.val.(type) { + case int64: + return int(val), nil + case string: + return strconv.Atoi(val) + default: + err := fmt.Errorf("redis: unexpected type=%T for Int64", val) + return 0, err } - if b, ok := cmd.val.([]byte); ok { - // Bytes must be copied, because underlying memory is reused. - cmd.val = string(b) +} + +func (cmd *Cmd) Int64() (int64, error) { + if cmd.err != nil { + return 0, cmd.err + } + switch val := cmd.val.(type) { + case int64: + return val, nil + case string: + return strconv.ParseInt(val, 10, 64) + default: + err := fmt.Errorf("redis: unexpected type=%T for Int64", val) + return 0, err } - return nil +} + +func (cmd *Cmd) Uint64() (uint64, error) { + if cmd.err != nil { + return 0, cmd.err + } + switch val := cmd.val.(type) { + case int64: + return uint64(val), nil + case string: + return strconv.ParseUint(val, 10, 64) + default: + err := fmt.Errorf("redis: unexpected type=%T for Uint64", val) + return 0, err + } +} + +func (cmd *Cmd) Float64() (float64, error) { + if cmd.err != nil { + return 0, cmd.err + } + switch val := cmd.val.(type) { + case int64: + return float64(val), nil + case string: + return strconv.ParseFloat(val, 64) + default: + err := fmt.Errorf("redis: unexpected type=%T for Float64", val) + return 0, err + } +} + +func (cmd *Cmd) Bool() (bool, error) { + if cmd.err != nil { + return false, cmd.err + } + switch val := cmd.val.(type) { + case int64: + return val != 0, nil + case string: + return strconv.ParseBool(val) + default: + err := fmt.Errorf("redis: unexpected type=%T for Bool", val) + return false, err + } +} + +func (cmd *Cmd) readReply(rd *proto.Reader) error { + cmd.val, cmd.err = rd.ReadReply(sliceParser) + return cmd.err +} + +// Implements proto.MultiBulkParse +func sliceParser(rd *proto.Reader, n int64) (interface{}, error) { + vals := make([]interface{}, 0, n) + for i := int64(0); i < n; i++ { + v, err := rd.ReadReply(sliceParser) + if err != nil { + if err == Nil { + vals = append(vals, nil) + continue + } + if err, ok := err.(proto.RedisError); ok { + vals = append(vals, err) + continue + } + return nil, err + } + + switch v := v.(type) { + case string: + vals = append(vals, v) + default: + vals = append(vals, v) + } + } + return vals, nil } //------------------------------------------------------------------------------ @@ -209,9 +308,9 @@ func (cmd *SliceCmd) String() string { return cmdString(cmd, cmd.val) } -func (cmd *SliceCmd) readReply(cn *pool.Conn) error { +func (cmd *SliceCmd) readReply(rd *proto.Reader) error { var v interface{} - v, cmd.err = cn.Rd.ReadArrayReply(sliceParser) + v, cmd.err = rd.ReadArrayReply(sliceParser) if cmd.err != nil { return cmd.err } @@ -247,8 +346,8 @@ func (cmd *StatusCmd) String() string { return cmdString(cmd, cmd.val) } -func (cmd *StatusCmd) readReply(cn *pool.Conn) error { - cmd.val, cmd.err = cn.Rd.ReadStringReply() +func (cmd *StatusCmd) readReply(rd *proto.Reader) error { + cmd.val, cmd.err = rd.ReadString() return cmd.err } @@ -280,8 +379,8 @@ func (cmd *IntCmd) String() string { return cmdString(cmd, cmd.val) } -func (cmd *IntCmd) readReply(cn *pool.Conn) error { - cmd.val, cmd.err = cn.Rd.ReadIntReply() +func (cmd *IntCmd) readReply(rd *proto.Reader) error { + cmd.val, cmd.err = rd.ReadIntReply() return cmd.err } @@ -315,9 +414,9 @@ func (cmd *DurationCmd) String() string { return cmdString(cmd, cmd.val) } -func (cmd *DurationCmd) readReply(cn *pool.Conn) error { +func (cmd *DurationCmd) readReply(rd *proto.Reader) error { var n int64 - n, cmd.err = cn.Rd.ReadIntReply() + n, cmd.err = rd.ReadIntReply() if cmd.err != nil { return cmd.err } @@ -353,9 +452,9 @@ func (cmd *TimeCmd) String() string { return cmdString(cmd, cmd.val) } -func (cmd *TimeCmd) readReply(cn *pool.Conn) error { +func (cmd *TimeCmd) readReply(rd *proto.Reader) error { var v interface{} - v, cmd.err = cn.Rd.ReadArrayReply(timeParser) + v, cmd.err = rd.ReadArrayReply(timeParser) if cmd.err != nil { return cmd.err } @@ -363,6 +462,25 @@ func (cmd *TimeCmd) readReply(cn *pool.Conn) error { return nil } +// Implements proto.MultiBulkParse +func timeParser(rd *proto.Reader, n int64) (interface{}, error) { + if n != 2 { + return nil, fmt.Errorf("got %d elements, expected 2", n) + } + + sec, err := rd.ReadInt() + if err != nil { + return nil, err + } + + microsec, err := rd.ReadInt() + if err != nil { + return nil, err + } + + return time.Unix(sec, microsec*1000), nil +} + //------------------------------------------------------------------------------ type BoolCmd struct { @@ -391,11 +509,9 @@ func (cmd *BoolCmd) String() string { return cmdString(cmd, cmd.val) } -var ok = []byte("OK") - -func (cmd *BoolCmd) readReply(cn *pool.Conn) error { +func (cmd *BoolCmd) readReply(rd *proto.Reader) error { var v interface{} - v, cmd.err = cn.Rd.ReadReply(nil) + v, cmd.err = rd.ReadReply(nil) // `SET key value NX` returns nil when key already exists. But // `SETNX key value` returns bool (0/1). So convert nil to bool. // TODO: is this okay? @@ -411,8 +527,8 @@ func (cmd *BoolCmd) readReply(cn *pool.Conn) error { case int64: cmd.val = v == 1 return nil - case []byte: - cmd.val = bytes.Equal(v, ok) + case string: + cmd.val = v == "OK" return nil default: cmd.err = fmt.Errorf("got %T, wanted int64 or string", v) @@ -425,7 +541,7 @@ func (cmd *BoolCmd) readReply(cn *pool.Conn) error { type StringCmd struct { baseCmd - val []byte + val string } var _ Cmder = (*StringCmd)(nil) @@ -437,7 +553,7 @@ func NewStringCmd(args ...interface{}) *StringCmd { } func (cmd *StringCmd) Val() string { - return util.BytesToString(cmd.val) + return cmd.val } func (cmd *StringCmd) Result() (string, error) { @@ -445,7 +561,14 @@ func (cmd *StringCmd) Result() (string, error) { } func (cmd *StringCmd) Bytes() ([]byte, error) { - return cmd.val, cmd.err + return []byte(cmd.val), cmd.err +} + +func (cmd *StringCmd) Int() (int, error) { + if cmd.err != nil { + return 0, cmd.err + } + return strconv.Atoi(cmd.Val()) } func (cmd *StringCmd) Int64() (int64, error) { @@ -473,15 +596,15 @@ func (cmd *StringCmd) Scan(val interface{}) error { if cmd.err != nil { return cmd.err } - return proto.Scan(cmd.val, val) + return proto.Scan([]byte(cmd.val), val) } func (cmd *StringCmd) String() string { return cmdString(cmd, cmd.val) } -func (cmd *StringCmd) readReply(cn *pool.Conn) error { - cmd.val, cmd.err = cn.Rd.ReadBytesReply() +func (cmd *StringCmd) readReply(rd *proto.Reader) error { + cmd.val, cmd.err = rd.ReadString() return cmd.err } @@ -513,8 +636,8 @@ func (cmd *FloatCmd) String() string { return cmdString(cmd, cmd.val) } -func (cmd *FloatCmd) readReply(cn *pool.Conn) error { - cmd.val, cmd.err = cn.Rd.ReadFloatReply() +func (cmd *FloatCmd) readReply(rd *proto.Reader) error { + cmd.val, cmd.err = rd.ReadFloatReply() return cmd.err } @@ -550,9 +673,9 @@ func (cmd *StringSliceCmd) ScanSlice(container interface{}) error { return proto.ScanSlice(cmd.Val(), container) } -func (cmd *StringSliceCmd) readReply(cn *pool.Conn) error { +func (cmd *StringSliceCmd) readReply(rd *proto.Reader) error { var v interface{} - v, cmd.err = cn.Rd.ReadArrayReply(stringSliceParser) + v, cmd.err = rd.ReadArrayReply(stringSliceParser) if cmd.err != nil { return cmd.err } @@ -560,6 +683,22 @@ func (cmd *StringSliceCmd) readReply(cn *pool.Conn) error { return nil } +// Implements proto.MultiBulkParse +func stringSliceParser(rd *proto.Reader, n int64) (interface{}, error) { + ss := make([]string, 0, n) + for i := int64(0); i < n; i++ { + s, err := rd.ReadString() + if err == Nil { + ss = append(ss, "") + } else if err != nil { + return nil, err + } else { + ss = append(ss, s) + } + } + return ss, nil +} + //------------------------------------------------------------------------------ type BoolSliceCmd struct { @@ -588,9 +727,9 @@ func (cmd *BoolSliceCmd) String() string { return cmdString(cmd, cmd.val) } -func (cmd *BoolSliceCmd) readReply(cn *pool.Conn) error { +func (cmd *BoolSliceCmd) readReply(rd *proto.Reader) error { var v interface{} - v, cmd.err = cn.Rd.ReadArrayReply(boolSliceParser) + v, cmd.err = rd.ReadArrayReply(boolSliceParser) if cmd.err != nil { return cmd.err } @@ -598,6 +737,19 @@ func (cmd *BoolSliceCmd) readReply(cn *pool.Conn) error { return nil } +// Implements proto.MultiBulkParse +func boolSliceParser(rd *proto.Reader, n int64) (interface{}, error) { + bools := make([]bool, 0, n) + for i := int64(0); i < n; i++ { + n, err := rd.ReadIntReply() + if err != nil { + return nil, err + } + bools = append(bools, n == 1) + } + return bools, nil +} + //------------------------------------------------------------------------------ type StringStringMapCmd struct { @@ -626,9 +778,9 @@ func (cmd *StringStringMapCmd) String() string { return cmdString(cmd, cmd.val) } -func (cmd *StringStringMapCmd) readReply(cn *pool.Conn) error { +func (cmd *StringStringMapCmd) readReply(rd *proto.Reader) error { var v interface{} - v, cmd.err = cn.Rd.ReadArrayReply(stringStringMapParser) + v, cmd.err = rd.ReadArrayReply(stringStringMapParser) if cmd.err != nil { return cmd.err } @@ -636,6 +788,25 @@ func (cmd *StringStringMapCmd) readReply(cn *pool.Conn) error { return nil } +// Implements proto.MultiBulkParse +func stringStringMapParser(rd *proto.Reader, n int64) (interface{}, error) { + m := make(map[string]string, n/2) + for i := int64(0); i < n; i += 2 { + key, err := rd.ReadString() + if err != nil { + return nil, err + } + + value, err := rd.ReadString() + if err != nil { + return nil, err + } + + m[key] = value + } + return m, nil +} + //------------------------------------------------------------------------------ type StringIntMapCmd struct { @@ -664,9 +835,9 @@ func (cmd *StringIntMapCmd) String() string { return cmdString(cmd, cmd.val) } -func (cmd *StringIntMapCmd) readReply(cn *pool.Conn) error { +func (cmd *StringIntMapCmd) readReply(rd *proto.Reader) error { var v interface{} - v, cmd.err = cn.Rd.ReadArrayReply(stringIntMapParser) + v, cmd.err = rd.ReadArrayReply(stringIntMapParser) if cmd.err != nil { return cmd.err } @@ -674,6 +845,25 @@ func (cmd *StringIntMapCmd) readReply(cn *pool.Conn) error { return nil } +// Implements proto.MultiBulkParse +func stringIntMapParser(rd *proto.Reader, n int64) (interface{}, error) { + m := make(map[string]int64, n/2) + for i := int64(0); i < n; i += 2 { + key, err := rd.ReadString() + if err != nil { + return nil, err + } + + n, err := rd.ReadIntReply() + if err != nil { + return nil, err + } + + m[key] = n + } + return m, nil +} + //------------------------------------------------------------------------------ type StringStructMapCmd struct { @@ -702,9 +892,9 @@ func (cmd *StringStructMapCmd) String() string { return cmdString(cmd, cmd.val) } -func (cmd *StringStructMapCmd) readReply(cn *pool.Conn) error { +func (cmd *StringStructMapCmd) readReply(rd *proto.Reader) error { var v interface{} - v, cmd.err = cn.Rd.ReadArrayReply(stringStructMapParser) + v, cmd.err = rd.ReadArrayReply(stringStructMapParser) if cmd.err != nil { return cmd.err } @@ -712,6 +902,380 @@ func (cmd *StringStructMapCmd) readReply(cn *pool.Conn) error { return nil } +// Implements proto.MultiBulkParse +func stringStructMapParser(rd *proto.Reader, n int64) (interface{}, error) { + m := make(map[string]struct{}, n) + for i := int64(0); i < n; i++ { + key, err := rd.ReadString() + if err != nil { + return nil, err + } + + m[key] = struct{}{} + } + return m, nil +} + +//------------------------------------------------------------------------------ + +type XMessage struct { + ID string + Values map[string]interface{} +} + +type XMessageSliceCmd struct { + baseCmd + + val []XMessage +} + +var _ Cmder = (*XMessageSliceCmd)(nil) + +func NewXMessageSliceCmd(args ...interface{}) *XMessageSliceCmd { + return &XMessageSliceCmd{ + baseCmd: baseCmd{_args: args}, + } +} + +func (cmd *XMessageSliceCmd) Val() []XMessage { + return cmd.val +} + +func (cmd *XMessageSliceCmd) Result() ([]XMessage, error) { + return cmd.val, cmd.err +} + +func (cmd *XMessageSliceCmd) String() string { + return cmdString(cmd, cmd.val) +} + +func (cmd *XMessageSliceCmd) readReply(rd *proto.Reader) error { + var v interface{} + v, cmd.err = rd.ReadArrayReply(xMessageSliceParser) + if cmd.err != nil { + return cmd.err + } + cmd.val = v.([]XMessage) + return nil +} + +// Implements proto.MultiBulkParse +func xMessageSliceParser(rd *proto.Reader, n int64) (interface{}, error) { + msgs := make([]XMessage, 0, n) + for i := int64(0); i < n; i++ { + _, err := rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) { + id, err := rd.ReadString() + if err != nil { + return nil, err + } + + v, err := rd.ReadArrayReply(stringInterfaceMapParser) + if err != nil { + return nil, err + } + + msgs = append(msgs, XMessage{ + ID: id, + Values: v.(map[string]interface{}), + }) + return nil, nil + }) + if err != nil { + return nil, err + } + } + return msgs, nil +} + +// Implements proto.MultiBulkParse +func stringInterfaceMapParser(rd *proto.Reader, n int64) (interface{}, error) { + m := make(map[string]interface{}, n/2) + for i := int64(0); i < n; i += 2 { + key, err := rd.ReadString() + if err != nil { + return nil, err + } + + value, err := rd.ReadString() + if err != nil { + return nil, err + } + + m[key] = value + } + return m, nil +} + +//------------------------------------------------------------------------------ + +type XStream struct { + Stream string + Messages []XMessage +} + +type XStreamSliceCmd struct { + baseCmd + + val []XStream +} + +var _ Cmder = (*XStreamSliceCmd)(nil) + +func NewXStreamSliceCmd(args ...interface{}) *XStreamSliceCmd { + return &XStreamSliceCmd{ + baseCmd: baseCmd{_args: args}, + } +} + +func (cmd *XStreamSliceCmd) Val() []XStream { + return cmd.val +} + +func (cmd *XStreamSliceCmd) Result() ([]XStream, error) { + return cmd.val, cmd.err +} + +func (cmd *XStreamSliceCmd) String() string { + return cmdString(cmd, cmd.val) +} + +func (cmd *XStreamSliceCmd) readReply(rd *proto.Reader) error { + var v interface{} + v, cmd.err = rd.ReadArrayReply(xStreamSliceParser) + if cmd.err != nil { + return cmd.err + } + cmd.val = v.([]XStream) + return nil +} + +// Implements proto.MultiBulkParse +func xStreamSliceParser(rd *proto.Reader, n int64) (interface{}, error) { + ret := make([]XStream, 0, n) + for i := int64(0); i < n; i++ { + _, err := rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) { + if n != 2 { + return nil, fmt.Errorf("got %d, wanted 2", n) + } + + stream, err := rd.ReadString() + if err != nil { + return nil, err + } + + v, err := rd.ReadArrayReply(xMessageSliceParser) + if err != nil { + return nil, err + } + + ret = append(ret, XStream{ + Stream: stream, + Messages: v.([]XMessage), + }) + return nil, nil + }) + if err != nil { + return nil, err + } + } + return ret, nil +} + +//------------------------------------------------------------------------------ + +type XPending struct { + Count int64 + Lower string + Higher string + Consumers map[string]int64 +} + +type XPendingCmd struct { + baseCmd + val *XPending +} + +var _ Cmder = (*XPendingCmd)(nil) + +func NewXPendingCmd(args ...interface{}) *XPendingCmd { + return &XPendingCmd{ + baseCmd: baseCmd{_args: args}, + } +} + +func (cmd *XPendingCmd) Val() *XPending { + return cmd.val +} + +func (cmd *XPendingCmd) Result() (*XPending, error) { + return cmd.val, cmd.err +} + +func (cmd *XPendingCmd) String() string { + return cmdString(cmd, cmd.val) +} + +func (cmd *XPendingCmd) readReply(rd *proto.Reader) error { + var info interface{} + info, cmd.err = rd.ReadArrayReply(xPendingParser) + if cmd.err != nil { + return cmd.err + } + cmd.val = info.(*XPending) + return nil +} + +func xPendingParser(rd *proto.Reader, n int64) (interface{}, error) { + if n != 4 { + return nil, fmt.Errorf("got %d, wanted 4", n) + } + + count, err := rd.ReadIntReply() + if err != nil { + return nil, err + } + + lower, err := rd.ReadString() + if err != nil && err != Nil { + return nil, err + } + + higher, err := rd.ReadString() + if err != nil && err != Nil { + return nil, err + } + + pending := &XPending{ + Count: count, + Lower: lower, + Higher: higher, + } + _, err = rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) { + for i := int64(0); i < n; i++ { + _, err = rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) { + if n != 2 { + return nil, fmt.Errorf("got %d, wanted 2", n) + } + + consumerName, err := rd.ReadString() + if err != nil { + return nil, err + } + + consumerPending, err := rd.ReadInt() + if err != nil { + return nil, err + } + + if pending.Consumers == nil { + pending.Consumers = make(map[string]int64) + } + pending.Consumers[consumerName] = consumerPending + + return nil, nil + }) + if err != nil { + return nil, err + } + } + return nil, nil + }) + if err != nil && err != Nil { + return nil, err + } + + return pending, nil +} + +//------------------------------------------------------------------------------ + +type XPendingExt struct { + Id string + Consumer string + Idle time.Duration + RetryCount int64 +} + +type XPendingExtCmd struct { + baseCmd + val []XPendingExt +} + +var _ Cmder = (*XPendingExtCmd)(nil) + +func NewXPendingExtCmd(args ...interface{}) *XPendingExtCmd { + return &XPendingExtCmd{ + baseCmd: baseCmd{_args: args}, + } +} + +func (cmd *XPendingExtCmd) Val() []XPendingExt { + return cmd.val +} + +func (cmd *XPendingExtCmd) Result() ([]XPendingExt, error) { + return cmd.val, cmd.err +} + +func (cmd *XPendingExtCmd) String() string { + return cmdString(cmd, cmd.val) +} + +func (cmd *XPendingExtCmd) readReply(rd *proto.Reader) error { + var info interface{} + info, cmd.err = rd.ReadArrayReply(xPendingExtSliceParser) + if cmd.err != nil { + return cmd.err + } + cmd.val = info.([]XPendingExt) + return nil +} + +func xPendingExtSliceParser(rd *proto.Reader, n int64) (interface{}, error) { + ret := make([]XPendingExt, 0, n) + for i := int64(0); i < n; i++ { + _, err := rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) { + if n != 4 { + return nil, fmt.Errorf("got %d, wanted 4", n) + } + + id, err := rd.ReadString() + if err != nil { + return nil, err + } + + consumer, err := rd.ReadString() + if err != nil && err != Nil { + return nil, err + } + + idle, err := rd.ReadIntReply() + if err != nil && err != Nil { + return nil, err + } + + retryCount, err := rd.ReadIntReply() + if err != nil && err != Nil { + return nil, err + } + + ret = append(ret, XPendingExt{ + Id: id, + Consumer: consumer, + Idle: time.Duration(idle) * time.Millisecond, + RetryCount: retryCount, + }) + return nil, nil + }) + if err != nil { + return nil, err + } + } + return ret, nil +} + +//------------------------------------------------------------------------------ + //------------------------------------------------------------------------------ type ZSliceCmd struct { @@ -740,9 +1304,9 @@ func (cmd *ZSliceCmd) String() string { return cmdString(cmd, cmd.val) } -func (cmd *ZSliceCmd) readReply(cn *pool.Conn) error { +func (cmd *ZSliceCmd) readReply(rd *proto.Reader) error { var v interface{} - v, cmd.err = cn.Rd.ReadArrayReply(zSliceParser) + v, cmd.err = rd.ReadArrayReply(zSliceParser) if cmd.err != nil { return cmd.err } @@ -750,6 +1314,27 @@ func (cmd *ZSliceCmd) readReply(cn *pool.Conn) error { return nil } +// Implements proto.MultiBulkParse +func zSliceParser(rd *proto.Reader, n int64) (interface{}, error) { + zz := make([]Z, n/2) + for i := int64(0); i < n; i += 2 { + var err error + + z := &zz[i/2] + + z.Member, err = rd.ReadString() + if err != nil { + return nil, err + } + + z.Score, err = rd.ReadFloatReply() + if err != nil { + return nil, err + } + } + return zz, nil +} + //------------------------------------------------------------------------------ type ScanCmd struct { @@ -782,8 +1367,8 @@ func (cmd *ScanCmd) String() string { return cmdString(cmd, cmd.page) } -func (cmd *ScanCmd) readReply(cn *pool.Conn) error { - cmd.page, cmd.cursor, cmd.err = cn.Rd.ReadScanReply() +func (cmd *ScanCmd) readReply(rd *proto.Reader) error { + cmd.page, cmd.cursor, cmd.err = rd.ReadScanReply() return cmd.err } @@ -833,9 +1418,9 @@ func (cmd *ClusterSlotsCmd) String() string { return cmdString(cmd, cmd.val) } -func (cmd *ClusterSlotsCmd) readReply(cn *pool.Conn) error { +func (cmd *ClusterSlotsCmd) readReply(rd *proto.Reader) error { var v interface{} - v, cmd.err = cn.Rd.ReadArrayReply(clusterSlotsParser) + v, cmd.err = rd.ReadArrayReply(clusterSlotsParser) if cmd.err != nil { return cmd.err } @@ -843,6 +1428,70 @@ func (cmd *ClusterSlotsCmd) readReply(cn *pool.Conn) error { return nil } +// Implements proto.MultiBulkParse +func clusterSlotsParser(rd *proto.Reader, n int64) (interface{}, error) { + slots := make([]ClusterSlot, n) + for i := 0; i < len(slots); i++ { + n, err := rd.ReadArrayLen() + if err != nil { + return nil, err + } + if n < 2 { + err := fmt.Errorf("redis: got %d elements in cluster info, expected at least 2", n) + return nil, err + } + + start, err := rd.ReadIntReply() + if err != nil { + return nil, err + } + + end, err := rd.ReadIntReply() + if err != nil { + return nil, err + } + + nodes := make([]ClusterNode, n-2) + for j := 0; j < len(nodes); j++ { + n, err := rd.ReadArrayLen() + if err != nil { + return nil, err + } + if n != 2 && n != 3 { + err := fmt.Errorf("got %d elements in cluster info address, expected 2 or 3", n) + return nil, err + } + + ip, err := rd.ReadString() + if err != nil { + return nil, err + } + + port, err := rd.ReadString() + if err != nil { + return nil, err + } + + nodes[j].Addr = net.JoinHostPort(ip, port) + + if n == 3 { + id, err := rd.ReadString() + if err != nil { + return nil, err + } + nodes[j].Id = id + } + } + + slots[i] = ClusterSlot{ + Start: int(start), + End: int(end), + Nodes: nodes, + } + } + return slots, nil +} + //------------------------------------------------------------------------------ // GeoLocation is used with GeoAdd to add geospatial location. @@ -924,9 +1573,9 @@ func (cmd *GeoLocationCmd) String() string { return cmdString(cmd, cmd.locations) } -func (cmd *GeoLocationCmd) readReply(cn *pool.Conn) error { +func (cmd *GeoLocationCmd) readReply(rd *proto.Reader) error { var v interface{} - v, cmd.err = cn.Rd.ReadArrayReply(newGeoLocationSliceParser(cmd.q)) + v, cmd.err = rd.ReadArrayReply(newGeoLocationSliceParser(cmd.q)) if cmd.err != nil { return cmd.err } @@ -934,6 +1583,73 @@ func (cmd *GeoLocationCmd) readReply(cn *pool.Conn) error { return nil } +func newGeoLocationParser(q *GeoRadiusQuery) proto.MultiBulkParse { + return func(rd *proto.Reader, n int64) (interface{}, error) { + var loc GeoLocation + var err error + + loc.Name, err = rd.ReadString() + if err != nil { + return nil, err + } + if q.WithDist { + loc.Dist, err = rd.ReadFloatReply() + if err != nil { + return nil, err + } + } + if q.WithGeoHash { + loc.GeoHash, err = rd.ReadIntReply() + if err != nil { + return nil, err + } + } + if q.WithCoord { + n, err := rd.ReadArrayLen() + if err != nil { + return nil, err + } + if n != 2 { + return nil, fmt.Errorf("got %d coordinates, expected 2", n) + } + + loc.Longitude, err = rd.ReadFloatReply() + if err != nil { + return nil, err + } + loc.Latitude, err = rd.ReadFloatReply() + if err != nil { + return nil, err + } + } + + return &loc, nil + } +} + +func newGeoLocationSliceParser(q *GeoRadiusQuery) proto.MultiBulkParse { + return func(rd *proto.Reader, n int64) (interface{}, error) { + locs := make([]GeoLocation, 0, n) + for i := int64(0); i < n; i++ { + v, err := rd.ReadReply(newGeoLocationParser(q)) + if err != nil { + return nil, err + } + switch vv := v.(type) { + case string: + locs = append(locs, GeoLocation{ + Name: vv, + }) + case *GeoLocation: + locs = append(locs, *vv) + default: + return nil, fmt.Errorf("got %T, expected string or *GeoLocation", v) + } + } + return locs, nil + } +} + //------------------------------------------------------------------------------ type GeoPos struct { @@ -966,9 +1682,9 @@ func (cmd *GeoPosCmd) String() string { return cmdString(cmd, cmd.positions) } -func (cmd *GeoPosCmd) readReply(cn *pool.Conn) error { +func (cmd *GeoPosCmd) readReply(rd *proto.Reader) error { var v interface{} - v, cmd.err = cn.Rd.ReadArrayReply(geoPosSliceParser) + v, cmd.err = rd.ReadArrayReply(geoPosSliceParser) if cmd.err != nil { return cmd.err } @@ -976,6 +1692,44 @@ func (cmd *GeoPosCmd) readReply(cn *pool.Conn) error { return nil } +func geoPosSliceParser(rd *proto.Reader, n int64) (interface{}, error) { + positions := make([]*GeoPos, 0, n) + for i := int64(0); i < n; i++ { + v, err := rd.ReadReply(geoPosParser) + if err != nil { + if err == Nil { + positions = append(positions, nil) + continue + } + return nil, err + } + switch v := v.(type) { + case *GeoPos: + positions = append(positions, v) + default: + return nil, fmt.Errorf("got %T, expected *GeoPos", v) + } + } + return positions, nil +} + +func geoPosParser(rd *proto.Reader, n int64) (interface{}, error) { + var pos GeoPos + var err error + + pos.Longitude, err = rd.ReadFloatReply() + if err != nil { + return nil, err + } + + pos.Latitude, err = rd.ReadFloatReply() + if err != nil { + return nil, err + } + + return &pos, nil +} + //------------------------------------------------------------------------------ type CommandInfo struct { @@ -1014,9 +1768,9 @@ func (cmd *CommandsInfoCmd) String() string { return cmdString(cmd, cmd.val) } -func (cmd *CommandsInfoCmd) readReply(cn *pool.Conn) error { +func (cmd *CommandsInfoCmd) readReply(rd *proto.Reader) error { var v interface{} - v, cmd.err = cn.Rd.ReadArrayReply(commandInfoSliceParser) + v, cmd.err = rd.ReadArrayReply(commandInfoSliceParser) if cmd.err != nil { return cmd.err } @@ -1024,6 +1778,74 @@ func (cmd *CommandsInfoCmd) readReply(cn *pool.Conn) error { return nil } +// Implements proto.MultiBulkParse +func commandInfoSliceParser(rd *proto.Reader, n int64) (interface{}, error) { + m := make(map[string]*CommandInfo, n) + for i := int64(0); i < n; i++ { + v, err := rd.ReadReply(commandInfoParser) + if err != nil { + return nil, err + } + vv := v.(*CommandInfo) + m[vv.Name] = vv + + } + return m, nil +} + +func commandInfoParser(rd *proto.Reader, n int64) (interface{}, error) { + var cmd CommandInfo + var err error + + if n != 6 { + return nil, fmt.Errorf("redis: got %d elements in COMMAND reply, wanted 6", n) + } + + cmd.Name, err = rd.ReadString() + if err != nil { + return nil, err + } + + arity, err := rd.ReadIntReply() + if err != nil { + return nil, err + } + cmd.Arity = int8(arity) + + flags, err := rd.ReadReply(stringSliceParser) + if err != nil { + return nil, err + } + cmd.Flags = flags.([]string) + + firstKeyPos, err := rd.ReadIntReply() + if err != nil { + return nil, err + } + cmd.FirstKeyPos = int8(firstKeyPos) + + lastKeyPos, err := rd.ReadIntReply() + if err != nil { + return nil, err + } + cmd.LastKeyPos = int8(lastKeyPos) + + stepCount, err := rd.ReadIntReply() + if err != nil { + return nil, err + } + cmd.StepCount = int8(stepCount) + + for _, flag := range cmd.Flags { + if flag == "readonly" { + cmd.ReadOnly = true + break + } + } + + return &cmd, nil +} + //------------------------------------------------------------------------------ type cmdsInfoCache struct { diff --git a/vendor/github.com/go-redis/redis/commands.go b/vendor/github.com/go-redis/redis/commands.go index c6a88154e..b259e3a8c 100644 --- a/vendor/github.com/go-redis/redis/commands.go +++ b/vendor/github.com/go-redis/redis/commands.go @@ -62,6 +62,7 @@ type Cmdable interface { TxPipelined(fn func(Pipeliner) error) ([]Cmder, error) TxPipeline() Pipeliner + Command() *CommandsInfoCmd ClientGetName() *StringCmd Echo(message interface{}) *StringCmd Ping() *StatusCmd @@ -171,6 +172,26 @@ type Cmdable interface { SRem(key string, members ...interface{}) *IntCmd SUnion(keys ...string) *StringSliceCmd SUnionStore(destination string, keys ...string) *IntCmd + XAdd(a *XAddArgs) *StringCmd + XLen(stream string) *IntCmd + XRange(stream, start, stop string) *XMessageSliceCmd + XRangeN(stream, start, stop string, count int64) *XMessageSliceCmd + XRevRange(stream string, start, stop string) *XMessageSliceCmd + XRevRangeN(stream string, start, stop string, count int64) *XMessageSliceCmd + XRead(a *XReadArgs) *XStreamSliceCmd + XReadStreams(streams ...string) *XStreamSliceCmd + XGroupCreate(stream, group, start string) *StatusCmd + XGroupSetID(stream, group, start string) *StatusCmd + XGroupDestroy(stream, group string) *IntCmd + XGroupDelConsumer(stream, group, consumer string) *IntCmd + XReadGroup(a *XReadGroupArgs) *XStreamSliceCmd + XAck(stream, group string, ids ...string) *IntCmd + XPending(stream, group string) *XPendingCmd + XPendingExt(a *XPendingExtArgs) *XPendingExtCmd + XClaim(a *XClaimArgs) *XMessageSliceCmd + XClaimJustID(a *XClaimArgs) *StringSliceCmd + XTrim(key string, maxLen int64) *IntCmd + XTrimApprox(key string, maxLen int64) *IntCmd ZAdd(key string, members ...Z) *IntCmd ZAddNX(key string, members ...Z) *IntCmd ZAddXX(key string, members ...Z) *IntCmd @@ -209,6 +230,7 @@ type Cmdable interface { BgRewriteAOF() *StatusCmd BgSave() *StatusCmd ClientKill(ipPort string) *StatusCmd + ClientKillByFilter(keys ...string) *IntCmd ClientList() *StringCmd ClientPause(dur time.Duration) *BoolCmd ConfigGet(parameter string) *SliceCmd @@ -265,9 +287,9 @@ type Cmdable interface { GeoRadiusByMemberRO(key, member string, query *GeoRadiusQuery) *GeoLocationCmd GeoDist(key string, member1, member2, unit string) *FloatCmd GeoHash(key string, members ...string) *StringSliceCmd - Command() *CommandsInfoCmd ReadOnly() *StatusCmd ReadWrite() *StatusCmd + MemoryUsage(key string, samples ...int) *IntCmd } type StatefulCmdable interface { @@ -345,6 +367,12 @@ func (c *statefulCmdable) SwapDB(index1, index2 int) *StatusCmd { //------------------------------------------------------------------------------ +func (c *cmdable) Command() *CommandsInfoCmd { + cmd := NewCommandsInfoCmd("command") + c.process(cmd) + return cmd +} + func (c *cmdable) Del(keys ...string) *IntCmd { args := make([]interface{}, 1+len(keys)) args[0] = "del" @@ -411,7 +439,7 @@ func (c *cmdable) Migrate(host, port, key string, db int64, timeout time.Duratio db, formatMs(timeout), ) - cmd.setReadTimeout(readTimeout(timeout)) + cmd.setReadTimeout(timeout) c.process(cmd) return cmd } @@ -985,7 +1013,7 @@ func (c *cmdable) BLPop(timeout time.Duration, keys ...string) *StringSliceCmd { } args[len(args)-1] = formatSec(timeout) cmd := NewStringSliceCmd(args...) - cmd.setReadTimeout(readTimeout(timeout)) + cmd.setReadTimeout(timeout) c.process(cmd) return cmd } @@ -998,7 +1026,7 @@ func (c *cmdable) BRPop(timeout time.Duration, keys ...string) *StringSliceCmd { } args[len(keys)+1] = formatSec(timeout) cmd := NewStringSliceCmd(args...) - cmd.setReadTimeout(readTimeout(timeout)) + cmd.setReadTimeout(timeout) c.process(cmd) return cmd } @@ -1010,7 +1038,7 @@ func (c *cmdable) BRPopLPush(source, destination string, timeout time.Duration) destination, formatSec(timeout), ) - cmd.setReadTimeout(readTimeout(timeout)) + cmd.setReadTimeout(timeout) c.process(cmd) return cmd } @@ -1282,6 +1310,239 @@ func (c *cmdable) SUnionStore(destination string, keys ...string) *IntCmd { //------------------------------------------------------------------------------ +type XAddArgs struct { + Stream string + MaxLen int64 // MAXLEN N + MaxLenApprox int64 // MAXLEN ~ N + ID string + Values map[string]interface{} +} + +func (c *cmdable) XAdd(a *XAddArgs) *StringCmd { + args := make([]interface{}, 0, 6+len(a.Values)*2) + args = append(args, "xadd") + args = append(args, a.Stream) + if a.MaxLen > 0 { + args = append(args, "maxlen", a.MaxLen) + } else if a.MaxLenApprox > 0 { + args = append(args, "maxlen", "~", a.MaxLenApprox) + } + if a.ID != "" { + args = append(args, a.ID) + } else { + args = append(args, "*") + } + for k, v := range a.Values { + args = append(args, k) + args = append(args, v) + } + + cmd := NewStringCmd(args...) + c.process(cmd) + return cmd +} + +func (c *cmdable) XLen(stream string) *IntCmd { + cmd := NewIntCmd("xlen", stream) + c.process(cmd) + return cmd +} + +func (c *cmdable) XRange(stream, start, stop string) *XMessageSliceCmd { + cmd := NewXMessageSliceCmd("xrange", stream, start, stop) + c.process(cmd) + return cmd +} + +func (c *cmdable) XRangeN(stream, start, stop string, count int64) *XMessageSliceCmd { + cmd := NewXMessageSliceCmd("xrange", stream, start, stop, "count", count) + c.process(cmd) + return cmd +} + +func (c *cmdable) XRevRange(stream, start, stop string) *XMessageSliceCmd { + cmd := NewXMessageSliceCmd("xrevrange", stream, start, stop) + c.process(cmd) + return cmd +} + +func (c *cmdable) XRevRangeN(stream, start, stop string, count int64) *XMessageSliceCmd { + cmd := NewXMessageSliceCmd("xrevrange", stream, start, stop, "count", count) + c.process(cmd) + return cmd +} + +type XReadArgs struct { + Streams []string + Count int64 + Block time.Duration +} + +func (c *cmdable) XRead(a *XReadArgs) *XStreamSliceCmd { + args := make([]interface{}, 0, 5+len(a.Streams)) + args = append(args, "xread") + if a.Count > 0 { + args = append(args, "count") + args = append(args, a.Count) + } + if a.Block >= 0 { + args = append(args, "block") + args = append(args, int64(a.Block/time.Millisecond)) + } + args = append(args, "streams") + for _, s := range a.Streams { + args = append(args, s) + } + + cmd := NewXStreamSliceCmd(args...) + c.process(cmd) + return cmd +} + +func (c *cmdable) XReadStreams(streams ...string) *XStreamSliceCmd { + return c.XRead(&XReadArgs{ + Streams: streams, + Block: -1, + }) +} + +func (c *cmdable) XGroupCreate(stream, group, start string) *StatusCmd { + cmd := NewStatusCmd("xgroup", "create", stream, group, start) + c.process(cmd) + return cmd +} + +func (c *cmdable) XGroupSetID(stream, group, start string) *StatusCmd { + cmd := NewStatusCmd("xgroup", "setid", stream, group, start) + c.process(cmd) + return cmd +} + +func (c *cmdable) XGroupDestroy(stream, group string) *IntCmd { + cmd := NewIntCmd("xgroup", "destroy", stream, group) + c.process(cmd) + return cmd +} + +func (c *cmdable) XGroupDelConsumer(stream, group, consumer string) *IntCmd { + cmd := NewIntCmd("xgroup", "delconsumer", stream, group, consumer) + c.process(cmd) + return cmd +} + +type XReadGroupArgs struct { + Group string + Consumer string + Streams []string + Count int64 + Block time.Duration +} + +func (c *cmdable) XReadGroup(a *XReadGroupArgs) *XStreamSliceCmd { + args := make([]interface{}, 0, 8+len(a.Streams)) + args = append(args, "xreadgroup", "group", a.Group, a.Consumer) + if a.Count > 0 { + args = append(args, "count", a.Count) + } + if a.Block >= 0 { + args = append(args, "block", int64(a.Block/time.Millisecond)) + } + args = append(args, "streams") + for _, s := range a.Streams { + args = append(args, s) + } + + cmd := NewXStreamSliceCmd(args...) + c.process(cmd) + return cmd +} + +func (c *cmdable) XAck(stream, group string, ids ...string) *IntCmd { + args := []interface{}{"xack", stream, group} + for _, id := range ids { + args = append(args, id) + } + cmd := NewIntCmd(args...) + c.process(cmd) + return cmd +} + +func (c *cmdable) XPending(stream, group string) *XPendingCmd { + cmd := NewXPendingCmd("xpending", stream, group) + c.process(cmd) + return cmd +} + +type XPendingExtArgs struct { + Stream string + Group string + Start string + End string + Count int64 + Consumer string +} + +func (c *cmdable) XPendingExt(a *XPendingExtArgs) *XPendingExtCmd { + args := make([]interface{}, 0, 7) + args = append(args, "xpending", a.Stream, a.Group, a.Start, a.End, a.Count) + if a.Consumer != "" { + args = append(args, a.Consumer) + } + cmd := NewXPendingExtCmd(args...) + c.process(cmd) + return cmd +} + +type XClaimArgs struct { + Stream string + Group string + Consumer string + MinIdle time.Duration + Messages []string +} + +func (c *cmdable) XClaim(a *XClaimArgs) *XMessageSliceCmd { + args := xClaimArgs(a) + cmd := NewXMessageSliceCmd(args...) + c.process(cmd) + return cmd +} + +func (c *cmdable) XClaimJustID(a *XClaimArgs) *StringSliceCmd { + args := xClaimArgs(a) + args = append(args, "justid") + cmd := NewStringSliceCmd(args...) + c.process(cmd) + return cmd +} + +func xClaimArgs(a *XClaimArgs) []interface{} { + args := make([]interface{}, 0, 4+len(a.Messages)) + args = append(args, + "xclaim", + a.Stream, + a.Group, a.Consumer, + int64(a.MinIdle/time.Millisecond)) + for _, id := range a.Messages { + args = append(args, id) + } + return args +} + +func (c *cmdable) XTrim(key string, maxLen int64) *IntCmd { + cmd := NewIntCmd("xtrim", key, "maxlen", maxLen) + c.process(cmd) + return cmd +} + +func (c *cmdable) XTrimApprox(key string, maxLen int64) *IntCmd { + cmd := NewIntCmd("xtrim", key, "maxlen", "~", maxLen) + c.process(cmd) + return cmd +} + +//------------------------------------------------------------------------------ + // Z represents sorted set member. type Z struct { Score float64 @@ -1682,6 +1943,20 @@ func (c *cmdable) ClientKill(ipPort string) *StatusCmd { return cmd } +// ClientKillByFilter is new style synx, while the ClientKill is old +// CLIENT KILL <option> [value] ... <option> [value] +func (c *cmdable) ClientKillByFilter(keys ...string) *IntCmd { + args := make([]interface{}, 2+len(keys)) + args[0] = "client" + args[1] = "kill" + for i, key := range keys { + args[2+i] = key + } + cmd := NewIntCmd(args...) + c.process(cmd) + return cmd +} + func (c *cmdable) ClientList() *StringCmd { cmd := NewStringCmd("client", "list") c.process(cmd) @@ -2168,8 +2443,15 @@ func (c *cmdable) GeoPos(key string, members ...string) *GeoPosCmd { //------------------------------------------------------------------------------ -func (c *cmdable) Command() *CommandsInfoCmd { - cmd := NewCommandsInfoCmd("command") +func (c *cmdable) MemoryUsage(key string, samples ...int) *IntCmd { + args := []interface{}{"memory", "usage", key} + if len(samples) > 0 { + if len(samples) != 1 { + panic("MemoryUsage expects single sample count") + } + args = append(args, "SAMPLES", samples[0]) + } + cmd := NewIntCmd(args...) c.process(cmd) return cmd } diff --git a/vendor/github.com/go-redis/redis/internal/error.go b/vendor/github.com/go-redis/redis/internal/error.go index 7b419577e..bda97baa6 100644 --- a/vendor/github.com/go-redis/redis/internal/error.go +++ b/vendor/github.com/go-redis/redis/internal/error.go @@ -8,9 +8,15 @@ import ( "github.com/go-redis/redis/internal/proto" ) -func IsRetryableError(err error, retryNetError bool) bool { - if IsNetworkError(err) { - return retryNetError +func IsRetryableError(err error, retryTimeout bool) bool { + if err == io.EOF { + return true + } + if netErr, ok := err.(net.Error); ok { + if netErr.Timeout() { + return retryTimeout + } + return true } s := err.Error() if s == "ERR max number of clients reached" { @@ -19,6 +25,9 @@ func IsRetryableError(err error, retryNetError bool) bool { if strings.HasPrefix(s, "LOADING ") { return true } + if strings.HasPrefix(s, "READONLY ") { + return true + } if strings.HasPrefix(s, "CLUSTERDOWN ") { return true } @@ -30,24 +39,12 @@ func IsRedisError(err error) bool { return ok } -func IsNetworkError(err error) bool { - if err == io.EOF { - return true - } - _, ok := err.(net.Error) - return ok -} - -func IsReadOnlyError(err error) bool { - return strings.HasPrefix(err.Error(), "READONLY ") -} - func IsBadConn(err error, allowTimeout bool) bool { if err == nil { return false } if IsRedisError(err) { - return false + return strings.HasPrefix(err.Error(), "READONLY ") } if allowTimeout { if netErr, ok := err.(net.Error); ok && netErr.Timeout() { diff --git a/vendor/github.com/go-redis/redis/internal/hashtag/hashtag.go b/vendor/github.com/go-redis/redis/internal/hashtag/hashtag.go index 8c7ebbfa6..22f5b3981 100644 --- a/vendor/github.com/go-redis/redis/internal/hashtag/hashtag.go +++ b/vendor/github.com/go-redis/redis/internal/hashtag/hashtag.go @@ -5,7 +5,7 @@ import ( "strings" ) -const SlotNumber = 16384 +const slotNumber = 16384 // CRC16 implementation according to CCITT standards. // Copyright 2001-2010 Georges Menie (www.menie.org) @@ -56,7 +56,7 @@ func Key(key string) string { } func RandomSlot() int { - return rand.Intn(SlotNumber) + return rand.Intn(slotNumber) } // hashSlot returns a consistent slot number between 0 and 16383 @@ -66,7 +66,7 @@ func Slot(key string) int { return RandomSlot() } key = Key(key) - return int(crc16sum(key)) % SlotNumber + return int(crc16sum(key)) % slotNumber } func crc16sum(key string) (crc uint16) { diff --git a/vendor/github.com/go-redis/redis/internal/pool/conn.go b/vendor/github.com/go-redis/redis/internal/pool/conn.go index 8af51d9de..1095bfe59 100644 --- a/vendor/github.com/go-redis/redis/internal/pool/conn.go +++ b/vendor/github.com/go-redis/redis/internal/pool/conn.go @@ -13,19 +13,21 @@ var noDeadline = time.Time{} type Conn struct { netConn net.Conn - Rd *proto.Reader - Wb *proto.WriteBuffer + rd *proto.Reader + rdLocked bool + wr *proto.Writer - Inited bool - usedAt atomic.Value + InitedAt time.Time + pooled bool + usedAt atomic.Value } func NewConn(netConn net.Conn) *Conn { cn := &Conn{ netConn: netConn, - Wb: proto.NewWriteBuffer(), } - cn.Rd = proto.NewReader(cn.netConn) + cn.rd = proto.NewReader(netConn) + cn.wr = proto.NewWriter(netConn) cn.SetUsedAt(time.Now()) return cn } @@ -40,14 +42,11 @@ func (cn *Conn) SetUsedAt(tm time.Time) { func (cn *Conn) SetNetConn(netConn net.Conn) { cn.netConn = netConn - cn.Rd.Reset(netConn) + cn.rd.Reset(netConn) + cn.wr.Reset(netConn) } -func (cn *Conn) IsStale(timeout time.Duration) bool { - return timeout > 0 && time.Since(cn.UsedAt()) > timeout -} - -func (cn *Conn) SetReadTimeout(timeout time.Duration) error { +func (cn *Conn) setReadTimeout(timeout time.Duration) error { now := time.Now() cn.SetUsedAt(now) if timeout > 0 { @@ -56,7 +55,7 @@ func (cn *Conn) SetReadTimeout(timeout time.Duration) error { return cn.netConn.SetReadDeadline(noDeadline) } -func (cn *Conn) SetWriteTimeout(timeout time.Duration) error { +func (cn *Conn) setWriteTimeout(timeout time.Duration) error { now := time.Now() cn.SetUsedAt(now) if timeout > 0 { @@ -73,6 +72,22 @@ func (cn *Conn) RemoteAddr() net.Addr { return cn.netConn.RemoteAddr() } +func (cn *Conn) WithReader(timeout time.Duration, fn func(rd *proto.Reader) error) error { + _ = cn.setReadTimeout(timeout) + return fn(cn.rd) +} + +func (cn *Conn) WithWriter(timeout time.Duration, fn func(wr *proto.Writer) error) error { + _ = cn.setWriteTimeout(timeout) + + firstErr := fn(cn.wr) + err := cn.wr.Flush() + if err != nil && firstErr == nil { + firstErr = err + } + return firstErr +} + func (cn *Conn) Close() error { return cn.netConn.Close() } diff --git a/vendor/github.com/go-redis/redis/internal/pool/pool.go b/vendor/github.com/go-redis/redis/internal/pool/pool.go index cab66904a..9cecee8ad 100644 --- a/vendor/github.com/go-redis/redis/internal/pool/pool.go +++ b/vendor/github.com/go-redis/redis/internal/pool/pool.go @@ -28,7 +28,6 @@ type Stats struct { Timeouts uint32 // number of times a wait timeout occurred TotalConns uint32 // number of total connections in the pool - FreeConns uint32 // deprecated - use IdleConns IdleConns uint32 // number of idle connections in the pool StaleConns uint32 // number of stale connections removed from the pool } @@ -53,6 +52,8 @@ type Options struct { OnClose func(*Conn) error PoolSize int + MinIdleConns int + MaxConnAge time.Duration PoolTimeout time.Duration IdleTimeout time.Duration IdleCheckFrequency time.Duration @@ -63,16 +64,16 @@ type ConnPool struct { dialErrorsNum uint32 // atomic - lastDialError error lastDialErrorMu sync.RWMutex + lastDialError error queue chan struct{} - connsMu sync.Mutex - conns []*Conn - - idleConnsMu sync.RWMutex - idleConns []*Conn + connsMu sync.Mutex + conns []*Conn + idleConns []*Conn + poolSize int + idleConnsLen int stats Stats @@ -90,6 +91,10 @@ func NewConnPool(opt *Options) *ConnPool { idleConns: make([]*Conn, 0, opt.PoolSize), } + for i := 0; i < opt.MinIdleConns; i++ { + p.checkMinIdleConns() + } + if opt.IdleTimeout > 0 && opt.IdleCheckFrequency > 0 { go p.reaper(opt.IdleCheckFrequency) } @@ -97,19 +102,53 @@ func NewConnPool(opt *Options) *ConnPool { return p } +func (p *ConnPool) checkMinIdleConns() { + if p.opt.MinIdleConns == 0 { + return + } + if p.poolSize < p.opt.PoolSize && p.idleConnsLen < p.opt.MinIdleConns { + p.poolSize++ + p.idleConnsLen++ + go p.addIdleConn() + } +} + +func (p *ConnPool) addIdleConn() { + cn, err := p.newConn(true) + if err != nil { + return + } + + p.connsMu.Lock() + p.conns = append(p.conns, cn) + p.idleConns = append(p.idleConns, cn) + p.connsMu.Unlock() +} + func (p *ConnPool) NewConn() (*Conn, error) { - cn, err := p.newConn() + return p._NewConn(false) +} + +func (p *ConnPool) _NewConn(pooled bool) (*Conn, error) { + cn, err := p.newConn(pooled) if err != nil { return nil, err } p.connsMu.Lock() p.conns = append(p.conns, cn) + if pooled { + if p.poolSize < p.opt.PoolSize { + p.poolSize++ + } else { + cn.pooled = false + } + } p.connsMu.Unlock() return cn, nil } -func (p *ConnPool) newConn() (*Conn, error) { +func (p *ConnPool) newConn(pooled bool) (*Conn, error) { if p.closed() { return nil, ErrClosed } @@ -127,7 +166,9 @@ func (p *ConnPool) newConn() (*Conn, error) { return nil, err } - return NewConn(netConn), nil + cn := NewConn(netConn) + cn.pooled = pooled + return cn, nil } func (p *ConnPool) tryDial() { @@ -174,16 +215,16 @@ func (p *ConnPool) Get() (*Conn, error) { } for { - p.idleConnsMu.Lock() + p.connsMu.Lock() cn := p.popIdle() - p.idleConnsMu.Unlock() + p.connsMu.Unlock() if cn == nil { break } - if cn.IsStale(p.opt.IdleTimeout) { - p.CloseConn(cn) + if p.isStaleConn(cn) { + _ = p.CloseConn(cn) continue } @@ -193,7 +234,7 @@ func (p *ConnPool) Get() (*Conn, error) { atomic.AddUint32(&p.stats.Misses, 1) - newcn, err := p.NewConn() + newcn, err := p._NewConn(true) if err != nil { p.freeTurn() return nil, err @@ -241,21 +282,21 @@ func (p *ConnPool) popIdle() *Conn { idx := len(p.idleConns) - 1 cn := p.idleConns[idx] p.idleConns = p.idleConns[:idx] - + p.idleConnsLen-- + p.checkMinIdleConns() return cn } func (p *ConnPool) Put(cn *Conn) { - buf := cn.Rd.PeekBuffered() - if buf != nil { - internal.Logf("connection has unread data: %.100q", buf) + if !cn.pooled { p.Remove(cn) return } - p.idleConnsMu.Lock() + p.connsMu.Lock() p.idleConns = append(p.idleConns, cn) - p.idleConnsMu.Unlock() + p.idleConnsLen++ + p.connsMu.Unlock() p.freeTurn() } @@ -275,6 +316,10 @@ func (p *ConnPool) removeConn(cn *Conn) { for i, c := range p.conns { if c == cn { p.conns = append(p.conns[:i], p.conns[i+1:]...) + if cn.pooled { + p.poolSize-- + p.checkMinIdleConns() + } break } } @@ -291,17 +336,17 @@ func (p *ConnPool) closeConn(cn *Conn) error { // Len returns total number of connections. func (p *ConnPool) Len() int { p.connsMu.Lock() - l := len(p.conns) + n := len(p.conns) p.connsMu.Unlock() - return l + return n } -// FreeLen returns number of idle connections. +// IdleLen returns number of idle connections. func (p *ConnPool) IdleLen() int { - p.idleConnsMu.RLock() - l := len(p.idleConns) - p.idleConnsMu.RUnlock() - return l + p.connsMu.Lock() + n := p.idleConnsLen + p.connsMu.Unlock() + return n } func (p *ConnPool) Stats() *Stats { @@ -312,7 +357,6 @@ func (p *ConnPool) Stats() *Stats { Timeouts: atomic.LoadUint32(&p.stats.Timeouts), TotalConns: uint32(p.Len()), - FreeConns: uint32(idleLen), IdleConns: uint32(idleLen), StaleConns: atomic.LoadUint32(&p.stats.StaleConns), } @@ -349,11 +393,10 @@ func (p *ConnPool) Close() error { } } p.conns = nil - p.connsMu.Unlock() - - p.idleConnsMu.Lock() + p.poolSize = 0 p.idleConns = nil - p.idleConnsMu.Unlock() + p.idleConnsLen = 0 + p.connsMu.Unlock() return firstErr } @@ -364,11 +407,12 @@ func (p *ConnPool) reapStaleConn() *Conn { } cn := p.idleConns[0] - if !cn.IsStale(p.opt.IdleTimeout) { + if !p.isStaleConn(cn) { return nil } p.idleConns = append(p.idleConns[:0], p.idleConns[1:]...) + p.idleConnsLen-- return cn } @@ -378,9 +422,9 @@ func (p *ConnPool) ReapStaleConns() (int, error) { for { p.getTurn() - p.idleConnsMu.Lock() + p.connsMu.Lock() cn := p.reapStaleConn() - p.idleConnsMu.Unlock() + p.connsMu.Unlock() if cn != nil { p.removeConn(cn) @@ -414,3 +458,19 @@ func (p *ConnPool) reaper(frequency time.Duration) { atomic.AddUint32(&p.stats.StaleConns, uint32(n)) } } + +func (p *ConnPool) isStaleConn(cn *Conn) bool { + if p.opt.IdleTimeout == 0 && p.opt.MaxConnAge == 0 { + return false + } + + now := time.Now() + if p.opt.IdleTimeout > 0 && now.Sub(cn.UsedAt()) >= p.opt.IdleTimeout { + return true + } + if p.opt.MaxConnAge > 0 && now.Sub(cn.InitedAt) >= p.opt.MaxConnAge { + return true + } + + return false +} diff --git a/vendor/github.com/go-redis/redis/internal/proto/reader.go b/vendor/github.com/go-redis/redis/internal/proto/reader.go index d5d695358..896b6f654 100644 --- a/vendor/github.com/go-redis/redis/internal/proto/reader.go +++ b/vendor/github.com/go-redis/redis/internal/proto/reader.go @@ -9,8 +9,6 @@ import ( "github.com/go-redis/redis/internal/util" ) -const bytesAllocLimit = 1024 * 1024 // 1mb - const ( ErrorReply = '-' StatusReply = '+' @@ -32,40 +30,23 @@ func (e RedisError) Error() string { return string(e) } type MultiBulkParse func(*Reader, int64) (interface{}, error) type Reader struct { - src *bufio.Reader - buf []byte + rd *bufio.Reader + _buf []byte } func NewReader(rd io.Reader) *Reader { return &Reader{ - src: bufio.NewReader(rd), - buf: make([]byte, 4096), + rd: bufio.NewReader(rd), + _buf: make([]byte, 64), } } func (r *Reader) Reset(rd io.Reader) { - r.src.Reset(rd) -} - -func (r *Reader) PeekBuffered() []byte { - if n := r.src.Buffered(); n != 0 { - b, _ := r.src.Peek(n) - return b - } - return nil -} - -func (r *Reader) ReadN(n int) ([]byte, error) { - b, err := readN(r.src, r.buf, n) - if err != nil { - return nil, err - } - r.buf = b - return b, nil + r.rd.Reset(rd) } func (r *Reader) ReadLine() ([]byte, error) { - line, isPrefix, err := r.src.ReadLine() + line, isPrefix, err := r.rd.ReadLine() if err != nil { return nil, err } @@ -91,11 +72,11 @@ func (r *Reader) ReadReply(m MultiBulkParse) (interface{}, error) { case ErrorReply: return nil, ParseErrorReply(line) case StatusReply: - return parseStatusValue(line), nil + return string(line[1:]), nil case IntReply: return util.ParseInt(line[1:], 10, 64) case StringReply: - return r.readTmpBytesValue(line) + return r.readStringReply(line) case ArrayReply: n, err := parseArrayLen(line) if err != nil { @@ -121,47 +102,42 @@ func (r *Reader) ReadIntReply() (int64, error) { } } -func (r *Reader) ReadTmpBytesReply() ([]byte, error) { +func (r *Reader) ReadString() (string, error) { line, err := r.ReadLine() if err != nil { - return nil, err + return "", err } switch line[0] { case ErrorReply: - return nil, ParseErrorReply(line) + return "", ParseErrorReply(line) case StringReply: - return r.readTmpBytesValue(line) + return r.readStringReply(line) case StatusReply: - return parseStatusValue(line), nil + return string(line[1:]), nil + case IntReply: + return string(line[1:]), nil default: - return nil, fmt.Errorf("redis: can't parse string reply: %.100q", line) + return "", fmt.Errorf("redis: can't parse reply=%.100q reading string", line) } } -func (r *Reader) ReadBytesReply() ([]byte, error) { - b, err := r.ReadTmpBytesReply() - if err != nil { - return nil, err +func (r *Reader) readStringReply(line []byte) (string, error) { + if isNilReply(line) { + return "", Nil } - cp := make([]byte, len(b)) - copy(cp, b) - return cp, nil -} -func (r *Reader) ReadStringReply() (string, error) { - b, err := r.ReadTmpBytesReply() + replyLen, err := strconv.Atoi(string(line[1:])) if err != nil { return "", err } - return string(b), nil -} -func (r *Reader) ReadFloatReply() (float64, error) { - b, err := r.ReadTmpBytesReply() + b := make([]byte, replyLen+2) + _, err = io.ReadFull(r.rd, b) if err != nil { - return 0, err + return "", err } - return util.ParseFloat(b, 64) + + return util.BytesToString(b[:replyLen]), nil } func (r *Reader) ReadArrayReply(m MultiBulkParse) (interface{}, error) { @@ -219,7 +195,7 @@ func (r *Reader) ReadScanReply() ([]string, uint64, error) { keys := make([]string, n) for i := int64(0); i < n; i++ { - key, err := r.ReadStringReply() + key, err := r.ReadString() if err != nil { return nil, 0, err } @@ -229,25 +205,8 @@ func (r *Reader) ReadScanReply() ([]string, uint64, error) { return keys, cursor, err } -func (r *Reader) readTmpBytesValue(line []byte) ([]byte, error) { - if isNilReply(line) { - return nil, Nil - } - - replyLen, err := strconv.Atoi(string(line[1:])) - if err != nil { - return nil, err - } - - b, err := r.ReadN(replyLen + 2) - if err != nil { - return nil, err - } - return b[:replyLen], nil -} - func (r *Reader) ReadInt() (int64, error) { - b, err := r.ReadTmpBytesReply() + b, err := r.readTmpBytesReply() if err != nil { return 0, err } @@ -255,55 +214,62 @@ func (r *Reader) ReadInt() (int64, error) { } func (r *Reader) ReadUint() (uint64, error) { - b, err := r.ReadTmpBytesReply() + b, err := r.readTmpBytesReply() if err != nil { return 0, err } return util.ParseUint(b, 10, 64) } -// -------------------------------------------------------------------- - -func readN(r io.Reader, b []byte, n int) ([]byte, error) { - if n == 0 && b == nil { - return make([]byte, 0), nil +func (r *Reader) ReadFloatReply() (float64, error) { + b, err := r.readTmpBytesReply() + if err != nil { + return 0, err } + return util.ParseFloat(b, 64) +} - if cap(b) >= n { - b = b[:n] - _, err := io.ReadFull(r, b) - return b, err +func (r *Reader) readTmpBytesReply() ([]byte, error) { + line, err := r.ReadLine() + if err != nil { + return nil, err } - b = b[:cap(b)] - - pos := 0 - for pos < n { - diff := n - len(b) - if diff > bytesAllocLimit { - diff = bytesAllocLimit - } - b = append(b, make([]byte, diff)...) + switch line[0] { + case ErrorReply: + return nil, ParseErrorReply(line) + case StringReply: + return r._readTmpBytesReply(line) + case StatusReply: + return line[1:], nil + default: + return nil, fmt.Errorf("redis: can't parse string reply: %.100q", line) + } +} - nn, err := io.ReadFull(r, b[pos:]) - if err != nil { - return nil, err - } - pos += nn +func (r *Reader) _readTmpBytesReply(line []byte) ([]byte, error) { + if isNilReply(line) { + return nil, Nil } - return b, nil -} + replyLen, err := strconv.Atoi(string(line[1:])) + if err != nil { + return nil, err + } -func formatInt(n int64) string { - return strconv.FormatInt(n, 10) -} + buf := r.buf(replyLen + 2) + _, err = io.ReadFull(r.rd, buf) + if err != nil { + return nil, err + } -func formatUint(u uint64) string { - return strconv.FormatUint(u, 10) + return buf[:replyLen], nil } -func formatFloat(f float64) string { - return strconv.FormatFloat(f, 'f', -1, 64) +func (r *Reader) buf(n int) []byte { + if d := n - cap(r._buf); d > 0 { + r._buf = append(r._buf, make([]byte, d)...) + } + return r._buf[:n] } func isNilReply(b []byte) bool { @@ -316,10 +282,6 @@ func ParseErrorReply(line []byte) error { return RedisError(string(line[1:])) } -func parseStatusValue(line []byte) []byte { - return line[1:] -} - func parseArrayLen(line []byte) (int64, error) { if isNilReply(line) { return 0, Nil diff --git a/vendor/github.com/go-redis/redis/internal/proto/write_buffer.go b/vendor/github.com/go-redis/redis/internal/proto/write_buffer.go deleted file mode 100644 index cc4014fb4..000000000 --- a/vendor/github.com/go-redis/redis/internal/proto/write_buffer.go +++ /dev/null @@ -1,101 +0,0 @@ -package proto - -import ( - "encoding" - "fmt" - "strconv" -) - -type WriteBuffer struct { - b []byte -} - -func NewWriteBuffer() *WriteBuffer { - return &WriteBuffer{ - b: make([]byte, 0, 4096), - } -} - -func (w *WriteBuffer) Len() int { return len(w.b) } -func (w *WriteBuffer) Bytes() []byte { return w.b } -func (w *WriteBuffer) Reset() { w.b = w.b[:0] } - -func (w *WriteBuffer) Append(args []interface{}) error { - w.b = append(w.b, ArrayReply) - w.b = strconv.AppendUint(w.b, uint64(len(args)), 10) - w.b = append(w.b, '\r', '\n') - - for _, arg := range args { - if err := w.append(arg); err != nil { - return err - } - } - return nil -} - -func (w *WriteBuffer) append(val interface{}) error { - switch v := val.(type) { - case nil: - w.AppendString("") - case string: - w.AppendString(v) - case []byte: - w.AppendBytes(v) - case int: - w.AppendString(formatInt(int64(v))) - case int8: - w.AppendString(formatInt(int64(v))) - case int16: - w.AppendString(formatInt(int64(v))) - case int32: - w.AppendString(formatInt(int64(v))) - case int64: - w.AppendString(formatInt(v)) - case uint: - w.AppendString(formatUint(uint64(v))) - case uint8: - w.AppendString(formatUint(uint64(v))) - case uint16: - w.AppendString(formatUint(uint64(v))) - case uint32: - w.AppendString(formatUint(uint64(v))) - case uint64: - w.AppendString(formatUint(v)) - case float32: - w.AppendString(formatFloat(float64(v))) - case float64: - w.AppendString(formatFloat(v)) - case bool: - if v { - w.AppendString("1") - } else { - w.AppendString("0") - } - case encoding.BinaryMarshaler: - b, err := v.MarshalBinary() - if err != nil { - return err - } - w.AppendBytes(b) - default: - return fmt.Errorf( - "redis: can't marshal %T (consider implementing encoding.BinaryMarshaler)", val) - } - return nil -} - -func (w *WriteBuffer) AppendString(s string) { - w.b = append(w.b, StringReply) - w.b = strconv.AppendUint(w.b, uint64(len(s)), 10) - w.b = append(w.b, '\r', '\n') - w.b = append(w.b, s...) - w.b = append(w.b, '\r', '\n') -} - -func (w *WriteBuffer) AppendBytes(p []byte) { - w.b = append(w.b, StringReply) - w.b = strconv.AppendUint(w.b, uint64(len(p)), 10) - w.b = append(w.b, '\r', '\n') - w.b = append(w.b, p...) - w.b = append(w.b, '\r', '\n') -} diff --git a/vendor/github.com/go-redis/redis/internal/proto/writer.go b/vendor/github.com/go-redis/redis/internal/proto/writer.go new file mode 100644 index 000000000..d106ce0ee --- /dev/null +++ b/vendor/github.com/go-redis/redis/internal/proto/writer.go @@ -0,0 +1,159 @@ +package proto + +import ( + "bufio" + "encoding" + "fmt" + "io" + "strconv" + + "github.com/go-redis/redis/internal/util" +) + +type Writer struct { + wr *bufio.Writer + + lenBuf []byte + numBuf []byte +} + +func NewWriter(wr io.Writer) *Writer { + return &Writer{ + wr: bufio.NewWriter(wr), + + lenBuf: make([]byte, 64), + numBuf: make([]byte, 64), + } +} + +func (w *Writer) WriteArgs(args []interface{}) error { + err := w.wr.WriteByte(ArrayReply) + if err != nil { + return err + } + + err = w.writeLen(len(args)) + if err != nil { + return err + } + + for _, arg := range args { + err := w.writeArg(arg) + if err != nil { + return err + } + } + + return nil +} + +func (w *Writer) writeLen(n int) error { + w.lenBuf = strconv.AppendUint(w.lenBuf[:0], uint64(n), 10) + w.lenBuf = append(w.lenBuf, '\r', '\n') + _, err := w.wr.Write(w.lenBuf) + return err +} + +func (w *Writer) writeArg(v interface{}) error { + switch v := v.(type) { + case nil: + return w.string("") + case string: + return w.string(v) + case []byte: + return w.bytes(v) + case int: + return w.int(int64(v)) + case int8: + return w.int(int64(v)) + case int16: + return w.int(int64(v)) + case int32: + return w.int(int64(v)) + case int64: + return w.int(v) + case uint: + return w.uint(uint64(v)) + case uint8: + return w.uint(uint64(v)) + case uint16: + return w.uint(uint64(v)) + case uint32: + return w.uint(uint64(v)) + case uint64: + return w.uint(v) + case float32: + return w.float(float64(v)) + case float64: + return w.float(v) + case bool: + if v { + return w.int(1) + } else { + return w.int(0) + } + case encoding.BinaryMarshaler: + b, err := v.MarshalBinary() + if err != nil { + return err + } + return w.bytes(b) + default: + return fmt.Errorf( + "redis: can't marshal %T (implement encoding.BinaryMarshaler)", v) + } +} + +func (w *Writer) bytes(b []byte) error { + err := w.wr.WriteByte(StringReply) + if err != nil { + return err + } + + err = w.writeLen(len(b)) + if err != nil { + return err + } + + _, err = w.wr.Write(b) + if err != nil { + return err + } + + return w.crlf() +} + +func (w *Writer) string(s string) error { + return w.bytes(util.StringToBytes(s)) +} + +func (w *Writer) uint(n uint64) error { + w.numBuf = strconv.AppendUint(w.numBuf[:0], n, 10) + return w.bytes(w.numBuf) +} + +func (w *Writer) int(n int64) error { + w.numBuf = strconv.AppendInt(w.numBuf[:0], n, 10) + return w.bytes(w.numBuf) +} + +func (w *Writer) float(f float64) error { + w.numBuf = strconv.AppendFloat(w.numBuf[:0], f, 'f', -1, 64) + return w.bytes(w.numBuf) +} + +func (w *Writer) crlf() error { + err := w.wr.WriteByte('\r') + if err != nil { + return err + } + return w.wr.WriteByte('\n') +} + +func (w *Writer) Reset(wr io.Writer) { + w.wr.Reset(wr) +} + +func (w *Writer) Flush() error { + return w.wr.Flush() +} diff --git a/vendor/github.com/go-redis/redis/internal/util/safe.go b/vendor/github.com/go-redis/redis/internal/util/safe.go index cd8918330..1b3060ebc 100644 --- a/vendor/github.com/go-redis/redis/internal/util/safe.go +++ b/vendor/github.com/go-redis/redis/internal/util/safe.go @@ -5,3 +5,7 @@ package util func BytesToString(b []byte) string { return string(b) } + +func StringToBytes(s string) []byte { + return []byte(s) +} diff --git a/vendor/github.com/go-redis/redis/internal/util/unsafe.go b/vendor/github.com/go-redis/redis/internal/util/unsafe.go index 93a89c55c..c9868aac2 100644 --- a/vendor/github.com/go-redis/redis/internal/util/unsafe.go +++ b/vendor/github.com/go-redis/redis/internal/util/unsafe.go @@ -10,3 +10,13 @@ import ( func BytesToString(b []byte) string { return *(*string)(unsafe.Pointer(&b)) } + +// StringToBytes converts string to byte slice. +func StringToBytes(s string) []byte { + return *(*[]byte)(unsafe.Pointer( + &struct { + string + Cap int + }{s, len(s)}, + )) +} diff --git a/vendor/github.com/go-redis/redis/options.go b/vendor/github.com/go-redis/redis/options.go index 35ce06195..2b5bcb58a 100644 --- a/vendor/github.com/go-redis/redis/options.go +++ b/vendor/github.com/go-redis/redis/options.go @@ -59,16 +59,24 @@ type Options struct { // Maximum number of socket connections. // Default is 10 connections per every CPU as reported by runtime.NumCPU. PoolSize int + // Minimum number of idle connections which is useful when establishing + // new connection is slow. + MinIdleConns int + // Connection age at which client retires (closes) the connection. + // Default is to not close aged connections. + MaxConnAge time.Duration // Amount of time client waits for connection if all connections // are busy before returning an error. // Default is ReadTimeout + 1 second. PoolTimeout time.Duration // Amount of time after which client closes idle connections. // Should be less than server's timeout. - // Default is 5 minutes. + // Default is 5 minutes. -1 disables idle timeout check. IdleTimeout time.Duration - // Frequency of idle checks. - // Default is 1 minute. -1 disables idle check. + // Frequency of idle checks made by idle connections reaper. + // Default is 1 minute. -1 disables idle connections reaper, + // but idle connections are still discarded by the client + // if IdleTimeout is set. IdleCheckFrequency time.Duration // Enables read only queries on slave nodes. @@ -84,12 +92,15 @@ func (opt *Options) init() { } if opt.Dialer == nil { opt.Dialer = func() (net.Conn, error) { - conn, err := net.DialTimeout(opt.Network, opt.Addr, opt.DialTimeout) - if opt.TLSConfig == nil || err != nil { - return conn, err + netDialer := &net.Dialer{ + Timeout: opt.DialTimeout, + KeepAlive: 5 * time.Minute, + } + if opt.TLSConfig == nil { + return netDialer.Dial(opt.Network, opt.Addr) + } else { + return tls.DialWithDialer(netDialer, opt.Network, opt.Addr, opt.TLSConfig) } - t := tls.Client(conn, opt.TLSConfig) - return t, t.Handshake() } } if opt.PoolSize == 0 { @@ -192,6 +203,8 @@ func newConnPool(opt *Options) *pool.ConnPool { return pool.NewConnPool(&pool.Options{ Dialer: opt.Dialer, PoolSize: opt.PoolSize, + MinIdleConns: opt.MinIdleConns, + MaxConnAge: opt.MaxConnAge, PoolTimeout: opt.PoolTimeout, IdleTimeout: opt.IdleTimeout, IdleCheckFrequency: opt.IdleCheckFrequency, diff --git a/vendor/github.com/go-redis/redis/parser.go b/vendor/github.com/go-redis/redis/parser.go deleted file mode 100644 index f0dc67f0e..000000000 --- a/vendor/github.com/go-redis/redis/parser.go +++ /dev/null @@ -1,394 +0,0 @@ -package redis - -import ( - "fmt" - "net" - "strconv" - "time" - - "github.com/go-redis/redis/internal/proto" -) - -// Implements proto.MultiBulkParse -func sliceParser(rd *proto.Reader, n int64) (interface{}, error) { - vals := make([]interface{}, 0, n) - for i := int64(0); i < n; i++ { - v, err := rd.ReadReply(sliceParser) - if err != nil { - if err == Nil { - vals = append(vals, nil) - continue - } - if err, ok := err.(proto.RedisError); ok { - vals = append(vals, err) - continue - } - return nil, err - } - - switch v := v.(type) { - case []byte: - vals = append(vals, string(v)) - default: - vals = append(vals, v) - } - } - return vals, nil -} - -// Implements proto.MultiBulkParse -func boolSliceParser(rd *proto.Reader, n int64) (interface{}, error) { - bools := make([]bool, 0, n) - for i := int64(0); i < n; i++ { - n, err := rd.ReadIntReply() - if err != nil { - return nil, err - } - bools = append(bools, n == 1) - } - return bools, nil -} - -// Implements proto.MultiBulkParse -func stringSliceParser(rd *proto.Reader, n int64) (interface{}, error) { - ss := make([]string, 0, n) - for i := int64(0); i < n; i++ { - s, err := rd.ReadStringReply() - if err == Nil { - ss = append(ss, "") - } else if err != nil { - return nil, err - } else { - ss = append(ss, s) - } - } - return ss, nil -} - -// Implements proto.MultiBulkParse -func stringStringMapParser(rd *proto.Reader, n int64) (interface{}, error) { - m := make(map[string]string, n/2) - for i := int64(0); i < n; i += 2 { - key, err := rd.ReadStringReply() - if err != nil { - return nil, err - } - - value, err := rd.ReadStringReply() - if err != nil { - return nil, err - } - - m[key] = value - } - return m, nil -} - -// Implements proto.MultiBulkParse -func stringIntMapParser(rd *proto.Reader, n int64) (interface{}, error) { - m := make(map[string]int64, n/2) - for i := int64(0); i < n; i += 2 { - key, err := rd.ReadStringReply() - if err != nil { - return nil, err - } - - n, err := rd.ReadIntReply() - if err != nil { - return nil, err - } - - m[key] = n - } - return m, nil -} - -// Implements proto.MultiBulkParse -func stringStructMapParser(rd *proto.Reader, n int64) (interface{}, error) { - m := make(map[string]struct{}, n) - for i := int64(0); i < n; i++ { - key, err := rd.ReadStringReply() - if err != nil { - return nil, err - } - - m[key] = struct{}{} - } - return m, nil -} - -// Implements proto.MultiBulkParse -func zSliceParser(rd *proto.Reader, n int64) (interface{}, error) { - zz := make([]Z, n/2) - for i := int64(0); i < n; i += 2 { - var err error - - z := &zz[i/2] - - z.Member, err = rd.ReadStringReply() - if err != nil { - return nil, err - } - - z.Score, err = rd.ReadFloatReply() - if err != nil { - return nil, err - } - } - return zz, nil -} - -// Implements proto.MultiBulkParse -func clusterSlotsParser(rd *proto.Reader, n int64) (interface{}, error) { - slots := make([]ClusterSlot, n) - for i := 0; i < len(slots); i++ { - n, err := rd.ReadArrayLen() - if err != nil { - return nil, err - } - if n < 2 { - err := fmt.Errorf("redis: got %d elements in cluster info, expected at least 2", n) - return nil, err - } - - start, err := rd.ReadIntReply() - if err != nil { - return nil, err - } - - end, err := rd.ReadIntReply() - if err != nil { - return nil, err - } - - nodes := make([]ClusterNode, n-2) - for j := 0; j < len(nodes); j++ { - n, err := rd.ReadArrayLen() - if err != nil { - return nil, err - } - if n != 2 && n != 3 { - err := fmt.Errorf("got %d elements in cluster info address, expected 2 or 3", n) - return nil, err - } - - ip, err := rd.ReadStringReply() - if err != nil { - return nil, err - } - - port, err := rd.ReadIntReply() - if err != nil { - return nil, err - } - nodes[j].Addr = net.JoinHostPort(ip, strconv.FormatInt(port, 10)) - - if n == 3 { - id, err := rd.ReadStringReply() - if err != nil { - return nil, err - } - nodes[j].Id = id - } - } - - slots[i] = ClusterSlot{ - Start: int(start), - End: int(end), - Nodes: nodes, - } - } - return slots, nil -} - -func newGeoLocationParser(q *GeoRadiusQuery) proto.MultiBulkParse { - return func(rd *proto.Reader, n int64) (interface{}, error) { - var loc GeoLocation - var err error - - loc.Name, err = rd.ReadStringReply() - if err != nil { - return nil, err - } - if q.WithDist { - loc.Dist, err = rd.ReadFloatReply() - if err != nil { - return nil, err - } - } - if q.WithGeoHash { - loc.GeoHash, err = rd.ReadIntReply() - if err != nil { - return nil, err - } - } - if q.WithCoord { - n, err := rd.ReadArrayLen() - if err != nil { - return nil, err - } - if n != 2 { - return nil, fmt.Errorf("got %d coordinates, expected 2", n) - } - - loc.Longitude, err = rd.ReadFloatReply() - if err != nil { - return nil, err - } - loc.Latitude, err = rd.ReadFloatReply() - if err != nil { - return nil, err - } - } - - return &loc, nil - } -} - -func newGeoLocationSliceParser(q *GeoRadiusQuery) proto.MultiBulkParse { - return func(rd *proto.Reader, n int64) (interface{}, error) { - locs := make([]GeoLocation, 0, n) - for i := int64(0); i < n; i++ { - v, err := rd.ReadReply(newGeoLocationParser(q)) - if err != nil { - return nil, err - } - switch vv := v.(type) { - case []byte: - locs = append(locs, GeoLocation{ - Name: string(vv), - }) - case *GeoLocation: - locs = append(locs, *vv) - default: - return nil, fmt.Errorf("got %T, expected string or *GeoLocation", v) - } - } - return locs, nil - } -} - -func geoPosParser(rd *proto.Reader, n int64) (interface{}, error) { - var pos GeoPos - var err error - - pos.Longitude, err = rd.ReadFloatReply() - if err != nil { - return nil, err - } - - pos.Latitude, err = rd.ReadFloatReply() - if err != nil { - return nil, err - } - - return &pos, nil -} - -func geoPosSliceParser(rd *proto.Reader, n int64) (interface{}, error) { - positions := make([]*GeoPos, 0, n) - for i := int64(0); i < n; i++ { - v, err := rd.ReadReply(geoPosParser) - if err != nil { - if err == Nil { - positions = append(positions, nil) - continue - } - return nil, err - } - switch v := v.(type) { - case *GeoPos: - positions = append(positions, v) - default: - return nil, fmt.Errorf("got %T, expected *GeoPos", v) - } - } - return positions, nil -} - -func commandInfoParser(rd *proto.Reader, n int64) (interface{}, error) { - var cmd CommandInfo - var err error - - if n != 6 { - return nil, fmt.Errorf("redis: got %d elements in COMMAND reply, wanted 6", n) - } - - cmd.Name, err = rd.ReadStringReply() - if err != nil { - return nil, err - } - - arity, err := rd.ReadIntReply() - if err != nil { - return nil, err - } - cmd.Arity = int8(arity) - - flags, err := rd.ReadReply(stringSliceParser) - if err != nil { - return nil, err - } - cmd.Flags = flags.([]string) - - firstKeyPos, err := rd.ReadIntReply() - if err != nil { - return nil, err - } - cmd.FirstKeyPos = int8(firstKeyPos) - - lastKeyPos, err := rd.ReadIntReply() - if err != nil { - return nil, err - } - cmd.LastKeyPos = int8(lastKeyPos) - - stepCount, err := rd.ReadIntReply() - if err != nil { - return nil, err - } - cmd.StepCount = int8(stepCount) - - for _, flag := range cmd.Flags { - if flag == "readonly" { - cmd.ReadOnly = true - break - } - } - - return &cmd, nil -} - -// Implements proto.MultiBulkParse -func commandInfoSliceParser(rd *proto.Reader, n int64) (interface{}, error) { - m := make(map[string]*CommandInfo, n) - for i := int64(0); i < n; i++ { - v, err := rd.ReadReply(commandInfoParser) - if err != nil { - return nil, err - } - vv := v.(*CommandInfo) - m[vv.Name] = vv - - } - return m, nil -} - -// Implements proto.MultiBulkParse -func timeParser(rd *proto.Reader, n int64) (interface{}, error) { - if n != 2 { - return nil, fmt.Errorf("got %d elements, expected 2", n) - } - - sec, err := rd.ReadInt() - if err != nil { - return nil, err - } - - microsec, err := rd.ReadInt() - if err != nil { - return nil, err - } - - return time.Unix(sec, microsec*1000), nil -} diff --git a/vendor/github.com/go-redis/redis/pipeline.go b/vendor/github.com/go-redis/redis/pipeline.go index 9349ef553..ba852283e 100644 --- a/vendor/github.com/go-redis/redis/pipeline.go +++ b/vendor/github.com/go-redis/redis/pipeline.go @@ -31,6 +31,7 @@ type Pipeline struct { closed bool } +// Process queues the cmd for later execution. func (c *Pipeline) Process(cmd Cmder) error { c.mu.Lock() c.cmds = append(c.cmds, cmd) diff --git a/vendor/github.com/go-redis/redis/pubsub.go b/vendor/github.com/go-redis/redis/pubsub.go index b56728f3e..b08f34ad2 100644 --- a/vendor/github.com/go-redis/redis/pubsub.go +++ b/vendor/github.com/go-redis/redis/pubsub.go @@ -2,20 +2,20 @@ package redis import ( "fmt" - "net" "sync" "time" "github.com/go-redis/redis/internal" "github.com/go-redis/redis/internal/pool" + "github.com/go-redis/redis/internal/proto" ) -// PubSub implements Pub/Sub commands as described in -// http://redis.io/topics/pubsub. It's NOT safe for concurrent use by -// multiple goroutines. +// PubSub implements Pub/Sub commands bas described in +// http://redis.io/topics/pubsub. Message receiving is NOT safe +// for concurrent use by multiple goroutines. // -// PubSub automatically resubscribes to the channels and patterns -// when Redis becomes unavailable. +// PubSub automatically reconnects to Redis Server and resubscribes +// to the channels in case of network errors. type PubSub struct { opt *Options @@ -27,11 +27,17 @@ type PubSub struct { channels map[string]struct{} patterns map[string]struct{} closed bool + exit chan struct{} cmd *Cmd chOnce sync.Once ch chan *Message + ping chan struct{} +} + +func (c *PubSub) init() { + c.exit = make(chan struct{}) } func (c *PubSub) conn() (*pool.Conn, error) { @@ -41,7 +47,7 @@ func (c *PubSub) conn() (*pool.Conn, error) { return cn, err } -func (c *PubSub) _conn(channels []string) (*pool.Conn, error) { +func (c *PubSub) _conn(newChannels []string) (*pool.Conn, error) { if c.closed { return nil, pool.ErrClosed } @@ -50,6 +56,9 @@ func (c *PubSub) _conn(channels []string) (*pool.Conn, error) { return c.cn, nil } + channels := mapKeys(c.channels) + channels = append(channels, newChannels...) + cn, err := c.newConn(channels) if err != nil { return nil, err @@ -64,61 +73,81 @@ func (c *PubSub) _conn(channels []string) (*pool.Conn, error) { return cn, nil } +func (c *PubSub) writeCmd(cn *pool.Conn, cmd Cmder) error { + return cn.WithWriter(c.opt.WriteTimeout, func(wr *proto.Writer) error { + return writeCmd(wr, cmd) + }) +} + func (c *PubSub) resubscribe(cn *pool.Conn) error { var firstErr error + if len(c.channels) > 0 { - channels := make([]string, len(c.channels)) - i := 0 - for channel := range c.channels { - channels[i] = channel - i++ - } - if err := c._subscribe(cn, "subscribe", channels...); err != nil && firstErr == nil { + err := c._subscribe(cn, "subscribe", mapKeys(c.channels)) + if err != nil && firstErr == nil { firstErr = err } } + if len(c.patterns) > 0 { - patterns := make([]string, len(c.patterns)) - i := 0 - for pattern := range c.patterns { - patterns[i] = pattern - i++ - } - if err := c._subscribe(cn, "psubscribe", patterns...); err != nil && firstErr == nil { + err := c._subscribe(cn, "psubscribe", mapKeys(c.patterns)) + if err != nil && firstErr == nil { firstErr = err } } + return firstErr } -func (c *PubSub) _subscribe(cn *pool.Conn, redisCmd string, channels ...string) error { - args := make([]interface{}, 1+len(channels)) - args[0] = redisCmd - for i, channel := range channels { - args[1+i] = channel +func mapKeys(m map[string]struct{}) []string { + s := make([]string, len(m)) + i := 0 + for k := range m { + s[i] = k + i++ } - cmd := NewSliceCmd(args...) + return s +} - cn.SetWriteTimeout(c.opt.WriteTimeout) - return writeCmd(cn, cmd) +func (c *PubSub) _subscribe( + cn *pool.Conn, redisCmd string, channels []string, +) error { + args := make([]interface{}, 0, 1+len(channels)) + args = append(args, redisCmd) + for _, channel := range channels { + args = append(args, channel) + } + cmd := NewSliceCmd(args...) + return c.writeCmd(cn, cmd) } -func (c *PubSub) releaseConn(cn *pool.Conn, err error) { +func (c *PubSub) releaseConn(cn *pool.Conn, err error, allowTimeout bool) { c.mu.Lock() - c._releaseConn(cn, err) + c._releaseConn(cn, err, allowTimeout) c.mu.Unlock() } -func (c *PubSub) _releaseConn(cn *pool.Conn, err error) { +func (c *PubSub) _releaseConn(cn *pool.Conn, err error, allowTimeout bool) { if c.cn != cn { return } - if internal.IsBadConn(err, true) { - _ = c.closeTheCn() + if internal.IsBadConn(err, allowTimeout) { + c._reconnect(err) } } -func (c *PubSub) closeTheCn() error { +func (c *PubSub) _reconnect(reason error) { + _ = c._closeTheCn(reason) + _, _ = c._conn(nil) +} + +func (c *PubSub) _closeTheCn(reason error) error { + if c.cn == nil { + return nil + } + if !c.closed { + internal.Logf("redis: discarding bad PubSub connection: %s", reason) + } err := c.closeConn(c.cn) c.cn = nil return err @@ -132,25 +161,25 @@ func (c *PubSub) Close() error { return pool.ErrClosed } c.closed = true + close(c.exit) - if c.cn != nil { - return c.closeTheCn() - } - return nil + err := c._closeTheCn(pool.ErrClosed) + return err } // Subscribe the client to the specified channels. It returns // empty subscription if there are no channels. func (c *PubSub) Subscribe(channels ...string) error { c.mu.Lock() + defer c.mu.Unlock() + err := c.subscribe("subscribe", channels...) if c.channels == nil { c.channels = make(map[string]struct{}) } - for _, channel := range channels { - c.channels[channel] = struct{}{} + for _, s := range channels { + c.channels[s] = struct{}{} } - c.mu.Unlock() return err } @@ -158,14 +187,15 @@ func (c *PubSub) Subscribe(channels ...string) error { // empty subscription if there are no patterns. func (c *PubSub) PSubscribe(patterns ...string) error { c.mu.Lock() + defer c.mu.Unlock() + err := c.subscribe("psubscribe", patterns...) if c.patterns == nil { c.patterns = make(map[string]struct{}) } - for _, pattern := range patterns { - c.patterns[pattern] = struct{}{} + for _, s := range patterns { + c.patterns[s] = struct{}{} } - c.mu.Unlock() return err } @@ -173,11 +203,12 @@ func (c *PubSub) PSubscribe(patterns ...string) error { // them if none is given. func (c *PubSub) Unsubscribe(channels ...string) error { c.mu.Lock() - err := c.subscribe("unsubscribe", channels...) + defer c.mu.Unlock() + for _, channel := range channels { delete(c.channels, channel) } - c.mu.Unlock() + err := c.subscribe("unsubscribe", channels...) return err } @@ -185,11 +216,12 @@ func (c *PubSub) Unsubscribe(channels ...string) error { // them if none is given. func (c *PubSub) PUnsubscribe(patterns ...string) error { c.mu.Lock() - err := c.subscribe("punsubscribe", patterns...) + defer c.mu.Unlock() + for _, pattern := range patterns { delete(c.patterns, pattern) } - c.mu.Unlock() + err := c.subscribe("punsubscribe", patterns...) return err } @@ -199,8 +231,8 @@ func (c *PubSub) subscribe(redisCmd string, channels ...string) error { return err } - err = c._subscribe(cn, redisCmd, channels...) - c._releaseConn(cn, err) + err = c._subscribe(cn, redisCmd, channels) + c._releaseConn(cn, err, false) return err } @@ -216,9 +248,8 @@ func (c *PubSub) Ping(payload ...string) error { return err } - cn.SetWriteTimeout(c.opt.WriteTimeout) - err = writeCmd(cn, cmd) - c.releaseConn(cn, err) + err = c.writeCmd(cn, cmd) + c.releaseConn(cn, err, false) return err } @@ -297,8 +328,8 @@ func (c *PubSub) newMessage(reply interface{}) (interface{}, error) { } // ReceiveTimeout acts like Receive but returns an error if message -// is not received in time. This is low-level API and most clients -// should use ReceiveMessage. +// is not received in time. This is low-level API and in most cases +// Channel should be used instead. func (c *PubSub) ReceiveTimeout(timeout time.Duration) (interface{}, error) { if c.cmd == nil { c.cmd = NewCmd() @@ -309,9 +340,11 @@ func (c *PubSub) ReceiveTimeout(timeout time.Duration) (interface{}, error) { return nil, err } - cn.SetReadTimeout(timeout) - err = c.cmd.readReply(cn) - c.releaseConn(cn, err) + err = cn.WithReader(timeout, func(rd *proto.Reader) error { + return c.cmd.readReply(rd) + }) + + c.releaseConn(cn, err, timeout > 0) if err != nil { return nil, err } @@ -320,49 +353,23 @@ func (c *PubSub) ReceiveTimeout(timeout time.Duration) (interface{}, error) { } // Receive returns a message as a Subscription, Message, Pong or error. -// See PubSub example for details. This is low-level API and most clients -// should use ReceiveMessage. +// See PubSub example for details. This is low-level API and in most cases +// Channel should be used instead. func (c *PubSub) Receive() (interface{}, error) { return c.ReceiveTimeout(0) } -// ReceiveMessage returns a Message or error ignoring Subscription or Pong -// messages. It automatically reconnects to Redis Server and resubscribes -// to channels in case of network errors. +// ReceiveMessage returns a Message or error ignoring Subscription and Pong +// messages. This is low-level API and in most cases Channel should be used +// instead. func (c *PubSub) ReceiveMessage() (*Message, error) { - return c.receiveMessage(5 * time.Second) -} - -func (c *PubSub) receiveMessage(timeout time.Duration) (*Message, error) { - var errNum uint for { - msgi, err := c.ReceiveTimeout(timeout) + msg, err := c.Receive() if err != nil { - if !internal.IsNetworkError(err) { - return nil, err - } - - errNum++ - if errNum < 3 { - if netErr, ok := err.(net.Error); ok && netErr.Timeout() { - err := c.Ping() - if err != nil { - internal.Logf("PubSub.Ping failed: %s", err) - } - } - } else { - // 3 consequent errors - connection is broken or - // Redis Server is down. - // Sleep to not exceed max number of open connections. - time.Sleep(time.Second) - } - continue + return nil, err } - // Reset error number, because we received a message. - errNum = 0 - - switch msg := msgi.(type) { + switch msg := msg.(type) { case *Subscription: // Ignore. case *Pong: @@ -370,30 +377,93 @@ func (c *PubSub) receiveMessage(timeout time.Duration) (*Message, error) { case *Message: return msg, nil default: - return nil, fmt.Errorf("redis: unknown message: %T", msgi) + err := fmt.Errorf("redis: unknown message: %T", msg) + return nil, err } } } // Channel returns a Go channel for concurrently receiving messages. -// The channel is closed with PubSub. Receive or ReceiveMessage APIs -// can not be used after channel is created. +// It periodically sends Ping messages to test connection health. +// The channel is closed with PubSub. Receive* APIs can not be used +// after channel is created. func (c *PubSub) Channel() <-chan *Message { - c.chOnce.Do(func() { - c.ch = make(chan *Message, 100) - go func() { - for { - msg, err := c.ReceiveMessage() - if err != nil { - if err == pool.ErrClosed { - break - } - continue + c.chOnce.Do(c.initChannel) + return c.ch +} + +func (c *PubSub) initChannel() { + c.ch = make(chan *Message, 100) + c.ping = make(chan struct{}, 10) + + go func() { + var errCount int + for { + msg, err := c.Receive() + if err != nil { + if err == pool.ErrClosed { + close(c.ch) + return } + if errCount > 0 { + time.Sleep(c.retryBackoff(errCount)) + } + errCount++ + continue + } + errCount = 0 + + // Any message is as good as a ping. + select { + case c.ping <- struct{}{}: + default: + } + + switch msg := msg.(type) { + case *Subscription: + // Ignore. + case *Pong: + // Ignore. + case *Message: c.ch <- msg + default: + internal.Logf("redis: unknown message: %T", msg) } - close(c.ch) - }() - }) - return c.ch + } + }() + + go func() { + const timeout = 5 * time.Second + + timer := time.NewTimer(timeout) + timer.Stop() + + healthy := true + var pingErr error + for { + timer.Reset(timeout) + select { + case <-c.ping: + healthy = true + if !timer.Stop() { + <-timer.C + } + case <-timer.C: + pingErr = c.Ping() + if healthy { + healthy = false + } else { + c.mu.Lock() + c._reconnect(pingErr) + c.mu.Unlock() + } + case <-c.exit: + return + } + } + }() +} + +func (c *PubSub) retryBackoff(attempt int) time.Duration { + return internal.RetryBackoff(attempt, c.opt.MinRetryBackoff, c.opt.MaxRetryBackoff) } diff --git a/vendor/github.com/go-redis/redis/redis.go b/vendor/github.com/go-redis/redis/redis.go index beb632e1e..3e72bf060 100644 --- a/vendor/github.com/go-redis/redis/redis.go +++ b/vendor/github.com/go-redis/redis/redis.go @@ -50,7 +50,7 @@ func (c *baseClient) newConn() (*pool.Conn, error) { return nil, err } - if !cn.Inited { + if cn.InitedAt.IsZero() { if err := c.initConn(cn); err != nil { _ = c.connPool.CloseConn(cn) return nil, err @@ -66,7 +66,7 @@ func (c *baseClient) getConn() (*pool.Conn, error) { return nil, err } - if !cn.Inited { + if cn.InitedAt.IsZero() { err := c.initConn(cn) if err != nil { c.connPool.Remove(cn) @@ -88,7 +88,7 @@ func (c *baseClient) releaseConn(cn *pool.Conn, err error) bool { } func (c *baseClient) initConn(cn *pool.Conn) error { - cn.Inited = true + cn.InitedAt = time.Now() if c.opt.Password == "" && c.opt.DB == 0 && @@ -123,8 +123,17 @@ func (c *baseClient) initConn(cn *pool.Conn) error { return nil } +// Do creates a Cmd from the args and processes the cmd. +func (c *baseClient) Do(args ...interface{}) *Cmd { + cmd := NewCmd(args...) + c.Process(cmd) + return cmd +} + // WrapProcess wraps function that processes Redis commands. -func (c *baseClient) WrapProcess(fn func(oldProcess func(cmd Cmder) error) func(cmd Cmder) error) { +func (c *baseClient) WrapProcess( + fn func(oldProcess func(cmd Cmder) error) func(cmd Cmder) error, +) { c.process = fn(c.process) } @@ -147,8 +156,10 @@ func (c *baseClient) defaultProcess(cmd Cmder) error { return err } - cn.SetWriteTimeout(c.opt.WriteTimeout) - if err := writeCmd(cn, cmd); err != nil { + err = cn.WithWriter(c.opt.WriteTimeout, func(wr *proto.Writer) error { + return writeCmd(wr, cmd) + }) + if err != nil { c.releaseConn(cn, err) cmd.setErr(err) if internal.IsRetryableError(err, true) { @@ -157,8 +168,9 @@ func (c *baseClient) defaultProcess(cmd Cmder) error { return err } - cn.SetReadTimeout(c.cmdTimeout(cmd)) - err = cmd.readReply(cn) + err = cn.WithReader(c.cmdTimeout(cmd), func(rd *proto.Reader) error { + return cmd.readReply(rd) + }) c.releaseConn(cn, err) if err != nil && internal.IsRetryableError(err, cmd.readTimeout() == nil) { continue @@ -176,9 +188,8 @@ func (c *baseClient) retryBackoff(attempt int) time.Duration { func (c *baseClient) cmdTimeout(cmd Cmder) time.Duration { if timeout := cmd.readTimeout(); timeout != nil { - return *timeout + return readTimeout(*timeout) } - return c.opt.ReadTimeout } @@ -244,24 +255,27 @@ func (c *baseClient) generalProcessPipeline(cmds []Cmder, p pipelineProcessor) e break } } - return firstCmdsErr(cmds) + return cmdsFirstErr(cmds) } func (c *baseClient) pipelineProcessCmds(cn *pool.Conn, cmds []Cmder) (bool, error) { - cn.SetWriteTimeout(c.opt.WriteTimeout) - if err := writeCmd(cn, cmds...); err != nil { + err := cn.WithWriter(c.opt.WriteTimeout, func(wr *proto.Writer) error { + return writeCmd(wr, cmds...) + }) + if err != nil { setCmdsErr(cmds, err) return true, err } - // Set read timeout for all commands. - cn.SetReadTimeout(c.opt.ReadTimeout) - return true, pipelineReadCmds(cn, cmds) + err = cn.WithReader(c.opt.ReadTimeout, func(rd *proto.Reader) error { + return pipelineReadCmds(rd, cmds) + }) + return true, err } -func pipelineReadCmds(cn *pool.Conn, cmds []Cmder) error { +func pipelineReadCmds(rd *proto.Reader, cmds []Cmder) error { for _, cmd := range cmds { - err := cmd.readReply(cn) + err := cmd.readReply(rd) if err != nil && !internal.IsRedisError(err) { return err } @@ -270,47 +284,50 @@ func pipelineReadCmds(cn *pool.Conn, cmds []Cmder) error { } func (c *baseClient) txPipelineProcessCmds(cn *pool.Conn, cmds []Cmder) (bool, error) { - cn.SetWriteTimeout(c.opt.WriteTimeout) - if err := txPipelineWriteMulti(cn, cmds); err != nil { + err := cn.WithWriter(c.opt.WriteTimeout, func(wr *proto.Writer) error { + return txPipelineWriteMulti(wr, cmds) + }) + if err != nil { setCmdsErr(cmds, err) return true, err } - // Set read timeout for all commands. - cn.SetReadTimeout(c.opt.ReadTimeout) - - if err := c.txPipelineReadQueued(cn, cmds); err != nil { - setCmdsErr(cmds, err) - return false, err - } - - return false, pipelineReadCmds(cn, cmds) + err = cn.WithReader(c.opt.ReadTimeout, func(rd *proto.Reader) error { + err := txPipelineReadQueued(rd, cmds) + if err != nil { + setCmdsErr(cmds, err) + return err + } + return pipelineReadCmds(rd, cmds) + }) + return false, err } -func txPipelineWriteMulti(cn *pool.Conn, cmds []Cmder) error { +func txPipelineWriteMulti(wr *proto.Writer, cmds []Cmder) error { multiExec := make([]Cmder, 0, len(cmds)+2) multiExec = append(multiExec, NewStatusCmd("MULTI")) multiExec = append(multiExec, cmds...) multiExec = append(multiExec, NewSliceCmd("EXEC")) - return writeCmd(cn, multiExec...) + return writeCmd(wr, multiExec...) } -func (c *baseClient) txPipelineReadQueued(cn *pool.Conn, cmds []Cmder) error { +func txPipelineReadQueued(rd *proto.Reader, cmds []Cmder) error { // Parse queued replies. var statusCmd StatusCmd - if err := statusCmd.readReply(cn); err != nil { + err := statusCmd.readReply(rd) + if err != nil { return err } for _ = range cmds { - err := statusCmd.readReply(cn) + err = statusCmd.readReply(rd) if err != nil && !internal.IsRedisError(err) { return err } } // Parse number of replies. - line, err := cn.Rd.ReadLine() + line, err := rd.ReadLine() if err != nil { if err == Nil { err = TxFailedErr @@ -424,7 +441,7 @@ func (c *Client) TxPipeline() Pipeliner { } func (c *Client) pubSub() *PubSub { - return &PubSub{ + pubsub := &PubSub{ opt: c.opt, newConn: func(channels []string) (*pool.Conn, error) { @@ -432,6 +449,8 @@ func (c *Client) pubSub() *PubSub { }, closeConn: c.connPool.CloseConn, } + pubsub.init() + return pubsub } // Subscribe subscribes the client to the specified channels. diff --git a/vendor/github.com/go-redis/redis/result.go b/vendor/github.com/go-redis/redis/result.go index e086e8e34..e438f260b 100644 --- a/vendor/github.com/go-redis/redis/result.go +++ b/vendor/github.com/go-redis/redis/result.go @@ -53,7 +53,7 @@ func NewBoolResult(val bool, err error) *BoolCmd { // NewStringResult returns a StringCmd initialised with val and err for testing func NewStringResult(val string, err error) *StringCmd { var cmd StringCmd - cmd.val = []byte(val) + cmd.val = val cmd.setErr(err) return &cmd } diff --git a/vendor/github.com/go-redis/redis/ring.go b/vendor/github.com/go-redis/redis/ring.go index b47a1094e..3ded28060 100644 --- a/vendor/github.com/go-redis/redis/ring.go +++ b/vendor/github.com/go-redis/redis/ring.go @@ -16,7 +16,8 @@ import ( "github.com/go-redis/redis/internal/pool" ) -const nreplicas = 100 +// Hash is type of hash function used in consistent hash. +type Hash consistenthash.Hash var errRingShardsDown = errors.New("redis: all ring shards are down") @@ -30,6 +31,27 @@ type RingOptions struct { // Shard is considered down after 3 subsequent failed checks. HeartbeatFrequency time.Duration + // Hash function used in consistent hash. + // Default is crc32.ChecksumIEEE. + Hash Hash + + // Number of replicas in consistent hash. + // Default is 100 replicas. + // + // Higher number of replicas will provide less deviation, that is keys will be + // distributed to nodes more evenly. + // + // Following is deviation for common nreplicas: + // -------------------------------------------------------- + // | nreplicas | standard error | 99% confidence interval | + // | 10 | 0.3152 | (0.37, 1.98) | + // | 100 | 0.0997 | (0.76, 1.28) | + // | 1000 | 0.0316 | (0.92, 1.09) | + // -------------------------------------------------------- + // + // See https://arxiv.org/abs/1406.2294 for reference + HashReplicas int + // Following options are copied from Options struct. OnConnect func(*Conn) error @@ -46,6 +68,8 @@ type RingOptions struct { WriteTimeout time.Duration PoolSize int + MinIdleConns int + MaxConnAge time.Duration PoolTimeout time.Duration IdleTimeout time.Duration IdleCheckFrequency time.Duration @@ -56,6 +80,10 @@ func (opt *RingOptions) init() { opt.HeartbeatFrequency = 500 * time.Millisecond } + if opt.HashReplicas == 0 { + opt.HashReplicas = 100 + } + switch opt.MinRetryBackoff { case -1: opt.MinRetryBackoff = 0 @@ -82,6 +110,8 @@ func (opt *RingOptions) clientOptions() *Options { WriteTimeout: opt.WriteTimeout, PoolSize: opt.PoolSize, + MinIdleConns: opt.MinIdleConns, + MaxConnAge: opt.MaxConnAge, PoolTimeout: opt.PoolTimeout, IdleTimeout: opt.IdleTimeout, IdleCheckFrequency: opt.IdleCheckFrequency, @@ -133,16 +163,21 @@ func (shard *ringShard) Vote(up bool) bool { //------------------------------------------------------------------------------ type ringShards struct { + opt *RingOptions + mu sync.RWMutex hash *consistenthash.Map shards map[string]*ringShard // read only list []*ringShard // read only + len int closed bool } -func newRingShards() *ringShards { +func newRingShards(opt *RingOptions) *ringShards { return &ringShards{ - hash: consistenthash.New(nreplicas, nil), + opt: opt, + + hash: newConsistentHash(opt), shards: make(map[string]*ringShard), } } @@ -238,18 +273,28 @@ func (c *ringShards) Heartbeat(frequency time.Duration) { // rebalance removes dead shards from the Ring. func (c *ringShards) rebalance() { - hash := consistenthash.New(nreplicas, nil) + hash := newConsistentHash(c.opt) + var shardsNum int for name, shard := range c.shards { if shard.IsUp() { hash.Add(name) + shardsNum++ } } c.mu.Lock() c.hash = hash + c.len = shardsNum c.mu.Unlock() } +func (c *ringShards) Len() int { + c.mu.RLock() + l := c.len + c.mu.RUnlock() + return l +} + func (c *ringShards) Close() error { c.mu.Lock() defer c.mu.Unlock() @@ -305,7 +350,7 @@ func NewRing(opt *RingOptions) *Ring { ring := &Ring{ opt: opt, - shards: newRingShards(), + shards: newRingShards(opt), } ring.cmdsInfoCache = newCmdsInfoCache(ring.cmdsInfo) @@ -363,11 +408,16 @@ func (c *Ring) PoolStats() *PoolStats { acc.Misses += s.Misses acc.Timeouts += s.Timeouts acc.TotalConns += s.TotalConns - acc.FreeConns += s.FreeConns + acc.IdleConns += s.IdleConns } return &acc } +// Len returns the current number of shards in the ring. +func (c *Ring) Len() int { + return c.shards.Len() +} + // Subscribe subscribes the client to the specified channels. func (c *Ring) Subscribe(channels ...string) *PubSub { if len(channels) == 0 { @@ -466,7 +516,16 @@ func (c *Ring) cmdShard(cmd Cmder) (*ringShard, error) { return c.shards.GetByKey(firstKey) } -func (c *Ring) WrapProcess(fn func(oldProcess func(cmd Cmder) error) func(cmd Cmder) error) { +// Do creates a Cmd from the args and processes the cmd. +func (c *Ring) Do(args ...interface{}) *Cmd { + cmd := NewCmd(args...) + c.Process(cmd) + return cmd +} + +func (c *Ring) WrapProcess( + fn func(oldProcess func(cmd Cmder) error) func(cmd Cmder) error, +) { c.ForEachShard(func(c *Client) error { c.WrapProcess(fn) return nil @@ -552,7 +611,7 @@ func (c *Ring) defaultProcessPipeline(cmds []Cmder) error { cmdsMap = failedCmdsMap } - return firstCmdsErr(cmds) + return cmdsFirstErr(cmds) } func (c *Ring) TxPipeline() Pipeliner { @@ -570,3 +629,7 @@ func (c *Ring) TxPipelined(fn func(Pipeliner) error) ([]Cmder, error) { func (c *Ring) Close() error { return c.shards.Close() } + +func newConsistentHash(opt *RingOptions) *consistenthash.Map { + return consistenthash.New(opt.HashReplicas, consistenthash.Hash(opt.Hash)) +} diff --git a/vendor/github.com/go-redis/redis/sentinel.go b/vendor/github.com/go-redis/redis/sentinel.go index 3cedf36ee..c5f71493d 100644 --- a/vendor/github.com/go-redis/redis/sentinel.go +++ b/vendor/github.com/go-redis/redis/sentinel.go @@ -29,13 +29,17 @@ type FailoverOptions struct { Password string DB int - MaxRetries int + MaxRetries int + MinRetryBackoff time.Duration + MaxRetryBackoff time.Duration DialTimeout time.Duration ReadTimeout time.Duration WriteTimeout time.Duration PoolSize int + MinIdleConns int + MaxConnAge time.Duration PoolTimeout time.Duration IdleTimeout time.Duration IdleCheckFrequency time.Duration @@ -92,7 +96,7 @@ func NewFailoverClient(failoverOpt *FailoverOptions) *Client { }, } c.baseClient.init() - c.setProcessor(c.Process) + c.cmdable.setProcessor(c.Process) return &c } @@ -116,7 +120,7 @@ func NewSentinelClient(opt *Options) *SentinelClient { } func (c *SentinelClient) PubSub() *PubSub { - return &PubSub{ + pubsub := &PubSub{ opt: c.opt, newConn: func(channels []string) (*pool.Conn, error) { @@ -124,6 +128,8 @@ func (c *SentinelClient) PubSub() *PubSub { }, closeConn: c.connPool.CloseConn, } + pubsub.init() + return pubsub } func (c *SentinelClient) GetMasterAddrByName(name string) *StringSliceCmd { @@ -180,10 +186,7 @@ func (d *sentinelFailover) MasterAddr() (string, error) { if err != nil { return "", err } - - if d._masterAddr != addr { - d.switchMaster(addr) - } + d._switchMaster(addr) return addr, nil } @@ -194,11 +197,11 @@ func (d *sentinelFailover) masterAddr() (string, error) { addr, err := d.sentinel.GetMasterAddrByName(d.masterName).Result() if err == nil { addr := net.JoinHostPort(addr[0], addr[1]) - internal.Logf("sentinel: master=%q addr=%q", d.masterName, addr) return addr, nil } - internal.Logf("sentinel: GetMasterAddrByName name=%q failed: %s", d.masterName, err) + internal.Logf("sentinel: GetMasterAddrByName name=%q failed: %s", + d.masterName, err) d._resetSentinel() } @@ -234,15 +237,23 @@ func (d *sentinelFailover) masterAddr() (string, error) { return "", errors.New("redis: all sentinels are unreachable") } -func (d *sentinelFailover) switchMaster(masterAddr string) { - internal.Logf( - "sentinel: new master=%q addr=%q", - d.masterName, masterAddr, - ) - _ = d.Pool().Filter(func(cn *pool.Conn) bool { - return cn.RemoteAddr().String() != masterAddr +func (c *sentinelFailover) switchMaster(addr string) { + c.mu.Lock() + c._switchMaster(addr) + c.mu.Unlock() +} + +func (c *sentinelFailover) _switchMaster(addr string) { + if c._masterAddr == addr { + return + } + + internal.Logf("sentinel: new master=%q addr=%q", + c.masterName, addr) + _ = c.Pool().Filter(func(cn *pool.Conn) bool { + return cn.RemoteAddr().String() != addr }) - d._masterAddr = masterAddr + c._masterAddr = addr } func (d *sentinelFailover) setSentinel(sentinel *SentinelClient) { @@ -292,27 +303,25 @@ func (d *sentinelFailover) discoverSentinels(sentinel *SentinelClient) { } func (d *sentinelFailover) listen(sentinel *SentinelClient) { - var pubsub *PubSub - for { - if pubsub == nil { - pubsub = sentinel.PubSub() + pubsub := sentinel.PubSub() + defer pubsub.Close() - if err := pubsub.Subscribe("+switch-master"); err != nil { - internal.Logf("sentinel: Subscribe failed: %s", err) - pubsub.Close() - d.resetSentinel() - return - } - } + err := pubsub.Subscribe("+switch-master") + if err != nil { + internal.Logf("sentinel: Subscribe failed: %s", err) + d.resetSentinel() + return + } + for { msg, err := pubsub.ReceiveMessage() if err != nil { - if err != pool.ErrClosed { - internal.Logf("sentinel: ReceiveMessage failed: %s", err) - pubsub.Close() + if err == pool.ErrClosed { + d.resetSentinel() + return } - d.resetSentinel() - return + internal.Logf("sentinel: ReceiveMessage failed: %s", err) + continue } switch msg.Channel { @@ -323,12 +332,7 @@ func (d *sentinelFailover) listen(sentinel *SentinelClient) { continue } addr := net.JoinHostPort(parts[3], parts[4]) - - d.mu.Lock() - if d._masterAddr != addr { - d.switchMaster(addr) - } - d.mu.Unlock() + d.switchMaster(addr) } } } diff --git a/vendor/github.com/go-redis/redis/tx.go b/vendor/github.com/go-redis/redis/tx.go index 6a753b6a0..6a7da99dd 100644 --- a/vendor/github.com/go-redis/redis/tx.go +++ b/vendor/github.com/go-redis/redis/tx.go @@ -29,6 +29,10 @@ func (c *Client) newTx() *Tx { return &tx } +// Watch prepares a transcaction and marks the keys to be watched +// for conditional execution if there are any keys. +// +// The transaction is automatically closed when the fn exits. func (c *Client) Watch(fn func(*Tx) error, keys ...string) error { tx := c.newTx() if len(keys) > 0 { @@ -74,6 +78,7 @@ func (c *Tx) Unwatch(keys ...string) *StatusCmd { return cmd } +// Pipeline creates a new pipeline. It is more convenient to use Pipelined. func (c *Tx) Pipeline() Pipeliner { pipe := Pipeline{ exec: c.processTxPipeline, @@ -82,23 +87,24 @@ func (c *Tx) Pipeline() Pipeliner { return &pipe } -// Pipelined executes commands queued in the fn in a transaction -// and restores the connection state to normal. +// Pipelined executes commands queued in the fn in a transaction. // // When using WATCH, EXEC will execute commands only if the watched keys // were not modified, allowing for a check-and-set mechanism. // // Exec always returns list of commands. If transaction fails -// TxFailedErr is returned. Otherwise Exec returns error of the first +// TxFailedErr is returned. Otherwise Exec returns an error of the first // failed command or nil. func (c *Tx) Pipelined(fn func(Pipeliner) error) ([]Cmder, error) { return c.Pipeline().Pipelined(fn) } +// TxPipelined is an alias for Pipelined. func (c *Tx) TxPipelined(fn func(Pipeliner) error) ([]Cmder, error) { return c.Pipelined(fn) } +// TxPipeline is an alias for Pipeline. func (c *Tx) TxPipeline() Pipeliner { return c.Pipeline() } diff --git a/vendor/github.com/go-redis/redis/universal.go b/vendor/github.com/go-redis/redis/universal.go index 9e30c81d9..a60756246 100644 --- a/vendor/github.com/go-redis/redis/universal.go +++ b/vendor/github.com/go-redis/redis/universal.go @@ -12,35 +12,38 @@ type UniversalOptions struct { // of cluster/sentinel nodes. Addrs []string - // The sentinel master name. - // Only failover clients. - MasterName string - // Database to be selected after connecting to the server. // Only single-node and failover clients. DB int - // Only cluster clients. - - // Enables read only queries on slave nodes. - ReadOnly bool - - MaxRedirects int - RouteByLatency bool - - // Common options + // Common options. OnConnect func(*Conn) error - MaxRetries int Password string + MaxRetries int + MinRetryBackoff time.Duration + MaxRetryBackoff time.Duration DialTimeout time.Duration ReadTimeout time.Duration WriteTimeout time.Duration PoolSize int + MinIdleConns int + MaxConnAge time.Duration PoolTimeout time.Duration IdleTimeout time.Duration IdleCheckFrequency time.Duration TLSConfig *tls.Config + + // Only cluster clients. + + MaxRedirects int + ReadOnly bool + RouteByLatency bool + RouteRandomly bool + + // The sentinel master name. + // Only failover clients. + MasterName string } func (o *UniversalOptions) cluster() *ClusterOptions { @@ -49,22 +52,31 @@ func (o *UniversalOptions) cluster() *ClusterOptions { } return &ClusterOptions{ - Addrs: o.Addrs, + Addrs: o.Addrs, + OnConnect: o.OnConnect, + + Password: o.Password, + MaxRedirects: o.MaxRedirects, - RouteByLatency: o.RouteByLatency, ReadOnly: o.ReadOnly, + RouteByLatency: o.RouteByLatency, + RouteRandomly: o.RouteRandomly, + + MaxRetries: o.MaxRetries, + MinRetryBackoff: o.MinRetryBackoff, + MaxRetryBackoff: o.MaxRetryBackoff, - OnConnect: o.OnConnect, - MaxRetries: o.MaxRetries, - Password: o.Password, DialTimeout: o.DialTimeout, ReadTimeout: o.ReadTimeout, WriteTimeout: o.WriteTimeout, PoolSize: o.PoolSize, + MinIdleConns: o.MinIdleConns, + MaxConnAge: o.MaxConnAge, PoolTimeout: o.PoolTimeout, IdleTimeout: o.IdleTimeout, IdleCheckFrequency: o.IdleCheckFrequency, - TLSConfig: o.TLSConfig, + + TLSConfig: o.TLSConfig, } } @@ -76,19 +88,27 @@ func (o *UniversalOptions) failover() *FailoverOptions { return &FailoverOptions{ SentinelAddrs: o.Addrs, MasterName: o.MasterName, - DB: o.DB, + OnConnect: o.OnConnect, + + DB: o.DB, + Password: o.Password, + + MaxRetries: o.MaxRetries, + MinRetryBackoff: o.MinRetryBackoff, + MaxRetryBackoff: o.MaxRetryBackoff, + + DialTimeout: o.DialTimeout, + ReadTimeout: o.ReadTimeout, + WriteTimeout: o.WriteTimeout, - OnConnect: o.OnConnect, - MaxRetries: o.MaxRetries, - Password: o.Password, - DialTimeout: o.DialTimeout, - ReadTimeout: o.ReadTimeout, - WriteTimeout: o.WriteTimeout, PoolSize: o.PoolSize, + MinIdleConns: o.MinIdleConns, + MaxConnAge: o.MaxConnAge, PoolTimeout: o.PoolTimeout, IdleTimeout: o.IdleTimeout, IdleCheckFrequency: o.IdleCheckFrequency, - TLSConfig: o.TLSConfig, + + TLSConfig: o.TLSConfig, } } @@ -99,20 +119,28 @@ func (o *UniversalOptions) simple() *Options { } return &Options{ - Addr: addr, - DB: o.DB, + Addr: addr, + OnConnect: o.OnConnect, + + DB: o.DB, + Password: o.Password, + + MaxRetries: o.MaxRetries, + MinRetryBackoff: o.MinRetryBackoff, + MaxRetryBackoff: o.MaxRetryBackoff, + + DialTimeout: o.DialTimeout, + ReadTimeout: o.ReadTimeout, + WriteTimeout: o.WriteTimeout, - OnConnect: o.OnConnect, - MaxRetries: o.MaxRetries, - Password: o.Password, - DialTimeout: o.DialTimeout, - ReadTimeout: o.ReadTimeout, - WriteTimeout: o.WriteTimeout, PoolSize: o.PoolSize, + MinIdleConns: o.MinIdleConns, + MaxConnAge: o.MaxConnAge, PoolTimeout: o.PoolTimeout, IdleTimeout: o.IdleTimeout, IdleCheckFrequency: o.IdleCheckFrequency, - TLSConfig: o.TLSConfig, + + TLSConfig: o.TLSConfig, } } |