1package graceful
2
3import (
4	"crypto/tls"
5	"log"
6	"net"
7	"net/http"
8	"os"
9	"sync"
10	"time"
11)
12
13// Server wraps an http.Server with graceful connection handling.
14// It may be used directly in the same way as http.Server, or may
15// be constructed with the global functions in this package.
16//
17// Example:
18//	srv := &graceful.Server{
19//		Timeout: 5 * time.Second,
20//		Server: &http.Server{Addr: ":1234", Handler: handler},
21//	}
22//	srv.ListenAndServe()
23type Server struct {
24	*http.Server
25
26	// Timeout is the duration to allow outstanding requests to survive
27	// before forcefully terminating them.
28	Timeout time.Duration
29
30	// Limit the number of outstanding requests
31	ListenLimit int
32
33	// TCPKeepAlive sets the TCP keep-alive timeouts on accepted
34	// connections. It prunes dead TCP connections ( e.g. closing
35	// laptop mid-download)
36	TCPKeepAlive time.Duration
37
38	// ConnState specifies an optional callback function that is
39	// called when a client connection changes state. This is a proxy
40	// to the underlying http.Server's ConnState, and the original
41	// must not be set directly.
42	ConnState func(net.Conn, http.ConnState)
43
44	// BeforeShutdown is an optional callback function that is called
45	// before the listener is closed. Returns true if shutdown is allowed
46	BeforeShutdown func() bool
47
48	// ShutdownInitiated is an optional callback function that is called
49	// when shutdown is initiated. It can be used to notify the client
50	// side of long lived connections (e.g. websockets) to reconnect.
51	ShutdownInitiated func()
52
53	// NoSignalHandling prevents graceful from automatically shutting down
54	// on SIGINT and SIGTERM. If set to true, you must shut down the server
55	// manually with Stop().
56	NoSignalHandling bool
57
58	// Logger used to notify of errors on startup and on stop.
59	Logger *log.Logger
60
61	// LogFunc can be assigned with a logging function of your choice, allowing
62	// you to use whatever logging approach you would like
63	LogFunc func(format string, args ...interface{})
64
65	// Interrupted is true if the server is handling a SIGINT or SIGTERM
66	// signal and is thus shutting down.
67	Interrupted bool
68
69	// interrupt signals the listener to stop serving connections,
70	// and the server to shut down.
71	interrupt chan os.Signal
72
73	// stopLock is used to protect against concurrent calls to Stop
74	stopLock sync.Mutex
75
76	// stopChan is the channel on which callers may block while waiting for
77	// the server to stop.
78	stopChan chan struct{}
79
80	// chanLock is used to protect access to the various channel constructors.
81	chanLock sync.RWMutex
82
83	// connections holds all connections managed by graceful
84	connections map[net.Conn]struct{}
85
86	// idleConnections holds all idle connections managed by graceful
87	idleConnections map[net.Conn]struct{}
88}
89
90// Run serves the http.Handler with graceful shutdown enabled.
91//
92// timeout is the duration to wait until killing active requests and stopping the server.
93// If timeout is 0, the server never times out. It waits for all active requests to finish.
94func Run(addr string, timeout time.Duration, n http.Handler) {
95	srv := &Server{
96		Timeout:      timeout,
97		TCPKeepAlive: 3 * time.Minute,
98		Server:       &http.Server{Addr: addr, Handler: n},
99		// Logger:       DefaultLogger(),
100	}
101
102	if err := srv.ListenAndServe(); err != nil {
103		if opErr, ok := err.(*net.OpError); !ok || (ok && opErr.Op != "accept") {
104			srv.logf("%s", err)
105			os.Exit(1)
106		}
107	}
108
109}
110
111// RunWithErr is an alternative version of Run function which can return error.
112//
113// Unlike Run this version will not exit the program if an error is encountered but will
114// return it instead.
115func RunWithErr(addr string, timeout time.Duration, n http.Handler) error {
116	srv := &Server{
117		Timeout:      timeout,
118		TCPKeepAlive: 3 * time.Minute,
119		Server:       &http.Server{Addr: addr, Handler: n},
120		Logger:       DefaultLogger(),
121	}
122
123	return srv.ListenAndServe()
124}
125
126// ListenAndServe is equivalent to http.Server.ListenAndServe with graceful shutdown enabled.
127//
128// timeout is the duration to wait until killing active requests and stopping the server.
129// If timeout is 0, the server never times out. It waits for all active requests to finish.
130func ListenAndServe(server *http.Server, timeout time.Duration) error {
131	srv := &Server{Timeout: timeout, Server: server, Logger: DefaultLogger()}
132	return srv.ListenAndServe()
133}
134
135// ListenAndServe is equivalent to http.Server.ListenAndServe with graceful shutdown enabled.
136func (srv *Server) ListenAndServe() error {
137	// Create the listener so we can control their lifetime
138	addr := srv.Addr
139	if addr == "" {
140		addr = ":http"
141	}
142	conn, err := srv.newTCPListener(addr)
143	if err != nil {
144		return err
145	}
146
147	return srv.Serve(conn)
148}
149
150// ListenAndServeTLS is equivalent to http.Server.ListenAndServeTLS with graceful shutdown enabled.
151//
152// timeout is the duration to wait until killing active requests and stopping the server.
153// If timeout is 0, the server never times out. It waits for all active requests to finish.
154func ListenAndServeTLS(server *http.Server, certFile, keyFile string, timeout time.Duration) error {
155	srv := &Server{Timeout: timeout, Server: server, Logger: DefaultLogger()}
156	return srv.ListenAndServeTLS(certFile, keyFile)
157}
158
159// ListenTLS is a convenience method that creates an https listener using the
160// provided cert and key files. Use this method if you need access to the
161// listener object directly. When ready, pass it to the Serve method.
162func (srv *Server) ListenTLS(certFile, keyFile string) (net.Listener, error) {
163	// Create the listener ourselves so we can control its lifetime
164	addr := srv.Addr
165	if addr == "" {
166		addr = ":https"
167	}
168
169	config := &tls.Config{}
170	if srv.TLSConfig != nil {
171		*config = *srv.TLSConfig
172	}
173
174	var err error
175	if certFile != "" && keyFile != "" {
176		config.Certificates = make([]tls.Certificate, 1)
177		config.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile)
178		if err != nil {
179			return nil, err
180		}
181	}
182
183	// Enable http2
184	enableHTTP2ForTLSConfig(config)
185
186	conn, err := srv.newTCPListener(addr)
187	if err != nil {
188		return nil, err
189	}
190
191	srv.TLSConfig = config
192
193	tlsListener := tls.NewListener(conn, config)
194	return tlsListener, nil
195}
196
197// Enable HTTP2ForTLSConfig explicitly enables http/2 for a TLS Config. This is due to changes in Go 1.7 where
198// http servers are no longer automatically configured to enable http/2 if the server's TLSConfig is set.
199// See https://github.com/golang/go/issues/15908
200func enableHTTP2ForTLSConfig(t *tls.Config) {
201
202	if TLSConfigHasHTTP2Enabled(t) {
203		return
204	}
205
206	t.NextProtos = append(t.NextProtos, "h2")
207}
208
209// TLSConfigHasHTTP2Enabled checks to see if a given TLS Config has http2 enabled.
210func TLSConfigHasHTTP2Enabled(t *tls.Config) bool {
211	for _, value := range t.NextProtos {
212		if value == "h2" {
213			return true
214		}
215	}
216	return false
217}
218
219// ListenAndServeTLS is equivalent to http.Server.ListenAndServeTLS with graceful shutdown enabled.
220func (srv *Server) ListenAndServeTLS(certFile, keyFile string) error {
221	l, err := srv.ListenTLS(certFile, keyFile)
222	if err != nil {
223		return err
224	}
225
226	return srv.Serve(l)
227}
228
229// ListenAndServeTLSConfig can be used with an existing TLS config and is equivalent to
230// http.Server.ListenAndServeTLS with graceful shutdown enabled,
231func (srv *Server) ListenAndServeTLSConfig(config *tls.Config) error {
232	addr := srv.Addr
233	if addr == "" {
234		addr = ":https"
235	}
236
237	conn, err := srv.newTCPListener(addr)
238	if err != nil {
239		return err
240	}
241
242	srv.TLSConfig = config
243
244	tlsListener := tls.NewListener(conn, config)
245	return srv.Serve(tlsListener)
246}
247
248// Serve is equivalent to http.Server.Serve with graceful shutdown enabled.
249//
250// timeout is the duration to wait until killing active requests and stopping the server.
251// If timeout is 0, the server never times out. It waits for all active requests to finish.
252func Serve(server *http.Server, l net.Listener, timeout time.Duration) error {
253	srv := &Server{Timeout: timeout, Server: server, Logger: DefaultLogger()}
254
255	return srv.Serve(l)
256}
257
258// Serve is equivalent to http.Server.Serve with graceful shutdown enabled.
259func (srv *Server) Serve(listener net.Listener) error {
260
261	if srv.ListenLimit != 0 {
262		listener = LimitListener(listener, srv.ListenLimit)
263	}
264
265	// Make our stopchan
266	srv.StopChan()
267
268	// Track connection state
269	add := make(chan net.Conn)
270	idle := make(chan net.Conn)
271	active := make(chan net.Conn)
272	remove := make(chan net.Conn)
273
274	srv.Server.ConnState = func(conn net.Conn, state http.ConnState) {
275		switch state {
276		case http.StateNew:
277			add <- conn
278		case http.StateActive:
279			active <- conn
280		case http.StateIdle:
281			idle <- conn
282		case http.StateClosed, http.StateHijacked:
283			remove <- conn
284		}
285
286		srv.stopLock.Lock()
287		defer srv.stopLock.Unlock()
288
289		if srv.ConnState != nil {
290			srv.ConnState(conn, state)
291		}
292	}
293
294	// Manage open connections
295	shutdown := make(chan chan struct{})
296	kill := make(chan struct{})
297	go srv.manageConnections(add, idle, active, remove, shutdown, kill)
298
299	interrupt := srv.interruptChan()
300	// Set up the interrupt handler
301	if !srv.NoSignalHandling {
302		signalNotify(interrupt)
303	}
304	quitting := make(chan struct{})
305	go srv.handleInterrupt(interrupt, quitting, listener)
306
307	// Serve with graceful listener.
308	// Execution blocks here until listener.Close() is called, above.
309	err := srv.Server.Serve(listener)
310	if err != nil {
311		// If the underlying listening is closed, Serve returns an error
312		// complaining about listening on a closed socket. This is expected, so
313		// let's ignore the error if we are the ones who explicitly closed the
314		// socket.
315		select {
316		case <-quitting:
317			err = nil
318		default:
319		}
320	}
321
322	srv.shutdown(shutdown, kill)
323
324	return err
325}
326
327// Stop instructs the type to halt operations and close
328// the stop channel when it is finished.
329//
330// timeout is grace period for which to wait before shutting
331// down the server. The timeout value passed here will override the
332// timeout given when constructing the server, as this is an explicit
333// command to stop the server.
334func (srv *Server) Stop(timeout time.Duration) {
335	srv.stopLock.Lock()
336	defer srv.stopLock.Unlock()
337
338	srv.Timeout = timeout
339	sendSignalInt(srv.interruptChan())
340}
341
342// StopChan gets the stop channel which will block until
343// stopping has completed, at which point it is closed.
344// Callers should never close the stop channel.
345func (srv *Server) StopChan() <-chan struct{} {
346	srv.chanLock.Lock()
347	defer srv.chanLock.Unlock()
348
349	if srv.stopChan == nil {
350		srv.stopChan = make(chan struct{})
351	}
352	return srv.stopChan
353}
354
355// DefaultLogger returns the logger used by Run, RunWithErr, ListenAndServe, ListenAndServeTLS and Serve.
356// The logger outputs to STDERR by default.
357func DefaultLogger() *log.Logger {
358	return log.New(os.Stderr, "[graceful] ", 0)
359}
360
361func (srv *Server) manageConnections(add, idle, active, remove chan net.Conn, shutdown chan chan struct{}, kill chan struct{}) {
362	var done chan struct{}
363	srv.connections = map[net.Conn]struct{}{}
364	srv.idleConnections = map[net.Conn]struct{}{}
365	for {
366		select {
367		case conn := <-add:
368			srv.connections[conn] = struct{}{}
369			srv.idleConnections[conn] = struct{}{} // Newly-added connections are considered idle until they become active.
370		case conn := <-idle:
371			srv.idleConnections[conn] = struct{}{}
372		case conn := <-active:
373			delete(srv.idleConnections, conn)
374		case conn := <-remove:
375			delete(srv.connections, conn)
376			delete(srv.idleConnections, conn)
377			if done != nil && len(srv.connections) == 0 {
378				done <- struct{}{}
379				return
380			}
381		case done = <-shutdown:
382			if len(srv.connections) == 0 && len(srv.idleConnections) == 0 {
383				done <- struct{}{}
384				return
385			}
386			// a shutdown request has been received. if we have open idle
387			// connections, we must close all of them now. this prevents idle
388			// connections from holding the server open while waiting for them to
389			// hit their idle timeout.
390			for k := range srv.idleConnections {
391				if err := k.Close(); err != nil {
392					srv.logf("[ERROR] %s", err)
393				}
394			}
395		case <-kill:
396			srv.stopLock.Lock()
397			defer srv.stopLock.Unlock()
398
399			srv.Server.ConnState = nil
400			for k := range srv.connections {
401				if err := k.Close(); err != nil {
402					srv.logf("[ERROR] %s", err)
403				}
404			}
405			return
406		}
407	}
408}
409
410func (srv *Server) interruptChan() chan os.Signal {
411	srv.chanLock.Lock()
412	defer srv.chanLock.Unlock()
413
414	if srv.interrupt == nil {
415		srv.interrupt = make(chan os.Signal, 1)
416	}
417
418	return srv.interrupt
419}
420
421func (srv *Server) handleInterrupt(interrupt chan os.Signal, quitting chan struct{}, listener net.Listener) {
422	for _ = range interrupt {
423		if srv.Interrupted {
424			srv.logf("already shutting down")
425			continue
426		}
427		srv.logf("shutdown initiated")
428		srv.Interrupted = true
429		if srv.BeforeShutdown != nil {
430			if !srv.BeforeShutdown() {
431				srv.Interrupted = false
432				continue
433			}
434		}
435
436		close(quitting)
437		srv.SetKeepAlivesEnabled(false)
438		if err := listener.Close(); err != nil {
439			srv.logf("[ERROR] %s", err)
440		}
441
442		if srv.ShutdownInitiated != nil {
443			srv.ShutdownInitiated()
444		}
445	}
446}
447
448func (srv *Server) logf(format string, args ...interface{}) {
449	if srv.LogFunc != nil {
450		srv.LogFunc(format, args...)
451	} else if srv.Logger != nil {
452		srv.Logger.Printf(format, args...)
453	}
454}
455
456func (srv *Server) shutdown(shutdown chan chan struct{}, kill chan struct{}) {
457	// Request done notification
458	done := make(chan struct{})
459	shutdown <- done
460
461	srv.stopLock.Lock()
462	defer srv.stopLock.Unlock()
463	if srv.Timeout > 0 {
464		select {
465		case <-done:
466		case <-time.After(srv.Timeout):
467			close(kill)
468		}
469	} else {
470		<-done
471	}
472	// Close the stopChan to wake up any blocked goroutines.
473	srv.chanLock.Lock()
474	if srv.stopChan != nil {
475		close(srv.stopChan)
476	}
477	srv.chanLock.Unlock()
478}
479
480func (srv *Server) newTCPListener(addr string) (net.Listener, error) {
481	conn, err := net.Listen("tcp", addr)
482	if err != nil {
483		return conn, err
484	}
485	if srv.TCPKeepAlive != 0 {
486		conn = keepAliveListener{conn, srv.TCPKeepAlive}
487	}
488	return conn, nil
489}
490