1package handlers
2
3import (
4	"log"
5	"net/http"
6	"runtime/debug"
7)
8
9// RecoveryHandlerLogger is an interface used by the recovering handler to print logs.
10type RecoveryHandlerLogger interface {
11	Println(...interface{})
12}
13
14type recoveryHandler struct {
15	handler    http.Handler
16	logger     RecoveryHandlerLogger
17	printStack bool
18}
19
20// RecoveryOption provides a functional approach to define
21// configuration for a handler; such as setting the logging
22// whether or not to print strack traces on panic.
23type RecoveryOption func(http.Handler)
24
25func parseRecoveryOptions(h http.Handler, opts ...RecoveryOption) http.Handler {
26	for _, option := range opts {
27		option(h)
28	}
29
30	return h
31}
32
33// RecoveryHandler is HTTP middleware that recovers from a panic,
34// logs the panic, writes http.StatusInternalServerError, and
35// continues to the next handler.
36//
37// Example:
38//
39//  r := mux.NewRouter()
40//  r.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
41//  	panic("Unexpected error!")
42//  })
43//
44//  http.ListenAndServe(":1123", handlers.RecoveryHandler()(r))
45func RecoveryHandler(opts ...RecoveryOption) func(h http.Handler) http.Handler {
46	return func(h http.Handler) http.Handler {
47		r := &recoveryHandler{handler: h}
48		return parseRecoveryOptions(r, opts...)
49	}
50}
51
52// RecoveryLogger is a functional option to override
53// the default logger
54func RecoveryLogger(logger RecoveryHandlerLogger) RecoveryOption {
55	return func(h http.Handler) {
56		r := h.(*recoveryHandler)
57		r.logger = logger
58	}
59}
60
61// PrintRecoveryStack is a functional option to enable
62// or disable printing stack traces on panic.
63func PrintRecoveryStack(print bool) RecoveryOption {
64	return func(h http.Handler) {
65		r := h.(*recoveryHandler)
66		r.printStack = print
67	}
68}
69
70func (h recoveryHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
71	defer func() {
72		if err := recover(); err != nil {
73			w.WriteHeader(http.StatusInternalServerError)
74			h.log(err)
75		}
76	}()
77
78	h.handler.ServeHTTP(w, req)
79}
80
81func (h recoveryHandler) log(v ...interface{}) {
82	if h.logger != nil {
83		h.logger.Println(v...)
84	} else {
85		log.Println(v...)
86	}
87
88	if h.printStack {
89		debug.PrintStack()
90	}
91}
92