1package http
2
3import (
4	"context"
5	"fmt"
6	"io"
7	"time"
8
9	"github.com/aws/smithy-go"
10	"github.com/aws/smithy-go/middleware"
11	smithyhttp "github.com/aws/smithy-go/transport/http"
12)
13
14type readResult struct {
15	n   int
16	err error
17}
18
19// ResponseTimeoutError is an error when the reads from the response are
20// delayed longer than the timeout the read was configured for.
21type ResponseTimeoutError struct {
22	TimeoutDur time.Duration
23}
24
25// Timeout returns that the error is was caused by a timeout, and can be
26// retried.
27func (*ResponseTimeoutError) Timeout() bool { return true }
28
29func (e *ResponseTimeoutError) Error() string {
30	return fmt.Sprintf("read on body reach timeout limit, %v", e.TimeoutDur)
31}
32
33// timeoutReadCloser will handle body reads that take too long.
34// We will return a ErrReadTimeout error if a timeout occurs.
35type timeoutReadCloser struct {
36	reader   io.ReadCloser
37	duration time.Duration
38}
39
40// Read will spin off a goroutine to call the reader's Read method. We will
41// select on the timer's channel or the read's channel. Whoever completes first
42// will be returned.
43func (r *timeoutReadCloser) Read(b []byte) (int, error) {
44	timer := time.NewTimer(r.duration)
45	c := make(chan readResult, 1)
46
47	go func() {
48		n, err := r.reader.Read(b)
49		timer.Stop()
50		c <- readResult{n: n, err: err}
51	}()
52
53	select {
54	case data := <-c:
55		return data.n, data.err
56	case <-timer.C:
57		return 0, &ResponseTimeoutError{TimeoutDur: r.duration}
58	}
59}
60
61func (r *timeoutReadCloser) Close() error {
62	return r.reader.Close()
63}
64
65// AddResponseReadTimeoutMiddleware adds a middleware to the stack that wraps the
66// response body so that a read that takes too long will return an error.
67func AddResponseReadTimeoutMiddleware(stack *middleware.Stack, duration time.Duration) error {
68	return stack.Deserialize.Add(&readTimeout{duration: duration}, middleware.After)
69}
70
71// readTimeout wraps the response body with a timeoutReadCloser
72type readTimeout struct {
73	duration time.Duration
74}
75
76// ID returns the id of the middleware
77func (*readTimeout) ID() string {
78	return "ReadResponseTimeout"
79}
80
81// HandleDeserialize implements the DeserializeMiddleware interface
82func (m *readTimeout) HandleDeserialize(
83	ctx context.Context, in middleware.DeserializeInput, next middleware.DeserializeHandler,
84) (
85	out middleware.DeserializeOutput, metadata middleware.Metadata, err error,
86) {
87	out, metadata, err = next.HandleDeserialize(ctx, in)
88	if err != nil {
89		return out, metadata, err
90	}
91
92	response, ok := out.RawResponse.(*smithyhttp.Response)
93	if !ok {
94		return out, metadata, &smithy.DeserializationError{Err: fmt.Errorf("unknown transport type %T", out.RawResponse)}
95	}
96
97	response.Body = &timeoutReadCloser{
98		reader:   response.Body,
99		duration: m.duration,
100	}
101	out.RawResponse = response
102
103	return out, metadata, err
104}
105