summaryrefslogtreecommitdiffstats
path: root/vendor/github.com/go-redis/redis/cluster.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/go-redis/redis/cluster.go')
-rw-r--r--vendor/github.com/go-redis/redis/cluster.go244
1 files changed, 187 insertions, 57 deletions
diff --git a/vendor/github.com/go-redis/redis/cluster.go b/vendor/github.com/go-redis/redis/cluster.go
index 4a2951157..0c58c8532 100644
--- a/vendor/github.com/go-redis/redis/cluster.go
+++ b/vendor/github.com/go-redis/redis/cluster.go
@@ -2,11 +2,13 @@ package redis
import (
"context"
+ "crypto/tls"
"errors"
"fmt"
"math"
"math/rand"
"net"
+ "strings"
"sync"
"sync/atomic"
"time"
@@ -34,6 +36,7 @@ type ClusterOptions struct {
// Enables read-only commands on slave nodes.
ReadOnly bool
// Allows routing read-only commands to the closest master or slave node.
+ // It automatically enables ReadOnly.
RouteByLatency bool
// Allows routing read-only commands to the random master or slave node.
RouteRandomly bool
@@ -56,6 +59,8 @@ type ClusterOptions struct {
PoolTimeout time.Duration
IdleTimeout time.Duration
IdleCheckFrequency time.Duration
+
+ TLSConfig *tls.Config
}
func (opt *ClusterOptions) init() {
@@ -117,6 +122,8 @@ func (opt *ClusterOptions) clientOptions() *Options {
IdleTimeout: opt.IdleTimeout,
IdleCheckFrequency: disableIdleCheck,
+
+ TLSConfig: opt.TLSConfig,
}
}
@@ -145,6 +152,10 @@ func newClusterNode(clOpt *ClusterOptions, addr string) *clusterNode {
return &node
}
+func (n *clusterNode) String() string {
+ return n.Client.String()
+}
+
func (n *clusterNode) Close() error {
return n.Client.Close()
}
@@ -215,7 +226,7 @@ type clusterNodes struct {
nodeCreateGroup singleflight.Group
- generation uint32
+ _generation uint32 // atomic
}
func newClusterNodes(opt *ClusterOptions) *clusterNodes {
@@ -272,8 +283,7 @@ func (c *clusterNodes) Addrs() ([]string, error) {
}
func (c *clusterNodes) NextGeneration() uint32 {
- c.generation++
- return c.generation
+ return atomic.AddUint32(&c._generation, 1)
}
// GC removes unused nodes.
@@ -296,10 +306,9 @@ func (c *clusterNodes) GC(generation uint32) {
}
}
-func (c *clusterNodes) GetOrCreate(addr string) (*clusterNode, error) {
+func (c *clusterNodes) Get(addr string) (*clusterNode, error) {
var node *clusterNode
var err error
-
c.mu.RLock()
if c.closed {
err = pool.ErrClosed
@@ -307,6 +316,11 @@ func (c *clusterNodes) GetOrCreate(addr string) (*clusterNode, error) {
node = c.allNodes[addr]
}
c.mu.RUnlock()
+ return node, err
+}
+
+func (c *clusterNodes) GetOrCreate(addr string) (*clusterNode, error) {
+ node, err := c.Get(addr)
if err != nil {
return nil, err
}
@@ -371,20 +385,25 @@ func (c *clusterNodes) Random() (*clusterNode, error) {
type clusterState struct {
nodes *clusterNodes
- masters []*clusterNode
- slaves []*clusterNode
+ Masters []*clusterNode
+ Slaves []*clusterNode
slots [][]*clusterNode
generation uint32
+ createdAt time.Time
}
-func newClusterState(nodes *clusterNodes, slots []ClusterSlot, origin string) (*clusterState, error) {
+func newClusterState(
+ nodes *clusterNodes, slots []ClusterSlot, origin string,
+) (*clusterState, error) {
c := clusterState{
- nodes: nodes,
- generation: nodes.NextGeneration(),
+ nodes: nodes,
slots: make([][]*clusterNode, hashtag.SlotNumber),
+
+ generation: nodes.NextGeneration(),
+ createdAt: time.Now(),
}
isLoopbackOrigin := isLoopbackAddr(origin)
@@ -392,7 +411,7 @@ func newClusterState(nodes *clusterNodes, slots []ClusterSlot, origin string) (*
var nodes []*clusterNode
for i, slotNode := range slot.Nodes {
addr := slotNode.Addr
- if !isLoopbackOrigin && isLoopbackAddr(addr) {
+ if !isLoopbackOrigin && useOriginAddr(origin, addr) {
addr = origin
}
@@ -405,9 +424,9 @@ func newClusterState(nodes *clusterNodes, slots []ClusterSlot, origin string) (*
nodes = append(nodes, node)
if i == 0 {
- c.masters = appendNode(c.masters, node)
+ c.Masters = appendUniqueNode(c.Masters, node)
} else {
- c.slaves = appendNode(c.slaves, node)
+ c.Slaves = appendUniqueNode(c.Slaves, node)
}
}
@@ -489,6 +508,28 @@ func (c *clusterState) slotNodes(slot int) []*clusterNode {
return nil
}
+func (c *clusterState) IsConsistent() bool {
+ if len(c.Masters) > len(c.Slaves) {
+ return false
+ }
+
+ for _, master := range c.Masters {
+ s := master.Client.Info("replication").Val()
+ if !strings.Contains(s, "role:master") {
+ return false
+ }
+ }
+
+ for _, slave := range c.Slaves {
+ s := slave.Client.Info("replication").Val()
+ if !strings.Contains(s, "role:slave") {
+ return false
+ }
+ }
+
+ return true
+}
+
//------------------------------------------------------------------------------
type clusterStateHolder struct {
@@ -496,8 +537,8 @@ type clusterStateHolder struct {
state atomic.Value
- lastErrMu sync.RWMutex
- lastErr error
+ firstErrMu sync.RWMutex
+ firstErr error
reloading uint32 // atomic
}
@@ -508,12 +549,25 @@ func newClusterStateHolder(fn func() (*clusterState, error)) *clusterStateHolder
}
}
-func (c *clusterStateHolder) Load() (*clusterState, error) {
+func (c *clusterStateHolder) Reload() (*clusterState, error) {
+ state, err := c.reload()
+ if err != nil {
+ return nil, err
+ }
+ if !state.IsConsistent() {
+ c.LazyReload()
+ }
+ return state, nil
+}
+
+func (c *clusterStateHolder) reload() (*clusterState, error) {
state, err := c.load()
if err != nil {
- c.lastErrMu.Lock()
- c.lastErr = err
- c.lastErrMu.Unlock()
+ c.firstErrMu.Lock()
+ if c.firstErr == nil {
+ c.firstErr = err
+ }
+ c.firstErrMu.Unlock()
return nil, err
}
c.state.Store(state)
@@ -527,9 +581,15 @@ func (c *clusterStateHolder) LazyReload() {
go func() {
defer atomic.StoreUint32(&c.reloading, 0)
- _, err := c.Load()
- if err == nil {
- time.Sleep(time.Second)
+ for {
+ state, err := c.reload()
+ if err != nil {
+ return
+ }
+ time.Sleep(100 * time.Millisecond)
+ if state.IsConsistent() {
+ return
+ }
}
}()
}
@@ -537,12 +597,16 @@ func (c *clusterStateHolder) LazyReload() {
func (c *clusterStateHolder) Get() (*clusterState, error) {
v := c.state.Load()
if v != nil {
- return v.(*clusterState), nil
+ state := v.(*clusterState)
+ if time.Since(state.createdAt) > time.Minute {
+ c.LazyReload()
+ }
+ return state, nil
}
- c.lastErrMu.RLock()
- err := c.lastErr
- c.lastErrMu.RUnlock()
+ c.firstErrMu.RLock()
+ err := c.firstErr
+ c.firstErrMu.RUnlock()
if err != nil {
return nil, err
}
@@ -576,19 +640,19 @@ func NewClusterClient(opt *ClusterOptions) *ClusterClient {
opt.init()
c := &ClusterClient{
- opt: opt,
- nodes: newClusterNodes(opt),
- cmdsInfoCache: newCmdsInfoCache(),
+ opt: opt,
+ nodes: newClusterNodes(opt),
}
c.state = newClusterStateHolder(c.loadState)
+ c.cmdsInfoCache = newCmdsInfoCache(c.cmdsInfo)
c.process = c.defaultProcess
c.processPipeline = c.defaultProcessPipeline
c.processTxPipeline = c.defaultProcessTxPipeline
- c.cmdable.setProcessor(c.Process)
+ c.init()
- _, _ = c.state.Load()
+ _, _ = c.state.Reload()
if opt.IdleCheckFrequency > 0 {
go c.reaper(opt.IdleCheckFrequency)
}
@@ -596,6 +660,10 @@ func NewClusterClient(opt *ClusterOptions) *ClusterClient {
return c
}
+func (c *ClusterClient) init() {
+ c.cmdable.setProcessor(c.Process)
+}
+
func (c *ClusterClient) Context() context.Context {
if c.ctx != nil {
return c.ctx
@@ -614,6 +682,7 @@ func (c *ClusterClient) WithContext(ctx context.Context) *ClusterClient {
func (c *ClusterClient) copy() *ClusterClient {
cp := *c
+ cp.init()
return &cp
}
@@ -626,17 +695,39 @@ func (c *ClusterClient) retryBackoff(attempt int) time.Duration {
return internal.RetryBackoff(attempt, c.opt.MinRetryBackoff, c.opt.MaxRetryBackoff)
}
-func (c *ClusterClient) cmdInfo(name string) *CommandInfo {
- cmdsInfo, err := c.cmdsInfoCache.Do(func() (map[string]*CommandInfo, error) {
- node, err := c.nodes.Random()
+func (c *ClusterClient) cmdsInfo() (map[string]*CommandInfo, error) {
+ addrs, err := c.nodes.Addrs()
+ if err != nil {
+ return nil, err
+ }
+
+ var firstErr error
+ for _, addr := range addrs {
+ node, err := c.nodes.Get(addr)
if err != nil {
return nil, err
}
- return node.Client.Command().Result()
- })
+ if node == nil {
+ continue
+ }
+
+ info, err := node.Client.Command().Result()
+ if err == nil {
+ return info, nil
+ }
+ if firstErr == nil {
+ firstErr = err
+ }
+ }
+ return nil, firstErr
+}
+
+func (c *ClusterClient) cmdInfo(name string) *CommandInfo {
+ cmdsInfo, err := c.cmdsInfoCache.Get()
if err != nil {
return nil
}
+
info := cmdsInfo[name]
if info == nil {
internal.Logf("info for cmd=%s not found", name)
@@ -700,13 +791,14 @@ func (c *ClusterClient) slotMasterNode(slot int) (*clusterNode, error) {
func (c *ClusterClient) Watch(fn func(*Tx) error, keys ...string) error {
if len(keys) == 0 {
- return fmt.Errorf("redis: keys don't hash to the same slot")
+ return fmt.Errorf("redis: Watch requires at least one key")
}
slot := hashtag.Slot(keys[0])
for _, key := range keys[1:] {
if hashtag.Slot(key) != slot {
- return fmt.Errorf("redis: Watch requires all keys to be in the same slot")
+ err := fmt.Errorf("redis: Watch requires all keys to be in the same slot")
+ return err
}
}
@@ -812,6 +904,12 @@ func (c *ClusterClient) defaultProcess(cmd Cmder) error {
}
if internal.IsRetryableError(err, true) {
+ // Firstly retry the same node.
+ if attempt == 0 {
+ continue
+ }
+
+ // Secondly try random node.
node, err = c.nodes.Random()
if err != nil {
break
@@ -846,14 +944,17 @@ func (c *ClusterClient) defaultProcess(cmd Cmder) error {
// ForEachMaster concurrently calls the fn on each master node in the cluster.
// It returns the first error if any.
func (c *ClusterClient) ForEachMaster(fn func(client *Client) error) error {
- state, err := c.state.Get()
+ state, err := c.state.Reload()
if err != nil {
- return err
+ state, err = c.state.Get()
+ if err != nil {
+ return err
+ }
}
var wg sync.WaitGroup
errCh := make(chan error, 1)
- for _, master := range state.masters {
+ for _, master := range state.Masters {
wg.Add(1)
go func(node *clusterNode) {
defer wg.Done()
@@ -879,14 +980,17 @@ func (c *ClusterClient) ForEachMaster(fn func(client *Client) error) error {
// ForEachSlave concurrently calls the fn on each slave node in the cluster.
// It returns the first error if any.
func (c *ClusterClient) ForEachSlave(fn func(client *Client) error) error {
- state, err := c.state.Get()
+ state, err := c.state.Reload()
if err != nil {
- return err
+ state, err = c.state.Get()
+ if err != nil {
+ return err
+ }
}
var wg sync.WaitGroup
errCh := make(chan error, 1)
- for _, slave := range state.slaves {
+ for _, slave := range state.Slaves {
wg.Add(1)
go func(node *clusterNode) {
defer wg.Done()
@@ -912,9 +1016,12 @@ func (c *ClusterClient) ForEachSlave(fn func(client *Client) error) error {
// ForEachNode concurrently calls the fn on each known node in the cluster.
// It returns the first error if any.
func (c *ClusterClient) ForEachNode(fn func(client *Client) error) error {
- state, err := c.state.Get()
+ state, err := c.state.Reload()
if err != nil {
- return err
+ state, err = c.state.Get()
+ if err != nil {
+ return err
+ }
}
var wg sync.WaitGroup
@@ -930,11 +1037,11 @@ func (c *ClusterClient) ForEachNode(fn func(client *Client) error) error {
}
}
- for _, node := range state.masters {
+ for _, node := range state.Masters {
wg.Add(1)
go worker(node)
}
- for _, node := range state.slaves {
+ for _, node := range state.Slaves {
wg.Add(1)
go worker(node)
}
@@ -957,7 +1064,7 @@ func (c *ClusterClient) PoolStats() *PoolStats {
return &acc
}
- for _, node := range state.masters {
+ for _, node := range state.Masters {
s := node.Client.connPool.Stats()
acc.Hits += s.Hits
acc.Misses += s.Misses
@@ -968,7 +1075,7 @@ func (c *ClusterClient) PoolStats() *PoolStats {
acc.StaleConns += s.StaleConns
}
- for _, node := range state.slaves {
+ for _, node := range state.Slaves {
s := node.Client.connPool.Stats()
acc.Hits += s.Hits
acc.Misses += s.Misses
@@ -1065,7 +1172,7 @@ func (c *ClusterClient) defaultProcessPipeline(cmds []Cmder) error {
failedCmds := make(map[*clusterNode][]Cmder)
for node, cmds := range cmdsMap {
- cn, _, err := node.Client.getConn()
+ cn, err := node.Client.getConn()
if err != nil {
if err == pool.ErrClosed {
c.remapCmds(cmds, failedCmds)
@@ -1077,9 +1184,9 @@ func (c *ClusterClient) defaultProcessPipeline(cmds []Cmder) error {
err = c.pipelineProcessCmds(node, cn, cmds, failedCmds)
if err == nil || internal.IsRedisError(err) {
- _ = node.Client.connPool.Put(cn)
+ node.Client.connPool.Put(cn)
} else {
- _ = node.Client.connPool.Remove(cn)
+ node.Client.connPool.Remove(cn)
}
}
@@ -1229,7 +1336,7 @@ func (c *ClusterClient) defaultProcessTxPipeline(cmds []Cmder) error {
failedCmds := make(map[*clusterNode][]Cmder)
for node, cmds := range cmdsMap {
- cn, _, err := node.Client.getConn()
+ cn, err := node.Client.getConn()
if err != nil {
if err == pool.ErrClosed {
c.remapCmds(cmds, failedCmds)
@@ -1241,9 +1348,9 @@ func (c *ClusterClient) defaultProcessTxPipeline(cmds []Cmder) error {
err = c.txPipelineProcessCmds(node, cn, cmds, failedCmds)
if err == nil || internal.IsRedisError(err) {
- _ = node.Client.connPool.Put(cn)
+ node.Client.connPool.Put(cn)
} else {
- _ = node.Client.connPool.Remove(cn)
+ node.Client.connPool.Remove(cn)
}
}
@@ -1387,6 +1494,29 @@ func (c *ClusterClient) PSubscribe(channels ...string) *PubSub {
return pubsub
}
+func useOriginAddr(originAddr, nodeAddr string) bool {
+ nodeHost, nodePort, err := net.SplitHostPort(nodeAddr)
+ if err != nil {
+ return false
+ }
+
+ nodeIP := net.ParseIP(nodeHost)
+ if nodeIP == nil {
+ return false
+ }
+
+ if !nodeIP.IsLoopback() {
+ return false
+ }
+
+ _, originPort, err := net.SplitHostPort(originAddr)
+ if err != nil {
+ return false
+ }
+
+ return nodePort == originPort
+}
+
func isLoopbackAddr(addr string) bool {
host, _, err := net.SplitHostPort(addr)
if err != nil {
@@ -1401,7 +1531,7 @@ func isLoopbackAddr(addr string) bool {
return ip.IsLoopback()
}
-func appendNode(nodes []*clusterNode, node *clusterNode) []*clusterNode {
+func appendUniqueNode(nodes []*clusterNode, node *clusterNode) []*clusterNode {
for _, n := range nodes {
if n == node {
return nodes