1// +build codegen 2 3package main 4 5import ( 6 "bytes" 7 "encoding/json" 8 "fmt" 9 "net/url" 10 "os" 11 "os/exec" 12 "reflect" 13 "regexp" 14 "sort" 15 "strconv" 16 "strings" 17 "text/template" 18 19 "github.com/aws/aws-sdk-go/private/model/api" 20 "github.com/aws/aws-sdk-go/private/util" 21) 22 23// TestSuiteTypeInput input test 24// TestSuiteTypeInput output test 25const ( 26 TestSuiteTypeInput = iota 27 TestSuiteTypeOutput 28) 29 30type testSuite struct { 31 *api.API 32 Description string 33 Cases []testCase 34 Type uint 35 title string 36} 37 38type testCase struct { 39 TestSuite *testSuite 40 Given *api.Operation 41 Params interface{} `json:",omitempty"` 42 Data interface{} `json:"result,omitempty"` 43 InputTest testExpectation `json:"serialized"` 44 OutputTest testExpectation `json:"response"` 45} 46 47type testExpectation struct { 48 Body string 49 URI string 50 Headers map[string]string 51 JSONValues map[string]string 52 StatusCode uint `json:"status_code"` 53} 54 55const preamble = ` 56var _ bytes.Buffer // always import bytes 57var _ http.Request 58var _ json.Marshaler 59var _ time.Time 60var _ xmlutil.XMLNode 61var _ xml.Attr 62var _ = ioutil.Discard 63var _ = util.Trim("") 64var _ = url.Values{} 65var _ = io.EOF 66var _ = aws.String 67var _ = fmt.Println 68var _ = reflect.Value{} 69 70func init() { 71 protocol.RandReader = &awstesting.ZeroReader{} 72} 73` 74 75var reStripSpace = regexp.MustCompile(`\s(\w)`) 76 77var reImportRemoval = regexp.MustCompile(`(?s:import \((.+?)\))`) 78 79func removeImports(code string) string { 80 return reImportRemoval.ReplaceAllString(code, "") 81} 82 83var extraImports = []string{ 84 "bytes", 85 "encoding/json", 86 "encoding/xml", 87 "fmt", 88 "io", 89 "io/ioutil", 90 "net/http", 91 "testing", 92 "time", 93 "reflect", 94 "net/url", 95 "", 96 "github.com/aws/aws-sdk-go/awstesting", 97 "github.com/aws/aws-sdk-go/awstesting/unit", 98 "github.com/aws/aws-sdk-go/private/protocol", 99 "github.com/aws/aws-sdk-go/private/protocol/xml/xmlutil", 100 "github.com/aws/aws-sdk-go/private/util", 101} 102 103func addImports(code string) string { 104 importNames := make([]string, len(extraImports)) 105 for i, n := range extraImports { 106 if n != "" { 107 importNames[i] = fmt.Sprintf("%q", n) 108 } 109 } 110 111 str := reImportRemoval.ReplaceAllString(code, "import (\n"+strings.Join(importNames, "\n")+"$1\n)") 112 return str 113} 114 115func (t *testSuite) TestSuite() string { 116 var buf bytes.Buffer 117 118 t.title = reStripSpace.ReplaceAllStringFunc(t.Description, func(x string) string { 119 return strings.ToUpper(x[1:]) 120 }) 121 t.title = regexp.MustCompile(`\W`).ReplaceAllString(t.title, "") 122 123 for idx, c := range t.Cases { 124 c.TestSuite = t 125 buf.WriteString(c.TestCase(idx) + "\n") 126 } 127 return buf.String() 128} 129 130var tplInputTestCase = template.Must(template.New("inputcase").Parse(` 131func Test{{ .OpName }}(t *testing.T) { 132 svc := New{{ .TestCase.TestSuite.API.StructName }}(unit.Session, &aws.Config{Endpoint: aws.String("https://test")}) 133 {{ if ne .ParamsString "" }}input := {{ .ParamsString }} 134 {{ range $k, $v := .JSONValues -}} 135 input.{{ $k }} = {{ $v }} 136 {{ end -}} 137 req, _ := svc.{{ .TestCase.Given.ExportedName }}Request(input){{ else }}req, _ := svc.{{ .TestCase.Given.ExportedName }}Request(nil){{ end }} 138 r := req.HTTPRequest 139 140 // build request 141 {{ .TestCase.TestSuite.API.ProtocolPackage }}.Build(req) 142 if req.Error != nil { 143 t.Errorf("expect no error, got %v", req.Error) 144 } 145 146 {{ if ne .TestCase.InputTest.Body "" }}// assert body 147 if r.Body == nil { 148 t.Errorf("expect body not to be nil") 149 } 150 {{ .BodyAssertions }}{{ end }} 151 152 {{ if ne .TestCase.InputTest.URI "" }}// assert URL 153 awstesting.AssertURL(t, "https://test{{ .TestCase.InputTest.URI }}", r.URL.String()){{ end }} 154 155 // assert headers 156 {{ range $k, $v := .TestCase.InputTest.Headers -}} 157 if e, a := "{{ $v }}", r.Header.Get("{{ $k }}"); e != a { 158 t.Errorf("expect %v to be %v", e, a) 159 } 160 {{ end }} 161} 162`)) 163 164type tplInputTestCaseData struct { 165 TestCase *testCase 166 JSONValues map[string]string 167 OpName, ParamsString string 168} 169 170func (t tplInputTestCaseData) BodyAssertions() string { 171 code := &bytes.Buffer{} 172 protocol := t.TestCase.TestSuite.API.Metadata.Protocol 173 174 // Extract the body bytes 175 switch protocol { 176 case "rest-xml": 177 fmt.Fprintln(code, "body := util.SortXML(r.Body)") 178 default: 179 fmt.Fprintln(code, "body, _ := ioutil.ReadAll(r.Body)") 180 } 181 182 // Generate the body verification code 183 expectedBody := util.Trim(t.TestCase.InputTest.Body) 184 switch protocol { 185 case "ec2", "query": 186 fmt.Fprintf(code, "awstesting.AssertQuery(t, `%s`, util.Trim(string(body)))", 187 expectedBody) 188 case "rest-xml": 189 if strings.HasPrefix(expectedBody, "<") { 190 fmt.Fprintf(code, "awstesting.AssertXML(t, `%s`, util.Trim(string(body)), %s{})", 191 expectedBody, t.TestCase.Given.InputRef.ShapeName) 192 } else { 193 code.WriteString(fmtAssertEqual(fmt.Sprintf("%q", expectedBody), "util.Trim(string(body))")) 194 } 195 case "json", "jsonrpc", "rest-json": 196 if strings.HasPrefix(expectedBody, "{") { 197 fmt.Fprintf(code, "awstesting.AssertJSON(t, `%s`, util.Trim(string(body)))", 198 expectedBody) 199 } else { 200 code.WriteString(fmtAssertEqual(fmt.Sprintf("%q", expectedBody), "util.Trim(string(body))")) 201 } 202 default: 203 code.WriteString(fmtAssertEqual(expectedBody, "util.Trim(string(body))")) 204 } 205 206 return code.String() 207} 208 209func fmtAssertEqual(e, a string) string { 210 const format = `if e, a := %s, %s; e != a { 211 t.Errorf("expect %%v, got %%v", e, a) 212 } 213 ` 214 215 return fmt.Sprintf(format, e, a) 216} 217 218func fmtAssertNil(v string) string { 219 const format = `if e := %s; e != nil { 220 t.Errorf("expect nil, got %%v", e) 221 } 222 ` 223 224 return fmt.Sprintf(format, v) 225} 226 227var tplOutputTestCase = template.Must(template.New("outputcase").Parse(` 228func Test{{ .OpName }}(t *testing.T) { 229 svc := New{{ .TestCase.TestSuite.API.StructName }}(unit.Session, &aws.Config{Endpoint: aws.String("https://test")}) 230 231 buf := bytes.NewReader([]byte({{ .Body }})) 232 req, out := svc.{{ .TestCase.Given.ExportedName }}Request(nil) 233 req.HTTPResponse = &http.Response{StatusCode: 200, Body: ioutil.NopCloser(buf), Header: http.Header{}} 234 235 // set headers 236 {{ range $k, $v := .TestCase.OutputTest.Headers }}req.HTTPResponse.Header.Set("{{ $k }}", "{{ $v }}") 237 {{ end }} 238 239 // unmarshal response 240 {{ .TestCase.TestSuite.API.ProtocolPackage }}.UnmarshalMeta(req) 241 {{ .TestCase.TestSuite.API.ProtocolPackage }}.Unmarshal(req) 242 if req.Error != nil { 243 t.Errorf("expect not error, got %v", req.Error) 244 } 245 246 // assert response 247 if out == nil { 248 t.Errorf("expect not to be nil") 249 } 250 {{ .Assertions }} 251} 252`)) 253 254type tplOutputTestCaseData struct { 255 TestCase *testCase 256 Body, OpName, Assertions string 257} 258 259func (i *testCase) TestCase(idx int) string { 260 var buf bytes.Buffer 261 262 opName := i.TestSuite.API.StructName() + i.TestSuite.title + "Case" + strconv.Itoa(idx+1) 263 264 if i.TestSuite.Type == TestSuiteTypeInput { // input test 265 // query test should sort body as form encoded values 266 switch i.TestSuite.API.Metadata.Protocol { 267 case "query", "ec2": 268 m, _ := url.ParseQuery(i.InputTest.Body) 269 i.InputTest.Body = m.Encode() 270 case "rest-xml": 271 i.InputTest.Body = util.SortXML(bytes.NewReader([]byte(i.InputTest.Body))) 272 case "json", "rest-json": 273 i.InputTest.Body = strings.Replace(i.InputTest.Body, " ", "", -1) 274 } 275 276 jsonValues := buildJSONValues(i.Given.InputRef.Shape) 277 var params interface{} 278 if m, ok := i.Params.(map[string]interface{}); ok { 279 paramsMap := map[string]interface{}{} 280 for k, v := range m { 281 if _, ok := jsonValues[k]; !ok { 282 paramsMap[k] = v 283 } else { 284 if i.InputTest.JSONValues == nil { 285 i.InputTest.JSONValues = map[string]string{} 286 } 287 i.InputTest.JSONValues[k] = serializeJSONValue(v.(map[string]interface{})) 288 } 289 } 290 params = paramsMap 291 } else { 292 params = i.Params 293 } 294 input := tplInputTestCaseData{ 295 TestCase: i, 296 OpName: strings.ToUpper(opName[0:1]) + opName[1:], 297 ParamsString: api.ParamsStructFromJSON(params, i.Given.InputRef.Shape, false), 298 JSONValues: i.InputTest.JSONValues, 299 } 300 301 if err := tplInputTestCase.Execute(&buf, input); err != nil { 302 panic(err) 303 } 304 } else if i.TestSuite.Type == TestSuiteTypeOutput { 305 output := tplOutputTestCaseData{ 306 TestCase: i, 307 Body: fmt.Sprintf("%q", i.OutputTest.Body), 308 OpName: strings.ToUpper(opName[0:1]) + opName[1:], 309 Assertions: GenerateAssertions(i.Data, i.Given.OutputRef.Shape, "out"), 310 } 311 312 if err := tplOutputTestCase.Execute(&buf, output); err != nil { 313 panic(err) 314 } 315 } 316 317 return buf.String() 318} 319 320func serializeJSONValue(m map[string]interface{}) string { 321 str := "aws.JSONValue" 322 str += walkMap(m) 323 return str 324} 325 326func walkMap(m map[string]interface{}) string { 327 str := "{" 328 for k, v := range m { 329 str += fmt.Sprintf("%q:", k) 330 switch v.(type) { 331 case bool: 332 str += fmt.Sprintf("%t,\n", v.(bool)) 333 case string: 334 str += fmt.Sprintf("%q,\n", v.(string)) 335 case int: 336 str += fmt.Sprintf("%d,\n", v.(int)) 337 case float64: 338 str += fmt.Sprintf("%f,\n", v.(float64)) 339 case map[string]interface{}: 340 str += walkMap(v.(map[string]interface{})) 341 } 342 } 343 str += "}" 344 return str 345} 346 347func buildJSONValues(shape *api.Shape) map[string]struct{} { 348 keys := map[string]struct{}{} 349 for key, field := range shape.MemberRefs { 350 if field.JSONValue { 351 keys[key] = struct{}{} 352 } 353 } 354 return keys 355} 356 357// generateTestSuite generates a protocol test suite for a given configuration 358// JSON protocol test file. 359func generateTestSuite(filename string) string { 360 inout := "Input" 361 if strings.Contains(filename, "output/") { 362 inout = "Output" 363 } 364 365 var suites []testSuite 366 f, err := os.Open(filename) 367 if err != nil { 368 panic(err) 369 } 370 371 err = json.NewDecoder(f).Decode(&suites) 372 if err != nil { 373 panic(err) 374 } 375 376 var buf bytes.Buffer 377 buf.WriteString("package " + suites[0].ProtocolPackage() + "_test\n\n") 378 379 var innerBuf bytes.Buffer 380 innerBuf.WriteString("//\n// Tests begin here\n//\n\n\n") 381 382 for i, suite := range suites { 383 svcPrefix := inout + "Service" + strconv.Itoa(i+1) 384 suite.API.Metadata.ServiceAbbreviation = svcPrefix + "ProtocolTest" 385 suite.API.Operations = map[string]*api.Operation{} 386 for idx, c := range suite.Cases { 387 c.Given.ExportedName = svcPrefix + "TestCaseOperation" + strconv.Itoa(idx+1) 388 suite.API.Operations[c.Given.ExportedName] = c.Given 389 } 390 391 suite.Type = getType(inout) 392 suite.API.NoInitMethods = true // don't generate init methods 393 suite.API.NoStringerMethods = true // don't generate stringer methods 394 suite.API.NoConstServiceNames = true // don't generate service names 395 suite.API.Setup() 396 suite.API.Metadata.EndpointPrefix = suite.API.PackageName() 397 398 // Sort in order for deterministic test generation 399 names := make([]string, 0, len(suite.API.Shapes)) 400 for n := range suite.API.Shapes { 401 names = append(names, n) 402 } 403 sort.Strings(names) 404 for _, name := range names { 405 s := suite.API.Shapes[name] 406 s.Rename(svcPrefix + "TestShape" + name) 407 } 408 409 svcCode := addImports(suite.API.ServiceGoCode()) 410 if i == 0 { 411 importMatch := reImportRemoval.FindStringSubmatch(svcCode) 412 buf.WriteString(importMatch[0] + "\n\n") 413 buf.WriteString(preamble + "\n\n") 414 } 415 svcCode = removeImports(svcCode) 416 svcCode = strings.Replace(svcCode, "func New(", "func New"+suite.API.StructName()+"(", -1) 417 svcCode = strings.Replace(svcCode, "func newClient(", "func new"+suite.API.StructName()+"Client(", -1) 418 svcCode = strings.Replace(svcCode, "return newClient(", "return new"+suite.API.StructName()+"Client(", -1) 419 buf.WriteString(svcCode + "\n\n") 420 421 apiCode := removeImports(suite.API.APIGoCode()) 422 apiCode = strings.Replace(apiCode, "var oprw sync.Mutex", "", -1) 423 apiCode = strings.Replace(apiCode, "oprw.Lock()", "", -1) 424 apiCode = strings.Replace(apiCode, "defer oprw.Unlock()", "", -1) 425 buf.WriteString(apiCode + "\n\n") 426 427 innerBuf.WriteString(suite.TestSuite() + "\n") 428 } 429 430 return buf.String() + innerBuf.String() 431} 432 433// findMember searches the shape for the member with the matching key name. 434func findMember(shape *api.Shape, key string) string { 435 for actualKey := range shape.MemberRefs { 436 if strings.ToLower(key) == strings.ToLower(actualKey) { 437 return actualKey 438 } 439 } 440 return "" 441} 442 443// GenerateAssertions builds assertions for a shape based on its type. 444// 445// The shape's recursive values also will have assertions generated for them. 446func GenerateAssertions(out interface{}, shape *api.Shape, prefix string) string { 447 if shape == nil { 448 return "" 449 } 450 switch t := out.(type) { 451 case map[string]interface{}: 452 keys := util.SortedKeys(t) 453 454 code := "" 455 if shape.Type == "map" { 456 for _, k := range keys { 457 v := t[k] 458 s := shape.ValueRef.Shape 459 code += GenerateAssertions(v, s, prefix+"[\""+k+"\"]") 460 } 461 } else if shape.Type == "jsonvalue" { 462 code += fmt.Sprintf("reflect.DeepEqual(%s, map[string]interface{}%s)\n", prefix, walkMap(out.(map[string]interface{}))) 463 } else { 464 for _, k := range keys { 465 v := t[k] 466 m := findMember(shape, k) 467 s := shape.MemberRefs[m].Shape 468 code += GenerateAssertions(v, s, prefix+"."+m+"") 469 } 470 } 471 return code 472 case []interface{}: 473 code := "" 474 for i, v := range t { 475 s := shape.MemberRef.Shape 476 code += GenerateAssertions(v, s, prefix+"["+strconv.Itoa(i)+"]") 477 } 478 return code 479 default: 480 switch shape.Type { 481 case "timestamp": 482 return fmtAssertEqual( 483 fmt.Sprintf("time.Unix(%#v, 0).UTC().String()", out), 484 fmt.Sprintf("%s.String()", prefix), 485 ) 486 case "blob": 487 return fmtAssertEqual( 488 fmt.Sprintf("%#v", out), 489 fmt.Sprintf("string(%s)", prefix), 490 ) 491 case "integer", "long": 492 return fmtAssertEqual( 493 fmt.Sprintf("int64(%#v)", out), 494 fmt.Sprintf("*%s", prefix), 495 ) 496 default: 497 if !reflect.ValueOf(out).IsValid() { 498 return fmtAssertNil(prefix) 499 } 500 return fmtAssertEqual( 501 fmt.Sprintf("%#v", out), 502 fmt.Sprintf("*%s", prefix), 503 ) 504 } 505 } 506} 507 508func getType(t string) uint { 509 switch t { 510 case "Input": 511 return TestSuiteTypeInput 512 case "Output": 513 return TestSuiteTypeOutput 514 default: 515 panic("Invalid type for test suite") 516 } 517} 518 519func main() { 520 fmt.Println("Generating test suite", os.Args[1:]) 521 out := generateTestSuite(os.Args[1]) 522 if len(os.Args) == 3 { 523 f, err := os.Create(os.Args[2]) 524 defer f.Close() 525 if err != nil { 526 panic(err) 527 } 528 f.WriteString(util.GoFmt(out)) 529 f.Close() 530 531 c := exec.Command("gofmt", "-s", "-w", os.Args[2]) 532 if err := c.Run(); err != nil { 533 panic(err) 534 } 535 } else { 536 fmt.Println(out) 537 } 538} 539