summaryrefslogtreecommitdiffstats
path: root/vendor/github.com/go-redis/redis/internal/pool
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/go-redis/redis/internal/pool')
-rw-r--r--vendor/github.com/go-redis/redis/internal/pool/bench_test.go80
-rw-r--r--vendor/github.com/go-redis/redis/internal/pool/conn.go78
-rw-r--r--vendor/github.com/go-redis/redis/internal/pool/main_test.go35
-rw-r--r--vendor/github.com/go-redis/redis/internal/pool/pool.go367
-rw-r--r--vendor/github.com/go-redis/redis/internal/pool/pool_single.go55
-rw-r--r--vendor/github.com/go-redis/redis/internal/pool/pool_sticky.go123
-rw-r--r--vendor/github.com/go-redis/redis/internal/pool/pool_test.go241
7 files changed, 979 insertions, 0 deletions
diff --git a/vendor/github.com/go-redis/redis/internal/pool/bench_test.go b/vendor/github.com/go-redis/redis/internal/pool/bench_test.go
new file mode 100644
index 000000000..e0bb52446
--- /dev/null
+++ b/vendor/github.com/go-redis/redis/internal/pool/bench_test.go
@@ -0,0 +1,80 @@
+package pool_test
+
+import (
+ "testing"
+ "time"
+
+ "github.com/go-redis/redis/internal/pool"
+)
+
+func benchmarkPoolGetPut(b *testing.B, poolSize int) {
+ connPool := pool.NewConnPool(&pool.Options{
+ Dialer: dummyDialer,
+ PoolSize: poolSize,
+ PoolTimeout: time.Second,
+ IdleTimeout: time.Hour,
+ IdleCheckFrequency: time.Hour,
+ })
+
+ b.ResetTimer()
+
+ b.RunParallel(func(pb *testing.PB) {
+ for pb.Next() {
+ cn, _, err := connPool.Get()
+ if err != nil {
+ b.Fatal(err)
+ }
+ if err = connPool.Put(cn); err != nil {
+ b.Fatal(err)
+ }
+ }
+ })
+}
+
+func BenchmarkPoolGetPut10Conns(b *testing.B) {
+ benchmarkPoolGetPut(b, 10)
+}
+
+func BenchmarkPoolGetPut100Conns(b *testing.B) {
+ benchmarkPoolGetPut(b, 100)
+}
+
+func BenchmarkPoolGetPut1000Conns(b *testing.B) {
+ benchmarkPoolGetPut(b, 1000)
+}
+
+func benchmarkPoolGetRemove(b *testing.B, poolSize int) {
+ connPool := pool.NewConnPool(&pool.Options{
+ Dialer: dummyDialer,
+ PoolSize: poolSize,
+ PoolTimeout: time.Second,
+ IdleTimeout: time.Hour,
+ IdleCheckFrequency: time.Hour,
+ })
+
+ b.ResetTimer()
+
+ b.RunParallel(func(pb *testing.PB) {
+ for pb.Next() {
+ cn, _, err := connPool.Get()
+ if err != nil {
+ b.Fatal(err)
+ }
+ if err := connPool.Remove(cn); err != nil {
+ b.Fatal(err)
+ }
+ }
+ })
+}
+
+func BenchmarkPoolGetRemove10Conns(b *testing.B) {
+ benchmarkPoolGetRemove(b, 10)
+}
+
+func BenchmarkPoolGetRemove100Conns(b *testing.B) {
+ benchmarkPoolGetRemove(b, 100)
+}
+
+func BenchmarkPoolGetRemove1000Conns(b *testing.B) {
+ benchmarkPoolGetRemove(b, 1000)
+}
diff --git a/vendor/github.com/go-redis/redis/internal/pool/conn.go b/vendor/github.com/go-redis/redis/internal/pool/conn.go
new file mode 100644
index 000000000..8af51d9de
--- /dev/null
+++ b/vendor/github.com/go-redis/redis/internal/pool/conn.go
@@ -0,0 +1,78 @@
+package pool
+
+import (
+ "net"
+ "sync/atomic"
+ "time"
+
+ "github.com/go-redis/redis/internal/proto"
+)
+
+var noDeadline = time.Time{}
+
+type Conn struct {
+ netConn net.Conn
+
+ Rd *proto.Reader
+ Wb *proto.WriteBuffer
+
+ Inited bool
+ usedAt atomic.Value
+}
+
+func NewConn(netConn net.Conn) *Conn {
+ cn := &Conn{
+ netConn: netConn,
+ Wb: proto.NewWriteBuffer(),
+ }
+ cn.Rd = proto.NewReader(cn.netConn)
+ cn.SetUsedAt(time.Now())
+ return cn
+}
+
+func (cn *Conn) UsedAt() time.Time {
+ return cn.usedAt.Load().(time.Time)
+}
+
+func (cn *Conn) SetUsedAt(tm time.Time) {
+ cn.usedAt.Store(tm)
+}
+
+func (cn *Conn) SetNetConn(netConn net.Conn) {
+ cn.netConn = netConn
+ cn.Rd.Reset(netConn)
+}
+
+func (cn *Conn) IsStale(timeout time.Duration) bool {
+ return timeout > 0 && time.Since(cn.UsedAt()) > timeout
+}
+
+func (cn *Conn) SetReadTimeout(timeout time.Duration) error {
+ now := time.Now()
+ cn.SetUsedAt(now)
+ if timeout > 0 {
+ return cn.netConn.SetReadDeadline(now.Add(timeout))
+ }
+ return cn.netConn.SetReadDeadline(noDeadline)
+}
+
+func (cn *Conn) SetWriteTimeout(timeout time.Duration) error {
+ now := time.Now()
+ cn.SetUsedAt(now)
+ if timeout > 0 {
+ return cn.netConn.SetWriteDeadline(now.Add(timeout))
+ }
+ return cn.netConn.SetWriteDeadline(noDeadline)
+}
+
+func (cn *Conn) Write(b []byte) (int, error) {
+ return cn.netConn.Write(b)
+}
+
+func (cn *Conn) RemoteAddr() net.Addr {
+ return cn.netConn.RemoteAddr()
+}
+
+func (cn *Conn) Close() error {
+ return cn.netConn.Close()
+}
diff --git a/vendor/github.com/go-redis/redis/internal/pool/main_test.go b/vendor/github.com/go-redis/redis/internal/pool/main_test.go
new file mode 100644
index 000000000..43afe3fa9
--- /dev/null
+++ b/vendor/github.com/go-redis/redis/internal/pool/main_test.go
@@ -0,0 +1,35 @@
+package pool_test
+
+import (
+ "net"
+ "sync"
+ "testing"
+
+ . "github.com/onsi/ginkgo"
+ . "github.com/onsi/gomega"
+)
+
+func TestGinkgoSuite(t *testing.T) {
+ RegisterFailHandler(Fail)
+ RunSpecs(t, "pool")
+}
+
+func perform(n int, cbs ...func(int)) {
+ var wg sync.WaitGroup
+ for _, cb := range cbs {
+ for i := 0; i < n; i++ {
+ wg.Add(1)
+ go func(cb func(int), i int) {
+ defer GinkgoRecover()
+ defer wg.Done()
+
+ cb(i)
+ }(cb, i)
+ }
+ }
+ wg.Wait()
+}
+
+func dummyDialer() (net.Conn, error) {
+ return &net.TCPConn{}, nil
+}
diff --git a/vendor/github.com/go-redis/redis/internal/pool/pool.go b/vendor/github.com/go-redis/redis/internal/pool/pool.go
new file mode 100644
index 000000000..a4e650847
--- /dev/null
+++ b/vendor/github.com/go-redis/redis/internal/pool/pool.go
@@ -0,0 +1,367 @@
+package pool
+
+import (
+ "errors"
+ "net"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ "github.com/go-redis/redis/internal"
+)
+
+var ErrClosed = errors.New("redis: client is closed")
+var ErrPoolTimeout = errors.New("redis: connection pool timeout")
+
+var timers = sync.Pool{
+ New: func() interface{} {
+ t := time.NewTimer(time.Hour)
+ t.Stop()
+ return t
+ },
+}
+
+// Stats contains pool state information and accumulated stats.
+type Stats struct {
+ Requests uint32 // number of times a connection was requested by the pool
+ Hits uint32 // number of times free connection was found in the pool
+ Timeouts uint32 // number of times a wait timeout occurred
+
+ TotalConns uint32 // the number of total connections in the pool
+ FreeConns uint32 // the number of free connections in the pool
+}
+
+type Pooler interface {
+ NewConn() (*Conn, error)
+ CloseConn(*Conn) error
+
+ Get() (*Conn, bool, error)
+ Put(*Conn) error
+ Remove(*Conn) error
+
+ Len() int
+ FreeLen() int
+ Stats() *Stats
+
+ Close() error
+}
+
+type Options struct {
+ Dialer func() (net.Conn, error)
+ OnClose func(*Conn) error
+
+ PoolSize int
+ PoolTimeout time.Duration
+ IdleTimeout time.Duration
+ IdleCheckFrequency time.Duration
+}
+
+type ConnPool struct {
+ opt *Options
+
+ dialErrorsNum uint32 // atomic
+ _lastDialError atomic.Value
+
+ queue chan struct{}
+
+ connsMu sync.Mutex
+ conns []*Conn
+
+ freeConnsMu sync.Mutex
+ freeConns []*Conn
+
+ stats Stats
+
+ _closed uint32 // atomic
+}
+
+var _ Pooler = (*ConnPool)(nil)
+
+func NewConnPool(opt *Options) *ConnPool {
+ p := &ConnPool{
+ opt: opt,
+
+ queue: make(chan struct{}, opt.PoolSize),
+ conns: make([]*Conn, 0, opt.PoolSize),
+ freeConns: make([]*Conn, 0, opt.PoolSize),
+ }
+ if opt.IdleTimeout > 0 && opt.IdleCheckFrequency > 0 {
+ go p.reaper(opt.IdleCheckFrequency)
+ }
+ return p
+}
+
+func (p *ConnPool) NewConn() (*Conn, error) {
+ if p.closed() {
+ return nil, ErrClosed
+ }
+
+ if atomic.LoadUint32(&p.dialErrorsNum) >= uint32(p.opt.PoolSize) {
+ return nil, p.lastDialError()
+ }
+
+ netConn, err := p.opt.Dialer()
+ if err != nil {
+ p.setLastDialError(err)
+ if atomic.AddUint32(&p.dialErrorsNum, 1) == uint32(p.opt.PoolSize) {
+ go p.tryDial()
+ }
+ return nil, err
+ }
+
+ cn := NewConn(netConn)
+ p.connsMu.Lock()
+ p.conns = append(p.conns, cn)
+ p.connsMu.Unlock()
+
+ return cn, nil
+}
+
+func (p *ConnPool) tryDial() {
+ for {
+ conn, err := p.opt.Dialer()
+ if err != nil {
+ p.setLastDialError(err)
+ time.Sleep(time.Second)
+ continue
+ }
+
+ atomic.StoreUint32(&p.dialErrorsNum, 0)
+ _ = conn.Close()
+ return
+ }
+}
+
+func (p *ConnPool) setLastDialError(err error) {
+ p._lastDialError.Store(err)
+}
+
+func (p *ConnPool) lastDialError() error {
+ return p._lastDialError.Load().(error)
+}
+
+// Get returns existed connection from the pool or creates a new one.
+func (p *ConnPool) Get() (*Conn, bool, error) {
+ if p.closed() {
+ return nil, false, ErrClosed
+ }
+
+ atomic.AddUint32(&p.stats.Requests, 1)
+
+ select {
+ case p.queue <- struct{}{}:
+ default:
+ timer := timers.Get().(*time.Timer)
+ timer.Reset(p.opt.PoolTimeout)
+
+ select {
+ case p.queue <- struct{}{}:
+ if !timer.Stop() {
+ <-timer.C
+ }
+ timers.Put(timer)
+ case <-timer.C:
+ timers.Put(timer)
+ atomic.AddUint32(&p.stats.Timeouts, 1)
+ return nil, false, ErrPoolTimeout
+ }
+ }
+
+ for {
+ p.freeConnsMu.Lock()
+ cn := p.popFree()
+ p.freeConnsMu.Unlock()
+
+ if cn == nil {
+ break
+ }
+
+ if cn.IsStale(p.opt.IdleTimeout) {
+ p.CloseConn(cn)
+ continue
+ }
+
+ atomic.AddUint32(&p.stats.Hits, 1)
+ return cn, false, nil
+ }
+
+ newcn, err := p.NewConn()
+ if err != nil {
+ <-p.queue
+ return nil, false, err
+ }
+
+ return newcn, true, nil
+}
+
+func (p *ConnPool) popFree() *Conn {
+ if len(p.freeConns) == 0 {
+ return nil
+ }
+
+ idx := len(p.freeConns) - 1
+ cn := p.freeConns[idx]
+ p.freeConns = p.freeConns[:idx]
+ return cn
+}
+
+func (p *ConnPool) Put(cn *Conn) error {
+ if data := cn.Rd.PeekBuffered(); data != nil {
+ internal.Logf("connection has unread data: %q", data)
+ return p.Remove(cn)
+ }
+ p.freeConnsMu.Lock()
+ p.freeConns = append(p.freeConns, cn)
+ p.freeConnsMu.Unlock()
+ <-p.queue
+ return nil
+}
+
+func (p *ConnPool) Remove(cn *Conn) error {
+ _ = p.CloseConn(cn)
+ <-p.queue
+ return nil
+}
+
+func (p *ConnPool) CloseConn(cn *Conn) error {
+ p.connsMu.Lock()
+ for i, c := range p.conns {
+ if c == cn {
+ p.conns = append(p.conns[:i], p.conns[i+1:]...)
+ break
+ }
+ }
+ p.connsMu.Unlock()
+
+ return p.closeConn(cn)
+}
+
+func (p *ConnPool) closeConn(cn *Conn) error {
+ if p.opt.OnClose != nil {
+ _ = p.opt.OnClose(cn)
+ }
+ return cn.Close()
+}
+
+// Len returns total number of connections.
+func (p *ConnPool) Len() int {
+ p.connsMu.Lock()
+ l := len(p.conns)
+ p.connsMu.Unlock()
+ return l
+}
+
+// FreeLen returns number of free connections.
+func (p *ConnPool) FreeLen() int {
+ p.freeConnsMu.Lock()
+ l := len(p.freeConns)
+ p.freeConnsMu.Unlock()
+ return l
+}
+
+func (p *ConnPool) Stats() *Stats {
+ return &Stats{
+ Requests: atomic.LoadUint32(&p.stats.Requests),
+ Hits: atomic.LoadUint32(&p.stats.Hits),
+ Timeouts: atomic.LoadUint32(&p.stats.Timeouts),
+ TotalConns: uint32(p.Len()),
+ FreeConns: uint32(p.FreeLen()),
+ }
+}
+
+func (p *ConnPool) closed() bool {
+ return atomic.LoadUint32(&p._closed) == 1
+}
+
+func (p *ConnPool) Filter(fn func(*Conn) bool) error {
+ var firstErr error
+ p.connsMu.Lock()
+ for _, cn := range p.conns {
+ if fn(cn) {
+ if err := p.closeConn(cn); err != nil && firstErr == nil {
+ firstErr = err
+ }
+ }
+ }
+ p.connsMu.Unlock()
+ return firstErr
+}
+
+func (p *ConnPool) Close() error {
+ if !atomic.CompareAndSwapUint32(&p._closed, 0, 1) {
+ return ErrClosed
+ }
+
+ var firstErr error
+ p.connsMu.Lock()
+ for _, cn := range p.conns {
+ if err := p.closeConn(cn); err != nil && firstErr == nil {
+ firstErr = err
+ }
+ }
+ p.conns = nil
+ p.connsMu.Unlock()
+
+ p.freeConnsMu.Lock()
+ p.freeConns = nil
+ p.freeConnsMu.Unlock()
+
+ return firstErr
+}
+
+func (p *ConnPool) reapStaleConn() bool {
+ if len(p.freeConns) == 0 {
+ return false
+ }
+
+ cn := p.freeConns[0]
+ if !cn.IsStale(p.opt.IdleTimeout) {
+ return false
+ }
+
+ p.CloseConn(cn)
+ p.freeConns = append(p.freeConns[:0], p.freeConns[1:]...)
+
+ return true
+}
+
+func (p *ConnPool) ReapStaleConns() (int, error) {
+ var n int
+ for {
+ p.queue <- struct{}{}
+ p.freeConnsMu.Lock()
+
+ reaped := p.reapStaleConn()
+
+ p.freeConnsMu.Unlock()
+ <-p.queue
+
+ if reaped {
+ n++
+ } else {
+ break
+ }
+ }
+ return n, nil
+}
+
+func (p *ConnPool) reaper(frequency time.Duration) {
+ ticker := time.NewTicker(frequency)
+ defer ticker.Stop()
+
+ for range ticker.C {
+ if p.closed() {
+ break
+ }
+ n, err := p.ReapStaleConns()
+ if err != nil {
+ internal.Logf("ReapStaleConns failed: %s", err)
+ continue
+ }
+ s := p.Stats()
+ internal.Logf(
+ "reaper: removed %d stale conns (TotalConns=%d FreeConns=%d Requests=%d Hits=%d Timeouts=%d)",
+ n, s.TotalConns, s.FreeConns, s.Requests, s.Hits, s.Timeouts,
+ )
+ }
+}
diff --git a/vendor/github.com/go-redis/redis/internal/pool/pool_single.go b/vendor/github.com/go-redis/redis/internal/pool/pool_single.go
new file mode 100644
index 000000000..ff91279b3
--- /dev/null
+++ b/vendor/github.com/go-redis/redis/internal/pool/pool_single.go
@@ -0,0 +1,55 @@
+package pool
+
+type SingleConnPool struct {
+ cn *Conn
+}
+
+var _ Pooler = (*SingleConnPool)(nil)
+
+func NewSingleConnPool(cn *Conn) *SingleConnPool {
+ return &SingleConnPool{
+ cn: cn,
+ }
+}
+
+func (p *SingleConnPool) NewConn() (*Conn, error) {
+ panic("not implemented")
+}
+
+func (p *SingleConnPool) CloseConn(*Conn) error {
+ panic("not implemented")
+}
+
+func (p *SingleConnPool) Get() (*Conn, bool, error) {
+ return p.cn, false, nil
+}
+
+func (p *SingleConnPool) Put(cn *Conn) error {
+ if p.cn != cn {
+ panic("p.cn != cn")
+ }
+ return nil
+}
+
+func (p *SingleConnPool) Remove(cn *Conn) error {
+ if p.cn != cn {
+ panic("p.cn != cn")
+ }
+ return nil
+}
+
+func (p *SingleConnPool) Len() int {
+ return 1
+}
+
+func (p *SingleConnPool) FreeLen() int {
+ return 0
+}
+
+func (p *SingleConnPool) Stats() *Stats {
+ return nil
+}
+
+func (p *SingleConnPool) Close() error {
+ return nil
+}
diff --git a/vendor/github.com/go-redis/redis/internal/pool/pool_sticky.go b/vendor/github.com/go-redis/redis/internal/pool/pool_sticky.go
new file mode 100644
index 000000000..17f163858
--- /dev/null
+++ b/vendor/github.com/go-redis/redis/internal/pool/pool_sticky.go
@@ -0,0 +1,123 @@
+package pool
+
+import "sync"
+
+type StickyConnPool struct {
+ pool *ConnPool
+ reusable bool
+
+ cn *Conn
+ closed bool
+ mu sync.Mutex
+}
+
+var _ Pooler = (*StickyConnPool)(nil)
+
+func NewStickyConnPool(pool *ConnPool, reusable bool) *StickyConnPool {
+ return &StickyConnPool{
+ pool: pool,
+ reusable: reusable,
+ }
+}
+
+func (p *StickyConnPool) NewConn() (*Conn, error) {
+ panic("not implemented")
+}
+
+func (p *StickyConnPool) CloseConn(*Conn) error {
+ panic("not implemented")
+}
+
+func (p *StickyConnPool) Get() (*Conn, bool, error) {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+
+ if p.closed {
+ return nil, false, ErrClosed
+ }
+ if p.cn != nil {
+ return p.cn, false, nil
+ }
+
+ cn, _, err := p.pool.Get()
+ if err != nil {
+ return nil, false, err
+ }
+ p.cn = cn
+ return cn, true, nil
+}
+
+func (p *StickyConnPool) putUpstream() (err error) {
+ err = p.pool.Put(p.cn)
+ p.cn = nil
+ return err
+}
+
+func (p *StickyConnPool) Put(cn *Conn) error {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+
+ if p.closed {
+ return ErrClosed
+ }
+ return nil
+}
+
+func (p *StickyConnPool) removeUpstream() error {
+ err := p.pool.Remove(p.cn)
+ p.cn = nil
+ return err
+}
+
+func (p *StickyConnPool) Remove(cn *Conn) error {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+
+ if p.closed {
+ return nil
+ }
+ return p.removeUpstream()
+}
+
+func (p *StickyConnPool) Len() int {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+
+ if p.cn == nil {
+ return 0
+ }
+ return 1
+}
+
+func (p *StickyConnPool) FreeLen() int {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+
+ if p.cn == nil {
+ return 1
+ }
+ return 0
+}
+
+func (p *StickyConnPool) Stats() *Stats {
+ return nil
+}
+
+func (p *StickyConnPool) Close() error {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+
+ if p.closed {
+ return ErrClosed
+ }
+ p.closed = true
+ var err error
+ if p.cn != nil {
+ if p.reusable {
+ err = p.putUpstream()
+ } else {
+ err = p.removeUpstream()
+ }
+ }
+ return err
+}
diff --git a/vendor/github.com/go-redis/redis/internal/pool/pool_test.go b/vendor/github.com/go-redis/redis/internal/pool/pool_test.go
new file mode 100644
index 000000000..68c9a1bef
--- /dev/null
+++ b/vendor/github.com/go-redis/redis/internal/pool/pool_test.go
@@ -0,0 +1,241 @@
+package pool_test
+
+import (
+ "testing"
+ "time"
+
+ "github.com/go-redis/redis/internal/pool"
+
+ . "github.com/onsi/ginkgo"
+ . "github.com/onsi/gomega"
+)
+
+var _ = Describe("ConnPool", func() {
+ var connPool *pool.ConnPool
+
+ BeforeEach(func() {
+ connPool = pool.NewConnPool(&pool.Options{
+ Dialer: dummyDialer,
+ PoolSize: 10,
+ PoolTimeout: time.Hour,
+ IdleTimeout: time.Millisecond,
+ IdleCheckFrequency: time.Millisecond,
+ })
+ })
+
+ AfterEach(func() {
+ connPool.Close()
+ })
+
+ It("should unblock client when conn is removed", func() {
+ // Reserve one connection.
+ cn, _, err := connPool.Get()
+ Expect(err).NotTo(HaveOccurred())
+
+ // Reserve all other connections.
+ var cns []*pool.Conn
+ for i := 0; i < 9; i++ {
+ cn, _, err := connPool.Get()
+ Expect(err).NotTo(HaveOccurred())
+ cns = append(cns, cn)
+ }
+
+ started := make(chan bool, 1)
+ done := make(chan bool, 1)
+ go func() {
+ defer GinkgoRecover()
+
+ started <- true
+ _, _, err := connPool.Get()
+ Expect(err).NotTo(HaveOccurred())
+ done <- true
+
+ err = connPool.Put(cn)
+ Expect(err).NotTo(HaveOccurred())
+ }()
+ <-started
+
+ // Check that Get is blocked.
+ select {
+ case <-done:
+ Fail("Get is not blocked")
+ default:
+ // ok
+ }
+
+ err = connPool.Remove(cn)
+ Expect(err).NotTo(HaveOccurred())
+
+ // Check that Ping is unblocked.
+ select {
+ case <-done:
+ // ok
+ case <-time.After(time.Second):
+ Fail("Get is not unblocked")
+ }
+
+ for _, cn := range cns {
+ err = connPool.Put(cn)
+ Expect(err).NotTo(HaveOccurred())
+ }
+ })
+})
+
+var _ = Describe("conns reaper", func() {
+ const idleTimeout = time.Minute
+
+ var connPool *pool.ConnPool
+ var conns, idleConns, closedConns []*pool.Conn
+
+ BeforeEach(func() {
+ conns = nil
+ closedConns = nil
+
+ connPool = pool.NewConnPool(&pool.Options{
+ Dialer: dummyDialer,
+ PoolSize: 10,
+ PoolTimeout: time.Second,
+ IdleTimeout: idleTimeout,
+ IdleCheckFrequency: time.Hour,
+
+ OnClose: func(cn *pool.Conn) error {
+ closedConns = append(closedConns, cn)
+ return nil
+ },
+ })
+
+ // add stale connections
+ idleConns = nil
+ for i := 0; i < 3; i++ {
+ cn, _, err := connPool.Get()
+ Expect(err).NotTo(HaveOccurred())
+ cn.SetUsedAt(time.Now().Add(-2 * idleTimeout))
+ conns = append(conns, cn)
+ idleConns = append(idleConns, cn)
+ }
+
+ // add fresh connections
+ for i := 0; i < 3; i++ {
+ cn, _, err := connPool.Get()
+ Expect(err).NotTo(HaveOccurred())
+ conns = append(conns, cn)
+ }
+
+ for _, cn := range conns {
+ Expect(connPool.Put(cn)).NotTo(HaveOccurred())
+ }
+
+ Expect(connPool.Len()).To(Equal(6))
+ Expect(connPool.FreeLen()).To(Equal(6))
+
+ n, err := connPool.ReapStaleConns()
+ Expect(err).NotTo(HaveOccurred())
+ Expect(n).To(Equal(3))
+ })
+
+ AfterEach(func() {
+ _ = connPool.Close()
+ Expect(connPool.Len()).To(Equal(0))
+ Expect(connPool.FreeLen()).To(Equal(0))
+ Expect(len(closedConns)).To(Equal(len(conns)))
+ Expect(closedConns).To(ConsistOf(conns))
+ })
+
+ It("reaps stale connections", func() {
+ Expect(connPool.Len()).To(Equal(3))
+ Expect(connPool.FreeLen()).To(Equal(3))
+ })
+
+ It("does not reap fresh connections", func() {
+ n, err := connPool.ReapStaleConns()
+ Expect(err).NotTo(HaveOccurred())
+ Expect(n).To(Equal(0))
+ })
+
+ It("stale connections are closed", func() {
+ Expect(len(closedConns)).To(Equal(len(idleConns)))
+ Expect(closedConns).To(ConsistOf(idleConns))
+ })
+
+ It("pool is functional", func() {
+ for j := 0; j < 3; j++ {
+ var freeCns []*pool.Conn
+ for i := 0; i < 3; i++ {
+ cn, _, err := connPool.Get()
+ Expect(err).NotTo(HaveOccurred())
+ Expect(cn).NotTo(BeNil())
+ freeCns = append(freeCns, cn)
+ }
+
+ Expect(connPool.Len()).To(Equal(3))
+ Expect(connPool.FreeLen()).To(Equal(0))
+
+ cn, _, err := connPool.Get()
+ Expect(err).NotTo(HaveOccurred())
+ Expect(cn).NotTo(BeNil())
+ conns = append(conns, cn)
+
+ Expect(connPool.Len()).To(Equal(4))
+ Expect(connPool.FreeLen()).To(Equal(0))
+
+ err = connPool.Remove(cn)
+ Expect(err).NotTo(HaveOccurred())
+
+ Expect(connPool.Len()).To(Equal(3))
+ Expect(connPool.FreeLen()).To(Equal(0))
+
+ for _, cn := range freeCns {
+ err := connPool.Put(cn)
+ Expect(err).NotTo(HaveOccurred())
+ }
+
+ Expect(connPool.Len()).To(Equal(3))
+ Expect(connPool.FreeLen()).To(Equal(3))
+ }
+ })
+})
+
+var _ = Describe("race", func() {
+ var connPool *pool.ConnPool
+ var C, N int
+
+ BeforeEach(func() {
+ C, N = 10, 1000
+ if testing.Short() {
+ C = 4
+ N = 100
+ }
+ })
+
+ AfterEach(func() {
+ connPool.Close()
+ })
+
+ It("does not happen on Get, Put, and Remove", func() {
+ connPool = pool.NewConnPool(&pool.Options{
+ Dialer: dummyDialer,
+ PoolSize: 10,
+ PoolTimeout: time.Minute,
+ IdleTimeout: time.Millisecond,
+ IdleCheckFrequency: time.Millisecond,
+ })
+
+ perform(C, func(id int) {
+ for i := 0; i < N; i++ {
+ cn, _, err := connPool.Get()
+ Expect(err).NotTo(HaveOccurred())
+ if err == nil {
+ Expect(connPool.Put(cn)).NotTo(HaveOccurred())
+ }
+ }
+ }, func(id int) {
+ for i := 0; i < N; i++ {
+ cn, _, err := connPool.Get()
+ Expect(err).NotTo(HaveOccurred())
+ if err == nil {
+ Expect(connPool.Remove(cn)).NotTo(HaveOccurred())
+ }
+ }
+ })
+ })
+})