1// +build codegen
2
3package api
4
5import (
6	"fmt"
7	"io"
8	"strings"
9	"text/template"
10)
11
12func renderEventStreamAPI(w io.Writer, op *Operation) error {
13	// Imports needed by the EventStream APIs.
14	op.API.AddImport("fmt")
15	op.API.AddImport("bytes")
16	op.API.AddImport("io")
17	op.API.AddImport("time")
18	op.API.AddSDKImport("aws")
19	op.API.AddSDKImport("aws/awserr")
20	op.API.AddSDKImport("aws/request")
21	op.API.AddSDKImport("private/protocol/eventstream")
22	op.API.AddSDKImport("private/protocol/eventstream/eventstreamapi")
23
24	w.Write([]byte(`
25var _ awserr.Error
26`))
27
28	return eventStreamAPITmpl.Execute(w, op)
29}
30
31// Template for an EventStream API Shape that will provide read/writing events
32// across the EventStream. This is a special shape that's only public members
33// are the Events channel and a Close and Err method.
34//
35// Executed in the context of a Shape.
36var eventStreamAPITmpl = template.Must(
37	template.New("eventStreamAPITmplDef").
38		Funcs(template.FuncMap{
39			"unexported": func(v string) string {
40				return strings.ToLower(string(v[0])) + v[1:]
41			},
42		}).
43		Parse(eventStreamAPITmplDef),
44)
45
46const eventStreamAPITmplDef = `
47{{- $esapi := $.EventStreamAPI }}
48{{- $outputStream := $esapi.OutputStream }}
49{{- $inputStream := $esapi.InputStream }}
50
51// {{ $esapi.Name }} provides the event stream handling for the {{ $.ExportedName }}.
52//
53// For testing and mocking the event stream this type should be initialized via
54// the New{{ $esapi.Name }} constructor function. Using the functional options
55// to pass in nested mock behavior.
56type {{ $esapi.Name }} struct {
57	{{- if $inputStream }}
58
59		// Writer is the EventStream writer for the {{ $inputStream.Name }}
60		// events. This value is automatically set by the SDK when the API call is made
61		// Use this member when unit testing your code with the SDK to mock out the
62		// EventStream Writer.
63		//
64		// Must not be nil.
65		Writer {{ $inputStream.StreamWriterAPIName }}
66
67		inputWriter io.WriteCloser
68		{{- if eq .API.Metadata.Protocol "json" }}
69			input {{ $.InputRef.GoType }}
70		{{- end }}
71	{{- end }}
72
73	{{- if $outputStream }}
74
75		// Reader is the EventStream reader for the {{ $outputStream.Name }}
76		// events. This value is automatically set by the SDK when the API call is made
77		// Use this member when unit testing your code with the SDK to mock out the
78		// EventStream Reader.
79		//
80		// Must not be nil.
81		Reader {{ $outputStream.StreamReaderAPIName }}
82
83		outputReader io.ReadCloser
84		{{- if eq .API.Metadata.Protocol "json" }}
85			output {{ $.OutputRef.GoType }}
86		{{- end }}
87	{{- end }}
88
89	{{- if $esapi.Legacy }}
90
91		// StreamCloser is the io.Closer for the EventStream connection. For HTTP
92		// EventStream this is the response Body. The stream will be closed when
93		// the Close method of the EventStream is called.
94		StreamCloser io.Closer
95	{{- end }}
96
97	done      chan struct{}
98	closeOnce sync.Once
99	err       *eventstreamapi.OnceError
100}
101
102// New{{ $esapi.Name }} initializes an {{ $esapi.Name }}.
103// This function should only be used for testing and mocking the {{ $esapi.Name }}
104// stream within your application.
105{{- if $inputStream }}
106//
107// The Writer member must be set before writing events to the stream.
108{{- end }}
109{{- if $outputStream }}
110//
111// The Reader member must be set before reading events from the stream.
112{{- end }}
113{{- if $esapi.Legacy }}
114//
115// The StreamCloser member should be set to the underlying io.Closer,
116// (e.g. http.Response.Body), that will be closed when the stream Close method
117// is called.
118{{- end }}
119//
120//   es := New{{ $esapi.Name }}(func(o *{{ $esapi.Name}}{
121{{- if $inputStream }}
122//       es.Writer = myMockStreamWriter
123{{- end }}
124{{- if $outputStream }}
125//       es.Reader = myMockStreamReader
126{{- end }}
127{{- if $esapi.Legacy }}
128//       es.StreamCloser = myMockStreamCloser
129{{- end }}
130//   })
131func New{{ $esapi.Name }}(opts ...func(*{{ $esapi.Name}})) *{{ $esapi.Name }} {
132	es := &{{ $esapi.Name }} {
133		done: make(chan struct{}),
134		err: eventstreamapi.NewOnceError(),
135	}
136
137	for _, fn := range opts {
138		fn(es)
139	}
140
141	return es
142}
143
144{{- if $esapi.Legacy }}
145
146	func (es *{{ $esapi.Name }}) setStreamCloser(r *request.Request) {
147		es.StreamCloser = r.HTTPResponse.Body
148	}
149{{- end }}
150
151func (es *{{ $esapi.Name }}) runOnStreamPartClose(r *request.Request) {
152	if es.done == nil {
153		return
154	}
155	go es.waitStreamPartClose()
156
157}
158
159func (es *{{ $esapi.Name }}) waitStreamPartClose() {
160	{{- if $inputStream }}
161		var inputErrCh <-chan struct{}
162		if v, ok := es.Writer.(interface{ErrorSet() <-chan struct{}}); ok {
163			inputErrCh = v.ErrorSet()
164		}
165	{{- end }}
166	{{- if $outputStream }}
167		var outputErrCh <-chan struct{}
168		if v, ok := es.Reader.(interface{ErrorSet() <-chan struct{}}); ok {
169			outputErrCh = v.ErrorSet()
170		}
171		var outputClosedCh <- chan struct{}
172		if v, ok := es.Reader.(interface{Closed() <-chan struct{}}); ok {
173			outputClosedCh = v.Closed()
174		}
175	{{- end }}
176
177	select {
178		case <-es.done:
179
180		{{- if $inputStream }}
181		case <-inputErrCh:
182			es.err.SetError(es.Writer.Err())
183			es.Close()
184		{{- end }}
185
186		{{- if $outputStream }}
187		case <-outputErrCh:
188			es.err.SetError(es.Reader.Err())
189			es.Close()
190		case <-outputClosedCh:
191			if err := es.Reader.Err(); err != nil {
192				es.err.SetError(es.Reader.Err())
193			}
194			es.Close()
195		{{- end }}
196	}
197}
198
199{{- if $inputStream }}
200
201	{{- if eq .API.Metadata.Protocol "json" }}
202
203		func {{ $esapi.StreamInputEventTypeGetterName }}(event {{ $inputStream.EventGroupName }}) (string, error) {
204			if _, ok := event.({{ $.InputRef.GoType }}); ok {
205				return "initial-request", nil
206			}
207			return {{ $inputStream.StreamEventTypeGetterName }}(event)
208		}
209	{{- end }}
210
211	func (es *{{ $esapi.Name }}) setupInputPipe(r *request.Request) {
212			inputReader, inputWriter := io.Pipe()
213			r.SetStreamingBody(inputReader)
214			es.inputWriter = inputWriter
215	}
216
217	// Send writes the event to the stream blocking until the event is written.
218	// Returns an error if the event was not written.
219	//
220	// These events are:
221	// {{ range $_, $event := $inputStream.Events }}
222	//     * {{ $event.Shape.ShapeName }}
223	{{- end }}
224	func (es *{{ $esapi.Name }}) Send(ctx aws.Context, event {{ $inputStream.EventGroupName }}) error {
225		return es.Writer.Send(ctx, event)
226	}
227
228	func (es *{{ $esapi.Name }}) runInputStream(r *request.Request) {
229		var opts []func(*eventstream.Encoder)
230		if r.Config.Logger != nil && r.Config.LogLevel.Matches(aws.LogDebugWithEventStreamBody) {
231			opts = append(opts, eventstream.EncodeWithLogger(r.Config.Logger))
232		}
233		var encoder eventstreamapi.Encoder = eventstream.NewEncoder(es.inputWriter, opts...)
234
235		var closer aws.MultiCloser
236		{{- if $.ShouldSignRequestBody }}
237			{{- $_ := $.API.AddSDKImport "aws/signer/v4" }}
238			sigSeed, err := v4.GetSignedRequestSignature(r.HTTPRequest)
239			if err != nil {
240				r.Error = awserr.New(request.ErrCodeSerialization,
241					"unable to get initial request's signature", err)
242				return
243			}
244			signer := eventstreamapi.NewSignEncoder(
245				v4.NewStreamSigner(r.ClientInfo.SigningRegion, r.ClientInfo.SigningName,
246					sigSeed, r.Config.Credentials),
247				encoder,
248			)
249			encoder = signer
250			closer = append(closer, signer)
251		{{- end }}
252		closer = append(closer, es.inputWriter)
253
254		eventWriter := eventstreamapi.NewEventWriter(encoder,
255			protocol.HandlerPayloadMarshal{
256				Marshalers: r.Handlers.BuildStream,
257			},
258			{{- if eq .API.Metadata.Protocol "json" }}
259				{{ $esapi.StreamInputEventTypeGetterName }},
260			{{- else }}
261				{{ $inputStream.StreamEventTypeGetterName }},
262			{{- end }}
263		)
264
265		es.Writer = &{{ $inputStream.StreamWriterImplName }}{
266			StreamWriter: eventstreamapi.NewStreamWriter(eventWriter, closer),
267		}
268	}
269
270	{{- if eq .API.Metadata.Protocol "json" }}
271		func (es *{{ $esapi.Name }}) sendInitialEvent(r *request.Request) {
272			if err := es.Send(es.input); err != nil {
273				r.Error = err
274			}
275		}
276	{{- end }}
277{{- end }}
278
279{{- if $outputStream }}
280	{{- if eq .API.Metadata.Protocol "json" }}
281
282		type {{ $esapi.StreamOutputUnmarshalerForEventName }} struct {
283			unmarshalerForEvent func(string) (eventstreamapi.Unmarshaler, error)
284			output {{ $.OutputRef.GoType }}
285		}
286		func (e {{ $esapi.StreamOutputUnmarshalerForEventName }}) UnmarshalerForEventName(eventType string) (eventstreamapi.Unmarshaler, error) {
287			if eventType == "initial-response" {
288				return e.output, nil
289			}
290			return e.unmarshalerForEvent(eventType)
291		}
292	{{- end }}
293
294	// Events returns a channel to read events from.
295	//
296	// These events are:
297	// {{ range $_, $event := $outputStream.Events }}
298	//     * {{ $event.Shape.ShapeName }}
299	{{- end }}
300    //     * {{ $outputStream.StreamUnknownEventName }}
301	func (es *{{ $esapi.Name }}) Events() <-chan {{ $outputStream.EventGroupName }} {
302		return es.Reader.Events()
303	}
304
305	func (es *{{ $esapi.Name }}) runOutputStream(r *request.Request) {
306		var opts []func(*eventstream.Decoder)
307		if r.Config.Logger != nil && r.Config.LogLevel.Matches(aws.LogDebugWithEventStreamBody) {
308			opts = append(opts, eventstream.DecodeWithLogger(r.Config.Logger))
309		}
310
311		unmarshalerForEvent := {{ $outputStream.StreamUnmarshalerForEventName }}{
312			metadata: protocol.ResponseMetadata{
313				StatusCode: r.HTTPResponse.StatusCode,
314				RequestID: r.RequestID,
315			},
316		}.UnmarshalerForEventName
317		{{- if eq .API.Metadata.Protocol "json" }}
318			unmarshalerForEvent = {{ $esapi.StreamOutputUnmarshalerForEventName }}{
319				unmarshalerForEvent: unmarshalerForEvent,
320				output: es.output,
321			}.UnmarshalerForEventName
322		{{- end }}
323
324		decoder := eventstream.NewDecoder(r.HTTPResponse.Body, opts...)
325		eventReader := eventstreamapi.NewEventReader(decoder,
326			protocol.HandlerPayloadUnmarshal{
327				Unmarshalers: r.Handlers.UnmarshalStream,
328			},
329			unmarshalerForEvent,
330		)
331
332		es.outputReader = r.HTTPResponse.Body
333		es.Reader = {{ $outputStream.StreamReaderImplConstructorName }}(eventReader)
334	}
335
336	{{- if eq .API.Metadata.Protocol "json" }}
337		func (es *{{ $esapi.Name }}) recvInitialEvent(r *request.Request) {
338			// Wait for the initial response event, which must be the first
339			// event to be received from the API.
340			select {
341			case event, ok := <- es.Events():
342				if !ok {
343					return
344				}
345
346				v, ok := event.({{ $.OutputRef.GoType }})
347				if !ok || v == nil {
348					r.Error = awserr.New(
349						request.ErrCodeSerialization,
350						fmt.Sprintf("invalid event, %T, expect %T, %v",
351							event, ({{ $.OutputRef.GoType }})(nil), v),
352						nil,
353					)
354					return
355				}
356
357				*es.output = *v
358				es.output.{{ $.EventStreamAPI.OutputMemberName  }} = es
359			}
360		}
361	{{- end }}
362{{- end }}
363
364// Close closes the stream. This will also cause the stream to be closed.
365// Close must be called when done using the stream API. Not calling Close
366// may result in resource leaks.
367{{- if $inputStream }}
368//
369// Will close the underlying EventStream writer, and no more events can be
370// sent.
371{{- end }}
372{{- if $outputStream }}
373//
374// You can use the closing of the Reader's Events channel to terminate your
375// application's read from the API's stream.
376{{- end }}
377//
378func (es *{{ $esapi.Name }}) Close() (err error) {
379	es.closeOnce.Do(es.safeClose)
380	return es.Err()
381}
382
383func (es *{{ $esapi.Name }}) safeClose() {
384	if es.done != nil {
385		close(es.done)
386	}
387
388	{{- if $inputStream }}
389
390		t := time.NewTicker(time.Second)
391		defer t.Stop()
392		writeCloseDone := make(chan error)
393		go func() {
394			if err := es.Writer.Close(); err != nil {
395				es.err.SetError(err)
396			}
397			close(writeCloseDone)
398		}()
399		select {
400		case <-t.C:
401		case <-writeCloseDone:
402		}
403		if es.inputWriter != nil {
404			es.inputWriter.Close()
405		}
406	{{- end }}
407
408	{{- if $outputStream }}
409
410		es.Reader.Close()
411		if es.outputReader != nil {
412			es.outputReader.Close()
413		}
414	{{- end }}
415
416	{{- if $esapi.Legacy }}
417
418		es.StreamCloser.Close()
419	{{- end }}
420}
421
422// Err returns any error that occurred while reading or writing EventStream
423// Events from the service API's response. Returns nil if there were no errors.
424func (es *{{ $esapi.Name }}) Err() error {
425	if err := es.err.Err(); err != nil {
426		return err
427	}
428
429	{{- if $inputStream }}
430		if err := es.Writer.Err(); err != nil {
431			return err
432		}
433	{{- end }}
434
435	{{- if $outputStream }}
436		if err := es.Reader.Err(); err != nil {
437			return err
438		}
439	{{- end }}
440
441	return nil
442}
443`
444
445func renderEventStreamShape(w io.Writer, s *Shape) error {
446	// Imports needed by the EventStream APIs.
447	s.API.AddImport("fmt")
448	s.API.AddImport("bytes")
449	s.API.AddImport("io")
450	s.API.AddImport("sync")
451	s.API.AddSDKImport("aws")
452	s.API.AddSDKImport("aws/awserr")
453	s.API.AddSDKImport("private/protocol/eventstream")
454	s.API.AddSDKImport("private/protocol/eventstream/eventstreamapi")
455
456	return eventStreamShapeTmpl.Execute(w, s)
457}
458
459var eventStreamShapeTmpl = func() *template.Template {
460	t := template.Must(
461		template.New("eventStreamShapeTmplDef").
462			Parse(eventStreamShapeTmplDef),
463	)
464	template.Must(
465		t.AddParseTree(
466			"eventStreamShapeWriterTmpl", eventStreamShapeWriterTmpl.Tree),
467	)
468	template.Must(
469		t.AddParseTree(
470			"eventStreamShapeReaderTmpl", eventStreamShapeReaderTmpl.Tree),
471	)
472
473	return t
474}()
475
476const eventStreamShapeTmplDef = `
477{{- $eventStream := $.EventStream }}
478{{- $eventStreamEventGroup := printf "%sEvent" $eventStream.Name }}
479
480// {{ $eventStreamEventGroup }} groups together all EventStream
481// events writes for {{ $eventStream.Name }}.
482//
483// These events are:
484// {{ range $_, $event := $eventStream.Events }}
485//     * {{ $event.Shape.ShapeName }}
486{{- end }}
487type {{ $eventStreamEventGroup }} interface {
488	event{{ $eventStream.Name }}()
489	eventstreamapi.Marshaler
490	eventstreamapi.Unmarshaler
491}
492
493{{- if $.IsInputEventStream }}
494	{{- template "eventStreamShapeWriterTmpl" $ }}
495{{- end }}
496
497{{- if $.IsOutputEventStream }}
498	{{- template "eventStreamShapeReaderTmpl" $ }}
499{{- end }}
500`
501
502// EventStreamHeaderTypeMap provides the mapping of a EventStream Header's
503// Value type to the shape reference's member type.
504type EventStreamHeaderTypeMap struct {
505	Header string
506	Member string
507}
508
509// Returns if the event has any members which are not the event's blob payload,
510// nor a header.
511func eventHasNonBlobPayloadMembers(s *Shape) bool {
512	num := len(s.MemberRefs)
513	for _, ref := range s.MemberRefs {
514		if ref.IsEventHeader || (ref.IsEventPayload && (ref.Shape.Type == "blob" || ref.Shape.Type == "string")) {
515			num--
516		}
517	}
518	return num > 0
519}
520
521func setEventHeaderValueForType(s *Shape, memVar string) string {
522	switch s.Type {
523	case "blob":
524		return fmt.Sprintf("eventstream.BytesValue(%s)", memVar)
525	case "string":
526		return fmt.Sprintf("eventstream.StringValue(*%s)", memVar)
527	case "boolean":
528		return fmt.Sprintf("eventstream.BoolValue(*%s)", memVar)
529	case "byte":
530		return fmt.Sprintf("eventstream.Int8Value(int8(*%s))", memVar)
531	case "short":
532		return fmt.Sprintf("eventstream.Int16Value(int16(*%s))", memVar)
533	case "integer":
534		return fmt.Sprintf("eventstream.Int32Value(int32(*%s))", memVar)
535	case "long":
536		return fmt.Sprintf("eventstream.Int64Value(*%s)", memVar)
537	case "float":
538		return fmt.Sprintf("eventstream.Float32Value(float32(*%s))", memVar)
539	case "double":
540		return fmt.Sprintf("eventstream.Float64Value(*%s)", memVar)
541	case "timestamp":
542		return fmt.Sprintf("eventstream.TimestampValue(*%s)", memVar)
543	default:
544		panic(fmt.Sprintf("value type %s not supported for event headers, %s", s.Type, s.ShapeName))
545	}
546}
547
548func shapeMessageType(s *Shape) string {
549	if s.Exception {
550		return "eventstreamapi.ExceptionMessageType"
551	}
552	return "eventstreamapi.EventMessageType"
553}
554
555var eventStreamEventShapeTmplFuncs = template.FuncMap{
556	"EventStreamHeaderTypeMap": func(ref *ShapeRef) EventStreamHeaderTypeMap {
557		switch ref.Shape.Type {
558		case "boolean":
559			return EventStreamHeaderTypeMap{Header: "bool", Member: "bool"}
560		case "byte":
561			return EventStreamHeaderTypeMap{Header: "int8", Member: "int64"}
562		case "short":
563			return EventStreamHeaderTypeMap{Header: "int16", Member: "int64"}
564		case "integer":
565			return EventStreamHeaderTypeMap{Header: "int32", Member: "int64"}
566		case "long":
567			return EventStreamHeaderTypeMap{Header: "int64", Member: "int64"}
568		case "timestamp":
569			return EventStreamHeaderTypeMap{Header: "time.Time", Member: "time.Time"}
570		case "blob":
571			return EventStreamHeaderTypeMap{Header: "[]byte", Member: "[]byte"}
572		case "string":
573			return EventStreamHeaderTypeMap{Header: "string", Member: "string"}
574		case "uuid":
575			return EventStreamHeaderTypeMap{Header: "[]byte", Member: "[]byte"}
576		default:
577			panic("unsupported EventStream header type, " + ref.Shape.Type)
578		}
579	},
580	"EventHeaderValueForType":  setEventHeaderValueForType,
581	"ShapeMessageType":         shapeMessageType,
582	"HasNonBlobPayloadMembers": eventHasNonBlobPayloadMembers,
583}
584
585// Template for an EventStream Event shape. This is a normal API shape that is
586// decorated as an EventStream Event.
587//
588// Executed in the context of a Shape.
589var eventStreamEventShapeTmpl = template.Must(template.New("eventStreamEventShapeTmpl").
590	Funcs(eventStreamEventShapeTmplFuncs).Parse(`
591{{ range $_, $eventStream := $.EventFor }}
592	// The {{ $.ShapeName }} is and event in the {{ $eventStream.Name }} group of events.
593	func (s *{{ $.ShapeName }}) event{{ $eventStream.Name }}() {}
594{{ end }}
595
596// UnmarshalEvent unmarshals the EventStream Message into the {{ $.ShapeName }} value.
597// This method is only used internally within the SDK's EventStream handling.
598func (s *{{ $.ShapeName }}) UnmarshalEvent(
599	payloadUnmarshaler protocol.PayloadUnmarshaler,
600	msg eventstream.Message,
601) error {
602	{{- range $memName, $memRef := $.MemberRefs }}
603		{{- if $memRef.IsEventHeader }}
604			if hv := msg.Headers.Get("{{ $memName }}"); hv != nil {
605				{{ $types := EventStreamHeaderTypeMap $memRef -}}
606				v := hv.Get().({{ $types.Header }})
607				{{- if ne $types.Header $types.Member }}
608					m := {{ $types.Member }}(v)
609					s.{{ $memName }} = {{ if $memRef.UseIndirection }}&{{ end }}m
610				{{- else }}
611					s.{{ $memName }} = {{ if $memRef.UseIndirection }}&{{ end }}v
612				{{- end }}
613			}
614		{{- else if (and ($memRef.IsEventPayload) (eq $memRef.Shape.Type "blob")) }}
615			s.{{ $memName }} = make([]byte, len(msg.Payload))
616			copy(s.{{ $memName }}, msg.Payload)
617		{{- else if (and ($memRef.IsEventPayload) (eq $memRef.Shape.Type "string")) }}
618			s.{{ $memName }} = aws.String(string(msg.Payload))
619		{{- end }}
620	{{- end }}
621	{{- if HasNonBlobPayloadMembers $ }}
622		if err := payloadUnmarshaler.UnmarshalPayload(
623			bytes.NewReader(msg.Payload), s,
624		); err != nil {
625			return err
626		}
627	{{- end }}
628	return nil
629}
630
631// MarshalEvent marshals the type into an stream event value. This method
632// should only used internally within the SDK's EventStream handling.
633func (s *{{ $.ShapeName}}) MarshalEvent(pm protocol.PayloadMarshaler) (msg eventstream.Message, err error) {
634	msg.Headers.Set(eventstreamapi.MessageTypeHeader, eventstream.StringValue({{ ShapeMessageType $ }}))
635
636	{{- range $memName, $memRef := $.MemberRefs }}
637		{{- if $memRef.IsEventHeader }}
638			{{ $memVar := printf "s.%s" $memName -}}
639			{{ $typedMem := EventHeaderValueForType $memRef.Shape $memVar -}}
640			msg.Headers.Set("{{ $memName }}", {{ $typedMem }})
641		{{- else if (and ($memRef.IsEventPayload) (eq $memRef.Shape.Type "blob")) }}
642			msg.Headers.Set(":content-type", eventstream.StringValue("application/octet-stream"))
643			msg.Payload = s.{{ $memName }}
644		{{- else if (and ($memRef.IsEventPayload) (eq $memRef.Shape.Type "string")) }}
645			msg.Payload = []byte(aws.StringValue(s.{{ $memName }}))
646		{{- end }}
647	{{- end }}
648	{{- if HasNonBlobPayloadMembers $ }}
649		var buf bytes.Buffer
650		if err = pm.MarshalPayload(&buf, s); err != nil {
651			return eventstream.Message{}, err
652		}
653		msg.Payload = buf.Bytes()
654	{{- end }}
655	return msg, err
656}
657`))
658