1package customizations
2
3import (
4	"context"
5	"fmt"
6	"github.com/aws/aws-sdk-go-v2/aws"
7	"net/url"
8	"strings"
9
10	"github.com/aws/smithy-go/middleware"
11	"github.com/aws/smithy-go/transport/http"
12
13	awsmiddleware "github.com/aws/aws-sdk-go-v2/aws/middleware"
14	"github.com/aws/aws-sdk-go-v2/service/internal/s3shared"
15	"github.com/aws/aws-sdk-go-v2/service/internal/s3shared/arn"
16	s3arn "github.com/aws/aws-sdk-go-v2/service/s3/internal/arn"
17)
18
19const (
20	s3AccessPoint  = "s3-accesspoint"
21	s3ObjectLambda = "s3-object-lambda"
22)
23
24// processARNResource is used to process an ARN resource.
25type processARNResource struct {
26
27	// UseARNRegion indicates if region parsed from an ARN should be used.
28	UseARNRegion bool
29
30	// UseAccelerate indicates if s3 transfer acceleration is enabled
31	UseAccelerate bool
32
33	// UseDualstack instructs if s3 dualstack endpoint config is enabled
34	UseDualstack bool
35
36	// EndpointResolver used to resolve endpoints. This may be a custom endpoint resolver
37	EndpointResolver EndpointResolver
38
39	// EndpointResolverOptions used by endpoint resolver
40	EndpointResolverOptions EndpointResolverOptions
41}
42
43// ID returns the middleware ID.
44func (*processARNResource) ID() string { return "S3:ProcessARNResource" }
45
46func (m *processARNResource) HandleSerialize(
47	ctx context.Context, in middleware.SerializeInput, next middleware.SerializeHandler,
48) (
49	out middleware.SerializeOutput, metadata middleware.Metadata, err error,
50) {
51	// check if arn was provided, if not skip this middleware
52	arnValue, ok := s3shared.GetARNResourceFromContext(ctx)
53	if !ok {
54		return next.HandleSerialize(ctx, in)
55	}
56
57	req, ok := in.Request.(*http.Request)
58	if !ok {
59		return out, metadata, fmt.Errorf("unknown request type %T", req)
60	}
61
62	// parse arn into an endpoint arn wrt to service
63	resource, err := s3arn.ParseEndpointARN(arnValue)
64	if err != nil {
65		return out, metadata, err
66	}
67
68	// build a resource request struct
69	resourceRequest := s3shared.ResourceRequest{
70		Resource:      resource,
71		UseARNRegion:  m.UseARNRegion,
72		RequestRegion: awsmiddleware.GetRegion(ctx),
73		SigningRegion: awsmiddleware.GetSigningRegion(ctx),
74		PartitionID:   awsmiddleware.GetPartitionID(ctx),
75	}
76
77	// validate resource request
78	if err := validateResourceRequest(resourceRequest); err != nil {
79		return out, metadata, err
80	}
81
82	// switch to correct endpoint updater
83	switch tv := resource.(type) {
84	case arn.AccessPointARN:
85		// check if accelerate
86		if m.UseAccelerate {
87			return out, metadata, s3shared.NewClientConfiguredForAccelerateError(tv,
88				resourceRequest.PartitionID, resourceRequest.RequestRegion, nil)
89		}
90
91		// fetch arn region to resolve request
92		resolveRegion := tv.Region
93		// check if request region is FIPS
94		if resourceRequest.UseFips() {
95			// if use arn region is enabled and request signing region is not same as arn region
96			if m.UseARNRegion && resourceRequest.IsCrossRegion() {
97				// FIPS with cross region is not supported, the SDK must fail
98				// because there is no well defined method for SDK to construct a
99				// correct FIPS endpoint.
100				return out, metadata,
101					s3shared.NewClientConfiguredForCrossRegionFIPSError(
102						tv,
103						resourceRequest.PartitionID,
104						resourceRequest.RequestRegion,
105						nil,
106					)
107			}
108
109			// if use arn region is NOT set, we should use the request region
110			resolveRegion = resourceRequest.RequestRegion
111		}
112
113		// build access point request
114		ctx, err = buildAccessPointRequest(ctx, accesspointOptions{
115			processARNResource: *m,
116			request:            req,
117			resource:           tv,
118			resolveRegion:      resolveRegion,
119			partitionID:        resourceRequest.PartitionID,
120			requestRegion:      resourceRequest.RequestRegion,
121		})
122		if err != nil {
123			return out, metadata, err
124		}
125
126	case arn.S3ObjectLambdaAccessPointARN:
127		// check if accelerate
128		if m.UseAccelerate {
129			return out, metadata, s3shared.NewClientConfiguredForAccelerateError(tv,
130				resourceRequest.PartitionID, resourceRequest.RequestRegion, nil)
131		}
132
133		// check if dualstack
134		if m.UseDualstack {
135			return out, metadata, s3shared.NewClientConfiguredForDualStackError(tv,
136				resourceRequest.PartitionID, resourceRequest.RequestRegion, nil)
137		}
138
139		// fetch arn region to resolve request
140		resolveRegion := tv.Region
141
142		if resourceRequest.UseFips() {
143			// if use arn region is enabled and request signing region is not same as arn region
144			if m.UseARNRegion && resourceRequest.IsCrossRegion() {
145				// FIPS with cross region is not supported, the SDK must fail
146				// because there is no well defined method for SDK to construct a
147				// correct FIPS endpoint.
148				return out, metadata,
149					s3shared.NewClientConfiguredForCrossRegionFIPSError(
150						tv,
151						resourceRequest.PartitionID,
152						resourceRequest.RequestRegion,
153						nil,
154					)
155			}
156
157			// if use arn region is NOT set, we should use the request region
158			resolveRegion = resourceRequest.RequestRegion
159		}
160
161		// build access point request
162		ctx, err = buildS3ObjectLambdaAccessPointRequest(ctx, accesspointOptions{
163			processARNResource: *m,
164			request:            req,
165			resource:           tv.AccessPointARN,
166			resolveRegion:      resolveRegion,
167			partitionID:        resourceRequest.PartitionID,
168			requestRegion:      resourceRequest.RequestRegion,
169		})
170		if err != nil {
171			return out, metadata, err
172		}
173
174	// process outpost accesspoint ARN
175	case arn.OutpostAccessPointARN:
176		// check if accelerate
177		if m.UseAccelerate {
178			return out, metadata, s3shared.NewClientConfiguredForAccelerateError(tv,
179				resourceRequest.PartitionID, resourceRequest.RequestRegion, nil)
180		}
181
182		// check if dual stack
183		if m.UseDualstack {
184			return out, metadata, s3shared.NewClientConfiguredForDualStackError(tv,
185				resourceRequest.PartitionID, resourceRequest.RequestRegion, nil)
186		}
187
188		// check if resource arn region is FIPS
189		if resourceRequest.ResourceConfiguredForFIPS() {
190			return out, metadata, s3shared.NewInvalidARNWithFIPSError(tv, nil)
191		}
192
193		// build outpost access point request
194		ctx, err = buildOutpostAccessPointRequest(ctx, outpostAccessPointOptions{
195			processARNResource: *m,
196			resource:           tv,
197			request:            req,
198			partitionID:        resourceRequest.PartitionID,
199			requestRegion:      resourceRequest.RequestRegion,
200		})
201		if err != nil {
202			return out, metadata, err
203		}
204
205	default:
206		return out, metadata, s3shared.NewInvalidARNError(resource, nil)
207	}
208
209	return next.HandleSerialize(ctx, in)
210}
211
212// validate if s3 resource and request config is compatible.
213func validateResourceRequest(resourceRequest s3shared.ResourceRequest) error {
214	// check if resourceRequest leads to a cross partition error
215	v, err := resourceRequest.IsCrossPartition()
216	if err != nil {
217		return err
218	}
219	if v {
220		// if cross partition
221		return s3shared.NewClientPartitionMismatchError(resourceRequest.Resource,
222			resourceRequest.PartitionID, resourceRequest.RequestRegion, nil)
223	}
224
225	// check if resourceRequest leads to a cross region error
226	if !resourceRequest.AllowCrossRegion() && resourceRequest.IsCrossRegion() {
227		// if cross region, but not use ARN region is not enabled
228		return s3shared.NewClientRegionMismatchError(resourceRequest.Resource,
229			resourceRequest.PartitionID, resourceRequest.RequestRegion, nil)
230	}
231
232	return nil
233}
234
235// === Accesspoint ==========
236
237type accesspointOptions struct {
238	processARNResource
239	request       *http.Request
240	resource      arn.AccessPointARN
241	resolveRegion string
242	partitionID   string
243	requestRegion string
244}
245
246func buildAccessPointRequest(ctx context.Context, options accesspointOptions) (context.Context, error) {
247	tv := options.resource
248	req := options.request
249	resolveRegion := options.resolveRegion
250
251	resolveService := tv.Service
252
253	// resolve endpoint
254	endpoint, err := options.EndpointResolver.ResolveEndpoint(resolveRegion, options.EndpointResolverOptions)
255	if err != nil {
256		return ctx, s3shared.NewFailedToResolveEndpointError(
257			tv,
258			options.partitionID,
259			options.requestRegion,
260			err,
261		)
262	}
263
264	// assign resolved endpoint url to request url
265	req.URL, err = url.Parse(endpoint.URL)
266	if err != nil {
267		return ctx, fmt.Errorf("failed to parse endpoint URL: %w", err)
268	}
269
270	if len(endpoint.SigningName) != 0 && endpoint.Source == aws.EndpointSourceCustom {
271		ctx = awsmiddleware.SetSigningName(ctx, endpoint.SigningName)
272	} else {
273		// Must sign with s3-object-lambda
274		ctx = awsmiddleware.SetSigningName(ctx, resolveService)
275	}
276
277	if len(endpoint.SigningRegion) != 0 {
278		ctx = awsmiddleware.SetSigningRegion(ctx, endpoint.SigningRegion)
279	} else {
280		ctx = awsmiddleware.SetSigningRegion(ctx, resolveRegion)
281	}
282
283	// update serviceID to "s3-accesspoint"
284	ctx = awsmiddleware.SetServiceID(ctx, s3AccessPoint)
285
286	// disable host prefix behavior
287	ctx = http.DisableEndpointHostPrefix(ctx, true)
288
289	// remove the serialized arn in place of /{Bucket}
290	ctx = setBucketToRemoveOnContext(ctx, tv.String())
291
292	// skip arn processing, if arn region resolves to a immutable endpoint
293	if endpoint.HostnameImmutable {
294		return ctx, nil
295	}
296
297	updateS3HostForS3AccessPoint(req)
298
299	ctx, err = buildAccessPointHostPrefix(ctx, req, tv)
300	if err != nil {
301		return ctx, err
302	}
303
304	return ctx, nil
305}
306
307func buildS3ObjectLambdaAccessPointRequest(ctx context.Context, options accesspointOptions) (context.Context, error) {
308	tv := options.resource
309	req := options.request
310	resolveRegion := options.resolveRegion
311
312	resolveService := tv.Service
313
314	// resolve endpoint
315	endpoint, err := options.EndpointResolver.ResolveEndpoint(resolveRegion, options.EndpointResolverOptions)
316	if err != nil {
317		return ctx, s3shared.NewFailedToResolveEndpointError(
318			tv,
319			options.partitionID,
320			options.requestRegion,
321			err,
322		)
323	}
324
325	// assign resolved endpoint url to request url
326	req.URL, err = url.Parse(endpoint.URL)
327	if err != nil {
328		return ctx, fmt.Errorf("failed to parse endpoint URL: %w", err)
329	}
330
331	if len(endpoint.SigningName) != 0 && endpoint.Source == aws.EndpointSourceCustom {
332		ctx = awsmiddleware.SetSigningName(ctx, endpoint.SigningName)
333	} else {
334		// Must sign with s3-object-lambda
335		ctx = awsmiddleware.SetSigningName(ctx, resolveService)
336	}
337
338	if len(endpoint.SigningRegion) != 0 {
339		ctx = awsmiddleware.SetSigningRegion(ctx, endpoint.SigningRegion)
340	} else {
341		ctx = awsmiddleware.SetSigningRegion(ctx, resolveRegion)
342	}
343
344	// update serviceID to "s3-object-lambda"
345	ctx = awsmiddleware.SetServiceID(ctx, s3ObjectLambda)
346
347	// disable host prefix behavior
348	ctx = http.DisableEndpointHostPrefix(ctx, true)
349
350	// remove the serialized arn in place of /{Bucket}
351	ctx = setBucketToRemoveOnContext(ctx, tv.String())
352
353	// skip arn processing, if arn region resolves to a immutable endpoint
354	if endpoint.HostnameImmutable {
355		return ctx, nil
356	}
357
358	if endpoint.Source == aws.EndpointSourceServiceMetadata {
359		updateS3HostForS3ObjectLambda(req)
360	}
361
362	ctx, err = buildAccessPointHostPrefix(ctx, req, tv)
363	if err != nil {
364		return ctx, err
365	}
366
367	return ctx, nil
368}
369
370func buildAccessPointHostPrefix(ctx context.Context, req *http.Request, tv arn.AccessPointARN) (context.Context, error) {
371	// add host prefix for access point
372	accessPointHostPrefix := tv.AccessPointName + "-" + tv.AccountID + "."
373	req.URL.Host = accessPointHostPrefix + req.URL.Host
374	if len(req.Host) > 0 {
375		req.Host = accessPointHostPrefix + req.Host
376	}
377
378	// validate the endpoint host
379	if err := http.ValidateEndpointHost(req.URL.Host); err != nil {
380		return ctx, s3shared.NewInvalidARNError(tv, err)
381	}
382
383	return ctx, nil
384}
385
386// ====== Outpost Accesspoint ========
387
388type outpostAccessPointOptions struct {
389	processARNResource
390	request       *http.Request
391	resource      arn.OutpostAccessPointARN
392	partitionID   string
393	requestRegion string
394}
395
396func buildOutpostAccessPointRequest(ctx context.Context, options outpostAccessPointOptions) (context.Context, error) {
397	tv := options.resource
398	req := options.request
399
400	resolveRegion := tv.Region
401	resolveService := tv.Service
402	endpointsID := resolveService
403	if strings.EqualFold(resolveService, "s3-outposts") {
404		// assign endpoints ID as "S3"
405		endpointsID = "s3"
406	}
407
408	// resolve regional endpoint for resolved region.
409	endpoint, err := options.EndpointResolver.ResolveEndpoint(resolveRegion, options.EndpointResolverOptions)
410	if err != nil {
411		return ctx, s3shared.NewFailedToResolveEndpointError(
412			tv,
413			options.partitionID,
414			options.requestRegion,
415			err,
416		)
417	}
418
419	// assign resolved endpoint url to request url
420	req.URL, err = url.Parse(endpoint.URL)
421	if err != nil {
422		return ctx, fmt.Errorf("failed to parse endpoint URL: %w", err)
423	}
424
425	// assign resolved service from arn as signing name
426	if len(endpoint.SigningName) != 0 && endpoint.Source == aws.EndpointSourceCustom {
427		ctx = awsmiddleware.SetSigningName(ctx, endpoint.SigningName)
428	} else {
429		ctx = awsmiddleware.SetSigningName(ctx, resolveService)
430	}
431
432	if len(endpoint.SigningRegion) != 0 {
433		// redirect signer to use resolved endpoint signing name and region
434		ctx = awsmiddleware.SetSigningRegion(ctx, endpoint.SigningRegion)
435	} else {
436		ctx = awsmiddleware.SetSigningRegion(ctx, resolveRegion)
437	}
438
439	// update serviceID to resolved service id
440	ctx = awsmiddleware.SetServiceID(ctx, resolveService)
441
442	// disable host prefix behavior
443	ctx = http.DisableEndpointHostPrefix(ctx, true)
444
445	// remove the serialized arn in place of /{Bucket}
446	ctx = setBucketToRemoveOnContext(ctx, tv.String())
447
448	// skip further customizations, if arn region resolves to a immutable endpoint
449	if endpoint.HostnameImmutable {
450		return ctx, nil
451	}
452
453	updateHostPrefix(req, endpointsID, resolveService)
454
455	// add host prefix for s3-outposts
456	outpostAPHostPrefix := tv.AccessPointName + "-" + tv.AccountID + "." + tv.OutpostID + "."
457	req.URL.Host = outpostAPHostPrefix + req.URL.Host
458	if len(req.Host) > 0 {
459		req.Host = outpostAPHostPrefix + req.Host
460	}
461
462	// validate the endpoint host
463	if err := http.ValidateEndpointHost(req.URL.Host); err != nil {
464		return ctx, s3shared.NewInvalidARNError(tv, err)
465	}
466
467	return ctx, nil
468}
469