1// +build integration
2
3package kinesis_test
4
5import (
6	crand "crypto/rand"
7	"crypto/tls"
8	"flag"
9	"fmt"
10	"io"
11	"math/rand"
12	"net/http"
13	"os"
14	"testing"
15	"time"
16
17	"github.com/aws/aws-sdk-go/aws"
18	"github.com/aws/aws-sdk-go/aws/awserr"
19	"github.com/aws/aws-sdk-go/awstesting/integration"
20	"github.com/aws/aws-sdk-go/service/kinesis"
21	"golang.org/x/net/http2"
22)
23
24var (
25	skipTLSVerify    bool
26	hUsage           string
27	endpoint         string
28	streamName       string
29	consumerName     string
30	numRecords       int
31	recordSize       int
32	debugEventStream bool
33	mode             string
34
35	svc     *kinesis.Kinesis
36	records []*kinesis.PutRecordsRequestEntry
37
38	startingTimestamp time.Time
39)
40
41func init() {
42	flag.StringVar(
43		&mode, "mode", "all",
44		"Sets the mode to run in, (test,create,cleanup,all).",
45	)
46	flag.BoolVar(
47		&skipTLSVerify, "skip-verify", false,
48		"Skips verification of TLS certificate.",
49	)
50	flag.StringVar(
51		&hUsage, "http", "default",
52		"The HTTP `version` to use for the connection. (default,1,2)",
53	)
54	flag.StringVar(
55		&endpoint, "endpoint", "",
56		"Overrides SDK `URL` endpoint for tests.",
57	)
58	flag.StringVar(
59		&streamName, "stream", fmt.Sprintf("awsdkgo-s%v", UniqueID()),
60		"The `name` of the stream to test against.",
61	)
62	flag.StringVar(
63		&consumerName, "consumer", fmt.Sprintf("awsdkgo-c%v", UniqueID()),
64		"The `name` of the stream to test against.",
65	)
66	flag.IntVar(
67		&numRecords, "records", 20,
68		"The `number` of records per PutRecords to test with.",
69	)
70	flag.IntVar(
71		&recordSize, "record-size", 500,
72		"The size in `bytes` of each record.",
73	)
74	flag.BoolVar(
75		&debugEventStream, "debug-eventstream", false,
76		"Enables debugging of the EventStream messages",
77	)
78}
79
80func TestMain(m *testing.M) {
81	flag.Parse()
82
83	svc = createClient()
84
85	startingTimestamp = time.Now().Add(-time.Minute)
86
87	switch mode {
88	case "create", "all":
89		if err := createStream(streamName); err != nil {
90			panic(err)
91		}
92		if err := createStreamConsumer(streamName, consumerName); err != nil {
93			panic(err)
94		}
95		fmt.Println("Stream Ready:", streamName, consumerName)
96
97		if mode != "all" {
98			break
99		}
100		fallthrough
101	case "test":
102		records = createRecords(numRecords, recordSize)
103		if err := putRecords(streamName, records, svc); err != nil {
104			panic(err)
105		}
106		time.Sleep(time.Second)
107
108		var exitCode int
109		defer func() {
110			os.Exit(exitCode)
111		}()
112
113		exitCode = m.Run()
114
115		if mode != "all" {
116			break
117		}
118		fallthrough
119	case "cleanup":
120		if err := cleanupStreamConsumer(streamName, consumerName); err != nil {
121			panic(err)
122		}
123		if err := cleanupStream(streamName); err != nil {
124			panic(err)
125		}
126	default:
127		fmt.Fprintf(os.Stderr, "unknown mode, %v", mode)
128		os.Exit(1)
129	}
130}
131
132func createClient() *kinesis.Kinesis {
133	ts := &http.Transport{}
134
135	if skipTLSVerify {
136		ts.TLSClientConfig = &tls.Config{
137			InsecureSkipVerify: true,
138		}
139	}
140
141	http2.ConfigureTransport(ts)
142	switch hUsage {
143	case "default":
144		// Restore H2 optional support since the Transport/TLSConfig was
145		// modified.
146		http2.ConfigureTransport(ts)
147	case "1":
148		// Do nothing. Without usign ConfigureTransport h2 won't be available.
149		ts.TLSClientConfig.NextProtos = []string{"http/1.1"}
150	case "2":
151		// Force the TLS ALPN (NextProto) to H2 only.
152		ts.TLSClientConfig.NextProtos = []string{http2.NextProtoTLS}
153	default:
154		panic("unknown h usage, " + hUsage)
155	}
156
157	sess := integration.SessionWithDefaultRegion("us-west-2")
158	cfg := &aws.Config{
159		HTTPClient: &http.Client{
160			Transport: ts,
161		},
162	}
163	if debugEventStream {
164		cfg.LogLevel = aws.LogLevel(
165			sess.Config.LogLevel.Value() | aws.LogDebugWithEventStreamBody)
166	}
167
168	return kinesis.New(sess, cfg)
169}
170
171func createStream(name string) error {
172	descParams := &kinesis.DescribeStreamInput{
173		StreamName: &name,
174	}
175
176	_, err := svc.DescribeStream(descParams)
177	if aerr, ok := err.(awserr.Error); ok && aerr.Code() == kinesis.ErrCodeResourceNotFoundException {
178		_, err := svc.CreateStream(&kinesis.CreateStreamInput{
179			ShardCount: aws.Int64(100),
180			StreamName: &name,
181		})
182		if err != nil {
183			return fmt.Errorf("failed to create stream, %v", err)
184		}
185	} else if err != nil {
186		return fmt.Errorf("failed to describe stream, %v", err)
187	}
188
189	if err := svc.WaitUntilStreamExists(descParams); err != nil {
190		return fmt.Errorf("failed to wait for stream to exist, %v", err)
191	}
192
193	return nil
194}
195
196func cleanupStream(name string) error {
197	_, err := svc.DeleteStream(&kinesis.DeleteStreamInput{
198		StreamName:              &name,
199		EnforceConsumerDeletion: aws.Bool(true),
200	})
201	if err != nil {
202		return fmt.Errorf("failed to delete stream, %v", err)
203	}
204
205	return nil
206}
207
208func createStreamConsumer(streamName, consumerName string) error {
209	desc, err := svc.DescribeStream(&kinesis.DescribeStreamInput{
210		StreamName: &streamName,
211	})
212	if err != nil {
213		return fmt.Errorf("failed to describe stream, %s, %v", streamName, err)
214	}
215
216	descParams := &kinesis.DescribeStreamConsumerInput{
217		StreamARN:    desc.StreamDescription.StreamARN,
218		ConsumerName: &consumerName,
219	}
220	_, err = svc.DescribeStreamConsumer(descParams)
221	if aerr, ok := err.(awserr.Error); ok && aerr.Code() == kinesis.ErrCodeResourceNotFoundException {
222		_, err := svc.RegisterStreamConsumer(
223			&kinesis.RegisterStreamConsumerInput{
224				ConsumerName: aws.String(consumerName),
225				StreamARN:    desc.StreamDescription.StreamARN,
226			},
227		)
228		if err != nil {
229			return fmt.Errorf("failed to create stream consumer %s, %v",
230				consumerName, err)
231		}
232	} else if err != nil {
233		return fmt.Errorf("failed to describe stream consumer %s, %v",
234			consumerName, err)
235	}
236
237	for i := 0; i < 10; i++ {
238		resp, err := svc.DescribeStreamConsumer(descParams)
239		if err != nil || aws.StringValue(resp.ConsumerDescription.ConsumerStatus) != kinesis.ConsumerStatusActive {
240			time.Sleep(time.Second * 30)
241			continue
242		}
243		return nil
244	}
245
246	return fmt.Errorf("failed to wait for consumer to exist, %v, %v",
247		*descParams.StreamARN, *descParams.ConsumerName)
248}
249
250func cleanupStreamConsumer(streamName, consumerName string) error {
251	desc, err := svc.DescribeStream(&kinesis.DescribeStreamInput{
252		StreamName: &streamName,
253	})
254	if err != nil {
255		return fmt.Errorf("failed to describe stream, %s, %v",
256			streamName, err)
257	}
258
259	descCons, err := svc.DescribeStreamConsumer(&kinesis.DescribeStreamConsumerInput{
260		StreamARN:    desc.StreamDescription.StreamARN,
261		ConsumerName: &consumerName,
262	})
263	if err != nil {
264		return fmt.Errorf("failed to describe stream consumer, %s, %v",
265			consumerName, err)
266	}
267
268	_, err = svc.DeregisterStreamConsumer(
269		&kinesis.DeregisterStreamConsumerInput{
270			ConsumerName: descCons.ConsumerDescription.ConsumerName,
271			ConsumerARN:  descCons.ConsumerDescription.ConsumerARN,
272			StreamARN:    desc.StreamDescription.StreamARN,
273		},
274	)
275	if err != nil {
276		return fmt.Errorf("failed to delete stream consumer, %s %v",
277			consumerName, err)
278	}
279
280	return nil
281}
282
283func createRecords(num, size int) []*kinesis.PutRecordsRequestEntry {
284	var err error
285	data, err := loadRandomData(num, size)
286	if err != nil {
287		fmt.Fprintf(os.Stderr, "unable to read random data, %v", err)
288		os.Exit(1)
289	}
290
291	records := make([]*kinesis.PutRecordsRequestEntry, len(data))
292	for i, td := range data {
293		records[i] = &kinesis.PutRecordsRequestEntry{
294			Data:         td,
295			PartitionKey: aws.String(UniqueID()),
296		}
297	}
298
299	return records
300}
301
302func putRecords(stream string, records []*kinesis.PutRecordsRequestEntry, svc *kinesis.Kinesis) error {
303	resp, err := svc.PutRecords(&kinesis.PutRecordsInput{
304		StreamName: &stream,
305		Records:    records,
306	})
307	if err != nil {
308		return fmt.Errorf("failed to put records to stream %s, %v", stream, err)
309	}
310
311	if v := aws.Int64Value(resp.FailedRecordCount); v != 0 {
312		return fmt.Errorf("failed to put records to stream %s, %d failed",
313			stream, v)
314	}
315
316	return nil
317}
318
319func loadRandomData(m, n int) ([][]byte, error) {
320	data := make([]byte, m*n)
321
322	_, err := rand.Read(data)
323	if err != nil {
324		return nil, err
325	}
326
327	parts := make([][]byte, m)
328
329	for i := 0; i < m; i++ {
330		mod := (i % m)
331		parts[i] = data[mod*n : (mod+1)*n]
332	}
333
334	return parts, nil
335}
336
337// UniqueID returns a unique UUID-like identifier for use in generating
338// resources for integration tests.
339func UniqueID() string {
340	uuid := make([]byte, 16)
341	io.ReadFull(crand.Reader, uuid)
342	return fmt.Sprintf("%x", uuid)
343}
344