1package request
2
3import (
4	"io"
5	"time"
6
7	"github.com/aws/aws-sdk-go/aws/awserr"
8)
9
10var timeoutErr = awserr.New(
11	ErrCodeResponseTimeout,
12	"read on body has reached the timeout limit",
13	nil,
14)
15
16type readResult struct {
17	n   int
18	err error
19}
20
21// timeoutReadCloser will handle body reads that take too long.
22// We will return a ErrReadTimeout error if a timeout occurs.
23type timeoutReadCloser struct {
24	reader   io.ReadCloser
25	duration time.Duration
26}
27
28// Read will spin off a goroutine to call the reader's Read method. We will
29// select on the timer's channel or the read's channel. Whoever completes first
30// will be returned.
31func (r *timeoutReadCloser) Read(b []byte) (int, error) {
32	timer := time.NewTimer(r.duration)
33	c := make(chan readResult, 1)
34
35	go func() {
36		n, err := r.reader.Read(b)
37		timer.Stop()
38		c <- readResult{n: n, err: err}
39	}()
40
41	select {
42	case data := <-c:
43		return data.n, data.err
44	case <-timer.C:
45		return 0, timeoutErr
46	}
47}
48
49func (r *timeoutReadCloser) Close() error {
50	return r.reader.Close()
51}
52
53const (
54	// HandlerResponseTimeout is what we use to signify the name of the
55	// response timeout handler.
56	HandlerResponseTimeout = "ResponseTimeoutHandler"
57)
58
59// adaptToResponseTimeoutError is a handler that will replace any top level error
60// to a ErrCodeResponseTimeout, if its child is that.
61func adaptToResponseTimeoutError(req *Request) {
62	if err, ok := req.Error.(awserr.Error); ok {
63		aerr, ok := err.OrigErr().(awserr.Error)
64		if ok && aerr.Code() == ErrCodeResponseTimeout {
65			req.Error = aerr
66		}
67	}
68}
69
70// WithResponseReadTimeout is a request option that will wrap the body in a timeout read closer.
71// This will allow for per read timeouts. If a timeout occurred, we will return the
72// ErrCodeResponseTimeout.
73//
74//     svc.PutObjectWithContext(ctx, params, request.WithTimeoutReadCloser(30 * time.Second)
75func WithResponseReadTimeout(duration time.Duration) Option {
76	return func(r *Request) {
77
78		var timeoutHandler = NamedHandler{
79			HandlerResponseTimeout,
80			func(req *Request) {
81				req.HTTPResponse.Body = &timeoutReadCloser{
82					reader:   req.HTTPResponse.Body,
83					duration: duration,
84				}
85			}}
86
87		// remove the handler so we are not stomping over any new durations.
88		r.Handlers.Send.RemoveByName(HandlerResponseTimeout)
89		r.Handlers.Send.PushBackNamed(timeoutHandler)
90
91		r.Handlers.Unmarshal.PushBack(adaptToResponseTimeoutError)
92		r.Handlers.UnmarshalError.PushBack(adaptToResponseTimeoutError)
93	}
94}
95