1package azblob
2
3import (
4	"bytes"
5	"context"
6	"crypto/hmac"
7	"crypto/sha256"
8	"encoding/base64"
9	"errors"
10	"net/http"
11	"net/url"
12	"sort"
13	"strings"
14	"time"
15
16	"github.com/Azure/azure-pipeline-go/pipeline"
17)
18
19// NewSharedKeyCredential creates an immutable SharedKeyCredential containing the
20// storage account's name and either its primary or secondary key.
21func NewSharedKeyCredential(accountName, accountKey string) (*SharedKeyCredential, error) {
22	bytes, err := base64.StdEncoding.DecodeString(accountKey)
23	if err != nil {
24		return &SharedKeyCredential{}, err
25	}
26	return &SharedKeyCredential{accountName: accountName, accountKey: bytes}, nil
27}
28
29// SharedKeyCredential contains an account's name and its primary or secondary key.
30// It is immutable making it shareable and goroutine-safe.
31type SharedKeyCredential struct {
32	// Only the NewSharedKeyCredential method should set these; all other methods should treat them as read-only
33	accountName string
34	accountKey  []byte
35}
36
37// AccountName returns the Storage account's name.
38func (f SharedKeyCredential) AccountName() string {
39	return f.accountName
40}
41
42func (f SharedKeyCredential) getAccountKey() []byte {
43	return f.accountKey
44}
45
46// noop function to satisfy StorageAccountCredential interface
47func (f SharedKeyCredential) getUDKParams() *UserDelegationKey {
48	return nil
49}
50
51// New creates a credential policy object.
52func (f *SharedKeyCredential) New(next pipeline.Policy, po *pipeline.PolicyOptions) pipeline.Policy {
53	return pipeline.PolicyFunc(func(ctx context.Context, request pipeline.Request) (pipeline.Response, error) {
54		// Add a x-ms-date header if it doesn't already exist
55		if d := request.Header.Get(headerXmsDate); d == "" {
56			request.Header[headerXmsDate] = []string{time.Now().UTC().Format(http.TimeFormat)}
57		}
58		stringToSign, err := f.buildStringToSign(request)
59		if err != nil {
60			return nil, err
61		}
62		signature := f.ComputeHMACSHA256(stringToSign)
63		authHeader := strings.Join([]string{"SharedKey ", f.accountName, ":", signature}, "")
64		request.Header[headerAuthorization] = []string{authHeader}
65
66		response, err := next.Do(ctx, request)
67		if err != nil && response != nil && response.Response() != nil && response.Response().StatusCode == http.StatusForbidden {
68			// Service failed to authenticate request, log it
69			po.Log(pipeline.LogError, "===== HTTP Forbidden status, String-to-Sign:\n"+stringToSign+"\n===============================\n")
70		}
71		return response, err
72	})
73}
74
75// credentialMarker is a package-internal method that exists just to satisfy the Credential interface.
76func (*SharedKeyCredential) credentialMarker() {}
77
78// Constants ensuring that header names are correctly spelled and consistently cased.
79const (
80	headerAuthorization      = "Authorization"
81	headerCacheControl       = "Cache-Control"
82	headerContentEncoding    = "Content-Encoding"
83	headerContentDisposition = "Content-Disposition"
84	headerContentLanguage    = "Content-Language"
85	headerContentLength      = "Content-Length"
86	headerContentMD5         = "Content-MD5"
87	headerContentType        = "Content-Type"
88	headerDate               = "Date"
89	headerIfMatch            = "If-Match"
90	headerIfModifiedSince    = "If-Modified-Since"
91	headerIfNoneMatch        = "If-None-Match"
92	headerIfUnmodifiedSince  = "If-Unmodified-Since"
93	headerRange              = "Range"
94	headerUserAgent          = "User-Agent"
95	headerXmsDate            = "x-ms-date"
96	headerXmsVersion         = "x-ms-version"
97)
98
99// ComputeHMACSHA256 generates a hash signature for an HTTP request or for a SAS.
100func (f SharedKeyCredential) ComputeHMACSHA256(message string) (base64String string) {
101	h := hmac.New(sha256.New, f.accountKey)
102	h.Write([]byte(message))
103	return base64.StdEncoding.EncodeToString(h.Sum(nil))
104}
105
106func (f *SharedKeyCredential) buildStringToSign(request pipeline.Request) (string, error) {
107	// https://docs.microsoft.com/en-us/rest/api/storageservices/authentication-for-the-azure-storage-services
108	headers := request.Header
109	contentLength := headers.Get(headerContentLength)
110	if contentLength == "0" {
111		contentLength = ""
112	}
113
114	canonicalizedResource, err := f.buildCanonicalizedResource(request.URL)
115	if err != nil {
116		return "", err
117	}
118
119	stringToSign := strings.Join([]string{
120		request.Method,
121		headers.Get(headerContentEncoding),
122		headers.Get(headerContentLanguage),
123		contentLength,
124		headers.Get(headerContentMD5),
125		headers.Get(headerContentType),
126		"", // Empty date because x-ms-date is expected (as per web page above)
127		headers.Get(headerIfModifiedSince),
128		headers.Get(headerIfMatch),
129		headers.Get(headerIfNoneMatch),
130		headers.Get(headerIfUnmodifiedSince),
131		headers.Get(headerRange),
132		buildCanonicalizedHeader(headers),
133		canonicalizedResource,
134	}, "\n")
135	return stringToSign, nil
136}
137
138func buildCanonicalizedHeader(headers http.Header) string {
139	cm := map[string][]string{}
140	for k, v := range headers {
141		headerName := strings.TrimSpace(strings.ToLower(k))
142		if strings.HasPrefix(headerName, "x-ms-") {
143			cm[headerName] = v // NOTE: the value must not have any whitespace around it.
144		}
145	}
146	if len(cm) == 0 {
147		return ""
148	}
149
150	keys := make([]string, 0, len(cm))
151	for key := range cm {
152		keys = append(keys, key)
153	}
154	sort.Strings(keys)
155	ch := bytes.NewBufferString("")
156	for i, key := range keys {
157		if i > 0 {
158			ch.WriteRune('\n')
159		}
160		ch.WriteString(key)
161		ch.WriteRune(':')
162		ch.WriteString(strings.Join(cm[key], ","))
163	}
164	return string(ch.Bytes())
165}
166
167func (f *SharedKeyCredential) buildCanonicalizedResource(u *url.URL) (string, error) {
168	// https://docs.microsoft.com/en-us/rest/api/storageservices/authentication-for-the-azure-storage-services
169	cr := bytes.NewBufferString("/")
170	cr.WriteString(f.accountName)
171
172	if len(u.Path) > 0 {
173		// Any portion of the CanonicalizedResource string that is derived from
174		// the resource's URI should be encoded exactly as it is in the URI.
175		// -- https://msdn.microsoft.com/en-gb/library/azure/dd179428.aspx
176		cr.WriteString(u.EscapedPath())
177	} else {
178		// a slash is required to indicate the root path
179		cr.WriteString("/")
180	}
181
182	// params is a map[string][]string; param name is key; params values is []string
183	params, err := url.ParseQuery(u.RawQuery) // Returns URL decoded values
184	if err != nil {
185		return "", errors.New("parsing query parameters must succeed, otherwise there might be serious problems in the SDK/generated code")
186	}
187
188	if len(params) > 0 { // There is at least 1 query parameter
189		paramNames := []string{} // We use this to sort the parameter key names
190		for paramName := range params {
191			paramNames = append(paramNames, paramName) // paramNames must be lowercase
192		}
193		sort.Strings(paramNames)
194
195		for _, paramName := range paramNames {
196			paramValues := params[paramName]
197			sort.Strings(paramValues)
198
199			// Join the sorted key values separated by ','
200			// Then prepend "keyName:"; then add this string to the buffer
201			cr.WriteString("\n" + paramName + ":" + strings.Join(paramValues, ","))
202		}
203	}
204	return string(cr.Bytes()), nil
205}
206