1// Copyright 2015 beego Author. All Rights Reserved. 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14 15package context 16 17import ( 18 "bytes" 19 "compress/flate" 20 "compress/gzip" 21 "compress/zlib" 22 "io" 23 "net/http" 24 "os" 25 "strconv" 26 "strings" 27 "sync" 28) 29 30var ( 31 //Default size==20B same as nginx 32 defaultGzipMinLength = 20 33 //Content will only be compressed if content length is either unknown or greater than gzipMinLength. 34 gzipMinLength = defaultGzipMinLength 35 //The compression level used for deflate compression. (0-9). 36 gzipCompressLevel int 37 //List of HTTP methods to compress. If not set, only GET requests are compressed. 38 includedMethods map[string]bool 39 getMethodOnly bool 40) 41 42// InitGzip init the gzipcompress 43func InitGzip(minLength, compressLevel int, methods []string) { 44 if minLength >= 0 { 45 gzipMinLength = minLength 46 } 47 gzipCompressLevel = compressLevel 48 if gzipCompressLevel < flate.NoCompression || gzipCompressLevel > flate.BestCompression { 49 gzipCompressLevel = flate.BestSpeed 50 } 51 getMethodOnly = (len(methods) == 0) || (len(methods) == 1 && strings.ToUpper(methods[0]) == "GET") 52 includedMethods = make(map[string]bool, len(methods)) 53 for _, v := range methods { 54 includedMethods[strings.ToUpper(v)] = true 55 } 56} 57 58type resetWriter interface { 59 io.Writer 60 Reset(w io.Writer) 61} 62 63type nopResetWriter struct { 64 io.Writer 65} 66 67func (n nopResetWriter) Reset(w io.Writer) { 68 //do nothing 69} 70 71type acceptEncoder struct { 72 name string 73 levelEncode func(int) resetWriter 74 customCompressLevelPool *sync.Pool 75 bestCompressionPool *sync.Pool 76} 77 78func (ac acceptEncoder) encode(wr io.Writer, level int) resetWriter { 79 if ac.customCompressLevelPool == nil || ac.bestCompressionPool == nil { 80 return nopResetWriter{wr} 81 } 82 var rwr resetWriter 83 switch level { 84 case flate.BestSpeed: 85 rwr = ac.customCompressLevelPool.Get().(resetWriter) 86 case flate.BestCompression: 87 rwr = ac.bestCompressionPool.Get().(resetWriter) 88 default: 89 rwr = ac.levelEncode(level) 90 } 91 rwr.Reset(wr) 92 return rwr 93} 94 95func (ac acceptEncoder) put(wr resetWriter, level int) { 96 if ac.customCompressLevelPool == nil || ac.bestCompressionPool == nil { 97 return 98 } 99 wr.Reset(nil) 100 101 //notice 102 //compressionLevel==BestCompression DOES NOT MATTER 103 //sync.Pool will not memory leak 104 105 switch level { 106 case gzipCompressLevel: 107 ac.customCompressLevelPool.Put(wr) 108 case flate.BestCompression: 109 ac.bestCompressionPool.Put(wr) 110 } 111} 112 113var ( 114 noneCompressEncoder = acceptEncoder{"", nil, nil, nil} 115 gzipCompressEncoder = acceptEncoder{ 116 name: "gzip", 117 levelEncode: func(level int) resetWriter { wr, _ := gzip.NewWriterLevel(nil, level); return wr }, 118 customCompressLevelPool: &sync.Pool{New: func() interface{} { wr, _ := gzip.NewWriterLevel(nil, gzipCompressLevel); return wr }}, 119 bestCompressionPool: &sync.Pool{New: func() interface{} { wr, _ := gzip.NewWriterLevel(nil, flate.BestCompression); return wr }}, 120 } 121 122 //according to the sec :http://tools.ietf.org/html/rfc2616#section-3.5 ,the deflate compress in http is zlib indeed 123 //deflate 124 //The "zlib" format defined in RFC 1950 [31] in combination with 125 //the "deflate" compression mechanism described in RFC 1951 [29]. 126 deflateCompressEncoder = acceptEncoder{ 127 name: "deflate", 128 levelEncode: func(level int) resetWriter { wr, _ := zlib.NewWriterLevel(nil, level); return wr }, 129 customCompressLevelPool: &sync.Pool{New: func() interface{} { wr, _ := zlib.NewWriterLevel(nil, gzipCompressLevel); return wr }}, 130 bestCompressionPool: &sync.Pool{New: func() interface{} { wr, _ := zlib.NewWriterLevel(nil, flate.BestCompression); return wr }}, 131 } 132) 133 134var ( 135 encoderMap = map[string]acceptEncoder{ // all the other compress methods will ignore 136 "gzip": gzipCompressEncoder, 137 "deflate": deflateCompressEncoder, 138 "*": gzipCompressEncoder, // * means any compress will accept,we prefer gzip 139 "identity": noneCompressEncoder, // identity means none-compress 140 } 141) 142 143// WriteFile reads from file and writes to writer by the specific encoding(gzip/deflate) 144func WriteFile(encoding string, writer io.Writer, file *os.File) (bool, string, error) { 145 return writeLevel(encoding, writer, file, flate.BestCompression) 146} 147 148// WriteBody reads writes content to writer by the specific encoding(gzip/deflate) 149func WriteBody(encoding string, writer io.Writer, content []byte) (bool, string, error) { 150 if encoding == "" || len(content) < gzipMinLength { 151 _, err := writer.Write(content) 152 return false, "", err 153 } 154 return writeLevel(encoding, writer, bytes.NewReader(content), gzipCompressLevel) 155} 156 157// writeLevel reads from reader,writes to writer by specific encoding and compress level 158// the compress level is defined by deflate package 159func writeLevel(encoding string, writer io.Writer, reader io.Reader, level int) (bool, string, error) { 160 var outputWriter resetWriter 161 var err error 162 var ce = noneCompressEncoder 163 164 if cf, ok := encoderMap[encoding]; ok { 165 ce = cf 166 } 167 encoding = ce.name 168 outputWriter = ce.encode(writer, level) 169 defer ce.put(outputWriter, level) 170 171 _, err = io.Copy(outputWriter, reader) 172 if err != nil { 173 return false, "", err 174 } 175 176 switch outputWriter.(type) { 177 case io.WriteCloser: 178 outputWriter.(io.WriteCloser).Close() 179 } 180 return encoding != "", encoding, nil 181} 182 183// ParseEncoding will extract the right encoding for response 184// the Accept-Encoding's sec is here: 185// http://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html#sec14.3 186func ParseEncoding(r *http.Request) string { 187 if r == nil { 188 return "" 189 } 190 if (getMethodOnly && r.Method == "GET") || includedMethods[r.Method] { 191 return parseEncoding(r) 192 } 193 return "" 194} 195 196type q struct { 197 name string 198 value float64 199} 200 201func parseEncoding(r *http.Request) string { 202 acceptEncoding := r.Header.Get("Accept-Encoding") 203 if acceptEncoding == "" { 204 return "" 205 } 206 var lastQ q 207 for _, v := range strings.Split(acceptEncoding, ",") { 208 v = strings.TrimSpace(v) 209 if v == "" { 210 continue 211 } 212 vs := strings.Split(v, ";") 213 var cf acceptEncoder 214 var ok bool 215 if cf, ok = encoderMap[vs[0]]; !ok { 216 continue 217 } 218 if len(vs) == 1 { 219 return cf.name 220 } 221 if len(vs) == 2 { 222 f, _ := strconv.ParseFloat(strings.Replace(vs[1], "q=", "", -1), 64) 223 if f == 0 { 224 continue 225 } 226 if f > lastQ.value { 227 lastQ = q{cf.name, f} 228 } 229 } 230 } 231 return lastQ.name 232} 233