1// Package httpdown provides http.ConnState enabled graceful termination of
2// http.Server.
3package httpdown
4
5import (
6	"crypto/tls"
7	"fmt"
8	"net"
9	"net/http"
10	"os"
11	"os/signal"
12	"sync"
13	"syscall"
14	"time"
15
16	"github.com/facebookgo/clock"
17	"github.com/facebookgo/stats"
18)
19
20const (
21	defaultStopTimeout = time.Minute
22	defaultKillTimeout = time.Minute
23)
24
25// A Server allows encapsulates the process of accepting new connections and
26// serving them, and gracefully shutting down the listener without dropping
27// active connections.
28type Server interface {
29	// Wait waits for the serving loop to finish. This will happen when Stop is
30	// called, at which point it returns no error, or if there is an error in the
31	// serving loop. You must call Wait after calling Serve or ListenAndServe.
32	Wait() error
33
34	// Stop stops the listener. It will block until all connections have been
35	// closed.
36	Stop() error
37}
38
39// HTTP defines the configuration for serving a http.Server. Multiple calls to
40// Serve or ListenAndServe can be made on the same HTTP instance. The default
41// timeouts of 1 minute each result in a maximum of 2 minutes before a Stop()
42// returns.
43type HTTP struct {
44	// StopTimeout is the duration before we begin force closing connections.
45	// Defaults to 1 minute.
46	StopTimeout time.Duration
47
48	// KillTimeout is the duration before which we completely give up and abort
49	// even though we still have connected clients. This is useful when a large
50	// number of client connections exist and closing them can take a long time.
51	// Note, this is in addition to the StopTimeout. Defaults to 1 minute.
52	KillTimeout time.Duration
53
54	// Stats is optional. If provided, it will be used to record various metrics.
55	Stats stats.Client
56
57	// Clock allows for testing timing related functionality. Do not specify this
58	// in production code.
59	Clock clock.Clock
60}
61
62// Serve provides the low-level API which is useful if you're creating your own
63// net.Listener.
64func (h HTTP) Serve(s *http.Server, l net.Listener) Server {
65	stopTimeout := h.StopTimeout
66	if stopTimeout == 0 {
67		stopTimeout = defaultStopTimeout
68	}
69	killTimeout := h.KillTimeout
70	if killTimeout == 0 {
71		killTimeout = defaultKillTimeout
72	}
73	klock := h.Clock
74	if klock == nil {
75		klock = clock.New()
76	}
77
78	ss := &server{
79		stopTimeout:  stopTimeout,
80		killTimeout:  killTimeout,
81		stats:        h.Stats,
82		clock:        klock,
83		oldConnState: s.ConnState,
84		listener:     l,
85		server:       s,
86		serveDone:    make(chan struct{}),
87		serveErr:     make(chan error, 1),
88		new:          make(chan net.Conn),
89		active:       make(chan net.Conn),
90		idle:         make(chan net.Conn),
91		closed:       make(chan net.Conn),
92		stop:         make(chan chan struct{}),
93		kill:         make(chan chan struct{}),
94	}
95	s.ConnState = ss.connState
96	go ss.manage()
97	go ss.serve()
98	return ss
99}
100
101// ListenAndServe returns a Server for the given http.Server. It is equivalent
102// to ListenAndServe from the standard library, but returns immediately.
103// Requests will be accepted in a background goroutine. If the http.Server has
104// a non-nil TLSConfig, a TLS enabled listener will be setup.
105func (h HTTP) ListenAndServe(s *http.Server) (Server, error) {
106	addr := s.Addr
107	if addr == "" {
108		if s.TLSConfig == nil {
109			addr = ":http"
110		} else {
111			addr = ":https"
112		}
113	}
114	l, err := net.Listen("tcp", addr)
115	if err != nil {
116		stats.BumpSum(h.Stats, "listen.error", 1)
117		return nil, err
118	}
119	if s.TLSConfig != nil {
120		l = tls.NewListener(l, s.TLSConfig)
121	}
122	return h.Serve(s, l), nil
123}
124
125// server manages the serving process and allows for gracefully stopping it.
126type server struct {
127	stopTimeout time.Duration
128	killTimeout time.Duration
129	stats       stats.Client
130	clock       clock.Clock
131
132	oldConnState func(net.Conn, http.ConnState)
133	server       *http.Server
134	serveDone    chan struct{}
135	serveErr     chan error
136	listener     net.Listener
137
138	new    chan net.Conn
139	active chan net.Conn
140	idle   chan net.Conn
141	closed chan net.Conn
142	stop   chan chan struct{}
143	kill   chan chan struct{}
144
145	stopOnce sync.Once
146	stopErr  error
147}
148
149func (s *server) connState(c net.Conn, cs http.ConnState) {
150	if s.oldConnState != nil {
151		s.oldConnState(c, cs)
152	}
153
154	switch cs {
155	case http.StateNew:
156		s.new <- c
157	case http.StateActive:
158		s.active <- c
159	case http.StateIdle:
160		s.idle <- c
161	case http.StateHijacked, http.StateClosed:
162		s.closed <- c
163	}
164}
165
166func (s *server) manage() {
167	defer func() {
168		close(s.new)
169		close(s.active)
170		close(s.idle)
171		close(s.closed)
172		close(s.stop)
173		close(s.kill)
174	}()
175
176	var stopDone chan struct{}
177
178	conns := map[net.Conn]http.ConnState{}
179	var countNew, countActive, countIdle float64
180
181	// decConn decrements the count associated with the current state of the
182	// given connection.
183	decConn := func(c net.Conn) {
184		switch conns[c] {
185		default:
186			panic(fmt.Errorf("unknown existing connection: %s", c))
187		case http.StateNew:
188			countNew--
189		case http.StateActive:
190			countActive--
191		case http.StateIdle:
192			countIdle--
193		}
194	}
195
196	// setup a ticker to report various values every minute. if we don't have a
197	// Stats implementation provided, we Stop it so it never ticks.
198	statsTicker := s.clock.Ticker(time.Minute)
199	if s.stats == nil {
200		statsTicker.Stop()
201	}
202
203	for {
204		select {
205		case <-statsTicker.C:
206			// we'll only get here when s.stats is not nil
207			s.stats.BumpAvg("http-state.new", countNew)
208			s.stats.BumpAvg("http-state.active", countActive)
209			s.stats.BumpAvg("http-state.idle", countIdle)
210			s.stats.BumpAvg("http-state.total", countNew+countActive+countIdle)
211		case c := <-s.new:
212			conns[c] = http.StateNew
213			countNew++
214		case c := <-s.active:
215			decConn(c)
216			countActive++
217
218			conns[c] = http.StateActive
219		case c := <-s.idle:
220			decConn(c)
221			countIdle++
222
223			conns[c] = http.StateIdle
224
225			// if we're already stopping, close it
226			if stopDone != nil {
227				c.Close()
228			}
229		case c := <-s.closed:
230			stats.BumpSum(s.stats, "conn.closed", 1)
231			decConn(c)
232			delete(conns, c)
233
234			// if we're waiting to stop and are all empty, we just closed the last
235			// connection and we're done.
236			if stopDone != nil && len(conns) == 0 {
237				close(stopDone)
238				return
239			}
240		case stopDone = <-s.stop:
241			// if we're already all empty, we're already done
242			if len(conns) == 0 {
243				close(stopDone)
244				return
245			}
246
247			// close current idle connections right away
248			for c, cs := range conns {
249				if cs == http.StateIdle {
250					c.Close()
251				}
252			}
253
254			// continue the loop and wait for all the ConnState updates which will
255			// eventually close(stopDone) and return from this goroutine.
256
257		case killDone := <-s.kill:
258			// force close all connections
259			stats.BumpSum(s.stats, "kill.conn.count", float64(len(conns)))
260			for c := range conns {
261				c.Close()
262			}
263
264			// don't block the kill.
265			close(killDone)
266
267			// continue the loop and we wait for all the ConnState updates and will
268			// return from this goroutine when we're all done. otherwise we'll try to
269			// send those ConnState updates on closed channels.
270
271		}
272	}
273}
274
275func (s *server) serve() {
276	stats.BumpSum(s.stats, "serve", 1)
277	s.serveErr <- s.server.Serve(s.listener)
278	close(s.serveDone)
279	close(s.serveErr)
280}
281
282func (s *server) Wait() error {
283	if err := <-s.serveErr; !isUseOfClosedError(err) {
284		return err
285	}
286	return nil
287}
288
289func (s *server) Stop() error {
290	s.stopOnce.Do(func() {
291		defer stats.BumpTime(s.stats, "stop.time").End()
292		stats.BumpSum(s.stats, "stop", 1)
293
294		// first disable keep-alive for new connections
295		s.server.SetKeepAlivesEnabled(false)
296
297		// then close the listener so new connections can't connect come thru
298		closeErr := s.listener.Close()
299		<-s.serveDone
300
301		// then trigger the background goroutine to stop and wait for it
302		stopDone := make(chan struct{})
303		s.stop <- stopDone
304
305		// wait for stop
306		select {
307		case <-stopDone:
308		case <-s.clock.After(s.stopTimeout):
309			defer stats.BumpTime(s.stats, "kill.time").End()
310			stats.BumpSum(s.stats, "kill", 1)
311
312			// stop timed out, wait for kill
313			killDone := make(chan struct{})
314			s.kill <- killDone
315			select {
316			case <-killDone:
317			case <-s.clock.After(s.killTimeout):
318				// kill timed out, give up
319				stats.BumpSum(s.stats, "kill.timeout", 1)
320			}
321		}
322
323		if closeErr != nil && !isUseOfClosedError(closeErr) {
324			stats.BumpSum(s.stats, "listener.close.error", 1)
325			s.stopErr = closeErr
326		}
327	})
328	return s.stopErr
329}
330
331func isUseOfClosedError(err error) bool {
332	if err == nil {
333		return false
334	}
335	if opErr, ok := err.(*net.OpError); ok {
336		err = opErr.Err
337	}
338	return err.Error() == "use of closed network connection"
339}
340
341// ListenAndServe is a convenience function to serve and wait for a SIGTERM
342// or SIGINT before shutting down.
343func ListenAndServe(s *http.Server, hd *HTTP) error {
344	if hd == nil {
345		hd = &HTTP{}
346	}
347	hs, err := hd.ListenAndServe(s)
348	if err != nil {
349		return err
350	}
351
352	waiterr := make(chan error, 1)
353	go func() {
354		defer close(waiterr)
355		waiterr <- hs.Wait()
356	}()
357
358	signals := make(chan os.Signal, 10)
359	signal.Notify(signals, syscall.SIGTERM, syscall.SIGINT)
360
361	select {
362	case err := <-waiterr:
363		if err != nil {
364			return err
365		}
366	case <-signals:
367		signal.Stop(signals)
368		if err := hs.Stop(); err != nil {
369			return err
370		}
371		if err := <-waiterr; err != nil {
372			return err
373		}
374	}
375	return nil
376}
377