summaryrefslogtreecommitdiffstats
path: root/vendor/github.com/go-ldap/ldap/conn.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/go-ldap/ldap/conn.go')
-rw-r--r--vendor/github.com/go-ldap/ldap/conn.go147
1 files changed, 95 insertions, 52 deletions
diff --git a/vendor/github.com/go-ldap/ldap/conn.go b/vendor/github.com/go-ldap/ldap/conn.go
index 6aad628be..b5bd99adb 100644
--- a/vendor/github.com/go-ldap/ldap/conn.go
+++ b/vendor/github.com/go-ldap/ldap/conn.go
@@ -17,18 +17,27 @@ import (
)
const (
- MessageQuit = 0
- MessageRequest = 1
+ // MessageQuit causes the processMessages loop to exit
+ MessageQuit = 0
+ // MessageRequest sends a request to the server
+ MessageRequest = 1
+ // MessageResponse receives a response from the server
MessageResponse = 2
- MessageFinish = 3
- MessageTimeout = 4
+ // MessageFinish indicates the client considers a particular message ID to be finished
+ MessageFinish = 3
+ // MessageTimeout indicates the client-specified timeout for a particular message ID has been reached
+ MessageTimeout = 4
)
+// PacketResponse contains the packet or error encountered reading a response
type PacketResponse struct {
+ // Packet is the packet read from the server
Packet *ber.Packet
- Error error
+ // Error is an error encountered while reading
+ Error error
}
+// ReadPacket returns the packet or an error
func (pr *PacketResponse) ReadPacket() (*ber.Packet, error) {
if (pr == nil) || (pr.Packet == nil && pr.Error == nil) {
return nil, NewError(ErrorNetwork, errors.New("ldap: could not retrieve response"))
@@ -36,11 +45,31 @@ func (pr *PacketResponse) ReadPacket() (*ber.Packet, error) {
return pr.Packet, pr.Error
}
+type messageContext struct {
+ id int64
+ // close(done) should only be called from finishMessage()
+ done chan struct{}
+ // close(responses) should only be called from processMessages(), and only sent to from sendResponse()
+ responses chan *PacketResponse
+}
+
+// sendResponse should only be called within the processMessages() loop which
+// is also responsible for closing the responses channel.
+func (msgCtx *messageContext) sendResponse(packet *PacketResponse) {
+ select {
+ case msgCtx.responses <- packet:
+ // Successfully sent packet to message handler.
+ case <-msgCtx.done:
+ // The request handler is done and will not receive more
+ // packets.
+ }
+}
+
type messagePacket struct {
Op int
MessageID int64
Packet *ber.Packet
- Channel chan *PacketResponse
+ Context *messageContext
}
type sendMessageFlags uint
@@ -54,10 +83,11 @@ type Conn struct {
conn net.Conn
isTLS bool
isClosing bool
+ closeErr error
isStartingTLS bool
Debug debugging
chanConfirm chan bool
- chanResults map[int64]chan *PacketResponse
+ messageContexts map[int64]*messageContext
chanMessage chan *messagePacket
chanMessageID chan int64
wgSender sync.WaitGroup
@@ -111,16 +141,17 @@ func DialTLS(network, addr string, config *tls.Config) (*Conn, error) {
// NewConn returns a new Conn using conn for network I/O.
func NewConn(conn net.Conn, isTLS bool) *Conn {
return &Conn{
- conn: conn,
- chanConfirm: make(chan bool),
- chanMessageID: make(chan int64),
- chanMessage: make(chan *messagePacket, 10),
- chanResults: map[int64]chan *PacketResponse{},
- requestTimeout: 0,
- isTLS: isTLS,
+ conn: conn,
+ chanConfirm: make(chan bool),
+ chanMessageID: make(chan int64),
+ chanMessage: make(chan *messagePacket, 10),
+ messageContexts: map[int64]*messageContext{},
+ requestTimeout: 0,
+ isTLS: isTLS,
}
}
+// Start initializes goroutines to read responses and process messages
func (l *Conn) Start() {
go l.reader()
go l.processMessages()
@@ -148,7 +179,7 @@ func (l *Conn) Close() {
l.wgClose.Wait()
}
-// Sets the time after a request is sent that a MessageTimeout triggers
+// SetTimeout sets the time after a request is sent that a MessageTimeout triggers
func (l *Conn) SetTimeout(timeout time.Duration) {
if timeout > 0 {
l.requestTimeout = timeout
@@ -167,35 +198,31 @@ func (l *Conn) nextMessageID() int64 {
// StartTLS sends the command to start a TLS session and then creates a new TLS Client
func (l *Conn) StartTLS(config *tls.Config) error {
- messageID := l.nextMessageID()
-
if l.isTLS {
return NewError(ErrorNetwork, errors.New("ldap: already encrypted"))
}
packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request")
- packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "MessageID"))
+ packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, l.nextMessageID(), "MessageID"))
request := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationExtendedRequest, nil, "Start TLS")
request.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, 0, "1.3.6.1.4.1.1466.20037", "TLS Extended Command"))
packet.AppendChild(request)
l.Debug.PrintPacket(packet)
- channel, err := l.sendMessageWithFlags(packet, startTLS)
+ msgCtx, err := l.sendMessageWithFlags(packet, startTLS)
if err != nil {
return err
}
- if channel == nil {
- return NewError(ErrorNetwork, errors.New("ldap: could not send message"))
- }
+ defer l.finishMessage(msgCtx)
- l.Debug.Printf("%d: waiting for response", messageID)
- defer l.finishMessage(messageID)
- packetResponse, ok := <-channel
+ l.Debug.Printf("%d: waiting for response", msgCtx.id)
+
+ packetResponse, ok := <-msgCtx.responses
if !ok {
- return NewError(ErrorNetwork, errors.New("ldap: channel closed"))
+ return NewError(ErrorNetwork, errors.New("ldap: response channel closed"))
}
packet, err = packetResponse.ReadPacket()
- l.Debug.Printf("%d: got response %p", messageID, packet)
+ l.Debug.Printf("%d: got response %p", msgCtx.id, packet)
if err != nil {
return err
}
@@ -226,11 +253,11 @@ func (l *Conn) StartTLS(config *tls.Config) error {
return nil
}
-func (l *Conn) sendMessage(packet *ber.Packet) (chan *PacketResponse, error) {
+func (l *Conn) sendMessage(packet *ber.Packet) (*messageContext, error) {
return l.sendMessageWithFlags(packet, 0)
}
-func (l *Conn) sendMessageWithFlags(packet *ber.Packet, flags sendMessageFlags) (chan *PacketResponse, error) {
+func (l *Conn) sendMessageWithFlags(packet *ber.Packet, flags sendMessageFlags) (*messageContext, error) {
if l.isClosing {
return nil, NewError(ErrorNetwork, errors.New("ldap: connection closed"))
}
@@ -238,32 +265,38 @@ func (l *Conn) sendMessageWithFlags(packet *ber.Packet, flags sendMessageFlags)
l.Debug.Printf("flags&startTLS = %d", flags&startTLS)
if l.isStartingTLS {
l.messageMutex.Unlock()
- return nil, NewError(ErrorNetwork, errors.New("ldap: connection is in startls phase."))
+ return nil, NewError(ErrorNetwork, errors.New("ldap: connection is in startls phase"))
}
if flags&startTLS != 0 {
if l.outstandingRequests != 0 {
l.messageMutex.Unlock()
return nil, NewError(ErrorNetwork, errors.New("ldap: cannot StartTLS with outstanding requests"))
- } else {
- l.isStartingTLS = true
}
+ l.isStartingTLS = true
}
l.outstandingRequests++
l.messageMutex.Unlock()
- out := make(chan *PacketResponse)
+ responses := make(chan *PacketResponse)
+ messageID := packet.Children[0].Value.(int64)
message := &messagePacket{
Op: MessageRequest,
- MessageID: packet.Children[0].Value.(int64),
+ MessageID: messageID,
Packet: packet,
- Channel: out,
+ Context: &messageContext{
+ id: messageID,
+ done: make(chan struct{}),
+ responses: responses,
+ },
}
l.sendProcessMessage(message)
- return out, nil
+ return message.Context, nil
}
-func (l *Conn) finishMessage(messageID int64) {
+func (l *Conn) finishMessage(msgCtx *messageContext) {
+ close(msgCtx.done)
+
if l.isClosing {
return
}
@@ -277,7 +310,7 @@ func (l *Conn) finishMessage(messageID int64) {
message := &messagePacket{
Op: MessageFinish,
- MessageID: messageID,
+ MessageID: msgCtx.id,
}
l.sendProcessMessage(message)
}
@@ -297,10 +330,15 @@ func (l *Conn) processMessages() {
if err := recover(); err != nil {
log.Printf("ldap: recovered panic in processMessages: %v", err)
}
- for messageID, channel := range l.chanResults {
+ for messageID, msgCtx := range l.messageContexts {
+ // If we are closing due to an error, inform anyone who
+ // is waiting about the error.
+ if l.isClosing && l.closeErr != nil {
+ msgCtx.sendResponse(&PacketResponse{Error: l.closeErr})
+ }
l.Debug.Printf("Closing channel for MessageID %d", messageID)
- close(channel)
- delete(l.chanResults, messageID)
+ close(msgCtx.responses)
+ delete(l.messageContexts, messageID)
}
close(l.chanMessageID)
l.chanConfirm <- true
@@ -324,15 +362,20 @@ func (l *Conn) processMessages() {
case MessageRequest:
// Add to message list and write to network
l.Debug.Printf("Sending message %d", message.MessageID)
- l.chanResults[message.MessageID] = message.Channel
buf := message.Packet.Bytes()
_, err := l.conn.Write(buf)
if err != nil {
l.Debug.Printf("Error Sending Message: %s", err.Error())
+ message.Context.sendResponse(&PacketResponse{Error: fmt.Errorf("unable to send request: %s", err)})
+ close(message.Context.responses)
break
}
+ // Only add to messageContexts if we were able to
+ // successfully write the message.
+ l.messageContexts[message.MessageID] = message.Context
+
// Add timeout if defined
if l.requestTimeout > 0 {
go func() {
@@ -351,8 +394,8 @@ func (l *Conn) processMessages() {
}
case MessageResponse:
l.Debug.Printf("Receiving message %d", message.MessageID)
- if chanResult, ok := l.chanResults[message.MessageID]; ok {
- chanResult <- &PacketResponse{message.Packet, nil}
+ if msgCtx, ok := l.messageContexts[message.MessageID]; ok {
+ msgCtx.sendResponse(&PacketResponse{message.Packet, nil})
} else {
log.Printf("Received unexpected message %d, %v", message.MessageID, l.isClosing)
ber.PrintPacket(message.Packet)
@@ -360,17 +403,17 @@ func (l *Conn) processMessages() {
case MessageTimeout:
// Handle the timeout by closing the channel
// All reads will return immediately
- if chanResult, ok := l.chanResults[message.MessageID]; ok {
- chanResult <- &PacketResponse{message.Packet, errors.New("ldap: connection timed out")}
+ if msgCtx, ok := l.messageContexts[message.MessageID]; ok {
l.Debug.Printf("Receiving message timeout for %d", message.MessageID)
- delete(l.chanResults, message.MessageID)
- close(chanResult)
+ msgCtx.sendResponse(&PacketResponse{message.Packet, errors.New("ldap: connection timed out")})
+ delete(l.messageContexts, message.MessageID)
+ close(msgCtx.responses)
}
case MessageFinish:
l.Debug.Printf("Finished message %d", message.MessageID)
- if chanResult, ok := l.chanResults[message.MessageID]; ok {
- close(chanResult)
- delete(l.chanResults, message.MessageID)
+ if msgCtx, ok := l.messageContexts[message.MessageID]; ok {
+ delete(l.messageContexts, message.MessageID)
+ close(msgCtx.responses)
}
}
}
@@ -397,6 +440,7 @@ func (l *Conn) reader() {
if err != nil {
// A read error is expected here if we are closing the connection...
if !l.isClosing {
+ l.closeErr = fmt.Errorf("unable to read LDAP response packet: %s", err)
l.Debug.Printf("reader error: %s", err.Error())
}
return
@@ -419,6 +463,5 @@ func (l *Conn) reader() {
if !l.sendProcessMessage(message) {
return
}
-
}
}