summaryrefslogtreecommitdiffstats
path: root/vendor/golang.org/x/crypto/ssh/handshake_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/golang.org/x/crypto/ssh/handshake_test.go')
-rw-r--r--vendor/golang.org/x/crypto/ssh/handshake_test.go486
1 files changed, 486 insertions, 0 deletions
diff --git a/vendor/golang.org/x/crypto/ssh/handshake_test.go b/vendor/golang.org/x/crypto/ssh/handshake_test.go
new file mode 100644
index 000000000..da53d3a0d
--- /dev/null
+++ b/vendor/golang.org/x/crypto/ssh/handshake_test.go
@@ -0,0 +1,486 @@
+// Copyright 2013 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package ssh
+
+import (
+ "bytes"
+ "crypto/rand"
+ "errors"
+ "fmt"
+ "net"
+ "reflect"
+ "runtime"
+ "strings"
+ "sync"
+ "testing"
+)
+
+type testChecker struct {
+ calls []string
+}
+
+func (t *testChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error {
+ if dialAddr == "bad" {
+ return fmt.Errorf("dialAddr is bad")
+ }
+
+ if tcpAddr, ok := addr.(*net.TCPAddr); !ok || tcpAddr == nil {
+ return fmt.Errorf("testChecker: got %T want *net.TCPAddr", addr)
+ }
+
+ t.calls = append(t.calls, fmt.Sprintf("%s %v %s %x", dialAddr, addr, key.Type(), key.Marshal()))
+
+ return nil
+}
+
+// netPipe is analogous to net.Pipe, but it uses a real net.Conn, and
+// therefore is buffered (net.Pipe deadlocks if both sides start with
+// a write.)
+func netPipe() (net.Conn, net.Conn, error) {
+ listener, err := net.Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ return nil, nil, err
+ }
+ defer listener.Close()
+ c1, err := net.Dial("tcp", listener.Addr().String())
+ if err != nil {
+ return nil, nil, err
+ }
+
+ c2, err := listener.Accept()
+ if err != nil {
+ c1.Close()
+ return nil, nil, err
+ }
+
+ return c1, c2, nil
+}
+
+func handshakePair(clientConf *ClientConfig, addr string) (client *handshakeTransport, server *handshakeTransport, err error) {
+ a, b, err := netPipe()
+ if err != nil {
+ return nil, nil, err
+ }
+
+ trC := newTransport(a, rand.Reader, true)
+ trS := newTransport(b, rand.Reader, false)
+ clientConf.SetDefaults()
+
+ v := []byte("version")
+ client = newClientTransport(trC, v, v, clientConf, addr, a.RemoteAddr())
+
+ serverConf := &ServerConfig{}
+ serverConf.AddHostKey(testSigners["ecdsa"])
+ serverConf.AddHostKey(testSigners["rsa"])
+ serverConf.SetDefaults()
+ server = newServerTransport(trS, v, v, serverConf)
+
+ return client, server, nil
+}
+
+func TestHandshakeBasic(t *testing.T) {
+ if runtime.GOOS == "plan9" {
+ t.Skip("see golang.org/issue/7237")
+ }
+ checker := &testChecker{}
+ trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr")
+ if err != nil {
+ t.Fatalf("handshakePair: %v", err)
+ }
+
+ defer trC.Close()
+ defer trS.Close()
+
+ go func() {
+ // Client writes a bunch of stuff, and does a key
+ // change in the middle. This should not confuse the
+ // handshake in progress
+ for i := 0; i < 10; i++ {
+ p := []byte{msgRequestSuccess, byte(i)}
+ if err := trC.writePacket(p); err != nil {
+ t.Fatalf("sendPacket: %v", err)
+ }
+ if i == 5 {
+ // halfway through, we request a key change.
+ err := trC.sendKexInit(subsequentKeyExchange)
+ if err != nil {
+ t.Fatalf("sendKexInit: %v", err)
+ }
+ }
+ }
+ trC.Close()
+ }()
+
+ // Server checks that client messages come in cleanly
+ i := 0
+ for {
+ p, err := trS.readPacket()
+ if err != nil {
+ break
+ }
+ if p[0] == msgNewKeys {
+ continue
+ }
+ want := []byte{msgRequestSuccess, byte(i)}
+ if bytes.Compare(p, want) != 0 {
+ t.Errorf("message %d: got %q, want %q", i, p, want)
+ }
+ i++
+ }
+ if i != 10 {
+ t.Errorf("received %d messages, want 10.", i)
+ }
+
+ // If all went well, we registered exactly 1 key change.
+ if len(checker.calls) != 1 {
+ t.Fatalf("got %d host key checks, want 1", len(checker.calls))
+ }
+
+ pub := testSigners["ecdsa"].PublicKey()
+ want := fmt.Sprintf("%s %v %s %x", "addr", trC.remoteAddr, pub.Type(), pub.Marshal())
+ if want != checker.calls[0] {
+ t.Errorf("got %q want %q for host key check", checker.calls[0], want)
+ }
+}
+
+func TestHandshakeError(t *testing.T) {
+ checker := &testChecker{}
+ trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "bad")
+ if err != nil {
+ t.Fatalf("handshakePair: %v", err)
+ }
+ defer trC.Close()
+ defer trS.Close()
+
+ // send a packet
+ packet := []byte{msgRequestSuccess, 42}
+ if err := trC.writePacket(packet); err != nil {
+ t.Errorf("writePacket: %v", err)
+ }
+
+ // Now request a key change.
+ err = trC.sendKexInit(subsequentKeyExchange)
+ if err != nil {
+ t.Errorf("sendKexInit: %v", err)
+ }
+
+ // the key change will fail, and afterwards we can't write.
+ if err := trC.writePacket([]byte{msgRequestSuccess, 43}); err == nil {
+ t.Errorf("writePacket after botched rekey succeeded.")
+ }
+
+ readback, err := trS.readPacket()
+ if err != nil {
+ t.Fatalf("server closed too soon: %v", err)
+ }
+ if bytes.Compare(readback, packet) != 0 {
+ t.Errorf("got %q want %q", readback, packet)
+ }
+ readback, err = trS.readPacket()
+ if err == nil {
+ t.Errorf("got a message %q after failed key change", readback)
+ }
+}
+
+func TestForceFirstKex(t *testing.T) {
+ checker := &testChecker{}
+ trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr")
+ if err != nil {
+ t.Fatalf("handshakePair: %v", err)
+ }
+
+ defer trC.Close()
+ defer trS.Close()
+
+ trC.writePacket(Marshal(&serviceRequestMsg{serviceUserAuth}))
+
+ // We setup the initial key exchange, but the remote side
+ // tries to send serviceRequestMsg in cleartext, which is
+ // disallowed.
+
+ err = trS.sendKexInit(firstKeyExchange)
+ if err == nil {
+ t.Errorf("server first kex init should reject unexpected packet")
+ }
+}
+
+func TestHandshakeTwice(t *testing.T) {
+ checker := &testChecker{}
+ trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr")
+ if err != nil {
+ t.Fatalf("handshakePair: %v", err)
+ }
+
+ defer trC.Close()
+ defer trS.Close()
+
+ // Both sides should ask for the first key exchange first.
+ err = trS.sendKexInit(firstKeyExchange)
+ if err != nil {
+ t.Errorf("server sendKexInit: %v", err)
+ }
+
+ err = trC.sendKexInit(firstKeyExchange)
+ if err != nil {
+ t.Errorf("client sendKexInit: %v", err)
+ }
+
+ sent := 0
+ // send a packet
+ packet := make([]byte, 5)
+ packet[0] = msgRequestSuccess
+ if err := trC.writePacket(packet); err != nil {
+ t.Errorf("writePacket: %v", err)
+ }
+ sent++
+
+ // Send another packet. Use a fresh one, since writePacket destroys.
+ packet = make([]byte, 5)
+ packet[0] = msgRequestSuccess
+ if err := trC.writePacket(packet); err != nil {
+ t.Errorf("writePacket: %v", err)
+ }
+ sent++
+
+ // 2nd key change.
+ err = trC.sendKexInit(subsequentKeyExchange)
+ if err != nil {
+ t.Errorf("sendKexInit: %v", err)
+ }
+
+ packet = make([]byte, 5)
+ packet[0] = msgRequestSuccess
+ if err := trC.writePacket(packet); err != nil {
+ t.Errorf("writePacket: %v", err)
+ }
+ sent++
+
+ packet = make([]byte, 5)
+ packet[0] = msgRequestSuccess
+ for i := 0; i < sent; i++ {
+ msg, err := trS.readPacket()
+ if err != nil {
+ t.Fatalf("server closed too soon: %v", err)
+ }
+
+ if bytes.Compare(msg, packet) != 0 {
+ t.Errorf("packet %d: got %q want %q", i, msg, packet)
+ }
+ }
+ if len(checker.calls) != 2 {
+ t.Errorf("got %d key changes, want 2", len(checker.calls))
+ }
+}
+
+func TestHandshakeAutoRekeyWrite(t *testing.T) {
+ checker := &testChecker{}
+ clientConf := &ClientConfig{HostKeyCallback: checker.Check}
+ clientConf.RekeyThreshold = 500
+ trC, trS, err := handshakePair(clientConf, "addr")
+ if err != nil {
+ t.Fatalf("handshakePair: %v", err)
+ }
+ defer trC.Close()
+ defer trS.Close()
+
+ for i := 0; i < 5; i++ {
+ packet := make([]byte, 251)
+ packet[0] = msgRequestSuccess
+ if err := trC.writePacket(packet); err != nil {
+ t.Errorf("writePacket: %v", err)
+ }
+ }
+
+ j := 0
+ for ; j < 5; j++ {
+ _, err := trS.readPacket()
+ if err != nil {
+ break
+ }
+ }
+
+ if j != 5 {
+ t.Errorf("got %d, want 5 messages", j)
+ }
+
+ if len(checker.calls) != 2 {
+ t.Errorf("got %d key changes, wanted 2", len(checker.calls))
+ }
+}
+
+type syncChecker struct {
+ called chan int
+}
+
+func (t *syncChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error {
+ t.called <- 1
+ return nil
+}
+
+func TestHandshakeAutoRekeyRead(t *testing.T) {
+ sync := &syncChecker{make(chan int, 2)}
+ clientConf := &ClientConfig{
+ HostKeyCallback: sync.Check,
+ }
+ clientConf.RekeyThreshold = 500
+
+ trC, trS, err := handshakePair(clientConf, "addr")
+ if err != nil {
+ t.Fatalf("handshakePair: %v", err)
+ }
+ defer trC.Close()
+ defer trS.Close()
+
+ packet := make([]byte, 501)
+ packet[0] = msgRequestSuccess
+ if err := trS.writePacket(packet); err != nil {
+ t.Fatalf("writePacket: %v", err)
+ }
+ // While we read out the packet, a key change will be
+ // initiated.
+ if _, err := trC.readPacket(); err != nil {
+ t.Fatalf("readPacket(client): %v", err)
+ }
+
+ <-sync.called
+}
+
+// errorKeyingTransport generates errors after a given number of
+// read/write operations.
+type errorKeyingTransport struct {
+ packetConn
+ readLeft, writeLeft int
+}
+
+func (n *errorKeyingTransport) prepareKeyChange(*algorithms, *kexResult) error {
+ return nil
+}
+func (n *errorKeyingTransport) getSessionID() []byte {
+ return nil
+}
+
+func (n *errorKeyingTransport) writePacket(packet []byte) error {
+ if n.writeLeft == 0 {
+ n.Close()
+ return errors.New("barf")
+ }
+
+ n.writeLeft--
+ return n.packetConn.writePacket(packet)
+}
+
+func (n *errorKeyingTransport) readPacket() ([]byte, error) {
+ if n.readLeft == 0 {
+ n.Close()
+ return nil, errors.New("barf")
+ }
+
+ n.readLeft--
+ return n.packetConn.readPacket()
+}
+
+func TestHandshakeErrorHandlingRead(t *testing.T) {
+ for i := 0; i < 20; i++ {
+ testHandshakeErrorHandlingN(t, i, -1)
+ }
+}
+
+func TestHandshakeErrorHandlingWrite(t *testing.T) {
+ for i := 0; i < 20; i++ {
+ testHandshakeErrorHandlingN(t, -1, i)
+ }
+}
+
+// testHandshakeErrorHandlingN runs handshakes, injecting errors. If
+// handshakeTransport deadlocks, the go runtime will detect it and
+// panic.
+func testHandshakeErrorHandlingN(t *testing.T, readLimit, writeLimit int) {
+ msg := Marshal(&serviceRequestMsg{strings.Repeat("x", int(minRekeyThreshold)/4)})
+
+ a, b := memPipe()
+ defer a.Close()
+ defer b.Close()
+
+ key := testSigners["ecdsa"]
+ serverConf := Config{RekeyThreshold: minRekeyThreshold}
+ serverConf.SetDefaults()
+ serverConn := newHandshakeTransport(&errorKeyingTransport{a, readLimit, writeLimit}, &serverConf, []byte{'a'}, []byte{'b'})
+ serverConn.hostKeys = []Signer{key}
+ go serverConn.readLoop()
+
+ clientConf := Config{RekeyThreshold: 10 * minRekeyThreshold}
+ clientConf.SetDefaults()
+ clientConn := newHandshakeTransport(&errorKeyingTransport{b, -1, -1}, &clientConf, []byte{'a'}, []byte{'b'})
+ clientConn.hostKeyAlgorithms = []string{key.PublicKey().Type()}
+ go clientConn.readLoop()
+
+ var wg sync.WaitGroup
+ wg.Add(4)
+
+ for _, hs := range []packetConn{serverConn, clientConn} {
+ go func(c packetConn) {
+ for {
+ err := c.writePacket(msg)
+ if err != nil {
+ break
+ }
+ }
+ wg.Done()
+ }(hs)
+ go func(c packetConn) {
+ for {
+ _, err := c.readPacket()
+ if err != nil {
+ break
+ }
+ }
+ wg.Done()
+ }(hs)
+ }
+
+ wg.Wait()
+}
+
+func TestDisconnect(t *testing.T) {
+ if runtime.GOOS == "plan9" {
+ t.Skip("see golang.org/issue/7237")
+ }
+ checker := &testChecker{}
+ trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr")
+ if err != nil {
+ t.Fatalf("handshakePair: %v", err)
+ }
+
+ defer trC.Close()
+ defer trS.Close()
+
+ trC.writePacket([]byte{msgRequestSuccess, 0, 0})
+ errMsg := &disconnectMsg{
+ Reason: 42,
+ Message: "such is life",
+ }
+ trC.writePacket(Marshal(errMsg))
+ trC.writePacket([]byte{msgRequestSuccess, 0, 0})
+
+ packet, err := trS.readPacket()
+ if err != nil {
+ t.Fatalf("readPacket 1: %v", err)
+ }
+ if packet[0] != msgRequestSuccess {
+ t.Errorf("got packet %v, want packet type %d", packet, msgRequestSuccess)
+ }
+
+ _, err = trS.readPacket()
+ if err == nil {
+ t.Errorf("readPacket 2 succeeded")
+ } else if !reflect.DeepEqual(err, errMsg) {
+ t.Errorf("got error %#v, want %#v", err, errMsg)
+ }
+
+ _, err = trS.readPacket()
+ if err == nil {
+ t.Errorf("readPacket 3 succeeded")
+ }
+}