1package azblob 2 3import ( 4 "bytes" 5 "context" 6 "crypto/hmac" 7 "crypto/sha256" 8 "encoding/base64" 9 "errors" 10 "net/http" 11 "net/url" 12 "sort" 13 "strings" 14 "time" 15 16 "github.com/Azure/azure-pipeline-go/pipeline" 17) 18 19// NewSharedKeyCredential creates an immutable SharedKeyCredential containing the 20// storage account's name and either its primary or secondary key. 21func NewSharedKeyCredential(accountName, accountKey string) (*SharedKeyCredential, error) { 22 bytes, err := base64.StdEncoding.DecodeString(accountKey) 23 if err != nil { 24 return &SharedKeyCredential{}, err 25 } 26 return &SharedKeyCredential{accountName: accountName, accountKey: bytes}, nil 27} 28 29// SharedKeyCredential contains an account's name and its primary or secondary key. 30// It is immutable making it shareable and goroutine-safe. 31type SharedKeyCredential struct { 32 // Only the NewSharedKeyCredential method should set these; all other methods should treat them as read-only 33 accountName string 34 accountKey []byte 35} 36 37// AccountName returns the Storage account's name. 38func (f SharedKeyCredential) AccountName() string { 39 return f.accountName 40} 41 42func (f SharedKeyCredential) getAccountKey() []byte { 43 return f.accountKey 44} 45 46// noop function to satisfy StorageAccountCredential interface 47func (f SharedKeyCredential) getUDKParams() *UserDelegationKey { 48 return nil 49} 50 51// New creates a credential policy object. 52func (f *SharedKeyCredential) New(next pipeline.Policy, po *pipeline.PolicyOptions) pipeline.Policy { 53 return pipeline.PolicyFunc(func(ctx context.Context, request pipeline.Request) (pipeline.Response, error) { 54 // Add a x-ms-date header if it doesn't already exist 55 if d := request.Header.Get(headerXmsDate); d == "" { 56 request.Header[headerXmsDate] = []string{time.Now().UTC().Format(http.TimeFormat)} 57 } 58 stringToSign, err := f.buildStringToSign(request) 59 if err != nil { 60 return nil, err 61 } 62 signature := f.ComputeHMACSHA256(stringToSign) 63 authHeader := strings.Join([]string{"SharedKey ", f.accountName, ":", signature}, "") 64 request.Header[headerAuthorization] = []string{authHeader} 65 66 response, err := next.Do(ctx, request) 67 if err != nil && response != nil && response.Response() != nil && response.Response().StatusCode == http.StatusForbidden { 68 // Service failed to authenticate request, log it 69 po.Log(pipeline.LogError, "===== HTTP Forbidden status, String-to-Sign:\n"+stringToSign+"\n===============================\n") 70 } 71 return response, err 72 }) 73} 74 75// credentialMarker is a package-internal method that exists just to satisfy the Credential interface. 76func (*SharedKeyCredential) credentialMarker() {} 77 78// Constants ensuring that header names are correctly spelled and consistently cased. 79const ( 80 headerAuthorization = "Authorization" 81 headerCacheControl = "Cache-Control" 82 headerContentEncoding = "Content-Encoding" 83 headerContentDisposition = "Content-Disposition" 84 headerContentLanguage = "Content-Language" 85 headerContentLength = "Content-Length" 86 headerContentMD5 = "Content-MD5" 87 headerContentType = "Content-Type" 88 headerDate = "Date" 89 headerIfMatch = "If-Match" 90 headerIfModifiedSince = "If-Modified-Since" 91 headerIfNoneMatch = "If-None-Match" 92 headerIfUnmodifiedSince = "If-Unmodified-Since" 93 headerRange = "Range" 94 headerUserAgent = "User-Agent" 95 headerXmsDate = "x-ms-date" 96 headerXmsVersion = "x-ms-version" 97) 98 99// ComputeHMACSHA256 generates a hash signature for an HTTP request or for a SAS. 100func (f SharedKeyCredential) ComputeHMACSHA256(message string) (base64String string) { 101 h := hmac.New(sha256.New, f.accountKey) 102 h.Write([]byte(message)) 103 return base64.StdEncoding.EncodeToString(h.Sum(nil)) 104} 105 106func (f *SharedKeyCredential) buildStringToSign(request pipeline.Request) (string, error) { 107 // https://docs.microsoft.com/en-us/rest/api/storageservices/authentication-for-the-azure-storage-services 108 headers := request.Header 109 contentLength := headers.Get(headerContentLength) 110 if contentLength == "0" { 111 contentLength = "" 112 } 113 114 canonicalizedResource, err := f.buildCanonicalizedResource(request.URL) 115 if err != nil { 116 return "", err 117 } 118 119 stringToSign := strings.Join([]string{ 120 request.Method, 121 headers.Get(headerContentEncoding), 122 headers.Get(headerContentLanguage), 123 contentLength, 124 headers.Get(headerContentMD5), 125 headers.Get(headerContentType), 126 "", // Empty date because x-ms-date is expected (as per web page above) 127 headers.Get(headerIfModifiedSince), 128 headers.Get(headerIfMatch), 129 headers.Get(headerIfNoneMatch), 130 headers.Get(headerIfUnmodifiedSince), 131 headers.Get(headerRange), 132 buildCanonicalizedHeader(headers), 133 canonicalizedResource, 134 }, "\n") 135 return stringToSign, nil 136} 137 138func buildCanonicalizedHeader(headers http.Header) string { 139 cm := map[string][]string{} 140 for k, v := range headers { 141 headerName := strings.TrimSpace(strings.ToLower(k)) 142 if strings.HasPrefix(headerName, "x-ms-") { 143 cm[headerName] = v // NOTE: the value must not have any whitespace around it. 144 } 145 } 146 if len(cm) == 0 { 147 return "" 148 } 149 150 keys := make([]string, 0, len(cm)) 151 for key := range cm { 152 keys = append(keys, key) 153 } 154 sort.Strings(keys) 155 ch := bytes.NewBufferString("") 156 for i, key := range keys { 157 if i > 0 { 158 ch.WriteRune('\n') 159 } 160 ch.WriteString(key) 161 ch.WriteRune(':') 162 ch.WriteString(strings.Join(cm[key], ",")) 163 } 164 return string(ch.Bytes()) 165} 166 167func (f *SharedKeyCredential) buildCanonicalizedResource(u *url.URL) (string, error) { 168 // https://docs.microsoft.com/en-us/rest/api/storageservices/authentication-for-the-azure-storage-services 169 cr := bytes.NewBufferString("/") 170 cr.WriteString(f.accountName) 171 172 if len(u.Path) > 0 { 173 // Any portion of the CanonicalizedResource string that is derived from 174 // the resource's URI should be encoded exactly as it is in the URI. 175 // -- https://msdn.microsoft.com/en-gb/library/azure/dd179428.aspx 176 cr.WriteString(u.EscapedPath()) 177 } else { 178 // a slash is required to indicate the root path 179 cr.WriteString("/") 180 } 181 182 // params is a map[string][]string; param name is key; params values is []string 183 params, err := url.ParseQuery(u.RawQuery) // Returns URL decoded values 184 if err != nil { 185 return "", errors.New("parsing query parameters must succeed, otherwise there might be serious problems in the SDK/generated code") 186 } 187 188 if len(params) > 0 { // There is at least 1 query parameter 189 paramNames := []string{} // We use this to sort the parameter key names 190 for paramName := range params { 191 paramNames = append(paramNames, paramName) // paramNames must be lowercase 192 } 193 sort.Strings(paramNames) 194 195 for _, paramName := range paramNames { 196 paramValues := params[paramName] 197 sort.Strings(paramValues) 198 199 // Join the sorted key values separated by ',' 200 // Then prepend "keyName:"; then add this string to the buffer 201 cr.WriteString("\n" + paramName + ":" + strings.Join(paramValues, ",")) 202 } 203 } 204 return string(cr.Bytes()), nil 205} 206