summaryrefslogtreecommitdiffstats
path: root/vendor/golang.org/x/crypto/ssh/handshake_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/golang.org/x/crypto/ssh/handshake_test.go')
-rw-r--r--vendor/golang.org/x/crypto/ssh/handshake_test.go366
1 files changed, 189 insertions, 177 deletions
diff --git a/vendor/golang.org/x/crypto/ssh/handshake_test.go b/vendor/golang.org/x/crypto/ssh/handshake_test.go
index da53d3a0d..e61348fea 100644
--- a/vendor/golang.org/x/crypto/ssh/handshake_test.go
+++ b/vendor/golang.org/x/crypto/ssh/handshake_test.go
@@ -9,6 +9,7 @@ import (
"crypto/rand"
"errors"
"fmt"
+ "io"
"net"
"reflect"
"runtime"
@@ -58,14 +59,46 @@ func netPipe() (net.Conn, net.Conn, error) {
return c1, c2, nil
}
-func handshakePair(clientConf *ClientConfig, addr string) (client *handshakeTransport, server *handshakeTransport, err error) {
+// noiseTransport inserts ignore messages to check that the read loop
+// and the key exchange filters out these messages.
+type noiseTransport struct {
+ keyingTransport
+}
+
+func (t *noiseTransport) writePacket(p []byte) error {
+ ignore := []byte{msgIgnore}
+ if err := t.keyingTransport.writePacket(ignore); err != nil {
+ return err
+ }
+ debug := []byte{msgDebug, 1, 2, 3}
+ if err := t.keyingTransport.writePacket(debug); err != nil {
+ return err
+ }
+
+ return t.keyingTransport.writePacket(p)
+}
+
+func addNoiseTransport(t keyingTransport) keyingTransport {
+ return &noiseTransport{t}
+}
+
+// handshakePair creates two handshakeTransports connected with each
+// other. If the noise argument is true, both transports will try to
+// confuse the other side by sending ignore and debug messages.
+func handshakePair(clientConf *ClientConfig, addr string, noise bool) (client *handshakeTransport, server *handshakeTransport, err error) {
a, b, err := netPipe()
if err != nil {
return nil, nil, err
}
- trC := newTransport(a, rand.Reader, true)
- trS := newTransport(b, rand.Reader, false)
+ var trC, trS keyingTransport
+
+ trC = newTransport(a, rand.Reader, true)
+ trS = newTransport(b, rand.Reader, false)
+ if noise {
+ trC = addNoiseTransport(trC)
+ trS = addNoiseTransport(trS)
+ }
clientConf.SetDefaults()
v := []byte("version")
@@ -77,6 +110,13 @@ func handshakePair(clientConf *ClientConfig, addr string) (client *handshakeTran
serverConf.SetDefaults()
server = newServerTransport(trS, v, v, serverConf)
+ if err := server.waitSession(); err != nil {
+ return nil, nil, fmt.Errorf("server.waitSession: %v", err)
+ }
+ if err := client.waitSession(); err != nil {
+ return nil, nil, fmt.Errorf("client.waitSession: %v", err)
+ }
+
return client, server, nil
}
@@ -84,8 +124,9 @@ func TestHandshakeBasic(t *testing.T) {
if runtime.GOOS == "plan9" {
t.Skip("see golang.org/issue/7237")
}
- checker := &testChecker{}
- trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr")
+
+ checker := &syncChecker{make(chan int, 10)}
+ trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr", false)
if err != nil {
t.Fatalf("handshakePair: %v", err)
}
@@ -93,7 +134,13 @@ func TestHandshakeBasic(t *testing.T) {
defer trC.Close()
defer trS.Close()
+ <-checker.called
+
+ clientDone := make(chan int, 0)
+ gotHalf := make(chan int, 0)
+
go func() {
+ defer close(clientDone)
// Client writes a bunch of stuff, and does a key
// change in the middle. This should not confuse the
// handshake in progress
@@ -103,219 +150,144 @@ func TestHandshakeBasic(t *testing.T) {
t.Fatalf("sendPacket: %v", err)
}
if i == 5 {
+ <-gotHalf
// halfway through, we request a key change.
- err := trC.sendKexInit(subsequentKeyExchange)
- if err != nil {
- t.Fatalf("sendKexInit: %v", err)
- }
+ trC.requestKeyExchange()
+
+ // Wait until we can be sure the key
+ // change has really started before we
+ // write more.
+ <-checker.called
}
}
- trC.Close()
}()
// Server checks that client messages come in cleanly
i := 0
- for {
- p, err := trS.readPacket()
+ err = nil
+ for ; i < 10; i++ {
+ var p []byte
+ p, err = trS.readPacket()
if err != nil {
break
}
- if p[0] == msgNewKeys {
- continue
+ if i == 5 {
+ gotHalf <- 1
}
+
want := []byte{msgRequestSuccess, byte(i)}
if bytes.Compare(p, want) != 0 {
t.Errorf("message %d: got %q, want %q", i, p, want)
}
- i++
+ }
+ <-clientDone
+ if err != nil && err != io.EOF {
+ t.Fatalf("server error: %v", err)
}
if i != 10 {
t.Errorf("received %d messages, want 10.", i)
}
- // If all went well, we registered exactly 1 key change.
- if len(checker.calls) != 1 {
- t.Fatalf("got %d host key checks, want 1", len(checker.calls))
- }
-
- pub := testSigners["ecdsa"].PublicKey()
- want := fmt.Sprintf("%s %v %s %x", "addr", trC.remoteAddr, pub.Type(), pub.Marshal())
- if want != checker.calls[0] {
- t.Errorf("got %q want %q for host key check", checker.calls[0], want)
+ close(checker.called)
+ if _, ok := <-checker.called; ok {
+ // If all went well, we registered exactly 2 key changes: one
+ // that establishes the session, and one that we requested
+ // additionally.
+ t.Fatalf("got another host key checks after 2 handshakes")
}
}
-func TestHandshakeError(t *testing.T) {
+func TestForceFirstKex(t *testing.T) {
+ // like handshakePair, but must access the keyingTransport.
checker := &testChecker{}
- trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "bad")
+ clientConf := &ClientConfig{HostKeyCallback: checker.Check}
+ a, b, err := netPipe()
if err != nil {
- t.Fatalf("handshakePair: %v", err)
+ t.Fatalf("netPipe: %v", err)
}
- defer trC.Close()
- defer trS.Close()
- // send a packet
- packet := []byte{msgRequestSuccess, 42}
- if err := trC.writePacket(packet); err != nil {
- t.Errorf("writePacket: %v", err)
- }
+ var trC, trS keyingTransport
- // Now request a key change.
- err = trC.sendKexInit(subsequentKeyExchange)
- if err != nil {
- t.Errorf("sendKexInit: %v", err)
- }
+ trC = newTransport(a, rand.Reader, true)
- // the key change will fail, and afterwards we can't write.
- if err := trC.writePacket([]byte{msgRequestSuccess, 43}); err == nil {
- t.Errorf("writePacket after botched rekey succeeded.")
- }
+ // This is the disallowed packet:
+ trC.writePacket(Marshal(&serviceRequestMsg{serviceUserAuth}))
- readback, err := trS.readPacket()
- if err != nil {
- t.Fatalf("server closed too soon: %v", err)
- }
- if bytes.Compare(readback, packet) != 0 {
- t.Errorf("got %q want %q", readback, packet)
- }
- readback, err = trS.readPacket()
- if err == nil {
- t.Errorf("got a message %q after failed key change", readback)
- }
-}
+ // Rest of the setup.
+ trS = newTransport(b, rand.Reader, false)
+ clientConf.SetDefaults()
-func TestForceFirstKex(t *testing.T) {
- checker := &testChecker{}
- trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr")
- if err != nil {
- t.Fatalf("handshakePair: %v", err)
- }
+ v := []byte("version")
+ client := newClientTransport(trC, v, v, clientConf, "addr", a.RemoteAddr())
- defer trC.Close()
- defer trS.Close()
+ serverConf := &ServerConfig{}
+ serverConf.AddHostKey(testSigners["ecdsa"])
+ serverConf.AddHostKey(testSigners["rsa"])
+ serverConf.SetDefaults()
+ server := newServerTransport(trS, v, v, serverConf)
- trC.writePacket(Marshal(&serviceRequestMsg{serviceUserAuth}))
+ defer client.Close()
+ defer server.Close()
// We setup the initial key exchange, but the remote side
// tries to send serviceRequestMsg in cleartext, which is
// disallowed.
- err = trS.sendKexInit(firstKeyExchange)
- if err == nil {
+ if err := server.waitSession(); err == nil {
t.Errorf("server first kex init should reject unexpected packet")
}
}
-func TestHandshakeTwice(t *testing.T) {
- checker := &testChecker{}
- trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr")
+func TestHandshakeAutoRekeyWrite(t *testing.T) {
+ checker := &syncChecker{make(chan int, 10)}
+ clientConf := &ClientConfig{HostKeyCallback: checker.Check}
+ clientConf.RekeyThreshold = 500
+ trC, trS, err := handshakePair(clientConf, "addr", false)
if err != nil {
t.Fatalf("handshakePair: %v", err)
}
-
defer trC.Close()
defer trS.Close()
- // Both sides should ask for the first key exchange first.
- err = trS.sendKexInit(firstKeyExchange)
- if err != nil {
- t.Errorf("server sendKexInit: %v", err)
- }
-
- err = trC.sendKexInit(firstKeyExchange)
- if err != nil {
- t.Errorf("client sendKexInit: %v", err)
- }
-
- sent := 0
- // send a packet
- packet := make([]byte, 5)
- packet[0] = msgRequestSuccess
- if err := trC.writePacket(packet); err != nil {
- t.Errorf("writePacket: %v", err)
- }
- sent++
-
- // Send another packet. Use a fresh one, since writePacket destroys.
- packet = make([]byte, 5)
- packet[0] = msgRequestSuccess
- if err := trC.writePacket(packet); err != nil {
- t.Errorf("writePacket: %v", err)
- }
- sent++
-
- // 2nd key change.
- err = trC.sendKexInit(subsequentKeyExchange)
- if err != nil {
- t.Errorf("sendKexInit: %v", err)
- }
-
- packet = make([]byte, 5)
- packet[0] = msgRequestSuccess
- if err := trC.writePacket(packet); err != nil {
- t.Errorf("writePacket: %v", err)
- }
- sent++
-
- packet = make([]byte, 5)
- packet[0] = msgRequestSuccess
- for i := 0; i < sent; i++ {
- msg, err := trS.readPacket()
- if err != nil {
- t.Fatalf("server closed too soon: %v", err)
+ done := make(chan int, 1)
+ const numPacket = 5
+ go func() {
+ defer close(done)
+ j := 0
+ for ; j < numPacket; j++ {
+ if _, err := trS.readPacket(); err != nil {
+ break
+ }
}
- if bytes.Compare(msg, packet) != 0 {
- t.Errorf("packet %d: got %q want %q", i, msg, packet)
+ if j != numPacket {
+ t.Errorf("got %d, want 5 messages", j)
}
- }
- if len(checker.calls) != 2 {
- t.Errorf("got %d key changes, want 2", len(checker.calls))
- }
-}
+ }()
-func TestHandshakeAutoRekeyWrite(t *testing.T) {
- checker := &testChecker{}
- clientConf := &ClientConfig{HostKeyCallback: checker.Check}
- clientConf.RekeyThreshold = 500
- trC, trS, err := handshakePair(clientConf, "addr")
- if err != nil {
- t.Fatalf("handshakePair: %v", err)
- }
- defer trC.Close()
- defer trS.Close()
+ <-checker.called
- for i := 0; i < 5; i++ {
+ for i := 0; i < numPacket; i++ {
packet := make([]byte, 251)
packet[0] = msgRequestSuccess
if err := trC.writePacket(packet); err != nil {
t.Errorf("writePacket: %v", err)
}
- }
-
- j := 0
- for ; j < 5; j++ {
- _, err := trS.readPacket()
- if err != nil {
- break
+ if i == 2 {
+ // Make sure the kex is in progress.
+ <-checker.called
}
- }
- if j != 5 {
- t.Errorf("got %d, want 5 messages", j)
- }
-
- if len(checker.calls) != 2 {
- t.Errorf("got %d key changes, wanted 2", len(checker.calls))
}
+ <-done
}
type syncChecker struct {
called chan int
}
-func (t *syncChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error {
- t.called <- 1
+func (c *syncChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error {
+ c.called <- 1
return nil
}
@@ -326,7 +298,7 @@ func TestHandshakeAutoRekeyRead(t *testing.T) {
}
clientConf.RekeyThreshold = 500
- trC, trS, err := handshakePair(clientConf, "addr")
+ trC, trS, err := handshakePair(clientConf, "addr", false)
if err != nil {
t.Fatalf("handshakePair: %v", err)
}
@@ -338,12 +310,19 @@ func TestHandshakeAutoRekeyRead(t *testing.T) {
if err := trS.writePacket(packet); err != nil {
t.Fatalf("writePacket: %v", err)
}
+
// While we read out the packet, a key change will be
// initiated.
- if _, err := trC.readPacket(); err != nil {
- t.Fatalf("readPacket(client): %v", err)
- }
+ done := make(chan int, 1)
+ go func() {
+ defer close(done)
+ if _, err := trC.readPacket(); err != nil {
+ t.Fatalf("readPacket(client): %v", err)
+ }
+
+ }()
+ <-done
<-sync.called
}
@@ -357,6 +336,7 @@ type errorKeyingTransport struct {
func (n *errorKeyingTransport) prepareKeyChange(*algorithms, *kexResult) error {
return nil
}
+
func (n *errorKeyingTransport) getSessionID() []byte {
return nil
}
@@ -383,20 +363,32 @@ func (n *errorKeyingTransport) readPacket() ([]byte, error) {
func TestHandshakeErrorHandlingRead(t *testing.T) {
for i := 0; i < 20; i++ {
- testHandshakeErrorHandlingN(t, i, -1)
+ testHandshakeErrorHandlingN(t, i, -1, false)
}
}
func TestHandshakeErrorHandlingWrite(t *testing.T) {
for i := 0; i < 20; i++ {
- testHandshakeErrorHandlingN(t, -1, i)
+ testHandshakeErrorHandlingN(t, -1, i, false)
+ }
+}
+
+func TestHandshakeErrorHandlingReadCoupled(t *testing.T) {
+ for i := 0; i < 20; i++ {
+ testHandshakeErrorHandlingN(t, i, -1, true)
+ }
+}
+
+func TestHandshakeErrorHandlingWriteCoupled(t *testing.T) {
+ for i := 0; i < 20; i++ {
+ testHandshakeErrorHandlingN(t, -1, i, true)
}
}
// testHandshakeErrorHandlingN runs handshakes, injecting errors. If
// handshakeTransport deadlocks, the go runtime will detect it and
// panic.
-func testHandshakeErrorHandlingN(t *testing.T, readLimit, writeLimit int) {
+func testHandshakeErrorHandlingN(t *testing.T, readLimit, writeLimit int, coupled bool) {
msg := Marshal(&serviceRequestMsg{strings.Repeat("x", int(minRekeyThreshold)/4)})
a, b := memPipe()
@@ -409,37 +401,57 @@ func testHandshakeErrorHandlingN(t *testing.T, readLimit, writeLimit int) {
serverConn := newHandshakeTransport(&errorKeyingTransport{a, readLimit, writeLimit}, &serverConf, []byte{'a'}, []byte{'b'})
serverConn.hostKeys = []Signer{key}
go serverConn.readLoop()
+ go serverConn.kexLoop()
clientConf := Config{RekeyThreshold: 10 * minRekeyThreshold}
clientConf.SetDefaults()
clientConn := newHandshakeTransport(&errorKeyingTransport{b, -1, -1}, &clientConf, []byte{'a'}, []byte{'b'})
clientConn.hostKeyAlgorithms = []string{key.PublicKey().Type()}
go clientConn.readLoop()
+ go clientConn.kexLoop()
var wg sync.WaitGroup
- wg.Add(4)
for _, hs := range []packetConn{serverConn, clientConn} {
- go func(c packetConn) {
- for {
- err := c.writePacket(msg)
- if err != nil {
- break
+ if !coupled {
+ wg.Add(2)
+ go func(c packetConn) {
+ for i := 0; ; i++ {
+ str := fmt.Sprintf("%08x", i) + strings.Repeat("x", int(minRekeyThreshold)/4-8)
+ err := c.writePacket(Marshal(&serviceRequestMsg{str}))
+ if err != nil {
+ break
+ }
}
- }
- wg.Done()
- }(hs)
- go func(c packetConn) {
- for {
- _, err := c.readPacket()
- if err != nil {
- break
+ wg.Done()
+ c.Close()
+ }(hs)
+ go func(c packetConn) {
+ for {
+ _, err := c.readPacket()
+ if err != nil {
+ break
+ }
}
- }
- wg.Done()
- }(hs)
- }
+ wg.Done()
+ }(hs)
+ } else {
+ wg.Add(1)
+ go func(c packetConn) {
+ for {
+ _, err := c.readPacket()
+ if err != nil {
+ break
+ }
+ if err := c.writePacket(msg); err != nil {
+ break
+ }
+ }
+ wg.Done()
+ }(hs)
+ }
+ }
wg.Wait()
}
@@ -448,7 +460,7 @@ func TestDisconnect(t *testing.T) {
t.Skip("see golang.org/issue/7237")
}
checker := &testChecker{}
- trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr")
+ trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr", false)
if err != nil {
t.Fatalf("handshakePair: %v", err)
}