summaryrefslogtreecommitdiffstats
path: root/vendor/github.com/hashicorp/memberlist/util.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/hashicorp/memberlist/util.go')
-rw-r--r--vendor/github.com/hashicorp/memberlist/util.go296
1 files changed, 296 insertions, 0 deletions
diff --git a/vendor/github.com/hashicorp/memberlist/util.go b/vendor/github.com/hashicorp/memberlist/util.go
new file mode 100644
index 000000000..a4f926e3a
--- /dev/null
+++ b/vendor/github.com/hashicorp/memberlist/util.go
@@ -0,0 +1,296 @@
+package memberlist
+
+import (
+ "bytes"
+ "compress/lzw"
+ "encoding/binary"
+ "fmt"
+ "io"
+ "math"
+ "math/rand"
+ "net"
+ "strconv"
+ "strings"
+ "time"
+
+ "github.com/hashicorp/go-msgpack/codec"
+ "github.com/sean-/seed"
+)
+
+// pushPullScale is the minimum number of nodes
+// before we start scaling the push/pull timing. The scale
+// effect is the log2(Nodes) - log2(pushPullScale). This means
+// that the 33rd node will cause us to double the interval,
+// while the 65th will triple it.
+const pushPullScaleThreshold = 32
+
+const (
+ // Constant litWidth 2-8
+ lzwLitWidth = 8
+)
+
+func init() {
+ seed.Init()
+}
+
+// Decode reverses the encode operation on a byte slice input
+func decode(buf []byte, out interface{}) error {
+ r := bytes.NewReader(buf)
+ hd := codec.MsgpackHandle{}
+ dec := codec.NewDecoder(r, &hd)
+ return dec.Decode(out)
+}
+
+// Encode writes an encoded object to a new bytes buffer
+func encode(msgType messageType, in interface{}) (*bytes.Buffer, error) {
+ buf := bytes.NewBuffer(nil)
+ buf.WriteByte(uint8(msgType))
+ hd := codec.MsgpackHandle{}
+ enc := codec.NewEncoder(buf, &hd)
+ err := enc.Encode(in)
+ return buf, err
+}
+
+// Returns a random offset between 0 and n
+func randomOffset(n int) int {
+ if n == 0 {
+ return 0
+ }
+ return int(rand.Uint32() % uint32(n))
+}
+
+// suspicionTimeout computes the timeout that should be used when
+// a node is suspected
+func suspicionTimeout(suspicionMult, n int, interval time.Duration) time.Duration {
+ nodeScale := math.Max(1.0, math.Log10(math.Max(1.0, float64(n))))
+ // multiply by 1000 to keep some precision because time.Duration is an int64 type
+ timeout := time.Duration(suspicionMult) * time.Duration(nodeScale*1000) * interval / 1000
+ return timeout
+}
+
+// retransmitLimit computes the limit of retransmissions
+func retransmitLimit(retransmitMult, n int) int {
+ nodeScale := math.Ceil(math.Log10(float64(n + 1)))
+ limit := retransmitMult * int(nodeScale)
+ return limit
+}
+
+// shuffleNodes randomly shuffles the input nodes using the Fisher-Yates shuffle
+func shuffleNodes(nodes []*nodeState) {
+ n := len(nodes)
+ for i := n - 1; i > 0; i-- {
+ j := rand.Intn(i + 1)
+ nodes[i], nodes[j] = nodes[j], nodes[i]
+ }
+}
+
+// pushPushScale is used to scale the time interval at which push/pull
+// syncs take place. It is used to prevent network saturation as the
+// cluster size grows
+func pushPullScale(interval time.Duration, n int) time.Duration {
+ // Don't scale until we cross the threshold
+ if n <= pushPullScaleThreshold {
+ return interval
+ }
+
+ multiplier := math.Ceil(math.Log2(float64(n))-math.Log2(pushPullScaleThreshold)) + 1.0
+ return time.Duration(multiplier) * interval
+}
+
+// moveDeadNodes moves nodes that are dead and beyond the gossip to the dead interval
+// to the end of the slice and returns the index of the first moved node.
+func moveDeadNodes(nodes []*nodeState, gossipToTheDeadTime time.Duration) int {
+ numDead := 0
+ n := len(nodes)
+ for i := 0; i < n-numDead; i++ {
+ if nodes[i].State != stateDead {
+ continue
+ }
+
+ // Respect the gossip to the dead interval
+ if time.Since(nodes[i].StateChange) <= gossipToTheDeadTime {
+ continue
+ }
+
+ // Move this node to the end
+ nodes[i], nodes[n-numDead-1] = nodes[n-numDead-1], nodes[i]
+ numDead++
+ i--
+ }
+ return n - numDead
+}
+
+// kRandomNodes is used to select up to k random nodes, excluding any nodes where
+// the filter function returns true. It is possible that less than k nodes are
+// returned.
+func kRandomNodes(k int, nodes []*nodeState, filterFn func(*nodeState) bool) []*nodeState {
+ n := len(nodes)
+ kNodes := make([]*nodeState, 0, k)
+OUTER:
+ // Probe up to 3*n times, with large n this is not necessary
+ // since k << n, but with small n we want search to be
+ // exhaustive
+ for i := 0; i < 3*n && len(kNodes) < k; i++ {
+ // Get random node
+ idx := randomOffset(n)
+ node := nodes[idx]
+
+ // Give the filter a shot at it.
+ if filterFn != nil && filterFn(node) {
+ continue OUTER
+ }
+
+ // Check if we have this node already
+ for j := 0; j < len(kNodes); j++ {
+ if node == kNodes[j] {
+ continue OUTER
+ }
+ }
+
+ // Append the node
+ kNodes = append(kNodes, node)
+ }
+ return kNodes
+}
+
+// makeCompoundMessage takes a list of messages and generates
+// a single compound message containing all of them
+func makeCompoundMessage(msgs [][]byte) *bytes.Buffer {
+ // Create a local buffer
+ buf := bytes.NewBuffer(nil)
+
+ // Write out the type
+ buf.WriteByte(uint8(compoundMsg))
+
+ // Write out the number of message
+ buf.WriteByte(uint8(len(msgs)))
+
+ // Add the message lengths
+ for _, m := range msgs {
+ binary.Write(buf, binary.BigEndian, uint16(len(m)))
+ }
+
+ // Append the messages
+ for _, m := range msgs {
+ buf.Write(m)
+ }
+
+ return buf
+}
+
+// decodeCompoundMessage splits a compound message and returns
+// the slices of individual messages. Also returns the number
+// of truncated messages and any potential error
+func decodeCompoundMessage(buf []byte) (trunc int, parts [][]byte, err error) {
+ if len(buf) < 1 {
+ err = fmt.Errorf("missing compound length byte")
+ return
+ }
+ numParts := uint8(buf[0])
+ buf = buf[1:]
+
+ // Check we have enough bytes
+ if len(buf) < int(numParts*2) {
+ err = fmt.Errorf("truncated len slice")
+ return
+ }
+
+ // Decode the lengths
+ lengths := make([]uint16, numParts)
+ for i := 0; i < int(numParts); i++ {
+ lengths[i] = binary.BigEndian.Uint16(buf[i*2 : i*2+2])
+ }
+ buf = buf[numParts*2:]
+
+ // Split each message
+ for idx, msgLen := range lengths {
+ if len(buf) < int(msgLen) {
+ trunc = int(numParts) - idx
+ return
+ }
+
+ // Extract the slice, seek past on the buffer
+ slice := buf[:msgLen]
+ buf = buf[msgLen:]
+ parts = append(parts, slice)
+ }
+ return
+}
+
+// Given a string of the form "host", "host:port",
+// "ipv6::addr" or "[ipv6::address]:port",
+// return true if the string includes a port.
+func hasPort(s string) bool {
+ last := strings.LastIndex(s, ":")
+ if last == -1 {
+ return false
+ }
+ if s[0] == '[' {
+ return s[last-1] == ']'
+ }
+ return strings.Index(s, ":") == last
+}
+
+// compressPayload takes an opaque input buffer, compresses it
+// and wraps it in a compress{} message that is encoded.
+func compressPayload(inp []byte) (*bytes.Buffer, error) {
+ var buf bytes.Buffer
+ compressor := lzw.NewWriter(&buf, lzw.LSB, lzwLitWidth)
+
+ _, err := compressor.Write(inp)
+ if err != nil {
+ return nil, err
+ }
+
+ // Ensure we flush everything out
+ if err := compressor.Close(); err != nil {
+ return nil, err
+ }
+
+ // Create a compressed message
+ c := compress{
+ Algo: lzwAlgo,
+ Buf: buf.Bytes(),
+ }
+ return encode(compressMsg, &c)
+}
+
+// decompressPayload is used to unpack an encoded compress{}
+// message and return its payload uncompressed
+func decompressPayload(msg []byte) ([]byte, error) {
+ // Decode the message
+ var c compress
+ if err := decode(msg, &c); err != nil {
+ return nil, err
+ }
+ return decompressBuffer(&c)
+}
+
+// decompressBuffer is used to decompress the buffer of
+// a single compress message, handling multiple algorithms
+func decompressBuffer(c *compress) ([]byte, error) {
+ // Verify the algorithm
+ if c.Algo != lzwAlgo {
+ return nil, fmt.Errorf("Cannot decompress unknown algorithm %d", c.Algo)
+ }
+
+ // Create a uncompressor
+ uncomp := lzw.NewReader(bytes.NewReader(c.Buf), lzw.LSB, lzwLitWidth)
+ defer uncomp.Close()
+
+ // Read all the data
+ var b bytes.Buffer
+ _, err := io.Copy(&b, uncomp)
+ if err != nil {
+ return nil, err
+ }
+
+ // Return the uncompressed bytes
+ return b.Bytes(), nil
+}
+
+// joinHostPort returns the host:port form of an address, for use with a
+// transport.
+func joinHostPort(host string, port uint16) string {
+ return net.JoinHostPort(host, strconv.Itoa(int(port)))
+}