1// Copyright (c) 2015-2019 Jeevanandam M (jeeva@myjeeva.com), All rights reserved.
2// resty source code and usage is governed by a MIT style
3// license that can be found in the LICENSE file.
4
5package resty
6
7import (
8	"errors"
9	"fmt"
10	"net"
11	"net/http"
12	"strings"
13)
14
15type (
16	// RedirectPolicy to regulate the redirects in the resty client.
17	// Objects implementing the RedirectPolicy interface can be registered as
18	//
19	// Apply function should return nil to continue the redirect jounery, otherwise
20	// return error to stop the redirect.
21	RedirectPolicy interface {
22		Apply(req *http.Request, via []*http.Request) error
23	}
24
25	// The RedirectPolicyFunc type is an adapter to allow the use of ordinary functions as RedirectPolicy.
26	// If f is a function with the appropriate signature, RedirectPolicyFunc(f) is a RedirectPolicy object that calls f.
27	RedirectPolicyFunc func(*http.Request, []*http.Request) error
28)
29
30// Apply calls f(req, via).
31func (f RedirectPolicyFunc) Apply(req *http.Request, via []*http.Request) error {
32	return f(req, via)
33}
34
35// NoRedirectPolicy is used to disable redirects in the HTTP client
36// 		resty.SetRedirectPolicy(NoRedirectPolicy())
37func NoRedirectPolicy() RedirectPolicy {
38	return RedirectPolicyFunc(func(req *http.Request, via []*http.Request) error {
39		return errors.New("auto redirect is disabled")
40	})
41}
42
43// FlexibleRedirectPolicy is convenient method to create No of redirect policy for HTTP client.
44// 		resty.SetRedirectPolicy(FlexibleRedirectPolicy(20))
45func FlexibleRedirectPolicy(noOfRedirect int) RedirectPolicy {
46	return RedirectPolicyFunc(func(req *http.Request, via []*http.Request) error {
47		if len(via) >= noOfRedirect {
48			return fmt.Errorf("stopped after %d redirects", noOfRedirect)
49		}
50		checkHostAndAddHeaders(req, via[0])
51		return nil
52	})
53}
54
55// DomainCheckRedirectPolicy is convenient method to define domain name redirect rule in resty client.
56// Redirect is allowed for only mentioned host in the policy.
57// 		resty.SetRedirectPolicy(DomainCheckRedirectPolicy("host1.com", "host2.org", "host3.net"))
58func DomainCheckRedirectPolicy(hostnames ...string) RedirectPolicy {
59	hosts := make(map[string]bool)
60	for _, h := range hostnames {
61		hosts[strings.ToLower(h)] = true
62	}
63
64	fn := RedirectPolicyFunc(func(req *http.Request, via []*http.Request) error {
65		if ok := hosts[getHostname(req.URL.Host)]; !ok {
66			return errors.New("redirect is not allowed as per DomainCheckRedirectPolicy")
67		}
68
69		return nil
70	})
71
72	return fn
73}
74
75//‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾
76// Package Unexported methods
77//_______________________________________________________________________
78
79func getHostname(host string) (hostname string) {
80	if strings.Index(host, ":") > 0 {
81		host, _, _ = net.SplitHostPort(host)
82	}
83	hostname = strings.ToLower(host)
84	return
85}
86
87// By default Golang will not redirect request headers
88// after go throughing various discussion comments from thread
89// https://github.com/golang/go/issues/4800
90// Resty will add all the headers during a redirect for the same host
91func checkHostAndAddHeaders(cur *http.Request, pre *http.Request) {
92	curHostname := getHostname(cur.URL.Host)
93	preHostname := getHostname(pre.URL.Host)
94	if strings.EqualFold(curHostname, preHostname) {
95		for key, val := range pre.Header {
96			cur.Header[key] = val
97		}
98	} else { // only library User-Agent header is added
99		cur.Header.Set(hdrUserAgentKey, hdrUserAgentValue)
100	}
101}
102