1// Copyright (c) 2021 Snowflake Computing Inc. All right reserved.
2
3package gosnowflake
4
5import (
6	"bytes"
7	"context"
8	"errors"
9	"os"
10	"strings"
11
12	"github.com/aws/aws-sdk-go-v2/aws"
13	"github.com/aws/aws-sdk-go-v2/credentials"
14	"github.com/aws/aws-sdk-go-v2/feature/s3/manager"
15	"github.com/aws/aws-sdk-go-v2/service/s3"
16	"github.com/aws/smithy-go"
17)
18
19const (
20	sfcDigest  = "sfc-digest"
21	amzMatdesc = "x-amz-matdesc"
22	amzKey     = "x-amz-key"
23	amzIv      = "x-amz-iv"
24
25	notFound             = "NotFound"
26	expiredToken         = "ExpiredToken"
27	errNoWsaeconnaborted = "10053"
28)
29
30type snowflakeS3Util struct {
31}
32
33type s3Location struct {
34	bucketName string
35	s3Path     string
36}
37
38func (util *snowflakeS3Util) createClient(info *execResponseStageInfo, useAccelerateEndpoint bool) cloudClient {
39	stageCredentials := info.Creds
40	var resolver s3.EndpointResolver
41	if info.EndPoint != "" {
42		resolver = s3.EndpointResolverFromURL("https://" + info.EndPoint) // FIPS endpoint
43	}
44
45	return s3.New(s3.Options{
46		Region: info.Region,
47		Credentials: aws.NewCredentialsCache(credentials.NewStaticCredentialsProvider(
48			stageCredentials.AwsKeyID,
49			stageCredentials.AwsSecretKey,
50			stageCredentials.AwsToken)),
51		EndpointResolver: resolver,
52		UseAccelerate:    useAccelerateEndpoint,
53	})
54}
55
56type s3HeaderAPI interface {
57	HeadObject(ctx context.Context, params *s3.HeadObjectInput, optFns ...func(*s3.Options)) (*s3.HeadObjectOutput, error)
58}
59
60// cloudUtil implementation
61func (util *snowflakeS3Util) getFileHeader(meta *fileMetadata, filename string) *fileHeader {
62	headObjInput := util.getS3Object(meta, filename)
63	var s3Cli s3HeaderAPI
64	s3Cli, ok := meta.client.(*s3.Client)
65	if !ok {
66		return nil
67	}
68	if meta.mockHeader != nil {
69		s3Cli = meta.mockHeader
70	}
71	out, err := s3Cli.HeadObject(context.Background(), headObjInput)
72	if err != nil {
73		var ae smithy.APIError
74		if errors.As(err, &ae) {
75			if ae.ErrorCode() == notFound {
76				meta.resStatus = notFoundFile
77				return &fileHeader{
78					digest:             "",
79					contentLength:      0,
80					encryptionMetadata: nil,
81				}
82			} else if ae.ErrorCode() == expiredToken {
83				meta.resStatus = renewToken
84				return nil
85			}
86			meta.resStatus = errStatus
87			meta.lastError = err
88			return nil
89		}
90	}
91
92	meta.resStatus = uploaded
93	var encMeta encryptMetadata
94	if out.Metadata[amzKey] != "" {
95		encMeta = encryptMetadata{
96			out.Metadata[amzKey],
97			out.Metadata[amzIv],
98			out.Metadata[amzMatdesc],
99		}
100	}
101	return &fileHeader{
102		out.Metadata[sfcDigest],
103		out.ContentLength,
104		&encMeta,
105	}
106}
107
108type s3UploadAPI interface {
109	Upload(ctx context.Context, params *s3.PutObjectInput, optFns ...func(*manager.Uploader)) (*manager.UploadOutput, error)
110}
111
112// cloudUtil implementation
113func (util *snowflakeS3Util) uploadFile(
114	dataFile string,
115	meta *fileMetadata,
116	encryptMeta *encryptMetadata,
117	maxConcurrency int,
118	multiPartThreshold int64) error {
119	s3Meta := map[string]string{
120		httpHeaderContentType: httpHeaderValueOctetStream,
121		sfcDigest:             meta.sha256Digest,
122	}
123	if encryptMeta != nil {
124		s3Meta[amzIv] = encryptMeta.iv
125		s3Meta[amzKey] = encryptMeta.key
126		s3Meta[amzMatdesc] = encryptMeta.matdesc
127	}
128
129	s3loc := util.extractBucketNameAndPath(meta.stageInfo.Location)
130	s3path := s3loc.s3Path + strings.TrimLeft(meta.dstFileName, "/")
131
132	client, ok := meta.client.(*s3.Client)
133	if !ok {
134		return &SnowflakeError{
135			Message: "failed to cast to s3 client",
136		}
137	}
138	var uploader s3UploadAPI
139	uploader = manager.NewUploader(client, func(u *manager.Uploader) {
140		u.Concurrency = maxConcurrency
141		u.PartSize = int64Max(multiPartThreshold, manager.DefaultUploadPartSize)
142	})
143	if meta.mockUploader != nil {
144		uploader = meta.mockUploader
145	}
146
147	var err error
148	if meta.srcStream != nil {
149		uploadStream := meta.srcStream
150		if meta.realSrcStream != nil {
151			uploadStream = meta.realSrcStream
152		}
153		_, err = uploader.Upload(context.Background(), &s3.PutObjectInput{
154			Bucket:   &s3loc.bucketName,
155			Key:      &s3path,
156			Body:     bytes.NewBuffer(uploadStream.Bytes()),
157			Metadata: s3Meta,
158		})
159	} else {
160		file, _ := os.Open(dataFile)
161		_, err = uploader.Upload(context.Background(), &s3.PutObjectInput{
162			Bucket:   &s3loc.bucketName,
163			Key:      &s3path,
164			Body:     file,
165			Metadata: s3Meta,
166		})
167	}
168
169	if err != nil {
170		var ae smithy.APIError
171		if errors.As(err, &ae) {
172			if ae.ErrorCode() == expiredToken {
173				meta.resStatus = renewToken
174				return err
175			} else if strings.Contains(ae.ErrorCode(), errNoWsaeconnaborted) {
176				meta.lastError = err
177				meta.resStatus = needRetryWithLowerConcurrency
178				return err
179			}
180		}
181		meta.lastError = err
182		meta.resStatus = needRetry
183		return err
184	}
185	meta.dstFileSize = meta.uploadSize
186	meta.resStatus = uploaded
187	return nil
188}
189
190// cloudUtil implementation
191func (util *snowflakeS3Util) nativeDownloadFile(
192	meta *fileMetadata,
193	fullDstFileName string,
194	maxConcurrency int64) error {
195	s3loc := util.extractBucketNameAndPath(meta.stageInfo.Location)
196	s3path := s3loc.s3Path + strings.TrimLeft(meta.dstFileName, "/")
197	client, ok := meta.client.(*s3.Client)
198	if !ok {
199		return &SnowflakeError{
200			Message: "failed to cast to s3 client",
201		}
202	}
203
204	f, err := os.OpenFile(fullDstFileName, os.O_CREATE|os.O_WRONLY, os.ModePerm)
205	if err != nil {
206		return err
207	}
208	defer f.Close()
209	downloader := manager.NewDownloader(client, func(u *manager.Downloader) {
210		u.Concurrency = int(maxConcurrency)
211	})
212	if _, err = downloader.Download(context.Background(), f, &s3.GetObjectInput{
213		Bucket: &s3loc.bucketName,
214		Key:    &s3path,
215	}); err != nil {
216		var ae smithy.APIError
217		if errors.As(err, &ae) {
218			if ae.ErrorCode() == expiredToken {
219				meta.resStatus = renewToken
220				return err
221			} else if strings.Contains(ae.ErrorCode(), errNoWsaeconnaborted) {
222				meta.lastError = err
223				meta.resStatus = needRetryWithLowerConcurrency
224				return err
225			}
226			meta.lastError = err
227			meta.resStatus = errStatus
228			return err
229		}
230		meta.lastError = err
231		meta.resStatus = needRetry
232		return err
233	}
234	meta.resStatus = downloaded
235	return nil
236}
237
238func (util *snowflakeS3Util) extractBucketNameAndPath(location string) *s3Location {
239	stageLocation := expandUser(location)
240	bucketName := stageLocation
241	s3Path := ""
242
243	if idx := strings.Index(stageLocation, "/"); idx >= 0 {
244		bucketName = stageLocation[0:idx]
245		s3Path = stageLocation[idx+1:]
246		if s3Path != "" && !strings.HasSuffix(s3Path, "/") {
247			s3Path += "/"
248		}
249	}
250	return &s3Location{bucketName, s3Path}
251}
252
253func (util *snowflakeS3Util) getS3Object(meta *fileMetadata, filename string) *s3.HeadObjectInput {
254	s3loc := util.extractBucketNameAndPath(meta.stageInfo.Location)
255	s3path := s3loc.s3Path + strings.TrimLeft(filename, "/")
256	return &s3.HeadObjectInput{
257		Bucket: &s3loc.bucketName,
258		Key:    &s3path,
259	}
260}
261