From 8f91c777559748fa6e857d9fc1f4ae079a532813 Mon Sep 17 00:00:00 2001 From: Christopher Speller Date: Mon, 3 Oct 2016 16:03:15 -0400 Subject: Adding ability to serve TLS directly from Mattermost server (#4119) --- vendor/github.com/tylerb/graceful/graceful_test.go | 692 +++++++++++++++++++++ 1 file changed, 692 insertions(+) create mode 100644 vendor/github.com/tylerb/graceful/graceful_test.go (limited to 'vendor/github.com/tylerb/graceful/graceful_test.go') diff --git a/vendor/github.com/tylerb/graceful/graceful_test.go b/vendor/github.com/tylerb/graceful/graceful_test.go new file mode 100644 index 000000000..b9c49336b --- /dev/null +++ b/vendor/github.com/tylerb/graceful/graceful_test.go @@ -0,0 +1,692 @@ +package graceful + +import ( + "bytes" + "fmt" + "io" + "log" + "net" + "net/http" + "net/url" + "os" + "reflect" + "strings" + "sync" + "syscall" + "testing" + "time" +) + +const ( + // The tests will run a test server on this port. + port = 9654 + concurrentRequestN = 8 + killTime = 500 * time.Millisecond + timeoutTime = 1000 * time.Millisecond + waitTime = 100 * time.Millisecond +) + +func runQuery(t *testing.T, expected int, shouldErr bool, wg *sync.WaitGroup, once *sync.Once) { + defer wg.Done() + client := http.Client{} + r, err := client.Get(fmt.Sprintf("http://localhost:%d", port)) + if shouldErr && err == nil { + once.Do(func() { + t.Error("Expected an error but none was encountered.") + }) + } else if shouldErr && err != nil { + if checkErr(t, err, once) { + return + } + } + if r != nil && r.StatusCode != expected { + once.Do(func() { + t.Errorf("Incorrect status code on response. Expected %d. Got %d", expected, r.StatusCode) + }) + } else if r == nil { + once.Do(func() { + t.Error("No response when a response was expected.") + }) + } +} + +func checkErr(t *testing.T, err error, once *sync.Once) bool { + if err.(*url.Error).Err == io.EOF { + return true + } + var errno syscall.Errno + switch oe := err.(*url.Error).Err.(type) { + case *net.OpError: + switch e := oe.Err.(type) { + case syscall.Errno: + errno = e + case *os.SyscallError: + errno = e.Err.(syscall.Errno) + } + if errno == syscall.ECONNREFUSED { + return true + } else if err != nil { + once.Do(func() { + t.Error("Error on Get:", err) + }) + } + default: + if strings.Contains(err.Error(), "transport closed before response was received") { + return true + } + if strings.Contains(err.Error(), "server closed connection") { + return true + } + fmt.Printf("unknown err: %s, %#v\n", err, err) + } + return false +} + +func createListener(sleep time.Duration) (*http.Server, net.Listener, error) { + mux := http.NewServeMux() + mux.HandleFunc("/", func(rw http.ResponseWriter, r *http.Request) { + time.Sleep(sleep) + rw.WriteHeader(http.StatusOK) + }) + + server := &http.Server{Addr: fmt.Sprintf(":%d", port), Handler: mux} + l, err := net.Listen("tcp", fmt.Sprintf(":%d", port)) + return server, l, err +} + +func launchTestQueries(t *testing.T, wg *sync.WaitGroup, c chan os.Signal) { + defer wg.Done() + var once sync.Once + + for i := 0; i < concurrentRequestN; i++ { + wg.Add(1) + go runQuery(t, http.StatusOK, false, wg, &once) + } + + time.Sleep(waitTime) + c <- os.Interrupt + time.Sleep(waitTime) + + for i := 0; i < concurrentRequestN; i++ { + wg.Add(1) + go runQuery(t, 0, true, wg, &once) + } +} + +func TestGracefulRun(t *testing.T) { + var wg sync.WaitGroup + defer wg.Wait() + + c := make(chan os.Signal, 1) + server, l, err := createListener(killTime / 2) + if err != nil { + t.Fatal(err) + } + + wg.Add(1) + go func() { + defer wg.Done() + srv := &Server{Timeout: killTime, Server: server, interrupt: c} + srv.Serve(l) + }() + + wg.Add(1) + go launchTestQueries(t, &wg, c) +} + +func TestGracefulRunLimitKeepAliveListener(t *testing.T) { + var wg sync.WaitGroup + defer wg.Wait() + + c := make(chan os.Signal, 1) + server, l, err := createListener(killTime / 2) + if err != nil { + t.Fatal(err) + } + + wg.Add(1) + go func() { + defer wg.Done() + srv := &Server{ + Timeout: killTime, + ListenLimit: concurrentRequestN, + TCPKeepAlive: 1 * time.Second, + Server: server, + interrupt: c, + } + srv.Serve(l) + }() + + wg.Add(1) + go launchTestQueries(t, &wg, c) +} + +func TestGracefulRunTimesOut(t *testing.T) { + var wg sync.WaitGroup + defer wg.Wait() + + c := make(chan os.Signal, 1) + server, l, err := createListener(killTime * 10) + if err != nil { + t.Fatal(err) + } + + wg.Add(1) + go func() { + defer wg.Done() + srv := &Server{Timeout: killTime, Server: server, interrupt: c} + srv.Serve(l) + }() + + wg.Add(1) + go func() { + defer wg.Done() + var once sync.Once + + for i := 0; i < concurrentRequestN; i++ { + wg.Add(1) + go runQuery(t, 0, true, &wg, &once) + } + + time.Sleep(waitTime) + c <- os.Interrupt + time.Sleep(waitTime) + + for i := 0; i < concurrentRequestN; i++ { + wg.Add(1) + go runQuery(t, 0, true, &wg, &once) + } + }() +} + +func TestGracefulRunDoesntTimeOut(t *testing.T) { + var wg sync.WaitGroup + defer wg.Wait() + + c := make(chan os.Signal, 1) + server, l, err := createListener(killTime * 2) + if err != nil { + t.Fatal(err) + } + + wg.Add(1) + go func() { + defer wg.Done() + srv := &Server{Timeout: 0, Server: server, interrupt: c} + srv.Serve(l) + }() + + wg.Add(1) + go launchTestQueries(t, &wg, c) +} + +func TestGracefulRunDoesntTimeOutAfterConnectionCreated(t *testing.T) { + var wg sync.WaitGroup + defer wg.Wait() + + c := make(chan os.Signal, 1) + server, l, err := createListener(killTime) + if err != nil { + t.Fatal(err) + } + + wg.Add(1) + go func() { + defer wg.Done() + srv := &Server{Timeout: 0, Server: server, interrupt: c} + srv.Serve(l) + }() + time.Sleep(waitTime) + + // Make a sample first request. The connection will be left idle. + resp, err := http.Get(fmt.Sprintf("http://localhost:%d", port)) + if err != nil { + panic(fmt.Sprintf("first request failed: %v", err)) + } + resp.Body.Close() + + wg.Add(1) + go func() { + defer wg.Done() + + // With idle connections improperly handled, the server doesn't wait for this + // to complete and the request fails. It should be allowed to complete successfully. + _, err := http.Get(fmt.Sprintf("http://localhost:%d", port)) + if err != nil { + t.Errorf("Get failed: %v", err) + } + }() + + // Ensure the request goes out + time.Sleep(waitTime) + c <- os.Interrupt + wg.Wait() +} + +func TestGracefulRunNoRequests(t *testing.T) { + var wg sync.WaitGroup + defer wg.Wait() + + c := make(chan os.Signal, 1) + server, l, err := createListener(killTime * 2) + if err != nil { + t.Fatal(err) + } + + wg.Add(1) + go func() { + defer wg.Done() + srv := &Server{Timeout: 0, Server: server, interrupt: c} + srv.Serve(l) + }() + + c <- os.Interrupt +} + +func TestGracefulForwardsConnState(t *testing.T) { + var stateLock sync.Mutex + states := make(map[http.ConnState]int) + connState := func(conn net.Conn, state http.ConnState) { + stateLock.Lock() + states[state]++ + stateLock.Unlock() + } + + var wg sync.WaitGroup + defer wg.Wait() + + expected := map[http.ConnState]int{ + http.StateNew: concurrentRequestN, + http.StateActive: concurrentRequestN, + http.StateClosed: concurrentRequestN, + } + + c := make(chan os.Signal, 1) + server, l, err := createListener(killTime / 2) + if err != nil { + t.Fatal(err) + } + + wg.Add(1) + go func() { + defer wg.Done() + srv := &Server{ + ConnState: connState, + Timeout: killTime, + Server: server, + interrupt: c, + } + srv.Serve(l) + }() + + wg.Add(1) + go launchTestQueries(t, &wg, c) + wg.Wait() + + stateLock.Lock() + if !reflect.DeepEqual(states, expected) { + t.Errorf("Incorrect connection state tracking.\n actual: %v\nexpected: %v\n", states, expected) + } + stateLock.Unlock() +} + +func TestGracefulExplicitStop(t *testing.T) { + server, l, err := createListener(1 * time.Millisecond) + if err != nil { + t.Fatal(err) + } + + srv := &Server{Timeout: killTime, Server: server} + + go func() { + go srv.Serve(l) + time.Sleep(waitTime) + srv.Stop(killTime) + }() + + // block on the stopChan until the server has shut down + select { + case <-srv.StopChan(): + case <-time.After(timeoutTime): + t.Fatal("Timed out while waiting for explicit stop to complete") + } +} + +func TestGracefulExplicitStopOverride(t *testing.T) { + server, l, err := createListener(1 * time.Millisecond) + if err != nil { + t.Fatal(err) + } + + srv := &Server{Timeout: killTime, Server: server} + + go func() { + go srv.Serve(l) + time.Sleep(waitTime) + srv.Stop(killTime / 2) + }() + + // block on the stopChan until the server has shut down + select { + case <-srv.StopChan(): + case <-time.After(killTime): + t.Fatal("Timed out while waiting for explicit stop to complete") + } +} + +func TestBeforeShutdownAndShutdownInitiatedCallbacks(t *testing.T) { + var wg sync.WaitGroup + defer wg.Wait() + + server, l, err := createListener(1 * time.Millisecond) + if err != nil { + t.Fatal(err) + } + + beforeShutdownCalled := make(chan struct{}) + cb1 := func() bool { close(beforeShutdownCalled); return true } + shutdownInitiatedCalled := make(chan struct{}) + cb2 := func() { close(shutdownInitiatedCalled) } + + wg.Add(2) + srv := &Server{Server: server, BeforeShutdown: cb1, ShutdownInitiated: cb2} + go func() { + defer wg.Done() + srv.Serve(l) + }() + go func() { + defer wg.Done() + time.Sleep(waitTime) + srv.Stop(killTime) + }() + + beforeShutdown := false + shutdownInitiated := false + for i := 0; i < 2; i++ { + select { + case <-beforeShutdownCalled: + beforeShutdownCalled = nil + beforeShutdown = true + case <-shutdownInitiatedCalled: + shutdownInitiatedCalled = nil + shutdownInitiated = true + case <-time.After(killTime): + t.Fatal("Timed out while waiting for ShutdownInitiated callback to be called") + } + } + + if !beforeShutdown { + t.Fatal("beforeShutdown should be true") + } + if !shutdownInitiated { + t.Fatal("shutdownInitiated should be true") + } +} + +func TestBeforeShutdownCanceled(t *testing.T) { + var wg sync.WaitGroup + wg.Add(1) + + server, l, err := createListener(1 * time.Millisecond) + if err != nil { + t.Fatal(err) + } + + beforeShutdownCalled := make(chan struct{}) + cb1 := func() bool { close(beforeShutdownCalled); return false } + shutdownInitiatedCalled := make(chan struct{}) + cb2 := func() { close(shutdownInitiatedCalled) } + + srv := &Server{Server: server, BeforeShutdown: cb1, ShutdownInitiated: cb2} + go func() { + srv.Serve(l) + wg.Done() + }() + go func() { + time.Sleep(waitTime) + srv.Stop(killTime) + }() + + beforeShutdown := false + shutdownInitiated := false + timeouted := false + + for i := 0; i < 2; i++ { + select { + case <-beforeShutdownCalled: + beforeShutdownCalled = nil + beforeShutdown = true + case <-shutdownInitiatedCalled: + shutdownInitiatedCalled = nil + shutdownInitiated = true + case <-time.After(killTime): + timeouted = true + } + } + + if !beforeShutdown { + t.Fatal("beforeShutdown should be true") + } + if !timeouted { + t.Fatal("timeouted should be true") + } + if shutdownInitiated { + t.Fatal("shutdownInitiated shouldn't be true") + } + + srv.BeforeShutdown = func() bool { return true } + srv.Stop(killTime) + + wg.Wait() +} + +func hijackingListener(srv *Server) (*http.Server, net.Listener, error) { + mux := http.NewServeMux() + mux.HandleFunc("/", func(rw http.ResponseWriter, r *http.Request) { + conn, bufrw, err := rw.(http.Hijacker).Hijack() + if err != nil { + http.Error(rw, "webserver doesn't support hijacking", http.StatusInternalServerError) + return + } + + defer conn.Close() + + bufrw.WriteString("HTTP/1.1 200 OK\r\n\r\n") + bufrw.Flush() + }) + + server := &http.Server{Addr: fmt.Sprintf(":%d", port), Handler: mux} + l, err := net.Listen("tcp", fmt.Sprintf(":%d", port)) + return server, l, err +} + +func TestNotifyClosed(t *testing.T) { + var wg sync.WaitGroup + defer wg.Wait() + + c := make(chan os.Signal, 1) + srv := &Server{Timeout: killTime, interrupt: c} + server, l, err := hijackingListener(srv) + if err != nil { + t.Fatal(err) + } + + srv.Server = server + + wg.Add(1) + go func() { + defer wg.Done() + srv.Serve(l) + }() + + var once sync.Once + for i := 0; i < concurrentRequestN; i++ { + wg.Add(1) + runQuery(t, http.StatusOK, false, &wg, &once) + } + + srv.Stop(0) + + // block on the stopChan until the server has shut down + select { + case <-srv.StopChan(): + case <-time.After(timeoutTime): + t.Fatal("Timed out while waiting for explicit stop to complete") + } + + if len(srv.connections) > 0 { + t.Fatal("hijacked connections should not be managed") + } + +} + +func TestStopDeadlock(t *testing.T) { + var wg sync.WaitGroup + defer wg.Wait() + + c := make(chan struct{}) + server, l, err := createListener(1 * time.Millisecond) + if err != nil { + t.Fatal(err) + } + + srv := &Server{Server: server, NoSignalHandling: true} + + wg.Add(2) + go func() { + defer wg.Done() + time.Sleep(waitTime) + srv.Serve(l) + }() + go func() { + defer wg.Done() + srv.Stop(0) + close(c) + }() + + select { + case <-c: + l.Close() + case <-time.After(timeoutTime): + t.Fatal("Timed out while waiting for explicit stop to complete") + } +} + +// Run with --race +func TestStopRace(t *testing.T) { + server, l, err := createListener(1 * time.Millisecond) + if err != nil { + t.Fatal(err) + } + + srv := &Server{Timeout: killTime, Server: server} + + go func() { + go srv.Serve(l) + srv.Stop(killTime) + }() + srv.Stop(0) + select { + case <-srv.StopChan(): + case <-time.After(timeoutTime): + t.Fatal("Timed out while waiting for explicit stop to complete") + } +} + +func TestInterruptLog(t *testing.T) { + c := make(chan os.Signal, 1) + + server, l, err := createListener(killTime * 10) + if err != nil { + t.Fatal(err) + } + + var buf bytes.Buffer + var tbuf bytes.Buffer + logger := log.New(&buf, "", 0) + expected := log.New(&tbuf, "", 0) + + srv := &Server{Timeout: killTime, Server: server, Logger: logger, interrupt: c} + go func() { srv.Serve(l) }() + + stop := srv.StopChan() + c <- os.Interrupt + expected.Print("shutdown initiated") + + <-stop + + if buf.String() != tbuf.String() { + t.Fatal("shutdown log incorrect - got '" + buf.String() + "'") + } +} + +func TestMultiInterrupts(t *testing.T) { + c := make(chan os.Signal, 1) + + server, l, err := createListener(killTime * 10) + if err != nil { + t.Fatal(err) + } + + var wg sync.WaitGroup + var bu bytes.Buffer + buf := SyncBuffer{&wg, &bu} + var tbuf bytes.Buffer + logger := log.New(&buf, "", 0) + expected := log.New(&tbuf, "", 0) + + srv := &Server{Timeout: killTime, Server: server, Logger: logger, interrupt: c} + go func() { srv.Serve(l) }() + + stop := srv.StopChan() + buf.Add(1 + 10) // Expecting 11 log calls + c <- os.Interrupt + expected.Printf("shutdown initiated") + for i := 0; i < 10; i++ { + c <- os.Interrupt + expected.Printf("already shutting down") + } + + <-stop + + wg.Wait() + bb, bt := buf.Bytes(), tbuf.Bytes() + for i, b := range bb { + if b != bt[i] { + t.Fatal(fmt.Sprintf("shutdown log incorrect - got '%s', expected '%s'", buf.String(), tbuf.String())) + } + } +} + +func TestLogFunc(t *testing.T) { + c := make(chan os.Signal, 1) + + server, l, err := createListener(killTime * 10) + if err != nil { + t.Fatal(err) + } + var called bool + srv := &Server{Timeout: killTime, Server: server, + LogFunc: func(format string, args ...interface{}) { + called = true + }, interrupt: c} + stop := srv.StopChan() + go func() { srv.Serve(l) }() + c <- os.Interrupt + <-stop + + if called != true { + t.Fatal("Expected LogFunc to be called.") + } +} + +// SyncBuffer calls Done on the embedded wait group after each call to Write. +type SyncBuffer struct { + *sync.WaitGroup + *bytes.Buffer +} + +func (buf *SyncBuffer) Write(b []byte) (int, error) { + defer buf.Done() + return buf.Buffer.Write(b) +} -- cgit v1.2.3-1-g7c22