1// +build go1.8
2
3package ec2query
4
5import (
6	"io/ioutil"
7	"net/http"
8	"strings"
9	"testing"
10
11	"github.com/aws/aws-sdk-go/aws/awserr"
12	"github.com/aws/aws-sdk-go/aws/request"
13)
14
15func TestUnmarshalError(t *testing.T) {
16	cases := map[string]struct {
17		Request   *request.Request
18		Code, Msg string
19		ReqID     string
20		Status    int
21	}{
22		"ErrorResponse": {
23			Request: &request.Request{
24				HTTPResponse: &http.Response{
25					StatusCode: 400,
26					Header:     http.Header{},
27					Body: ioutil.NopCloser(strings.NewReader(
28						`<Response>
29							<Errors>
30								<Error>
31									<Code>codeAbc</Code>
32									<Message>msg123</Message>
33								</Error>
34							</Errors>
35							<RequestID>reqID123</RequestID>
36						</Response>`)),
37				},
38			},
39			Code: "codeAbc", Msg: "msg123",
40			Status: 400, ReqID: "reqID123",
41		},
42		"unknown tag": {
43			Request: &request.Request{
44				HTTPResponse: &http.Response{
45					StatusCode: 400,
46					Header:     http.Header{},
47					Body: ioutil.NopCloser(strings.NewReader(
48						`<Hello>
49							<World>.</World>
50						</Hello>`)),
51				},
52			},
53			Code:   request.ErrCodeSerialization,
54			Msg:    "failed to unmarshal error message",
55			Status: 400,
56		},
57	}
58
59	for name, c := range cases {
60		t.Run(name, func(t *testing.T) {
61			r := c.Request
62			UnmarshalError(r)
63			if r.Error == nil {
64				t.Fatalf("expect error, got none")
65			}
66
67			aerr := r.Error.(awserr.RequestFailure)
68			if e, a := c.Code, aerr.Code(); e != a {
69				t.Errorf("expect %v code, got %v", e, a)
70			}
71			if e, a := c.Msg, aerr.Message(); e != a {
72				t.Errorf("expect %q message, got %q", e, a)
73			}
74			if e, a := c.ReqID, aerr.RequestID(); e != a {
75				t.Errorf("expect %v request ID, got %v", e, a)
76			}
77			if e, a := c.Status, aerr.StatusCode(); e != a {
78				t.Errorf("expect %v status code, got %v", e, a)
79			}
80		})
81	}
82}
83