1package ec2 2 3import ( 4 "time" 5 6 "github.com/aws/aws-sdk-go/aws" 7 "github.com/aws/aws-sdk-go/aws/awsutil" 8 "github.com/aws/aws-sdk-go/aws/client" 9 "github.com/aws/aws-sdk-go/aws/endpoints" 10 "github.com/aws/aws-sdk-go/aws/request" 11 "github.com/aws/aws-sdk-go/internal/sdkrand" 12) 13 14type retryer struct { 15 client.DefaultRetryer 16} 17 18func (d retryer) RetryRules(r *request.Request) time.Duration { 19 switch r.Operation.Name { 20 case opModifyNetworkInterfaceAttribute: 21 fallthrough 22 case opAssignPrivateIpAddresses: 23 return customRetryRule(r) 24 default: 25 return d.DefaultRetryer.RetryRules(r) 26 } 27} 28 29func customRetryRule(r *request.Request) time.Duration { 30 retryTimes := []time.Duration{ 31 time.Second, 32 3 * time.Second, 33 5 * time.Second, 34 } 35 36 count := r.RetryCount 37 if count >= len(retryTimes) { 38 count = len(retryTimes) - 1 39 } 40 41 minTime := int(retryTimes[count]) 42 return time.Duration(sdkrand.SeededRand.Intn(minTime) + minTime) 43} 44 45func setCustomRetryer(c *client.Client) { 46 maxRetries := aws.IntValue(c.Config.MaxRetries) 47 if c.Config.MaxRetries == nil || maxRetries == aws.UseServiceDefaultRetries { 48 maxRetries = 3 49 } 50 51 c.Retryer = retryer{ 52 DefaultRetryer: client.DefaultRetryer{ 53 NumMaxRetries: maxRetries, 54 }, 55 } 56} 57 58func init() { 59 initClient = func(c *client.Client) { 60 if c.Config.Retryer == nil { 61 // Only override the retryer with a custom one if the config 62 // does not already contain a retryer 63 setCustomRetryer(c) 64 } 65 } 66 initRequest = func(r *request.Request) { 67 if r.Operation.Name == opCopySnapshot { // fill the PresignedURL parameter 68 r.Handlers.Build.PushFront(fillPresignedURL) 69 } 70 } 71} 72 73func fillPresignedURL(r *request.Request) { 74 if !r.ParamsFilled() { 75 return 76 } 77 78 origParams := r.Params.(*CopySnapshotInput) 79 80 // Stop if PresignedURL/DestinationRegion is set 81 if origParams.PresignedUrl != nil || origParams.DestinationRegion != nil { 82 return 83 } 84 85 origParams.DestinationRegion = r.Config.Region 86 newParams := awsutil.CopyOf(r.Params).(*CopySnapshotInput) 87 88 // Create a new request based on the existing request. We will use this to 89 // presign the CopySnapshot request against the source region. 90 cfg := r.Config.Copy(aws.NewConfig(). 91 WithEndpoint(""). 92 WithRegion(aws.StringValue(origParams.SourceRegion))) 93 94 clientInfo := r.ClientInfo 95 resolved, err := r.Config.EndpointResolver.EndpointFor( 96 clientInfo.ServiceName, aws.StringValue(cfg.Region), 97 func(opt *endpoints.Options) { 98 opt.DisableSSL = aws.BoolValue(cfg.DisableSSL) 99 opt.UseDualStack = aws.BoolValue(cfg.UseDualStack) 100 }, 101 ) 102 if err != nil { 103 r.Error = err 104 return 105 } 106 107 clientInfo.Endpoint = resolved.URL 108 clientInfo.SigningRegion = resolved.SigningRegion 109 110 // Presign a CopySnapshot request with modified params 111 req := request.New(*cfg, clientInfo, r.Handlers, r.Retryer, r.Operation, newParams, r.Data) 112 url, err := req.Presign(5 * time.Minute) // 5 minutes should be enough. 113 if err != nil { // bubble error back up to original request 114 r.Error = err 115 return 116 } 117 118 // We have our URL, set it on params 119 origParams.PresignedUrl = &url 120} 121