1// +build go1.13 2 3// Copyright (c) Microsoft Corporation. All rights reserved. 4// Licensed under the MIT License. 5 6package azcore 7 8import ( 9 "bytes" 10 "context" 11 "encoding/base64" 12 "encoding/json" 13 "encoding/xml" 14 "errors" 15 "fmt" 16 "io" 17 "io/ioutil" 18 "mime/multipart" 19 "net/http" 20 "reflect" 21 "strconv" 22 "strings" 23 24 "golang.org/x/net/http/httpguts" 25) 26 27const ( 28 contentTypeAppJSON = "application/json" 29 contentTypeAppXML = "application/xml" 30) 31 32// Base64Encoding is usesd to specify which base-64 encoder/decoder to use when 33// encoding/decoding a slice of bytes to/from a string. 34type Base64Encoding int 35 36const ( 37 // Base64StdFormat uses base64.StdEncoding for encoding and decoding payloads. 38 Base64StdFormat Base64Encoding = 0 39 40 // Base64URLFormat uses base64.RawURLEncoding for encoding and decoding payloads. 41 Base64URLFormat Base64Encoding = 1 42) 43 44// Request is an abstraction over the creation of an HTTP request as it passes through the pipeline. 45// Don't use this type directly, use NewRequest() instead. 46type Request struct { 47 *http.Request 48 body ReadSeekCloser 49 policies []Policy 50 values opValues 51} 52 53type opValues map[reflect.Type]interface{} 54 55// Set adds/changes a value 56func (ov opValues) set(value interface{}) { 57 ov[reflect.TypeOf(value)] = value 58} 59 60// Get looks for a value set by SetValue first 61func (ov opValues) get(value interface{}) bool { 62 v, ok := ov[reflect.ValueOf(value).Elem().Type()] 63 if ok { 64 reflect.ValueOf(value).Elem().Set(reflect.ValueOf(v)) 65 } 66 return ok 67} 68 69// JoinPaths concatenates multiple URL path segments into one path, 70// inserting path separation characters as required. 71func JoinPaths(paths ...string) string { 72 if len(paths) == 0 { 73 return "" 74 } 75 path := paths[0] 76 for i := 1; i < len(paths); i++ { 77 if path[len(path)-1] == '/' && paths[i][0] == '/' { 78 // strip off trailing '/' to avoid doubling up 79 path = path[:len(path)-1] 80 } else if path[len(path)-1] != '/' && paths[i][0] != '/' { 81 // add a trailing '/' 82 path = path + "/" 83 } 84 path += paths[i] 85 } 86 return path 87} 88 89// NewRequest creates a new Request with the specified input. 90func NewRequest(ctx context.Context, httpMethod string, endpoint string) (*Request, error) { 91 req, err := http.NewRequestWithContext(ctx, httpMethod, endpoint, nil) 92 if err != nil { 93 return nil, err 94 } 95 if req.URL.Host == "" { 96 return nil, errors.New("no Host in request URL") 97 } 98 if !(req.URL.Scheme == "http" || req.URL.Scheme == "https") { 99 return nil, fmt.Errorf("unsupported protocol scheme %s", req.URL.Scheme) 100 } 101 return &Request{Request: req}, nil 102} 103 104// Next calls the next policy in the pipeline. 105// If there are no more policies, nil and ErrNoMorePolicies are returned. 106// This method is intended to be called from pipeline policies. 107// To send a request through a pipeline call Pipeline.Do(). 108func (req *Request) Next() (*Response, error) { 109 if len(req.policies) == 0 { 110 return nil, ErrNoMorePolicies 111 } 112 nextPolicy := req.policies[0] 113 nextReq := *req 114 nextReq.policies = nextReq.policies[1:] 115 return nextPolicy.Do(&nextReq) 116} 117 118// MarshalAsByteArray will base-64 encode the byte slice v, then calls SetBody. 119// The encoded value is treated as a JSON string. 120func (req *Request) MarshalAsByteArray(v []byte, format Base64Encoding) error { 121 // send as a JSON string 122 encode := fmt.Sprintf("\"%s\"", EncodeByteArray(v, format)) 123 return req.SetBody(NopCloser(strings.NewReader(encode)), contentTypeAppJSON) 124} 125 126// MarshalAsJSON calls json.Marshal() to get the JSON encoding of v then calls SetBody. 127func (req *Request) MarshalAsJSON(v interface{}) error { 128 v = cloneWithoutReadOnlyFields(v) 129 b, err := json.Marshal(v) 130 if err != nil { 131 return fmt.Errorf("error marshalling type %T: %s", v, err) 132 } 133 return req.SetBody(NopCloser(bytes.NewReader(b)), contentTypeAppJSON) 134} 135 136// MarshalAsXML calls xml.Marshal() to get the XML encoding of v then calls SetBody. 137func (req *Request) MarshalAsXML(v interface{}) error { 138 b, err := xml.Marshal(v) 139 if err != nil { 140 return fmt.Errorf("error marshalling type %T: %s", v, err) 141 } 142 return req.SetBody(NopCloser(bytes.NewReader(b)), contentTypeAppXML) 143} 144 145// SetOperationValue adds/changes a mutable key/value associated with a single operation. 146func (req *Request) SetOperationValue(value interface{}) { 147 if req.values == nil { 148 req.values = opValues{} 149 } 150 req.values.set(value) 151} 152 153// OperationValue looks for a value set by SetOperationValue(). 154func (req *Request) OperationValue(value interface{}) bool { 155 if req.values == nil { 156 return false 157 } 158 return req.values.get(value) 159} 160 161// SetBody sets the specified ReadSeekCloser as the HTTP request body. 162func (req *Request) SetBody(body ReadSeekCloser, contentType string) error { 163 // Set the body and content length. 164 size, err := body.Seek(0, io.SeekEnd) // Seek to the end to get the stream's size 165 if err != nil { 166 return err 167 } 168 if size == 0 { 169 body.Close() 170 return nil 171 } 172 _, err = body.Seek(0, io.SeekStart) 173 if err != nil { 174 return err 175 } 176 // keep a copy of the original body. this is to handle cases 177 // where req.Body is replaced, e.g. httputil.DumpRequest and friends. 178 req.body = body 179 req.Request.Body = body 180 req.Request.ContentLength = size 181 req.Header.Set(HeaderContentType, contentType) 182 req.Header.Set(HeaderContentLength, strconv.FormatInt(size, 10)) 183 return nil 184} 185 186// SetMultipartFormData writes the specified keys/values as multi-part form 187// fields with the specified value. File content must be specified as a ReadSeekCloser. 188// All other values are treated as string values. 189func (req *Request) SetMultipartFormData(formData map[string]interface{}) error { 190 body := bytes.Buffer{} 191 writer := multipart.NewWriter(&body) 192 for k, v := range formData { 193 if rsc, ok := v.(ReadSeekCloser); ok { 194 // this is the body to upload, the key is its file name 195 fd, err := writer.CreateFormFile(k, k) 196 if err != nil { 197 return err 198 } 199 // copy the data to the form file 200 if _, err = io.Copy(fd, rsc); err != nil { 201 return err 202 } 203 continue 204 } 205 // ensure the value is in string format 206 s, ok := v.(string) 207 if !ok { 208 s = fmt.Sprintf("%v", v) 209 } 210 if err := writer.WriteField(k, s); err != nil { 211 return err 212 } 213 } 214 if err := writer.Close(); err != nil { 215 return err 216 } 217 req.body = NopCloser(bytes.NewReader(body.Bytes())) 218 req.Body = req.body 219 req.ContentLength = int64(body.Len()) 220 req.Header.Set(HeaderContentType, writer.FormDataContentType()) 221 req.Header.Set(HeaderContentLength, strconv.FormatInt(req.ContentLength, 10)) 222 return nil 223} 224 225// SkipBodyDownload will disable automatic downloading of the response body. 226func (req *Request) SkipBodyDownload() { 227 req.SetOperationValue(bodyDownloadPolicyOpValues{skip: true}) 228} 229 230// RewindBody seeks the request's Body stream back to the beginning so it can be resent when retrying an operation. 231func (req *Request) RewindBody() error { 232 if req.body != nil { 233 // Reset the stream back to the beginning and restore the body 234 _, err := req.body.Seek(0, io.SeekStart) 235 req.Body = req.body 236 return err 237 } 238 return nil 239} 240 241// Close closes the request body. 242func (req *Request) Close() error { 243 if req.Body == nil { 244 return nil 245 } 246 return req.Body.Close() 247} 248 249// Telemetry adds telemetry data to the request. 250// If telemetry reporting is disabled the value is discarded. 251func (req *Request) Telemetry(v string) { 252 req.SetOperationValue(requestTelemetry(v)) 253} 254 255// clone returns a deep copy of the request with its context changed to ctx 256func (req *Request) clone(ctx context.Context) *Request { 257 r2 := Request{} 258 r2 = *req 259 r2.Request = req.Request.Clone(ctx) 260 return &r2 261} 262 263// valid returns nil if the underlying http.Request is well-formed. 264func (req *Request) valid() error { 265 // check copied from Transport.roundTrip() 266 for k, vv := range req.Header { 267 if !httpguts.ValidHeaderFieldName(k) { 268 req.Close() 269 return fmt.Errorf("invalid header field name %q", k) 270 } 271 for _, v := range vv { 272 if !httpguts.ValidHeaderFieldValue(v) { 273 req.Close() 274 return fmt.Errorf("invalid header field value %q for key %v", v, k) 275 } 276 } 277 } 278 return nil 279} 280 281// writes to a buffer, used for logging purposes 282func (req *Request) writeBody(b *bytes.Buffer) error { 283 if req.Body == nil { 284 fmt.Fprint(b, " Request contained no body\n") 285 return nil 286 } 287 if ct := req.Header.Get(HeaderContentType); !shouldLogBody(b, ct) { 288 return nil 289 } 290 body, err := ioutil.ReadAll(req.Body) 291 if err != nil { 292 fmt.Fprintf(b, " Failed to read request body: %s\n", err.Error()) 293 return err 294 } 295 if err := req.RewindBody(); err != nil { 296 return err 297 } 298 logBody(b, body) 299 return nil 300} 301 302// EncodeByteArray will base-64 encode the byte slice v. 303func EncodeByteArray(v []byte, format Base64Encoding) string { 304 if format == Base64URLFormat { 305 return base64.RawURLEncoding.EncodeToString(v) 306 } 307 return base64.StdEncoding.EncodeToString(v) 308} 309 310// returns a clone of the object graph pointed to by v, omitting values of all read-only 311// fields. if there are no read-only fields in the object graph, no clone is created. 312func cloneWithoutReadOnlyFields(v interface{}) interface{} { 313 val := reflect.Indirect(reflect.ValueOf(v)) 314 if val.Kind() != reflect.Struct { 315 // not a struct, skip 316 return v 317 } 318 // first walk the graph to find any R/O fields. 319 // if there aren't any, skip cloning the graph. 320 if !recursiveFindReadOnlyField(val) { 321 return v 322 } 323 return recursiveCloneWithoutReadOnlyFields(val) 324} 325 326// returns true if any field in the object graph of val contains the `azure:"ro"` tag value 327func recursiveFindReadOnlyField(val reflect.Value) bool { 328 t := val.Type() 329 // iterate over the fields, looking for the "azure" tag. 330 for i := 0; i < t.NumField(); i++ { 331 field := t.Field(i) 332 aztag := field.Tag.Get("azure") 333 if azureTagIsReadOnly(aztag) { 334 return true 335 } else if reflect.Indirect(val.Field(i)).Kind() == reflect.Struct && recursiveFindReadOnlyField(reflect.Indirect(val.Field(i))) { 336 return true 337 } 338 } 339 return false 340} 341 342// clones the object graph of val. all non-R/O properties are copied to the clone 343func recursiveCloneWithoutReadOnlyFields(val reflect.Value) interface{} { 344 clone := reflect.New(val.Type()) 345 t := val.Type() 346 // iterate over the fields, looking for the "azure" tag. 347 for i := 0; i < t.NumField(); i++ { 348 field := t.Field(i) 349 aztag := field.Tag.Get("azure") 350 if azureTagIsReadOnly(aztag) { 351 // omit from payload 352 } else if reflect.Indirect(val.Field(i)).Kind() == reflect.Struct { 353 // recursive case 354 v := recursiveCloneWithoutReadOnlyFields(reflect.Indirect(val.Field(i))) 355 if t.Field(i).Anonymous { 356 // NOTE: this does not handle the case of embedded fields of unexported struct types. 357 // this should be ok as we don't generate any code like this at present 358 reflect.Indirect(clone).Field(i).Set(reflect.Indirect(reflect.ValueOf(v))) 359 } else { 360 reflect.Indirect(clone).Field(i).Set(reflect.ValueOf(v)) 361 } 362 } else { 363 // no azure RO tag, non-recursive case, include in payload 364 reflect.Indirect(clone).Field(i).Set(val.Field(i)) 365 } 366 } 367 return clone.Interface() 368} 369 370// returns true if the "azure" tag contains the option "ro" 371func azureTagIsReadOnly(tag string) bool { 372 if tag == "" { 373 return false 374 } 375 parts := strings.Split(tag, ",") 376 for _, part := range parts { 377 if part == "ro" { 378 return true 379 } 380 } 381 return false 382} 383 384func logBody(b *bytes.Buffer, body []byte) { 385 fmt.Fprintln(b, " --------------------------------------------------------------------------------") 386 fmt.Fprintln(b, string(body)) 387 fmt.Fprintln(b, " --------------------------------------------------------------------------------") 388} 389