summaryrefslogtreecommitdiffstats
path: root/vendor/github.com/go-ldap/ldap/conn.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/go-ldap/ldap/conn.go')
-rw-r--r--vendor/github.com/go-ldap/ldap/conn.go73
1 files changed, 38 insertions, 35 deletions
diff --git a/vendor/github.com/go-ldap/ldap/conn.go b/vendor/github.com/go-ldap/ldap/conn.go
index b5bd99adb..eb28eb472 100644
--- a/vendor/github.com/go-ldap/ldap/conn.go
+++ b/vendor/github.com/go-ldap/ldap/conn.go
@@ -11,6 +11,7 @@ import (
"log"
"net"
"sync"
+ "sync/atomic"
"time"
"gopkg.in/asn1-ber.v1"
@@ -82,20 +83,18 @@ const (
type Conn struct {
conn net.Conn
isTLS bool
- isClosing bool
- closeErr error
+ closing uint32
+ closeErr atomicValue
isStartingTLS bool
Debug debugging
- chanConfirm chan bool
+ chanConfirm chan struct{}
messageContexts map[int64]*messageContext
chanMessage chan *messagePacket
chanMessageID chan int64
- wgSender sync.WaitGroup
wgClose sync.WaitGroup
- once sync.Once
outstandingRequests uint
messageMutex sync.Mutex
- requestTimeout time.Duration
+ requestTimeout int64
}
var _ Client = &Conn{}
@@ -142,7 +141,7 @@ func DialTLS(network, addr string, config *tls.Config) (*Conn, error) {
func NewConn(conn net.Conn, isTLS bool) *Conn {
return &Conn{
conn: conn,
- chanConfirm: make(chan bool),
+ chanConfirm: make(chan struct{}),
chanMessageID: make(chan int64),
chanMessage: make(chan *messagePacket, 10),
messageContexts: map[int64]*messageContext{},
@@ -158,12 +157,22 @@ func (l *Conn) Start() {
l.wgClose.Add(1)
}
+// isClosing returns whether or not we're currently closing.
+func (l *Conn) isClosing() bool {
+ return atomic.LoadUint32(&l.closing) == 1
+}
+
+// setClosing sets the closing value to true
+func (l *Conn) setClosing() bool {
+ return atomic.CompareAndSwapUint32(&l.closing, 0, 1)
+}
+
// Close closes the connection.
func (l *Conn) Close() {
- l.once.Do(func() {
- l.isClosing = true
- l.wgSender.Wait()
+ l.messageMutex.Lock()
+ defer l.messageMutex.Unlock()
+ if l.setClosing() {
l.Debug.Printf("Sending quit message and waiting for confirmation")
l.chanMessage <- &messagePacket{Op: MessageQuit}
<-l.chanConfirm
@@ -171,27 +180,25 @@ func (l *Conn) Close() {
l.Debug.Printf("Closing network connection")
if err := l.conn.Close(); err != nil {
- log.Print(err)
+ log.Println(err)
}
l.wgClose.Done()
- })
+ }
l.wgClose.Wait()
}
// SetTimeout sets the time after a request is sent that a MessageTimeout triggers
func (l *Conn) SetTimeout(timeout time.Duration) {
if timeout > 0 {
- l.requestTimeout = timeout
+ atomic.StoreInt64(&l.requestTimeout, int64(timeout))
}
}
// Returns the next available messageID
func (l *Conn) nextMessageID() int64 {
- if l.chanMessageID != nil {
- if messageID, ok := <-l.chanMessageID; ok {
- return messageID
- }
+ if messageID, ok := <-l.chanMessageID; ok {
+ return messageID
}
return 0
}
@@ -258,7 +265,7 @@ func (l *Conn) sendMessage(packet *ber.Packet) (*messageContext, error) {
}
func (l *Conn) sendMessageWithFlags(packet *ber.Packet, flags sendMessageFlags) (*messageContext, error) {
- if l.isClosing {
+ if l.isClosing() {
return nil, NewError(ErrorNetwork, errors.New("ldap: connection closed"))
}
l.messageMutex.Lock()
@@ -297,7 +304,7 @@ func (l *Conn) sendMessageWithFlags(packet *ber.Packet, flags sendMessageFlags)
func (l *Conn) finishMessage(msgCtx *messageContext) {
close(msgCtx.done)
- if l.isClosing {
+ if l.isClosing() {
return
}
@@ -316,12 +323,12 @@ func (l *Conn) finishMessage(msgCtx *messageContext) {
}
func (l *Conn) sendProcessMessage(message *messagePacket) bool {
- if l.isClosing {
+ l.messageMutex.Lock()
+ defer l.messageMutex.Unlock()
+ if l.isClosing() {
return false
}
- l.wgSender.Add(1)
l.chanMessage <- message
- l.wgSender.Done()
return true
}
@@ -333,15 +340,14 @@ func (l *Conn) processMessages() {
for messageID, msgCtx := range l.messageContexts {
// If we are closing due to an error, inform anyone who
// is waiting about the error.
- if l.isClosing && l.closeErr != nil {
- msgCtx.sendResponse(&PacketResponse{Error: l.closeErr})
+ if l.isClosing() && l.closeErr.Load() != nil {
+ msgCtx.sendResponse(&PacketResponse{Error: l.closeErr.Load().(error)})
}
l.Debug.Printf("Closing channel for MessageID %d", messageID)
close(msgCtx.responses)
delete(l.messageContexts, messageID)
}
close(l.chanMessageID)
- l.chanConfirm <- true
close(l.chanConfirm)
}()
@@ -350,11 +356,7 @@ func (l *Conn) processMessages() {
select {
case l.chanMessageID <- messageID:
messageID++
- case message, ok := <-l.chanMessage:
- if !ok {
- l.Debug.Printf("Shutting down - message channel is closed")
- return
- }
+ case message := <-l.chanMessage:
switch message.Op {
case MessageQuit:
l.Debug.Printf("Shutting down - quit message received")
@@ -377,14 +379,15 @@ func (l *Conn) processMessages() {
l.messageContexts[message.MessageID] = message.Context
// Add timeout if defined
- if l.requestTimeout > 0 {
+ requestTimeout := time.Duration(atomic.LoadInt64(&l.requestTimeout))
+ if requestTimeout > 0 {
go func() {
defer func() {
if err := recover(); err != nil {
log.Printf("ldap: recovered panic in RequestTimeout: %v", err)
}
}()
- time.Sleep(l.requestTimeout)
+ time.Sleep(requestTimeout)
timeoutMessage := &messagePacket{
Op: MessageTimeout,
MessageID: message.MessageID,
@@ -397,7 +400,7 @@ func (l *Conn) processMessages() {
if msgCtx, ok := l.messageContexts[message.MessageID]; ok {
msgCtx.sendResponse(&PacketResponse{message.Packet, nil})
} else {
- log.Printf("Received unexpected message %d, %v", message.MessageID, l.isClosing)
+ log.Printf("Received unexpected message %d, %v", message.MessageID, l.isClosing())
ber.PrintPacket(message.Packet)
}
case MessageTimeout:
@@ -439,8 +442,8 @@ func (l *Conn) reader() {
packet, err := ber.ReadPacket(l.conn)
if err != nil {
// A read error is expected here if we are closing the connection...
- if !l.isClosing {
- l.closeErr = fmt.Errorf("unable to read LDAP response packet: %s", err)
+ if !l.isClosing() {
+ l.closeErr.Store(fmt.Errorf("unable to read LDAP response packet: %s", err))
l.Debug.Printf("reader error: %s", err.Error())
}
return