1package rest 2 3import ( 4 "encoding/base64" 5 "fmt" 6 "io" 7 "io/ioutil" 8 "net/http" 9 "reflect" 10 "strconv" 11 "strings" 12 "time" 13 14 "github.com/aws/aws-sdk-go/aws" 15 "github.com/aws/aws-sdk-go/aws/awserr" 16 "github.com/aws/aws-sdk-go/aws/request" 17) 18 19// UnmarshalHandler is a named request handler for unmarshaling rest protocol requests 20var UnmarshalHandler = request.NamedHandler{Name: "awssdk.rest.Unmarshal", Fn: Unmarshal} 21 22// UnmarshalMetaHandler is a named request handler for unmarshaling rest protocol request metadata 23var UnmarshalMetaHandler = request.NamedHandler{Name: "awssdk.rest.UnmarshalMeta", Fn: UnmarshalMeta} 24 25// Unmarshal unmarshals the REST component of a response in a REST service. 26func Unmarshal(r *request.Request) { 27 if r.DataFilled() { 28 v := reflect.Indirect(reflect.ValueOf(r.Data)) 29 unmarshalBody(r, v) 30 } 31} 32 33// UnmarshalMeta unmarshals the REST metadata of a response in a REST service 34func UnmarshalMeta(r *request.Request) { 35 r.RequestID = r.HTTPResponse.Header.Get("X-Amzn-Requestid") 36 if r.RequestID == "" { 37 // Alternative version of request id in the header 38 r.RequestID = r.HTTPResponse.Header.Get("X-Amz-Request-Id") 39 } 40 if r.DataFilled() { 41 v := reflect.Indirect(reflect.ValueOf(r.Data)) 42 unmarshalLocationElements(r, v) 43 } 44} 45 46func unmarshalBody(r *request.Request, v reflect.Value) { 47 if field, ok := v.Type().FieldByName("_"); ok { 48 if payloadName := field.Tag.Get("payload"); payloadName != "" { 49 pfield, _ := v.Type().FieldByName(payloadName) 50 if ptag := pfield.Tag.Get("type"); ptag != "" && ptag != "structure" { 51 payload := v.FieldByName(payloadName) 52 if payload.IsValid() { 53 switch payload.Interface().(type) { 54 case []byte: 55 defer r.HTTPResponse.Body.Close() 56 b, err := ioutil.ReadAll(r.HTTPResponse.Body) 57 if err != nil { 58 r.Error = awserr.New("SerializationError", "failed to decode REST response", err) 59 } else { 60 payload.Set(reflect.ValueOf(b)) 61 } 62 case *string: 63 defer r.HTTPResponse.Body.Close() 64 b, err := ioutil.ReadAll(r.HTTPResponse.Body) 65 if err != nil { 66 r.Error = awserr.New("SerializationError", "failed to decode REST response", err) 67 } else { 68 str := string(b) 69 payload.Set(reflect.ValueOf(&str)) 70 } 71 default: 72 switch payload.Type().String() { 73 case "io.ReadSeeker": 74 payload.Set(reflect.ValueOf(aws.ReadSeekCloser(r.HTTPResponse.Body))) 75 case "aws.ReadSeekCloser", "io.ReadCloser": 76 payload.Set(reflect.ValueOf(r.HTTPResponse.Body)) 77 default: 78 io.Copy(ioutil.Discard, r.HTTPResponse.Body) 79 defer r.HTTPResponse.Body.Close() 80 r.Error = awserr.New("SerializationError", 81 "failed to decode REST response", 82 fmt.Errorf("unknown payload type %s", payload.Type())) 83 } 84 } 85 } 86 } 87 } 88 } 89} 90 91func unmarshalLocationElements(r *request.Request, v reflect.Value) { 92 for i := 0; i < v.NumField(); i++ { 93 m, field := v.Field(i), v.Type().Field(i) 94 if n := field.Name; n[0:1] == strings.ToLower(n[0:1]) { 95 continue 96 } 97 98 if m.IsValid() { 99 name := field.Tag.Get("locationName") 100 if name == "" { 101 name = field.Name 102 } 103 104 switch field.Tag.Get("location") { 105 case "statusCode": 106 unmarshalStatusCode(m, r.HTTPResponse.StatusCode) 107 case "header": 108 err := unmarshalHeader(m, r.HTTPResponse.Header.Get(name)) 109 if err != nil { 110 r.Error = awserr.New("SerializationError", "failed to decode REST response", err) 111 break 112 } 113 case "headers": 114 prefix := field.Tag.Get("locationName") 115 err := unmarshalHeaderMap(m, r.HTTPResponse.Header, prefix) 116 if err != nil { 117 r.Error = awserr.New("SerializationError", "failed to decode REST response", err) 118 break 119 } 120 } 121 } 122 if r.Error != nil { 123 return 124 } 125 } 126} 127 128func unmarshalStatusCode(v reflect.Value, statusCode int) { 129 if !v.IsValid() { 130 return 131 } 132 133 switch v.Interface().(type) { 134 case *int64: 135 s := int64(statusCode) 136 v.Set(reflect.ValueOf(&s)) 137 } 138} 139 140func unmarshalHeaderMap(r reflect.Value, headers http.Header, prefix string) error { 141 switch r.Interface().(type) { 142 case map[string]*string: // we only support string map value types 143 out := map[string]*string{} 144 for k, v := range headers { 145 k = http.CanonicalHeaderKey(k) 146 if strings.HasPrefix(strings.ToLower(k), strings.ToLower(prefix)) { 147 out[k[len(prefix):]] = &v[0] 148 } 149 } 150 r.Set(reflect.ValueOf(out)) 151 } 152 return nil 153} 154 155func unmarshalHeader(v reflect.Value, header string) error { 156 if !v.IsValid() || (header == "" && v.Elem().Kind() != reflect.String) { 157 return nil 158 } 159 160 switch v.Interface().(type) { 161 case *string: 162 v.Set(reflect.ValueOf(&header)) 163 case []byte: 164 b, err := base64.StdEncoding.DecodeString(header) 165 if err != nil { 166 return err 167 } 168 v.Set(reflect.ValueOf(&b)) 169 case *bool: 170 b, err := strconv.ParseBool(header) 171 if err != nil { 172 return err 173 } 174 v.Set(reflect.ValueOf(&b)) 175 case *int64: 176 i, err := strconv.ParseInt(header, 10, 64) 177 if err != nil { 178 return err 179 } 180 v.Set(reflect.ValueOf(&i)) 181 case *float64: 182 f, err := strconv.ParseFloat(header, 64) 183 if err != nil { 184 return err 185 } 186 v.Set(reflect.ValueOf(&f)) 187 case *time.Time: 188 t, err := time.Parse(RFC822, header) 189 if err != nil { 190 return err 191 } 192 v.Set(reflect.ValueOf(&t)) 193 default: 194 err := fmt.Errorf("Unsupported value for param %v (%s)", v.Interface(), v.Type()) 195 return err 196 } 197 return nil 198} 199