summaryrefslogtreecommitdiffstats
path: root/vendor/github.com/go-ldap/ldap/conn_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/go-ldap/ldap/conn_test.go')
-rw-r--r--vendor/github.com/go-ldap/ldap/conn_test.go295
1 files changed, 291 insertions, 4 deletions
diff --git a/vendor/github.com/go-ldap/ldap/conn_test.go b/vendor/github.com/go-ldap/ldap/conn_test.go
index 8394e5339..10766bbd4 100644
--- a/vendor/github.com/go-ldap/ldap/conn_test.go
+++ b/vendor/github.com/go-ldap/ldap/conn_test.go
@@ -1,9 +1,14 @@
package ldap
import (
+ "bytes"
+ "errors"
+ "io"
"net"
"net/http"
"net/http/httptest"
+ "runtime"
+ "sync"
"testing"
"time"
@@ -27,19 +32,20 @@ func TestUnresponsiveConnection(t *testing.T) {
defer conn.Close()
// Mock a packet
- messageID := conn.nextMessageID()
packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request")
- packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "MessageID"))
+ packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, conn.nextMessageID(), "MessageID"))
bindRequest := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationBindRequest, nil, "Bind Request")
bindRequest.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, 3, "Version"))
packet.AppendChild(bindRequest)
// Send packet and test response
- channel, err := conn.sendMessage(packet)
+ msgCtx, err := conn.sendMessage(packet)
if err != nil {
t.Fatalf("error sending message: %v", err)
}
- packetResponse, ok := <-channel
+ defer conn.finishMessage(msgCtx)
+
+ packetResponse, ok := <-msgCtx.responses
if !ok {
t.Fatalf("no PacketResponse in response channel")
}
@@ -51,3 +57,284 @@ func TestUnresponsiveConnection(t *testing.T) {
t.Fatalf("unexpected error: %v", err)
}
}
+
+// TestFinishMessage tests that we do not enter deadlock when a goroutine makes
+// a request but does not handle all responses from the server.
+func TestConn(t *testing.T) {
+ ptc := newPacketTranslatorConn()
+ defer ptc.Close()
+
+ conn := NewConn(ptc, false)
+ conn.Start()
+
+ // Test sending 5 different requests in series. Ensure that we can
+ // get a response packet from the underlying connection and also
+ // ensure that we can gracefully ignore unhandled responses.
+ for i := 0; i < 5; i++ {
+ t.Logf("serial request %d", i)
+ // Create a message and make sure we can receive responses.
+ msgCtx := testSendRequest(t, ptc, conn)
+ testReceiveResponse(t, ptc, msgCtx)
+
+ // Send a few unhandled responses and finish the message.
+ testSendUnhandledResponsesAndFinish(t, ptc, conn, msgCtx, 5)
+ t.Logf("serial request %d done", i)
+ }
+
+ // Test sending 5 different requests in parallel.
+ var wg sync.WaitGroup
+ for i := 0; i < 5; i++ {
+ wg.Add(1)
+ go func(i int) {
+ defer wg.Done()
+ t.Logf("parallel request %d", i)
+ // Create a message and make sure we can receive responses.
+ msgCtx := testSendRequest(t, ptc, conn)
+ testReceiveResponse(t, ptc, msgCtx)
+
+ // Send a few unhandled responses and finish the message.
+ testSendUnhandledResponsesAndFinish(t, ptc, conn, msgCtx, 5)
+ t.Logf("parallel request %d done", i)
+ }(i)
+ }
+ wg.Wait()
+
+ // We cannot run Close() in a defer because t.FailNow() will run it and
+ // it will block if the processMessage Loop is in a deadlock.
+ conn.Close()
+}
+
+func testSendRequest(t *testing.T, ptc *packetTranslatorConn, conn *Conn) (msgCtx *messageContext) {
+ var msgID int64
+ runWithTimeout(t, time.Second, func() {
+ msgID = conn.nextMessageID()
+ })
+
+ requestPacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request")
+ requestPacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, msgID, "MessageID"))
+
+ var err error
+
+ runWithTimeout(t, time.Second, func() {
+ msgCtx, err = conn.sendMessage(requestPacket)
+ if err != nil {
+ t.Fatalf("unable to send request message: %s", err)
+ }
+ })
+
+ // We should now be able to get this request packet out from the other
+ // side.
+ runWithTimeout(t, time.Second, func() {
+ if _, err = ptc.ReceiveRequest(); err != nil {
+ t.Fatalf("unable to receive request packet: %s", err)
+ }
+ })
+
+ return msgCtx
+}
+
+func testReceiveResponse(t *testing.T, ptc *packetTranslatorConn, msgCtx *messageContext) {
+ // Send a mock response packet.
+ responsePacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Response")
+ responsePacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, msgCtx.id, "MessageID"))
+
+ runWithTimeout(t, time.Second, func() {
+ if err := ptc.SendResponse(responsePacket); err != nil {
+ t.Fatalf("unable to send response packet: %s", err)
+ }
+ })
+
+ // We should be able to receive the packet from the connection.
+ runWithTimeout(t, time.Second, func() {
+ if _, ok := <-msgCtx.responses; !ok {
+ t.Fatal("response channel closed")
+ }
+ })
+}
+
+func testSendUnhandledResponsesAndFinish(t *testing.T, ptc *packetTranslatorConn, conn *Conn, msgCtx *messageContext, numResponses int) {
+ // Send a mock response packet.
+ responsePacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Response")
+ responsePacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, msgCtx.id, "MessageID"))
+
+ // Send extra responses but do not attempt to receive them on the
+ // client side.
+ for i := 0; i < numResponses; i++ {
+ runWithTimeout(t, time.Second, func() {
+ if err := ptc.SendResponse(responsePacket); err != nil {
+ t.Fatalf("unable to send response packet: %s", err)
+ }
+ })
+ }
+
+ // Finally, attempt to finish this message.
+ runWithTimeout(t, time.Second, func() {
+ conn.finishMessage(msgCtx)
+ })
+}
+
+func runWithTimeout(t *testing.T, timeout time.Duration, f func()) {
+ runtime.Gosched()
+
+ done := make(chan struct{})
+ go func() {
+ f()
+ close(done)
+ }()
+
+ runtime.Gosched()
+
+ select {
+ case <-done: // Success!
+ case <-time.After(timeout):
+ _, file, line, _ := runtime.Caller(1)
+ t.Fatalf("%s:%d timed out", file, line)
+ }
+}
+
+// packetTranslatorConn is a helful type which can be used with various tests
+// in this package. It implements the net.Conn interface to be used as an
+// underlying connection for a *ldap.Conn. Most methods are no-ops but the
+// Read() and Write() methods are able to translate ber-encoded packets for
+// testing LDAP requests and responses.
+//
+// Test cases can simulate an LDAP server sending a response by calling the
+// SendResponse() method with a ber-encoded LDAP response packet. Test cases
+// can simulate an LDAP server receiving a request from a client by calling the
+// ReceiveRequest() method which returns a ber-encoded LDAP request packet.
+type packetTranslatorConn struct {
+ lock sync.Mutex
+ isClosed bool
+
+ responseCond sync.Cond
+ requestCond sync.Cond
+
+ responseBuf bytes.Buffer
+ requestBuf bytes.Buffer
+}
+
+var errPacketTranslatorConnClosed = errors.New("connection closed")
+
+func newPacketTranslatorConn() *packetTranslatorConn {
+ conn := &packetTranslatorConn{}
+ conn.responseCond = sync.Cond{L: &conn.lock}
+ conn.requestCond = sync.Cond{L: &conn.lock}
+
+ return conn
+}
+
+// Read is called by the reader() loop to receive response packets. It will
+// block until there are more packet bytes available or this connection is
+// closed.
+func (c *packetTranslatorConn) Read(b []byte) (n int, err error) {
+ c.lock.Lock()
+ defer c.lock.Unlock()
+
+ for !c.isClosed {
+ // Attempt to read data from the response buffer. If it fails
+ // with an EOF, wait and try again.
+ n, err = c.responseBuf.Read(b)
+ if err != io.EOF {
+ return n, err
+ }
+
+ c.responseCond.Wait()
+ }
+
+ return 0, errPacketTranslatorConnClosed
+}
+
+// SendResponse writes the given response packet to the response buffer for
+// this conection, signalling any goroutine waiting to read a response.
+func (c *packetTranslatorConn) SendResponse(packet *ber.Packet) error {
+ c.lock.Lock()
+ defer c.lock.Unlock()
+
+ if c.isClosed {
+ return errPacketTranslatorConnClosed
+ }
+
+ // Signal any goroutine waiting to read a response.
+ defer c.responseCond.Broadcast()
+
+ // Writes to the buffer should always succeed.
+ c.responseBuf.Write(packet.Bytes())
+
+ return nil
+}
+
+// Write is called by the processMessages() loop to send request packets.
+func (c *packetTranslatorConn) Write(b []byte) (n int, err error) {
+ c.lock.Lock()
+ defer c.lock.Unlock()
+
+ if c.isClosed {
+ return 0, errPacketTranslatorConnClosed
+ }
+
+ // Signal any goroutine waiting to read a request.
+ defer c.requestCond.Broadcast()
+
+ // Writes to the buffer should always succeed.
+ return c.requestBuf.Write(b)
+}
+
+// ReceiveRequest attempts to read a request packet from this connection. It
+// will block until it is able to read a full request packet or until this
+// connection is closed.
+func (c *packetTranslatorConn) ReceiveRequest() (*ber.Packet, error) {
+ c.lock.Lock()
+ defer c.lock.Unlock()
+
+ for !c.isClosed {
+ // Attempt to parse a request packet from the request buffer.
+ // If it fails with an unexpected EOF, wait and try again.
+ requestReader := bytes.NewReader(c.requestBuf.Bytes())
+ packet, err := ber.ReadPacket(requestReader)
+ switch err {
+ case io.EOF, io.ErrUnexpectedEOF:
+ c.requestCond.Wait()
+ case nil:
+ // Advance the request buffer by the number of bytes
+ // read to decode the request packet.
+ c.requestBuf.Next(c.requestBuf.Len() - requestReader.Len())
+ return packet, nil
+ default:
+ return nil, err
+ }
+ }
+
+ return nil, errPacketTranslatorConnClosed
+}
+
+// Close closes this connection causing Read() and Write() calls to fail.
+func (c *packetTranslatorConn) Close() error {
+ c.lock.Lock()
+ defer c.lock.Unlock()
+
+ c.isClosed = true
+ c.responseCond.Broadcast()
+ c.requestCond.Broadcast()
+
+ return nil
+}
+
+func (c *packetTranslatorConn) LocalAddr() net.Addr {
+ return (*net.TCPAddr)(nil)
+}
+
+func (c *packetTranslatorConn) RemoteAddr() net.Addr {
+ return (*net.TCPAddr)(nil)
+}
+
+func (c *packetTranslatorConn) SetDeadline(t time.Time) error {
+ return nil
+}
+
+func (c *packetTranslatorConn) SetReadDeadline(t time.Time) error {
+ return nil
+}
+
+func (c *packetTranslatorConn) SetWriteDeadline(t time.Time) error {
+ return nil
+}