1package protocol_test
2
3import (
4	"net/http"
5	"net/url"
6	"testing"
7
8	"github.com/aws/aws-sdk-go/aws/client/metadata"
9	"github.com/aws/aws-sdk-go/aws/request"
10	"github.com/aws/aws-sdk-go/awstesting"
11	"github.com/aws/aws-sdk-go/private/protocol"
12	"github.com/aws/aws-sdk-go/private/protocol/ec2query"
13	"github.com/aws/aws-sdk-go/private/protocol/jsonrpc"
14	"github.com/aws/aws-sdk-go/private/protocol/query"
15	"github.com/aws/aws-sdk-go/private/protocol/rest"
16	"github.com/aws/aws-sdk-go/private/protocol/restjson"
17	"github.com/aws/aws-sdk-go/private/protocol/restxml"
18)
19
20func xmlData(set bool, b []byte, size, delta int) {
21	const openingTags = "<B><A>"
22	const closingTags = "</A></B>"
23	if !set {
24		copy(b, []byte(openingTags))
25	}
26	if size == 0 {
27		copy(b[delta-len(closingTags):], []byte(closingTags))
28	}
29}
30
31func jsonData(set bool, b []byte, size, delta int) {
32	if !set {
33		copy(b, []byte("{\"A\": \""))
34	}
35	if size == 0 {
36		copy(b[delta-len("\"}"):], []byte("\"}"))
37	}
38}
39
40func buildNewRequest(data interface{}) *request.Request {
41	v := url.Values{}
42	v.Set("test", "TEST")
43	v.Add("test1", "TEST1")
44
45	req := &request.Request{
46		HTTPRequest: &http.Request{
47			Header: make(http.Header),
48			Body:   &awstesting.ReadCloser{Size: 2048},
49			URL: &url.URL{
50				RawQuery: v.Encode(),
51			},
52		},
53		Params: &struct {
54			LocationName string `locationName:"test"`
55		}{
56			"Test",
57		},
58		ClientInfo: metadata.ClientInfo{
59			ServiceName:   "test",
60			TargetPrefix:  "test",
61			JSONVersion:   "test",
62			APIVersion:    "test",
63			Endpoint:      "test",
64			SigningName:   "test",
65			SigningRegion: "test",
66		},
67		Operation: &request.Operation{
68			Name: "test",
69		},
70	}
71	req.HTTPResponse = &http.Response{
72		Body: &awstesting.ReadCloser{Size: 2048},
73		Header: http.Header{
74			"X-Amzn-Requestid": []string{"1"},
75		},
76		StatusCode: http.StatusOK,
77	}
78
79	if data == nil {
80		data = &struct {
81			_            struct{} `type:"structure"`
82			LocationName *string  `locationName:"testName"`
83			Location     *string  `location:"statusCode"`
84			A            *string  `type:"string"`
85		}{}
86	}
87
88	req.Data = data
89
90	return req
91}
92
93type expected struct {
94	dataType  int
95	closed    bool
96	size      int
97	errExists bool
98}
99
100const (
101	jsonType = iota
102	xmlType
103)
104
105func checkForLeak(data interface{}, build, fn func(*request.Request), t *testing.T, result expected) {
106	req := buildNewRequest(data)
107	reader := req.HTTPResponse.Body.(*awstesting.ReadCloser)
108	switch result.dataType {
109	case jsonType:
110		reader.FillData = jsonData
111	case xmlType:
112		reader.FillData = xmlData
113	}
114	build(req)
115	fn(req)
116
117	if result.errExists {
118		if err := req.Error; err == nil {
119			t.Errorf("expect error")
120		}
121	} else {
122		if err := req.Error; err != nil {
123			t.Errorf("expect nil, %v", err)
124		}
125	}
126
127	if e, a := reader.Closed, result.closed; e != a {
128		t.Errorf("expect %v, got %v", e, a)
129	}
130	if e, a := reader.Size, result.size; e != a {
131		t.Errorf("expect %v, got %v", e, a)
132	}
133}
134
135func TestJSONRpc(t *testing.T) {
136	checkForLeak(nil, jsonrpc.Build, jsonrpc.Unmarshal, t, expected{jsonType, true, 0, false})
137	checkForLeak(nil, jsonrpc.Build, jsonrpc.UnmarshalMeta, t, expected{jsonType, false, 2048, false})
138	checkForLeak(nil, jsonrpc.Build, jsonrpc.UnmarshalError, t, expected{jsonType, true, 0, true})
139}
140
141func TestQuery(t *testing.T) {
142	checkForLeak(nil, query.Build, query.Unmarshal, t, expected{jsonType, true, 0, false})
143	checkForLeak(nil, query.Build, query.UnmarshalMeta, t, expected{jsonType, false, 2048, false})
144	checkForLeak(nil, query.Build, query.UnmarshalError, t, expected{jsonType, true, 0, true})
145}
146
147func TestRest(t *testing.T) {
148	// case 1: Payload io.ReadSeeker
149	checkForLeak(nil, rest.Build, rest.Unmarshal, t, expected{jsonType, false, 2048, false})
150	checkForLeak(nil, query.Build, query.UnmarshalMeta, t, expected{jsonType, false, 2048, false})
151
152	// case 2: Payload *string
153	// should close the body
154	dataStr := struct {
155		_            struct{} `type:"structure" payload:"Payload"`
156		LocationName *string  `locationName:"testName"`
157		Location     *string  `location:"statusCode"`
158		A            *string  `type:"string"`
159		Payload      *string  `locationName:"payload" type:"blob" required:"true"`
160	}{}
161	checkForLeak(&dataStr, rest.Build, rest.Unmarshal, t, expected{jsonType, true, 0, false})
162	checkForLeak(&dataStr, query.Build, query.UnmarshalMeta, t, expected{jsonType, false, 2048, false})
163
164	// case 3: Payload []byte
165	// should close the body
166	dataBytes := struct {
167		_            struct{} `type:"structure" payload:"Payload"`
168		LocationName *string  `locationName:"testName"`
169		Location     *string  `location:"statusCode"`
170		A            *string  `type:"string"`
171		Payload      []byte   `locationName:"payload" type:"blob" required:"true"`
172	}{}
173	checkForLeak(&dataBytes, rest.Build, rest.Unmarshal, t, expected{jsonType, true, 0, false})
174	checkForLeak(&dataBytes, query.Build, query.UnmarshalMeta, t, expected{jsonType, false, 2048, false})
175
176	// case 4: Payload unsupported type
177	// should close the body
178	dataUnsupported := struct {
179		_            struct{} `type:"structure" payload:"Payload"`
180		LocationName *string  `locationName:"testName"`
181		Location     *string  `location:"statusCode"`
182		A            *string  `type:"string"`
183		Payload      string   `locationName:"payload" type:"blob" required:"true"`
184	}{}
185	checkForLeak(&dataUnsupported, rest.Build, rest.Unmarshal, t, expected{jsonType, true, 0, true})
186	checkForLeak(&dataUnsupported, query.Build, query.UnmarshalMeta, t, expected{jsonType, false, 2048, false})
187}
188
189func TestRestJSON(t *testing.T) {
190	checkForLeak(nil, restjson.Build, restjson.Unmarshal, t, expected{jsonType, true, 0, false})
191	checkForLeak(nil, restjson.Build, restjson.UnmarshalMeta, t, expected{jsonType, false, 2048, false})
192	checkForLeak(nil, restjson.Build, restjson.UnmarshalError, t, expected{jsonType, true, 0, true})
193}
194
195func TestRestXML(t *testing.T) {
196	checkForLeak(nil, restxml.Build, restxml.Unmarshal, t, expected{xmlType, true, 0, false})
197	checkForLeak(nil, restxml.Build, restxml.UnmarshalMeta, t, expected{xmlType, false, 2048, false})
198	checkForLeak(nil, restxml.Build, restxml.UnmarshalError, t, expected{xmlType, true, 0, true})
199}
200
201func TestXML(t *testing.T) {
202	checkForLeak(nil, ec2query.Build, ec2query.Unmarshal, t, expected{jsonType, true, 0, false})
203	checkForLeak(nil, ec2query.Build, ec2query.UnmarshalMeta, t, expected{jsonType, false, 2048, false})
204	checkForLeak(nil, ec2query.Build, ec2query.UnmarshalError, t, expected{jsonType, true, 0, true})
205}
206
207func TestProtocol(t *testing.T) {
208	checkForLeak(nil, restxml.Build, protocol.UnmarshalDiscardBody, t, expected{xmlType, true, 0, false})
209}
210