1// +build codegen 2 3package main 4 5import ( 6 "bytes" 7 "encoding/json" 8 "fmt" 9 "net/url" 10 "os" 11 "os/exec" 12 "reflect" 13 "regexp" 14 "sort" 15 "strconv" 16 "strings" 17 "text/template" 18 19 "github.com/aws/aws-sdk-go/private/model/api" 20 "github.com/aws/aws-sdk-go/private/util" 21) 22 23// TestSuiteTypeInput input test 24// TestSuiteTypeInput output test 25const ( 26 TestSuiteTypeInput = iota 27 TestSuiteTypeOutput 28) 29 30type testSuite struct { 31 *api.API 32 Description string 33 ClientEndpoint string 34 Cases []testCase 35 Type uint 36 title string 37} 38 39func (s *testSuite) UnmarshalJSON(p []byte) error { 40 type stub testSuite 41 42 var v stub 43 if err := json.Unmarshal(p, &v); err != nil { 44 return err 45 } 46 47 if len(v.ClientEndpoint) == 0 { 48 v.ClientEndpoint = "https://test" 49 } 50 for i := 0; i < len(v.Cases); i++ { 51 if len(v.Cases[i].InputTest.Host) == 0 { 52 v.Cases[i].InputTest.Host = "test" 53 } 54 if len(v.Cases[i].InputTest.URI) == 0 { 55 v.Cases[i].InputTest.URI = "/" 56 } 57 } 58 59 *s = testSuite(v) 60 return nil 61} 62 63type testCase struct { 64 TestSuite *testSuite 65 Given *api.Operation 66 Params interface{} `json:",omitempty"` 67 Data interface{} `json:"result,omitempty"` 68 InputTest testExpectation `json:"serialized"` 69 OutputTest testExpectation `json:"response"` 70} 71 72type testExpectation struct { 73 Body string 74 Host string 75 URI string 76 Headers map[string]string 77 JSONValues map[string]string 78 StatusCode uint `json:"status_code"` 79} 80 81const preamble = ` 82var _ bytes.Buffer // always import bytes 83var _ http.Request 84var _ json.Marshaler 85var _ time.Time 86var _ xmlutil.XMLNode 87var _ xml.Attr 88var _ = ioutil.Discard 89var _ = util.Trim("") 90var _ = url.Values{} 91var _ = io.EOF 92var _ = aws.String 93var _ = fmt.Println 94var _ = reflect.Value{} 95 96func init() { 97 protocol.RandReader = &awstesting.ZeroReader{} 98} 99` 100 101var reStripSpace = regexp.MustCompile(`\s(\w)`) 102 103var reImportRemoval = regexp.MustCompile(`(?s:import \((.+?)\))`) 104 105func removeImports(code string) string { 106 return reImportRemoval.ReplaceAllString(code, "") 107} 108 109var extraImports = []string{ 110 "bytes", 111 "encoding/json", 112 "encoding/xml", 113 "fmt", 114 "io", 115 "io/ioutil", 116 "net/http", 117 "testing", 118 "time", 119 "reflect", 120 "net/url", 121 "", 122 "github.com/aws/aws-sdk-go/awstesting", 123 "github.com/aws/aws-sdk-go/awstesting/unit", 124 "github.com/aws/aws-sdk-go/private/protocol", 125 "github.com/aws/aws-sdk-go/private/protocol/xml/xmlutil", 126 "github.com/aws/aws-sdk-go/private/util", 127} 128 129func addImports(code string) string { 130 importNames := make([]string, len(extraImports)) 131 for i, n := range extraImports { 132 if n != "" { 133 importNames[i] = fmt.Sprintf("%q", n) 134 } 135 } 136 137 str := reImportRemoval.ReplaceAllString(code, "import (\n"+strings.Join(importNames, "\n")+"$1\n)") 138 return str 139} 140 141func (t *testSuite) TestSuite() string { 142 var buf bytes.Buffer 143 144 t.title = reStripSpace.ReplaceAllStringFunc(t.Description, func(x string) string { 145 return strings.ToUpper(x[1:]) 146 }) 147 t.title = regexp.MustCompile(`\W`).ReplaceAllString(t.title, "") 148 149 for idx, c := range t.Cases { 150 c.TestSuite = t 151 buf.WriteString(c.TestCase(idx) + "\n") 152 } 153 return buf.String() 154} 155 156var tplInputTestCase = template.Must(template.New("inputcase").Parse(` 157func Test{{ .OpName }}(t *testing.T) { 158 svc := New{{ .TestCase.TestSuite.API.StructName }}(unit.Session, &aws.Config{Endpoint: aws.String("{{ .TestCase.TestSuite.ClientEndpoint }}")}) 159 {{ if ne .ParamsString "" }}input := {{ .ParamsString }} 160 {{ range $k, $v := .JSONValues -}} 161 input.{{ $k }} = {{ $v }} 162 {{ end -}} 163 req, _ := svc.{{ .TestCase.Given.ExportedName }}Request(input){{ else }}req, _ := svc.{{ .TestCase.Given.ExportedName }}Request(nil){{ end }} 164 r := req.HTTPRequest 165 166 // build request 167 req.Build() 168 if req.Error != nil { 169 t.Errorf("expect no error, got %v", req.Error) 170 } 171 172 {{ if ne .TestCase.InputTest.Body "" }}// assert body 173 if r.Body == nil { 174 t.Errorf("expect body not to be nil") 175 } 176 {{ .BodyAssertions }}{{ end }} 177 178 // assert URL 179 awstesting.AssertURL(t, "https://{{ .TestCase.InputTest.Host }}{{ .TestCase.InputTest.URI }}", r.URL.String()) 180 181 // assert headers 182 {{ range $k, $v := .TestCase.InputTest.Headers -}} 183 if e, a := "{{ $v }}", r.Header.Get("{{ $k }}"); e != a { 184 t.Errorf("expect %v, got %v", e, a) 185 } 186 {{ end }} 187} 188`)) 189 190type tplInputTestCaseData struct { 191 TestCase *testCase 192 JSONValues map[string]string 193 OpName, ParamsString string 194} 195 196func (t tplInputTestCaseData) BodyAssertions() string { 197 code := &bytes.Buffer{} 198 protocol := t.TestCase.TestSuite.API.Metadata.Protocol 199 200 // Extract the body bytes 201 switch protocol { 202 case "rest-xml": 203 fmt.Fprintln(code, "body := util.SortXML(r.Body)") 204 default: 205 fmt.Fprintln(code, "body, _ := ioutil.ReadAll(r.Body)") 206 } 207 208 // Generate the body verification code 209 expectedBody := util.Trim(t.TestCase.InputTest.Body) 210 switch protocol { 211 case "ec2", "query": 212 fmt.Fprintf(code, "awstesting.AssertQuery(t, `%s`, util.Trim(string(body)))", 213 expectedBody) 214 case "rest-xml": 215 if strings.HasPrefix(expectedBody, "<") { 216 fmt.Fprintf(code, "awstesting.AssertXML(t, `%s`, util.Trim(body))", 217 expectedBody) 218 } else { 219 code.WriteString(fmtAssertEqual(fmt.Sprintf("%q", expectedBody), "util.Trim(string(body))")) 220 } 221 case "json", "jsonrpc", "rest-json": 222 if strings.HasPrefix(expectedBody, "{") { 223 fmt.Fprintf(code, "awstesting.AssertJSON(t, `%s`, util.Trim(string(body)))", 224 expectedBody) 225 } else { 226 code.WriteString(fmtAssertEqual(fmt.Sprintf("%q", expectedBody), "util.Trim(string(body))")) 227 } 228 default: 229 code.WriteString(fmtAssertEqual(expectedBody, "util.Trim(string(body))")) 230 } 231 232 return code.String() 233} 234 235func fmtAssertEqual(e, a string) string { 236 const format = `if e, a := %s, %s; e != a { 237 t.Errorf("expect %%v, got %%v", e, a) 238 } 239 ` 240 241 return fmt.Sprintf(format, e, a) 242} 243 244func fmtAssertNil(v string) string { 245 const format = `if e := %s; e != nil { 246 t.Errorf("expect nil, got %%v", e) 247 } 248 ` 249 250 return fmt.Sprintf(format, v) 251} 252 253var tplOutputTestCase = template.Must(template.New("outputcase").Parse(` 254func Test{{ .OpName }}(t *testing.T) { 255 svc := New{{ .TestCase.TestSuite.API.StructName }}(unit.Session, &aws.Config{Endpoint: aws.String("https://test")}) 256 257 buf := bytes.NewReader([]byte({{ .Body }})) 258 req, out := svc.{{ .TestCase.Given.ExportedName }}Request(nil) 259 req.HTTPResponse = &http.Response{StatusCode: 200, Body: ioutil.NopCloser(buf), Header: http.Header{}} 260 261 // set headers 262 {{ range $k, $v := .TestCase.OutputTest.Headers }}req.HTTPResponse.Header.Set("{{ $k }}", "{{ $v }}") 263 {{ end }} 264 265 // unmarshal response 266 req.Handlers.UnmarshalMeta.Run(req) 267 req.Handlers.Unmarshal.Run(req) 268 if req.Error != nil { 269 t.Errorf("expect not error, got %v", req.Error) 270 } 271 272 // assert response 273 if out == nil { 274 t.Errorf("expect not to be nil") 275 } 276 {{ .Assertions }} 277} 278`)) 279 280type tplOutputTestCaseData struct { 281 TestCase *testCase 282 Body, OpName, Assertions string 283} 284 285func (i *testCase) TestCase(idx int) string { 286 var buf bytes.Buffer 287 288 opName := i.TestSuite.API.StructName() + i.TestSuite.title + "Case" + strconv.Itoa(idx+1) 289 290 if i.TestSuite.Type == TestSuiteTypeInput { // input test 291 // query test should sort body as form encoded values 292 switch i.TestSuite.API.Metadata.Protocol { 293 case "query", "ec2": 294 m, _ := url.ParseQuery(i.InputTest.Body) 295 i.InputTest.Body = m.Encode() 296 case "rest-xml": 297 i.InputTest.Body = util.SortXML(bytes.NewReader([]byte(i.InputTest.Body))) 298 case "json", "rest-json": 299 // Nothing to do 300 } 301 302 jsonValues := buildJSONValues(i.Given.InputRef.Shape) 303 var params interface{} 304 if m, ok := i.Params.(map[string]interface{}); ok { 305 paramsMap := map[string]interface{}{} 306 for k, v := range m { 307 if _, ok := jsonValues[k]; !ok { 308 paramsMap[k] = v 309 } else { 310 if i.InputTest.JSONValues == nil { 311 i.InputTest.JSONValues = map[string]string{} 312 } 313 i.InputTest.JSONValues[k] = serializeJSONValue(v.(map[string]interface{})) 314 } 315 } 316 params = paramsMap 317 } else { 318 params = i.Params 319 } 320 input := tplInputTestCaseData{ 321 TestCase: i, 322 OpName: strings.ToUpper(opName[0:1]) + opName[1:], 323 ParamsString: api.ParamsStructFromJSON(params, i.Given.InputRef.Shape, false), 324 JSONValues: i.InputTest.JSONValues, 325 } 326 327 if err := tplInputTestCase.Execute(&buf, input); err != nil { 328 panic(err) 329 } 330 } else if i.TestSuite.Type == TestSuiteTypeOutput { 331 output := tplOutputTestCaseData{ 332 TestCase: i, 333 Body: fmt.Sprintf("%q", i.OutputTest.Body), 334 OpName: strings.ToUpper(opName[0:1]) + opName[1:], 335 Assertions: GenerateAssertions(i.Data, i.Given.OutputRef.Shape, "out"), 336 } 337 338 if err := tplOutputTestCase.Execute(&buf, output); err != nil { 339 panic(err) 340 } 341 } 342 343 return buf.String() 344} 345 346func serializeJSONValue(m map[string]interface{}) string { 347 str := "aws.JSONValue" 348 str += walkMap(m) 349 return str 350} 351 352func walkMap(m map[string]interface{}) string { 353 str := "{" 354 for k, v := range m { 355 str += fmt.Sprintf("%q:", k) 356 switch v.(type) { 357 case bool: 358 str += fmt.Sprintf("%t,\n", v.(bool)) 359 case string: 360 str += fmt.Sprintf("%q,\n", v.(string)) 361 case int: 362 str += fmt.Sprintf("%d,\n", v.(int)) 363 case float64: 364 str += fmt.Sprintf("%f,\n", v.(float64)) 365 case map[string]interface{}: 366 str += walkMap(v.(map[string]interface{})) 367 } 368 } 369 str += "}" 370 return str 371} 372 373func buildJSONValues(shape *api.Shape) map[string]struct{} { 374 keys := map[string]struct{}{} 375 for key, field := range shape.MemberRefs { 376 if field.JSONValue { 377 keys[key] = struct{}{} 378 } 379 } 380 return keys 381} 382 383// generateTestSuite generates a protocol test suite for a given configuration 384// JSON protocol test file. 385func generateTestSuite(filename string) string { 386 inout := "Input" 387 if strings.Contains(filename, "output/") { 388 inout = "Output" 389 } 390 391 var suites []testSuite 392 f, err := os.Open(filename) 393 if err != nil { 394 panic(err) 395 } 396 397 err = json.NewDecoder(f).Decode(&suites) 398 if err != nil { 399 panic(err) 400 } 401 402 var buf bytes.Buffer 403 buf.WriteString("// Code generated by models/protocol_tests/generate.go. DO NOT EDIT.\n\n") 404 buf.WriteString("package " + suites[0].ProtocolPackage() + "_test\n\n") 405 406 var innerBuf bytes.Buffer 407 innerBuf.WriteString("//\n// Tests begin here\n//\n\n\n") 408 409 for i, suite := range suites { 410 svcPrefix := inout + "Service" + strconv.Itoa(i+1) 411 suite.API.Metadata.ServiceAbbreviation = svcPrefix + "ProtocolTest" 412 suite.API.Operations = map[string]*api.Operation{} 413 for idx, c := range suite.Cases { 414 c.Given.ExportedName = svcPrefix + "TestCaseOperation" + strconv.Itoa(idx+1) 415 suite.API.Operations[c.Given.ExportedName] = c.Given 416 } 417 418 suite.Type = getType(inout) 419 suite.API.NoInitMethods = true // don't generate init methods 420 suite.API.NoStringerMethods = true // don't generate stringer methods 421 suite.API.NoConstServiceNames = true // don't generate service names 422 suite.API.Setup() 423 suite.API.Metadata.EndpointPrefix = suite.API.PackageName() 424 suite.API.Metadata.EndpointsID = suite.API.Metadata.EndpointPrefix 425 426 // Sort in order for deterministic test generation 427 names := make([]string, 0, len(suite.API.Shapes)) 428 for n := range suite.API.Shapes { 429 names = append(names, n) 430 } 431 sort.Strings(names) 432 for _, name := range names { 433 s := suite.API.Shapes[name] 434 s.Rename(svcPrefix + "TestShape" + name) 435 } 436 437 svcCode := addImports(suite.API.ServiceGoCode()) 438 if i == 0 { 439 importMatch := reImportRemoval.FindStringSubmatch(svcCode) 440 buf.WriteString(importMatch[0] + "\n\n") 441 buf.WriteString(preamble + "\n\n") 442 } 443 svcCode = removeImports(svcCode) 444 svcCode = strings.Replace(svcCode, "func New(", "func New"+suite.API.StructName()+"(", -1) 445 svcCode = strings.Replace(svcCode, "func newClient(", "func new"+suite.API.StructName()+"Client(", -1) 446 svcCode = strings.Replace(svcCode, "return newClient(", "return new"+suite.API.StructName()+"Client(", -1) 447 buf.WriteString(svcCode + "\n\n") 448 449 apiCode := removeImports(suite.API.APIGoCode()) 450 apiCode = strings.Replace(apiCode, "var oprw sync.Mutex", "", -1) 451 apiCode = strings.Replace(apiCode, "oprw.Lock()", "", -1) 452 apiCode = strings.Replace(apiCode, "defer oprw.Unlock()", "", -1) 453 buf.WriteString(apiCode + "\n\n") 454 455 innerBuf.WriteString(suite.TestSuite() + "\n") 456 } 457 458 return buf.String() + innerBuf.String() 459} 460 461// findMember searches the shape for the member with the matching key name. 462func findMember(shape *api.Shape, key string) string { 463 for actualKey := range shape.MemberRefs { 464 if strings.EqualFold(key, actualKey) { 465 return actualKey 466 } 467 } 468 return "" 469} 470 471// GenerateAssertions builds assertions for a shape based on its type. 472// 473// The shape's recursive values also will have assertions generated for them. 474func GenerateAssertions(out interface{}, shape *api.Shape, prefix string) string { 475 if shape == nil { 476 return "" 477 } 478 switch t := out.(type) { 479 case map[string]interface{}: 480 keys := util.SortedKeys(t) 481 482 code := "" 483 if shape.Type == "map" { 484 for _, k := range keys { 485 v := t[k] 486 s := shape.ValueRef.Shape 487 code += GenerateAssertions(v, s, prefix+"[\""+k+"\"]") 488 } 489 } else if shape.Type == "jsonvalue" { 490 code += fmt.Sprintf("reflect.DeepEqual(%s, map[string]interface{}%s)\n", prefix, walkMap(out.(map[string]interface{}))) 491 } else { 492 for _, k := range keys { 493 v := t[k] 494 m := findMember(shape, k) 495 s := shape.MemberRefs[m].Shape 496 code += GenerateAssertions(v, s, prefix+"."+m+"") 497 } 498 } 499 return code 500 case []interface{}: 501 code := "" 502 for i, v := range t { 503 s := shape.MemberRef.Shape 504 code += GenerateAssertions(v, s, prefix+"["+strconv.Itoa(i)+"]") 505 } 506 return code 507 default: 508 switch shape.Type { 509 case "timestamp": 510 return fmtAssertEqual( 511 fmt.Sprintf("time.Unix(%#v, 0).UTC().String()", out), 512 fmt.Sprintf("%s.UTC().String()", prefix), 513 ) 514 case "blob": 515 return fmtAssertEqual( 516 fmt.Sprintf("%#v", out), 517 fmt.Sprintf("string(%s)", prefix), 518 ) 519 case "integer", "long": 520 return fmtAssertEqual( 521 fmt.Sprintf("int64(%#v)", out), 522 fmt.Sprintf("*%s", prefix), 523 ) 524 default: 525 if !reflect.ValueOf(out).IsValid() { 526 return fmtAssertNil(prefix) 527 } 528 return fmtAssertEqual( 529 fmt.Sprintf("%#v", out), 530 fmt.Sprintf("*%s", prefix), 531 ) 532 } 533 } 534} 535 536func getType(t string) uint { 537 switch t { 538 case "Input": 539 return TestSuiteTypeInput 540 case "Output": 541 return TestSuiteTypeOutput 542 default: 543 panic("Invalid type for test suite") 544 } 545} 546 547func main() { 548 if len(os.Getenv("AWS_SDK_CODEGEN_DEBUG")) != 0 { 549 api.LogDebug(os.Stdout) 550 } 551 552 fmt.Println("Generating test suite", os.Args[1:]) 553 out := generateTestSuite(os.Args[1]) 554 if len(os.Args) == 3 { 555 f, err := os.Create(os.Args[2]) 556 defer f.Close() 557 if err != nil { 558 panic(err) 559 } 560 f.WriteString(util.GoFmt(out)) 561 f.Close() 562 563 c := exec.Command("gofmt", "-s", "-w", os.Args[2]) 564 if err := c.Run(); err != nil { 565 panic(err) 566 } 567 } else { 568 fmt.Println(out) 569 } 570} 571