1// Copyright (c) 2021 Snowflake Computing Inc. All right reserved.
2
3package gosnowflake
4
5import (
6	"context"
7	"encoding/json"
8	"errors"
9	"fmt"
10	"io/ioutil"
11	"net/http"
12	"net/url"
13	"os"
14	"strings"
15	"time"
16
17	"github.com/Azure/azure-storage-blob-go/azblob"
18)
19
20type snowflakeAzureUtil struct {
21}
22
23type azureLocation struct {
24	containerName string
25	path          string
26}
27
28func (util *snowflakeAzureUtil) createClient(info *execResponseStageInfo, _ bool) cloudClient {
29	sasToken := info.Creds.AzureSasToken
30	p := azblob.NewPipeline(azblob.NewAnonymousCredential(), azblob.PipelineOptions{
31		Retry: azblob.RetryOptions{
32			Policy:     azblob.RetryPolicyExponential,
33			MaxTries:   60,
34			RetryDelay: 2 * time.Second,
35		},
36	})
37
38	u, _ := url.Parse(fmt.Sprintf("https://%s.%s/%s%s", info.StorageAccount, info.EndPoint, info.Path, sasToken))
39	containerURL := azblob.NewContainerURL(*u, p)
40	return &containerURL
41}
42
43// cloudUtil implementation
44func (util *snowflakeAzureUtil) getFileHeader(meta *fileMetadata, filename string) *fileHeader {
45	container, ok := meta.client.(*azblob.ContainerURL)
46	if !ok {
47		return nil
48	}
49
50	azureLoc := util.extractContainerNameAndPath(meta.stageInfo.Location)
51	path := azureLoc.path + strings.TrimLeft(filename, "/")
52	b := container.NewBlockBlobURL(path)
53	resp, err := b.GetProperties(context.Background(), azblob.BlobAccessConditions{}, azblob.ClientProvidedKeyOptions{})
54	if err != nil {
55		var se azblob.StorageError
56		if errors.As(err, &se) {
57			if se.ServiceCode() == azblob.ServiceCodeBlobNotFound {
58				meta.resStatus = notFoundFile
59				return &fileHeader{}
60			} else if se.Response().StatusCode == 403 {
61				meta.resStatus = renewToken
62				return nil
63			}
64		}
65		meta.resStatus = errStatus
66		return nil
67	}
68
69	meta.resStatus = uploaded
70	metadata := resp.NewMetadata()
71	var encData encryptionData
72	if err = json.Unmarshal([]byte(metadata["encryptiondata"]), &encData); err != nil {
73		return nil
74	}
75	encryptionMetadata := encryptMetadata{
76		encData.WrappedContentKey.EncryptionKey,
77		encData.ContentEncryptionIV,
78		metadata["matdesc"],
79	}
80
81	return &fileHeader{
82		metadata["sfcdigest"],
83		int64(len(metadata)),
84		&encryptionMetadata,
85	}
86}
87
88// cloudUtil implementation
89func (util *snowflakeAzureUtil) uploadFile(
90	dataFile string,
91	meta *fileMetadata,
92	encryptMeta *encryptMetadata,
93	maxConcurrency int,
94	multiPartThreshold int64) error {
95	azureMeta := map[string]string{
96		"sfcdigest": meta.sha256Digest,
97	}
98	if encryptMeta != nil {
99		ed := &encryptionData{
100			EncryptionMode: "FullBlob",
101			WrappedContentKey: contentKey{
102				"symmKey1",
103				encryptMeta.key,
104				"AES_CBC_256",
105			},
106			EncryptionAgent: encryptionAgent{
107				"1.0",
108				"AES_CBC_128",
109			},
110			ContentEncryptionIV: encryptMeta.iv,
111			KeyWrappingMetadata: keyMetadata{
112				"Java 5.3.0",
113			},
114		}
115		metadata, _ := json.Marshal(ed)
116		azureMeta["encryptiondata"] = string(metadata)
117		azureMeta["matdesc"] = encryptMeta.matdesc
118	}
119
120	azureLoc := util.extractContainerNameAndPath(meta.stageInfo.Location)
121	path := azureLoc.path + strings.TrimLeft(meta.dstFileName, "/")
122	azContainerURL, ok := meta.client.(*azblob.ContainerURL)
123	if !ok {
124		return &SnowflakeError{
125			Message: "failed to cast to azure client",
126		}
127	}
128
129	var err error
130	blobURL := azContainerURL.NewBlockBlobURL(path)
131	if meta.srcStream != nil {
132		uploadSrc := meta.srcStream
133		if meta.realSrcStream != nil {
134			uploadSrc = meta.realSrcStream
135		}
136		_, err = azblob.UploadStreamToBlockBlob(context.Background(), uploadSrc, blobURL, azblob.UploadStreamToBlockBlobOptions{
137			BufferSize: uploadSrc.Len(),
138			Metadata:   azureMeta,
139		})
140	} else {
141		f, _ := os.OpenFile(dataFile, os.O_RDONLY, os.ModePerm)
142		defer f.Close()
143		fi, _ := f.Stat()
144		_, err = azblob.UploadFileToBlockBlob(context.Background(), f, blobURL, azblob.UploadToBlockBlobOptions{
145			BlockSize:   fi.Size(),
146			Parallelism: uint16(maxConcurrency),
147			Metadata:    azureMeta,
148		})
149	}
150	if err != nil {
151		var se azblob.StorageError
152		if errors.As(err, &se) {
153			if se.Response().StatusCode == 403 && util.detectAzureTokenExpireError(se.Response()) {
154				meta.resStatus = renewToken
155			} else {
156				meta.resStatus = needRetry
157				meta.lastError = err
158			}
159			return err
160		}
161		meta.resStatus = errStatus
162		return err
163	}
164
165	meta.dstFileSize = meta.uploadSize
166	meta.resStatus = uploaded
167	return nil
168}
169
170// cloudUtil implementation
171func (util *snowflakeAzureUtil) nativeDownloadFile(
172	meta *fileMetadata,
173	fullDstFileName string,
174	maxConcurrency int64) error {
175	azureLoc := util.extractContainerNameAndPath(meta.stageInfo.Location)
176	path := azureLoc.path + strings.TrimLeft(meta.dstFileName, "/")
177	azContainerURL, ok := meta.client.(*azblob.ContainerURL)
178	if !ok {
179		return &SnowflakeError{
180			Message: "failed to cast to azure client",
181		}
182	}
183
184	f, err := os.OpenFile(fullDstFileName, os.O_CREATE|os.O_WRONLY, os.ModePerm)
185	if err != nil {
186		return err
187	}
188	defer f.Close()
189	blobURL := azContainerURL.NewBlockBlobURL(path)
190	if err := azblob.DownloadBlobToFile(context.Background(), blobURL.BlobURL, 0, azblob.CountToEnd, f, azblob.DownloadFromBlobOptions{
191		Parallelism: uint16(maxConcurrency),
192	}); err != nil {
193		return err
194	}
195	meta.resStatus = downloaded
196	return nil
197}
198
199func (util *snowflakeAzureUtil) extractContainerNameAndPath(location string) *azureLocation {
200	stageLocation := expandUser(location)
201	containerName := stageLocation
202	path := ""
203
204	if strings.Contains(stageLocation, "/") {
205		containerName = stageLocation[:strings.Index(stageLocation, "/")]
206		path = stageLocation[strings.Index(stageLocation, "/")+1:]
207		if path != "" && !strings.HasSuffix(path, "/") {
208			path += "/"
209		}
210	}
211	return &azureLocation{containerName, path}
212}
213
214func (util *snowflakeAzureUtil) detectAzureTokenExpireError(resp *http.Response) bool {
215	if resp.StatusCode != 403 {
216		return false
217	}
218	azureErr, err := ioutil.ReadAll(resp.Body)
219	if err != nil {
220		return false
221	}
222	errStr := string(azureErr)
223	return strings.Contains(errStr, "Signature not valid in the specified time frame") ||
224		strings.Contains(errStr, "Server failed to authenticate the request")
225}
226