1// Copyright 2013 Martini Authors
2//
3// Licensed under the Apache License, Version 2.0 (the "License"): you may
4// not use this file except in compliance with the License. You may obtain
5// a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12// License for the specific language governing permissions and limitations
13// under the License.
14
15package macaron
16
17import (
18	"bufio"
19	"errors"
20	"net"
21	"net/http"
22)
23
24// ResponseWriter is a wrapper around http.ResponseWriter that provides extra information about
25// the response. It is recommended that middleware handlers use this construct to wrap a responsewriter
26// if the functionality calls for it.
27type ResponseWriter interface {
28	http.ResponseWriter
29	http.Flusher
30	// Status returns the status code of the response or 0 if the response has not been written.
31	Status() int
32	// Written returns whether or not the ResponseWriter has been written.
33	Written() bool
34	// Size returns the size of the response body.
35	Size() int
36	// Before allows for a function to be called before the ResponseWriter has been written to. This is
37	// useful for setting headers or any other operations that must happen before a response has been written.
38	Before(BeforeFunc)
39}
40
41// BeforeFunc is a function that is called before the ResponseWriter has been written to.
42type BeforeFunc func(ResponseWriter)
43
44// NewResponseWriter creates a ResponseWriter that wraps an http.ResponseWriter
45func NewResponseWriter(method string, rw http.ResponseWriter) ResponseWriter {
46	return &responseWriter{method, rw, 0, 0, nil}
47}
48
49type responseWriter struct {
50	method string
51	http.ResponseWriter
52	status      int
53	size        int
54	beforeFuncs []BeforeFunc
55}
56
57func (rw *responseWriter) WriteHeader(s int) {
58	rw.callBefore()
59
60	// Avoid panic if status code is not a valid HTTP status code
61	if s < 100 || s > 999 {
62		rw.ResponseWriter.WriteHeader(500)
63		rw.status = 500
64		return
65	}
66
67	rw.ResponseWriter.WriteHeader(s)
68	rw.status = s
69}
70
71func (rw *responseWriter) Write(b []byte) (size int, err error) {
72	if !rw.Written() {
73		// The status will be StatusOK if WriteHeader has not been called yet
74		rw.WriteHeader(http.StatusOK)
75	}
76	if rw.method != "HEAD" {
77		size, err = rw.ResponseWriter.Write(b)
78		rw.size += size
79	}
80	return size, err
81}
82
83func (rw *responseWriter) Status() int {
84	return rw.status
85}
86
87func (rw *responseWriter) Size() int {
88	return rw.size
89}
90
91func (rw *responseWriter) Written() bool {
92	return rw.status != 0
93}
94
95func (rw *responseWriter) Before(before BeforeFunc) {
96	rw.beforeFuncs = append(rw.beforeFuncs, before)
97}
98
99func (rw *responseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
100	hijacker, ok := rw.ResponseWriter.(http.Hijacker)
101	if !ok {
102		return nil, nil, errors.New("the ResponseWriter doesn't support the Hijacker interface")
103	}
104	return hijacker.Hijack()
105}
106
107func (rw *responseWriter) callBefore() {
108	for i := len(rw.beforeFuncs) - 1; i >= 0; i-- {
109		rw.beforeFuncs[i](rw)
110	}
111}
112
113func (rw *responseWriter) Flush() {
114	flusher, ok := rw.ResponseWriter.(http.Flusher)
115	if ok {
116		flusher.Flush()
117	}
118}
119