1// +build codegen
2
3package main
4
5import (
6	"bytes"
7	"encoding/json"
8	"fmt"
9	"net/url"
10	"os"
11	"os/exec"
12	"reflect"
13	"regexp"
14	"sort"
15	"strconv"
16	"strings"
17	"text/template"
18
19	"github.com/aws/aws-sdk-go/private/model/api"
20	"github.com/aws/aws-sdk-go/private/util"
21)
22
23// TestSuiteTypeInput input test
24// TestSuiteTypeInput output test
25const (
26	TestSuiteTypeInput = iota
27	TestSuiteTypeOutput
28)
29
30type testSuite struct {
31	*api.API
32	Description    string
33	ClientEndpoint string
34	Cases          []testCase
35	Type           uint
36	title          string
37}
38
39func (s *testSuite) UnmarshalJSON(p []byte) error {
40	type stub testSuite
41
42	var v stub
43	if err := json.Unmarshal(p, &v); err != nil {
44		return err
45	}
46
47	if len(v.ClientEndpoint) == 0 {
48		v.ClientEndpoint = "https://test"
49	}
50	for i := 0; i < len(v.Cases); i++ {
51		if len(v.Cases[i].InputTest.Host) == 0 {
52			v.Cases[i].InputTest.Host = "test"
53		}
54		if len(v.Cases[i].InputTest.URI) == 0 {
55			v.Cases[i].InputTest.URI = "/"
56		}
57	}
58
59	*s = testSuite(v)
60	return nil
61}
62
63type testCase struct {
64	TestSuite  *testSuite
65	Given      *api.Operation
66	Params     interface{}     `json:",omitempty"`
67	Data       interface{}     `json:"result,omitempty"`
68	InputTest  testExpectation `json:"serialized"`
69	OutputTest testExpectation `json:"response"`
70}
71
72type testExpectation struct {
73	Body       string
74	Host       string
75	URI        string
76	Headers    map[string]string
77	JSONValues map[string]string
78	StatusCode uint `json:"status_code"`
79}
80
81const preamble = `
82var _ bytes.Buffer // always import bytes
83var _ http.Request
84var _ json.Marshaler
85var _ time.Time
86var _ xmlutil.XMLNode
87var _ xml.Attr
88var _ = ioutil.Discard
89var _ = util.Trim("")
90var _ = url.Values{}
91var _ = io.EOF
92var _ = aws.String
93var _ = fmt.Println
94var _ = reflect.Value{}
95
96func init() {
97	protocol.RandReader = &awstesting.ZeroReader{}
98}
99`
100
101var reStripSpace = regexp.MustCompile(`\s(\w)`)
102
103var reImportRemoval = regexp.MustCompile(`(?s:import \((.+?)\))`)
104
105func removeImports(code string) string {
106	return reImportRemoval.ReplaceAllString(code, "")
107}
108
109var extraImports = []string{
110	"bytes",
111	"encoding/json",
112	"encoding/xml",
113	"fmt",
114	"io",
115	"io/ioutil",
116	"net/http",
117	"testing",
118	"time",
119	"reflect",
120	"net/url",
121	"",
122	"github.com/aws/aws-sdk-go/awstesting",
123	"github.com/aws/aws-sdk-go/awstesting/unit",
124	"github.com/aws/aws-sdk-go/private/protocol",
125	"github.com/aws/aws-sdk-go/private/protocol/xml/xmlutil",
126	"github.com/aws/aws-sdk-go/private/util",
127}
128
129func addImports(code string) string {
130	importNames := make([]string, len(extraImports))
131	for i, n := range extraImports {
132		if n != "" {
133			importNames[i] = fmt.Sprintf("%q", n)
134		}
135	}
136
137	str := reImportRemoval.ReplaceAllString(code, "import (\n"+strings.Join(importNames, "\n")+"$1\n)")
138	return str
139}
140
141func (t *testSuite) TestSuite() string {
142	var buf bytes.Buffer
143
144	t.title = reStripSpace.ReplaceAllStringFunc(t.Description, func(x string) string {
145		return strings.ToUpper(x[1:])
146	})
147	t.title = regexp.MustCompile(`\W`).ReplaceAllString(t.title, "")
148
149	for idx, c := range t.Cases {
150		c.TestSuite = t
151		buf.WriteString(c.TestCase(idx) + "\n")
152	}
153	return buf.String()
154}
155
156var tplInputTestCase = template.Must(template.New("inputcase").Parse(`
157func Test{{ .OpName }}(t *testing.T) {
158	svc := New{{ .TestCase.TestSuite.API.StructName }}(unit.Session, &aws.Config{Endpoint: aws.String("{{ .TestCase.TestSuite.ClientEndpoint  }}")})
159	{{ if ne .ParamsString "" }}input := {{ .ParamsString }}
160	{{ range $k, $v := .JSONValues -}}
161	input.{{ $k }} = {{ $v }}
162	{{ end -}}
163	req, _ := svc.{{ .TestCase.Given.ExportedName }}Request(input){{ else }}req, _ := svc.{{ .TestCase.Given.ExportedName }}Request(nil){{ end }}
164	r := req.HTTPRequest
165
166	// build request
167	req.Build()
168	if req.Error != nil {
169		t.Errorf("expect no error, got %v", req.Error)
170	}
171
172	{{ if ne .TestCase.InputTest.Body "" }}// assert body
173	if r.Body == nil {
174		t.Errorf("expect body not to be nil")
175	}
176	{{ .BodyAssertions }}{{ end }}
177
178	// assert URL
179	awstesting.AssertURL(t, "https://{{ .TestCase.InputTest.Host }}{{ .TestCase.InputTest.URI }}", r.URL.String())
180
181	// assert headers
182	{{ range $k, $v := .TestCase.InputTest.Headers -}}
183		if e, a := "{{ $v }}", r.Header.Get("{{ $k }}"); e != a {
184			t.Errorf("expect %v, got %v", e, a)
185		}
186	{{ end }}
187}
188`))
189
190type tplInputTestCaseData struct {
191	TestCase             *testCase
192	JSONValues           map[string]string
193	OpName, ParamsString string
194}
195
196func (t tplInputTestCaseData) BodyAssertions() string {
197	code := &bytes.Buffer{}
198	protocol := t.TestCase.TestSuite.API.Metadata.Protocol
199
200	// Extract the body bytes
201	switch protocol {
202	case "rest-xml":
203		fmt.Fprintln(code, "body := util.SortXML(r.Body)")
204	default:
205		fmt.Fprintln(code, "body, _ := ioutil.ReadAll(r.Body)")
206	}
207
208	// Generate the body verification code
209	expectedBody := util.Trim(t.TestCase.InputTest.Body)
210	switch protocol {
211	case "ec2", "query":
212		fmt.Fprintf(code, "awstesting.AssertQuery(t, `%s`, util.Trim(string(body)))",
213			expectedBody)
214	case "rest-xml":
215		if strings.HasPrefix(expectedBody, "<") {
216			fmt.Fprintf(code, "awstesting.AssertXML(t, `%s`, util.Trim(body))",
217				expectedBody)
218		} else {
219			code.WriteString(fmtAssertEqual(fmt.Sprintf("%q", expectedBody), "util.Trim(string(body))"))
220		}
221	case "json", "jsonrpc", "rest-json":
222		if strings.HasPrefix(expectedBody, "{") {
223			fmt.Fprintf(code, "awstesting.AssertJSON(t, `%s`, util.Trim(string(body)))",
224				expectedBody)
225		} else {
226			code.WriteString(fmtAssertEqual(fmt.Sprintf("%q", expectedBody), "util.Trim(string(body))"))
227		}
228	default:
229		code.WriteString(fmtAssertEqual(expectedBody, "util.Trim(string(body))"))
230	}
231
232	return code.String()
233}
234
235func fmtAssertEqual(e, a string) string {
236	const format = `if e, a := %s, %s; e != a {
237		t.Errorf("expect %%v, got %%v", e, a)
238	}
239	`
240
241	return fmt.Sprintf(format, e, a)
242}
243
244func fmtAssertNil(v string) string {
245	const format = `if e := %s; e != nil {
246		t.Errorf("expect nil, got %%v", e)
247	}
248	`
249
250	return fmt.Sprintf(format, v)
251}
252
253var tplOutputTestCase = template.Must(template.New("outputcase").Parse(`
254func Test{{ .OpName }}(t *testing.T) {
255	svc := New{{ .TestCase.TestSuite.API.StructName }}(unit.Session, &aws.Config{Endpoint: aws.String("https://test")})
256
257	buf := bytes.NewReader([]byte({{ .Body }}))
258	req, out := svc.{{ .TestCase.Given.ExportedName }}Request(nil)
259	req.HTTPResponse = &http.Response{StatusCode: 200, Body: ioutil.NopCloser(buf), Header: http.Header{}}
260
261	// set headers
262	{{ range $k, $v := .TestCase.OutputTest.Headers }}req.HTTPResponse.Header.Set("{{ $k }}", "{{ $v }}")
263	{{ end }}
264
265	// unmarshal response
266	req.Handlers.UnmarshalMeta.Run(req)
267	req.Handlers.Unmarshal.Run(req)
268	if req.Error != nil {
269		t.Errorf("expect not error, got %v", req.Error)
270	}
271
272	// assert response
273	if out == nil {
274		t.Errorf("expect not to be nil")
275	}
276	{{ .Assertions }}
277}
278`))
279
280type tplOutputTestCaseData struct {
281	TestCase                 *testCase
282	Body, OpName, Assertions string
283}
284
285func (i *testCase) TestCase(idx int) string {
286	var buf bytes.Buffer
287
288	opName := i.TestSuite.API.StructName() + i.TestSuite.title + "Case" + strconv.Itoa(idx+1)
289
290	if i.TestSuite.Type == TestSuiteTypeInput { // input test
291		// query test should sort body as form encoded values
292		switch i.TestSuite.API.Metadata.Protocol {
293		case "query", "ec2":
294			m, _ := url.ParseQuery(i.InputTest.Body)
295			i.InputTest.Body = m.Encode()
296		case "rest-xml":
297			i.InputTest.Body = util.SortXML(bytes.NewReader([]byte(i.InputTest.Body)))
298		case "json", "rest-json":
299			// Nothing to do
300		}
301
302		jsonValues := buildJSONValues(i.Given.InputRef.Shape)
303		var params interface{}
304		if m, ok := i.Params.(map[string]interface{}); ok {
305			paramsMap := map[string]interface{}{}
306			for k, v := range m {
307				if _, ok := jsonValues[k]; !ok {
308					paramsMap[k] = v
309				} else {
310					if i.InputTest.JSONValues == nil {
311						i.InputTest.JSONValues = map[string]string{}
312					}
313					i.InputTest.JSONValues[k] = serializeJSONValue(v.(map[string]interface{}))
314				}
315			}
316			params = paramsMap
317		} else {
318			params = i.Params
319		}
320		input := tplInputTestCaseData{
321			TestCase:     i,
322			OpName:       strings.ToUpper(opName[0:1]) + opName[1:],
323			ParamsString: api.ParamsStructFromJSON(params, i.Given.InputRef.Shape, false),
324			JSONValues:   i.InputTest.JSONValues,
325		}
326
327		if err := tplInputTestCase.Execute(&buf, input); err != nil {
328			panic(err)
329		}
330	} else if i.TestSuite.Type == TestSuiteTypeOutput {
331		output := tplOutputTestCaseData{
332			TestCase:   i,
333			Body:       fmt.Sprintf("%q", i.OutputTest.Body),
334			OpName:     strings.ToUpper(opName[0:1]) + opName[1:],
335			Assertions: GenerateAssertions(i.Data, i.Given.OutputRef.Shape, "out"),
336		}
337
338		if err := tplOutputTestCase.Execute(&buf, output); err != nil {
339			panic(err)
340		}
341	}
342
343	return buf.String()
344}
345
346func serializeJSONValue(m map[string]interface{}) string {
347	str := "aws.JSONValue"
348	str += walkMap(m)
349	return str
350}
351
352func walkMap(m map[string]interface{}) string {
353	str := "{"
354	for k, v := range m {
355		str += fmt.Sprintf("%q:", k)
356		switch v.(type) {
357		case bool:
358			str += fmt.Sprintf("%t,\n", v.(bool))
359		case string:
360			str += fmt.Sprintf("%q,\n", v.(string))
361		case int:
362			str += fmt.Sprintf("%d,\n", v.(int))
363		case float64:
364			str += fmt.Sprintf("%f,\n", v.(float64))
365		case map[string]interface{}:
366			str += walkMap(v.(map[string]interface{}))
367		}
368	}
369	str += "}"
370	return str
371}
372
373func buildJSONValues(shape *api.Shape) map[string]struct{} {
374	keys := map[string]struct{}{}
375	for key, field := range shape.MemberRefs {
376		if field.JSONValue {
377			keys[key] = struct{}{}
378		}
379	}
380	return keys
381}
382
383// generateTestSuite generates a protocol test suite for a given configuration
384// JSON protocol test file.
385func generateTestSuite(filename string) string {
386	inout := "Input"
387	if strings.Contains(filename, "output/") {
388		inout = "Output"
389	}
390
391	var suites []testSuite
392	f, err := os.Open(filename)
393	if err != nil {
394		panic(err)
395	}
396
397	err = json.NewDecoder(f).Decode(&suites)
398	if err != nil {
399		panic(err)
400	}
401
402	var buf bytes.Buffer
403	buf.WriteString("// Code generated by models/protocol_tests/generate.go. DO NOT EDIT.\n\n")
404	buf.WriteString("package " + suites[0].ProtocolPackage() + "_test\n\n")
405
406	var innerBuf bytes.Buffer
407	innerBuf.WriteString("//\n// Tests begin here\n//\n\n\n")
408
409	for i, suite := range suites {
410		svcPrefix := inout + "Service" + strconv.Itoa(i+1)
411		suite.API.Metadata.ServiceAbbreviation = svcPrefix + "ProtocolTest"
412		suite.API.Operations = map[string]*api.Operation{}
413		for idx, c := range suite.Cases {
414			c.Given.ExportedName = svcPrefix + "TestCaseOperation" + strconv.Itoa(idx+1)
415			suite.API.Operations[c.Given.ExportedName] = c.Given
416		}
417
418		suite.Type = getType(inout)
419		suite.API.NoInitMethods = true       // don't generate init methods
420		suite.API.NoStringerMethods = true   // don't generate stringer methods
421		suite.API.NoConstServiceNames = true // don't generate service names
422		suite.API.Setup()
423		suite.API.Metadata.EndpointPrefix = suite.API.PackageName()
424		suite.API.Metadata.EndpointsID = suite.API.Metadata.EndpointPrefix
425
426		// Sort in order for deterministic test generation
427		names := make([]string, 0, len(suite.API.Shapes))
428		for n := range suite.API.Shapes {
429			names = append(names, n)
430		}
431		sort.Strings(names)
432		for _, name := range names {
433			s := suite.API.Shapes[name]
434			s.Rename(svcPrefix + "TestShape" + name)
435		}
436
437		svcCode := addImports(suite.API.ServiceGoCode())
438		if i == 0 {
439			importMatch := reImportRemoval.FindStringSubmatch(svcCode)
440			buf.WriteString(importMatch[0] + "\n\n")
441			buf.WriteString(preamble + "\n\n")
442		}
443		svcCode = removeImports(svcCode)
444		svcCode = strings.Replace(svcCode, "func New(", "func New"+suite.API.StructName()+"(", -1)
445		svcCode = strings.Replace(svcCode, "func newClient(", "func new"+suite.API.StructName()+"Client(", -1)
446		svcCode = strings.Replace(svcCode, "return newClient(", "return new"+suite.API.StructName()+"Client(", -1)
447		buf.WriteString(svcCode + "\n\n")
448
449		apiCode := removeImports(suite.API.APIGoCode())
450		apiCode = strings.Replace(apiCode, "var oprw sync.Mutex", "", -1)
451		apiCode = strings.Replace(apiCode, "oprw.Lock()", "", -1)
452		apiCode = strings.Replace(apiCode, "defer oprw.Unlock()", "", -1)
453		buf.WriteString(apiCode + "\n\n")
454
455		innerBuf.WriteString(suite.TestSuite() + "\n")
456	}
457
458	return buf.String() + innerBuf.String()
459}
460
461// findMember searches the shape for the member with the matching key name.
462func findMember(shape *api.Shape, key string) string {
463	for actualKey := range shape.MemberRefs {
464		if strings.EqualFold(key, actualKey) {
465			return actualKey
466		}
467	}
468	return ""
469}
470
471// GenerateAssertions builds assertions for a shape based on its type.
472//
473// The shape's recursive values also will have assertions generated for them.
474func GenerateAssertions(out interface{}, shape *api.Shape, prefix string) string {
475	if shape == nil {
476		return ""
477	}
478	switch t := out.(type) {
479	case map[string]interface{}:
480		keys := util.SortedKeys(t)
481
482		code := ""
483		if shape.Type == "map" {
484			for _, k := range keys {
485				v := t[k]
486				s := shape.ValueRef.Shape
487				code += GenerateAssertions(v, s, prefix+"[\""+k+"\"]")
488			}
489		} else if shape.Type == "jsonvalue" {
490			code += fmt.Sprintf("reflect.DeepEqual(%s, map[string]interface{}%s)\n", prefix, walkMap(out.(map[string]interface{})))
491		} else {
492			for _, k := range keys {
493				v := t[k]
494				m := findMember(shape, k)
495				s := shape.MemberRefs[m].Shape
496				code += GenerateAssertions(v, s, prefix+"."+m+"")
497			}
498		}
499		return code
500	case []interface{}:
501		code := ""
502		for i, v := range t {
503			s := shape.MemberRef.Shape
504			code += GenerateAssertions(v, s, prefix+"["+strconv.Itoa(i)+"]")
505		}
506		return code
507	default:
508		switch shape.Type {
509		case "timestamp":
510			return fmtAssertEqual(
511				fmt.Sprintf("time.Unix(%#v, 0).UTC().String()", out),
512				fmt.Sprintf("%s.UTC().String()", prefix),
513			)
514		case "blob":
515			return fmtAssertEqual(
516				fmt.Sprintf("%#v", out),
517				fmt.Sprintf("string(%s)", prefix),
518			)
519		case "integer", "long":
520			return fmtAssertEqual(
521				fmt.Sprintf("int64(%#v)", out),
522				fmt.Sprintf("*%s", prefix),
523			)
524		default:
525			if !reflect.ValueOf(out).IsValid() {
526				return fmtAssertNil(prefix)
527			}
528			return fmtAssertEqual(
529				fmt.Sprintf("%#v", out),
530				fmt.Sprintf("*%s", prefix),
531			)
532		}
533	}
534}
535
536func getType(t string) uint {
537	switch t {
538	case "Input":
539		return TestSuiteTypeInput
540	case "Output":
541		return TestSuiteTypeOutput
542	default:
543		panic("Invalid type for test suite")
544	}
545}
546
547func main() {
548	if len(os.Getenv("AWS_SDK_CODEGEN_DEBUG")) != 0 {
549		api.LogDebug(os.Stdout)
550	}
551
552	fmt.Println("Generating test suite", os.Args[1:])
553	out := generateTestSuite(os.Args[1])
554	if len(os.Args) == 3 {
555		f, err := os.Create(os.Args[2])
556		defer f.Close()
557		if err != nil {
558			panic(err)
559		}
560		f.WriteString(util.GoFmt(out))
561		f.Close()
562
563		c := exec.Command("gofmt", "-s", "-w", os.Args[2])
564		if err := c.Run(); err != nil {
565			panic(err)
566		}
567	} else {
568		fmt.Println(out)
569	}
570}
571