1package corehandlers_test
2
3import (
4	"fmt"
5	"reflect"
6	"testing"
7
8	"github.com/aws/aws-sdk-go/aws"
9	"github.com/aws/aws-sdk-go/aws/awserr"
10	"github.com/aws/aws-sdk-go/aws/client"
11	"github.com/aws/aws-sdk-go/aws/client/metadata"
12	"github.com/aws/aws-sdk-go/aws/corehandlers"
13	"github.com/aws/aws-sdk-go/aws/request"
14	"github.com/aws/aws-sdk-go/awstesting/unit"
15	"github.com/aws/aws-sdk-go/service/kinesis"
16)
17
18var testSvc = func() *client.Client {
19	s := &client.Client{
20		Config: aws.Config{},
21		ClientInfo: metadata.ClientInfo{
22			ServiceName: "mock-service",
23			APIVersion:  "2015-01-01",
24		},
25	}
26	return s
27}()
28
29type StructShape struct {
30	_ struct{} `type:"structure"`
31
32	RequiredList   []*ConditionalStructShape          `required:"true"`
33	RequiredMap    map[string]*ConditionalStructShape `required:"true"`
34	RequiredBool   *bool                              `required:"true"`
35	OptionalStruct *ConditionalStructShape
36
37	hiddenParameter *string
38}
39
40func (s *StructShape) Validate() error {
41	invalidParams := request.ErrInvalidParams{Context: "StructShape"}
42	if s.RequiredList == nil {
43		invalidParams.Add(request.NewErrParamRequired("RequiredList"))
44	}
45	if s.RequiredMap == nil {
46		invalidParams.Add(request.NewErrParamRequired("RequiredMap"))
47	}
48	if s.RequiredBool == nil {
49		invalidParams.Add(request.NewErrParamRequired("RequiredBool"))
50	}
51	if s.RequiredList != nil {
52		for i, v := range s.RequiredList {
53			if v == nil {
54				continue
55			}
56			if err := v.Validate(); err != nil {
57				invalidParams.AddNested(fmt.Sprintf("%s[%v]", "RequiredList", i), err.(request.ErrInvalidParams))
58			}
59		}
60	}
61	if s.RequiredMap != nil {
62		for i, v := range s.RequiredMap {
63			if v == nil {
64				continue
65			}
66			if err := v.Validate(); err != nil {
67				invalidParams.AddNested(fmt.Sprintf("%s[%v]", "RequiredMap", i), err.(request.ErrInvalidParams))
68			}
69		}
70	}
71	if s.OptionalStruct != nil {
72		if err := s.OptionalStruct.Validate(); err != nil {
73			invalidParams.AddNested("OptionalStruct", err.(request.ErrInvalidParams))
74		}
75	}
76
77	if invalidParams.Len() > 0 {
78		return invalidParams
79	}
80	return nil
81}
82
83type ConditionalStructShape struct {
84	_ struct{} `type:"structure"`
85
86	Name *string `required:"true"`
87}
88
89func (s *ConditionalStructShape) Validate() error {
90	invalidParams := request.ErrInvalidParams{Context: "ConditionalStructShape"}
91	if s.Name == nil {
92		invalidParams.Add(request.NewErrParamRequired("Name"))
93	}
94
95	if invalidParams.Len() > 0 {
96		return invalidParams
97	}
98	return nil
99}
100
101func TestNoErrors(t *testing.T) {
102	input := &StructShape{
103		RequiredList: []*ConditionalStructShape{},
104		RequiredMap: map[string]*ConditionalStructShape{
105			"key1": {Name: aws.String("Name")},
106			"key2": {Name: aws.String("Name")},
107		},
108		RequiredBool:   aws.Bool(true),
109		OptionalStruct: &ConditionalStructShape{Name: aws.String("Name")},
110	}
111
112	req := testSvc.NewRequest(&request.Operation{}, input, nil)
113	corehandlers.ValidateParametersHandler.Fn(req)
114	if req.Error != nil {
115		t.Fatalf("expect no error, got %v", req.Error)
116	}
117}
118
119func TestMissingRequiredParameters(t *testing.T) {
120	input := &StructShape{}
121	req := testSvc.NewRequest(&request.Operation{}, input, nil)
122	corehandlers.ValidateParametersHandler.Fn(req)
123
124	if req.Error == nil {
125		t.Fatalf("expect error")
126	}
127	if e, a := "InvalidParameter", req.Error.(awserr.Error).Code(); e != a {
128		t.Errorf("expect %v, got %v", e, a)
129	}
130	if e, a := "3 validation error(s) found.", req.Error.(awserr.Error).Message(); e != a {
131		t.Errorf("expect %v, got %v", e, a)
132	}
133
134	errs := req.Error.(awserr.BatchedErrors).OrigErrs()
135	if e, a := 3, len(errs); e != a {
136		t.Errorf("expect %v, got %v", e, a)
137	}
138	if e, a := "ParamRequiredError: missing required field, StructShape.RequiredList.", errs[0].Error(); e != a {
139		t.Errorf("expect %v, got %v", e, a)
140	}
141	if e, a := "ParamRequiredError: missing required field, StructShape.RequiredMap.", errs[1].Error(); e != a {
142		t.Errorf("expect %v, got %v", e, a)
143	}
144	if e, a := "ParamRequiredError: missing required field, StructShape.RequiredBool.", errs[2].Error(); e != a {
145		t.Errorf("expect %v, got %v", e, a)
146	}
147
148	if e, a := "InvalidParameter: 3 validation error(s) found.\n- missing required field, StructShape.RequiredList.\n- missing required field, StructShape.RequiredMap.\n- missing required field, StructShape.RequiredBool.\n", req.Error.Error(); e != a {
149		t.Errorf("expect %v, got %v", e, a)
150	}
151}
152
153func TestNestedMissingRequiredParameters(t *testing.T) {
154	input := &StructShape{
155		RequiredList: []*ConditionalStructShape{{}},
156		RequiredMap: map[string]*ConditionalStructShape{
157			"key1": {Name: aws.String("Name")},
158			"key2": {},
159		},
160		RequiredBool:   aws.Bool(true),
161		OptionalStruct: &ConditionalStructShape{},
162	}
163
164	req := testSvc.NewRequest(&request.Operation{}, input, nil)
165	corehandlers.ValidateParametersHandler.Fn(req)
166
167	if req.Error == nil {
168		t.Fatalf("expect error")
169	}
170	if e, a := "InvalidParameter", req.Error.(awserr.Error).Code(); e != a {
171		t.Errorf("expect %v, got %v", e, a)
172	}
173	if e, a := "3 validation error(s) found.", req.Error.(awserr.Error).Message(); e != a {
174		t.Errorf("expect %v, got %v", e, a)
175	}
176
177	errs := req.Error.(awserr.BatchedErrors).OrigErrs()
178	if e, a := 3, len(errs); e != a {
179		t.Errorf("expect %v, got %v", e, a)
180	}
181	if e, a := "ParamRequiredError: missing required field, StructShape.RequiredList[0].Name.", errs[0].Error(); e != a {
182		t.Errorf("expect %v, got %v", e, a)
183	}
184	if e, a := "ParamRequiredError: missing required field, StructShape.RequiredMap[key2].Name.", errs[1].Error(); e != a {
185		t.Errorf("expect %v, got %v", e, a)
186	}
187	if e, a := "ParamRequiredError: missing required field, StructShape.OptionalStruct.Name.", errs[2].Error(); e != a {
188		t.Errorf("expect %v, got %v", e, a)
189	}
190}
191
192type testInput struct {
193	StringField *string           `min:"5"`
194	ListField   []string          `min:"3"`
195	MapField    map[string]string `min:"4"`
196}
197
198func (s testInput) Validate() error {
199	invalidParams := request.ErrInvalidParams{Context: "testInput"}
200	if s.StringField != nil && len(*s.StringField) < 5 {
201		invalidParams.Add(request.NewErrParamMinLen("StringField", 5))
202	}
203	if s.ListField != nil && len(s.ListField) < 3 {
204		invalidParams.Add(request.NewErrParamMinLen("ListField", 3))
205	}
206	if s.MapField != nil && len(s.MapField) < 4 {
207		invalidParams.Add(request.NewErrParamMinLen("MapField", 4))
208	}
209
210	if invalidParams.Len() > 0 {
211		return invalidParams
212	}
213	return nil
214}
215
216var testsFieldMin = []struct {
217	err awserr.Error
218	in  testInput
219}{
220	{
221		err: func() awserr.Error {
222			invalidParams := request.ErrInvalidParams{Context: "testInput"}
223			invalidParams.Add(request.NewErrParamMinLen("StringField", 5))
224			return invalidParams
225		}(),
226		in: testInput{StringField: aws.String("abcd")},
227	},
228	{
229		err: func() awserr.Error {
230			invalidParams := request.ErrInvalidParams{Context: "testInput"}
231			invalidParams.Add(request.NewErrParamMinLen("StringField", 5))
232			invalidParams.Add(request.NewErrParamMinLen("ListField", 3))
233			return invalidParams
234		}(),
235		in: testInput{StringField: aws.String("abcd"), ListField: []string{"a", "b"}},
236	},
237	{
238		err: func() awserr.Error {
239			invalidParams := request.ErrInvalidParams{Context: "testInput"}
240			invalidParams.Add(request.NewErrParamMinLen("StringField", 5))
241			invalidParams.Add(request.NewErrParamMinLen("ListField", 3))
242			invalidParams.Add(request.NewErrParamMinLen("MapField", 4))
243			return invalidParams
244		}(),
245		in: testInput{StringField: aws.String("abcd"), ListField: []string{"a", "b"}, MapField: map[string]string{"a": "a", "b": "b"}},
246	},
247	{
248		err: nil,
249		in: testInput{StringField: aws.String("abcde"),
250			ListField: []string{"a", "b", "c"}, MapField: map[string]string{"a": "a", "b": "b", "c": "c", "d": "d"}},
251	},
252}
253
254func TestValidateFieldMinParameter(t *testing.T) {
255	for i, c := range testsFieldMin {
256		req := testSvc.NewRequest(&request.Operation{}, &c.in, nil)
257		corehandlers.ValidateParametersHandler.Fn(req)
258
259		if e, a := c.err, req.Error; !reflect.DeepEqual(e, a) {
260			t.Errorf("%d, expect %v, got %v", i, e, a)
261		}
262	}
263}
264
265func BenchmarkValidateAny(b *testing.B) {
266	input := &kinesis.PutRecordsInput{
267		StreamName: aws.String("stream"),
268	}
269	for i := 0; i < 100; i++ {
270		record := &kinesis.PutRecordsRequestEntry{
271			Data:         make([]byte, 10000),
272			PartitionKey: aws.String("partition"),
273		}
274		input.Records = append(input.Records, record)
275	}
276
277	req, _ := kinesis.New(unit.Session).PutRecordsRequest(input)
278
279	b.ResetTimer()
280	for i := 0; i < b.N; i++ {
281		corehandlers.ValidateParametersHandler.Fn(req)
282		if err := req.Error; err != nil {
283			b.Fatalf("validation failed: %v", err)
284		}
285	}
286}
287