summaryrefslogtreecommitdiffstats
path: root/vendor/github.com/tylerb/graceful/graceful.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/tylerb/graceful/graceful.go')
-rw-r--r--vendor/github.com/tylerb/graceful/graceful.go487
1 files changed, 487 insertions, 0 deletions
diff --git a/vendor/github.com/tylerb/graceful/graceful.go b/vendor/github.com/tylerb/graceful/graceful.go
new file mode 100644
index 000000000..a5e2395e0
--- /dev/null
+++ b/vendor/github.com/tylerb/graceful/graceful.go
@@ -0,0 +1,487 @@
+package graceful
+
+import (
+ "crypto/tls"
+ "log"
+ "net"
+ "net/http"
+ "os"
+ "os/signal"
+ "sync"
+ "syscall"
+ "time"
+)
+
+// Server wraps an http.Server with graceful connection handling.
+// It may be used directly in the same way as http.Server, or may
+// be constructed with the global functions in this package.
+//
+// Example:
+// srv := &graceful.Server{
+// Timeout: 5 * time.Second,
+// Server: &http.Server{Addr: ":1234", Handler: handler},
+// }
+// srv.ListenAndServe()
+type Server struct {
+ *http.Server
+
+ // Timeout is the duration to allow outstanding requests to survive
+ // before forcefully terminating them.
+ Timeout time.Duration
+
+ // Limit the number of outstanding requests
+ ListenLimit int
+
+ // TCPKeepAlive sets the TCP keep-alive timeouts on accepted
+ // connections. It prunes dead TCP connections ( e.g. closing
+ // laptop mid-download)
+ TCPKeepAlive time.Duration
+
+ // ConnState specifies an optional callback function that is
+ // called when a client connection changes state. This is a proxy
+ // to the underlying http.Server's ConnState, and the original
+ // must not be set directly.
+ ConnState func(net.Conn, http.ConnState)
+
+ // BeforeShutdown is an optional callback function that is called
+ // before the listener is closed. Returns true if shutdown is allowed
+ BeforeShutdown func() bool
+
+ // ShutdownInitiated is an optional callback function that is called
+ // when shutdown is initiated. It can be used to notify the client
+ // side of long lived connections (e.g. websockets) to reconnect.
+ ShutdownInitiated func()
+
+ // NoSignalHandling prevents graceful from automatically shutting down
+ // on SIGINT and SIGTERM. If set to true, you must shut down the server
+ // manually with Stop().
+ NoSignalHandling bool
+
+ // Logger used to notify of errors on startup and on stop.
+ Logger *log.Logger
+
+ // LogFunc can be assigned with a logging function of your choice, allowing
+ // you to use whatever logging approach you would like
+ LogFunc func(format string, args ...interface{})
+
+ // Interrupted is true if the server is handling a SIGINT or SIGTERM
+ // signal and is thus shutting down.
+ Interrupted bool
+
+ // interrupt signals the listener to stop serving connections,
+ // and the server to shut down.
+ interrupt chan os.Signal
+
+ // stopLock is used to protect against concurrent calls to Stop
+ stopLock sync.Mutex
+
+ // stopChan is the channel on which callers may block while waiting for
+ // the server to stop.
+ stopChan chan struct{}
+
+ // chanLock is used to protect access to the various channel constructors.
+ chanLock sync.RWMutex
+
+ // connections holds all connections managed by graceful
+ connections map[net.Conn]struct{}
+
+ // idleConnections holds all idle connections managed by graceful
+ idleConnections map[net.Conn]struct{}
+}
+
+// Run serves the http.Handler with graceful shutdown enabled.
+//
+// timeout is the duration to wait until killing active requests and stopping the server.
+// If timeout is 0, the server never times out. It waits for all active requests to finish.
+func Run(addr string, timeout time.Duration, n http.Handler) {
+ srv := &Server{
+ Timeout: timeout,
+ TCPKeepAlive: 3 * time.Minute,
+ Server: &http.Server{Addr: addr, Handler: n},
+ // Logger: DefaultLogger(),
+ }
+
+ if err := srv.ListenAndServe(); err != nil {
+ if opErr, ok := err.(*net.OpError); !ok || (ok && opErr.Op != "accept") {
+ srv.logf("%s", err)
+ os.Exit(1)
+ }
+ }
+
+}
+
+// RunWithErr is an alternative version of Run function which can return error.
+//
+// Unlike Run this version will not exit the program if an error is encountered but will
+// return it instead.
+func RunWithErr(addr string, timeout time.Duration, n http.Handler) error {
+ srv := &Server{
+ Timeout: timeout,
+ TCPKeepAlive: 3 * time.Minute,
+ Server: &http.Server{Addr: addr, Handler: n},
+ Logger: DefaultLogger(),
+ }
+
+ return srv.ListenAndServe()
+}
+
+// ListenAndServe is equivalent to http.Server.ListenAndServe with graceful shutdown enabled.
+//
+// timeout is the duration to wait until killing active requests and stopping the server.
+// If timeout is 0, the server never times out. It waits for all active requests to finish.
+func ListenAndServe(server *http.Server, timeout time.Duration) error {
+ srv := &Server{Timeout: timeout, Server: server, Logger: DefaultLogger()}
+ return srv.ListenAndServe()
+}
+
+// ListenAndServe is equivalent to http.Server.ListenAndServe with graceful shutdown enabled.
+func (srv *Server) ListenAndServe() error {
+ // Create the listener so we can control their lifetime
+ addr := srv.Addr
+ if addr == "" {
+ addr = ":http"
+ }
+ conn, err := srv.newTCPListener(addr)
+ if err != nil {
+ return err
+ }
+
+ return srv.Serve(conn)
+}
+
+// ListenAndServeTLS is equivalent to http.Server.ListenAndServeTLS with graceful shutdown enabled.
+//
+// timeout is the duration to wait until killing active requests and stopping the server.
+// If timeout is 0, the server never times out. It waits for all active requests to finish.
+func ListenAndServeTLS(server *http.Server, certFile, keyFile string, timeout time.Duration) error {
+ srv := &Server{Timeout: timeout, Server: server, Logger: DefaultLogger()}
+ return srv.ListenAndServeTLS(certFile, keyFile)
+}
+
+// ListenTLS is a convenience method that creates an https listener using the
+// provided cert and key files. Use this method if you need access to the
+// listener object directly. When ready, pass it to the Serve method.
+func (srv *Server) ListenTLS(certFile, keyFile string) (net.Listener, error) {
+ // Create the listener ourselves so we can control its lifetime
+ addr := srv.Addr
+ if addr == "" {
+ addr = ":https"
+ }
+
+ config := &tls.Config{}
+ if srv.TLSConfig != nil {
+ *config = *srv.TLSConfig
+ }
+
+ var err error
+ config.Certificates = make([]tls.Certificate, 1)
+ config.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile)
+ if err != nil {
+ return nil, err
+ }
+
+ // Enable http2
+ enableHTTP2ForTLSConfig(config)
+
+ conn, err := srv.newTCPListener(addr)
+ if err != nil {
+ return nil, err
+ }
+
+ srv.TLSConfig = config
+
+ tlsListener := tls.NewListener(conn, config)
+ return tlsListener, nil
+}
+
+// Enable HTTP2ForTLSConfig explicitly enables http/2 for a TLS Config. This is due to changes in Go 1.7 where
+// http servers are no longer automatically configured to enable http/2 if the server's TLSConfig is set.
+// See https://github.com/golang/go/issues/15908
+func enableHTTP2ForTLSConfig(t *tls.Config) {
+
+ if TLSConfigHasHTTP2Enabled(t) {
+ return
+ }
+
+ t.NextProtos = append(t.NextProtos, "h2")
+}
+
+// TLSConfigHasHTTP2Enabled checks to see if a given TLS Config has http2 enabled.
+func TLSConfigHasHTTP2Enabled(t *tls.Config) bool {
+ for _, value := range t.NextProtos {
+ if value == "h2" {
+ return true
+ }
+ }
+ return false
+}
+
+// ListenAndServeTLS is equivalent to http.Server.ListenAndServeTLS with graceful shutdown enabled.
+func (srv *Server) ListenAndServeTLS(certFile, keyFile string) error {
+ l, err := srv.ListenTLS(certFile, keyFile)
+ if err != nil {
+ return err
+ }
+
+ return srv.Serve(l)
+}
+
+// ListenAndServeTLSConfig can be used with an existing TLS config and is equivalent to
+// http.Server.ListenAndServeTLS with graceful shutdown enabled,
+func (srv *Server) ListenAndServeTLSConfig(config *tls.Config) error {
+ addr := srv.Addr
+ if addr == "" {
+ addr = ":https"
+ }
+
+ conn, err := srv.newTCPListener(addr)
+ if err != nil {
+ return err
+ }
+
+ srv.TLSConfig = config
+
+ tlsListener := tls.NewListener(conn, config)
+ return srv.Serve(tlsListener)
+}
+
+// Serve is equivalent to http.Server.Serve with graceful shutdown enabled.
+//
+// timeout is the duration to wait until killing active requests and stopping the server.
+// If timeout is 0, the server never times out. It waits for all active requests to finish.
+func Serve(server *http.Server, l net.Listener, timeout time.Duration) error {
+ srv := &Server{Timeout: timeout, Server: server, Logger: DefaultLogger()}
+
+ return srv.Serve(l)
+}
+
+// Serve is equivalent to http.Server.Serve with graceful shutdown enabled.
+func (srv *Server) Serve(listener net.Listener) error {
+
+ if srv.ListenLimit != 0 {
+ listener = LimitListener(listener, srv.ListenLimit)
+ }
+
+ // Make our stopchan
+ srv.StopChan()
+
+ // Track connection state
+ add := make(chan net.Conn)
+ idle := make(chan net.Conn)
+ active := make(chan net.Conn)
+ remove := make(chan net.Conn)
+
+ srv.Server.ConnState = func(conn net.Conn, state http.ConnState) {
+ switch state {
+ case http.StateNew:
+ add <- conn
+ case http.StateActive:
+ active <- conn
+ case http.StateIdle:
+ idle <- conn
+ case http.StateClosed, http.StateHijacked:
+ remove <- conn
+ }
+
+ srv.stopLock.Lock()
+ defer srv.stopLock.Unlock()
+
+ if srv.ConnState != nil {
+ srv.ConnState(conn, state)
+ }
+ }
+
+ // Manage open connections
+ shutdown := make(chan chan struct{})
+ kill := make(chan struct{})
+ go srv.manageConnections(add, idle, active, remove, shutdown, kill)
+
+ interrupt := srv.interruptChan()
+ // Set up the interrupt handler
+ if !srv.NoSignalHandling {
+ signal.Notify(interrupt, syscall.SIGINT, syscall.SIGTERM)
+ }
+ quitting := make(chan struct{})
+ go srv.handleInterrupt(interrupt, quitting, listener)
+
+ // Serve with graceful listener.
+ // Execution blocks here until listener.Close() is called, above.
+ err := srv.Server.Serve(listener)
+ if err != nil {
+ // If the underlying listening is closed, Serve returns an error
+ // complaining about listening on a closed socket. This is expected, so
+ // let's ignore the error if we are the ones who explicitly closed the
+ // socket.
+ select {
+ case <-quitting:
+ err = nil
+ default:
+ }
+ }
+
+ srv.shutdown(shutdown, kill)
+
+ return err
+}
+
+// Stop instructs the type to halt operations and close
+// the stop channel when it is finished.
+//
+// timeout is grace period for which to wait before shutting
+// down the server. The timeout value passed here will override the
+// timeout given when constructing the server, as this is an explicit
+// command to stop the server.
+func (srv *Server) Stop(timeout time.Duration) {
+ srv.stopLock.Lock()
+ defer srv.stopLock.Unlock()
+
+ srv.Timeout = timeout
+ interrupt := srv.interruptChan()
+ interrupt <- syscall.SIGINT
+}
+
+// StopChan gets the stop channel which will block until
+// stopping has completed, at which point it is closed.
+// Callers should never close the stop channel.
+func (srv *Server) StopChan() <-chan struct{} {
+ srv.chanLock.Lock()
+ defer srv.chanLock.Unlock()
+
+ if srv.stopChan == nil {
+ srv.stopChan = make(chan struct{})
+ }
+ return srv.stopChan
+}
+
+// DefaultLogger returns the logger used by Run, RunWithErr, ListenAndServe, ListenAndServeTLS and Serve.
+// The logger outputs to STDERR by default.
+func DefaultLogger() *log.Logger {
+ return log.New(os.Stderr, "[graceful] ", 0)
+}
+
+func (srv *Server) manageConnections(add, idle, active, remove chan net.Conn, shutdown chan chan struct{}, kill chan struct{}) {
+ var done chan struct{}
+ srv.connections = map[net.Conn]struct{}{}
+ srv.idleConnections = map[net.Conn]struct{}{}
+ for {
+ select {
+ case conn := <-add:
+ srv.connections[conn] = struct{}{}
+ case conn := <-idle:
+ srv.idleConnections[conn] = struct{}{}
+ case conn := <-active:
+ delete(srv.idleConnections, conn)
+ case conn := <-remove:
+ delete(srv.connections, conn)
+ delete(srv.idleConnections, conn)
+ if done != nil && len(srv.connections) == 0 {
+ done <- struct{}{}
+ return
+ }
+ case done = <-shutdown:
+ if len(srv.connections) == 0 && len(srv.idleConnections) == 0 {
+ done <- struct{}{}
+ return
+ }
+ // a shutdown request has been received. if we have open idle
+ // connections, we must close all of them now. this prevents idle
+ // connections from holding the server open while waiting for them to
+ // hit their idle timeout.
+ for k := range srv.idleConnections {
+ if err := k.Close(); err != nil {
+ srv.logf("[ERROR] %s", err)
+ }
+ }
+ case <-kill:
+ srv.stopLock.Lock()
+ defer srv.stopLock.Unlock()
+
+ srv.Server.ConnState = nil
+ for k := range srv.connections {
+ if err := k.Close(); err != nil {
+ srv.logf("[ERROR] %s", err)
+ }
+ }
+ return
+ }
+ }
+}
+
+func (srv *Server) interruptChan() chan os.Signal {
+ srv.chanLock.Lock()
+ defer srv.chanLock.Unlock()
+
+ if srv.interrupt == nil {
+ srv.interrupt = make(chan os.Signal, 1)
+ }
+
+ return srv.interrupt
+}
+
+func (srv *Server) handleInterrupt(interrupt chan os.Signal, quitting chan struct{}, listener net.Listener) {
+ for _ = range interrupt {
+ if srv.Interrupted {
+ srv.logf("already shutting down")
+ continue
+ }
+ srv.logf("shutdown initiated")
+ srv.Interrupted = true
+ if srv.BeforeShutdown != nil {
+ if !srv.BeforeShutdown() {
+ srv.Interrupted = false
+ continue
+ }
+ }
+
+ close(quitting)
+ srv.SetKeepAlivesEnabled(false)
+ if err := listener.Close(); err != nil {
+ srv.logf("[ERROR] %s", err)
+ }
+
+ if srv.ShutdownInitiated != nil {
+ srv.ShutdownInitiated()
+ }
+ }
+}
+
+func (srv *Server) logf(format string, args ...interface{}) {
+ if srv.LogFunc != nil {
+ srv.LogFunc(format, args...)
+ } else if srv.Logger != nil {
+ srv.Logger.Printf(format, args...)
+ }
+}
+
+func (srv *Server) shutdown(shutdown chan chan struct{}, kill chan struct{}) {
+ // Request done notification
+ done := make(chan struct{})
+ shutdown <- done
+
+ if srv.Timeout > 0 {
+ select {
+ case <-done:
+ case <-time.After(srv.Timeout):
+ close(kill)
+ }
+ } else {
+ <-done
+ }
+ // Close the stopChan to wake up any blocked goroutines.
+ srv.chanLock.Lock()
+ if srv.stopChan != nil {
+ close(srv.stopChan)
+ }
+ srv.chanLock.Unlock()
+}
+
+func (srv *Server) newTCPListener(addr string) (net.Listener, error) {
+ conn, err := net.Listen("tcp", addr)
+ if err != nil {
+ return conn, err
+ }
+ if srv.TCPKeepAlive != 0 {
+ conn = keepAliveListener{conn, srv.TCPKeepAlive}
+ }
+ return conn, nil
+}