diff options
Diffstat (limited to 'vendor/github.com/hashicorp/memberlist/net_test.go')
-rw-r--r-- | vendor/github.com/hashicorp/memberlist/net_test.go | 787 |
1 files changed, 787 insertions, 0 deletions
diff --git a/vendor/github.com/hashicorp/memberlist/net_test.go b/vendor/github.com/hashicorp/memberlist/net_test.go new file mode 100644 index 000000000..80d3ebb36 --- /dev/null +++ b/vendor/github.com/hashicorp/memberlist/net_test.go @@ -0,0 +1,787 @@ +package memberlist + +import ( + "bytes" + "encoding/binary" + "fmt" + "io" + "log" + "net" + "reflect" + "strings" + "testing" + "time" + + "github.com/hashicorp/go-msgpack/codec" +) + +// As a regression we left this test very low-level and network-ey, even after +// we abstracted the transport. We added some basic network-free transport tests +// in transport_test.go to prove that we didn't hard code some network stuff +// outside of NetTransport. + +func TestHandleCompoundPing(t *testing.T) { + m := GetMemberlist(t) + m.config.EnableCompression = false + defer m.Shutdown() + + var udp *net.UDPConn + for port := 60000; port < 61000; port++ { + udpAddr := fmt.Sprintf("127.0.0.1:%d", port) + udpLn, err := net.ListenPacket("udp", udpAddr) + if err == nil { + udp = udpLn.(*net.UDPConn) + break + } + } + + if udp == nil { + t.Fatalf("no udp listener") + } + + // Encode a ping + ping := ping{SeqNo: 42} + buf, err := encode(pingMsg, ping) + if err != nil { + t.Fatalf("unexpected err %s", err) + } + + // Make a compound message + compound := makeCompoundMessage([][]byte{buf.Bytes(), buf.Bytes(), buf.Bytes()}) + + // Send compound version + addr := &net.UDPAddr{IP: net.ParseIP(m.config.BindAddr), Port: m.config.BindPort} + udp.WriteTo(compound.Bytes(), addr) + + // Wait for responses + doneCh := make(chan struct{}, 1) + go func() { + select { + case <-doneCh: + case <-time.After(2 * time.Second): + panic("timeout") + } + }() + + for i := 0; i < 3; i++ { + in := make([]byte, 1500) + n, _, err := udp.ReadFrom(in) + if err != nil { + t.Fatalf("unexpected err %s", err) + } + in = in[0:n] + + msgType := messageType(in[0]) + if msgType != ackRespMsg { + t.Fatalf("bad response %v", in) + } + + var ack ackResp + if err := decode(in[1:], &ack); err != nil { + t.Fatalf("unexpected err %s", err) + } + + if ack.SeqNo != 42 { + t.Fatalf("bad sequence no") + } + } + + doneCh <- struct{}{} +} + +func TestHandlePing(t *testing.T) { + m := GetMemberlist(t) + m.config.EnableCompression = false + defer m.Shutdown() + + var udp *net.UDPConn + for port := 60000; port < 61000; port++ { + udpAddr := fmt.Sprintf("127.0.0.1:%d", port) + udpLn, err := net.ListenPacket("udp", udpAddr) + if err == nil { + udp = udpLn.(*net.UDPConn) + break + } + } + + if udp == nil { + t.Fatalf("no udp listener") + } + + // Encode a ping + ping := ping{SeqNo: 42} + buf, err := encode(pingMsg, ping) + if err != nil { + t.Fatalf("unexpected err %s", err) + } + + // Send + addr := &net.UDPAddr{IP: net.ParseIP(m.config.BindAddr), Port: m.config.BindPort} + udp.WriteTo(buf.Bytes(), addr) + + // Wait for response + doneCh := make(chan struct{}, 1) + go func() { + select { + case <-doneCh: + case <-time.After(2 * time.Second): + panic("timeout") + } + }() + + in := make([]byte, 1500) + n, _, err := udp.ReadFrom(in) + if err != nil { + t.Fatalf("unexpected err %s", err) + } + in = in[0:n] + + msgType := messageType(in[0]) + if msgType != ackRespMsg { + t.Fatalf("bad response %v", in) + } + + var ack ackResp + if err := decode(in[1:], &ack); err != nil { + t.Fatalf("unexpected err %s", err) + } + + if ack.SeqNo != 42 { + t.Fatalf("bad sequence no") + } + + doneCh <- struct{}{} +} + +func TestHandlePing_WrongNode(t *testing.T) { + m := GetMemberlist(t) + m.config.EnableCompression = false + defer m.Shutdown() + + var udp *net.UDPConn + for port := 60000; port < 61000; port++ { + udpAddr := fmt.Sprintf("127.0.0.1:%d", port) + udpLn, err := net.ListenPacket("udp", udpAddr) + if err == nil { + udp = udpLn.(*net.UDPConn) + break + } + } + + if udp == nil { + t.Fatalf("no udp listener") + } + + // Encode a ping, wrong node! + ping := ping{SeqNo: 42, Node: m.config.Name + "-bad"} + buf, err := encode(pingMsg, ping) + if err != nil { + t.Fatalf("unexpected err %s", err) + } + + // Send + addr := &net.UDPAddr{IP: net.ParseIP(m.config.BindAddr), Port: m.config.BindPort} + udp.WriteTo(buf.Bytes(), addr) + + // Wait for response + udp.SetDeadline(time.Now().Add(50 * time.Millisecond)) + in := make([]byte, 1500) + _, _, err = udp.ReadFrom(in) + + // Should get an i/o timeout + if err == nil { + t.Fatalf("expected err %s", err) + } +} + +func TestHandleIndirectPing(t *testing.T) { + m := GetMemberlist(t) + m.config.EnableCompression = false + defer m.Shutdown() + + var udp *net.UDPConn + for port := 60000; port < 61000; port++ { + udpAddr := fmt.Sprintf("127.0.0.1:%d", port) + udpLn, err := net.ListenPacket("udp", udpAddr) + if err == nil { + udp = udpLn.(*net.UDPConn) + break + } + } + + if udp == nil { + t.Fatalf("no udp listener") + } + + // Encode an indirect ping + ind := indirectPingReq{ + SeqNo: 100, + Target: net.ParseIP(m.config.BindAddr), + Port: uint16(m.config.BindPort), + } + buf, err := encode(indirectPingMsg, &ind) + if err != nil { + t.Fatalf("unexpected err %s", err) + } + + // Send + addr := &net.UDPAddr{IP: net.ParseIP(m.config.BindAddr), Port: m.config.BindPort} + udp.WriteTo(buf.Bytes(), addr) + + // Wait for response + doneCh := make(chan struct{}, 1) + go func() { + select { + case <-doneCh: + case <-time.After(2 * time.Second): + panic("timeout") + } + }() + + in := make([]byte, 1500) + n, _, err := udp.ReadFrom(in) + if err != nil { + t.Fatalf("unexpected err %s", err) + } + in = in[0:n] + + msgType := messageType(in[0]) + if msgType != ackRespMsg { + t.Fatalf("bad response %v", in) + } + + var ack ackResp + if err := decode(in[1:], &ack); err != nil { + t.Fatalf("unexpected err %s", err) + } + + if ack.SeqNo != 100 { + t.Fatalf("bad sequence no") + } + + doneCh <- struct{}{} +} + +func TestTCPPing(t *testing.T) { + var tcp *net.TCPListener + var tcpAddr *net.TCPAddr + for port := 60000; port < 61000; port++ { + tcpAddr = &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: port} + tcpLn, err := net.ListenTCP("tcp", tcpAddr) + if err == nil { + tcp = tcpLn + break + } + } + if tcp == nil { + t.Fatalf("no tcp listener") + } + + // Note that tcp gets closed in the last test, so we avoid a deferred + // Close() call here. + + m := GetMemberlist(t) + defer m.Shutdown() + pingTimeout := m.config.ProbeInterval + pingTimeMax := m.config.ProbeInterval + 10*time.Millisecond + + // Do a normal round trip. + pingOut := ping{SeqNo: 23, Node: "mongo"} + go func() { + tcp.SetDeadline(time.Now().Add(pingTimeMax)) + conn, err := tcp.AcceptTCP() + if err != nil { + t.Fatalf("failed to connect: %s", err) + } + defer conn.Close() + + msgType, _, dec, err := m.readStream(conn) + if err != nil { + t.Fatalf("failed to read ping: %s", err) + } + + if msgType != pingMsg { + t.Fatalf("expecting ping, got message type (%d)", msgType) + } + + var pingIn ping + if err := dec.Decode(&pingIn); err != nil { + t.Fatalf("failed to decode ping: %s", err) + } + + if pingIn.SeqNo != pingOut.SeqNo { + t.Fatalf("sequence number isn't correct (%d) vs (%d)", pingIn.SeqNo, pingOut.SeqNo) + } + + if pingIn.Node != pingOut.Node { + t.Fatalf("node name isn't correct (%s) vs (%s)", pingIn.Node, pingOut.Node) + } + + ack := ackResp{pingIn.SeqNo, nil} + out, err := encode(ackRespMsg, &ack) + if err != nil { + t.Fatalf("failed to encode ack: %s", err) + } + + err = m.rawSendMsgStream(conn, out.Bytes()) + if err != nil { + t.Fatalf("failed to send ack: %s", err) + } + }() + deadline := time.Now().Add(pingTimeout) + didContact, err := m.sendPingAndWaitForAck(tcpAddr.String(), pingOut, deadline) + if err != nil { + t.Fatalf("error trying to ping: %s", err) + } + if !didContact { + t.Fatalf("expected successful ping") + } + + // Make sure a mis-matched sequence number is caught. + go func() { + tcp.SetDeadline(time.Now().Add(pingTimeMax)) + conn, err := tcp.AcceptTCP() + if err != nil { + t.Fatalf("failed to connect: %s", err) + } + defer conn.Close() + + _, _, dec, err := m.readStream(conn) + if err != nil { + t.Fatalf("failed to read ping: %s", err) + } + + var pingIn ping + if err := dec.Decode(&pingIn); err != nil { + t.Fatalf("failed to decode ping: %s", err) + } + + ack := ackResp{pingIn.SeqNo + 1, nil} + out, err := encode(ackRespMsg, &ack) + if err != nil { + t.Fatalf("failed to encode ack: %s", err) + } + + err = m.rawSendMsgStream(conn, out.Bytes()) + if err != nil { + t.Fatalf("failed to send ack: %s", err) + } + }() + deadline = time.Now().Add(pingTimeout) + didContact, err = m.sendPingAndWaitForAck(tcpAddr.String(), pingOut, deadline) + if err == nil || !strings.Contains(err.Error(), "Sequence number") { + t.Fatalf("expected an error from mis-matched sequence number") + } + if didContact { + t.Fatalf("expected failed ping") + } + + // Make sure an unexpected message type is handled gracefully. + go func() { + tcp.SetDeadline(time.Now().Add(pingTimeMax)) + conn, err := tcp.AcceptTCP() + if err != nil { + t.Fatalf("failed to connect: %s", err) + } + defer conn.Close() + + _, _, _, err = m.readStream(conn) + if err != nil { + t.Fatalf("failed to read ping: %s", err) + } + + bogus := indirectPingReq{} + out, err := encode(indirectPingMsg, &bogus) + if err != nil { + t.Fatalf("failed to encode bogus msg: %s", err) + } + + err = m.rawSendMsgStream(conn, out.Bytes()) + if err != nil { + t.Fatalf("failed to send bogus msg: %s", err) + } + }() + deadline = time.Now().Add(pingTimeout) + didContact, err = m.sendPingAndWaitForAck(tcpAddr.String(), pingOut, deadline) + if err == nil || !strings.Contains(err.Error(), "Unexpected msgType") { + t.Fatalf("expected an error from bogus message") + } + if didContact { + t.Fatalf("expected failed ping") + } + + // Make sure failed I/O respects the deadline. In this case we try the + // common case of the receiving node being totally down. + tcp.Close() + deadline = time.Now().Add(pingTimeout) + startPing := time.Now() + didContact, err = m.sendPingAndWaitForAck(tcpAddr.String(), pingOut, deadline) + pingTime := time.Now().Sub(startPing) + if err != nil { + t.Fatalf("expected no error during ping on closed socket, got: %s", err) + } + if didContact { + t.Fatalf("expected failed ping") + } + if pingTime > pingTimeMax { + t.Fatalf("took too long to fail ping, %9.6f", pingTime.Seconds()) + } +} + +func TestTCPPushPull(t *testing.T) { + m := GetMemberlist(t) + defer m.Shutdown() + m.nodes = append(m.nodes, &nodeState{ + Node: Node{ + Name: "Test 0", + Addr: net.ParseIP(m.config.BindAddr), + Port: uint16(m.config.BindPort), + }, + Incarnation: 0, + State: stateSuspect, + StateChange: time.Now().Add(-1 * time.Second), + }) + + addr := fmt.Sprintf("%s:%d", m.config.BindAddr, m.config.BindPort) + conn, err := net.Dial("tcp", addr) + if err != nil { + t.Fatalf("unexpected err %s", err) + } + defer conn.Close() + + localNodes := make([]pushNodeState, 3) + localNodes[0].Name = "Test 0" + localNodes[0].Addr = net.ParseIP(m.config.BindAddr) + localNodes[0].Port = uint16(m.config.BindPort) + localNodes[0].Incarnation = 1 + localNodes[0].State = stateAlive + localNodes[1].Name = "Test 1" + localNodes[1].Addr = net.ParseIP(m.config.BindAddr) + localNodes[1].Port = uint16(m.config.BindPort) + localNodes[1].Incarnation = 1 + localNodes[1].State = stateAlive + localNodes[2].Name = "Test 2" + localNodes[2].Addr = net.ParseIP(m.config.BindAddr) + localNodes[2].Port = uint16(m.config.BindPort) + localNodes[2].Incarnation = 1 + localNodes[2].State = stateAlive + + // Send our node state + header := pushPullHeader{Nodes: 3} + hd := codec.MsgpackHandle{} + enc := codec.NewEncoder(conn, &hd) + + // Send the push/pull indicator + conn.Write([]byte{byte(pushPullMsg)}) + + if err := enc.Encode(&header); err != nil { + t.Fatalf("unexpected err %s", err) + } + for i := 0; i < header.Nodes; i++ { + if err := enc.Encode(&localNodes[i]); err != nil { + t.Fatalf("unexpected err %s", err) + } + } + + // Read the message type + var msgType messageType + if err := binary.Read(conn, binary.BigEndian, &msgType); err != nil { + t.Fatalf("unexpected err %s", err) + } + + var bufConn io.Reader = conn + msghd := codec.MsgpackHandle{} + dec := codec.NewDecoder(bufConn, &msghd) + + // Check if we have a compressed message + if msgType == compressMsg { + var c compress + if err := dec.Decode(&c); err != nil { + t.Fatalf("unexpected err %s", err) + } + decomp, err := decompressBuffer(&c) + if err != nil { + t.Fatalf("unexpected err %s", err) + } + + // Reset the message type + msgType = messageType(decomp[0]) + + // Create a new bufConn + bufConn = bytes.NewReader(decomp[1:]) + + // Create a new decoder + dec = codec.NewDecoder(bufConn, &hd) + } + + // Quit if not push/pull + if msgType != pushPullMsg { + t.Fatalf("bad message type") + } + + if err := dec.Decode(&header); err != nil { + t.Fatalf("unexpected err %s", err) + } + + // Allocate space for the transfer + remoteNodes := make([]pushNodeState, header.Nodes) + + // Try to decode all the states + for i := 0; i < header.Nodes; i++ { + if err := dec.Decode(&remoteNodes[i]); err != nil { + t.Fatalf("unexpected err %s", err) + } + } + + if len(remoteNodes) != 1 { + t.Fatalf("bad response") + } + + n := &remoteNodes[0] + if n.Name != "Test 0" { + t.Fatalf("bad name") + } + if bytes.Compare(n.Addr, net.ParseIP(m.config.BindAddr)) != 0 { + t.Fatal("bad addr") + } + if n.Incarnation != 0 { + t.Fatal("bad incarnation") + } + if n.State != stateSuspect { + t.Fatal("bad state") + } +} + +func TestSendMsg_Piggyback(t *testing.T) { + m := GetMemberlist(t) + defer m.Shutdown() + + // Add a message to be broadcast + a := alive{ + Incarnation: 10, + Node: "rand", + Addr: []byte{127, 0, 0, 255}, + Meta: nil, + } + m.encodeAndBroadcast("rand", aliveMsg, &a) + + var udp *net.UDPConn + for port := 60000; port < 61000; port++ { + udpAddr := fmt.Sprintf("127.0.0.1:%d", port) + udpLn, err := net.ListenPacket("udp", udpAddr) + if err == nil { + udp = udpLn.(*net.UDPConn) + break + } + } + + // Encode a ping + ping := ping{SeqNo: 42} + buf, err := encode(pingMsg, ping) + if err != nil { + t.Fatalf("unexpected err %s", err) + } + + // Send + addr := &net.UDPAddr{IP: net.ParseIP(m.config.BindAddr), Port: m.config.BindPort} + udp.WriteTo(buf.Bytes(), addr) + + // Wait for response + doneCh := make(chan struct{}, 1) + go func() { + select { + case <-doneCh: + case <-time.After(2 * time.Second): + panic("timeout") + } + }() + + in := make([]byte, 1500) + n, _, err := udp.ReadFrom(in) + if err != nil { + t.Fatalf("unexpected err %s", err) + } + in = in[0:n] + + msgType := messageType(in[0]) + if msgType != compoundMsg { + t.Fatalf("bad response %v", in) + } + + // get the parts + trunc, parts, err := decodeCompoundMessage(in[1:]) + if trunc != 0 { + t.Fatalf("unexpected truncation") + } + if len(parts) != 2 { + t.Fatalf("unexpected parts %v", parts) + } + if err != nil { + t.Fatalf("unexpected err %s", err) + } + + var ack ackResp + if err := decode(parts[0][1:], &ack); err != nil { + t.Fatalf("unexpected err %s", err) + } + + if ack.SeqNo != 42 { + t.Fatalf("bad sequence no") + } + + var aliveout alive + if err := decode(parts[1][1:], &aliveout); err != nil { + t.Fatalf("unexpected err %s", err) + } + + if aliveout.Node != "rand" || aliveout.Incarnation != 10 { + t.Fatalf("bad mesg") + } + + doneCh <- struct{}{} +} + +func TestEncryptDecryptState(t *testing.T) { + state := []byte("this is our internal state...") + config := &Config{ + SecretKey: []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + ProtocolVersion: ProtocolVersionMax, + } + + m, err := Create(config) + if err != nil { + t.Fatalf("err: %s", err) + } + defer m.Shutdown() + + crypt, err := m.encryptLocalState(state) + if err != nil { + t.Fatalf("err: %v", err) + } + + // Create reader, seek past the type byte + buf := bytes.NewReader(crypt) + buf.Seek(1, 0) + + plain, err := m.decryptRemoteState(buf) + if err != nil { + t.Fatalf("err: %v", err) + } + + if !reflect.DeepEqual(state, plain) { + t.Fatalf("Decrypt failed: %v", plain) + } +} + +func TestRawSendUdp_CRC(t *testing.T) { + m := GetMemberlist(t) + m.config.EnableCompression = false + defer m.Shutdown() + + var udp *net.UDPConn + for port := 60000; port < 61000; port++ { + udpAddr := fmt.Sprintf("127.0.0.1:%d", port) + udpLn, err := net.ListenPacket("udp", udpAddr) + if err == nil { + udp = udpLn.(*net.UDPConn) + break + } + } + + if udp == nil { + t.Fatalf("no udp listener") + } + + // Pass a nil node with no nodes registered, should result in no checksum + payload := []byte{3, 3, 3, 3} + m.rawSendMsgPacket(udp.LocalAddr().String(), nil, payload) + + in := make([]byte, 1500) + n, _, err := udp.ReadFrom(in) + if err != nil { + t.Fatalf("unexpected err %s", err) + } + in = in[0:n] + + if len(in) != 4 { + t.Fatalf("bad: %v", in) + } + + // Pass a non-nil node with PMax >= 5, should result in a checksum + m.rawSendMsgPacket(udp.LocalAddr().String(), &Node{PMax: 5}, payload) + + in = make([]byte, 1500) + n, _, err = udp.ReadFrom(in) + if err != nil { + t.Fatalf("unexpected err %s", err) + } + in = in[0:n] + + if len(in) != 9 { + t.Fatalf("bad: %v", in) + } + + // Register a node with PMax >= 5 to be looked up, should result in a checksum + m.nodeMap["127.0.0.1"] = &nodeState{ + Node: Node{PMax: 5}, + } + m.rawSendMsgPacket(udp.LocalAddr().String(), nil, payload) + + in = make([]byte, 1500) + n, _, err = udp.ReadFrom(in) + if err != nil { + t.Fatalf("unexpected err %s", err) + } + in = in[0:n] + + if len(in) != 9 { + t.Fatalf("bad: %v", in) + } +} + +func TestIngestPacket_CRC(t *testing.T) { + m := GetMemberlist(t) + m.config.EnableCompression = false + defer m.Shutdown() + + var udp *net.UDPConn + for port := 60000; port < 61000; port++ { + udpAddr := fmt.Sprintf("127.0.0.1:%d", port) + udpLn, err := net.ListenPacket("udp", udpAddr) + if err == nil { + udp = udpLn.(*net.UDPConn) + break + } + } + + if udp == nil { + t.Fatalf("no udp listener") + } + + // Get a message with a checksum + payload := []byte{3, 3, 3, 3} + m.rawSendMsgPacket(udp.LocalAddr().String(), &Node{PMax: 5}, payload) + + in := make([]byte, 1500) + n, _, err := udp.ReadFrom(in) + if err != nil { + t.Fatalf("unexpected err %s", err) + } + in = in[0:n] + + if len(in) != 9 { + t.Fatalf("bad: %v", in) + } + + // Corrupt the checksum + in[1] <<= 1 + + logs := &bytes.Buffer{} + logger := log.New(logs, "", 0) + m.logger = logger + m.ingestPacket(in, udp.LocalAddr(), time.Now()) + + if !strings.Contains(logs.String(), "invalid checksum") { + t.Fatalf("bad: %s", logs.String()) + } +} |