1// +build codegen
2
3package api
4
5import (
6	"bytes"
7	"encoding/json"
8	"fmt"
9	"os"
10	"text/template"
11)
12
13// SmokeTestSuite defines the test suite for smoke tests.
14type SmokeTestSuite struct {
15	Version       int             `json:"version"`
16	DefaultRegion string          `json:"defaultRegion"`
17	TestCases     []SmokeTestCase `json:"testCases"`
18}
19
20// SmokeTestCase provides the definition for a integration smoke test case.
21type SmokeTestCase struct {
22	OpName    string                 `json:"operationName"`
23	Input     map[string]interface{} `json:"input"`
24	ExpectErr bool                   `json:"errorExpectedFromService"`
25}
26
27// BuildInputShape returns the Go code as a string for initializing the test
28// case's input shape.
29func (c SmokeTestCase) BuildInputShape(ref *ShapeRef) string {
30	b := NewShapeValueBuilder()
31	return fmt.Sprintf("&%s{\n%s\n}",
32		b.GoType(ref, true),
33		b.BuildShape(ref, c.Input, false),
34	)
35}
36
37// AttachSmokeTests attaches the smoke test cases to the API model.
38func (a *API) AttachSmokeTests(filename string) error {
39	f, err := os.Open(filename)
40	if err != nil {
41		return fmt.Errorf("failed to open smoke tests %s, err: %v", filename, err)
42	}
43	defer f.Close()
44
45	if err := json.NewDecoder(f).Decode(&a.SmokeTests); err != nil {
46		return fmt.Errorf("failed to decode smoke tests %s, err: %v", filename, err)
47	}
48
49	if v := a.SmokeTests.Version; v != 1 {
50		return fmt.Errorf("invalid smoke test version, %d", v)
51	}
52
53	return nil
54}
55
56// APISmokeTestsGoCode returns the Go Code string for the smoke tests.
57func (a *API) APISmokeTestsGoCode() string {
58	w := bytes.NewBuffer(nil)
59
60	a.resetImports()
61	a.AddImport("context")
62	a.AddImport("testing")
63	a.AddImport("time")
64	a.AddSDKImport("aws")
65	a.AddSDKImport("aws/request")
66	a.AddSDKImport("aws/awserr")
67	a.AddSDKImport("aws/request")
68	a.AddSDKImport("awstesting/integration")
69	a.AddImport(a.ImportPath())
70
71	smokeTests := struct {
72		API *API
73		SmokeTestSuite
74	}{
75		API:            a,
76		SmokeTestSuite: a.SmokeTests,
77	}
78
79	if err := smokeTestTmpl.Execute(w, smokeTests); err != nil {
80		panic(fmt.Sprintf("failed to create smoke tests, %v", err))
81	}
82
83	ignoreImports := `
84	var _ aws.Config
85	var _ awserr.Error
86	var _ request.Request
87	`
88
89	return a.importsGoCode() + ignoreImports + w.String()
90}
91
92var smokeTestTmpl = template.Must(template.New(`smokeTestTmpl`).Parse(`
93{{- range $i, $testCase := $.TestCases }}
94	{{- $op := index $.API.Operations $testCase.OpName }}
95	func TestInteg_{{ printf "%02d" $i }}_{{ $op.ExportedName }}(t *testing.T) {
96		ctx, cancelFn := context.WithTimeout(context.Background(), 5 *time.Second)
97		defer cancelFn()
98
99		sess := integration.SessionWithDefaultRegion("{{ $.DefaultRegion }}")
100		svc := {{ $.API.PackageName }}.New(sess)
101		params := {{ $testCase.BuildInputShape $op.InputRef }}
102		_, err := svc.{{ $op.ExportedName }}WithContext(ctx, params, func(r *request.Request) {
103			r.Handlers.Validate.RemoveByName("core.ValidateParametersHandler")
104		})
105		{{- if $testCase.ExpectErr }}
106			if err == nil {
107				t.Fatalf("expect request to fail")
108			}
109			aerr, ok := err.(awserr.RequestFailure)
110			if !ok {
111				t.Fatalf("expect awserr, was %T", err)
112			}
113			if len(aerr.Code()) == 0 {
114				t.Errorf("expect non-empty error code")
115			}
116			if len(aerr.Message()) == 0 {
117				t.Errorf("expect non-empty error message")
118			}
119			if v := aerr.Code(); v == request.ErrCodeSerialization {
120				t.Errorf("expect API error code got serialization failure")
121			}
122		{{- else }}
123			if err != nil {
124				t.Errorf("expect no error, got %v", err)
125			}
126		{{- end }}
127	}
128{{- end }}
129`))
130