diff options
Diffstat (limited to 'vendor/github.com/go-ldap/ldap/conn_test.go')
-rw-r--r-- | vendor/github.com/go-ldap/ldap/conn_test.go | 295 |
1 files changed, 291 insertions, 4 deletions
diff --git a/vendor/github.com/go-ldap/ldap/conn_test.go b/vendor/github.com/go-ldap/ldap/conn_test.go index 8394e5339..10766bbd4 100644 --- a/vendor/github.com/go-ldap/ldap/conn_test.go +++ b/vendor/github.com/go-ldap/ldap/conn_test.go @@ -1,9 +1,14 @@ package ldap import ( + "bytes" + "errors" + "io" "net" "net/http" "net/http/httptest" + "runtime" + "sync" "testing" "time" @@ -27,19 +32,20 @@ func TestUnresponsiveConnection(t *testing.T) { defer conn.Close() // Mock a packet - messageID := conn.nextMessageID() packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request") - packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "MessageID")) + packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, conn.nextMessageID(), "MessageID")) bindRequest := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationBindRequest, nil, "Bind Request") bindRequest.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, 3, "Version")) packet.AppendChild(bindRequest) // Send packet and test response - channel, err := conn.sendMessage(packet) + msgCtx, err := conn.sendMessage(packet) if err != nil { t.Fatalf("error sending message: %v", err) } - packetResponse, ok := <-channel + defer conn.finishMessage(msgCtx) + + packetResponse, ok := <-msgCtx.responses if !ok { t.Fatalf("no PacketResponse in response channel") } @@ -51,3 +57,284 @@ func TestUnresponsiveConnection(t *testing.T) { t.Fatalf("unexpected error: %v", err) } } + +// TestFinishMessage tests that we do not enter deadlock when a goroutine makes +// a request but does not handle all responses from the server. +func TestConn(t *testing.T) { + ptc := newPacketTranslatorConn() + defer ptc.Close() + + conn := NewConn(ptc, false) + conn.Start() + + // Test sending 5 different requests in series. Ensure that we can + // get a response packet from the underlying connection and also + // ensure that we can gracefully ignore unhandled responses. + for i := 0; i < 5; i++ { + t.Logf("serial request %d", i) + // Create a message and make sure we can receive responses. + msgCtx := testSendRequest(t, ptc, conn) + testReceiveResponse(t, ptc, msgCtx) + + // Send a few unhandled responses and finish the message. + testSendUnhandledResponsesAndFinish(t, ptc, conn, msgCtx, 5) + t.Logf("serial request %d done", i) + } + + // Test sending 5 different requests in parallel. + var wg sync.WaitGroup + for i := 0; i < 5; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + t.Logf("parallel request %d", i) + // Create a message and make sure we can receive responses. + msgCtx := testSendRequest(t, ptc, conn) + testReceiveResponse(t, ptc, msgCtx) + + // Send a few unhandled responses and finish the message. + testSendUnhandledResponsesAndFinish(t, ptc, conn, msgCtx, 5) + t.Logf("parallel request %d done", i) + }(i) + } + wg.Wait() + + // We cannot run Close() in a defer because t.FailNow() will run it and + // it will block if the processMessage Loop is in a deadlock. + conn.Close() +} + +func testSendRequest(t *testing.T, ptc *packetTranslatorConn, conn *Conn) (msgCtx *messageContext) { + var msgID int64 + runWithTimeout(t, time.Second, func() { + msgID = conn.nextMessageID() + }) + + requestPacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request") + requestPacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, msgID, "MessageID")) + + var err error + + runWithTimeout(t, time.Second, func() { + msgCtx, err = conn.sendMessage(requestPacket) + if err != nil { + t.Fatalf("unable to send request message: %s", err) + } + }) + + // We should now be able to get this request packet out from the other + // side. + runWithTimeout(t, time.Second, func() { + if _, err = ptc.ReceiveRequest(); err != nil { + t.Fatalf("unable to receive request packet: %s", err) + } + }) + + return msgCtx +} + +func testReceiveResponse(t *testing.T, ptc *packetTranslatorConn, msgCtx *messageContext) { + // Send a mock response packet. + responsePacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Response") + responsePacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, msgCtx.id, "MessageID")) + + runWithTimeout(t, time.Second, func() { + if err := ptc.SendResponse(responsePacket); err != nil { + t.Fatalf("unable to send response packet: %s", err) + } + }) + + // We should be able to receive the packet from the connection. + runWithTimeout(t, time.Second, func() { + if _, ok := <-msgCtx.responses; !ok { + t.Fatal("response channel closed") + } + }) +} + +func testSendUnhandledResponsesAndFinish(t *testing.T, ptc *packetTranslatorConn, conn *Conn, msgCtx *messageContext, numResponses int) { + // Send a mock response packet. + responsePacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Response") + responsePacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, msgCtx.id, "MessageID")) + + // Send extra responses but do not attempt to receive them on the + // client side. + for i := 0; i < numResponses; i++ { + runWithTimeout(t, time.Second, func() { + if err := ptc.SendResponse(responsePacket); err != nil { + t.Fatalf("unable to send response packet: %s", err) + } + }) + } + + // Finally, attempt to finish this message. + runWithTimeout(t, time.Second, func() { + conn.finishMessage(msgCtx) + }) +} + +func runWithTimeout(t *testing.T, timeout time.Duration, f func()) { + runtime.Gosched() + + done := make(chan struct{}) + go func() { + f() + close(done) + }() + + runtime.Gosched() + + select { + case <-done: // Success! + case <-time.After(timeout): + _, file, line, _ := runtime.Caller(1) + t.Fatalf("%s:%d timed out", file, line) + } +} + +// packetTranslatorConn is a helful type which can be used with various tests +// in this package. It implements the net.Conn interface to be used as an +// underlying connection for a *ldap.Conn. Most methods are no-ops but the +// Read() and Write() methods are able to translate ber-encoded packets for +// testing LDAP requests and responses. +// +// Test cases can simulate an LDAP server sending a response by calling the +// SendResponse() method with a ber-encoded LDAP response packet. Test cases +// can simulate an LDAP server receiving a request from a client by calling the +// ReceiveRequest() method which returns a ber-encoded LDAP request packet. +type packetTranslatorConn struct { + lock sync.Mutex + isClosed bool + + responseCond sync.Cond + requestCond sync.Cond + + responseBuf bytes.Buffer + requestBuf bytes.Buffer +} + +var errPacketTranslatorConnClosed = errors.New("connection closed") + +func newPacketTranslatorConn() *packetTranslatorConn { + conn := &packetTranslatorConn{} + conn.responseCond = sync.Cond{L: &conn.lock} + conn.requestCond = sync.Cond{L: &conn.lock} + + return conn +} + +// Read is called by the reader() loop to receive response packets. It will +// block until there are more packet bytes available or this connection is +// closed. +func (c *packetTranslatorConn) Read(b []byte) (n int, err error) { + c.lock.Lock() + defer c.lock.Unlock() + + for !c.isClosed { + // Attempt to read data from the response buffer. If it fails + // with an EOF, wait and try again. + n, err = c.responseBuf.Read(b) + if err != io.EOF { + return n, err + } + + c.responseCond.Wait() + } + + return 0, errPacketTranslatorConnClosed +} + +// SendResponse writes the given response packet to the response buffer for +// this conection, signalling any goroutine waiting to read a response. +func (c *packetTranslatorConn) SendResponse(packet *ber.Packet) error { + c.lock.Lock() + defer c.lock.Unlock() + + if c.isClosed { + return errPacketTranslatorConnClosed + } + + // Signal any goroutine waiting to read a response. + defer c.responseCond.Broadcast() + + // Writes to the buffer should always succeed. + c.responseBuf.Write(packet.Bytes()) + + return nil +} + +// Write is called by the processMessages() loop to send request packets. +func (c *packetTranslatorConn) Write(b []byte) (n int, err error) { + c.lock.Lock() + defer c.lock.Unlock() + + if c.isClosed { + return 0, errPacketTranslatorConnClosed + } + + // Signal any goroutine waiting to read a request. + defer c.requestCond.Broadcast() + + // Writes to the buffer should always succeed. + return c.requestBuf.Write(b) +} + +// ReceiveRequest attempts to read a request packet from this connection. It +// will block until it is able to read a full request packet or until this +// connection is closed. +func (c *packetTranslatorConn) ReceiveRequest() (*ber.Packet, error) { + c.lock.Lock() + defer c.lock.Unlock() + + for !c.isClosed { + // Attempt to parse a request packet from the request buffer. + // If it fails with an unexpected EOF, wait and try again. + requestReader := bytes.NewReader(c.requestBuf.Bytes()) + packet, err := ber.ReadPacket(requestReader) + switch err { + case io.EOF, io.ErrUnexpectedEOF: + c.requestCond.Wait() + case nil: + // Advance the request buffer by the number of bytes + // read to decode the request packet. + c.requestBuf.Next(c.requestBuf.Len() - requestReader.Len()) + return packet, nil + default: + return nil, err + } + } + + return nil, errPacketTranslatorConnClosed +} + +// Close closes this connection causing Read() and Write() calls to fail. +func (c *packetTranslatorConn) Close() error { + c.lock.Lock() + defer c.lock.Unlock() + + c.isClosed = true + c.responseCond.Broadcast() + c.requestCond.Broadcast() + + return nil +} + +func (c *packetTranslatorConn) LocalAddr() net.Addr { + return (*net.TCPAddr)(nil) +} + +func (c *packetTranslatorConn) RemoteAddr() net.Addr { + return (*net.TCPAddr)(nil) +} + +func (c *packetTranslatorConn) SetDeadline(t time.Time) error { + return nil +} + +func (c *packetTranslatorConn) SetReadDeadline(t time.Time) error { + return nil +} + +func (c *packetTranslatorConn) SetWriteDeadline(t time.Time) error { + return nil +} |