1package rds
2
3import (
4	"fmt"
5	"io/ioutil"
6	"net/url"
7	"regexp"
8	"strings"
9	"testing"
10	"time"
11
12	"github.com/aws/aws-sdk-go/aws"
13	"github.com/aws/aws-sdk-go/aws/request"
14	"github.com/aws/aws-sdk-go/awstesting"
15	"github.com/aws/aws-sdk-go/awstesting/unit"
16)
17
18func TestPresignWithPresignNotSet(t *testing.T) {
19	reqs := map[string]*request.Request{}
20	svc := New(unit.Session, &aws.Config{Region: aws.String("us-west-2")})
21
22	f := func() {
23		// Doesn't panic on nil input
24		req, _ := svc.CopyDBSnapshotRequest(nil)
25		req.Sign()
26	}
27	if paniced, p := awstesting.DidPanic(f); paniced {
28		t.Errorf("expect no panic, got %v", p)
29	}
30
31	reqs[opCopyDBSnapshot], _ = svc.CopyDBSnapshotRequest(&CopyDBSnapshotInput{
32		SourceRegion:               aws.String("us-west-1"),
33		SourceDBSnapshotIdentifier: aws.String("foo"),
34		TargetDBSnapshotIdentifier: aws.String("bar"),
35	})
36
37	reqs[opCreateDBInstanceReadReplica], _ = svc.CreateDBInstanceReadReplicaRequest(&CreateDBInstanceReadReplicaInput{
38		SourceRegion:               aws.String("us-west-1"),
39		SourceDBInstanceIdentifier: aws.String("foo"),
40		DBInstanceIdentifier:       aws.String("bar"),
41	})
42
43	for op, req := range reqs {
44		req.Sign()
45		b, _ := ioutil.ReadAll(req.HTTPRequest.Body)
46		q, _ := url.ParseQuery(string(b))
47
48		u, _ := url.QueryUnescape(q.Get("PreSignedUrl"))
49
50		exp := fmt.Sprintf(`^https://rds.us-west-1\.amazonaws\.com/\?Action=%s.+?DestinationRegion=us-west-2.+`, op)
51		if re, a := regexp.MustCompile(exp), u; !re.MatchString(a) {
52			t.Errorf("expect %s to match %s", re, a)
53		}
54	}
55}
56
57func TestPresignWithPresignSet(t *testing.T) {
58	reqs := map[string]*request.Request{}
59	svc := New(unit.Session, &aws.Config{Region: aws.String("us-west-2")})
60
61	f := func() {
62		// Doesn't panic on nil input
63		req, _ := svc.CopyDBSnapshotRequest(nil)
64		req.Sign()
65	}
66	if paniced, p := awstesting.DidPanic(f); paniced {
67		t.Errorf("expect no panic, got %v", p)
68	}
69
70	reqs[opCopyDBSnapshot], _ = svc.CopyDBSnapshotRequest(&CopyDBSnapshotInput{
71		SourceRegion:               aws.String("us-west-1"),
72		SourceDBSnapshotIdentifier: aws.String("foo"),
73		TargetDBSnapshotIdentifier: aws.String("bar"),
74		PreSignedUrl:               aws.String("presignedURL"),
75	})
76
77	reqs[opCreateDBInstanceReadReplica], _ = svc.CreateDBInstanceReadReplicaRequest(&CreateDBInstanceReadReplicaInput{
78		SourceRegion:               aws.String("us-west-1"),
79		SourceDBInstanceIdentifier: aws.String("foo"),
80		DBInstanceIdentifier:       aws.String("bar"),
81		PreSignedUrl:               aws.String("presignedURL"),
82	})
83
84	for _, req := range reqs {
85		req.Sign()
86
87		b, _ := ioutil.ReadAll(req.HTTPRequest.Body)
88		q, _ := url.ParseQuery(string(b))
89
90		u, _ := url.QueryUnescape(q.Get("PreSignedUrl"))
91		if e, a := "presignedURL", u; !strings.Contains(a, e) {
92			t.Errorf("expect %s to be in %s", e, a)
93		}
94	}
95}
96
97func TestPresignWithSourceNotSet(t *testing.T) {
98	reqs := map[string]*request.Request{}
99	svc := New(unit.Session, &aws.Config{Region: aws.String("us-west-2")})
100
101	f := func() {
102		// Doesn't panic on nil input
103		req, _ := svc.CopyDBSnapshotRequest(nil)
104		req.Sign()
105	}
106	if paniced, p := awstesting.DidPanic(f); paniced {
107		t.Errorf("expect no panic, got %v", p)
108	}
109
110	reqs[opCopyDBSnapshot], _ = svc.CopyDBSnapshotRequest(&CopyDBSnapshotInput{
111		SourceDBSnapshotIdentifier: aws.String("foo"),
112		TargetDBSnapshotIdentifier: aws.String("bar"),
113	})
114
115	for _, req := range reqs {
116		_, err := req.Presign(5 * time.Minute)
117		if err != nil {
118			t.Fatal(err)
119		}
120	}
121}
122