1// +build integration 2 3package kinesis_test 4 5import ( 6 crand "crypto/rand" 7 "crypto/tls" 8 "flag" 9 "fmt" 10 "io" 11 "math/rand" 12 "net/http" 13 "os" 14 "testing" 15 "time" 16 17 "github.com/aws/aws-sdk-go/aws" 18 "github.com/aws/aws-sdk-go/aws/awserr" 19 "github.com/aws/aws-sdk-go/awstesting/integration" 20 "github.com/aws/aws-sdk-go/service/kinesis" 21 "golang.org/x/net/http2" 22) 23 24var ( 25 skipTLSVerify bool 26 hUsage string 27 endpoint string 28 streamName string 29 consumerName string 30 numRecords int 31 recordSize int 32 debugEventStream bool 33 mode string 34 35 svc *kinesis.Kinesis 36 records []*kinesis.PutRecordsRequestEntry 37 38 startingTimestamp time.Time 39) 40 41func init() { 42 flag.StringVar( 43 &mode, "mode", "all", 44 "Sets the mode to run in, (test,create,cleanup,all).", 45 ) 46 flag.BoolVar( 47 &skipTLSVerify, "skip-verify", false, 48 "Skips verification of TLS certificate.", 49 ) 50 flag.StringVar( 51 &hUsage, "http", "default", 52 "The HTTP `version` to use for the connection. (default,1,2)", 53 ) 54 flag.StringVar( 55 &endpoint, "endpoint", "", 56 "Overrides SDK `URL` endpoint for tests.", 57 ) 58 flag.StringVar( 59 &streamName, "stream", fmt.Sprintf("awsdkgo-s%v", UniqueID()), 60 "The `name` of the stream to test against.", 61 ) 62 flag.StringVar( 63 &consumerName, "consumer", fmt.Sprintf("awsdkgo-c%v", UniqueID()), 64 "The `name` of the stream to test against.", 65 ) 66 flag.IntVar( 67 &numRecords, "records", 20, 68 "The `number` of records per PutRecords to test with.", 69 ) 70 flag.IntVar( 71 &recordSize, "record-size", 500, 72 "The size in `bytes` of each record.", 73 ) 74 flag.BoolVar( 75 &debugEventStream, "debug-eventstream", false, 76 "Enables debugging of the EventStream messages", 77 ) 78} 79 80func TestMain(m *testing.M) { 81 flag.Parse() 82 83 svc = createClient() 84 85 startingTimestamp = time.Now().Add(-time.Minute) 86 87 switch mode { 88 case "create", "all": 89 if err := createStream(streamName); err != nil { 90 panic(err) 91 } 92 if err := createStreamConsumer(streamName, consumerName); err != nil { 93 panic(err) 94 } 95 fmt.Println("Stream Ready:", streamName, consumerName) 96 97 if mode != "all" { 98 break 99 } 100 fallthrough 101 case "test": 102 records = createRecords(numRecords, recordSize) 103 if err := putRecords(streamName, records, svc); err != nil { 104 panic(err) 105 } 106 time.Sleep(time.Second) 107 108 var exitCode int 109 defer func() { 110 os.Exit(exitCode) 111 }() 112 113 exitCode = m.Run() 114 115 if mode != "all" { 116 break 117 } 118 fallthrough 119 case "cleanup": 120 if err := cleanupStreamConsumer(streamName, consumerName); err != nil { 121 panic(err) 122 } 123 if err := cleanupStream(streamName); err != nil { 124 panic(err) 125 } 126 default: 127 fmt.Fprintf(os.Stderr, "unknown mode, %v", mode) 128 os.Exit(1) 129 } 130} 131 132func createClient() *kinesis.Kinesis { 133 ts := &http.Transport{} 134 135 if skipTLSVerify { 136 ts.TLSClientConfig = &tls.Config{ 137 InsecureSkipVerify: true, 138 } 139 } 140 141 http2.ConfigureTransport(ts) 142 switch hUsage { 143 case "default": 144 // Restore H2 optional support since the Transport/TLSConfig was 145 // modified. 146 http2.ConfigureTransport(ts) 147 case "1": 148 // Do nothing. Without usign ConfigureTransport h2 won't be available. 149 ts.TLSClientConfig.NextProtos = []string{"http/1.1"} 150 case "2": 151 // Force the TLS ALPN (NextProto) to H2 only. 152 ts.TLSClientConfig.NextProtos = []string{http2.NextProtoTLS} 153 default: 154 panic("unknown h usage, " + hUsage) 155 } 156 157 sess := integration.SessionWithDefaultRegion("us-west-2") 158 cfg := &aws.Config{ 159 HTTPClient: &http.Client{ 160 Transport: ts, 161 }, 162 } 163 if debugEventStream { 164 cfg.LogLevel = aws.LogLevel( 165 sess.Config.LogLevel.Value() | aws.LogDebugWithEventStreamBody) 166 } 167 168 return kinesis.New(sess, cfg) 169} 170 171func createStream(name string) error { 172 descParams := &kinesis.DescribeStreamInput{ 173 StreamName: &name, 174 } 175 176 _, err := svc.DescribeStream(descParams) 177 if aerr, ok := err.(awserr.Error); ok && aerr.Code() == kinesis.ErrCodeResourceNotFoundException { 178 _, err := svc.CreateStream(&kinesis.CreateStreamInput{ 179 ShardCount: aws.Int64(100), 180 StreamName: &name, 181 }) 182 if err != nil { 183 return fmt.Errorf("failed to create stream, %v", err) 184 } 185 } else if err != nil { 186 return fmt.Errorf("failed to describe stream, %v", err) 187 } 188 189 if err := svc.WaitUntilStreamExists(descParams); err != nil { 190 return fmt.Errorf("failed to wait for stream to exist, %v", err) 191 } 192 193 return nil 194} 195 196func cleanupStream(name string) error { 197 _, err := svc.DeleteStream(&kinesis.DeleteStreamInput{ 198 StreamName: &name, 199 EnforceConsumerDeletion: aws.Bool(true), 200 }) 201 if err != nil { 202 return fmt.Errorf("failed to delete stream, %v", err) 203 } 204 205 return nil 206} 207 208func createStreamConsumer(streamName, consumerName string) error { 209 desc, err := svc.DescribeStream(&kinesis.DescribeStreamInput{ 210 StreamName: &streamName, 211 }) 212 if err != nil { 213 return fmt.Errorf("failed to describe stream, %s, %v", streamName, err) 214 } 215 216 descParams := &kinesis.DescribeStreamConsumerInput{ 217 StreamARN: desc.StreamDescription.StreamARN, 218 ConsumerName: &consumerName, 219 } 220 _, err = svc.DescribeStreamConsumer(descParams) 221 if aerr, ok := err.(awserr.Error); ok && aerr.Code() == kinesis.ErrCodeResourceNotFoundException { 222 _, err := svc.RegisterStreamConsumer( 223 &kinesis.RegisterStreamConsumerInput{ 224 ConsumerName: aws.String(consumerName), 225 StreamARN: desc.StreamDescription.StreamARN, 226 }, 227 ) 228 if err != nil { 229 return fmt.Errorf("failed to create stream consumer %s, %v", 230 consumerName, err) 231 } 232 } else if err != nil { 233 return fmt.Errorf("failed to describe stream consumer %s, %v", 234 consumerName, err) 235 } 236 237 for i := 0; i < 10; i++ { 238 resp, err := svc.DescribeStreamConsumer(descParams) 239 if err != nil || aws.StringValue(resp.ConsumerDescription.ConsumerStatus) != kinesis.ConsumerStatusActive { 240 time.Sleep(time.Second * 30) 241 continue 242 } 243 return nil 244 } 245 246 return fmt.Errorf("failed to wait for consumer to exist, %v, %v", 247 *descParams.StreamARN, *descParams.ConsumerName) 248} 249 250func cleanupStreamConsumer(streamName, consumerName string) error { 251 desc, err := svc.DescribeStream(&kinesis.DescribeStreamInput{ 252 StreamName: &streamName, 253 }) 254 if err != nil { 255 return fmt.Errorf("failed to describe stream, %s, %v", 256 streamName, err) 257 } 258 259 descCons, err := svc.DescribeStreamConsumer(&kinesis.DescribeStreamConsumerInput{ 260 StreamARN: desc.StreamDescription.StreamARN, 261 ConsumerName: &consumerName, 262 }) 263 if err != nil { 264 return fmt.Errorf("failed to describe stream consumer, %s, %v", 265 consumerName, err) 266 } 267 268 _, err = svc.DeregisterStreamConsumer( 269 &kinesis.DeregisterStreamConsumerInput{ 270 ConsumerName: descCons.ConsumerDescription.ConsumerName, 271 ConsumerARN: descCons.ConsumerDescription.ConsumerARN, 272 StreamARN: desc.StreamDescription.StreamARN, 273 }, 274 ) 275 if err != nil { 276 return fmt.Errorf("failed to delete stream consumer, %s %v", 277 consumerName, err) 278 } 279 280 return nil 281} 282 283func createRecords(num, size int) []*kinesis.PutRecordsRequestEntry { 284 var err error 285 data, err := loadRandomData(num, size) 286 if err != nil { 287 fmt.Fprintf(os.Stderr, "unable to read random data, %v", err) 288 os.Exit(1) 289 } 290 291 records := make([]*kinesis.PutRecordsRequestEntry, len(data)) 292 for i, td := range data { 293 records[i] = &kinesis.PutRecordsRequestEntry{ 294 Data: td, 295 PartitionKey: aws.String(UniqueID()), 296 } 297 } 298 299 return records 300} 301 302func putRecords(stream string, records []*kinesis.PutRecordsRequestEntry, svc *kinesis.Kinesis) error { 303 resp, err := svc.PutRecords(&kinesis.PutRecordsInput{ 304 StreamName: &stream, 305 Records: records, 306 }) 307 if err != nil { 308 return fmt.Errorf("failed to put records to stream %s, %v", stream, err) 309 } 310 311 if v := aws.Int64Value(resp.FailedRecordCount); v != 0 { 312 return fmt.Errorf("failed to put records to stream %s, %d failed", 313 stream, v) 314 } 315 316 return nil 317} 318 319func loadRandomData(m, n int) ([][]byte, error) { 320 data := make([]byte, m*n) 321 322 _, err := rand.Read(data) 323 if err != nil { 324 return nil, err 325 } 326 327 parts := make([][]byte, m) 328 329 for i := 0; i < m; i++ { 330 mod := (i % m) 331 parts[i] = data[mod*n : (mod+1)*n] 332 } 333 334 return parts, nil 335} 336 337// UniqueID returns a unique UUID-like identifier for use in generating 338// resources for integration tests. 339func UniqueID() string { 340 uuid := make([]byte, 16) 341 io.ReadFull(crand.Reader, uuid) 342 return fmt.Sprintf("%x", uuid) 343} 344