1// +build integration 2 3package s3_test 4 5import ( 6 "bytes" 7 "context" 8 "crypto/tls" 9 "flag" 10 "fmt" 11 "io" 12 "io/ioutil" 13 "net/http" 14 "os" 15 "reflect" 16 "strings" 17 "testing" 18 "time" 19 20 "github.com/aws/aws-sdk-go/aws" 21 "github.com/aws/aws-sdk-go/aws/arn" 22 "github.com/aws/aws-sdk-go/aws/endpoints" 23 "github.com/aws/aws-sdk-go/aws/request" 24 "github.com/aws/aws-sdk-go/awstesting/integration" 25 "github.com/aws/aws-sdk-go/awstesting/integration/s3integ" 26 "github.com/aws/aws-sdk-go/service/s3" 27 "github.com/aws/aws-sdk-go/service/s3control" 28 "github.com/aws/aws-sdk-go/service/sts" 29) 30 31const integBucketPrefix = "aws-sdk-go-integration" 32 33var integMetadata = struct { 34 AccountID string 35 Region string 36 Buckets struct { 37 Source struct { 38 Name string 39 ARN string 40 } 41 Target struct { 42 Name string 43 ARN string 44 } 45 } 46 47 AccessPoints struct { 48 Source struct { 49 Name string 50 ARN string 51 } 52 Target struct { 53 Name string 54 ARN string 55 } 56 } 57}{} 58 59var s3Svc *s3.S3 60var s3ControlSvc *s3control.S3Control 61var stsSvc *sts.STS 62var httpClient *http.Client 63 64// TODO: (Westeros) Remove Custom Resolver Usage Before Launch 65type customS3Resolver struct { 66 endpoint string 67 withTLS bool 68 region string 69} 70 71func (r customS3Resolver) EndpointFor(service, _ string, opts ...func(*endpoints.Options)) (endpoints.ResolvedEndpoint, error) { 72 switch strings.ToLower(service) { 73 case "s3-control": 74 case "s3": 75 default: 76 return endpoints.ResolvedEndpoint{}, fmt.Errorf("unsupported in custom resolver") 77 } 78 79 return endpoints.ResolvedEndpoint{ 80 PartitionID: "aws", 81 SigningRegion: r.region, 82 SigningName: "s3", 83 SigningMethod: "s3v4", 84 URL: endpoints.AddScheme(r.endpoint, r.withTLS), 85 }, nil 86} 87 88func TestMain(m *testing.M) { 89 var result int 90 defer func() { 91 if r := recover(); r != nil { 92 fmt.Fprintln(os.Stderr, "S3 integration tests paniced,", r) 93 result = 1 94 } 95 os.Exit(result) 96 }() 97 98 var verifyTLS bool 99 var s3Endpoint, s3ControlEndpoint string 100 var s3EnableTLS, s3ControlEnableTLS bool 101 102 flag.StringVar(&s3Endpoint, "s3-endpoint", "", "integration endpoint for S3") 103 flag.BoolVar(&s3EnableTLS, "s3-tls", true, "enable TLS for S3 endpoint") 104 105 flag.StringVar(&s3ControlEndpoint, "s3-control-endpoint", "", "integration endpoint for S3") 106 flag.BoolVar(&s3ControlEnableTLS, "s3-control-tls", true, "enable TLS for S3 control endpoint") 107 108 flag.StringVar(&integMetadata.AccountID, "account", "", "integration account id") 109 flag.BoolVar(&verifyTLS, "verify-tls", true, "verify server TLS certificate") 110 flag.Parse() 111 112 httpClient = &http.Client{ 113 Transport: &http.Transport{ 114 TLSClientConfig: &tls.Config{InsecureSkipVerify: verifyTLS}, 115 }} 116 117 sess := integration.SessionWithDefaultRegion("us-west-2").Copy(&aws.Config{ 118 HTTPClient: httpClient, 119 }) 120 121 var s3EndpointResolver endpoints.Resolver 122 if len(s3Endpoint) != 0 { 123 s3EndpointResolver = customS3Resolver{ 124 endpoint: s3Endpoint, 125 withTLS: s3EnableTLS, 126 region: aws.StringValue(sess.Config.Region), 127 } 128 } 129 s3Svc = s3.New(sess, &aws.Config{ 130 DisableSSL: aws.Bool(!s3EnableTLS), 131 EndpointResolver: s3EndpointResolver, 132 }) 133 134 var s3ControlEndpointResolver endpoints.Resolver 135 if len(s3Endpoint) != 0 { 136 s3ControlEndpointResolver = customS3Resolver{ 137 endpoint: s3ControlEndpoint, 138 withTLS: s3ControlEnableTLS, 139 region: aws.StringValue(sess.Config.Region), 140 } 141 } 142 s3ControlSvc = s3control.New(sess, &aws.Config{ 143 DisableSSL: aws.Bool(!s3ControlEnableTLS), 144 EndpointResolver: s3ControlEndpointResolver, 145 }) 146 stsSvc = sts.New(sess) 147 148 var err error 149 integMetadata.AccountID, err = getAccountID() 150 if err != nil { 151 fmt.Fprintf(os.Stderr, "failed to get integration aws account id: %v\n", err) 152 result = 1 153 return 154 } 155 156 bucketCleanup, err := setupBuckets() 157 defer bucketCleanup() 158 if err != nil { 159 fmt.Fprintf(os.Stderr, "failed to setup integration test buckets: %v\n", err) 160 result = 1 161 return 162 } 163 164 accessPointsCleanup, err := setupAccessPoints() 165 defer accessPointsCleanup() 166 if err != nil { 167 fmt.Fprintf(os.Stderr, "failed to setup integration test access points: %v\n", err) 168 result = 1 169 return 170 } 171 172 result = m.Run() 173} 174 175func getAccountID() (string, error) { 176 if len(integMetadata.AccountID) != 0 { 177 return integMetadata.AccountID, nil 178 } 179 180 output, err := stsSvc.GetCallerIdentity(nil) 181 if err != nil { 182 return "", fmt.Errorf("faield to get sts caller identity") 183 } 184 185 return *output.Account, nil 186} 187 188func setupBuckets() (func(), error) { 189 var cleanups []func() 190 191 cleanup := func() { 192 for i := range cleanups { 193 cleanups[i]() 194 } 195 } 196 197 bucketCreates := []struct { 198 name *string 199 arn *string 200 }{ 201 {name: &integMetadata.Buckets.Source.Name, arn: &integMetadata.Buckets.Source.ARN}, 202 {name: &integMetadata.Buckets.Target.Name, arn: &integMetadata.Buckets.Target.ARN}, 203 } 204 205 for _, bucket := range bucketCreates { 206 *bucket.name = s3integ.GenerateBucketName() 207 208 if err := s3integ.SetupBucket(s3Svc, *bucket.name); err != nil { 209 return cleanup, err 210 } 211 212 // Compute ARN 213 bARN := arn.ARN{ 214 Partition: "aws", 215 Service: "s3", 216 Region: s3Svc.SigningRegion, 217 AccountID: integMetadata.AccountID, 218 Resource: fmt.Sprintf("bucket_name:%s", *bucket.name), 219 }.String() 220 221 *bucket.arn = bARN 222 223 bucketName := *bucket.name 224 cleanups = append(cleanups, func() { 225 if err := s3integ.CleanupBucket(s3Svc, bucketName); err != nil { 226 fmt.Fprintln(os.Stderr, err) 227 } 228 }) 229 } 230 231 return cleanup, nil 232} 233 234func setupAccessPoints() (func(), error) { 235 var cleanups []func() 236 237 cleanup := func() { 238 for i := range cleanups { 239 cleanups[i]() 240 } 241 } 242 243 creates := []struct { 244 bucket string 245 name *string 246 arn *string 247 }{ 248 {bucket: integMetadata.Buckets.Source.Name, name: &integMetadata.AccessPoints.Source.Name, arn: &integMetadata.AccessPoints.Source.ARN}, 249 {bucket: integMetadata.Buckets.Target.Name, name: &integMetadata.AccessPoints.Target.Name, arn: &integMetadata.AccessPoints.Target.ARN}, 250 } 251 252 for _, ap := range creates { 253 *ap.name = integration.UniqueID() 254 255 err := s3integ.SetupAccessPoint(s3ControlSvc, integMetadata.AccountID, ap.bucket, *ap.name) 256 if err != nil { 257 return cleanup, err 258 } 259 260 // Compute ARN 261 apARN := arn.ARN{ 262 Partition: "aws", 263 Service: "s3", 264 Region: s3ControlSvc.SigningRegion, 265 AccountID: integMetadata.AccountID, 266 Resource: fmt.Sprintf("accesspoint/%s", *ap.name), 267 }.String() 268 269 *ap.arn = apARN 270 271 apName := *ap.name 272 cleanups = append(cleanups, func() { 273 err := s3integ.CleanupAccessPoint(s3ControlSvc, integMetadata.AccountID, apName) 274 if err != nil { 275 fmt.Fprintln(os.Stderr, err) 276 } 277 }) 278 } 279 280 return cleanup, nil 281} 282 283func putTestFile(t *testing.T, filename, key string, opts ...request.Option) { 284 f, err := os.Open(filename) 285 if err != nil { 286 t.Fatalf("failed to open testfile, %v", err) 287 } 288 defer f.Close() 289 290 putTestContent(t, f, key, opts...) 291} 292 293func putTestContent(t *testing.T, reader io.ReadSeeker, key string, opts ...request.Option) { 294 t.Logf("uploading test file %s/%s", integMetadata.Buckets.Source.Name, key) 295 _, err := s3Svc.PutObjectWithContext(context.Background(), 296 &s3.PutObjectInput{ 297 Bucket: &integMetadata.Buckets.Source.Name, 298 Key: aws.String(key), 299 Body: reader, 300 }, opts...) 301 if err != nil { 302 t.Errorf("expect no error, got %v", err) 303 } 304} 305 306func testWriteToObject(t *testing.T, bucket string, opts ...request.Option) { 307 key := integration.UniqueID() 308 309 _, err := s3Svc.PutObjectWithContext(context.Background(), 310 &s3.PutObjectInput{ 311 Bucket: &bucket, 312 Key: &key, 313 Body: bytes.NewReader([]byte("hello world")), 314 }, opts...) 315 if err != nil { 316 t.Fatalf("expect no error, got %v", err) 317 } 318 319 resp, err := s3Svc.GetObjectWithContext(context.Background(), 320 &s3.GetObjectInput{ 321 Bucket: &bucket, 322 Key: &key, 323 }, opts...) 324 if err != nil { 325 t.Fatalf("expect no error, got %v", err) 326 } 327 328 b, _ := ioutil.ReadAll(resp.Body) 329 if e, a := []byte("hello world"), b; !bytes.Equal(e, a) { 330 t.Errorf("expect %v, got %v", e, a) 331 } 332} 333 334func testPresignedGetPut(t *testing.T, bucket string, opts ...request.Option) { 335 key := integration.UniqueID() 336 337 putreq, _ := s3Svc.PutObjectRequest(&s3.PutObjectInput{ 338 Bucket: &bucket, 339 Key: &key, 340 }) 341 putreq.ApplyOptions(opts...) 342 var err error 343 344 // Presign a PUT request 345 var puturl string 346 puturl, err = putreq.Presign(5 * time.Minute) 347 if err != nil { 348 t.Fatalf("expect no error, got %v", err) 349 } 350 351 // PUT to the presigned URL with a body 352 var puthttpreq *http.Request 353 buf := bytes.NewReader([]byte("hello world")) 354 puthttpreq, err = http.NewRequest("PUT", puturl, buf) 355 if err != nil { 356 t.Fatalf("expect no error, got %v", err) 357 } 358 359 var putresp *http.Response 360 putresp, err = httpClient.Do(puthttpreq) 361 if err != nil { 362 t.Errorf("expect put with presign url no error, got %v", err) 363 } 364 if e, a := 200, putresp.StatusCode; e != a { 365 t.Fatalf("expect %v, got %v", e, a) 366 } 367 368 // Presign a GET on the same URL 369 getreq, _ := s3Svc.GetObjectRequest(&s3.GetObjectInput{ 370 Bucket: &bucket, 371 Key: &key, 372 }) 373 getreq.ApplyOptions(opts...) 374 375 var geturl string 376 geturl, err = getreq.Presign(300 * time.Second) 377 if err != nil { 378 t.Fatalf("expect no error, got %v", err) 379 } 380 381 // Get the body 382 var getresp *http.Response 383 getresp, err = httpClient.Get(geturl) 384 if err != nil { 385 t.Fatalf("expect no error, got %v", err) 386 } 387 388 var b []byte 389 defer getresp.Body.Close() 390 b, err = ioutil.ReadAll(getresp.Body) 391 if e, a := "hello world", string(b); e != a { 392 t.Fatalf("expect %v, got %v", e, a) 393 } 394} 395 396func testCopyObject(t *testing.T, sourceBucket string, targetBucket string, opts ...request.Option) { 397 key := integration.UniqueID() 398 399 _, err := s3Svc.PutObjectWithContext(context.Background(), 400 &s3.PutObjectInput{ 401 Bucket: &sourceBucket, 402 Key: &key, 403 Body: bytes.NewReader([]byte("hello world")), 404 }, opts...) 405 if err != nil { 406 t.Fatalf("expect no error, got %v", err) 407 } 408 409 _, err = s3Svc.CopyObjectWithContext(context.Background(), 410 &s3.CopyObjectInput{ 411 Bucket: &targetBucket, 412 CopySource: aws.String("/" + sourceBucket + "/" + key), 413 Key: &key, 414 }, opts...) 415 if err != nil { 416 t.Fatalf("expect no error, got %v", err) 417 } 418 419 resp, err := s3Svc.GetObjectWithContext(context.Background(), 420 &s3.GetObjectInput{ 421 Bucket: &targetBucket, 422 Key: &key, 423 }, opts...) 424 if err != nil { 425 t.Fatalf("expect no error, got %v", err) 426 } 427 428 b, _ := ioutil.ReadAll(resp.Body) 429 if e, a := []byte("hello world"), b; !reflect.DeepEqual(e, a) { 430 t.Errorf("expect %v, got %v", e, a) 431 } 432} 433