1// +build codegen
2
3package api
4
5import (
6	"bytes"
7	"fmt"
8	"io"
9	"os"
10	"strings"
11	"text/template"
12)
13
14// EventStreamAPI provides details about the event stream async API and
15// associated EventStream shapes.
16type EventStreamAPI struct {
17	API       *API
18	Name      string
19	Operation *Operation
20	Shape     *Shape
21	Inbound   *EventStream
22	Outbound  *EventStream
23}
24
25// EventStream represents a single eventstream group (input/output) and the
26// modeled events that are known for the stream.
27type EventStream struct {
28	Name       string
29	RefName    string
30	Shape      *Shape
31	Events     []*Event
32	Exceptions []*Event
33}
34
35// Event is a single EventStream event that can be sent or received in an
36// EventStream.
37type Event struct {
38	Name  string
39	Shape *Shape
40	For   *EventStream
41}
42
43// ShapeDoc returns the docstring for the EventStream API.
44func (esAPI *EventStreamAPI) ShapeDoc() string {
45	tmpl := template.Must(template.New("eventStreamShapeDoc").Parse(`
46{{- $.Name }} provides handling of EventStreams for
47the {{ $.Operation.ExportedName }} API.
48{{- if $.Inbound }}
49
50Use this type to receive {{ $.Inbound.Name }} events. The events
51can be read from the Events channel member.
52
53The events that can be received are:
54{{ range $_, $event := $.Inbound.Events }}
55    * {{ $event.Shape.ShapeName }}
56{{- end }}
57
58{{- end }}
59
60{{- if $.Outbound }}
61
62Use this type to send {{ $.Outbound.Name }} events. The events
63can be sent with the Send method.
64
65The events that can be sent are:
66{{ range $_, $event := $.Outbound.Events -}}
67    * {{ $event.Shape.ShapeName }}
68{{- end }}
69
70{{- end }}`))
71
72	var w bytes.Buffer
73	if err := tmpl.Execute(&w, esAPI); err != nil {
74		panic(fmt.Sprintf("failed to generate eventstream shape template for %v, %v", esAPI.Name, err))
75	}
76
77	return commentify(w.String())
78}
79
80func hasEventStream(topShape *Shape) bool {
81	for _, ref := range topShape.MemberRefs {
82		if ref.Shape.IsEventStream {
83			return true
84		}
85	}
86
87	return false
88}
89
90func eventStreamAPIShapeRefDoc(refName string) string {
91	return commentify(fmt.Sprintf("Use %s to use the API's stream.", refName))
92}
93
94func (a *API) setupEventStreams() error {
95	const eventStreamMemberName = "EventStream"
96
97	for opName, op := range a.Operations {
98		outbound := setupEventStream(op.InputRef.Shape)
99		inbound := setupEventStream(op.OutputRef.Shape)
100
101		if outbound == nil && inbound == nil {
102			continue
103		}
104
105		if outbound != nil {
106			err := fmt.Errorf("Outbound stream support not implemented, %s, %s",
107				outbound.Name, outbound.Shape.ShapeName)
108
109			if a.IgnoreUnsupportedAPIs {
110				fmt.Fprintf(os.Stderr, "removing operation, %s, %v\n", opName, err)
111				delete(a.Operations, opName)
112				continue
113			}
114			return UnsupportedAPIModelError{
115				Err: err,
116			}
117		}
118
119		switch a.Metadata.Protocol {
120		case `rest-json`, `rest-xml`, `json`:
121		default:
122			return UnsupportedAPIModelError{
123				Err: fmt.Errorf("EventStream not supported for protocol %v",
124					a.Metadata.Protocol),
125			}
126		}
127
128		op.EventStreamAPI = &EventStreamAPI{
129			API:       a,
130			Name:      op.ExportedName + eventStreamMemberName,
131			Operation: op,
132			Outbound:  outbound,
133			Inbound:   inbound,
134		}
135
136		streamShape := &Shape{
137			API:            a,
138			ShapeName:      op.EventStreamAPI.Name,
139			Documentation:  op.EventStreamAPI.ShapeDoc(),
140			Type:           "structure",
141			EventStreamAPI: op.EventStreamAPI,
142			IsEventStream:  true,
143			MemberRefs: map[string]*ShapeRef{
144				"Inbound": &ShapeRef{
145					ShapeName: inbound.Shape.ShapeName,
146					Shape:     inbound.Shape,
147				},
148			},
149		}
150		// persist reference to the original inbound shape so its event types
151		// can be walked during generation.
152		inbound.Shape.refs = append(inbound.Shape.refs, streamShape.MemberRefs["Inbound"])
153
154		streamShapeRef := &ShapeRef{
155			API:           a,
156			ShapeName:     streamShape.ShapeName,
157			Shape:         streamShape,
158			Documentation: eventStreamAPIShapeRefDoc(inbound.RefName),
159		}
160		streamShape.refs = []*ShapeRef{streamShapeRef}
161		op.EventStreamAPI.Shape = streamShape
162
163		op.OutputRef.Shape.MemberRefs[inbound.RefName] = streamShapeRef
164		op.OutputRef.Shape.EventStreamsMemberName = inbound.RefName
165
166		if s, ok := a.Shapes[streamShape.ShapeName]; ok {
167			s.Rename(streamShape.ShapeName + "Data")
168		}
169		a.Shapes[streamShape.ShapeName] = streamShape
170
171		a.HasEventStream = true
172	}
173
174	return nil
175}
176
177func setupEventStream(topShape *Shape) *EventStream {
178	var eventStream *EventStream
179	for refName, ref := range topShape.MemberRefs {
180		if !ref.Shape.IsEventStream {
181			continue
182		}
183		if eventStream != nil {
184			// Only on event stream member within an Input/Output shape is valid.
185			panic(fmt.Sprintf("multiple shape ref eventstreams, %s, prev: %s",
186				refName, eventStream.Name))
187		}
188
189		eventStream = &EventStream{
190			Name:    ref.Shape.ShapeName,
191			RefName: refName,
192			Shape:   ref.Shape,
193		}
194
195		if topShape.API.Metadata.Protocol == "json" {
196			topShape.EventFor = append(topShape.EventFor, eventStream)
197		}
198
199		for _, eventRefName := range ref.Shape.MemberNames() {
200			eventRef := ref.Shape.MemberRefs[eventRefName]
201			if !(eventRef.Shape.IsEvent || eventRef.Shape.Exception) {
202				panic(fmt.Sprintf("unexpected non-event member reference %s.%s",
203					ref.Shape.ShapeName, eventRefName))
204			}
205
206			updateEventPayloadRef(eventRef.Shape)
207
208			eventRef.Shape.EventFor = append(eventRef.Shape.EventFor, eventStream)
209
210			// Exceptions and events are two different lists to allow the SDK
211			// to easily generate code with the two handled differently.
212			event := &Event{
213				Name:  eventRefName,
214				Shape: eventRef.Shape,
215				For:   eventStream,
216			}
217			if eventRef.Shape.Exception {
218				eventStream.Exceptions = append(eventStream.Exceptions, event)
219			} else {
220				eventStream.Events = append(eventStream.Events, event)
221			}
222		}
223
224		// Remove the event stream member and shape as they will be added
225		// elsewhere.
226		ref.Shape.removeRef(ref)
227		delete(topShape.MemberRefs, refName)
228		delete(topShape.API.Shapes, ref.Shape.ShapeName)
229	}
230
231	return eventStream
232}
233
234func updateEventPayloadRef(parent *Shape) {
235	refName := parent.PayloadRefName()
236	if len(refName) == 0 {
237		return
238	}
239
240	payloadRef := parent.MemberRefs[refName]
241
242	if payloadRef.Shape.Type == "blob" {
243		return
244	}
245
246	if len(payloadRef.LocationName) != 0 {
247		return
248	}
249
250	payloadRef.LocationName = refName
251}
252
253func renderEventStreamAPIShape(w io.Writer, s *Shape) error {
254	// Imports needed by the EventStream APIs.
255	s.API.AddImport("fmt")
256	s.API.AddImport("bytes")
257	s.API.AddImport("io")
258	s.API.AddImport("sync")
259	s.API.AddImport("sync/atomic")
260	s.API.AddSDKImport("aws")
261	s.API.AddSDKImport("aws/awserr")
262	s.API.AddSDKImport("private/protocol/eventstream")
263	s.API.AddSDKImport("private/protocol/eventstream/eventstreamapi")
264
265	return eventStreamAPIShapeTmpl.Execute(w, s)
266}
267
268// Template for an EventStream API Shape that will provide read/writing events
269// across the EventStream. This is a special shape that's only public members
270// are the Events channel and a Close and Err method.
271//
272// Executed in the context of a Shape.
273var eventStreamAPIShapeTmpl = func() *template.Template {
274	t := template.Must(
275		template.New("eventStreamAPIShapeTmpl").
276			Funcs(template.FuncMap{}).
277			Parse(eventStreamAPITmplDef),
278	)
279
280	template.Must(
281		t.AddParseTree(
282			"eventStreamAPIReaderTmpl", eventStreamAPIReaderTmpl.Tree),
283	)
284
285	return t
286}()
287
288const eventStreamAPITmplDef = `
289{{ $.Documentation }}
290type {{ $.ShapeName }} struct {
291	{{- if $.EventStreamAPI.Inbound }}
292		// Reader is the EventStream reader for the {{ $.EventStreamAPI.Inbound.Name }}
293		// events. This value is automatically set by the SDK when the API call is made
294		// Use this member when unit testing your code with the SDK to mock out the
295		// EventStream Reader.
296		//
297		// Must not be nil.
298		Reader {{ $.ShapeName }}Reader
299
300	{{ end -}}
301
302	{{- if $.EventStreamAPI.Outbound }}
303		// Writer is the EventStream reader for the {{ $.EventStreamAPI.Inbound.Name }}
304		// events. This value is automatically set by the SDK when the API call is made
305		// Use this member when unit testing your code with the SDK to mock out the
306		// EventStream Writer.
307		//
308		// Must not be nil.
309		Writer *{{ $.ShapeName }}Writer
310
311	{{ end -}}
312
313	// StreamCloser is the io.Closer for the EventStream connection. For HTTP
314	// EventStream this is the response Body. The stream will be closed when
315	// the Close method of the EventStream is called.
316	StreamCloser io.Closer
317}
318
319// Close closes the EventStream. This will also cause the Events channel to be
320// closed. You can use the closing of the Events channel to terminate your
321// application's read from the API's EventStream.
322{{- if $.EventStreamAPI.Inbound }}
323//
324// Will close the underlying EventStream reader. For EventStream over HTTP
325// connection this will also close the HTTP connection.
326{{ end -}}
327//
328// Close must be called when done using the EventStream API. Not calling Close
329// may result in resource leaks.
330func (es *{{ $.ShapeName }}) Close() (err error) {
331	{{- if $.EventStreamAPI.Inbound }}
332		es.Reader.Close()
333	{{ end -}}
334	{{- if $.EventStreamAPI.Outbound }}
335		es.Writer.Close()
336	{{ end -}}
337
338	es.StreamCloser.Close()
339
340	return es.Err()
341}
342
343// Err returns any error that occurred while reading EventStream Events from
344// the service API's response. Returns nil if there were no errors.
345func (es *{{ $.ShapeName }}) Err() error {
346	{{- if $.EventStreamAPI.Outbound }}
347		if err := es.Writer.Err(); err != nil {
348			return err
349		}
350	{{ end -}}
351
352	{{- if $.EventStreamAPI.Inbound }}
353		if err := es.Reader.Err(); err != nil {
354			return err
355		}
356	{{ end -}}
357
358	return nil
359}
360
361{{ if $.EventStreamAPI.Inbound }}
362	// Events returns a channel to read EventStream Events from the
363	// {{ $.EventStreamAPI.Operation.ExportedName }} API.
364	//
365	// These events are:
366	// {{ range $_, $event := $.EventStreamAPI.Inbound.Events }}
367	//     * {{ $event.Shape.ShapeName }}
368	{{- end }}
369	func (es *{{ $.ShapeName }}) Events() <-chan {{ $.EventStreamAPI.Inbound.Name }}Event {
370		return es.Reader.Events()
371	}
372
373	{{ template "eventStreamAPIReaderTmpl" $ }}
374{{ end }}
375
376{{ if $.EventStreamAPI.Outbound }}
377	// TODO writer helper method.
378{{ end }}
379
380`
381
382var eventStreamAPIReaderTmpl = template.Must(template.New("eventStreamAPIReaderTmpl").
383	Funcs(template.FuncMap{}).
384	Parse(`
385// {{ $.EventStreamAPI.Inbound.Name }}Event groups together all EventStream
386// events read from the {{ $.EventStreamAPI.Operation.ExportedName }} API.
387//
388// These events are:
389// {{ range $_, $event := $.EventStreamAPI.Inbound.Events }}
390//     * {{ $event.Shape.ShapeName }}
391{{- end }}
392type {{ $.EventStreamAPI.Inbound.Name }}Event interface {
393	event{{ $.EventStreamAPI.Inbound.Name }}()
394}
395
396// {{ $.ShapeName }}Reader provides the interface for reading EventStream
397// Events from the {{ $.EventStreamAPI.Operation.ExportedName }} API. The
398// default implementation for this interface will be {{ $.ShapeName }}.
399//
400// The reader's Close method must allow multiple concurrent calls.
401//
402// These events are:
403// {{ range $_, $event := $.EventStreamAPI.Inbound.Events }}
404//     * {{ $event.Shape.ShapeName }}
405{{- end }}
406type {{ $.ShapeName }}Reader interface {
407	// Returns a channel of events as they are read from the event stream.
408	Events() <-chan {{ $.EventStreamAPI.Inbound.Name }}Event
409
410	// Close will close the underlying event stream reader. For event stream over
411	// HTTP this will also close the HTTP connection.
412	Close() error
413
414	// Returns any error that has occurred while reading from the event stream.
415	Err() error
416}
417
418type read{{ $.ShapeName }} struct {
419	eventReader *eventstreamapi.EventReader
420	stream chan {{ $.EventStreamAPI.Inbound.Name }}Event
421	errVal atomic.Value
422
423	done      chan struct{}
424	closeOnce sync.Once
425
426	{{ if eq $.API.Metadata.Protocol "json" -}}
427		initResp eventstreamapi.Unmarshaler
428	{{ end -}}
429}
430
431func newRead{{ $.ShapeName }}(
432	reader io.ReadCloser,
433	unmarshalers request.HandlerList,
434	logger aws.Logger,
435	logLevel aws.LogLevelType,
436	{{ if eq $.API.Metadata.Protocol "json" -}}
437		initResp eventstreamapi.Unmarshaler,
438	{{ end -}}
439) *read{{ $.ShapeName }} {
440	r := &read{{ $.ShapeName }}{
441		stream: make(chan {{ $.EventStreamAPI.Inbound.Name }}Event),
442		done: make(chan struct{}),
443		{{ if eq $.API.Metadata.Protocol "json" -}}
444			initResp: initResp,
445		{{ end -}}
446	}
447
448	r.eventReader = eventstreamapi.NewEventReader(
449		reader,
450		protocol.HandlerPayloadUnmarshal{
451			Unmarshalers: unmarshalers,
452		},
453		r.unmarshalerForEventType,
454	)
455	r.eventReader.UseLogger(logger, logLevel)
456
457	return r
458}
459
460// Close will close the underlying event stream reader. For EventStream over
461// HTTP this will also close the HTTP connection.
462func (r *read{{ $.ShapeName }}) Close() error {
463	r.closeOnce.Do(r.safeClose)
464
465	return r.Err()
466}
467
468func (r *read{{ $.ShapeName }}) safeClose() {
469	close(r.done)
470	err := r.eventReader.Close()
471	if err != nil {
472		r.errVal.Store(err)
473	}
474}
475
476func (r *read{{ $.ShapeName }}) Err() error {
477	if v := r.errVal.Load(); v != nil {
478		return v.(error)
479	}
480
481	return nil
482}
483
484func (r *read{{ $.ShapeName }}) Events() <-chan {{ $.EventStreamAPI.Inbound.Name }}Event {
485	return r.stream
486}
487
488func (r *read{{ $.ShapeName }}) readEventStream() {
489	defer close(r.stream)
490
491	for {
492		event, err := r.eventReader.ReadEvent()
493		if err != nil {
494			if err == io.EOF {
495				return
496			}
497			select {
498			case <-r.done:
499				// If closed already ignore the error
500				return
501			default:
502			}
503			r.errVal.Store(err)
504			return
505		}
506
507		select {
508		case r.stream <- event.({{ $.EventStreamAPI.Inbound.Name }}Event):
509		case <-r.done:
510			return
511		}
512	}
513}
514
515func (r *read{{ $.ShapeName }}) unmarshalerForEventType(
516	eventType string,
517) (eventstreamapi.Unmarshaler, error) {
518	switch eventType {
519		{{- if eq $.API.Metadata.Protocol "json" }}
520			case "initial-response":
521				return r.initResp, nil
522		{{ end -}}
523		{{- range $_, $event := $.EventStreamAPI.Inbound.Events }}
524			case {{ printf "%q" $event.Name }}:
525				return &{{ $event.Shape.ShapeName }}{}, nil
526		{{ end -}}
527		{{- range $_, $event := $.EventStreamAPI.Inbound.Exceptions }}
528			case {{ printf "%q" $event.Name }}:
529				return &{{ $event.Shape.ShapeName }}{}, nil
530		{{ end -}}
531	default:
532		return nil, awserr.New(
533			request.ErrCodeSerialization,
534			fmt.Sprintf("unknown event type name, %s, for {{ $.ShapeName }}", eventType),
535			nil,
536		)
537	}
538}
539`))
540
541// Template for the EventStream API Output shape that contains the EventStream
542// member.
543//
544// Executed in the context of a Shape.
545var eventStreamAPILoopMethodTmpl = template.Must(
546	template.New("eventStreamAPILoopMethodTmpl").Parse(`
547func (s *{{ $.ShapeName }}) runEventStreamLoop(r *request.Request) {
548	if r.Error != nil {
549		return
550	}
551
552	{{- $esMemberRef := index $.MemberRefs $.EventStreamsMemberName }}
553	{{- if $esMemberRef.Shape.EventStreamAPI.Inbound }}
554		reader := newRead{{ $esMemberRef.ShapeName }}(
555			r.HTTPResponse.Body,
556			r.Handlers.UnmarshalStream,
557			r.Config.Logger,
558			r.Config.LogLevel.Value(),
559			{{ if eq $.API.Metadata.Protocol "json" -}}
560				s,
561			{{ end -}}
562		)
563		go reader.readEventStream()
564
565		eventStream := &{{ $esMemberRef.ShapeName }} {
566			StreamCloser: r.HTTPResponse.Body,
567			Reader: reader,
568		}
569	{{ end -}}
570
571	s.{{ $.EventStreamsMemberName }} = eventStream
572}
573
574{{ if eq $.API.Metadata.Protocol "json" -}}
575	func (s *{{ $.ShapeName }}) unmarshalInitialResponse(r *request.Request) {
576		// Wait for the initial response event, which must be the first event to be
577		// received from the API.
578		select {
579		case event, ok := <-s.{{ $.EventStreamsMemberName }}.Events():
580			if !ok {
581				return
582			}
583			es := s.{{ $.EventStreamsMemberName }}
584			v, ok := event.(*{{ $.ShapeName }})
585			if !ok || v == nil {
586				r.Error = awserr.New(
587					request.ErrCodeSerialization,
588					fmt.Sprintf("invalid event, %T, expect *SubscribeToShardOutput, %v", event, v),
589					nil,
590				)
591				return
592			}
593			*s = *v
594			s.{{ $.EventStreamsMemberName }} = es
595		}
596	}
597{{ end -}}
598`))
599
600// EventStreamHeaderTypeMap provides the mapping of a EventStream Header's
601// Value type to the shape reference's member type.
602type EventStreamHeaderTypeMap struct {
603	Header string
604	Member string
605}
606
607var eventStreamEventShapeTmplFuncs = template.FuncMap{
608	"EventStreamHeaderTypeMap": func(ref *ShapeRef) EventStreamHeaderTypeMap {
609		switch ref.Shape.Type {
610		case "boolean":
611			return EventStreamHeaderTypeMap{Header: "bool", Member: "bool"}
612		case "byte":
613			return EventStreamHeaderTypeMap{Header: "int8", Member: "int64"}
614		case "short":
615			return EventStreamHeaderTypeMap{Header: "int16", Member: "int64"}
616		case "integer":
617			return EventStreamHeaderTypeMap{Header: "int32", Member: "int64"}
618		case "long":
619			return EventStreamHeaderTypeMap{Header: "int64", Member: "int64"}
620		case "timestamp":
621			return EventStreamHeaderTypeMap{Header: "time.Time", Member: "time.Time"}
622		case "blob":
623			return EventStreamHeaderTypeMap{Header: "[]byte", Member: "[]byte"}
624		case "string":
625			return EventStreamHeaderTypeMap{Header: "string", Member: "string"}
626		// TODO case "uuid"  what is modeled type
627		default:
628			panic("unsupported EventStream header type, " + ref.Shape.Type)
629		}
630	},
631	"HasNonBlobPayloadMembers": eventHasNonBlobPayloadMembers,
632}
633
634// Returns if the event has any members which are not the event's blob payload,
635// nor a header.
636func eventHasNonBlobPayloadMembers(s *Shape) bool {
637	num := len(s.MemberRefs)
638	for _, ref := range s.MemberRefs {
639		if ref.IsEventHeader || (ref.IsEventPayload && (ref.Shape.Type == "blob" || ref.Shape.Type == "string")) {
640			num--
641		}
642	}
643	return num > 0
644}
645
646// Template for an EventStream Event shape. This is a normal API shape that is
647// decorated as an EventStream Event.
648//
649// Executed in the context of a Shape.
650var eventStreamEventShapeTmpl = template.Must(template.New("eventStreamEventShapeTmpl").
651	Funcs(eventStreamEventShapeTmplFuncs).Parse(`
652{{ range $_, $eventstream := $.EventFor }}
653	// The {{ $.ShapeName }} is and event in the {{ $eventstream.Name }} group of events.
654	func (s *{{ $.ShapeName }}) event{{ $eventstream.Name }}() {}
655{{ end }}
656
657// UnmarshalEvent unmarshals the EventStream Message into the {{ $.ShapeName }} value.
658// This method is only used internally within the SDK's EventStream handling.
659func (s *{{ $.ShapeName }}) UnmarshalEvent(
660	payloadUnmarshaler protocol.PayloadUnmarshaler,
661	msg eventstream.Message,
662) error {
663	{{- range $memName, $memRef := $.MemberRefs }}
664		{{- if $memRef.IsEventHeader }}
665			if hv := msg.Headers.Get("{{ $memName }}"); hv != nil {
666				{{ $types := EventStreamHeaderTypeMap $memRef -}}
667				v := hv.Get().({{ $types.Header }})
668				{{- if ne $types.Header $types.Member }}
669					m := {{ $types.Member }}(v)
670					s.{{ $memName }} = {{ if $memRef.UseIndirection }}&{{ end }}m
671				{{- else }}
672					s.{{ $memName }} = {{ if $memRef.UseIndirection }}&{{ end }}v
673				{{- end }}
674			}
675		{{- else if (and ($memRef.IsEventPayload) (eq $memRef.Shape.Type "blob")) }}
676			s.{{ $memName }} = make([]byte, len(msg.Payload))
677			copy(s.{{ $memName }}, msg.Payload)
678		{{- else if (and ($memRef.IsEventPayload) (eq $memRef.Shape.Type "string")) }}
679			s.{{ $memName }} = aws.String(string(msg.Payload))
680		{{- end }}
681	{{- end }}
682	{{- if HasNonBlobPayloadMembers $ }}
683		if err := payloadUnmarshaler.UnmarshalPayload(
684			bytes.NewReader(msg.Payload), s,
685		); err != nil {
686			return err
687		}
688	{{- end }}
689	return nil
690}
691`))
692
693var eventStreamExceptionEventShapeTmpl = template.Must(
694	template.New("eventStreamExceptionEventShapeTmpl").Parse(`
695// Code returns the exception type name.
696func (s {{ $.ShapeName }}) Code() string {
697	{{- if $.ErrorInfo.Code }}
698		return "{{ $.ErrorInfo.Code }}"
699	{{- else }}
700		return "{{ $.ShapeName }}"
701	{{ end -}}
702}
703
704// Message returns the exception's message.
705func (s {{ $.ShapeName }}) Message() string {
706	{{- if index $.MemberRefs "Message_" }}
707		return *s.Message_
708	{{- else }}
709		return ""
710	{{ end -}}
711}
712
713// OrigErr always returns nil, satisfies awserr.Error interface.
714func (s {{ $.ShapeName }}) OrigErr() error {
715	return nil
716}
717
718func (s {{ $.ShapeName }}) Error() string {
719	return fmt.Sprintf("%s: %s", s.Code(), s.Message())
720}
721`))
722
723// APIEventStreamTestGoCode generates Go code for EventStream operation tests.
724func (a *API) APIEventStreamTestGoCode() string {
725	var buf bytes.Buffer
726
727	a.resetImports()
728	a.AddImport("bytes")
729	a.AddImport("io/ioutil")
730	a.AddImport("net/http")
731	a.AddImport("reflect")
732	a.AddImport("testing")
733	a.AddImport("time")
734	a.AddSDKImport("aws")
735	a.AddSDKImport("aws/corehandlers")
736	a.AddSDKImport("aws/request")
737	a.AddSDKImport("aws/awserr")
738	a.AddSDKImport("awstesting/unit")
739	a.AddSDKImport("private/protocol")
740	a.AddSDKImport("private/protocol/", a.ProtocolPackage())
741	a.AddSDKImport("private/protocol/eventstream")
742	a.AddSDKImport("private/protocol/eventstream/eventstreamapi")
743	a.AddSDKImport("private/protocol/eventstream/eventstreamtest")
744
745	unused := `
746	var _ time.Time
747	var _ awserr.Error
748	`
749
750	if err := eventStreamTestTmpl.Execute(&buf, a); err != nil {
751		panic(err)
752	}
753
754	return a.importsGoCode() + unused + strings.TrimSpace(buf.String())
755}
756
757func valueForType(s *Shape, visited []string) string {
758	for _, v := range visited {
759		if v == s.ShapeName {
760			return "nil"
761		}
762	}
763
764	visited = append(visited, s.ShapeName)
765
766	switch s.Type {
767	case "blob":
768		return `[]byte("blob value goes here")`
769	case "string":
770		return `aws.String("string value goes here")`
771	case "boolean":
772		return `aws.Bool(true)`
773	case "byte":
774		return `aws.Int64(1)`
775	case "short":
776		return `aws.Int64(12)`
777	case "integer":
778		return `aws.Int64(123)`
779	case "long":
780		return `aws.Int64(1234)`
781	case "float":
782		return `aws.Float64(123.4)`
783	case "double":
784		return `aws.Float64(123.45)`
785	case "timestamp":
786		return `aws.Time(time.Unix(1396594860, 0).UTC())`
787	case "structure":
788		w := bytes.NewBuffer(nil)
789		fmt.Fprintf(w, "&%s{\n", s.ShapeName)
790		for _, refName := range s.MemberNames() {
791			fmt.Fprintf(w, "%s: %s,\n", refName, valueForType(s.MemberRefs[refName].Shape, visited))
792		}
793		fmt.Fprintf(w, "}")
794		return w.String()
795	case "list":
796		w := bytes.NewBuffer(nil)
797		fmt.Fprintf(w, "%s{\n", s.GoType())
798		for i := 0; i < 3; i++ {
799			fmt.Fprintf(w, "%s,\n", valueForType(s.MemberRef.Shape, visited))
800		}
801		fmt.Fprintf(w, "}")
802		return w.String()
803
804	case "map":
805		w := bytes.NewBuffer(nil)
806		fmt.Fprintf(w, "%s{\n", s.GoType())
807		for _, k := range []string{"a", "b", "c"} {
808			fmt.Fprintf(w, "%q: %s,\n", k, valueForType(s.ValueRef.Shape, visited))
809		}
810		fmt.Fprintf(w, "}")
811		return w.String()
812
813	default:
814		panic(fmt.Sprintf("valueForType does not support %s, %s", s.ShapeName, s.Type))
815	}
816}
817
818func setEventHeaderValueForType(s *Shape, memVar string) string {
819	switch s.Type {
820	case "blob":
821		return fmt.Sprintf("eventstream.BytesValue(%s)", memVar)
822	case "string":
823		return fmt.Sprintf("eventstream.StringValue(*%s)", memVar)
824	case "boolean":
825		return fmt.Sprintf("eventstream.BoolValue(*%s)", memVar)
826	case "byte":
827		return fmt.Sprintf("eventstream.Int8Value(int8(*%s))", memVar)
828	case "short":
829		return fmt.Sprintf("eventstream.Int16Value(int16(*%s))", memVar)
830	case "integer":
831		return fmt.Sprintf("eventstream.Int32Value(int32(*%s))", memVar)
832	case "long":
833		return fmt.Sprintf("eventstream.Int64Value(*%s)", memVar)
834	case "float":
835		return fmt.Sprintf("eventstream.Float32Value(float32(*%s))", memVar)
836	case "double":
837		return fmt.Sprintf("eventstream.Float64Value(*%s)", memVar)
838	case "timestamp":
839		return fmt.Sprintf("eventstream.TimestampValue(*%s)", memVar)
840	default:
841		panic(fmt.Sprintf("value type %s not supported for event headers, %s", s.Type, s.ShapeName))
842	}
843}
844
845func templateMap(args ...interface{}) map[string]interface{} {
846	if len(args)%2 != 0 {
847		panic(fmt.Sprintf("invalid map call, non-even args %v", args))
848	}
849
850	m := map[string]interface{}{}
851	for i := 0; i < len(args); i += 2 {
852		k, ok := args[i].(string)
853		if !ok {
854			panic(fmt.Sprintf("invalid map call, arg is not string, %T, %v", args[i], args[i]))
855		}
856		m[k] = args[i+1]
857	}
858
859	return m
860}
861
862var eventStreamTestTmpl = template.Must(
863	template.New("eventStreamTestTmpl").Funcs(template.FuncMap{
864		"ValueForType":               valueForType,
865		"HasNonBlobPayloadMembers":   eventHasNonBlobPayloadMembers,
866		"SetEventHeaderValueForType": setEventHeaderValueForType,
867		"Map":                        templateMap,
868		"OptionalAddInt": func(do bool, a, b int) int {
869			if !do {
870				return a
871			}
872			return a + b
873		},
874		"HasNonEventStreamMember": func(s *Shape) bool {
875			for _, ref := range s.MemberRefs {
876				if !ref.Shape.IsEventStream {
877					return true
878				}
879			}
880			return false
881		},
882	}).Parse(`
883{{ range $opName, $op := $.Operations }}
884	{{ if $op.EventStreamAPI }}
885		{{ if $op.EventStreamAPI.Inbound }}
886			{{ template "event stream inbound tests" $op.EventStreamAPI }}
887		{{ end }}
888	{{ end }}
889{{ end }}
890
891type loopReader struct {
892	source *bytes.Reader
893}
894
895func (c *loopReader) Read(p []byte) (int, error) {
896	if c.source.Len() == 0 {
897		c.source.Seek(0, 0)
898	}
899
900	return c.source.Read(p)
901}
902
903{{ define "event stream inbound tests" }}
904    {{- $esMemberName := $.Operation.OutputRef.Shape.EventStreamsMemberName }}
905	func Test{{ $.Operation.ExportedName }}_Read(t *testing.T) {
906		expectEvents, eventMsgs := mock{{ $.Operation.ExportedName }}ReadEvents()
907		sess, cleanupFn, err := eventstreamtest.SetupEventStreamSession(t,
908			eventstreamtest.ServeEventStream{
909				T:      t,
910				Events: eventMsgs,
911			},
912			true,
913		)
914		if err != nil {
915			t.Fatalf("expect no error, %v", err)
916		}
917		defer cleanupFn()
918
919		svc := New(sess)
920		resp, err := svc.{{ $.Operation.ExportedName }}(nil)
921		if err != nil {
922			t.Fatalf("expect no error got, %v", err)
923		}
924		defer resp.{{ $esMemberName }}.Close()
925
926		{{- if eq $.Operation.API.Metadata.Protocol "json" }}
927			{{- if HasNonEventStreamMember $.Operation.OutputRef.Shape }}
928				expectResp := expectEvents[0].(*{{ $.Operation.OutputRef.Shape.ShapeName }})
929				{{- range $name, $ref := $.Operation.OutputRef.Shape.MemberRefs }}
930					{{- if not $ref.Shape.IsEventStream }}
931						if e, a := expectResp.{{ $name }}, resp.{{ $name }}; !reflect.DeepEqual(e,a) {
932							t.Errorf("expect %v, got %v", e, a)
933						}
934					{{- end }}
935				{{- end }}
936			{{- end }}
937			// Trim off response output type pseudo event so only event messages remain.
938			expectEvents = expectEvents[1:]
939		{{ end }}
940
941		var i int
942		for event := range resp.{{ $esMemberName }}.Events() {
943			if event == nil {
944				t.Errorf("%d, expect event, got nil", i)
945			}
946			if e, a := expectEvents[i], event; !reflect.DeepEqual(e, a) {
947				t.Errorf("%d, expect %T %v, got %T %v", i, e, e, a, a)
948			}
949			i++
950		}
951
952		if err := resp.{{ $esMemberName }}.Err(); err != nil {
953			t.Errorf("expect no error, %v", err)
954		}
955	}
956
957	func Test{{ $.Operation.ExportedName }}_ReadClose(t *testing.T) {
958		_, eventMsgs := mock{{ $.Operation.ExportedName }}ReadEvents()
959		sess, cleanupFn, err := eventstreamtest.SetupEventStreamSession(t,
960			eventstreamtest.ServeEventStream{
961				T:      t,
962				Events: eventMsgs,
963			},
964			true,
965		)
966		if err != nil {
967			t.Fatalf("expect no error, %v", err)
968		}
969		defer cleanupFn()
970
971		svc := New(sess)
972		resp, err := svc.{{ $.Operation.ExportedName }}(nil)
973		if err != nil {
974			t.Fatalf("expect no error got, %v", err)
975		}
976
977		{{ if gt (len $.Inbound.Events) 0 -}}
978			// Assert calling Err before close does not close the stream.
979			resp.{{ $esMemberName }}.Err()
980			select {
981			case _, ok := <-resp.{{ $esMemberName }}.Events():
982				if !ok {
983					t.Fatalf("expect stream not to be closed, but was")
984				}
985			default:
986			}
987		{{- end }}
988
989		resp.{{ $esMemberName }}.Close()
990		<-resp.{{ $esMemberName }}.Events()
991
992		if err := resp.{{ $esMemberName }}.Err(); err != nil {
993			t.Errorf("expect no error, %v", err)
994		}
995	}
996
997	func Benchmark{{ $.Operation.ExportedName }}_Read(b *testing.B) {
998		_, eventMsgs := mock{{ $.Operation.ExportedName }}ReadEvents()
999		var buf bytes.Buffer
1000		encoder := eventstream.NewEncoder(&buf)
1001		for _, msg := range eventMsgs {
1002			if err := encoder.Encode(msg); err != nil {
1003				b.Fatalf("failed to encode message, %v", err)
1004			}
1005		}
1006		stream := &loopReader{source: bytes.NewReader(buf.Bytes())}
1007
1008		sess := unit.Session
1009		svc := New(sess, &aws.Config{
1010			Endpoint:               aws.String("https://example.com"),
1011			DisableParamValidation: aws.Bool(true),
1012		})
1013		svc.Handlers.Send.Swap(corehandlers.SendHandler.Name,
1014			request.NamedHandler{Name: "mockSend",
1015				Fn: func(r *request.Request) {
1016					r.HTTPResponse = &http.Response{
1017						Status:     "200 OK",
1018						StatusCode: 200,
1019						Header:     http.Header{},
1020						Body:       ioutil.NopCloser(stream),
1021					}
1022				},
1023			},
1024		)
1025
1026		resp, err := svc.{{ $.Operation.ExportedName }}(nil)
1027		if err != nil {
1028			b.Fatalf("failed to create request, %v", err)
1029		}
1030		defer resp.{{ $esMemberName }}.Close()
1031		b.ResetTimer()
1032
1033		for i := 0; i < b.N; i++ {
1034			if err = resp.{{ $esMemberName }}.Err(); err != nil {
1035				b.Fatalf("expect no error, got %v", err)
1036			}
1037			event := <-resp.{{ $esMemberName }}.Events()
1038			if event == nil {
1039				b.Fatalf("expect event, got nil, %v, %d", resp.{{ $esMemberName }}.Err(), i)
1040			}
1041		}
1042	}
1043
1044	func mock{{ $.Operation.ExportedName }}ReadEvents() (
1045		[]{{ $.Inbound.Name }}Event,
1046		[]eventstream.Message,
1047	) {
1048		expectEvents := []{{ $.Inbound.Name }}Event {
1049			{{- if eq $.Operation.API.Metadata.Protocol "json" }}
1050				{{- template "set event type" $.Operation.OutputRef.Shape }}
1051			{{- end }}
1052			{{- range $_, $event := $.Inbound.Events }}
1053				{{- template "set event type" $event.Shape }}
1054			{{- end }}
1055		}
1056
1057		var marshalers request.HandlerList
1058		marshalers.PushBackNamed({{ $.API.ProtocolPackage }}.BuildHandler)
1059		payloadMarshaler := protocol.HandlerPayloadMarshal{
1060			Marshalers: marshalers,
1061		}
1062		_ = payloadMarshaler
1063
1064		eventMsgs := []eventstream.Message{
1065			{{- if eq $.Operation.API.Metadata.Protocol "json" }}
1066				{{- template "set event message" Map "idx" 0 "parentShape" $.Operation.OutputRef.Shape "eventName" "initial-response" }}
1067			{{- end }}
1068			{{- range $idx, $event := $.Inbound.Events }}
1069				{{- $offsetIdx := OptionalAddInt (eq $.Operation.API.Metadata.Protocol "json") $idx 1 }}
1070				{{- template "set event message" Map "idx" $offsetIdx "parentShape" $event.Shape "eventName" $event.Name }}
1071			{{- end }}
1072		}
1073
1074		return expectEvents, eventMsgs
1075	}
1076
1077	{{- if $.Inbound.Exceptions }}
1078		func Test{{ $.Operation.ExportedName }}_ReadException(t *testing.T) {
1079			expectEvents := []{{ $.Inbound.Name }}Event {
1080				{{- if eq $.Operation.API.Metadata.Protocol "json" }}
1081					{{- template "set event type" $.Operation.OutputRef.Shape }}
1082				{{- end }}
1083
1084				{{- $exception := index $.Inbound.Exceptions 0 }}
1085				{{- template "set event type" $exception.Shape }}
1086			}
1087
1088			var marshalers request.HandlerList
1089			marshalers.PushBackNamed({{ $.API.ProtocolPackage }}.BuildHandler)
1090			payloadMarshaler := protocol.HandlerPayloadMarshal{
1091				Marshalers: marshalers,
1092			}
1093
1094			eventMsgs := []eventstream.Message{
1095				{{- if eq $.Operation.API.Metadata.Protocol "json" }}
1096					{{- template "set event message" Map "idx" 0 "parentShape" $.Operation.OutputRef.Shape "eventName" "initial-response" }}
1097				{{- end }}
1098
1099				{{- $offsetIdx := OptionalAddInt (eq $.Operation.API.Metadata.Protocol "json") 0 1 }}
1100				{{- $exception := index $.Inbound.Exceptions 0 }}
1101				{{- template "set event message" Map "idx" $offsetIdx "parentShape" $exception.Shape "eventName" $exception.Name }}
1102			}
1103
1104			sess, cleanupFn, err := eventstreamtest.SetupEventStreamSession(t,
1105				eventstreamtest.ServeEventStream{
1106					T:      t,
1107					Events: eventMsgs,
1108				},
1109				true,
1110			)
1111			if err != nil {
1112				t.Fatalf("expect no error, %v", err)
1113			}
1114			defer cleanupFn()
1115
1116			svc := New(sess)
1117			resp, err := svc.{{ $.Operation.ExportedName }}(nil)
1118			if err != nil {
1119				t.Fatalf("expect no error got, %v", err)
1120			}
1121
1122			defer resp.{{ $esMemberName }}.Close()
1123
1124			<-resp.{{ $esMemberName }}.Events()
1125
1126			err = resp.{{ $esMemberName }}.Err()
1127			if err == nil {
1128				t.Fatalf("expect err, got none")
1129			}
1130
1131			expectErr := {{ ValueForType $exception.Shape nil }}
1132			aerr, ok := err.(awserr.Error)
1133			if !ok {
1134				t.Errorf("expect exception, got %T, %#v", err, err)
1135			}
1136			if e, a := expectErr.Code(), aerr.Code(); e != a {
1137				t.Errorf("expect %v, got %v", e, a)
1138			}
1139			if e, a := expectErr.Message(), aerr.Message(); e != a {
1140				t.Errorf("expect %v, got %v", e, a)
1141			}
1142
1143			if e, a := expectErr, aerr; !reflect.DeepEqual(e, a) {
1144				t.Errorf("expect %#v, got %#v", e, a)
1145			}
1146		}
1147
1148		{{- range $_, $exception := $.Inbound.Exceptions }}
1149			var _ awserr.Error = (*{{ $exception.Shape.ShapeName }})(nil)
1150		{{- end }}
1151
1152	{{ end }}
1153{{ end }}
1154
1155{{/* Params: *Shape */}}
1156{{ define "set event type" }}
1157	&{{ $.ShapeName }}{
1158		{{- range $memName, $memRef := $.MemberRefs }}
1159			{{- if not $memRef.Shape.IsEventStream }}
1160				{{ $memName }}: {{ ValueForType $memRef.Shape nil }},
1161			{{- end }}
1162		{{- end }}
1163	},
1164{{- end }}
1165
1166{{/* Params: idx:int, parentShape:*Shape, eventName:string */}}
1167{{ define "set event message" }}
1168	{
1169		Headers: eventstream.Headers{
1170			{{- if $.parentShape.Exception }}
1171				eventstreamtest.EventExceptionTypeHeader,
1172				{
1173					Name:  eventstreamapi.ExceptionTypeHeader,
1174					Value: eventstream.StringValue("{{ $.eventName }}"),
1175				},
1176			{{- else }}
1177				eventstreamtest.EventMessageTypeHeader,
1178				{
1179					Name:  eventstreamapi.EventTypeHeader,
1180					Value: eventstream.StringValue("{{ $.eventName }}"),
1181				},
1182			{{- end }}
1183			{{- range $memName, $memRef := $.parentShape.MemberRefs }}
1184				{{- template "set event message header" Map "idx" $.idx "parentShape" $.parentShape "memName" $memName "memRef" $memRef }}
1185			{{- end }}
1186		},
1187		{{- template "set event message payload" Map "idx" $.idx "parentShape" $.parentShape }}
1188	},
1189{{- end }}
1190
1191{{/* Params: idx:int, parentShape:*Shape, memName:string, memRef:*ShapeRef */}}
1192{{ define "set event message header" }}
1193	{{- if $.memRef.IsEventHeader }}
1194		{
1195			Name: "{{ $.memName }}",
1196			{{- $shapeValueVar := printf "expectEvents[%d].(%s).%s" $.idx $.parentShape.GoType $.memName }}
1197			Value: {{ SetEventHeaderValueForType $.memRef.Shape $shapeValueVar }},
1198		},
1199	{{- end }}
1200{{- end }}
1201
1202{{/* Params: idx:int, parentShape:*Shape, memName:string, memRef:*ShapeRef */}}
1203{{ define "set event message payload" }}
1204	{{- $payloadMemName := $.parentShape.PayloadRefName }}
1205	{{- if HasNonBlobPayloadMembers $.parentShape }}
1206		Payload: eventstreamtest.MarshalEventPayload(payloadMarshaler, expectEvents[{{ $.idx }}]),
1207	{{- else if $payloadMemName }}
1208		{{- $shapeType := (index $.parentShape.MemberRefs $payloadMemName).Shape.Type }}
1209		{{- if eq $shapeType "blob" }}
1210			Payload: expectEvents[{{ $.idx }}].({{ $.parentShape.GoType }}).{{ $payloadMemName }},
1211		{{- else if eq $shapeType "string" }}
1212			Payload: []byte(*expectEvents[{{ $.idx }}].({{ $.parentShape.GoType }}).{{ $payloadMemName }}),
1213		{{- end }}
1214	{{- end }}
1215{{- end }}
1216`))
1217