1package customizations
2
3import (
4	"context"
5	"fmt"
6	"log"
7	"net/url"
8	"strings"
9
10	"github.com/aws/aws-sdk-go-v2/aws"
11	"github.com/aws/smithy-go/middleware"
12	smithyhttp "github.com/aws/smithy-go/transport/http"
13
14	awsmiddleware "github.com/aws/aws-sdk-go-v2/aws/middleware"
15	"github.com/aws/aws-sdk-go-v2/service/internal/s3shared"
16
17	internalendpoints "github.com/aws/aws-sdk-go-v2/service/s3/internal/endpoints"
18)
19
20// EndpointResolver interface for resolving service endpoints.
21type EndpointResolver interface {
22	ResolveEndpoint(region string, options EndpointResolverOptions) (aws.Endpoint, error)
23}
24
25// EndpointResolverOptions is the service endpoint resolver options
26type EndpointResolverOptions = internalendpoints.Options
27
28// UpdateEndpointParameterAccessor represents accessor functions used by the middleware
29type UpdateEndpointParameterAccessor struct {
30	// functional pointer to fetch bucket name from provided input.
31	// The function is intended to take an input value, and
32	// return a string pointer to value of string, and bool if
33	// input has no bucket member.
34	GetBucketFromInput func(interface{}) (*string, bool)
35}
36
37// UpdateEndpointOptions provides the options for the UpdateEndpoint middleware setup.
38type UpdateEndpointOptions struct {
39
40	// Accessor are parameter accessors used by the middleware
41	Accessor UpdateEndpointParameterAccessor
42
43	// use path style
44	UsePathStyle bool
45
46	// use transfer acceleration
47	UseAccelerate bool
48
49	// indicates if an operation supports s3 transfer acceleration.
50	SupportsAccelerate bool
51
52	// use dualstack
53	UseDualstack bool
54
55	// use ARN region
56	UseARNRegion bool
57
58	// EndpointResolver used to resolve endpoints. This may be a custom endpoint resolver
59	EndpointResolver EndpointResolver
60
61	// EndpointResolverOptions used by endpoint resolver
62	EndpointResolverOptions EndpointResolverOptions
63}
64
65// UpdateEndpoint adds the middleware to the middleware stack based on the UpdateEndpointOptions.
66func UpdateEndpoint(stack *middleware.Stack, options UpdateEndpointOptions) (err error) {
67	// initial arn look up middleware
68	err = stack.Initialize.Add(&s3shared.ARNLookup{
69		GetARNValue: options.Accessor.GetBucketFromInput,
70	}, middleware.Before)
71	if err != nil {
72		return err
73	}
74
75	// process arn
76	err = stack.Serialize.Insert(&processARNResource{
77		UseARNRegion:            options.UseARNRegion,
78		UseAccelerate:           options.UseAccelerate,
79		UseDualstack:            options.UseDualstack,
80		EndpointResolver:        options.EndpointResolver,
81		EndpointResolverOptions: options.EndpointResolverOptions,
82	}, "OperationSerializer", middleware.Before)
83	if err != nil {
84		return err
85	}
86
87	// remove bucket arn middleware
88	err = stack.Serialize.Insert(&removeBucketFromPathMiddleware{}, "OperationSerializer", middleware.After)
89	if err != nil {
90		return err
91	}
92
93	// enable dual stack support
94	err = stack.Serialize.Insert(&s3shared.EnableDualstack{
95		UseDualstack:     options.UseDualstack,
96		DefaultServiceID: "s3",
97	}, "OperationSerializer", middleware.After)
98	if err != nil {
99		return err
100	}
101
102	// update endpoint to use options for path style and accelerate
103	err = stack.Serialize.Insert(&updateEndpoint{
104		usePathStyle:       options.UsePathStyle,
105		getBucketFromInput: options.Accessor.GetBucketFromInput,
106		useAccelerate:      options.UseAccelerate,
107		supportsAccelerate: options.SupportsAccelerate,
108	}, (*s3shared.EnableDualstack)(nil).ID(), middleware.After)
109	if err != nil {
110		return err
111	}
112
113	return err
114}
115
116type updateEndpoint struct {
117	// path style options
118	usePathStyle       bool
119	getBucketFromInput func(interface{}) (*string, bool)
120
121	// accelerate options
122	useAccelerate      bool
123	supportsAccelerate bool
124}
125
126// ID returns the middleware ID.
127func (*updateEndpoint) ID() string {
128	return "S3:UpdateEndpoint"
129}
130
131func (u *updateEndpoint) HandleSerialize(
132	ctx context.Context, in middleware.SerializeInput, next middleware.SerializeHandler,
133) (
134	out middleware.SerializeOutput, metadata middleware.Metadata, err error,
135) {
136	// if arn was processed, skip this middleware
137	if _, ok := s3shared.GetARNResourceFromContext(ctx); ok {
138		return next.HandleSerialize(ctx, in)
139	}
140
141	// skip this customization if host name is set as immutable
142	if smithyhttp.GetHostnameImmutable(ctx) {
143		return next.HandleSerialize(ctx, in)
144	}
145
146	req, ok := in.Request.(*smithyhttp.Request)
147	if !ok {
148		return out, metadata, fmt.Errorf("unknown request type %T", req)
149	}
150
151	// check if accelerate is supported
152	if u.useAccelerate && !u.supportsAccelerate {
153		// accelerate is not supported, thus will be ignored
154		log.Println("Transfer acceleration is not supported for the operation, ignoring UseAccelerate.")
155		u.useAccelerate = false
156	}
157
158	// transfer acceleration is not supported with path style urls
159	if u.useAccelerate && u.usePathStyle {
160		log.Println("UseAccelerate is not compatible with UsePathStyle, ignoring UsePathStyle.")
161		u.usePathStyle = false
162	}
163
164	if u.getBucketFromInput != nil {
165		// Below customization only apply if bucket name is provided
166		bucket, ok := u.getBucketFromInput(in.Parameters)
167		if ok && bucket != nil {
168			region := awsmiddleware.GetRegion(ctx)
169			if err := u.updateEndpointFromConfig(req, *bucket, region); err != nil {
170				return out, metadata, err
171			}
172		}
173	}
174
175	return next.HandleSerialize(ctx, in)
176}
177
178func (u updateEndpoint) updateEndpointFromConfig(req *smithyhttp.Request, bucket string, region string) error {
179	// do nothing if path style is enforced
180	if u.usePathStyle {
181		return nil
182	}
183
184	if !hostCompatibleBucketName(req.URL, bucket) {
185		// bucket name must be valid to put into the host for accelerate operations.
186		// For non-accelerate operations the bucket name can stay in the path if
187		// not valid hostname.
188		var err error
189		if u.useAccelerate {
190			err = fmt.Errorf("bucket name %s is not compatible with S3", bucket)
191		}
192
193		// No-Op if not using accelerate.
194		return err
195	}
196
197	// accelerate is only supported if use path style is disabled
198	if u.useAccelerate {
199		parts := strings.Split(req.URL.Host, ".")
200		if len(parts) < 3 {
201			return fmt.Errorf("unable to update endpoint host for S3 accelerate, hostname invalid, %s", req.URL.Host)
202		}
203
204		if parts[0] == "s3" || strings.HasPrefix(parts[0], "s3-") {
205			parts[0] = "s3-accelerate"
206		}
207
208		for i := 1; i+1 < len(parts); i++ {
209			if strings.EqualFold(parts[i], region) {
210				parts = append(parts[:i], parts[i+1:]...)
211				break
212			}
213		}
214
215		// construct the url host
216		req.URL.Host = strings.Join(parts, ".")
217	}
218
219	// move bucket to follow virtual host style
220	moveBucketNameToHost(req.URL, bucket)
221	return nil
222}
223
224// updates endpoint to use virtual host styling
225func moveBucketNameToHost(u *url.URL, bucket string) {
226	u.Host = bucket + "." + u.Host
227	removeBucketFromPath(u, bucket)
228}
229
230// remove bucket from url
231func removeBucketFromPath(u *url.URL, bucket string) {
232	// modify url path
233	u.Path = strings.Replace(u.Path, "/"+bucket, "", -1)
234	if u.Path == "" {
235		u.Path = "/"
236	}
237
238	// modify url raw path
239	u.RawPath = strings.Replace(u.RawPath, "/"+bucket, "", -1)
240	if u.RawPath == "" {
241		u.RawPath = "/"
242	}
243}
244
245// hostCompatibleBucketName returns true if the request should
246// put the bucket in the host. This is false if S3ForcePathStyle is
247// explicitly set or if the bucket is not DNS compatible.
248func hostCompatibleBucketName(u *url.URL, bucket string) bool {
249	// Bucket might be DNS compatible but dots in the hostname will fail
250	// certificate validation, so do not use host-style.
251	if u.Scheme == "https" && strings.Contains(bucket, ".") {
252		return false
253	}
254
255	// if the bucket is DNS compatible
256	return dnsCompatibleBucketName(bucket)
257}
258
259// dnsCompatibleBucketName returns true if the bucket name is DNS compatible.
260// Buckets created outside of the classic region MUST be DNS compatible.
261func dnsCompatibleBucketName(bucket string) bool {
262	if strings.Contains(bucket, "..") {
263		return false
264	}
265
266	// checks for `^[a-z0-9][a-z0-9\.\-]{1,61}[a-z0-9]$` domain mapping
267	if !((bucket[0] > 96 && bucket[0] < 123) || (bucket[0] > 47 && bucket[0] < 58)) {
268		return false
269	}
270
271	for _, c := range bucket[1:] {
272		if !((c > 96 && c < 123) || (c > 47 && c < 58) || c == 46 || c == 45) {
273			return false
274		}
275	}
276
277	// checks for `^(\d+\.){3}\d+$` IPaddressing
278	v := strings.SplitN(bucket, ".", -1)
279	if len(v) == 4 {
280		for _, c := range bucket {
281			if !((c > 47 && c < 58) || c == 46) {
282				// we confirm that this is not a IP address
283				return true
284			}
285		}
286		// this is a IP address
287		return false
288	}
289
290	return true
291}
292