1package azblob_test
2
3import (
4	"bytes"
5	"context"
6	"encoding/json"
7	"errors"
8	"fmt"
9	"io/ioutil"
10	"net/url"
11	"os"
12	"reflect"
13	"runtime"
14	"strings"
15	"testing"
16	"time"
17
18	chk "gopkg.in/check.v1"
19
20	"math/rand"
21
22	"github.com/Azure/azure-pipeline-go/pipeline"
23	"github.com/Azure/azure-storage-blob-go/azblob"
24	"github.com/Azure/go-autorest/autorest/adal"
25)
26
27// For testing docs, see: https://labix.org/gocheck
28// To test a specific test: go test -check.f MyTestSuite
29
30// Hookup to the testing framework
31func Test(t *testing.T) { chk.TestingT(t) }
32
33type aztestsSuite struct{}
34
35var _ = chk.Suite(&aztestsSuite{})
36
37func (s *aztestsSuite) TestRetryPolicyRetryReadsFromSecondaryHostField(c *chk.C) {
38	_, found := reflect.TypeOf(azblob.RetryOptions{}).FieldByName("RetryReadsFromSecondaryHost")
39	if !found {
40		// Make sure the RetryOption was not erroneously overwritten
41		c.Fatal("RetryOption's RetryReadsFromSecondaryHost field must exist in the Blob SDK - uncomment it and make sure the field is returned from the retryReadsFromSecondaryHost() method too!")
42	}
43}
44
45const (
46	containerPrefix             = "go"
47	blobPrefix                  = "gotestblob"
48	blockBlobDefaultData        = "GoBlockBlobData"
49	validationErrorSubstring    = "validation failed"
50	invalidHeaderErrorSubstring = "invalid header field" // error thrown by the http client
51)
52
53var ctx = context.Background()
54var basicHeaders = azblob.BlobHTTPHeaders{
55	ContentType:        "my_type",
56	ContentDisposition: "my_disposition",
57	CacheControl:       "control",
58	ContentMD5:         nil,
59	ContentLanguage:    "my_language",
60	ContentEncoding:    "my_encoding",
61}
62
63var basicMetadata = azblob.Metadata{"foo": "bar"}
64
65type testPipeline struct{}
66
67const testPipelineMessage string = "Test factory invoked"
68
69func (tm testPipeline) Do(ctx context.Context, methodFactory pipeline.Factory, request pipeline.Request) (pipeline.Response, error) {
70	return nil, errors.New(testPipelineMessage)
71}
72
73// This function generates an entity name by concatenating the passed prefix,
74// the name of the test requesting the entity name, and the minute, second, and nanoseconds of the call.
75// This should make it easy to associate the entities with their test, uniquely identify
76// them, and determine the order in which they were created.
77// Note that this imposes a restriction on the length of test names
78func generateName(prefix string) string {
79	// These next lines up through the for loop are obtaining and walking up the stack
80	// trace to extrat the test name, which is stored in name
81	pc := make([]uintptr, 10)
82	runtime.Callers(0, pc)
83	frames := runtime.CallersFrames(pc)
84	name := ""
85	for f, next := frames.Next(); next; f, next = frames.Next() {
86		name = f.Function
87		if strings.Contains(name, "Suite") {
88			break
89		}
90	}
91	funcNameStart := strings.Index(name, "Test")
92	name = name[funcNameStart+len("Test"):] // Just get the name of the test and not any of the garbage at the beginning
93	name = strings.ToLower(name)            // Ensure it is a valid resource name
94	currentTime := time.Now()
95	name = fmt.Sprintf("%s%s%d%d%d", prefix, strings.ToLower(name), currentTime.Minute(), currentTime.Second(), currentTime.Nanosecond())
96	return name
97}
98
99func generateContainerName() string {
100	return generateName(containerPrefix)
101}
102
103func generateBlobName() string {
104	return generateName(blobPrefix)
105}
106
107func getContainerURL(c *chk.C, bsu azblob.ServiceURL) (container azblob.ContainerURL, name string) {
108	name = generateContainerName()
109	container = bsu.NewContainerURL(name)
110
111	return container, name
112}
113
114func getBlockBlobURL(c *chk.C, container azblob.ContainerURL) (blob azblob.BlockBlobURL, name string) {
115	name = generateBlobName()
116	blob = container.NewBlockBlobURL(name)
117
118	return blob, name
119}
120
121func getAppendBlobURL(c *chk.C, container azblob.ContainerURL) (blob azblob.AppendBlobURL, name string) {
122	name = generateBlobName()
123	blob = container.NewAppendBlobURL(name)
124
125	return blob, name
126}
127
128func getPageBlobURL(c *chk.C, container azblob.ContainerURL) (blob azblob.PageBlobURL, name string) {
129	name = generateBlobName()
130	blob = container.NewPageBlobURL(name)
131
132	return
133}
134
135func getReaderToRandomBytes(n int) *bytes.Reader {
136	r, _ := getRandomDataAndReader(n)
137	return r
138}
139
140func getRandomDataAndReader(n int) (*bytes.Reader, []byte) {
141	data := make([]byte, n, n)
142	rand.Read(data)
143	return bytes.NewReader(data), data
144}
145
146func createNewContainer(c *chk.C, bsu azblob.ServiceURL) (container azblob.ContainerURL, name string) {
147	container, name = getContainerURL(c, bsu)
148
149	cResp, err := container.Create(ctx, nil, azblob.PublicAccessNone)
150	c.Assert(err, chk.IsNil)
151	c.Assert(cResp.StatusCode(), chk.Equals, 201)
152	return container, name
153}
154
155func createNewContainerWithSuffix(c *chk.C, bsu azblob.ServiceURL, suffix string) (container azblob.ContainerURL, name string) {
156	// The goal of adding the suffix is to be able to predetermine what order the containers will be in when listed.
157	// We still need the container prefix to come first, though, to ensure only containers as a part of this test
158	// are listed at all.
159	name = generateName(containerPrefix + suffix)
160	container = bsu.NewContainerURL(name)
161
162	cResp, err := container.Create(ctx, nil, azblob.PublicAccessNone)
163	c.Assert(err, chk.IsNil)
164	c.Assert(cResp.StatusCode(), chk.Equals, 201)
165	return container, name
166}
167
168func createNewBlockBlob(c *chk.C, container azblob.ContainerURL) (blob azblob.BlockBlobURL, name string) {
169	blob, name = getBlockBlobURL(c, container)
170
171	cResp, err := blob.Upload(ctx, strings.NewReader(blockBlobDefaultData), azblob.BlobHTTPHeaders{},
172		nil, azblob.BlobAccessConditions{})
173
174	c.Assert(err, chk.IsNil)
175	c.Assert(cResp.StatusCode(), chk.Equals, 201)
176
177	return
178}
179
180func createNewAppendBlob(c *chk.C, container azblob.ContainerURL) (blob azblob.AppendBlobURL, name string) {
181	blob, name = getAppendBlobURL(c, container)
182
183	resp, err := blob.Create(ctx, azblob.BlobHTTPHeaders{}, nil, azblob.BlobAccessConditions{})
184
185	c.Assert(err, chk.IsNil)
186	c.Assert(resp.StatusCode(), chk.Equals, 201)
187	return
188}
189
190func createNewPageBlob(c *chk.C, container azblob.ContainerURL) (blob azblob.PageBlobURL, name string) {
191	blob, name = getPageBlobURL(c, container)
192
193	resp, err := blob.Create(ctx, azblob.PageBlobPageBytes*10, 0, azblob.BlobHTTPHeaders{}, nil, azblob.BlobAccessConditions{})
194	c.Assert(err, chk.IsNil)
195	c.Assert(resp.StatusCode(), chk.Equals, 201)
196	return
197}
198
199func createNewPageBlobWithSize(c *chk.C, container azblob.ContainerURL, sizeInBytes int64) (blob azblob.PageBlobURL, name string) {
200	blob, name = getPageBlobURL(c, container)
201
202	resp, err := blob.Create(ctx, sizeInBytes, 0, azblob.BlobHTTPHeaders{}, nil, azblob.BlobAccessConditions{})
203
204	c.Assert(err, chk.IsNil)
205	c.Assert(resp.StatusCode(), chk.Equals, 201)
206	return
207}
208
209func createBlockBlobWithPrefix(c *chk.C, container azblob.ContainerURL, prefix string) (blob azblob.BlockBlobURL, name string) {
210	name = prefix + generateName(blobPrefix)
211	blob = container.NewBlockBlobURL(name)
212
213	cResp, err := blob.Upload(ctx, strings.NewReader(blockBlobDefaultData), azblob.BlobHTTPHeaders{},
214		nil, azblob.BlobAccessConditions{})
215
216	c.Assert(err, chk.IsNil)
217	c.Assert(cResp.StatusCode(), chk.Equals, 201)
218	return
219}
220
221func deleteContainer(c *chk.C, container azblob.ContainerURL) {
222	resp, err := container.Delete(ctx, azblob.ContainerAccessConditions{})
223	c.Assert(err, chk.IsNil)
224	c.Assert(resp.StatusCode(), chk.Equals, 202)
225}
226
227func getGenericCredential(accountType string) (*azblob.SharedKeyCredential, error) {
228	accountNameEnvVar := accountType + "ACCOUNT_NAME"
229	accountKeyEnvVar := accountType + "ACCOUNT_KEY"
230	accountName, accountKey := os.Getenv(accountNameEnvVar), os.Getenv(accountKeyEnvVar)
231	if accountName == "" || accountKey == "" {
232		return nil, errors.New(accountNameEnvVar + " and/or " + accountKeyEnvVar + " environment variables not specified.")
233	}
234	return azblob.NewSharedKeyCredential(accountName, accountKey)
235}
236
237//getOAuthCredential can intake a OAuth credential from environment variables in one of the following ways:
238//Direct: Supply a ADAL OAuth token in OAUTH_TOKEN and application ID in APPLICATION_ID to refresh the supplied token.
239//Client secret: Supply a client secret in CLIENT_SECRET and application ID in APPLICATION_ID for SPN auth.
240//TENANT_ID is optional and will be inferred as common if it is not explicitly defined.
241func getOAuthCredential(accountType string) (*azblob.TokenCredential, error) {
242	oauthTokenEnvVar := accountType + "OAUTH_TOKEN"
243	clientSecretEnvVar := accountType + "CLIENT_SECRET"
244	applicationIdEnvVar := accountType + "APPLICATION_ID"
245	tenantIdEnvVar := accountType + "TENANT_ID"
246	oauthToken, appId, tenantId, clientSecret := []byte(os.Getenv(oauthTokenEnvVar)), os.Getenv(applicationIdEnvVar), os.Getenv(tenantIdEnvVar), os.Getenv(clientSecretEnvVar)
247	if (len(oauthToken) == 0 && clientSecret == "") || appId == "" {
248		return nil, errors.New("(" + oauthTokenEnvVar + " OR " + clientSecretEnvVar + ") and/or " + applicationIdEnvVar + " environment variables not specified.")
249	}
250	if tenantId == "" {
251		tenantId = "common"
252	}
253
254	var Token adal.Token
255	if len(oauthToken) != 0 {
256		if err := json.Unmarshal(oauthToken, &Token); err != nil {
257			return nil, err
258		}
259	}
260
261	var spt *adal.ServicePrincipalToken
262
263	oauthConfig, err := adal.NewOAuthConfig("https://login.microsoftonline.com", tenantId)
264	if err != nil {
265		return nil, err
266	}
267
268	if len(oauthToken) == 0 {
269		spt, err = adal.NewServicePrincipalToken(
270			*oauthConfig,
271			appId,
272			clientSecret,
273			"https://storage.azure.com")
274		if err != nil {
275			return nil, err
276		}
277	} else {
278		spt, err = adal.NewServicePrincipalTokenFromManualToken(*oauthConfig,
279			appId,
280			"https://storage.azure.com",
281			Token,
282		)
283		if err != nil {
284			return nil, err
285		}
286	}
287
288	err = spt.Refresh()
289	if err != nil {
290		return nil, err
291	}
292
293	tc := azblob.NewTokenCredential(spt.Token().AccessToken, func(tc azblob.TokenCredential) time.Duration {
294		_ = spt.Refresh()
295		return time.Until(spt.Token().Expires())
296	})
297
298	return &tc, nil
299}
300
301func getGenericBSU(accountType string) (azblob.ServiceURL, error) {
302	credential, err := getGenericCredential(accountType)
303	if err != nil {
304		return azblob.ServiceURL{}, err
305	}
306
307	pipeline := azblob.NewPipeline(credential, azblob.PipelineOptions{})
308	blobPrimaryURL, _ := url.Parse("https://" + credential.AccountName() + ".blob.core.windows.net/")
309	return azblob.NewServiceURL(*blobPrimaryURL, pipeline), nil
310}
311
312func getBSU() azblob.ServiceURL {
313	bsu, _ := getGenericBSU("")
314	return bsu
315}
316
317func getAlternateBSU() (azblob.ServiceURL, error) {
318	return getGenericBSU("SECONDARY_")
319}
320
321func getPremiumBSU() (azblob.ServiceURL, error) {
322	return getGenericBSU("PREMIUM_")
323}
324
325func getBlobStorageBSU() (azblob.ServiceURL, error) {
326	return getGenericBSU("BLOB_STORAGE_")
327}
328
329func validateStorageError(c *chk.C, err error, code azblob.ServiceCodeType) {
330	serr, _ := err.(azblob.StorageError)
331	c.Assert(serr.ServiceCode(), chk.Equals, code)
332}
333
334func getRelativeTimeGMT(amount time.Duration) time.Time {
335	currentTime := time.Now().In(time.FixedZone("GMT", 0))
336	currentTime = currentTime.Add(amount * time.Second)
337	return currentTime
338}
339
340func generateCurrentTimeWithModerateResolution() time.Time {
341	highResolutionTime := time.Now().UTC()
342	return time.Date(highResolutionTime.Year(), highResolutionTime.Month(), highResolutionTime.Day(), highResolutionTime.Hour(), highResolutionTime.Minute(),
343		highResolutionTime.Second(), 0, highResolutionTime.Location())
344}
345
346// Some tests require setting service properties. It can take up to 30 seconds for the new properties to be reflected across all FEs.
347// We will enable the necessary property and try to run the test implementation. If it fails with an error that should be due to
348// those changes not being reflected yet, we will wait 30 seconds and try the test again. If it fails this time for any reason,
349// we fail the test. It is the responsibility of the the testImplFunc to determine which error string indicates the test should be retried.
350// There can only be one such string. All errors that cannot be due to this detail should be asserted and not returned as an error string.
351func runTestRequiringServiceProperties(c *chk.C, bsu azblob.ServiceURL, code string,
352	enableServicePropertyFunc func(*chk.C, azblob.ServiceURL),
353	testImplFunc func(*chk.C, azblob.ServiceURL) error,
354	disableServicePropertyFunc func(*chk.C, azblob.ServiceURL)) {
355	enableServicePropertyFunc(c, bsu)
356	defer disableServicePropertyFunc(c, bsu)
357	err := testImplFunc(c, bsu)
358	// We cannot assume that the error indicative of slow update will necessarily be a StorageError. As in ListBlobs.
359	if err != nil && err.Error() == code {
360		time.Sleep(time.Second * 30)
361		err = testImplFunc(c, bsu)
362		c.Assert(err, chk.IsNil)
363	}
364}
365
366func enableSoftDelete(c *chk.C, bsu azblob.ServiceURL) {
367	days := int32(1)
368	_, err := bsu.SetProperties(ctx, azblob.StorageServiceProperties{DeleteRetentionPolicy: &azblob.RetentionPolicy{Enabled: true, Days: &days}})
369	c.Assert(err, chk.IsNil)
370}
371
372func disableSoftDelete(c *chk.C, bsu azblob.ServiceURL) {
373	_, err := bsu.SetProperties(ctx, azblob.StorageServiceProperties{DeleteRetentionPolicy: &azblob.RetentionPolicy{Enabled: false}})
374	c.Assert(err, chk.IsNil)
375}
376
377func validateUpload(c *chk.C, blobURL azblob.BlockBlobURL) {
378	resp, err := blobURL.Download(ctx, 0, 0, azblob.BlobAccessConditions{}, false)
379	c.Assert(err, chk.IsNil)
380	data, _ := ioutil.ReadAll(resp.Response().Body)
381	c.Assert(data, chk.HasLen, 0)
382}
383