1// +build go1.8
2
3package query
4
5import (
6	"io/ioutil"
7	"net/http"
8	"strings"
9	"testing"
10
11	"github.com/aws/aws-sdk-go/aws/awserr"
12	"github.com/aws/aws-sdk-go/aws/request"
13)
14
15func TestUnmarshalError(t *testing.T) {
16	cases := map[string]struct {
17		Request   *request.Request
18		Code, Msg string
19		ReqID     string
20		Status    int
21	}{
22		"ErrorResponse": {
23			Request: &request.Request{
24				HTTPResponse: &http.Response{
25					StatusCode: 400,
26					Header:     http.Header{},
27					Body: ioutil.NopCloser(strings.NewReader(
28						`<ErrorResponse>
29							<Error>
30								<Code>codeAbc</Code><Message>msg123</Message>
31							</Error>
32							<RequestId>reqID123</RequestId>
33						</ErrorResponse>`)),
34				},
35			},
36			Code: "codeAbc", Msg: "msg123",
37			Status: 400, ReqID: "reqID123",
38		},
39		"ServiceUnavailableException": {
40			Request: &request.Request{
41				HTTPResponse: &http.Response{
42					StatusCode: 502,
43					Header:     http.Header{},
44					Body: ioutil.NopCloser(strings.NewReader(
45						`<ServiceUnavailableException>
46							<Something>else</Something>
47						</ServiceUnavailableException>`)),
48				},
49			},
50			Code:   "ServiceUnavailableException",
51			Msg:    "service is unavailable",
52			Status: 502,
53		},
54		"unknown tag": {
55			Request: &request.Request{
56				HTTPResponse: &http.Response{
57					StatusCode: 400,
58					Header:     http.Header{},
59					Body: ioutil.NopCloser(strings.NewReader(
60						`<Hello>
61							<World>.</World>
62						</Hello>`)),
63				},
64			},
65			Code:   request.ErrCodeSerialization,
66			Msg:    "failed to unmarshal error message",
67			Status: 400,
68		},
69	}
70
71	for name, c := range cases {
72		t.Run(name, func(t *testing.T) {
73			r := c.Request
74			UnmarshalError(r)
75			if r.Error == nil {
76				t.Fatalf("expect error, got none")
77			}
78
79			aerr := r.Error.(awserr.RequestFailure)
80			if e, a := c.Code, aerr.Code(); e != a {
81				t.Errorf("expect %v code, got %v", e, a)
82			}
83			if e, a := c.Msg, aerr.Message(); e != a {
84				t.Errorf("expect %q message, got %q", e, a)
85			}
86			if e, a := c.ReqID, aerr.RequestID(); e != a {
87				t.Errorf("expect %v request ID, got %v", e, a)
88			}
89			if e, a := c.Status, aerr.StatusCode(); e != a {
90				t.Errorf("expect %v status code, got %v", e, a)
91			}
92		})
93	}
94}
95