1// +build go1.7
2
3package ec2_test
4
5import (
6	"bytes"
7	"context"
8	"io/ioutil"
9	"net/http"
10	"net/url"
11	"regexp"
12	"testing"
13
14	"github.com/aws/aws-sdk-go/aws"
15	sdkclient "github.com/aws/aws-sdk-go/aws/client"
16	"github.com/aws/aws-sdk-go/aws/request"
17	"github.com/aws/aws-sdk-go/awstesting/unit"
18	"github.com/aws/aws-sdk-go/service/ec2"
19)
20
21func TestCopySnapshotPresignedURL(t *testing.T) {
22	svc := ec2.New(unit.Session, &aws.Config{Region: aws.String("us-west-2")})
23
24	func() {
25		defer func() {
26			if r := recover(); r != nil {
27				t.Fatalf("expect CopySnapshotRequest with nill")
28			}
29		}()
30		// Doesn't panic on nil input
31		req, _ := svc.CopySnapshotRequest(nil)
32		req.Sign()
33	}()
34
35	req, _ := svc.CopySnapshotRequest(&ec2.CopySnapshotInput{
36		SourceRegion:     aws.String("us-west-1"),
37		SourceSnapshotId: aws.String("snap-id"),
38	})
39	req.Sign()
40
41	b, _ := ioutil.ReadAll(req.HTTPRequest.Body)
42	q, _ := url.ParseQuery(string(b))
43	u, _ := url.QueryUnescape(q.Get("PresignedUrl"))
44	if e, a := "us-west-2", q.Get("DestinationRegion"); e != a {
45		t.Errorf("expect %v, got %v", e, a)
46	}
47	if e, a := "us-west-1", q.Get("SourceRegion"); e != a {
48		t.Errorf("expect %v, got %v", e, a)
49	}
50
51	r := regexp.MustCompile(`^https://ec2\.us-west-1\.amazonaws\.com/.+&DestinationRegion=us-west-2`)
52	if !r.MatchString(u) {
53		t.Errorf("expect %v to match, got %v", r.String(), u)
54	}
55}
56
57func TestNoCustomRetryerWithMaxRetries(t *testing.T) {
58	cases := map[string]struct {
59		Config           aws.Config
60		ExpectMaxRetries int
61	}{
62		"With custom retrier": {
63			Config: aws.Config{
64				Retryer: sdkclient.DefaultRetryer{
65					NumMaxRetries: 10,
66				},
67			},
68			ExpectMaxRetries: 10,
69		},
70		"with max retries": {
71			Config: aws.Config{
72				MaxRetries: aws.Int(10),
73			},
74			ExpectMaxRetries: 10,
75		},
76		"no options set": {
77			ExpectMaxRetries: sdkclient.DefaultRetryerMaxNumRetries,
78		},
79	}
80
81	for name, c := range cases {
82		t.Run(name, func(t *testing.T) {
83			client := ec2.New(unit.Session, &aws.Config{
84				DisableParamValidation: aws.Bool(true),
85			}, c.Config.Copy())
86			client.ModifyNetworkInterfaceAttributeWithContext(context.Background(), nil, checkRetryerMaxRetries(t, c.ExpectMaxRetries))
87			client.AssignPrivateIpAddressesWithContext(context.Background(), nil, checkRetryerMaxRetries(t, c.ExpectMaxRetries))
88		})
89	}
90
91}
92
93func checkRetryerMaxRetries(t *testing.T, maxRetries int) func(*request.Request) {
94	return func(r *request.Request) {
95		r.Handlers.Send.Clear()
96		r.Handlers.Send.PushBack(func(rr *request.Request) {
97			if e, a := maxRetries, rr.Retryer.MaxRetries(); e != a {
98				t.Errorf("%s, expect %v max retries, got %v", rr.Operation.Name, e, a)
99			}
100			rr.HTTPResponse = &http.Response{
101				StatusCode: 200,
102				Header:     http.Header{},
103				Body:       ioutil.NopCloser(&bytes.Buffer{}),
104			}
105		})
106	}
107}
108