1// Copyright (c) 2021 Snowflake Computing Inc. All right reserved. 2 3package gosnowflake 4 5import ( 6 "bytes" 7 "context" 8 "errors" 9 "os" 10 "strings" 11 12 "github.com/aws/aws-sdk-go-v2/aws" 13 "github.com/aws/aws-sdk-go-v2/credentials" 14 "github.com/aws/aws-sdk-go-v2/feature/s3/manager" 15 "github.com/aws/aws-sdk-go-v2/service/s3" 16 "github.com/aws/smithy-go" 17) 18 19const ( 20 sfcDigest = "sfc-digest" 21 amzMatdesc = "x-amz-matdesc" 22 amzKey = "x-amz-key" 23 amzIv = "x-amz-iv" 24 25 notFound = "NotFound" 26 expiredToken = "ExpiredToken" 27 errNoWsaeconnaborted = "10053" 28) 29 30type snowflakeS3Util struct { 31} 32 33type s3Location struct { 34 bucketName string 35 s3Path string 36} 37 38func (util *snowflakeS3Util) createClient(info *execResponseStageInfo, useAccelerateEndpoint bool) cloudClient { 39 stageCredentials := info.Creds 40 var resolver s3.EndpointResolver 41 if info.EndPoint != "" { 42 resolver = s3.EndpointResolverFromURL("https://" + info.EndPoint) // FIPS endpoint 43 } 44 45 return s3.New(s3.Options{ 46 Region: info.Region, 47 Credentials: aws.NewCredentialsCache(credentials.NewStaticCredentialsProvider( 48 stageCredentials.AwsKeyID, 49 stageCredentials.AwsSecretKey, 50 stageCredentials.AwsToken)), 51 EndpointResolver: resolver, 52 UseAccelerate: useAccelerateEndpoint, 53 }) 54} 55 56type s3HeaderAPI interface { 57 HeadObject(ctx context.Context, params *s3.HeadObjectInput, optFns ...func(*s3.Options)) (*s3.HeadObjectOutput, error) 58} 59 60// cloudUtil implementation 61func (util *snowflakeS3Util) getFileHeader(meta *fileMetadata, filename string) *fileHeader { 62 headObjInput := util.getS3Object(meta, filename) 63 var s3Cli s3HeaderAPI 64 s3Cli, ok := meta.client.(*s3.Client) 65 if !ok { 66 return nil 67 } 68 if meta.mockHeader != nil { 69 s3Cli = meta.mockHeader 70 } 71 out, err := s3Cli.HeadObject(context.Background(), headObjInput) 72 if err != nil { 73 var ae smithy.APIError 74 if errors.As(err, &ae) { 75 if ae.ErrorCode() == notFound { 76 meta.resStatus = notFoundFile 77 return &fileHeader{ 78 digest: "", 79 contentLength: 0, 80 encryptionMetadata: nil, 81 } 82 } else if ae.ErrorCode() == expiredToken { 83 meta.resStatus = renewToken 84 return nil 85 } 86 meta.resStatus = errStatus 87 meta.lastError = err 88 return nil 89 } 90 } 91 92 meta.resStatus = uploaded 93 var encMeta encryptMetadata 94 if out.Metadata[amzKey] != "" { 95 encMeta = encryptMetadata{ 96 out.Metadata[amzKey], 97 out.Metadata[amzIv], 98 out.Metadata[amzMatdesc], 99 } 100 } 101 return &fileHeader{ 102 out.Metadata[sfcDigest], 103 out.ContentLength, 104 &encMeta, 105 } 106} 107 108type s3UploadAPI interface { 109 Upload(ctx context.Context, params *s3.PutObjectInput, optFns ...func(*manager.Uploader)) (*manager.UploadOutput, error) 110} 111 112// cloudUtil implementation 113func (util *snowflakeS3Util) uploadFile( 114 dataFile string, 115 meta *fileMetadata, 116 encryptMeta *encryptMetadata, 117 maxConcurrency int, 118 multiPartThreshold int64) error { 119 s3Meta := map[string]string{ 120 httpHeaderContentType: httpHeaderValueOctetStream, 121 sfcDigest: meta.sha256Digest, 122 } 123 if encryptMeta != nil { 124 s3Meta[amzIv] = encryptMeta.iv 125 s3Meta[amzKey] = encryptMeta.key 126 s3Meta[amzMatdesc] = encryptMeta.matdesc 127 } 128 129 s3loc := util.extractBucketNameAndPath(meta.stageInfo.Location) 130 s3path := s3loc.s3Path + strings.TrimLeft(meta.dstFileName, "/") 131 132 client, ok := meta.client.(*s3.Client) 133 if !ok { 134 return &SnowflakeError{ 135 Message: "failed to cast to s3 client", 136 } 137 } 138 var uploader s3UploadAPI 139 uploader = manager.NewUploader(client, func(u *manager.Uploader) { 140 u.Concurrency = maxConcurrency 141 u.PartSize = int64Max(multiPartThreshold, manager.DefaultUploadPartSize) 142 }) 143 if meta.mockUploader != nil { 144 uploader = meta.mockUploader 145 } 146 147 var err error 148 if meta.srcStream != nil { 149 uploadStream := meta.srcStream 150 if meta.realSrcStream != nil { 151 uploadStream = meta.realSrcStream 152 } 153 _, err = uploader.Upload(context.Background(), &s3.PutObjectInput{ 154 Bucket: &s3loc.bucketName, 155 Key: &s3path, 156 Body: bytes.NewBuffer(uploadStream.Bytes()), 157 Metadata: s3Meta, 158 }) 159 } else { 160 file, _ := os.Open(dataFile) 161 _, err = uploader.Upload(context.Background(), &s3.PutObjectInput{ 162 Bucket: &s3loc.bucketName, 163 Key: &s3path, 164 Body: file, 165 Metadata: s3Meta, 166 }) 167 } 168 169 if err != nil { 170 var ae smithy.APIError 171 if errors.As(err, &ae) { 172 if ae.ErrorCode() == expiredToken { 173 meta.resStatus = renewToken 174 return err 175 } else if strings.Contains(ae.ErrorCode(), errNoWsaeconnaborted) { 176 meta.lastError = err 177 meta.resStatus = needRetryWithLowerConcurrency 178 return err 179 } 180 } 181 meta.lastError = err 182 meta.resStatus = needRetry 183 return err 184 } 185 meta.dstFileSize = meta.uploadSize 186 meta.resStatus = uploaded 187 return nil 188} 189 190// cloudUtil implementation 191func (util *snowflakeS3Util) nativeDownloadFile( 192 meta *fileMetadata, 193 fullDstFileName string, 194 maxConcurrency int64) error { 195 s3loc := util.extractBucketNameAndPath(meta.stageInfo.Location) 196 s3path := s3loc.s3Path + strings.TrimLeft(meta.dstFileName, "/") 197 client, ok := meta.client.(*s3.Client) 198 if !ok { 199 return &SnowflakeError{ 200 Message: "failed to cast to s3 client", 201 } 202 } 203 204 f, err := os.OpenFile(fullDstFileName, os.O_CREATE|os.O_WRONLY, os.ModePerm) 205 if err != nil { 206 return err 207 } 208 defer f.Close() 209 downloader := manager.NewDownloader(client, func(u *manager.Downloader) { 210 u.Concurrency = int(maxConcurrency) 211 }) 212 if _, err = downloader.Download(context.Background(), f, &s3.GetObjectInput{ 213 Bucket: &s3loc.bucketName, 214 Key: &s3path, 215 }); err != nil { 216 var ae smithy.APIError 217 if errors.As(err, &ae) { 218 if ae.ErrorCode() == expiredToken { 219 meta.resStatus = renewToken 220 return err 221 } else if strings.Contains(ae.ErrorCode(), errNoWsaeconnaborted) { 222 meta.lastError = err 223 meta.resStatus = needRetryWithLowerConcurrency 224 return err 225 } 226 meta.lastError = err 227 meta.resStatus = errStatus 228 return err 229 } 230 meta.lastError = err 231 meta.resStatus = needRetry 232 return err 233 } 234 meta.resStatus = downloaded 235 return nil 236} 237 238func (util *snowflakeS3Util) extractBucketNameAndPath(location string) *s3Location { 239 stageLocation := expandUser(location) 240 bucketName := stageLocation 241 s3Path := "" 242 243 if idx := strings.Index(stageLocation, "/"); idx >= 0 { 244 bucketName = stageLocation[0:idx] 245 s3Path = stageLocation[idx+1:] 246 if s3Path != "" && !strings.HasSuffix(s3Path, "/") { 247 s3Path += "/" 248 } 249 } 250 return &s3Location{bucketName, s3Path} 251} 252 253func (util *snowflakeS3Util) getS3Object(meta *fileMetadata, filename string) *s3.HeadObjectInput { 254 s3loc := util.extractBucketNameAndPath(meta.stageInfo.Location) 255 s3path := s3loc.s3Path + strings.TrimLeft(filename, "/") 256 return &s3.HeadObjectInput{ 257 Bucket: &s3loc.bucketName, 258 Key: &s3path, 259 } 260} 261