diff options
Diffstat (limited to 'vendor/github.com/go-redis/redis/cluster.go')
-rw-r--r-- | vendor/github.com/go-redis/redis/cluster.go | 244 |
1 files changed, 187 insertions, 57 deletions
diff --git a/vendor/github.com/go-redis/redis/cluster.go b/vendor/github.com/go-redis/redis/cluster.go index 4a2951157..0c58c8532 100644 --- a/vendor/github.com/go-redis/redis/cluster.go +++ b/vendor/github.com/go-redis/redis/cluster.go @@ -2,11 +2,13 @@ package redis import ( "context" + "crypto/tls" "errors" "fmt" "math" "math/rand" "net" + "strings" "sync" "sync/atomic" "time" @@ -34,6 +36,7 @@ type ClusterOptions struct { // Enables read-only commands on slave nodes. ReadOnly bool // Allows routing read-only commands to the closest master or slave node. + // It automatically enables ReadOnly. RouteByLatency bool // Allows routing read-only commands to the random master or slave node. RouteRandomly bool @@ -56,6 +59,8 @@ type ClusterOptions struct { PoolTimeout time.Duration IdleTimeout time.Duration IdleCheckFrequency time.Duration + + TLSConfig *tls.Config } func (opt *ClusterOptions) init() { @@ -117,6 +122,8 @@ func (opt *ClusterOptions) clientOptions() *Options { IdleTimeout: opt.IdleTimeout, IdleCheckFrequency: disableIdleCheck, + + TLSConfig: opt.TLSConfig, } } @@ -145,6 +152,10 @@ func newClusterNode(clOpt *ClusterOptions, addr string) *clusterNode { return &node } +func (n *clusterNode) String() string { + return n.Client.String() +} + func (n *clusterNode) Close() error { return n.Client.Close() } @@ -215,7 +226,7 @@ type clusterNodes struct { nodeCreateGroup singleflight.Group - generation uint32 + _generation uint32 // atomic } func newClusterNodes(opt *ClusterOptions) *clusterNodes { @@ -272,8 +283,7 @@ func (c *clusterNodes) Addrs() ([]string, error) { } func (c *clusterNodes) NextGeneration() uint32 { - c.generation++ - return c.generation + return atomic.AddUint32(&c._generation, 1) } // GC removes unused nodes. @@ -296,10 +306,9 @@ func (c *clusterNodes) GC(generation uint32) { } } -func (c *clusterNodes) GetOrCreate(addr string) (*clusterNode, error) { +func (c *clusterNodes) Get(addr string) (*clusterNode, error) { var node *clusterNode var err error - c.mu.RLock() if c.closed { err = pool.ErrClosed @@ -307,6 +316,11 @@ func (c *clusterNodes) GetOrCreate(addr string) (*clusterNode, error) { node = c.allNodes[addr] } c.mu.RUnlock() + return node, err +} + +func (c *clusterNodes) GetOrCreate(addr string) (*clusterNode, error) { + node, err := c.Get(addr) if err != nil { return nil, err } @@ -371,20 +385,25 @@ func (c *clusterNodes) Random() (*clusterNode, error) { type clusterState struct { nodes *clusterNodes - masters []*clusterNode - slaves []*clusterNode + Masters []*clusterNode + Slaves []*clusterNode slots [][]*clusterNode generation uint32 + createdAt time.Time } -func newClusterState(nodes *clusterNodes, slots []ClusterSlot, origin string) (*clusterState, error) { +func newClusterState( + nodes *clusterNodes, slots []ClusterSlot, origin string, +) (*clusterState, error) { c := clusterState{ - nodes: nodes, - generation: nodes.NextGeneration(), + nodes: nodes, slots: make([][]*clusterNode, hashtag.SlotNumber), + + generation: nodes.NextGeneration(), + createdAt: time.Now(), } isLoopbackOrigin := isLoopbackAddr(origin) @@ -392,7 +411,7 @@ func newClusterState(nodes *clusterNodes, slots []ClusterSlot, origin string) (* var nodes []*clusterNode for i, slotNode := range slot.Nodes { addr := slotNode.Addr - if !isLoopbackOrigin && isLoopbackAddr(addr) { + if !isLoopbackOrigin && useOriginAddr(origin, addr) { addr = origin } @@ -405,9 +424,9 @@ func newClusterState(nodes *clusterNodes, slots []ClusterSlot, origin string) (* nodes = append(nodes, node) if i == 0 { - c.masters = appendNode(c.masters, node) + c.Masters = appendUniqueNode(c.Masters, node) } else { - c.slaves = appendNode(c.slaves, node) + c.Slaves = appendUniqueNode(c.Slaves, node) } } @@ -489,6 +508,28 @@ func (c *clusterState) slotNodes(slot int) []*clusterNode { return nil } +func (c *clusterState) IsConsistent() bool { + if len(c.Masters) > len(c.Slaves) { + return false + } + + for _, master := range c.Masters { + s := master.Client.Info("replication").Val() + if !strings.Contains(s, "role:master") { + return false + } + } + + for _, slave := range c.Slaves { + s := slave.Client.Info("replication").Val() + if !strings.Contains(s, "role:slave") { + return false + } + } + + return true +} + //------------------------------------------------------------------------------ type clusterStateHolder struct { @@ -496,8 +537,8 @@ type clusterStateHolder struct { state atomic.Value - lastErrMu sync.RWMutex - lastErr error + firstErrMu sync.RWMutex + firstErr error reloading uint32 // atomic } @@ -508,12 +549,25 @@ func newClusterStateHolder(fn func() (*clusterState, error)) *clusterStateHolder } } -func (c *clusterStateHolder) Load() (*clusterState, error) { +func (c *clusterStateHolder) Reload() (*clusterState, error) { + state, err := c.reload() + if err != nil { + return nil, err + } + if !state.IsConsistent() { + c.LazyReload() + } + return state, nil +} + +func (c *clusterStateHolder) reload() (*clusterState, error) { state, err := c.load() if err != nil { - c.lastErrMu.Lock() - c.lastErr = err - c.lastErrMu.Unlock() + c.firstErrMu.Lock() + if c.firstErr == nil { + c.firstErr = err + } + c.firstErrMu.Unlock() return nil, err } c.state.Store(state) @@ -527,9 +581,15 @@ func (c *clusterStateHolder) LazyReload() { go func() { defer atomic.StoreUint32(&c.reloading, 0) - _, err := c.Load() - if err == nil { - time.Sleep(time.Second) + for { + state, err := c.reload() + if err != nil { + return + } + time.Sleep(100 * time.Millisecond) + if state.IsConsistent() { + return + } } }() } @@ -537,12 +597,16 @@ func (c *clusterStateHolder) LazyReload() { func (c *clusterStateHolder) Get() (*clusterState, error) { v := c.state.Load() if v != nil { - return v.(*clusterState), nil + state := v.(*clusterState) + if time.Since(state.createdAt) > time.Minute { + c.LazyReload() + } + return state, nil } - c.lastErrMu.RLock() - err := c.lastErr - c.lastErrMu.RUnlock() + c.firstErrMu.RLock() + err := c.firstErr + c.firstErrMu.RUnlock() if err != nil { return nil, err } @@ -576,19 +640,19 @@ func NewClusterClient(opt *ClusterOptions) *ClusterClient { opt.init() c := &ClusterClient{ - opt: opt, - nodes: newClusterNodes(opt), - cmdsInfoCache: newCmdsInfoCache(), + opt: opt, + nodes: newClusterNodes(opt), } c.state = newClusterStateHolder(c.loadState) + c.cmdsInfoCache = newCmdsInfoCache(c.cmdsInfo) c.process = c.defaultProcess c.processPipeline = c.defaultProcessPipeline c.processTxPipeline = c.defaultProcessTxPipeline - c.cmdable.setProcessor(c.Process) + c.init() - _, _ = c.state.Load() + _, _ = c.state.Reload() if opt.IdleCheckFrequency > 0 { go c.reaper(opt.IdleCheckFrequency) } @@ -596,6 +660,10 @@ func NewClusterClient(opt *ClusterOptions) *ClusterClient { return c } +func (c *ClusterClient) init() { + c.cmdable.setProcessor(c.Process) +} + func (c *ClusterClient) Context() context.Context { if c.ctx != nil { return c.ctx @@ -614,6 +682,7 @@ func (c *ClusterClient) WithContext(ctx context.Context) *ClusterClient { func (c *ClusterClient) copy() *ClusterClient { cp := *c + cp.init() return &cp } @@ -626,17 +695,39 @@ func (c *ClusterClient) retryBackoff(attempt int) time.Duration { return internal.RetryBackoff(attempt, c.opt.MinRetryBackoff, c.opt.MaxRetryBackoff) } -func (c *ClusterClient) cmdInfo(name string) *CommandInfo { - cmdsInfo, err := c.cmdsInfoCache.Do(func() (map[string]*CommandInfo, error) { - node, err := c.nodes.Random() +func (c *ClusterClient) cmdsInfo() (map[string]*CommandInfo, error) { + addrs, err := c.nodes.Addrs() + if err != nil { + return nil, err + } + + var firstErr error + for _, addr := range addrs { + node, err := c.nodes.Get(addr) if err != nil { return nil, err } - return node.Client.Command().Result() - }) + if node == nil { + continue + } + + info, err := node.Client.Command().Result() + if err == nil { + return info, nil + } + if firstErr == nil { + firstErr = err + } + } + return nil, firstErr +} + +func (c *ClusterClient) cmdInfo(name string) *CommandInfo { + cmdsInfo, err := c.cmdsInfoCache.Get() if err != nil { return nil } + info := cmdsInfo[name] if info == nil { internal.Logf("info for cmd=%s not found", name) @@ -700,13 +791,14 @@ func (c *ClusterClient) slotMasterNode(slot int) (*clusterNode, error) { func (c *ClusterClient) Watch(fn func(*Tx) error, keys ...string) error { if len(keys) == 0 { - return fmt.Errorf("redis: keys don't hash to the same slot") + return fmt.Errorf("redis: Watch requires at least one key") } slot := hashtag.Slot(keys[0]) for _, key := range keys[1:] { if hashtag.Slot(key) != slot { - return fmt.Errorf("redis: Watch requires all keys to be in the same slot") + err := fmt.Errorf("redis: Watch requires all keys to be in the same slot") + return err } } @@ -812,6 +904,12 @@ func (c *ClusterClient) defaultProcess(cmd Cmder) error { } if internal.IsRetryableError(err, true) { + // Firstly retry the same node. + if attempt == 0 { + continue + } + + // Secondly try random node. node, err = c.nodes.Random() if err != nil { break @@ -846,14 +944,17 @@ func (c *ClusterClient) defaultProcess(cmd Cmder) error { // ForEachMaster concurrently calls the fn on each master node in the cluster. // It returns the first error if any. func (c *ClusterClient) ForEachMaster(fn func(client *Client) error) error { - state, err := c.state.Get() + state, err := c.state.Reload() if err != nil { - return err + state, err = c.state.Get() + if err != nil { + return err + } } var wg sync.WaitGroup errCh := make(chan error, 1) - for _, master := range state.masters { + for _, master := range state.Masters { wg.Add(1) go func(node *clusterNode) { defer wg.Done() @@ -879,14 +980,17 @@ func (c *ClusterClient) ForEachMaster(fn func(client *Client) error) error { // ForEachSlave concurrently calls the fn on each slave node in the cluster. // It returns the first error if any. func (c *ClusterClient) ForEachSlave(fn func(client *Client) error) error { - state, err := c.state.Get() + state, err := c.state.Reload() if err != nil { - return err + state, err = c.state.Get() + if err != nil { + return err + } } var wg sync.WaitGroup errCh := make(chan error, 1) - for _, slave := range state.slaves { + for _, slave := range state.Slaves { wg.Add(1) go func(node *clusterNode) { defer wg.Done() @@ -912,9 +1016,12 @@ func (c *ClusterClient) ForEachSlave(fn func(client *Client) error) error { // ForEachNode concurrently calls the fn on each known node in the cluster. // It returns the first error if any. func (c *ClusterClient) ForEachNode(fn func(client *Client) error) error { - state, err := c.state.Get() + state, err := c.state.Reload() if err != nil { - return err + state, err = c.state.Get() + if err != nil { + return err + } } var wg sync.WaitGroup @@ -930,11 +1037,11 @@ func (c *ClusterClient) ForEachNode(fn func(client *Client) error) error { } } - for _, node := range state.masters { + for _, node := range state.Masters { wg.Add(1) go worker(node) } - for _, node := range state.slaves { + for _, node := range state.Slaves { wg.Add(1) go worker(node) } @@ -957,7 +1064,7 @@ func (c *ClusterClient) PoolStats() *PoolStats { return &acc } - for _, node := range state.masters { + for _, node := range state.Masters { s := node.Client.connPool.Stats() acc.Hits += s.Hits acc.Misses += s.Misses @@ -968,7 +1075,7 @@ func (c *ClusterClient) PoolStats() *PoolStats { acc.StaleConns += s.StaleConns } - for _, node := range state.slaves { + for _, node := range state.Slaves { s := node.Client.connPool.Stats() acc.Hits += s.Hits acc.Misses += s.Misses @@ -1065,7 +1172,7 @@ func (c *ClusterClient) defaultProcessPipeline(cmds []Cmder) error { failedCmds := make(map[*clusterNode][]Cmder) for node, cmds := range cmdsMap { - cn, _, err := node.Client.getConn() + cn, err := node.Client.getConn() if err != nil { if err == pool.ErrClosed { c.remapCmds(cmds, failedCmds) @@ -1077,9 +1184,9 @@ func (c *ClusterClient) defaultProcessPipeline(cmds []Cmder) error { err = c.pipelineProcessCmds(node, cn, cmds, failedCmds) if err == nil || internal.IsRedisError(err) { - _ = node.Client.connPool.Put(cn) + node.Client.connPool.Put(cn) } else { - _ = node.Client.connPool.Remove(cn) + node.Client.connPool.Remove(cn) } } @@ -1229,7 +1336,7 @@ func (c *ClusterClient) defaultProcessTxPipeline(cmds []Cmder) error { failedCmds := make(map[*clusterNode][]Cmder) for node, cmds := range cmdsMap { - cn, _, err := node.Client.getConn() + cn, err := node.Client.getConn() if err != nil { if err == pool.ErrClosed { c.remapCmds(cmds, failedCmds) @@ -1241,9 +1348,9 @@ func (c *ClusterClient) defaultProcessTxPipeline(cmds []Cmder) error { err = c.txPipelineProcessCmds(node, cn, cmds, failedCmds) if err == nil || internal.IsRedisError(err) { - _ = node.Client.connPool.Put(cn) + node.Client.connPool.Put(cn) } else { - _ = node.Client.connPool.Remove(cn) + node.Client.connPool.Remove(cn) } } @@ -1387,6 +1494,29 @@ func (c *ClusterClient) PSubscribe(channels ...string) *PubSub { return pubsub } +func useOriginAddr(originAddr, nodeAddr string) bool { + nodeHost, nodePort, err := net.SplitHostPort(nodeAddr) + if err != nil { + return false + } + + nodeIP := net.ParseIP(nodeHost) + if nodeIP == nil { + return false + } + + if !nodeIP.IsLoopback() { + return false + } + + _, originPort, err := net.SplitHostPort(originAddr) + if err != nil { + return false + } + + return nodePort == originPort +} + func isLoopbackAddr(addr string) bool { host, _, err := net.SplitHostPort(addr) if err != nil { @@ -1401,7 +1531,7 @@ func isLoopbackAddr(addr string) bool { return ip.IsLoopback() } -func appendNode(nodes []*clusterNode, node *clusterNode) []*clusterNode { +func appendUniqueNode(nodes []*clusterNode, node *clusterNode) []*clusterNode { for _, n := range nodes { if n == node { return nodes |