1// Copyright 2018 Adam Tauber
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package colly
16
17import (
18	"crypto/sha1"
19	"encoding/gob"
20	"encoding/hex"
21	"io"
22	"io/ioutil"
23	"math/rand"
24	"net/http"
25	"os"
26	"path"
27	"regexp"
28	"sync"
29	"time"
30
31	"compress/gzip"
32
33	"github.com/gobwas/glob"
34)
35
36type httpBackend struct {
37	LimitRules []*LimitRule
38	Client     *http.Client
39	lock       *sync.RWMutex
40}
41
42// LimitRule provides connection restrictions for domains.
43// Both DomainRegexp and DomainGlob can be used to specify
44// the included domains patterns, but at least one is required.
45// There can be two kind of limitations:
46//  - Parallelism: Set limit for the number of concurrent requests to matching domains
47//  - Delay: Wait specified amount of time between requests (parallelism is 1 in this case)
48type LimitRule struct {
49	// DomainRegexp is a regular expression to match against domains
50	DomainRegexp string
51	// DomainRegexp is a glob pattern to match against domains
52	DomainGlob string
53	// Delay is the duration to wait before creating a new request to the matching domains
54	Delay time.Duration
55	// RandomDelay is the extra randomized duration to wait added to Delay before creating a new request
56	RandomDelay time.Duration
57	// Parallelism is the number of the maximum allowed concurrent requests of the matching domains
58	Parallelism    int
59	waitChan       chan bool
60	compiledRegexp *regexp.Regexp
61	compiledGlob   glob.Glob
62}
63
64// Init initializes the private members of LimitRule
65func (r *LimitRule) Init() error {
66	waitChanSize := 1
67	if r.Parallelism > 1 {
68		waitChanSize = r.Parallelism
69	}
70	r.waitChan = make(chan bool, waitChanSize)
71	hasPattern := false
72	if r.DomainRegexp != "" {
73		c, err := regexp.Compile(r.DomainRegexp)
74		if err != nil {
75			return err
76		}
77		r.compiledRegexp = c
78		hasPattern = true
79	}
80	if r.DomainGlob != "" {
81		c, err := glob.Compile(r.DomainGlob)
82		if err != nil {
83			return err
84		}
85		r.compiledGlob = c
86		hasPattern = true
87	}
88	if !hasPattern {
89		return ErrNoPattern
90	}
91	return nil
92}
93
94func (h *httpBackend) Init(jar http.CookieJar) {
95	rand.Seed(time.Now().UnixNano())
96	h.Client = &http.Client{
97		Jar:     jar,
98		Timeout: 10 * time.Second,
99	}
100	h.lock = &sync.RWMutex{}
101}
102
103// Match checks that the domain parameter triggers the rule
104func (r *LimitRule) Match(domain string) bool {
105	match := false
106	if r.compiledRegexp != nil && r.compiledRegexp.MatchString(domain) {
107		match = true
108	}
109	if r.compiledGlob != nil && r.compiledGlob.Match(domain) {
110		match = true
111	}
112	return match
113}
114
115func (h *httpBackend) GetMatchingRule(domain string) *LimitRule {
116	if h.LimitRules == nil {
117		return nil
118	}
119	h.lock.RLock()
120	defer h.lock.RUnlock()
121	for _, r := range h.LimitRules {
122		if r.Match(domain) {
123			return r
124		}
125	}
126	return nil
127}
128
129func (h *httpBackend) Cache(request *http.Request, bodySize int, cacheDir string) (*Response, error) {
130	if cacheDir == "" || request.Method != "GET" {
131		return h.Do(request, bodySize)
132	}
133	sum := sha1.Sum([]byte(request.URL.String()))
134	hash := hex.EncodeToString(sum[:])
135	dir := path.Join(cacheDir, hash[:2])
136	filename := path.Join(dir, hash)
137	if file, err := os.Open(filename); err == nil {
138		resp := new(Response)
139		err := gob.NewDecoder(file).Decode(resp)
140		file.Close()
141		if resp.StatusCode < 500 {
142			return resp, err
143		}
144	}
145	resp, err := h.Do(request, bodySize)
146	if err != nil || resp.StatusCode >= 500 {
147		return resp, err
148	}
149	if _, err := os.Stat(dir); err != nil {
150		if err := os.MkdirAll(dir, 0750); err != nil {
151			return resp, err
152		}
153	}
154	file, err := os.Create(filename + "~")
155	if err != nil {
156		return resp, err
157	}
158	if err := gob.NewEncoder(file).Encode(resp); err != nil {
159		file.Close()
160		return resp, err
161	}
162	file.Close()
163	return resp, os.Rename(filename+"~", filename)
164}
165
166func (h *httpBackend) Do(request *http.Request, bodySize int) (*Response, error) {
167	r := h.GetMatchingRule(request.URL.Host)
168	if r != nil {
169		r.waitChan <- true
170		defer func(r *LimitRule) {
171			randomDelay := time.Duration(0)
172			if r.RandomDelay != 0 {
173				randomDelay = time.Duration(rand.Int63n(int64(r.RandomDelay)))
174			}
175			time.Sleep(r.Delay + randomDelay)
176			<-r.waitChan
177		}(r)
178	}
179
180	res, err := h.Client.Do(request)
181	if err != nil {
182		return nil, err
183	}
184	if res.Request != nil {
185		*request = *res.Request
186	}
187
188	var bodyReader io.Reader = res.Body
189	if bodySize > 0 {
190		bodyReader = io.LimitReader(bodyReader, int64(bodySize))
191	}
192	if !res.Uncompressed && res.Header.Get("Content-Encoding") == "gzip" {
193		bodyReader, err = gzip.NewReader(bodyReader)
194		if err != nil {
195			return nil, err
196		}
197	}
198	body, err := ioutil.ReadAll(bodyReader)
199	defer res.Body.Close()
200	if err != nil {
201		return nil, err
202	}
203	return &Response{
204		StatusCode: res.StatusCode,
205		Body:       body,
206		Headers:    &res.Header,
207	}, nil
208}
209
210func (h *httpBackend) Limit(rule *LimitRule) error {
211	h.lock.Lock()
212	if h.LimitRules == nil {
213		h.LimitRules = make([]*LimitRule, 0, 8)
214	}
215	h.LimitRules = append(h.LimitRules, rule)
216	h.lock.Unlock()
217	return rule.Init()
218}
219
220func (h *httpBackend) Limits(rules []*LimitRule) error {
221	for _, r := range rules {
222		if err := h.Limit(r); err != nil {
223			return err
224		}
225	}
226	return nil
227}
228