1// Copyright 2015 beego Author. All Rights Reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package context
16
17import (
18	"bytes"
19	"compress/flate"
20	"compress/gzip"
21	"compress/zlib"
22	"io"
23	"net/http"
24	"os"
25	"strconv"
26	"strings"
27	"sync"
28)
29
30var (
31	//Default size==20B same as nginx
32	defaultGzipMinLength = 20
33	//Content will only be compressed if content length is either unknown or greater than gzipMinLength.
34	gzipMinLength = defaultGzipMinLength
35	//The compression level used for deflate compression. (0-9).
36	gzipCompressLevel int
37	//List of HTTP methods to compress. If not set, only GET requests are compressed.
38	includedMethods map[string]bool
39	getMethodOnly   bool
40)
41
42// InitGzip init the gzipcompress
43func InitGzip(minLength, compressLevel int, methods []string) {
44	if minLength >= 0 {
45		gzipMinLength = minLength
46	}
47	gzipCompressLevel = compressLevel
48	if gzipCompressLevel < flate.NoCompression || gzipCompressLevel > flate.BestCompression {
49		gzipCompressLevel = flate.BestSpeed
50	}
51	getMethodOnly = (len(methods) == 0) || (len(methods) == 1 && strings.ToUpper(methods[0]) == "GET")
52	includedMethods = make(map[string]bool, len(methods))
53	for _, v := range methods {
54		includedMethods[strings.ToUpper(v)] = true
55	}
56}
57
58type resetWriter interface {
59	io.Writer
60	Reset(w io.Writer)
61}
62
63type nopResetWriter struct {
64	io.Writer
65}
66
67func (n nopResetWriter) Reset(w io.Writer) {
68	//do nothing
69}
70
71type acceptEncoder struct {
72	name                    string
73	levelEncode             func(int) resetWriter
74	customCompressLevelPool *sync.Pool
75	bestCompressionPool     *sync.Pool
76}
77
78func (ac acceptEncoder) encode(wr io.Writer, level int) resetWriter {
79	if ac.customCompressLevelPool == nil || ac.bestCompressionPool == nil {
80		return nopResetWriter{wr}
81	}
82	var rwr resetWriter
83	switch level {
84	case flate.BestSpeed:
85		rwr = ac.customCompressLevelPool.Get().(resetWriter)
86	case flate.BestCompression:
87		rwr = ac.bestCompressionPool.Get().(resetWriter)
88	default:
89		rwr = ac.levelEncode(level)
90	}
91	rwr.Reset(wr)
92	return rwr
93}
94
95func (ac acceptEncoder) put(wr resetWriter, level int) {
96	if ac.customCompressLevelPool == nil || ac.bestCompressionPool == nil {
97		return
98	}
99	wr.Reset(nil)
100
101	//notice
102	//compressionLevel==BestCompression DOES NOT MATTER
103	//sync.Pool will not memory leak
104
105	switch level {
106	case gzipCompressLevel:
107		ac.customCompressLevelPool.Put(wr)
108	case flate.BestCompression:
109		ac.bestCompressionPool.Put(wr)
110	}
111}
112
113var (
114	noneCompressEncoder = acceptEncoder{"", nil, nil, nil}
115	gzipCompressEncoder = acceptEncoder{
116		name:                    "gzip",
117		levelEncode:             func(level int) resetWriter { wr, _ := gzip.NewWriterLevel(nil, level); return wr },
118		customCompressLevelPool: &sync.Pool{New: func() interface{} { wr, _ := gzip.NewWriterLevel(nil, gzipCompressLevel); return wr }},
119		bestCompressionPool:     &sync.Pool{New: func() interface{} { wr, _ := gzip.NewWriterLevel(nil, flate.BestCompression); return wr }},
120	}
121
122	//according to the sec :http://tools.ietf.org/html/rfc2616#section-3.5 ,the deflate compress in http is zlib indeed
123	//deflate
124	//The "zlib" format defined in RFC 1950 [31] in combination with
125	//the "deflate" compression mechanism described in RFC 1951 [29].
126	deflateCompressEncoder = acceptEncoder{
127		name:                    "deflate",
128		levelEncode:             func(level int) resetWriter { wr, _ := zlib.NewWriterLevel(nil, level); return wr },
129		customCompressLevelPool: &sync.Pool{New: func() interface{} { wr, _ := zlib.NewWriterLevel(nil, gzipCompressLevel); return wr }},
130		bestCompressionPool:     &sync.Pool{New: func() interface{} { wr, _ := zlib.NewWriterLevel(nil, flate.BestCompression); return wr }},
131	}
132)
133
134var (
135	encoderMap = map[string]acceptEncoder{ // all the other compress methods will ignore
136		"gzip":     gzipCompressEncoder,
137		"deflate":  deflateCompressEncoder,
138		"*":        gzipCompressEncoder, // * means any compress will accept,we prefer gzip
139		"identity": noneCompressEncoder, // identity means none-compress
140	}
141)
142
143// WriteFile reads from file and writes to writer by the specific encoding(gzip/deflate)
144func WriteFile(encoding string, writer io.Writer, file *os.File) (bool, string, error) {
145	return writeLevel(encoding, writer, file, flate.BestCompression)
146}
147
148// WriteBody reads  writes content to writer by the specific encoding(gzip/deflate)
149func WriteBody(encoding string, writer io.Writer, content []byte) (bool, string, error) {
150	if encoding == "" || len(content) < gzipMinLength {
151		_, err := writer.Write(content)
152		return false, "", err
153	}
154	return writeLevel(encoding, writer, bytes.NewReader(content), gzipCompressLevel)
155}
156
157// writeLevel reads from reader,writes to writer by specific encoding and compress level
158// the compress level is defined by deflate package
159func writeLevel(encoding string, writer io.Writer, reader io.Reader, level int) (bool, string, error) {
160	var outputWriter resetWriter
161	var err error
162	var ce = noneCompressEncoder
163
164	if cf, ok := encoderMap[encoding]; ok {
165		ce = cf
166	}
167	encoding = ce.name
168	outputWriter = ce.encode(writer, level)
169	defer ce.put(outputWriter, level)
170
171	_, err = io.Copy(outputWriter, reader)
172	if err != nil {
173		return false, "", err
174	}
175
176	switch outputWriter.(type) {
177	case io.WriteCloser:
178		outputWriter.(io.WriteCloser).Close()
179	}
180	return encoding != "", encoding, nil
181}
182
183// ParseEncoding will extract the right encoding for response
184// the Accept-Encoding's sec is here:
185// http://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html#sec14.3
186func ParseEncoding(r *http.Request) string {
187	if r == nil {
188		return ""
189	}
190	if (getMethodOnly && r.Method == "GET") || includedMethods[r.Method] {
191		return parseEncoding(r)
192	}
193	return ""
194}
195
196type q struct {
197	name  string
198	value float64
199}
200
201func parseEncoding(r *http.Request) string {
202	acceptEncoding := r.Header.Get("Accept-Encoding")
203	if acceptEncoding == "" {
204		return ""
205	}
206	var lastQ q
207	for _, v := range strings.Split(acceptEncoding, ",") {
208		v = strings.TrimSpace(v)
209		if v == "" {
210			continue
211		}
212		vs := strings.Split(v, ";")
213		var cf acceptEncoder
214		var ok bool
215		if cf, ok = encoderMap[vs[0]]; !ok {
216			continue
217		}
218		if len(vs) == 1 {
219			return cf.name
220		}
221		if len(vs) == 2 {
222			f, _ := strconv.ParseFloat(strings.Replace(vs[1], "q=", "", -1), 64)
223			if f == 0 {
224				continue
225			}
226			if f > lastQ.value {
227				lastQ = q{cf.name, f}
228			}
229		}
230	}
231	return lastQ.name
232}
233