1// Copyright (c) 2021 Snowflake Computing Inc. All right reserved. 2 3package gosnowflake 4 5import ( 6 "context" 7 "encoding/json" 8 "errors" 9 "fmt" 10 "io/ioutil" 11 "net/http" 12 "net/url" 13 "os" 14 "strings" 15 "time" 16 17 "github.com/Azure/azure-storage-blob-go/azblob" 18) 19 20type snowflakeAzureUtil struct { 21} 22 23type azureLocation struct { 24 containerName string 25 path string 26} 27 28func (util *snowflakeAzureUtil) createClient(info *execResponseStageInfo, _ bool) cloudClient { 29 sasToken := info.Creds.AzureSasToken 30 p := azblob.NewPipeline(azblob.NewAnonymousCredential(), azblob.PipelineOptions{ 31 Retry: azblob.RetryOptions{ 32 Policy: azblob.RetryPolicyExponential, 33 MaxTries: 60, 34 RetryDelay: 2 * time.Second, 35 }, 36 }) 37 38 u, _ := url.Parse(fmt.Sprintf("https://%s.%s/%s%s", info.StorageAccount, info.EndPoint, info.Path, sasToken)) 39 containerURL := azblob.NewContainerURL(*u, p) 40 return &containerURL 41} 42 43// cloudUtil implementation 44func (util *snowflakeAzureUtil) getFileHeader(meta *fileMetadata, filename string) *fileHeader { 45 container, ok := meta.client.(*azblob.ContainerURL) 46 if !ok { 47 return nil 48 } 49 50 azureLoc := util.extractContainerNameAndPath(meta.stageInfo.Location) 51 path := azureLoc.path + strings.TrimLeft(filename, "/") 52 b := container.NewBlockBlobURL(path) 53 resp, err := b.GetProperties(context.Background(), azblob.BlobAccessConditions{}, azblob.ClientProvidedKeyOptions{}) 54 if err != nil { 55 var se azblob.StorageError 56 if errors.As(err, &se) { 57 if se.ServiceCode() == azblob.ServiceCodeBlobNotFound { 58 meta.resStatus = notFoundFile 59 return &fileHeader{} 60 } else if se.Response().StatusCode == 403 { 61 meta.resStatus = renewToken 62 return nil 63 } 64 } 65 meta.resStatus = errStatus 66 return nil 67 } 68 69 meta.resStatus = uploaded 70 metadata := resp.NewMetadata() 71 var encData encryptionData 72 if err = json.Unmarshal([]byte(metadata["encryptiondata"]), &encData); err != nil { 73 return nil 74 } 75 encryptionMetadata := encryptMetadata{ 76 encData.WrappedContentKey.EncryptionKey, 77 encData.ContentEncryptionIV, 78 metadata["matdesc"], 79 } 80 81 return &fileHeader{ 82 metadata["sfcdigest"], 83 int64(len(metadata)), 84 &encryptionMetadata, 85 } 86} 87 88// cloudUtil implementation 89func (util *snowflakeAzureUtil) uploadFile( 90 dataFile string, 91 meta *fileMetadata, 92 encryptMeta *encryptMetadata, 93 maxConcurrency int, 94 multiPartThreshold int64) error { 95 azureMeta := map[string]string{ 96 "sfcdigest": meta.sha256Digest, 97 } 98 if encryptMeta != nil { 99 ed := &encryptionData{ 100 EncryptionMode: "FullBlob", 101 WrappedContentKey: contentKey{ 102 "symmKey1", 103 encryptMeta.key, 104 "AES_CBC_256", 105 }, 106 EncryptionAgent: encryptionAgent{ 107 "1.0", 108 "AES_CBC_128", 109 }, 110 ContentEncryptionIV: encryptMeta.iv, 111 KeyWrappingMetadata: keyMetadata{ 112 "Java 5.3.0", 113 }, 114 } 115 metadata, _ := json.Marshal(ed) 116 azureMeta["encryptiondata"] = string(metadata) 117 azureMeta["matdesc"] = encryptMeta.matdesc 118 } 119 120 azureLoc := util.extractContainerNameAndPath(meta.stageInfo.Location) 121 path := azureLoc.path + strings.TrimLeft(meta.dstFileName, "/") 122 azContainerURL, ok := meta.client.(*azblob.ContainerURL) 123 if !ok { 124 return &SnowflakeError{ 125 Message: "failed to cast to azure client", 126 } 127 } 128 129 var err error 130 blobURL := azContainerURL.NewBlockBlobURL(path) 131 if meta.srcStream != nil { 132 uploadSrc := meta.srcStream 133 if meta.realSrcStream != nil { 134 uploadSrc = meta.realSrcStream 135 } 136 _, err = azblob.UploadStreamToBlockBlob(context.Background(), uploadSrc, blobURL, azblob.UploadStreamToBlockBlobOptions{ 137 BufferSize: uploadSrc.Len(), 138 Metadata: azureMeta, 139 }) 140 } else { 141 f, _ := os.OpenFile(dataFile, os.O_RDONLY, os.ModePerm) 142 defer f.Close() 143 fi, _ := f.Stat() 144 _, err = azblob.UploadFileToBlockBlob(context.Background(), f, blobURL, azblob.UploadToBlockBlobOptions{ 145 BlockSize: fi.Size(), 146 Parallelism: uint16(maxConcurrency), 147 Metadata: azureMeta, 148 }) 149 } 150 if err != nil { 151 var se azblob.StorageError 152 if errors.As(err, &se) { 153 if se.Response().StatusCode == 403 && util.detectAzureTokenExpireError(se.Response()) { 154 meta.resStatus = renewToken 155 } else { 156 meta.resStatus = needRetry 157 meta.lastError = err 158 } 159 return err 160 } 161 meta.resStatus = errStatus 162 return err 163 } 164 165 meta.dstFileSize = meta.uploadSize 166 meta.resStatus = uploaded 167 return nil 168} 169 170// cloudUtil implementation 171func (util *snowflakeAzureUtil) nativeDownloadFile( 172 meta *fileMetadata, 173 fullDstFileName string, 174 maxConcurrency int64) error { 175 azureLoc := util.extractContainerNameAndPath(meta.stageInfo.Location) 176 path := azureLoc.path + strings.TrimLeft(meta.dstFileName, "/") 177 azContainerURL, ok := meta.client.(*azblob.ContainerURL) 178 if !ok { 179 return &SnowflakeError{ 180 Message: "failed to cast to azure client", 181 } 182 } 183 184 f, err := os.OpenFile(fullDstFileName, os.O_CREATE|os.O_WRONLY, os.ModePerm) 185 if err != nil { 186 return err 187 } 188 defer f.Close() 189 blobURL := azContainerURL.NewBlockBlobURL(path) 190 if err := azblob.DownloadBlobToFile(context.Background(), blobURL.BlobURL, 0, azblob.CountToEnd, f, azblob.DownloadFromBlobOptions{ 191 Parallelism: uint16(maxConcurrency), 192 }); err != nil { 193 return err 194 } 195 meta.resStatus = downloaded 196 return nil 197} 198 199func (util *snowflakeAzureUtil) extractContainerNameAndPath(location string) *azureLocation { 200 stageLocation := expandUser(location) 201 containerName := stageLocation 202 path := "" 203 204 if strings.Contains(stageLocation, "/") { 205 containerName = stageLocation[:strings.Index(stageLocation, "/")] 206 path = stageLocation[strings.Index(stageLocation, "/")+1:] 207 if path != "" && !strings.HasSuffix(path, "/") { 208 path += "/" 209 } 210 } 211 return &azureLocation{containerName, path} 212} 213 214func (util *snowflakeAzureUtil) detectAzureTokenExpireError(resp *http.Response) bool { 215 if resp.StatusCode != 403 { 216 return false 217 } 218 azureErr, err := ioutil.ReadAll(resp.Body) 219 if err != nil { 220 return false 221 } 222 errStr := string(azureErr) 223 return strings.Contains(errStr, "Signature not valid in the specified time frame") || 224 strings.Contains(errStr, "Server failed to authenticate the request") 225} 226