1// +build codegen
2
3package api
4
5import (
6	"fmt"
7	"io"
8	"strings"
9	"text/template"
10)
11
12func renderEventStreamAPI(w io.Writer, op *Operation) error {
13	// Imports needed by the EventStream APIs.
14	op.API.AddImport("fmt")
15	op.API.AddImport("bytes")
16	op.API.AddImport("io")
17	op.API.AddImport("time")
18	op.API.AddSDKImport("aws")
19	op.API.AddSDKImport("aws/awserr")
20	op.API.AddSDKImport("aws/request")
21	op.API.AddSDKImport("private/protocol/eventstream")
22	op.API.AddSDKImport("private/protocol/eventstream/eventstreamapi")
23
24	return eventStreamAPITmpl.Execute(w, op)
25}
26
27// Template for an EventStream API Shape that will provide read/writing events
28// across the EventStream. This is a special shape that's only public members
29// are the Events channel and a Close and Err method.
30//
31// Executed in the context of a Shape.
32var eventStreamAPITmpl = template.Must(
33	template.New("eventStreamAPITmplDef").
34		Funcs(template.FuncMap{
35			"unexported": func(v string) string {
36				return strings.ToLower(string(v[0])) + v[1:]
37			},
38		}).
39		Parse(eventStreamAPITmplDef),
40)
41
42const eventStreamAPITmplDef = `
43{{- $esapi := $.EventStreamAPI }}
44{{- $outputStream := $esapi.OutputStream }}
45{{- $inputStream := $esapi.InputStream }}
46
47// {{ $esapi.Name }} provides the event stream handling for the {{ $.ExportedName }}.
48type {{ $esapi.Name }} struct {
49	{{- if $inputStream }}
50
51		// Writer is the EventStream writer for the {{ $inputStream.Name }}
52		// events. This value is automatically set by the SDK when the API call is made
53		// Use this member when unit testing your code with the SDK to mock out the
54		// EventStream Writer.
55		//
56		// Must not be nil.
57		Writer {{ $inputStream.StreamWriterAPIName }}
58
59		inputWriter io.WriteCloser
60		{{- if eq .API.Metadata.Protocol "json" }}
61			input {{ $.InputRef.GoType }}
62		{{- end }}
63	{{- end }}
64
65	{{- if $outputStream }}
66
67		// Reader is the EventStream reader for the {{ $outputStream.Name }}
68		// events. This value is automatically set by the SDK when the API call is made
69		// Use this member when unit testing your code with the SDK to mock out the
70		// EventStream Reader.
71		//
72		// Must not be nil.
73		Reader {{ $outputStream.StreamReaderAPIName }}
74
75		outputReader io.ReadCloser
76		{{- if eq .API.Metadata.Protocol "json" }}
77			output {{ $.OutputRef.GoType }}
78		{{- end }}
79	{{- end }}
80
81	{{- if $esapi.Legacy }}
82
83		// StreamCloser is the io.Closer for the EventStream connection. For HTTP
84		// EventStream this is the response Body. The stream will be closed when
85		// the Close method of the EventStream is called.
86		StreamCloser io.Closer
87	{{- end }}
88
89	done      chan struct{}
90	closeOnce sync.Once
91	err       *eventstreamapi.OnceError
92}
93
94func new{{ $esapi.Name }}() *{{ $esapi.Name }} {
95	return &{{ $esapi.Name }} {
96		done: make(chan struct{}),
97		err: eventstreamapi.NewOnceError(),
98	}
99}
100
101{{- if $esapi.Legacy }}
102
103	func (es *{{ $esapi.Name }}) setStreamCloser(r *request.Request) {
104		es.StreamCloser = r.HTTPResponse.Body
105	}
106{{- end }}
107
108func (es *{{ $esapi.Name }}) runOnStreamPartClose(r *request.Request) {
109	if es.done == nil {
110		return
111	}
112	go es.waitStreamPartClose()
113
114}
115
116func (es *{{ $esapi.Name }}) waitStreamPartClose() {
117	{{- if $inputStream }}
118		var inputErrCh <-chan struct{}
119		if v, ok := es.Writer.(interface{ErrorSet() <-chan struct{}}); ok {
120			inputErrCh = v.ErrorSet()
121		}
122	{{- end }}
123	{{- if $outputStream }}
124		var outputErrCh <-chan struct{}
125		if v, ok := es.Reader.(interface{ErrorSet() <-chan struct{}}); ok {
126			outputErrCh = v.ErrorSet()
127		}
128		var outputClosedCh <- chan struct{}
129		if v, ok := es.Reader.(interface{Closed() <-chan struct{}}); ok {
130			outputClosedCh = v.Closed()
131		}
132	{{- end }}
133
134	select {
135		case <-es.done:
136
137		{{- if $inputStream }}
138		case <-inputErrCh:
139			es.err.SetError(es.Writer.Err())
140			es.Close()
141		{{- end }}
142
143		{{- if $outputStream }}
144		case <-outputErrCh:
145			es.err.SetError(es.Reader.Err())
146			es.Close()
147		case <-outputClosedCh:
148			if err := es.Reader.Err(); err != nil {
149				es.err.SetError(es.Reader.Err())
150			}
151			es.Close()
152		{{- end }}
153	}
154}
155
156{{- if $inputStream }}
157
158	{{- if eq .API.Metadata.Protocol "json" }}
159
160		func {{ $esapi.StreamInputEventTypeGetterName }}(event {{ $inputStream.EventGroupName }}) (string, error) {
161			if _, ok := event.({{ $.InputRef.GoType }}); ok {
162				return "initial-request", nil
163			}
164			return {{ $inputStream.StreamEventTypeGetterName }}(event)
165		}
166	{{- end }}
167
168	func (es *{{ $esapi.Name }}) setupInputPipe(r *request.Request) {
169			inputReader, inputWriter := io.Pipe()
170			r.SetStreamingBody(inputReader)
171			es.inputWriter = inputWriter
172	}
173
174	// Send writes the event to the stream blocking until the event is written.
175	// Returns an error if the event was not written.
176	//
177	// These events are:
178	// {{ range $_, $event := $inputStream.Events }}
179	//     * {{ $event.Shape.ShapeName }}
180	{{- end }}
181	func (es *{{ $esapi.Name }}) Send(ctx aws.Context, event {{ $inputStream.EventGroupName }}) error {
182		return es.Writer.Send(ctx, event)
183	}
184
185	func (es *{{ $esapi.Name }}) runInputStream(r *request.Request) {
186		var opts []func(*eventstream.Encoder)
187		if r.Config.Logger != nil && r.Config.LogLevel.Matches(aws.LogDebugWithEventStreamBody) {
188			opts = append(opts, eventstream.EncodeWithLogger(r.Config.Logger))
189		}
190		var encoder eventstreamapi.Encoder = eventstream.NewEncoder(es.inputWriter, opts...)
191
192		var closer aws.MultiCloser
193		{{- if $.ShouldSignRequestBody }}
194			{{- $_ := $.API.AddSDKImport "aws/signer/v4" }}
195			sigSeed, err := v4.GetSignedRequestSignature(r.HTTPRequest)
196			if err != nil {
197				r.Error = awserr.New(request.ErrCodeSerialization,
198					"unable to get initial request's signature", err)
199				return
200			}
201			signer := eventstreamapi.NewSignEncoder(
202				v4.NewStreamSigner(r.ClientInfo.SigningRegion, r.ClientInfo.SigningName,
203					sigSeed, r.Config.Credentials),
204				encoder,
205			)
206			encoder = signer
207			closer = append(closer, signer)
208		{{- end }}
209		closer = append(closer, es.inputWriter)
210
211		eventWriter := eventstreamapi.NewEventWriter(encoder,
212			protocol.HandlerPayloadMarshal{
213				Marshalers: r.Handlers.BuildStream,
214			},
215			{{- if eq .API.Metadata.Protocol "json" }}
216				{{ $esapi.StreamInputEventTypeGetterName }},
217			{{- else }}
218				{{ $inputStream.StreamEventTypeGetterName }},
219			{{- end }}
220		)
221
222		es.Writer = &{{ $inputStream.StreamWriterImplName }}{
223			StreamWriter: eventstreamapi.NewStreamWriter(eventWriter, closer),
224		}
225	}
226
227	{{- if eq .API.Metadata.Protocol "json" }}
228		func (es *{{ $esapi.Name }}) sendInitialEvent(r *request.Request) {
229			if err := es.Send(es.input); err != nil {
230				r.Error = err
231			}
232		}
233	{{- end }}
234{{- end }}
235
236{{- if $outputStream }}
237	{{- if eq .API.Metadata.Protocol "json" }}
238
239		type {{ $esapi.StreamOutputUnmarshalerForEventName }} struct {
240			unmarshalerForEvent func(string) (eventstreamapi.Unmarshaler, error)
241			output {{ $.OutputRef.GoType }}
242		}
243		func (e {{ $esapi.StreamOutputUnmarshalerForEventName }}) UnmarshalerForEventName(eventType string) (eventstreamapi.Unmarshaler, error) {
244			if eventType == "initial-response" {
245				return e.output, nil
246			}
247			return e.unmarshalerForEvent(eventType)
248		}
249	{{- end }}
250
251	// Events returns a channel to read events from.
252	//
253	// These events are:
254	// {{ range $_, $event := $outputStream.Events }}
255	//     * {{ $event.Shape.ShapeName }}
256	{{- end }}
257	func (es *{{ $esapi.Name }}) Events() <-chan {{ $outputStream.EventGroupName }} {
258		return es.Reader.Events()
259	}
260
261	func (es *{{ $esapi.Name }}) runOutputStream(r *request.Request) {
262		var opts []func(*eventstream.Decoder)
263		if r.Config.Logger != nil && r.Config.LogLevel.Matches(aws.LogDebugWithEventStreamBody) {
264			opts = append(opts, eventstream.DecodeWithLogger(r.Config.Logger))
265		}
266
267		unmarshalerForEvent := {{ $outputStream.StreamUnmarshalerForEventName }}{
268			metadata: protocol.ResponseMetadata{
269				StatusCode: r.HTTPResponse.StatusCode,
270				RequestID: r.RequestID,
271			},
272		}.UnmarshalerForEventName
273		{{- if eq .API.Metadata.Protocol "json" }}
274			unmarshalerForEvent = {{ $esapi.StreamOutputUnmarshalerForEventName }}{
275				unmarshalerForEvent: unmarshalerForEvent,
276				output: es.output,
277			}.UnmarshalerForEventName
278		{{- end }}
279
280		decoder := eventstream.NewDecoder(r.HTTPResponse.Body, opts...)
281		eventReader := eventstreamapi.NewEventReader(decoder,
282			protocol.HandlerPayloadUnmarshal{
283				Unmarshalers: r.Handlers.UnmarshalStream,
284			},
285			unmarshalerForEvent,
286		)
287
288		es.outputReader = r.HTTPResponse.Body
289		es.Reader = {{ $outputStream.StreamReaderImplConstructorName }}(eventReader)
290	}
291
292	{{- if eq .API.Metadata.Protocol "json" }}
293		func (es *{{ $esapi.Name }}) recvInitialEvent(r *request.Request) {
294			// Wait for the initial response event, which must be the first
295			// event to be received from the API.
296			select {
297			case event, ok := <- es.Events():
298				if !ok {
299					return
300				}
301
302				v, ok := event.({{ $.OutputRef.GoType }})
303				if !ok || v == nil {
304					r.Error = awserr.New(
305						request.ErrCodeSerialization,
306						fmt.Sprintf("invalid event, %T, expect %T, %v",
307							event, ({{ $.OutputRef.GoType }})(nil), v),
308						nil,
309					)
310					return
311				}
312
313				*es.output = *v
314				es.output.{{ $.EventStreamAPI.OutputMemberName  }} = es
315			}
316		}
317	{{- end }}
318{{- end }}
319
320// Close closes the stream. This will also cause the stream to be closed.
321// Close must be called when done using the stream API. Not calling Close
322// may result in resource leaks.
323{{- if $inputStream }}
324//
325// Will close the underlying EventStream writer, and no more events can be
326// sent.
327{{- end }}
328{{- if $outputStream }}
329//
330// You can use the closing of the Reader's Events channel to terminate your
331// application's read from the API's stream.
332{{- end }}
333//
334func (es *{{ $esapi.Name }}) Close() (err error) {
335	es.closeOnce.Do(es.safeClose)
336	return es.Err()
337}
338
339func (es *{{ $esapi.Name }}) safeClose() {
340	if es.done != nil {
341		close(es.done)
342	}
343
344	{{- if $inputStream }}
345
346		t := time.NewTicker(time.Second)
347		defer t.Stop()
348		writeCloseDone := make(chan error)
349		go func() {
350			if err := es.Writer.Close(); err != nil {
351				es.err.SetError(err)
352			}
353			close(writeCloseDone)
354		}()
355		select {
356		case <-t.C:
357		case <-writeCloseDone:
358		}
359		if es.inputWriter != nil {
360			es.inputWriter.Close()
361		}
362	{{- end }}
363
364	{{- if $outputStream }}
365
366		es.Reader.Close()
367		if es.outputReader != nil {
368			es.outputReader.Close()
369		}
370	{{- end }}
371
372	{{- if $esapi.Legacy }}
373
374		es.StreamCloser.Close()
375	{{- end }}
376}
377
378// Err returns any error that occurred while reading or writing EventStream
379// Events from the service API's response. Returns nil if there were no errors.
380func (es *{{ $esapi.Name }}) Err() error {
381	if err := es.err.Err(); err != nil {
382		return err
383	}
384
385	{{- if $inputStream }}
386		if err := es.Writer.Err(); err != nil {
387			return err
388		}
389	{{- end }}
390
391	{{- if $outputStream }}
392		if err := es.Reader.Err(); err != nil {
393			return err
394		}
395	{{- end }}
396
397	return nil
398}
399`
400
401func renderEventStreamShape(w io.Writer, s *Shape) error {
402	// Imports needed by the EventStream APIs.
403	s.API.AddImport("fmt")
404	s.API.AddImport("bytes")
405	s.API.AddImport("io")
406	s.API.AddImport("sync")
407	s.API.AddSDKImport("aws")
408	s.API.AddSDKImport("aws/awserr")
409	s.API.AddSDKImport("private/protocol/eventstream")
410	s.API.AddSDKImport("private/protocol/eventstream/eventstreamapi")
411
412	return eventStreamShapeTmpl.Execute(w, s)
413}
414
415var eventStreamShapeTmpl = func() *template.Template {
416	t := template.Must(
417		template.New("eventStreamShapeTmplDef").
418			Parse(eventStreamShapeTmplDef),
419	)
420	template.Must(
421		t.AddParseTree(
422			"eventStreamShapeWriterTmpl", eventStreamShapeWriterTmpl.Tree),
423	)
424	template.Must(
425		t.AddParseTree(
426			"eventStreamShapeReaderTmpl", eventStreamShapeReaderTmpl.Tree),
427	)
428
429	return t
430}()
431
432const eventStreamShapeTmplDef = `
433{{- $eventStream := $.EventStream }}
434{{- $eventStreamEventGroup := printf "%sEvent" $eventStream.Name }}
435
436// {{ $eventStreamEventGroup }} groups together all EventStream
437// events writes for {{ $eventStream.Name }}.
438//
439// These events are:
440// {{ range $_, $event := $eventStream.Events }}
441//     * {{ $event.Shape.ShapeName }}
442{{- end }}
443type {{ $eventStreamEventGroup }} interface {
444	event{{ $eventStream.Name }}()
445	eventstreamapi.Marshaler
446	eventstreamapi.Unmarshaler
447}
448
449{{- if $.IsInputEventStream }}
450	{{- template "eventStreamShapeWriterTmpl" $ }}
451{{- end }}
452
453{{- if $.IsOutputEventStream }}
454	{{- template "eventStreamShapeReaderTmpl" $ }}
455{{- end }}
456`
457
458// EventStreamHeaderTypeMap provides the mapping of a EventStream Header's
459// Value type to the shape reference's member type.
460type EventStreamHeaderTypeMap struct {
461	Header string
462	Member string
463}
464
465// Returns if the event has any members which are not the event's blob payload,
466// nor a header.
467func eventHasNonBlobPayloadMembers(s *Shape) bool {
468	num := len(s.MemberRefs)
469	for _, ref := range s.MemberRefs {
470		if ref.IsEventHeader || (ref.IsEventPayload && (ref.Shape.Type == "blob" || ref.Shape.Type == "string")) {
471			num--
472		}
473	}
474	return num > 0
475}
476
477func setEventHeaderValueForType(s *Shape, memVar string) string {
478	switch s.Type {
479	case "blob":
480		return fmt.Sprintf("eventstream.BytesValue(%s)", memVar)
481	case "string":
482		return fmt.Sprintf("eventstream.StringValue(*%s)", memVar)
483	case "boolean":
484		return fmt.Sprintf("eventstream.BoolValue(*%s)", memVar)
485	case "byte":
486		return fmt.Sprintf("eventstream.Int8Value(int8(*%s))", memVar)
487	case "short":
488		return fmt.Sprintf("eventstream.Int16Value(int16(*%s))", memVar)
489	case "integer":
490		return fmt.Sprintf("eventstream.Int32Value(int32(*%s))", memVar)
491	case "long":
492		return fmt.Sprintf("eventstream.Int64Value(*%s)", memVar)
493	case "float":
494		return fmt.Sprintf("eventstream.Float32Value(float32(*%s))", memVar)
495	case "double":
496		return fmt.Sprintf("eventstream.Float64Value(*%s)", memVar)
497	case "timestamp":
498		return fmt.Sprintf("eventstream.TimestampValue(*%s)", memVar)
499	default:
500		panic(fmt.Sprintf("value type %s not supported for event headers, %s", s.Type, s.ShapeName))
501	}
502}
503
504func shapeMessageType(s *Shape) string {
505	if s.Exception {
506		return "eventstreamapi.ExceptionMessageType"
507	}
508	return "eventstreamapi.EventMessageType"
509}
510
511var eventStreamEventShapeTmplFuncs = template.FuncMap{
512	"EventStreamHeaderTypeMap": func(ref *ShapeRef) EventStreamHeaderTypeMap {
513		switch ref.Shape.Type {
514		case "boolean":
515			return EventStreamHeaderTypeMap{Header: "bool", Member: "bool"}
516		case "byte":
517			return EventStreamHeaderTypeMap{Header: "int8", Member: "int64"}
518		case "short":
519			return EventStreamHeaderTypeMap{Header: "int16", Member: "int64"}
520		case "integer":
521			return EventStreamHeaderTypeMap{Header: "int32", Member: "int64"}
522		case "long":
523			return EventStreamHeaderTypeMap{Header: "int64", Member: "int64"}
524		case "timestamp":
525			return EventStreamHeaderTypeMap{Header: "time.Time", Member: "time.Time"}
526		case "blob":
527			return EventStreamHeaderTypeMap{Header: "[]byte", Member: "[]byte"}
528		case "string":
529			return EventStreamHeaderTypeMap{Header: "string", Member: "string"}
530		case "uuid":
531			return EventStreamHeaderTypeMap{Header: "[]byte", Member: "[]byte"}
532		default:
533			panic("unsupported EventStream header type, " + ref.Shape.Type)
534		}
535	},
536	"EventHeaderValueForType":  setEventHeaderValueForType,
537	"ShapeMessageType":         shapeMessageType,
538	"HasNonBlobPayloadMembers": eventHasNonBlobPayloadMembers,
539}
540
541// Template for an EventStream Event shape. This is a normal API shape that is
542// decorated as an EventStream Event.
543//
544// Executed in the context of a Shape.
545var eventStreamEventShapeTmpl = template.Must(template.New("eventStreamEventShapeTmpl").
546	Funcs(eventStreamEventShapeTmplFuncs).Parse(`
547{{ range $_, $eventStream := $.EventFor }}
548	// The {{ $.ShapeName }} is and event in the {{ $eventStream.Name }} group of events.
549	func (s *{{ $.ShapeName }}) event{{ $eventStream.Name }}() {}
550{{ end }}
551
552// UnmarshalEvent unmarshals the EventStream Message into the {{ $.ShapeName }} value.
553// This method is only used internally within the SDK's EventStream handling.
554func (s *{{ $.ShapeName }}) UnmarshalEvent(
555	payloadUnmarshaler protocol.PayloadUnmarshaler,
556	msg eventstream.Message,
557) error {
558	{{- range $memName, $memRef := $.MemberRefs }}
559		{{- if $memRef.IsEventHeader }}
560			if hv := msg.Headers.Get("{{ $memName }}"); hv != nil {
561				{{ $types := EventStreamHeaderTypeMap $memRef -}}
562				v := hv.Get().({{ $types.Header }})
563				{{- if ne $types.Header $types.Member }}
564					m := {{ $types.Member }}(v)
565					s.{{ $memName }} = {{ if $memRef.UseIndirection }}&{{ end }}m
566				{{- else }}
567					s.{{ $memName }} = {{ if $memRef.UseIndirection }}&{{ end }}v
568				{{- end }}
569			}
570		{{- else if (and ($memRef.IsEventPayload) (eq $memRef.Shape.Type "blob")) }}
571			s.{{ $memName }} = make([]byte, len(msg.Payload))
572			copy(s.{{ $memName }}, msg.Payload)
573		{{- else if (and ($memRef.IsEventPayload) (eq $memRef.Shape.Type "string")) }}
574			s.{{ $memName }} = aws.String(string(msg.Payload))
575		{{- end }}
576	{{- end }}
577	{{- if HasNonBlobPayloadMembers $ }}
578		if err := payloadUnmarshaler.UnmarshalPayload(
579			bytes.NewReader(msg.Payload), s,
580		); err != nil {
581			return err
582		}
583	{{- end }}
584	return nil
585}
586
587func (s *{{ $.ShapeName}}) MarshalEvent(pm protocol.PayloadMarshaler) (msg eventstream.Message, err error) {
588	msg.Headers.Set(eventstreamapi.MessageTypeHeader, eventstream.StringValue({{ ShapeMessageType $ }}))
589
590	{{- range $memName, $memRef := $.MemberRefs }}
591		{{- if $memRef.IsEventHeader }}
592			{{ $memVar := printf "s.%s" $memName -}}
593			{{ $typedMem := EventHeaderValueForType $memRef.Shape $memVar -}}
594			msg.Headers.Set("{{ $memName }}", {{ $typedMem }})
595		{{- else if (and ($memRef.IsEventPayload) (eq $memRef.Shape.Type "blob")) }}
596			msg.Headers.Set(":content-type", eventstream.StringValue("application/octet-stream"))
597			msg.Payload = s.{{ $memName }}
598		{{- else if (and ($memRef.IsEventPayload) (eq $memRef.Shape.Type "string")) }}
599			msg.Payload = []byte(aws.StringValue(s.{{ $memName }}))
600		{{- end }}
601	{{- end }}
602	{{- if HasNonBlobPayloadMembers $ }}
603		var buf bytes.Buffer
604		if err = pm.MarshalPayload(&buf, s); err != nil {
605			return eventstream.Message{}, err
606		}
607		msg.Payload = buf.Bytes()
608	{{- end }}
609	return msg, err
610}
611`))
612