diff options
Diffstat (limited to 'vendor/golang.org/x/crypto/ssh/handshake_test.go')
-rw-r--r-- | vendor/golang.org/x/crypto/ssh/handshake_test.go | 366 |
1 files changed, 189 insertions, 177 deletions
diff --git a/vendor/golang.org/x/crypto/ssh/handshake_test.go b/vendor/golang.org/x/crypto/ssh/handshake_test.go index da53d3a0d..e61348fea 100644 --- a/vendor/golang.org/x/crypto/ssh/handshake_test.go +++ b/vendor/golang.org/x/crypto/ssh/handshake_test.go @@ -9,6 +9,7 @@ import ( "crypto/rand" "errors" "fmt" + "io" "net" "reflect" "runtime" @@ -58,14 +59,46 @@ func netPipe() (net.Conn, net.Conn, error) { return c1, c2, nil } -func handshakePair(clientConf *ClientConfig, addr string) (client *handshakeTransport, server *handshakeTransport, err error) { +// noiseTransport inserts ignore messages to check that the read loop +// and the key exchange filters out these messages. +type noiseTransport struct { + keyingTransport +} + +func (t *noiseTransport) writePacket(p []byte) error { + ignore := []byte{msgIgnore} + if err := t.keyingTransport.writePacket(ignore); err != nil { + return err + } + debug := []byte{msgDebug, 1, 2, 3} + if err := t.keyingTransport.writePacket(debug); err != nil { + return err + } + + return t.keyingTransport.writePacket(p) +} + +func addNoiseTransport(t keyingTransport) keyingTransport { + return &noiseTransport{t} +} + +// handshakePair creates two handshakeTransports connected with each +// other. If the noise argument is true, both transports will try to +// confuse the other side by sending ignore and debug messages. +func handshakePair(clientConf *ClientConfig, addr string, noise bool) (client *handshakeTransport, server *handshakeTransport, err error) { a, b, err := netPipe() if err != nil { return nil, nil, err } - trC := newTransport(a, rand.Reader, true) - trS := newTransport(b, rand.Reader, false) + var trC, trS keyingTransport + + trC = newTransport(a, rand.Reader, true) + trS = newTransport(b, rand.Reader, false) + if noise { + trC = addNoiseTransport(trC) + trS = addNoiseTransport(trS) + } clientConf.SetDefaults() v := []byte("version") @@ -77,6 +110,13 @@ func handshakePair(clientConf *ClientConfig, addr string) (client *handshakeTran serverConf.SetDefaults() server = newServerTransport(trS, v, v, serverConf) + if err := server.waitSession(); err != nil { + return nil, nil, fmt.Errorf("server.waitSession: %v", err) + } + if err := client.waitSession(); err != nil { + return nil, nil, fmt.Errorf("client.waitSession: %v", err) + } + return client, server, nil } @@ -84,8 +124,9 @@ func TestHandshakeBasic(t *testing.T) { if runtime.GOOS == "plan9" { t.Skip("see golang.org/issue/7237") } - checker := &testChecker{} - trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr") + + checker := &syncChecker{make(chan int, 10)} + trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr", false) if err != nil { t.Fatalf("handshakePair: %v", err) } @@ -93,7 +134,13 @@ func TestHandshakeBasic(t *testing.T) { defer trC.Close() defer trS.Close() + <-checker.called + + clientDone := make(chan int, 0) + gotHalf := make(chan int, 0) + go func() { + defer close(clientDone) // Client writes a bunch of stuff, and does a key // change in the middle. This should not confuse the // handshake in progress @@ -103,219 +150,144 @@ func TestHandshakeBasic(t *testing.T) { t.Fatalf("sendPacket: %v", err) } if i == 5 { + <-gotHalf // halfway through, we request a key change. - err := trC.sendKexInit(subsequentKeyExchange) - if err != nil { - t.Fatalf("sendKexInit: %v", err) - } + trC.requestKeyExchange() + + // Wait until we can be sure the key + // change has really started before we + // write more. + <-checker.called } } - trC.Close() }() // Server checks that client messages come in cleanly i := 0 - for { - p, err := trS.readPacket() + err = nil + for ; i < 10; i++ { + var p []byte + p, err = trS.readPacket() if err != nil { break } - if p[0] == msgNewKeys { - continue + if i == 5 { + gotHalf <- 1 } + want := []byte{msgRequestSuccess, byte(i)} if bytes.Compare(p, want) != 0 { t.Errorf("message %d: got %q, want %q", i, p, want) } - i++ + } + <-clientDone + if err != nil && err != io.EOF { + t.Fatalf("server error: %v", err) } if i != 10 { t.Errorf("received %d messages, want 10.", i) } - // If all went well, we registered exactly 1 key change. - if len(checker.calls) != 1 { - t.Fatalf("got %d host key checks, want 1", len(checker.calls)) - } - - pub := testSigners["ecdsa"].PublicKey() - want := fmt.Sprintf("%s %v %s %x", "addr", trC.remoteAddr, pub.Type(), pub.Marshal()) - if want != checker.calls[0] { - t.Errorf("got %q want %q for host key check", checker.calls[0], want) + close(checker.called) + if _, ok := <-checker.called; ok { + // If all went well, we registered exactly 2 key changes: one + // that establishes the session, and one that we requested + // additionally. + t.Fatalf("got another host key checks after 2 handshakes") } } -func TestHandshakeError(t *testing.T) { +func TestForceFirstKex(t *testing.T) { + // like handshakePair, but must access the keyingTransport. checker := &testChecker{} - trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "bad") + clientConf := &ClientConfig{HostKeyCallback: checker.Check} + a, b, err := netPipe() if err != nil { - t.Fatalf("handshakePair: %v", err) + t.Fatalf("netPipe: %v", err) } - defer trC.Close() - defer trS.Close() - // send a packet - packet := []byte{msgRequestSuccess, 42} - if err := trC.writePacket(packet); err != nil { - t.Errorf("writePacket: %v", err) - } + var trC, trS keyingTransport - // Now request a key change. - err = trC.sendKexInit(subsequentKeyExchange) - if err != nil { - t.Errorf("sendKexInit: %v", err) - } + trC = newTransport(a, rand.Reader, true) - // the key change will fail, and afterwards we can't write. - if err := trC.writePacket([]byte{msgRequestSuccess, 43}); err == nil { - t.Errorf("writePacket after botched rekey succeeded.") - } + // This is the disallowed packet: + trC.writePacket(Marshal(&serviceRequestMsg{serviceUserAuth})) - readback, err := trS.readPacket() - if err != nil { - t.Fatalf("server closed too soon: %v", err) - } - if bytes.Compare(readback, packet) != 0 { - t.Errorf("got %q want %q", readback, packet) - } - readback, err = trS.readPacket() - if err == nil { - t.Errorf("got a message %q after failed key change", readback) - } -} + // Rest of the setup. + trS = newTransport(b, rand.Reader, false) + clientConf.SetDefaults() -func TestForceFirstKex(t *testing.T) { - checker := &testChecker{} - trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr") - if err != nil { - t.Fatalf("handshakePair: %v", err) - } + v := []byte("version") + client := newClientTransport(trC, v, v, clientConf, "addr", a.RemoteAddr()) - defer trC.Close() - defer trS.Close() + serverConf := &ServerConfig{} + serverConf.AddHostKey(testSigners["ecdsa"]) + serverConf.AddHostKey(testSigners["rsa"]) + serverConf.SetDefaults() + server := newServerTransport(trS, v, v, serverConf) - trC.writePacket(Marshal(&serviceRequestMsg{serviceUserAuth})) + defer client.Close() + defer server.Close() // We setup the initial key exchange, but the remote side // tries to send serviceRequestMsg in cleartext, which is // disallowed. - err = trS.sendKexInit(firstKeyExchange) - if err == nil { + if err := server.waitSession(); err == nil { t.Errorf("server first kex init should reject unexpected packet") } } -func TestHandshakeTwice(t *testing.T) { - checker := &testChecker{} - trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr") +func TestHandshakeAutoRekeyWrite(t *testing.T) { + checker := &syncChecker{make(chan int, 10)} + clientConf := &ClientConfig{HostKeyCallback: checker.Check} + clientConf.RekeyThreshold = 500 + trC, trS, err := handshakePair(clientConf, "addr", false) if err != nil { t.Fatalf("handshakePair: %v", err) } - defer trC.Close() defer trS.Close() - // Both sides should ask for the first key exchange first. - err = trS.sendKexInit(firstKeyExchange) - if err != nil { - t.Errorf("server sendKexInit: %v", err) - } - - err = trC.sendKexInit(firstKeyExchange) - if err != nil { - t.Errorf("client sendKexInit: %v", err) - } - - sent := 0 - // send a packet - packet := make([]byte, 5) - packet[0] = msgRequestSuccess - if err := trC.writePacket(packet); err != nil { - t.Errorf("writePacket: %v", err) - } - sent++ - - // Send another packet. Use a fresh one, since writePacket destroys. - packet = make([]byte, 5) - packet[0] = msgRequestSuccess - if err := trC.writePacket(packet); err != nil { - t.Errorf("writePacket: %v", err) - } - sent++ - - // 2nd key change. - err = trC.sendKexInit(subsequentKeyExchange) - if err != nil { - t.Errorf("sendKexInit: %v", err) - } - - packet = make([]byte, 5) - packet[0] = msgRequestSuccess - if err := trC.writePacket(packet); err != nil { - t.Errorf("writePacket: %v", err) - } - sent++ - - packet = make([]byte, 5) - packet[0] = msgRequestSuccess - for i := 0; i < sent; i++ { - msg, err := trS.readPacket() - if err != nil { - t.Fatalf("server closed too soon: %v", err) + done := make(chan int, 1) + const numPacket = 5 + go func() { + defer close(done) + j := 0 + for ; j < numPacket; j++ { + if _, err := trS.readPacket(); err != nil { + break + } } - if bytes.Compare(msg, packet) != 0 { - t.Errorf("packet %d: got %q want %q", i, msg, packet) + if j != numPacket { + t.Errorf("got %d, want 5 messages", j) } - } - if len(checker.calls) != 2 { - t.Errorf("got %d key changes, want 2", len(checker.calls)) - } -} + }() -func TestHandshakeAutoRekeyWrite(t *testing.T) { - checker := &testChecker{} - clientConf := &ClientConfig{HostKeyCallback: checker.Check} - clientConf.RekeyThreshold = 500 - trC, trS, err := handshakePair(clientConf, "addr") - if err != nil { - t.Fatalf("handshakePair: %v", err) - } - defer trC.Close() - defer trS.Close() + <-checker.called - for i := 0; i < 5; i++ { + for i := 0; i < numPacket; i++ { packet := make([]byte, 251) packet[0] = msgRequestSuccess if err := trC.writePacket(packet); err != nil { t.Errorf("writePacket: %v", err) } - } - - j := 0 - for ; j < 5; j++ { - _, err := trS.readPacket() - if err != nil { - break + if i == 2 { + // Make sure the kex is in progress. + <-checker.called } - } - if j != 5 { - t.Errorf("got %d, want 5 messages", j) - } - - if len(checker.calls) != 2 { - t.Errorf("got %d key changes, wanted 2", len(checker.calls)) } + <-done } type syncChecker struct { called chan int } -func (t *syncChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error { - t.called <- 1 +func (c *syncChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error { + c.called <- 1 return nil } @@ -326,7 +298,7 @@ func TestHandshakeAutoRekeyRead(t *testing.T) { } clientConf.RekeyThreshold = 500 - trC, trS, err := handshakePair(clientConf, "addr") + trC, trS, err := handshakePair(clientConf, "addr", false) if err != nil { t.Fatalf("handshakePair: %v", err) } @@ -338,12 +310,19 @@ func TestHandshakeAutoRekeyRead(t *testing.T) { if err := trS.writePacket(packet); err != nil { t.Fatalf("writePacket: %v", err) } + // While we read out the packet, a key change will be // initiated. - if _, err := trC.readPacket(); err != nil { - t.Fatalf("readPacket(client): %v", err) - } + done := make(chan int, 1) + go func() { + defer close(done) + if _, err := trC.readPacket(); err != nil { + t.Fatalf("readPacket(client): %v", err) + } + + }() + <-done <-sync.called } @@ -357,6 +336,7 @@ type errorKeyingTransport struct { func (n *errorKeyingTransport) prepareKeyChange(*algorithms, *kexResult) error { return nil } + func (n *errorKeyingTransport) getSessionID() []byte { return nil } @@ -383,20 +363,32 @@ func (n *errorKeyingTransport) readPacket() ([]byte, error) { func TestHandshakeErrorHandlingRead(t *testing.T) { for i := 0; i < 20; i++ { - testHandshakeErrorHandlingN(t, i, -1) + testHandshakeErrorHandlingN(t, i, -1, false) } } func TestHandshakeErrorHandlingWrite(t *testing.T) { for i := 0; i < 20; i++ { - testHandshakeErrorHandlingN(t, -1, i) + testHandshakeErrorHandlingN(t, -1, i, false) + } +} + +func TestHandshakeErrorHandlingReadCoupled(t *testing.T) { + for i := 0; i < 20; i++ { + testHandshakeErrorHandlingN(t, i, -1, true) + } +} + +func TestHandshakeErrorHandlingWriteCoupled(t *testing.T) { + for i := 0; i < 20; i++ { + testHandshakeErrorHandlingN(t, -1, i, true) } } // testHandshakeErrorHandlingN runs handshakes, injecting errors. If // handshakeTransport deadlocks, the go runtime will detect it and // panic. -func testHandshakeErrorHandlingN(t *testing.T, readLimit, writeLimit int) { +func testHandshakeErrorHandlingN(t *testing.T, readLimit, writeLimit int, coupled bool) { msg := Marshal(&serviceRequestMsg{strings.Repeat("x", int(minRekeyThreshold)/4)}) a, b := memPipe() @@ -409,37 +401,57 @@ func testHandshakeErrorHandlingN(t *testing.T, readLimit, writeLimit int) { serverConn := newHandshakeTransport(&errorKeyingTransport{a, readLimit, writeLimit}, &serverConf, []byte{'a'}, []byte{'b'}) serverConn.hostKeys = []Signer{key} go serverConn.readLoop() + go serverConn.kexLoop() clientConf := Config{RekeyThreshold: 10 * minRekeyThreshold} clientConf.SetDefaults() clientConn := newHandshakeTransport(&errorKeyingTransport{b, -1, -1}, &clientConf, []byte{'a'}, []byte{'b'}) clientConn.hostKeyAlgorithms = []string{key.PublicKey().Type()} go clientConn.readLoop() + go clientConn.kexLoop() var wg sync.WaitGroup - wg.Add(4) for _, hs := range []packetConn{serverConn, clientConn} { - go func(c packetConn) { - for { - err := c.writePacket(msg) - if err != nil { - break + if !coupled { + wg.Add(2) + go func(c packetConn) { + for i := 0; ; i++ { + str := fmt.Sprintf("%08x", i) + strings.Repeat("x", int(minRekeyThreshold)/4-8) + err := c.writePacket(Marshal(&serviceRequestMsg{str})) + if err != nil { + break + } } - } - wg.Done() - }(hs) - go func(c packetConn) { - for { - _, err := c.readPacket() - if err != nil { - break + wg.Done() + c.Close() + }(hs) + go func(c packetConn) { + for { + _, err := c.readPacket() + if err != nil { + break + } } - } - wg.Done() - }(hs) - } + wg.Done() + }(hs) + } else { + wg.Add(1) + go func(c packetConn) { + for { + _, err := c.readPacket() + if err != nil { + break + } + if err := c.writePacket(msg); err != nil { + break + } + } + wg.Done() + }(hs) + } + } wg.Wait() } @@ -448,7 +460,7 @@ func TestDisconnect(t *testing.T) { t.Skip("see golang.org/issue/7237") } checker := &testChecker{} - trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr") + trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr", false) if err != nil { t.Fatalf("handshakePair: %v", err) } |