1// +build codegen
2
3package main
4
5import (
6	"bytes"
7	"encoding/json"
8	"fmt"
9	"net/url"
10	"os"
11	"os/exec"
12	"reflect"
13	"regexp"
14	"sort"
15	"strconv"
16	"strings"
17	"text/template"
18
19	"github.com/aws/aws-sdk-go/private/model/api"
20	"github.com/aws/aws-sdk-go/private/util"
21)
22
23// TestSuiteTypeInput input test
24// TestSuiteTypeInput output test
25const (
26	TestSuiteTypeInput = iota
27	TestSuiteTypeOutput
28)
29
30type testSuite struct {
31	*api.API
32	Description string
33	Cases       []testCase
34	Type        uint
35	title       string
36}
37
38type testCase struct {
39	TestSuite  *testSuite
40	Given      *api.Operation
41	Params     interface{}     `json:",omitempty"`
42	Data       interface{}     `json:"result,omitempty"`
43	InputTest  testExpectation `json:"serialized"`
44	OutputTest testExpectation `json:"response"`
45}
46
47type testExpectation struct {
48	Body       string
49	URI        string
50	Headers    map[string]string
51	JSONValues map[string]string
52	StatusCode uint `json:"status_code"`
53}
54
55const preamble = `
56var _ bytes.Buffer // always import bytes
57var _ http.Request
58var _ json.Marshaler
59var _ time.Time
60var _ xmlutil.XMLNode
61var _ xml.Attr
62var _ = ioutil.Discard
63var _ = util.Trim("")
64var _ = url.Values{}
65var _ = io.EOF
66var _ = aws.String
67var _ = fmt.Println
68var _ = reflect.Value{}
69
70func init() {
71	protocol.RandReader = &awstesting.ZeroReader{}
72}
73`
74
75var reStripSpace = regexp.MustCompile(`\s(\w)`)
76
77var reImportRemoval = regexp.MustCompile(`(?s:import \((.+?)\))`)
78
79func removeImports(code string) string {
80	return reImportRemoval.ReplaceAllString(code, "")
81}
82
83var extraImports = []string{
84	"bytes",
85	"encoding/json",
86	"encoding/xml",
87	"fmt",
88	"io",
89	"io/ioutil",
90	"net/http",
91	"testing",
92	"time",
93	"reflect",
94	"net/url",
95	"",
96	"github.com/aws/aws-sdk-go/awstesting",
97	"github.com/aws/aws-sdk-go/awstesting/unit",
98	"github.com/aws/aws-sdk-go/private/protocol",
99	"github.com/aws/aws-sdk-go/private/protocol/xml/xmlutil",
100	"github.com/aws/aws-sdk-go/private/util",
101}
102
103func addImports(code string) string {
104	importNames := make([]string, len(extraImports))
105	for i, n := range extraImports {
106		if n != "" {
107			importNames[i] = fmt.Sprintf("%q", n)
108		}
109	}
110
111	str := reImportRemoval.ReplaceAllString(code, "import (\n"+strings.Join(importNames, "\n")+"$1\n)")
112	return str
113}
114
115func (t *testSuite) TestSuite() string {
116	var buf bytes.Buffer
117
118	t.title = reStripSpace.ReplaceAllStringFunc(t.Description, func(x string) string {
119		return strings.ToUpper(x[1:])
120	})
121	t.title = regexp.MustCompile(`\W`).ReplaceAllString(t.title, "")
122
123	for idx, c := range t.Cases {
124		c.TestSuite = t
125		buf.WriteString(c.TestCase(idx) + "\n")
126	}
127	return buf.String()
128}
129
130var tplInputTestCase = template.Must(template.New("inputcase").Parse(`
131func Test{{ .OpName }}(t *testing.T) {
132	svc := New{{ .TestCase.TestSuite.API.StructName }}(unit.Session, &aws.Config{Endpoint: aws.String("https://test")})
133	{{ if ne .ParamsString "" }}input := {{ .ParamsString }}
134	{{ range $k, $v := .JSONValues -}}
135	input.{{ $k }} = {{ $v }}
136	{{ end -}}
137	req, _ := svc.{{ .TestCase.Given.ExportedName }}Request(input){{ else }}req, _ := svc.{{ .TestCase.Given.ExportedName }}Request(nil){{ end }}
138	r := req.HTTPRequest
139
140	// build request
141	{{ .TestCase.TestSuite.API.ProtocolPackage }}.Build(req)
142	if req.Error != nil {
143		t.Errorf("expect no error, got %v", req.Error)
144	}
145
146	{{ if ne .TestCase.InputTest.Body "" }}// assert body
147	if r.Body == nil {
148		t.Errorf("expect body not to be nil")
149	}
150	{{ .BodyAssertions }}{{ end }}
151
152	{{ if ne .TestCase.InputTest.URI "" }}// assert URL
153	awstesting.AssertURL(t, "https://test{{ .TestCase.InputTest.URI }}", r.URL.String()){{ end }}
154
155	// assert headers
156	{{ range $k, $v := .TestCase.InputTest.Headers -}}
157		if e, a := "{{ $v }}", r.Header.Get("{{ $k }}"); e != a {
158			t.Errorf("expect %v to be %v", e, a)
159		}
160	{{ end }}
161}
162`))
163
164type tplInputTestCaseData struct {
165	TestCase             *testCase
166	JSONValues           map[string]string
167	OpName, ParamsString string
168}
169
170func (t tplInputTestCaseData) BodyAssertions() string {
171	code := &bytes.Buffer{}
172	protocol := t.TestCase.TestSuite.API.Metadata.Protocol
173
174	// Extract the body bytes
175	switch protocol {
176	case "rest-xml":
177		fmt.Fprintln(code, "body := util.SortXML(r.Body)")
178	default:
179		fmt.Fprintln(code, "body, _ := ioutil.ReadAll(r.Body)")
180	}
181
182	// Generate the body verification code
183	expectedBody := util.Trim(t.TestCase.InputTest.Body)
184	switch protocol {
185	case "ec2", "query":
186		fmt.Fprintf(code, "awstesting.AssertQuery(t, `%s`, util.Trim(string(body)))",
187			expectedBody)
188	case "rest-xml":
189		if strings.HasPrefix(expectedBody, "<") {
190			fmt.Fprintf(code, "awstesting.AssertXML(t, `%s`, util.Trim(string(body)), %s{})",
191				expectedBody, t.TestCase.Given.InputRef.ShapeName)
192		} else {
193			code.WriteString(fmtAssertEqual(fmt.Sprintf("%q", expectedBody), "util.Trim(string(body))"))
194		}
195	case "json", "jsonrpc", "rest-json":
196		if strings.HasPrefix(expectedBody, "{") {
197			fmt.Fprintf(code, "awstesting.AssertJSON(t, `%s`, util.Trim(string(body)))",
198				expectedBody)
199		} else {
200			code.WriteString(fmtAssertEqual(fmt.Sprintf("%q", expectedBody), "util.Trim(string(body))"))
201		}
202	default:
203		code.WriteString(fmtAssertEqual(expectedBody, "util.Trim(string(body))"))
204	}
205
206	return code.String()
207}
208
209func fmtAssertEqual(e, a string) string {
210	const format = `if e, a := %s, %s; e != a {
211		t.Errorf("expect %%v, got %%v", e, a)
212	}
213	`
214
215	return fmt.Sprintf(format, e, a)
216}
217
218func fmtAssertNil(v string) string {
219	const format = `if e := %s; e != nil {
220		t.Errorf("expect nil, got %%v", e)
221	}
222	`
223
224	return fmt.Sprintf(format, v)
225}
226
227var tplOutputTestCase = template.Must(template.New("outputcase").Parse(`
228func Test{{ .OpName }}(t *testing.T) {
229	svc := New{{ .TestCase.TestSuite.API.StructName }}(unit.Session, &aws.Config{Endpoint: aws.String("https://test")})
230
231	buf := bytes.NewReader([]byte({{ .Body }}))
232	req, out := svc.{{ .TestCase.Given.ExportedName }}Request(nil)
233	req.HTTPResponse = &http.Response{StatusCode: 200, Body: ioutil.NopCloser(buf), Header: http.Header{}}
234
235	// set headers
236	{{ range $k, $v := .TestCase.OutputTest.Headers }}req.HTTPResponse.Header.Set("{{ $k }}", "{{ $v }}")
237	{{ end }}
238
239	// unmarshal response
240	{{ .TestCase.TestSuite.API.ProtocolPackage }}.UnmarshalMeta(req)
241	{{ .TestCase.TestSuite.API.ProtocolPackage }}.Unmarshal(req)
242	if req.Error != nil {
243		t.Errorf("expect not error, got %v", req.Error)
244	}
245
246	// assert response
247	if out == nil {
248		t.Errorf("expect not to be nil")
249	}
250	{{ .Assertions }}
251}
252`))
253
254type tplOutputTestCaseData struct {
255	TestCase                 *testCase
256	Body, OpName, Assertions string
257}
258
259func (i *testCase) TestCase(idx int) string {
260	var buf bytes.Buffer
261
262	opName := i.TestSuite.API.StructName() + i.TestSuite.title + "Case" + strconv.Itoa(idx+1)
263
264	if i.TestSuite.Type == TestSuiteTypeInput { // input test
265		// query test should sort body as form encoded values
266		switch i.TestSuite.API.Metadata.Protocol {
267		case "query", "ec2":
268			m, _ := url.ParseQuery(i.InputTest.Body)
269			i.InputTest.Body = m.Encode()
270		case "rest-xml":
271			i.InputTest.Body = util.SortXML(bytes.NewReader([]byte(i.InputTest.Body)))
272		case "json", "rest-json":
273			i.InputTest.Body = strings.Replace(i.InputTest.Body, " ", "", -1)
274		}
275
276		jsonValues := buildJSONValues(i.Given.InputRef.Shape)
277		var params interface{}
278		if m, ok := i.Params.(map[string]interface{}); ok {
279			paramsMap := map[string]interface{}{}
280			for k, v := range m {
281				if _, ok := jsonValues[k]; !ok {
282					paramsMap[k] = v
283				} else {
284					if i.InputTest.JSONValues == nil {
285						i.InputTest.JSONValues = map[string]string{}
286					}
287					i.InputTest.JSONValues[k] = serializeJSONValue(v.(map[string]interface{}))
288				}
289			}
290			params = paramsMap
291		} else {
292			params = i.Params
293		}
294		input := tplInputTestCaseData{
295			TestCase:     i,
296			OpName:       strings.ToUpper(opName[0:1]) + opName[1:],
297			ParamsString: api.ParamsStructFromJSON(params, i.Given.InputRef.Shape, false),
298			JSONValues:   i.InputTest.JSONValues,
299		}
300
301		if err := tplInputTestCase.Execute(&buf, input); err != nil {
302			panic(err)
303		}
304	} else if i.TestSuite.Type == TestSuiteTypeOutput {
305		output := tplOutputTestCaseData{
306			TestCase:   i,
307			Body:       fmt.Sprintf("%q", i.OutputTest.Body),
308			OpName:     strings.ToUpper(opName[0:1]) + opName[1:],
309			Assertions: GenerateAssertions(i.Data, i.Given.OutputRef.Shape, "out"),
310		}
311
312		if err := tplOutputTestCase.Execute(&buf, output); err != nil {
313			panic(err)
314		}
315	}
316
317	return buf.String()
318}
319
320func serializeJSONValue(m map[string]interface{}) string {
321	str := "aws.JSONValue"
322	str += walkMap(m)
323	return str
324}
325
326func walkMap(m map[string]interface{}) string {
327	str := "{"
328	for k, v := range m {
329		str += fmt.Sprintf("%q:", k)
330		switch v.(type) {
331		case bool:
332			str += fmt.Sprintf("%t,\n", v.(bool))
333		case string:
334			str += fmt.Sprintf("%q,\n", v.(string))
335		case int:
336			str += fmt.Sprintf("%d,\n", v.(int))
337		case float64:
338			str += fmt.Sprintf("%f,\n", v.(float64))
339		case map[string]interface{}:
340			str += walkMap(v.(map[string]interface{}))
341		}
342	}
343	str += "}"
344	return str
345}
346
347func buildJSONValues(shape *api.Shape) map[string]struct{} {
348	keys := map[string]struct{}{}
349	for key, field := range shape.MemberRefs {
350		if field.JSONValue {
351			keys[key] = struct{}{}
352		}
353	}
354	return keys
355}
356
357// generateTestSuite generates a protocol test suite for a given configuration
358// JSON protocol test file.
359func generateTestSuite(filename string) string {
360	inout := "Input"
361	if strings.Contains(filename, "output/") {
362		inout = "Output"
363	}
364
365	var suites []testSuite
366	f, err := os.Open(filename)
367	if err != nil {
368		panic(err)
369	}
370
371	err = json.NewDecoder(f).Decode(&suites)
372	if err != nil {
373		panic(err)
374	}
375
376	var buf bytes.Buffer
377	buf.WriteString("package " + suites[0].ProtocolPackage() + "_test\n\n")
378
379	var innerBuf bytes.Buffer
380	innerBuf.WriteString("//\n// Tests begin here\n//\n\n\n")
381
382	for i, suite := range suites {
383		svcPrefix := inout + "Service" + strconv.Itoa(i+1)
384		suite.API.Metadata.ServiceAbbreviation = svcPrefix + "ProtocolTest"
385		suite.API.Operations = map[string]*api.Operation{}
386		for idx, c := range suite.Cases {
387			c.Given.ExportedName = svcPrefix + "TestCaseOperation" + strconv.Itoa(idx+1)
388			suite.API.Operations[c.Given.ExportedName] = c.Given
389		}
390
391		suite.Type = getType(inout)
392		suite.API.NoInitMethods = true       // don't generate init methods
393		suite.API.NoStringerMethods = true   // don't generate stringer methods
394		suite.API.NoConstServiceNames = true // don't generate service names
395		suite.API.Setup()
396		suite.API.Metadata.EndpointPrefix = suite.API.PackageName()
397
398		// Sort in order for deterministic test generation
399		names := make([]string, 0, len(suite.API.Shapes))
400		for n := range suite.API.Shapes {
401			names = append(names, n)
402		}
403		sort.Strings(names)
404		for _, name := range names {
405			s := suite.API.Shapes[name]
406			s.Rename(svcPrefix + "TestShape" + name)
407		}
408
409		svcCode := addImports(suite.API.ServiceGoCode())
410		if i == 0 {
411			importMatch := reImportRemoval.FindStringSubmatch(svcCode)
412			buf.WriteString(importMatch[0] + "\n\n")
413			buf.WriteString(preamble + "\n\n")
414		}
415		svcCode = removeImports(svcCode)
416		svcCode = strings.Replace(svcCode, "func New(", "func New"+suite.API.StructName()+"(", -1)
417		svcCode = strings.Replace(svcCode, "func newClient(", "func new"+suite.API.StructName()+"Client(", -1)
418		svcCode = strings.Replace(svcCode, "return newClient(", "return new"+suite.API.StructName()+"Client(", -1)
419		buf.WriteString(svcCode + "\n\n")
420
421		apiCode := removeImports(suite.API.APIGoCode())
422		apiCode = strings.Replace(apiCode, "var oprw sync.Mutex", "", -1)
423		apiCode = strings.Replace(apiCode, "oprw.Lock()", "", -1)
424		apiCode = strings.Replace(apiCode, "defer oprw.Unlock()", "", -1)
425		buf.WriteString(apiCode + "\n\n")
426
427		innerBuf.WriteString(suite.TestSuite() + "\n")
428	}
429
430	return buf.String() + innerBuf.String()
431}
432
433// findMember searches the shape for the member with the matching key name.
434func findMember(shape *api.Shape, key string) string {
435	for actualKey := range shape.MemberRefs {
436		if strings.ToLower(key) == strings.ToLower(actualKey) {
437			return actualKey
438		}
439	}
440	return ""
441}
442
443// GenerateAssertions builds assertions for a shape based on its type.
444//
445// The shape's recursive values also will have assertions generated for them.
446func GenerateAssertions(out interface{}, shape *api.Shape, prefix string) string {
447	if shape == nil {
448		return ""
449	}
450	switch t := out.(type) {
451	case map[string]interface{}:
452		keys := util.SortedKeys(t)
453
454		code := ""
455		if shape.Type == "map" {
456			for _, k := range keys {
457				v := t[k]
458				s := shape.ValueRef.Shape
459				code += GenerateAssertions(v, s, prefix+"[\""+k+"\"]")
460			}
461		} else if shape.Type == "jsonvalue" {
462			code += fmt.Sprintf("reflect.DeepEqual(%s, map[string]interface{}%s)\n", prefix, walkMap(out.(map[string]interface{})))
463		} else {
464			for _, k := range keys {
465				v := t[k]
466				m := findMember(shape, k)
467				s := shape.MemberRefs[m].Shape
468				code += GenerateAssertions(v, s, prefix+"."+m+"")
469			}
470		}
471		return code
472	case []interface{}:
473		code := ""
474		for i, v := range t {
475			s := shape.MemberRef.Shape
476			code += GenerateAssertions(v, s, prefix+"["+strconv.Itoa(i)+"]")
477		}
478		return code
479	default:
480		switch shape.Type {
481		case "timestamp":
482			return fmtAssertEqual(
483				fmt.Sprintf("time.Unix(%#v, 0).UTC().String()", out),
484				fmt.Sprintf("%s.String()", prefix),
485			)
486		case "blob":
487			return fmtAssertEqual(
488				fmt.Sprintf("%#v", out),
489				fmt.Sprintf("string(%s)", prefix),
490			)
491		case "integer", "long":
492			return fmtAssertEqual(
493				fmt.Sprintf("int64(%#v)", out),
494				fmt.Sprintf("*%s", prefix),
495			)
496		default:
497			if !reflect.ValueOf(out).IsValid() {
498				return fmtAssertNil(prefix)
499			}
500			return fmtAssertEqual(
501				fmt.Sprintf("%#v", out),
502				fmt.Sprintf("*%s", prefix),
503			)
504		}
505	}
506}
507
508func getType(t string) uint {
509	switch t {
510	case "Input":
511		return TestSuiteTypeInput
512	case "Output":
513		return TestSuiteTypeOutput
514	default:
515		panic("Invalid type for test suite")
516	}
517}
518
519func main() {
520	fmt.Println("Generating test suite", os.Args[1:])
521	out := generateTestSuite(os.Args[1])
522	if len(os.Args) == 3 {
523		f, err := os.Create(os.Args[2])
524		defer f.Close()
525		if err != nil {
526			panic(err)
527		}
528		f.WriteString(util.GoFmt(out))
529		f.Close()
530
531		c := exec.Command("gofmt", "-s", "-w", os.Args[2])
532		if err := c.Run(); err != nil {
533			panic(err)
534		}
535	} else {
536		fmt.Println(out)
537	}
538}
539