summaryrefslogtreecommitdiffstats
path: root/vendor/github.com/tylerb/graceful/graceful_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/tylerb/graceful/graceful_test.go')
-rw-r--r--vendor/github.com/tylerb/graceful/graceful_test.go692
1 files changed, 692 insertions, 0 deletions
diff --git a/vendor/github.com/tylerb/graceful/graceful_test.go b/vendor/github.com/tylerb/graceful/graceful_test.go
new file mode 100644
index 000000000..b9c49336b
--- /dev/null
+++ b/vendor/github.com/tylerb/graceful/graceful_test.go
@@ -0,0 +1,692 @@
+package graceful
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+ "log"
+ "net"
+ "net/http"
+ "net/url"
+ "os"
+ "reflect"
+ "strings"
+ "sync"
+ "syscall"
+ "testing"
+ "time"
+)
+
+const (
+ // The tests will run a test server on this port.
+ port = 9654
+ concurrentRequestN = 8
+ killTime = 500 * time.Millisecond
+ timeoutTime = 1000 * time.Millisecond
+ waitTime = 100 * time.Millisecond
+)
+
+func runQuery(t *testing.T, expected int, shouldErr bool, wg *sync.WaitGroup, once *sync.Once) {
+ defer wg.Done()
+ client := http.Client{}
+ r, err := client.Get(fmt.Sprintf("http://localhost:%d", port))
+ if shouldErr && err == nil {
+ once.Do(func() {
+ t.Error("Expected an error but none was encountered.")
+ })
+ } else if shouldErr && err != nil {
+ if checkErr(t, err, once) {
+ return
+ }
+ }
+ if r != nil && r.StatusCode != expected {
+ once.Do(func() {
+ t.Errorf("Incorrect status code on response. Expected %d. Got %d", expected, r.StatusCode)
+ })
+ } else if r == nil {
+ once.Do(func() {
+ t.Error("No response when a response was expected.")
+ })
+ }
+}
+
+func checkErr(t *testing.T, err error, once *sync.Once) bool {
+ if err.(*url.Error).Err == io.EOF {
+ return true
+ }
+ var errno syscall.Errno
+ switch oe := err.(*url.Error).Err.(type) {
+ case *net.OpError:
+ switch e := oe.Err.(type) {
+ case syscall.Errno:
+ errno = e
+ case *os.SyscallError:
+ errno = e.Err.(syscall.Errno)
+ }
+ if errno == syscall.ECONNREFUSED {
+ return true
+ } else if err != nil {
+ once.Do(func() {
+ t.Error("Error on Get:", err)
+ })
+ }
+ default:
+ if strings.Contains(err.Error(), "transport closed before response was received") {
+ return true
+ }
+ if strings.Contains(err.Error(), "server closed connection") {
+ return true
+ }
+ fmt.Printf("unknown err: %s, %#v\n", err, err)
+ }
+ return false
+}
+
+func createListener(sleep time.Duration) (*http.Server, net.Listener, error) {
+ mux := http.NewServeMux()
+ mux.HandleFunc("/", func(rw http.ResponseWriter, r *http.Request) {
+ time.Sleep(sleep)
+ rw.WriteHeader(http.StatusOK)
+ })
+
+ server := &http.Server{Addr: fmt.Sprintf(":%d", port), Handler: mux}
+ l, err := net.Listen("tcp", fmt.Sprintf(":%d", port))
+ return server, l, err
+}
+
+func launchTestQueries(t *testing.T, wg *sync.WaitGroup, c chan os.Signal) {
+ defer wg.Done()
+ var once sync.Once
+
+ for i := 0; i < concurrentRequestN; i++ {
+ wg.Add(1)
+ go runQuery(t, http.StatusOK, false, wg, &once)
+ }
+
+ time.Sleep(waitTime)
+ c <- os.Interrupt
+ time.Sleep(waitTime)
+
+ for i := 0; i < concurrentRequestN; i++ {
+ wg.Add(1)
+ go runQuery(t, 0, true, wg, &once)
+ }
+}
+
+func TestGracefulRun(t *testing.T) {
+ var wg sync.WaitGroup
+ defer wg.Wait()
+
+ c := make(chan os.Signal, 1)
+ server, l, err := createListener(killTime / 2)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ srv := &Server{Timeout: killTime, Server: server, interrupt: c}
+ srv.Serve(l)
+ }()
+
+ wg.Add(1)
+ go launchTestQueries(t, &wg, c)
+}
+
+func TestGracefulRunLimitKeepAliveListener(t *testing.T) {
+ var wg sync.WaitGroup
+ defer wg.Wait()
+
+ c := make(chan os.Signal, 1)
+ server, l, err := createListener(killTime / 2)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ srv := &Server{
+ Timeout: killTime,
+ ListenLimit: concurrentRequestN,
+ TCPKeepAlive: 1 * time.Second,
+ Server: server,
+ interrupt: c,
+ }
+ srv.Serve(l)
+ }()
+
+ wg.Add(1)
+ go launchTestQueries(t, &wg, c)
+}
+
+func TestGracefulRunTimesOut(t *testing.T) {
+ var wg sync.WaitGroup
+ defer wg.Wait()
+
+ c := make(chan os.Signal, 1)
+ server, l, err := createListener(killTime * 10)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ srv := &Server{Timeout: killTime, Server: server, interrupt: c}
+ srv.Serve(l)
+ }()
+
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ var once sync.Once
+
+ for i := 0; i < concurrentRequestN; i++ {
+ wg.Add(1)
+ go runQuery(t, 0, true, &wg, &once)
+ }
+
+ time.Sleep(waitTime)
+ c <- os.Interrupt
+ time.Sleep(waitTime)
+
+ for i := 0; i < concurrentRequestN; i++ {
+ wg.Add(1)
+ go runQuery(t, 0, true, &wg, &once)
+ }
+ }()
+}
+
+func TestGracefulRunDoesntTimeOut(t *testing.T) {
+ var wg sync.WaitGroup
+ defer wg.Wait()
+
+ c := make(chan os.Signal, 1)
+ server, l, err := createListener(killTime * 2)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ srv := &Server{Timeout: 0, Server: server, interrupt: c}
+ srv.Serve(l)
+ }()
+
+ wg.Add(1)
+ go launchTestQueries(t, &wg, c)
+}
+
+func TestGracefulRunDoesntTimeOutAfterConnectionCreated(t *testing.T) {
+ var wg sync.WaitGroup
+ defer wg.Wait()
+
+ c := make(chan os.Signal, 1)
+ server, l, err := createListener(killTime)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ srv := &Server{Timeout: 0, Server: server, interrupt: c}
+ srv.Serve(l)
+ }()
+ time.Sleep(waitTime)
+
+ // Make a sample first request. The connection will be left idle.
+ resp, err := http.Get(fmt.Sprintf("http://localhost:%d", port))
+ if err != nil {
+ panic(fmt.Sprintf("first request failed: %v", err))
+ }
+ resp.Body.Close()
+
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+
+ // With idle connections improperly handled, the server doesn't wait for this
+ // to complete and the request fails. It should be allowed to complete successfully.
+ _, err := http.Get(fmt.Sprintf("http://localhost:%d", port))
+ if err != nil {
+ t.Errorf("Get failed: %v", err)
+ }
+ }()
+
+ // Ensure the request goes out
+ time.Sleep(waitTime)
+ c <- os.Interrupt
+ wg.Wait()
+}
+
+func TestGracefulRunNoRequests(t *testing.T) {
+ var wg sync.WaitGroup
+ defer wg.Wait()
+
+ c := make(chan os.Signal, 1)
+ server, l, err := createListener(killTime * 2)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ srv := &Server{Timeout: 0, Server: server, interrupt: c}
+ srv.Serve(l)
+ }()
+
+ c <- os.Interrupt
+}
+
+func TestGracefulForwardsConnState(t *testing.T) {
+ var stateLock sync.Mutex
+ states := make(map[http.ConnState]int)
+ connState := func(conn net.Conn, state http.ConnState) {
+ stateLock.Lock()
+ states[state]++
+ stateLock.Unlock()
+ }
+
+ var wg sync.WaitGroup
+ defer wg.Wait()
+
+ expected := map[http.ConnState]int{
+ http.StateNew: concurrentRequestN,
+ http.StateActive: concurrentRequestN,
+ http.StateClosed: concurrentRequestN,
+ }
+
+ c := make(chan os.Signal, 1)
+ server, l, err := createListener(killTime / 2)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ srv := &Server{
+ ConnState: connState,
+ Timeout: killTime,
+ Server: server,
+ interrupt: c,
+ }
+ srv.Serve(l)
+ }()
+
+ wg.Add(1)
+ go launchTestQueries(t, &wg, c)
+ wg.Wait()
+
+ stateLock.Lock()
+ if !reflect.DeepEqual(states, expected) {
+ t.Errorf("Incorrect connection state tracking.\n actual: %v\nexpected: %v\n", states, expected)
+ }
+ stateLock.Unlock()
+}
+
+func TestGracefulExplicitStop(t *testing.T) {
+ server, l, err := createListener(1 * time.Millisecond)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ srv := &Server{Timeout: killTime, Server: server}
+
+ go func() {
+ go srv.Serve(l)
+ time.Sleep(waitTime)
+ srv.Stop(killTime)
+ }()
+
+ // block on the stopChan until the server has shut down
+ select {
+ case <-srv.StopChan():
+ case <-time.After(timeoutTime):
+ t.Fatal("Timed out while waiting for explicit stop to complete")
+ }
+}
+
+func TestGracefulExplicitStopOverride(t *testing.T) {
+ server, l, err := createListener(1 * time.Millisecond)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ srv := &Server{Timeout: killTime, Server: server}
+
+ go func() {
+ go srv.Serve(l)
+ time.Sleep(waitTime)
+ srv.Stop(killTime / 2)
+ }()
+
+ // block on the stopChan until the server has shut down
+ select {
+ case <-srv.StopChan():
+ case <-time.After(killTime):
+ t.Fatal("Timed out while waiting for explicit stop to complete")
+ }
+}
+
+func TestBeforeShutdownAndShutdownInitiatedCallbacks(t *testing.T) {
+ var wg sync.WaitGroup
+ defer wg.Wait()
+
+ server, l, err := createListener(1 * time.Millisecond)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ beforeShutdownCalled := make(chan struct{})
+ cb1 := func() bool { close(beforeShutdownCalled); return true }
+ shutdownInitiatedCalled := make(chan struct{})
+ cb2 := func() { close(shutdownInitiatedCalled) }
+
+ wg.Add(2)
+ srv := &Server{Server: server, BeforeShutdown: cb1, ShutdownInitiated: cb2}
+ go func() {
+ defer wg.Done()
+ srv.Serve(l)
+ }()
+ go func() {
+ defer wg.Done()
+ time.Sleep(waitTime)
+ srv.Stop(killTime)
+ }()
+
+ beforeShutdown := false
+ shutdownInitiated := false
+ for i := 0; i < 2; i++ {
+ select {
+ case <-beforeShutdownCalled:
+ beforeShutdownCalled = nil
+ beforeShutdown = true
+ case <-shutdownInitiatedCalled:
+ shutdownInitiatedCalled = nil
+ shutdownInitiated = true
+ case <-time.After(killTime):
+ t.Fatal("Timed out while waiting for ShutdownInitiated callback to be called")
+ }
+ }
+
+ if !beforeShutdown {
+ t.Fatal("beforeShutdown should be true")
+ }
+ if !shutdownInitiated {
+ t.Fatal("shutdownInitiated should be true")
+ }
+}
+
+func TestBeforeShutdownCanceled(t *testing.T) {
+ var wg sync.WaitGroup
+ wg.Add(1)
+
+ server, l, err := createListener(1 * time.Millisecond)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ beforeShutdownCalled := make(chan struct{})
+ cb1 := func() bool { close(beforeShutdownCalled); return false }
+ shutdownInitiatedCalled := make(chan struct{})
+ cb2 := func() { close(shutdownInitiatedCalled) }
+
+ srv := &Server{Server: server, BeforeShutdown: cb1, ShutdownInitiated: cb2}
+ go func() {
+ srv.Serve(l)
+ wg.Done()
+ }()
+ go func() {
+ time.Sleep(waitTime)
+ srv.Stop(killTime)
+ }()
+
+ beforeShutdown := false
+ shutdownInitiated := false
+ timeouted := false
+
+ for i := 0; i < 2; i++ {
+ select {
+ case <-beforeShutdownCalled:
+ beforeShutdownCalled = nil
+ beforeShutdown = true
+ case <-shutdownInitiatedCalled:
+ shutdownInitiatedCalled = nil
+ shutdownInitiated = true
+ case <-time.After(killTime):
+ timeouted = true
+ }
+ }
+
+ if !beforeShutdown {
+ t.Fatal("beforeShutdown should be true")
+ }
+ if !timeouted {
+ t.Fatal("timeouted should be true")
+ }
+ if shutdownInitiated {
+ t.Fatal("shutdownInitiated shouldn't be true")
+ }
+
+ srv.BeforeShutdown = func() bool { return true }
+ srv.Stop(killTime)
+
+ wg.Wait()
+}
+
+func hijackingListener(srv *Server) (*http.Server, net.Listener, error) {
+ mux := http.NewServeMux()
+ mux.HandleFunc("/", func(rw http.ResponseWriter, r *http.Request) {
+ conn, bufrw, err := rw.(http.Hijacker).Hijack()
+ if err != nil {
+ http.Error(rw, "webserver doesn't support hijacking", http.StatusInternalServerError)
+ return
+ }
+
+ defer conn.Close()
+
+ bufrw.WriteString("HTTP/1.1 200 OK\r\n\r\n")
+ bufrw.Flush()
+ })
+
+ server := &http.Server{Addr: fmt.Sprintf(":%d", port), Handler: mux}
+ l, err := net.Listen("tcp", fmt.Sprintf(":%d", port))
+ return server, l, err
+}
+
+func TestNotifyClosed(t *testing.T) {
+ var wg sync.WaitGroup
+ defer wg.Wait()
+
+ c := make(chan os.Signal, 1)
+ srv := &Server{Timeout: killTime, interrupt: c}
+ server, l, err := hijackingListener(srv)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ srv.Server = server
+
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ srv.Serve(l)
+ }()
+
+ var once sync.Once
+ for i := 0; i < concurrentRequestN; i++ {
+ wg.Add(1)
+ runQuery(t, http.StatusOK, false, &wg, &once)
+ }
+
+ srv.Stop(0)
+
+ // block on the stopChan until the server has shut down
+ select {
+ case <-srv.StopChan():
+ case <-time.After(timeoutTime):
+ t.Fatal("Timed out while waiting for explicit stop to complete")
+ }
+
+ if len(srv.connections) > 0 {
+ t.Fatal("hijacked connections should not be managed")
+ }
+
+}
+
+func TestStopDeadlock(t *testing.T) {
+ var wg sync.WaitGroup
+ defer wg.Wait()
+
+ c := make(chan struct{})
+ server, l, err := createListener(1 * time.Millisecond)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ srv := &Server{Server: server, NoSignalHandling: true}
+
+ wg.Add(2)
+ go func() {
+ defer wg.Done()
+ time.Sleep(waitTime)
+ srv.Serve(l)
+ }()
+ go func() {
+ defer wg.Done()
+ srv.Stop(0)
+ close(c)
+ }()
+
+ select {
+ case <-c:
+ l.Close()
+ case <-time.After(timeoutTime):
+ t.Fatal("Timed out while waiting for explicit stop to complete")
+ }
+}
+
+// Run with --race
+func TestStopRace(t *testing.T) {
+ server, l, err := createListener(1 * time.Millisecond)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ srv := &Server{Timeout: killTime, Server: server}
+
+ go func() {
+ go srv.Serve(l)
+ srv.Stop(killTime)
+ }()
+ srv.Stop(0)
+ select {
+ case <-srv.StopChan():
+ case <-time.After(timeoutTime):
+ t.Fatal("Timed out while waiting for explicit stop to complete")
+ }
+}
+
+func TestInterruptLog(t *testing.T) {
+ c := make(chan os.Signal, 1)
+
+ server, l, err := createListener(killTime * 10)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ var buf bytes.Buffer
+ var tbuf bytes.Buffer
+ logger := log.New(&buf, "", 0)
+ expected := log.New(&tbuf, "", 0)
+
+ srv := &Server{Timeout: killTime, Server: server, Logger: logger, interrupt: c}
+ go func() { srv.Serve(l) }()
+
+ stop := srv.StopChan()
+ c <- os.Interrupt
+ expected.Print("shutdown initiated")
+
+ <-stop
+
+ if buf.String() != tbuf.String() {
+ t.Fatal("shutdown log incorrect - got '" + buf.String() + "'")
+ }
+}
+
+func TestMultiInterrupts(t *testing.T) {
+ c := make(chan os.Signal, 1)
+
+ server, l, err := createListener(killTime * 10)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ var wg sync.WaitGroup
+ var bu bytes.Buffer
+ buf := SyncBuffer{&wg, &bu}
+ var tbuf bytes.Buffer
+ logger := log.New(&buf, "", 0)
+ expected := log.New(&tbuf, "", 0)
+
+ srv := &Server{Timeout: killTime, Server: server, Logger: logger, interrupt: c}
+ go func() { srv.Serve(l) }()
+
+ stop := srv.StopChan()
+ buf.Add(1 + 10) // Expecting 11 log calls
+ c <- os.Interrupt
+ expected.Printf("shutdown initiated")
+ for i := 0; i < 10; i++ {
+ c <- os.Interrupt
+ expected.Printf("already shutting down")
+ }
+
+ <-stop
+
+ wg.Wait()
+ bb, bt := buf.Bytes(), tbuf.Bytes()
+ for i, b := range bb {
+ if b != bt[i] {
+ t.Fatal(fmt.Sprintf("shutdown log incorrect - got '%s', expected '%s'", buf.String(), tbuf.String()))
+ }
+ }
+}
+
+func TestLogFunc(t *testing.T) {
+ c := make(chan os.Signal, 1)
+
+ server, l, err := createListener(killTime * 10)
+ if err != nil {
+ t.Fatal(err)
+ }
+ var called bool
+ srv := &Server{Timeout: killTime, Server: server,
+ LogFunc: func(format string, args ...interface{}) {
+ called = true
+ }, interrupt: c}
+ stop := srv.StopChan()
+ go func() { srv.Serve(l) }()
+ c <- os.Interrupt
+ <-stop
+
+ if called != true {
+ t.Fatal("Expected LogFunc to be called.")
+ }
+}
+
+// SyncBuffer calls Done on the embedded wait group after each call to Write.
+type SyncBuffer struct {
+ *sync.WaitGroup
+ *bytes.Buffer
+}
+
+func (buf *SyncBuffer) Write(b []byte) (int, error) {
+ defer buf.Done()
+ return buf.Buffer.Write(b)
+}