1// Copyright 2018 Adam Tauber 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14 15package colly 16 17import ( 18 "crypto/sha1" 19 "encoding/gob" 20 "encoding/hex" 21 "io" 22 "io/ioutil" 23 "math/rand" 24 "net/http" 25 "os" 26 "path" 27 "regexp" 28 "sync" 29 "time" 30 31 "compress/gzip" 32 33 "github.com/gobwas/glob" 34) 35 36type httpBackend struct { 37 LimitRules []*LimitRule 38 Client *http.Client 39 lock *sync.RWMutex 40} 41 42// LimitRule provides connection restrictions for domains. 43// Both DomainRegexp and DomainGlob can be used to specify 44// the included domains patterns, but at least one is required. 45// There can be two kind of limitations: 46// - Parallelism: Set limit for the number of concurrent requests to matching domains 47// - Delay: Wait specified amount of time between requests (parallelism is 1 in this case) 48type LimitRule struct { 49 // DomainRegexp is a regular expression to match against domains 50 DomainRegexp string 51 // DomainRegexp is a glob pattern to match against domains 52 DomainGlob string 53 // Delay is the duration to wait before creating a new request to the matching domains 54 Delay time.Duration 55 // RandomDelay is the extra randomized duration to wait added to Delay before creating a new request 56 RandomDelay time.Duration 57 // Parallelism is the number of the maximum allowed concurrent requests of the matching domains 58 Parallelism int 59 waitChan chan bool 60 compiledRegexp *regexp.Regexp 61 compiledGlob glob.Glob 62} 63 64// Init initializes the private members of LimitRule 65func (r *LimitRule) Init() error { 66 waitChanSize := 1 67 if r.Parallelism > 1 { 68 waitChanSize = r.Parallelism 69 } 70 r.waitChan = make(chan bool, waitChanSize) 71 hasPattern := false 72 if r.DomainRegexp != "" { 73 c, err := regexp.Compile(r.DomainRegexp) 74 if err != nil { 75 return err 76 } 77 r.compiledRegexp = c 78 hasPattern = true 79 } 80 if r.DomainGlob != "" { 81 c, err := glob.Compile(r.DomainGlob) 82 if err != nil { 83 return err 84 } 85 r.compiledGlob = c 86 hasPattern = true 87 } 88 if !hasPattern { 89 return ErrNoPattern 90 } 91 return nil 92} 93 94func (h *httpBackend) Init(jar http.CookieJar) { 95 rand.Seed(time.Now().UnixNano()) 96 h.Client = &http.Client{ 97 Jar: jar, 98 Timeout: 10 * time.Second, 99 } 100 h.lock = &sync.RWMutex{} 101} 102 103// Match checks that the domain parameter triggers the rule 104func (r *LimitRule) Match(domain string) bool { 105 match := false 106 if r.compiledRegexp != nil && r.compiledRegexp.MatchString(domain) { 107 match = true 108 } 109 if r.compiledGlob != nil && r.compiledGlob.Match(domain) { 110 match = true 111 } 112 return match 113} 114 115func (h *httpBackend) GetMatchingRule(domain string) *LimitRule { 116 if h.LimitRules == nil { 117 return nil 118 } 119 h.lock.RLock() 120 defer h.lock.RUnlock() 121 for _, r := range h.LimitRules { 122 if r.Match(domain) { 123 return r 124 } 125 } 126 return nil 127} 128 129func (h *httpBackend) Cache(request *http.Request, bodySize int, cacheDir string) (*Response, error) { 130 if cacheDir == "" || request.Method != "GET" { 131 return h.Do(request, bodySize) 132 } 133 sum := sha1.Sum([]byte(request.URL.String())) 134 hash := hex.EncodeToString(sum[:]) 135 dir := path.Join(cacheDir, hash[:2]) 136 filename := path.Join(dir, hash) 137 if file, err := os.Open(filename); err == nil { 138 resp := new(Response) 139 err := gob.NewDecoder(file).Decode(resp) 140 file.Close() 141 if resp.StatusCode < 500 { 142 return resp, err 143 } 144 } 145 resp, err := h.Do(request, bodySize) 146 if err != nil || resp.StatusCode >= 500 { 147 return resp, err 148 } 149 if _, err := os.Stat(dir); err != nil { 150 if err := os.MkdirAll(dir, 0750); err != nil { 151 return resp, err 152 } 153 } 154 file, err := os.Create(filename + "~") 155 if err != nil { 156 return resp, err 157 } 158 if err := gob.NewEncoder(file).Encode(resp); err != nil { 159 file.Close() 160 return resp, err 161 } 162 file.Close() 163 return resp, os.Rename(filename+"~", filename) 164} 165 166func (h *httpBackend) Do(request *http.Request, bodySize int) (*Response, error) { 167 r := h.GetMatchingRule(request.URL.Host) 168 if r != nil { 169 r.waitChan <- true 170 defer func(r *LimitRule) { 171 randomDelay := time.Duration(0) 172 if r.RandomDelay != 0 { 173 randomDelay = time.Duration(rand.Int63n(int64(r.RandomDelay))) 174 } 175 time.Sleep(r.Delay + randomDelay) 176 <-r.waitChan 177 }(r) 178 } 179 180 res, err := h.Client.Do(request) 181 if err != nil { 182 return nil, err 183 } 184 if res.Request != nil { 185 *request = *res.Request 186 } 187 188 var bodyReader io.Reader = res.Body 189 if bodySize > 0 { 190 bodyReader = io.LimitReader(bodyReader, int64(bodySize)) 191 } 192 if !res.Uncompressed && res.Header.Get("Content-Encoding") == "gzip" { 193 bodyReader, err = gzip.NewReader(bodyReader) 194 if err != nil { 195 return nil, err 196 } 197 } 198 body, err := ioutil.ReadAll(bodyReader) 199 defer res.Body.Close() 200 if err != nil { 201 return nil, err 202 } 203 return &Response{ 204 StatusCode: res.StatusCode, 205 Body: body, 206 Headers: &res.Header, 207 }, nil 208} 209 210func (h *httpBackend) Limit(rule *LimitRule) error { 211 h.lock.Lock() 212 if h.LimitRules == nil { 213 h.LimitRules = make([]*LimitRule, 0, 8) 214 } 215 h.LimitRules = append(h.LimitRules, rule) 216 h.lock.Unlock() 217 return rule.Init() 218} 219 220func (h *httpBackend) Limits(rules []*LimitRule) error { 221 for _, r := range rules { 222 if err := h.Limit(r); err != nil { 223 return err 224 } 225 } 226 return nil 227} 228