1// +build integration,perftest
2
3package main
4
5import (
6	"context"
7	"flag"
8	"fmt"
9	"io"
10	"io/ioutil"
11	"os"
12	"os/signal"
13	"time"
14
15	"github.com/aws/aws-sdk-go/aws"
16	"github.com/aws/aws-sdk-go/aws/credentials"
17	"github.com/aws/aws-sdk-go/aws/request"
18	"github.com/aws/aws-sdk-go/aws/session"
19	"github.com/aws/aws-sdk-go/service/s3"
20)
21
22var config Config
23
24func init() {
25	config.SetupFlags("", flag.CommandLine)
26}
27
28func main() {
29	if err := flag.CommandLine.Parse(os.Args[1:]); err != nil {
30		flag.CommandLine.PrintDefaults()
31		exitErrorf(err, "failed to parse CLI commands")
32	}
33	if err := config.Validate(); err != nil {
34		flag.CommandLine.PrintDefaults()
35		exitErrorf(err, "invalid arguments")
36	}
37
38	client := NewClient(config.Client)
39
40	var creds *credentials.Credentials
41	if config.SDK.Anonymous {
42		creds = credentials.AnonymousCredentials
43	}
44
45	var endpoint *string
46	if v := config.Endpoint; len(v) != 0 {
47		endpoint = &v
48	}
49
50	sess, err := session.NewSession(&aws.Config{
51		HTTPClient:           client,
52		Endpoint:             endpoint,
53		Credentials:          creds,
54		S3Disable100Continue: aws.Bool(!config.SDK.ExpectContinue),
55	})
56	if err != nil {
57		exitErrorf(err, "failed to load config")
58	}
59
60	// Create context cancel for Ctrl+C/Interrupt
61	ctx, cancelFn := context.WithCancel(context.Background())
62	defer cancelFn()
63	sigCh := make(chan os.Signal, 1)
64	signal.Notify(sigCh, os.Interrupt)
65	go func() {
66		<-sigCh
67		cancelFn()
68	}()
69
70	// Use the request duration timeout if specified.
71	if config.RequestDuration != 0 {
72		var timeoutFn func()
73		ctx, timeoutFn = context.WithTimeout(ctx, config.RequestDuration)
74		defer timeoutFn()
75	}
76
77	logger := NewLogger(os.Stdout)
78
79	// Start making the requests.
80	svc := s3.New(sess)
81	var reqCount int64
82	errCount := 0
83	for {
84		trace := doRequest(ctx, reqCount, svc, config)
85		select {
86		case <-ctx.Done():
87			return
88		default:
89		}
90		logger.RecordTrace(trace)
91
92		if err := trace.Err(); err != nil {
93			fmt.Fprintf(os.Stderr, err.Error())
94			errCount++
95		} else {
96			errCount = 0
97		}
98
99		if config.RequestCount > 0 && reqCount == config.RequestCount {
100			return
101		}
102
103		reqCount++
104
105		// If the first several requests fail, exist, something is broken.
106		if errCount == 5 && reqCount == 5 {
107			exitErrorf(trace.Err(), "unable to make requests")
108		}
109
110		if config.RequestDelay > 0 {
111			time.Sleep(config.RequestDelay)
112		}
113	}
114}
115
116func doRequest(ctx context.Context, id int64, svc *s3.S3, config Config) *RequestTrace {
117	traceCtx := NewRequestTrace(ctx, id)
118	defer traceCtx.RequestDone()
119
120	resp, err := svc.GetObjectWithContext(traceCtx, &s3.GetObjectInput{
121		Bucket: &config.Bucket,
122		Key:    &config.Key,
123	}, func(r *request.Request) {
124		r.Handlers.Send.PushFront(traceCtx.OnSendAttempt)
125		r.Handlers.Complete.PushBack(traceCtx.OnCompleteRequest)
126		r.Handlers.CompleteAttempt.PushBack(traceCtx.OnCompleteAttempt)
127	})
128	if err != nil {
129		traceCtx.AppendError(fmt.Errorf("request failed, %v", err))
130		return traceCtx
131	}
132	defer resp.Body.Close()
133
134	if n, err := io.Copy(ioutil.Discard, resp.Body); err != nil {
135		traceCtx.AppendError(fmt.Errorf("read request body failed, read %v, %v", n, err))
136		return traceCtx
137	}
138
139	return traceCtx
140}
141
142func exitErrorf(err error, msg string, args ...interface{}) {
143	fmt.Fprintf(os.Stderr, "FAILED: %v\n"+msg+"\n", append([]interface{}{err}, args...)...)
144	os.Exit(1)
145}
146