1package graceful 2 3import ( 4 "crypto/tls" 5 "log" 6 "net" 7 "net/http" 8 "os" 9 "sync" 10 "time" 11) 12 13// Server wraps an http.Server with graceful connection handling. 14// It may be used directly in the same way as http.Server, or may 15// be constructed with the global functions in this package. 16// 17// Example: 18// srv := &graceful.Server{ 19// Timeout: 5 * time.Second, 20// Server: &http.Server{Addr: ":1234", Handler: handler}, 21// } 22// srv.ListenAndServe() 23type Server struct { 24 *http.Server 25 26 // Timeout is the duration to allow outstanding requests to survive 27 // before forcefully terminating them. 28 Timeout time.Duration 29 30 // Limit the number of outstanding requests 31 ListenLimit int 32 33 // TCPKeepAlive sets the TCP keep-alive timeouts on accepted 34 // connections. It prunes dead TCP connections ( e.g. closing 35 // laptop mid-download) 36 TCPKeepAlive time.Duration 37 38 // ConnState specifies an optional callback function that is 39 // called when a client connection changes state. This is a proxy 40 // to the underlying http.Server's ConnState, and the original 41 // must not be set directly. 42 ConnState func(net.Conn, http.ConnState) 43 44 // BeforeShutdown is an optional callback function that is called 45 // before the listener is closed. Returns true if shutdown is allowed 46 BeforeShutdown func() bool 47 48 // ShutdownInitiated is an optional callback function that is called 49 // when shutdown is initiated. It can be used to notify the client 50 // side of long lived connections (e.g. websockets) to reconnect. 51 ShutdownInitiated func() 52 53 // NoSignalHandling prevents graceful from automatically shutting down 54 // on SIGINT and SIGTERM. If set to true, you must shut down the server 55 // manually with Stop(). 56 NoSignalHandling bool 57 58 // Logger used to notify of errors on startup and on stop. 59 Logger *log.Logger 60 61 // LogFunc can be assigned with a logging function of your choice, allowing 62 // you to use whatever logging approach you would like 63 LogFunc func(format string, args ...interface{}) 64 65 // Interrupted is true if the server is handling a SIGINT or SIGTERM 66 // signal and is thus shutting down. 67 Interrupted bool 68 69 // interrupt signals the listener to stop serving connections, 70 // and the server to shut down. 71 interrupt chan os.Signal 72 73 // stopLock is used to protect against concurrent calls to Stop 74 stopLock sync.Mutex 75 76 // stopChan is the channel on which callers may block while waiting for 77 // the server to stop. 78 stopChan chan struct{} 79 80 // chanLock is used to protect access to the various channel constructors. 81 chanLock sync.RWMutex 82 83 // connections holds all connections managed by graceful 84 connections map[net.Conn]struct{} 85 86 // idleConnections holds all idle connections managed by graceful 87 idleConnections map[net.Conn]struct{} 88} 89 90// Run serves the http.Handler with graceful shutdown enabled. 91// 92// timeout is the duration to wait until killing active requests and stopping the server. 93// If timeout is 0, the server never times out. It waits for all active requests to finish. 94func Run(addr string, timeout time.Duration, n http.Handler) { 95 srv := &Server{ 96 Timeout: timeout, 97 TCPKeepAlive: 3 * time.Minute, 98 Server: &http.Server{Addr: addr, Handler: n}, 99 // Logger: DefaultLogger(), 100 } 101 102 if err := srv.ListenAndServe(); err != nil { 103 if opErr, ok := err.(*net.OpError); !ok || (ok && opErr.Op != "accept") { 104 srv.logf("%s", err) 105 os.Exit(1) 106 } 107 } 108 109} 110 111// RunWithErr is an alternative version of Run function which can return error. 112// 113// Unlike Run this version will not exit the program if an error is encountered but will 114// return it instead. 115func RunWithErr(addr string, timeout time.Duration, n http.Handler) error { 116 srv := &Server{ 117 Timeout: timeout, 118 TCPKeepAlive: 3 * time.Minute, 119 Server: &http.Server{Addr: addr, Handler: n}, 120 Logger: DefaultLogger(), 121 } 122 123 return srv.ListenAndServe() 124} 125 126// ListenAndServe is equivalent to http.Server.ListenAndServe with graceful shutdown enabled. 127// 128// timeout is the duration to wait until killing active requests and stopping the server. 129// If timeout is 0, the server never times out. It waits for all active requests to finish. 130func ListenAndServe(server *http.Server, timeout time.Duration) error { 131 srv := &Server{Timeout: timeout, Server: server, Logger: DefaultLogger()} 132 return srv.ListenAndServe() 133} 134 135// ListenAndServe is equivalent to http.Server.ListenAndServe with graceful shutdown enabled. 136func (srv *Server) ListenAndServe() error { 137 // Create the listener so we can control their lifetime 138 addr := srv.Addr 139 if addr == "" { 140 addr = ":http" 141 } 142 conn, err := srv.newTCPListener(addr) 143 if err != nil { 144 return err 145 } 146 147 return srv.Serve(conn) 148} 149 150// ListenAndServeTLS is equivalent to http.Server.ListenAndServeTLS with graceful shutdown enabled. 151// 152// timeout is the duration to wait until killing active requests and stopping the server. 153// If timeout is 0, the server never times out. It waits for all active requests to finish. 154func ListenAndServeTLS(server *http.Server, certFile, keyFile string, timeout time.Duration) error { 155 srv := &Server{Timeout: timeout, Server: server, Logger: DefaultLogger()} 156 return srv.ListenAndServeTLS(certFile, keyFile) 157} 158 159// ListenTLS is a convenience method that creates an https listener using the 160// provided cert and key files. Use this method if you need access to the 161// listener object directly. When ready, pass it to the Serve method. 162func (srv *Server) ListenTLS(certFile, keyFile string) (net.Listener, error) { 163 // Create the listener ourselves so we can control its lifetime 164 addr := srv.Addr 165 if addr == "" { 166 addr = ":https" 167 } 168 169 config := &tls.Config{} 170 if srv.TLSConfig != nil { 171 *config = *srv.TLSConfig 172 } 173 174 var err error 175 if certFile != "" && keyFile != "" { 176 config.Certificates = make([]tls.Certificate, 1) 177 config.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile) 178 if err != nil { 179 return nil, err 180 } 181 } 182 183 // Enable http2 184 enableHTTP2ForTLSConfig(config) 185 186 conn, err := srv.newTCPListener(addr) 187 if err != nil { 188 return nil, err 189 } 190 191 srv.TLSConfig = config 192 193 tlsListener := tls.NewListener(conn, config) 194 return tlsListener, nil 195} 196 197// Enable HTTP2ForTLSConfig explicitly enables http/2 for a TLS Config. This is due to changes in Go 1.7 where 198// http servers are no longer automatically configured to enable http/2 if the server's TLSConfig is set. 199// See https://github.com/golang/go/issues/15908 200func enableHTTP2ForTLSConfig(t *tls.Config) { 201 202 if TLSConfigHasHTTP2Enabled(t) { 203 return 204 } 205 206 t.NextProtos = append(t.NextProtos, "h2") 207} 208 209// TLSConfigHasHTTP2Enabled checks to see if a given TLS Config has http2 enabled. 210func TLSConfigHasHTTP2Enabled(t *tls.Config) bool { 211 for _, value := range t.NextProtos { 212 if value == "h2" { 213 return true 214 } 215 } 216 return false 217} 218 219// ListenAndServeTLS is equivalent to http.Server.ListenAndServeTLS with graceful shutdown enabled. 220func (srv *Server) ListenAndServeTLS(certFile, keyFile string) error { 221 l, err := srv.ListenTLS(certFile, keyFile) 222 if err != nil { 223 return err 224 } 225 226 return srv.Serve(l) 227} 228 229// ListenAndServeTLSConfig can be used with an existing TLS config and is equivalent to 230// http.Server.ListenAndServeTLS with graceful shutdown enabled, 231func (srv *Server) ListenAndServeTLSConfig(config *tls.Config) error { 232 addr := srv.Addr 233 if addr == "" { 234 addr = ":https" 235 } 236 237 conn, err := srv.newTCPListener(addr) 238 if err != nil { 239 return err 240 } 241 242 srv.TLSConfig = config 243 244 tlsListener := tls.NewListener(conn, config) 245 return srv.Serve(tlsListener) 246} 247 248// Serve is equivalent to http.Server.Serve with graceful shutdown enabled. 249// 250// timeout is the duration to wait until killing active requests and stopping the server. 251// If timeout is 0, the server never times out. It waits for all active requests to finish. 252func Serve(server *http.Server, l net.Listener, timeout time.Duration) error { 253 srv := &Server{Timeout: timeout, Server: server, Logger: DefaultLogger()} 254 255 return srv.Serve(l) 256} 257 258// Serve is equivalent to http.Server.Serve with graceful shutdown enabled. 259func (srv *Server) Serve(listener net.Listener) error { 260 261 if srv.ListenLimit != 0 { 262 listener = LimitListener(listener, srv.ListenLimit) 263 } 264 265 // Make our stopchan 266 srv.StopChan() 267 268 // Track connection state 269 add := make(chan net.Conn) 270 idle := make(chan net.Conn) 271 active := make(chan net.Conn) 272 remove := make(chan net.Conn) 273 274 srv.Server.ConnState = func(conn net.Conn, state http.ConnState) { 275 switch state { 276 case http.StateNew: 277 add <- conn 278 case http.StateActive: 279 active <- conn 280 case http.StateIdle: 281 idle <- conn 282 case http.StateClosed, http.StateHijacked: 283 remove <- conn 284 } 285 286 srv.stopLock.Lock() 287 defer srv.stopLock.Unlock() 288 289 if srv.ConnState != nil { 290 srv.ConnState(conn, state) 291 } 292 } 293 294 // Manage open connections 295 shutdown := make(chan chan struct{}) 296 kill := make(chan struct{}) 297 go srv.manageConnections(add, idle, active, remove, shutdown, kill) 298 299 interrupt := srv.interruptChan() 300 // Set up the interrupt handler 301 if !srv.NoSignalHandling { 302 signalNotify(interrupt) 303 } 304 quitting := make(chan struct{}) 305 go srv.handleInterrupt(interrupt, quitting, listener) 306 307 // Serve with graceful listener. 308 // Execution blocks here until listener.Close() is called, above. 309 err := srv.Server.Serve(listener) 310 if err != nil { 311 // If the underlying listening is closed, Serve returns an error 312 // complaining about listening on a closed socket. This is expected, so 313 // let's ignore the error if we are the ones who explicitly closed the 314 // socket. 315 select { 316 case <-quitting: 317 err = nil 318 default: 319 } 320 } 321 322 srv.shutdown(shutdown, kill) 323 324 return err 325} 326 327// Stop instructs the type to halt operations and close 328// the stop channel when it is finished. 329// 330// timeout is grace period for which to wait before shutting 331// down the server. The timeout value passed here will override the 332// timeout given when constructing the server, as this is an explicit 333// command to stop the server. 334func (srv *Server) Stop(timeout time.Duration) { 335 srv.stopLock.Lock() 336 defer srv.stopLock.Unlock() 337 338 srv.Timeout = timeout 339 sendSignalInt(srv.interruptChan()) 340} 341 342// StopChan gets the stop channel which will block until 343// stopping has completed, at which point it is closed. 344// Callers should never close the stop channel. 345func (srv *Server) StopChan() <-chan struct{} { 346 srv.chanLock.Lock() 347 defer srv.chanLock.Unlock() 348 349 if srv.stopChan == nil { 350 srv.stopChan = make(chan struct{}) 351 } 352 return srv.stopChan 353} 354 355// DefaultLogger returns the logger used by Run, RunWithErr, ListenAndServe, ListenAndServeTLS and Serve. 356// The logger outputs to STDERR by default. 357func DefaultLogger() *log.Logger { 358 return log.New(os.Stderr, "[graceful] ", 0) 359} 360 361func (srv *Server) manageConnections(add, idle, active, remove chan net.Conn, shutdown chan chan struct{}, kill chan struct{}) { 362 var done chan struct{} 363 srv.connections = map[net.Conn]struct{}{} 364 srv.idleConnections = map[net.Conn]struct{}{} 365 for { 366 select { 367 case conn := <-add: 368 srv.connections[conn] = struct{}{} 369 srv.idleConnections[conn] = struct{}{} // Newly-added connections are considered idle until they become active. 370 case conn := <-idle: 371 srv.idleConnections[conn] = struct{}{} 372 case conn := <-active: 373 delete(srv.idleConnections, conn) 374 case conn := <-remove: 375 delete(srv.connections, conn) 376 delete(srv.idleConnections, conn) 377 if done != nil && len(srv.connections) == 0 { 378 done <- struct{}{} 379 return 380 } 381 case done = <-shutdown: 382 if len(srv.connections) == 0 && len(srv.idleConnections) == 0 { 383 done <- struct{}{} 384 return 385 } 386 // a shutdown request has been received. if we have open idle 387 // connections, we must close all of them now. this prevents idle 388 // connections from holding the server open while waiting for them to 389 // hit their idle timeout. 390 for k := range srv.idleConnections { 391 if err := k.Close(); err != nil { 392 srv.logf("[ERROR] %s", err) 393 } 394 } 395 case <-kill: 396 srv.stopLock.Lock() 397 defer srv.stopLock.Unlock() 398 399 srv.Server.ConnState = nil 400 for k := range srv.connections { 401 if err := k.Close(); err != nil { 402 srv.logf("[ERROR] %s", err) 403 } 404 } 405 return 406 } 407 } 408} 409 410func (srv *Server) interruptChan() chan os.Signal { 411 srv.chanLock.Lock() 412 defer srv.chanLock.Unlock() 413 414 if srv.interrupt == nil { 415 srv.interrupt = make(chan os.Signal, 1) 416 } 417 418 return srv.interrupt 419} 420 421func (srv *Server) handleInterrupt(interrupt chan os.Signal, quitting chan struct{}, listener net.Listener) { 422 for _ = range interrupt { 423 if srv.Interrupted { 424 srv.logf("already shutting down") 425 continue 426 } 427 srv.logf("shutdown initiated") 428 srv.Interrupted = true 429 if srv.BeforeShutdown != nil { 430 if !srv.BeforeShutdown() { 431 srv.Interrupted = false 432 continue 433 } 434 } 435 436 close(quitting) 437 srv.SetKeepAlivesEnabled(false) 438 if err := listener.Close(); err != nil { 439 srv.logf("[ERROR] %s", err) 440 } 441 442 if srv.ShutdownInitiated != nil { 443 srv.ShutdownInitiated() 444 } 445 } 446} 447 448func (srv *Server) logf(format string, args ...interface{}) { 449 if srv.LogFunc != nil { 450 srv.LogFunc(format, args...) 451 } else if srv.Logger != nil { 452 srv.Logger.Printf(format, args...) 453 } 454} 455 456func (srv *Server) shutdown(shutdown chan chan struct{}, kill chan struct{}) { 457 // Request done notification 458 done := make(chan struct{}) 459 shutdown <- done 460 461 srv.stopLock.Lock() 462 defer srv.stopLock.Unlock() 463 if srv.Timeout > 0 { 464 select { 465 case <-done: 466 case <-time.After(srv.Timeout): 467 close(kill) 468 } 469 } else { 470 <-done 471 } 472 // Close the stopChan to wake up any blocked goroutines. 473 srv.chanLock.Lock() 474 if srv.stopChan != nil { 475 close(srv.stopChan) 476 } 477 srv.chanLock.Unlock() 478} 479 480func (srv *Server) newTCPListener(addr string) (net.Listener, error) { 481 conn, err := net.Listen("tcp", addr) 482 if err != nil { 483 return conn, err 484 } 485 if srv.TCPKeepAlive != 0 { 486 conn = keepAliveListener{conn, srv.TCPKeepAlive} 487 } 488 return conn, nil 489} 490