1 /*
2  * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License").
5  * You may not use this file except in compliance with the License.
6  * A copy of the License is located at
7  *
8  *  http://aws.amazon.com/apache2.0
9  *
10  * or in the "license" file accompanying this file. This file is distributed
11  * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
12  * express or implied. See the License for the specific language governing
13  * permissions and limitations under the License.
14  */
15 
16 package software.amazon.smithy.go.codegen.integration;
17 
18 import static software.amazon.smithy.go.codegen.integration.HttpProtocolGeneratorUtils.isShapeWithResponseBindings;
19 import static software.amazon.smithy.go.codegen.integration.ProtocolUtils.requiresDocumentSerdeFunction;
20 
21 import java.util.Collection;
22 import java.util.Comparator;
23 import java.util.List;
24 import java.util.Optional;
25 import java.util.Set;
26 import java.util.TreeSet;
27 import java.util.function.BiConsumer;
28 import java.util.logging.Logger;
29 import java.util.stream.Collectors;
30 import software.amazon.smithy.codegen.core.CodegenException;
31 import software.amazon.smithy.codegen.core.Symbol;
32 import software.amazon.smithy.codegen.core.SymbolProvider;
33 import software.amazon.smithy.go.codegen.ApplicationProtocol;
34 import software.amazon.smithy.go.codegen.CodegenUtils;
35 import software.amazon.smithy.go.codegen.GoStackStepMiddlewareGenerator;
36 import software.amazon.smithy.go.codegen.GoValueAccessUtils;
37 import software.amazon.smithy.go.codegen.GoWriter;
38 import software.amazon.smithy.go.codegen.SmithyGoDependency;
39 import software.amazon.smithy.go.codegen.SymbolUtils;
40 import software.amazon.smithy.go.codegen.knowledge.GoPointableIndex;
41 import software.amazon.smithy.go.codegen.trait.NoSerializeTrait;
42 import software.amazon.smithy.model.Model;
43 import software.amazon.smithy.model.knowledge.HttpBinding;
44 import software.amazon.smithy.model.knowledge.HttpBindingIndex;
45 import software.amazon.smithy.model.knowledge.TopDownIndex;
46 import software.amazon.smithy.model.shapes.CollectionShape;
47 import software.amazon.smithy.model.shapes.MapShape;
48 import software.amazon.smithy.model.shapes.MemberShape;
49 import software.amazon.smithy.model.shapes.OperationShape;
50 import software.amazon.smithy.model.shapes.Shape;
51 import software.amazon.smithy.model.shapes.ShapeId;
52 import software.amazon.smithy.model.shapes.ShapeType;
53 import software.amazon.smithy.model.shapes.StructureShape;
54 import software.amazon.smithy.model.shapes.ToShapeId;
55 import software.amazon.smithy.model.traits.EnumTrait;
56 import software.amazon.smithy.model.traits.HttpTrait;
57 import software.amazon.smithy.model.traits.MediaTypeTrait;
58 import software.amazon.smithy.model.traits.StreamingTrait;
59 import software.amazon.smithy.model.traits.TimestampFormatTrait;
60 import software.amazon.smithy.model.traits.TimestampFormatTrait.Format;
61 import software.amazon.smithy.utils.OptionalUtils;
62 
63 
64 /**
65  * Abstract implementation useful for all protocols that use HTTP bindings.
66  */
67 public abstract class HttpBindingProtocolGenerator implements ProtocolGenerator {
68     private static final Logger LOGGER = Logger.getLogger(HttpBindingProtocolGenerator.class.getName());
69 
70     private final boolean isErrorCodeInBody;
71     private final Set<Shape> serializeDocumentBindingShapes = new TreeSet<>();
72     private final Set<Shape> deserializeDocumentBindingShapes = new TreeSet<>();
73     private final Set<StructureShape> deserializingErrorShapes = new TreeSet<>();
74 
75     /**
76      * Creates a Http binding protocol generator.
77      *
78      * @param isErrorCodeInBody A boolean that indicates if the error code for the implementing protocol is located in
79      *                          the error response body, meaning this generator will parse the body before attempting to
80      *                          load an error code.
81      */
HttpBindingProtocolGenerator(boolean isErrorCodeInBody)82     public HttpBindingProtocolGenerator(boolean isErrorCodeInBody) {
83         this.isErrorCodeInBody = isErrorCodeInBody;
84     }
85 
86     @Override
getApplicationProtocol()87     public ApplicationProtocol getApplicationProtocol() {
88         return ApplicationProtocol.createDefaultHttpApplicationProtocol();
89     }
90 
91     @Override
generateSharedSerializerComponents(GenerationContext context)92     public void generateSharedSerializerComponents(GenerationContext context) {
93         serializeDocumentBindingShapes.addAll(ProtocolUtils.resolveRequiredDocumentShapeSerde(
94                 context.getModel(), serializeDocumentBindingShapes));
95         generateDocumentBodyShapeSerializers(context, serializeDocumentBindingShapes);
96     }
97 
98     /**
99      * Get the operations with HTTP Bindings.
100      *
101      * @param context the generation context
102      * @return the list of operation shapes
103      */
getHttpBindingOperations(GenerationContext context)104     public Set<OperationShape> getHttpBindingOperations(GenerationContext context) {
105         TopDownIndex topDownIndex = context.getModel().getKnowledge(TopDownIndex.class);
106 
107         Set<OperationShape> containedOperations = new TreeSet<>();
108         for (OperationShape operation : topDownIndex.getContainedOperations(context.getService())) {
109             OptionalUtils.ifPresentOrElse(
110                     operation.getTrait(HttpTrait.class),
111                     httpTrait -> containedOperations.add(operation),
112                     () -> LOGGER.warning(String.format(
113                             "Unable to fetch %s protocol request bindings for %s because it does not have an "
114                                     + "http binding trait", getProtocol(), operation.getId()))
115             );
116         }
117         return containedOperations;
118     }
119 
120     @Override
generateRequestSerializers(GenerationContext context)121     public void generateRequestSerializers(GenerationContext context) {
122         for (OperationShape operation : getHttpBindingOperations(context)) {
123             generateOperationSerializer(context, operation);
124         }
125     }
126 
127     /**
128      * Gets the default serde format for timestamps.
129      *
130      * @return Returns the default format.
131      */
getDocumentTimestampFormat()132     protected abstract Format getDocumentTimestampFormat();
133 
134     /**
135      * Gets the default content-type when a document is synthesized in the body.
136      *
137      * @return Returns the default content-type.
138      */
getDocumentContentType()139     protected abstract String getDocumentContentType();
140 
generateOperationSerializer(GenerationContext context, OperationShape operation)141     private void generateOperationSerializer(GenerationContext context, OperationShape operation) {
142         generateOperationSerializerMiddleware(context, operation);
143         generateOperationHttpBindingSerializer(context, operation);
144 
145         if (!CodegenUtils.isStubSyntheticClone(ProtocolUtils.expectInput(context.getModel(), operation))) {
146             generateOperationDocumentSerializer(context, operation);
147             addOperationDocumentShapeBindersForSerializer(context, operation);
148         }
149     }
150 
151     /**
152      * Generates the operation document serializer function.
153      *
154      * @param context   the generation context
155      * @param operation the operation shape being generated
156      */
generateOperationDocumentSerializer(GenerationContext context, OperationShape operation)157     protected abstract void generateOperationDocumentSerializer(GenerationContext context, OperationShape operation);
158 
159     /**
160      * Adds the top-level shapes from the operation that bind to the body document that require serializer functions.
161      *
162      * @param context   the generator context
163      * @param operation the operation to add document binders from
164      */
addOperationDocumentShapeBindersForSerializer(GenerationContext context, OperationShape operation)165     private void addOperationDocumentShapeBindersForSerializer(GenerationContext context, OperationShape operation) {
166         Model model = context.getModel();
167 
168         // Walk and add members shapes to the list that will require serializer functions
169         Collection<HttpBinding> bindings = model.getKnowledge(HttpBindingIndex.class)
170                 .getRequestBindings(operation).values();
171 
172         for (HttpBinding binding : bindings) {
173             Shape targetShape = model.expectShape(binding.getMember().getTarget());
174             // Check if the input shape has a members that target the document or payload and require serializers
175             if (requiresDocumentSerdeFunction(targetShape)
176                     && (binding.getLocation() == HttpBinding.Location.DOCUMENT
177                     || binding.getLocation() == HttpBinding.Location.PAYLOAD)) {
178                 serializeDocumentBindingShapes.add(targetShape);
179             }
180         }
181     }
182 
generateOperationSerializerMiddleware(GenerationContext context, OperationShape operation)183     private void generateOperationSerializerMiddleware(GenerationContext context, OperationShape operation) {
184         GoStackStepMiddlewareGenerator middleware = GoStackStepMiddlewareGenerator.createSerializeStepMiddleware(
185                 ProtocolGenerator.getSerializeMiddlewareName(operation.getId(), getProtocolName()),
186                 ProtocolUtils.OPERATION_SERIALIZER_MIDDLEWARE_ID);
187 
188         SymbolProvider symbolProvider = context.getSymbolProvider();
189         Model model = context.getModel();
190         Shape inputShape = model.expectShape(operation.getInput()
191                 .orElseThrow(() -> new CodegenException("expect input shape for operation: " + operation.getId())));
192         Symbol inputSymbol = symbolProvider.toSymbol(inputShape);
193         ApplicationProtocol applicationProtocol = getApplicationProtocol();
194         Symbol requestType = applicationProtocol.getRequestType();
195         HttpTrait httpTrait = operation.expectTrait(HttpTrait.class);
196 
197         middleware.writeMiddleware(context.getWriter(), (generator, writer) -> {
198             writer.addUseImports(SmithyGoDependency.FMT);
199             writer.addUseImports(SmithyGoDependency.SMITHY);
200             writer.addUseImports(SmithyGoDependency.SMITHY_HTTP_BINDING);
201 
202             // cast input request to smithy transport type, check for failures
203             writer.write("request, ok := in.Request.($P)", requestType);
204             writer.openBlock("if !ok {", "}", () -> {
205                 writer.write("return out, metadata, "
206                         + "&smithy.SerializationError{Err: fmt.Errorf(\"unknown transport type %T\", in.Request)}");
207             });
208             writer.write("");
209 
210             // cast input parameters type to the input type of the operation
211             writer.write("input, ok := in.Parameters.($P)", inputSymbol);
212             writer.write("_ = input");
213             writer.openBlock("if !ok {", "}", () -> {
214                 writer.write("return out, metadata, "
215                         + "&smithy.SerializationError{Err: fmt.Errorf(\"unknown input parameters type %T\","
216                         + " in.Parameters)}");
217             });
218 
219             writer.write("");
220             writer.write("opPath, opQuery := httpbinding.SplitURI($S)", httpTrait.getUri());
221             writer.write("request.URL.Path = smithyhttp.JoinPath(request.URL.Path, opPath)");
222             writer.write("request.URL.RawQuery = smithyhttp.JoinRawQuery(request.URL.RawQuery, opQuery)");
223             writer.write("request.Method = $S", httpTrait.getMethod());
224             writer.write("restEncoder, err := httpbinding.NewEncoder(request.URL.Path, request.URL.RawQuery, "
225                     + "request.Header)");
226             writer.openBlock("if err != nil {", "}", () -> {
227                 writer.write("return out, metadata, &smithy.SerializationError{Err: err}");
228             });
229             writer.write("");
230 
231             // we only generate an operations http bindings function if there are bindings
232             if (isOperationWithRestRequestBindings(model, operation)) {
233                 String serFunctionName = ProtocolGenerator.getOperationHttpBindingsSerFunctionName(inputShape,
234                         getProtocolName());
235                 writer.openBlock("if err := $L(input, restEncoder); err != nil {", "}", serFunctionName, () -> {
236                     writer.write("return out, metadata, &smithy.SerializationError{Err: err}");
237                 });
238                 writer.write("");
239             }
240 
241             // Don't consider serializing the body if the input shape is a stubbed synthetic clone, without an
242             // archetype.
243             if (!CodegenUtils.isStubSyntheticClone(ProtocolUtils.expectInput(model, operation))) {
244                 // document bindings vs payload bindings
245                 HttpBindingIndex httpBindingIndex = model.getKnowledge(HttpBindingIndex.class);
246                 boolean hasDocumentBindings = !httpBindingIndex
247                         .getRequestBindings(operation, HttpBinding.Location.DOCUMENT)
248                         .isEmpty();
249                 Optional<HttpBinding> payloadBinding = httpBindingIndex.getRequestBindings(operation,
250                         HttpBinding.Location.PAYLOAD).stream().findFirst();
251 
252                 if (hasDocumentBindings) {
253                     // delegate the setup and usage of the document serializer function for the protocol
254                     writeMiddlewareDocumentSerializerDelegator(context, operation, generator);
255 
256                 } else if (payloadBinding.isPresent()) {
257                     // delegate the setup and usage of the payload serializer function for the protocol
258                     MemberShape memberShape = payloadBinding.get().getMember();
259                     writeMiddlewarePayloadSerializerDelegator(context, memberShape);
260                 }
261                 writer.write("");
262             }
263 
264             // Serialize HTTP request with payload, if set.
265             writer.openBlock("if request.Request, err = restEncoder.Encode(request.Request); err != nil {", "}", () -> {
266                 writer.write("return out, metadata, &smithy.SerializationError{Err: err}");
267             });
268             writer.write("in.Request = request");
269             writer.write("");
270 
271             writer.write("return next.$L(ctx, in)", generator.getHandleMethodName());
272         });
273     }
274 
275     // Generates operation deserializer middleware that delegates to appropriate deserializers for the error,
276     // output shapes for the operation.
generateOperationDeserializerMiddleware(GenerationContext context, OperationShape operation)277     private void generateOperationDeserializerMiddleware(GenerationContext context, OperationShape operation) {
278         GoStackStepMiddlewareGenerator middleware = GoStackStepMiddlewareGenerator.createDeserializeStepMiddleware(
279                 ProtocolGenerator.getDeserializeMiddlewareName(operation.getId(), getProtocolName()),
280                 ProtocolUtils.OPERATION_DESERIALIZER_MIDDLEWARE_ID);
281 
282         SymbolProvider symbolProvider = context.getSymbolProvider();
283         Model model = context.getModel();
284 
285         ApplicationProtocol applicationProtocol = getApplicationProtocol();
286         Symbol responseType = applicationProtocol.getResponseType();
287         GoWriter goWriter = context.getWriter();
288 
289         String errorFunctionName = ProtocolGenerator.getOperationErrorDeserFunctionName(
290                 operation, context.getProtocolName());
291 
292         middleware.writeMiddleware(goWriter, (generator, writer) -> {
293             writer.addUseImports(SmithyGoDependency.FMT);
294 
295             writer.write("out, metadata, err = next.$L(ctx, in)", generator.getHandleMethodName());
296             writer.write("if err != nil { return out, metadata, err }");
297             writer.write("");
298 
299             writer.write("response, ok := out.RawResponse.($P)", responseType);
300             writer.openBlock("if !ok {", "}", () -> {
301                 writer.addUseImports(SmithyGoDependency.SMITHY);
302                 writer.write(String.format("return out, metadata, &smithy.DeserializationError{Err: %s}",
303                         "fmt.Errorf(\"unknown transport type %T\", out.RawResponse)"));
304             });
305             writer.write("");
306 
307             writer.openBlock("if response.StatusCode < 200 || response.StatusCode >= 300 {", "}", () -> {
308                 writer.write("return out, metadata, $L(response, &metadata)", errorFunctionName);
309             });
310 
311             Shape outputShape = model.expectShape(operation.getOutput()
312                     .orElseThrow(() -> new CodegenException("expect output shape for operation: " + operation.getId()))
313             );
314 
315             Symbol outputSymbol = symbolProvider.toSymbol(outputShape);
316 
317             // initialize out.Result as output structure shape
318             writer.write("output := &$T{}", outputSymbol);
319             writer.write("out.Result = output");
320             writer.write("");
321 
322             // Output shape HTTP binding middleware generation
323             if (isShapeWithRestResponseBindings(model, operation)) {
324                 String deserFuncName = ProtocolGenerator.getOperationHttpBindingsDeserFunctionName(
325                         outputShape, getProtocolName());
326 
327                 writer.write("err= $L(output, response)", deserFuncName);
328                 writer.openBlock("if err != nil {", "}", () -> {
329                     writer.addUseImports(SmithyGoDependency.SMITHY);
330                     writer.write(String.format("return out, metadata, &smithy.DeserializationError{Err: %s}",
331                             "fmt.Errorf(\"failed to decode response with invalid Http bindings, %w\", err)"));
332                 });
333                 writer.write("");
334             }
335 
336             // Discard without deserializing the response if the input shape is a stubbed synthetic clone
337             // without an archetype.
338             if (CodegenUtils.isStubSyntheticClone(ProtocolUtils.expectOutput(model, operation))) {
339                 writer.addUseImports(SmithyGoDependency.IOUTIL);
340                 writer.openBlock("if _, err = io.Copy(ioutil.Discard, response.Body); err != nil {", "}", () -> {
341                     writer.openBlock("return out, metadata, &smithy.DeserializationError{", "}", () -> {
342                         writer.write("Err: fmt.Errorf(\"failed to discard response body, %w\", err),");
343                     });
344                 });
345             } else if (isShapeWithResponseBindings(model, operation, HttpBinding.Location.DOCUMENT)
346                     || isShapeWithResponseBindings(model, operation, HttpBinding.Location.PAYLOAD)) {
347                 // Output Shape Document Binding middleware generation
348                 writeMiddlewareDocumentDeserializerDelegator(context, operation, generator);
349             }
350             writer.write("");
351 
352             writer.write("return out, metadata, err");
353         });
354         goWriter.write("");
355 
356         Set<StructureShape> errorShapes = HttpProtocolGeneratorUtils.generateErrorDispatcher(
357                 context, operation, responseType, this::writeErrorMessageCodeDeserializer);
358         deserializingErrorShapes.addAll(errorShapes);
359         deserializeDocumentBindingShapes.addAll(errorShapes);
360     }
361 
362     /**
363      * Writes a code snippet that gets the error code and error message.
364      *
365      * <p>Four parameters will be available in scope:
366      * <ul>
367      *   <li>{@code response: smithyhttp.HTTPResponse}: the HTTP response received.</li>
368      *   <li>{@code errorBody: bytes.BytesReader}: the HTTP response body.</li>
369      *   <li>{@code errorMessage: string}: the error message initialized to a default value.</li>
370      *   <li>{@code errorCode: string}: the error code initialized to a default value.</li>
371      * </ul>
372      *
373      * @param context the generation context.
374      */
writeErrorMessageCodeDeserializer(GenerationContext context)375     protected abstract void writeErrorMessageCodeDeserializer(GenerationContext context);
376 
377     /**
378      * Generate the document serializer logic for the serializer middleware body.
379      *
380      * @param context   the generation context
381      * @param operation the operation
382      * @param generator middleware generator definition
383      */
writeMiddlewareDocumentSerializerDelegator( GenerationContext context, OperationShape operation, GoStackStepMiddlewareGenerator generator )384     protected abstract void writeMiddlewareDocumentSerializerDelegator(
385             GenerationContext context,
386             OperationShape operation,
387             GoStackStepMiddlewareGenerator generator
388     );
389 
390     /**
391      * Generate the payload serializer logic for the serializer middleware body.
392      *
393      * @param context     the generation context
394      * @param memberShape the payload target member
395      */
writeMiddlewarePayloadSerializerDelegator( GenerationContext context, MemberShape memberShape )396     protected void writeMiddlewarePayloadSerializerDelegator(
397             GenerationContext context,
398             MemberShape memberShape
399     ) {
400         GoWriter writer = context.getWriter();
401         Model model = context.getModel();
402         Shape payloadShape = model.expectShape(memberShape.getTarget());
403 
404         GoValueAccessUtils.writeIfNonZeroValueMember(context.getModel(), context.getSymbolProvider(), writer,
405                 memberShape, "input", (s) -> {
406                     writer.openBlock("if !restEncoder.HasHeader(\"Content-Type\") {", "}", () -> {
407                         writer.write("restEncoder.SetHeader(\"Content-Type\").String($S)",
408                                 getPayloadShapeMediaType(payloadShape));
409                     });
410                     writer.write("");
411 
412                     if (payloadShape.hasTrait(StreamingTrait.class)) {
413                         writer.write("payload := $L", s);
414 
415                     } else if (payloadShape.isBlobShape()) {
416                         writer.addUseImports(SmithyGoDependency.BYTES);
417                         writer.write("payload := bytes.NewReader($L)", s);
418 
419                     } else if (payloadShape.isStringShape()) {
420                         writer.addUseImports(SmithyGoDependency.STRINGS);
421                         writer.write("payload := strings.NewReader(*$L)", s);
422 
423                     } else {
424                         writeMiddlewarePayloadAsDocumentSerializerDelegator(context, memberShape, s);
425                     }
426 
427                     writer.openBlock("if request, err = request.SetStream(payload); err != nil {", "}",
428                             () -> {
429                                 writer.write("return out, metadata, &smithy.SerializationError{Err: err}");
430                             });
431                 });
432     }
433 
434     /**
435      * Returns the MediaType for the payload shape derived from the MediaTypeTrait, shape type, or
436      * document content type.
437      *
438      * @param payloadShape shape bound to the payload.
439      * @return string for media type.
440      */
getPayloadShapeMediaType(Shape payloadShape)441     private String getPayloadShapeMediaType(Shape payloadShape) {
442         Optional<MediaTypeTrait> mediaTypeTrait = payloadShape.getTrait(MediaTypeTrait.class);
443 
444         if (mediaTypeTrait.isPresent()) {
445             return mediaTypeTrait.get().getValue();
446         }
447 
448         if (payloadShape.isBlobShape()) {
449             return "application/octet-stream";
450         }
451 
452         if (payloadShape.isStringShape()) {
453             return "text/plain";
454         }
455 
456         return getDocumentContentType();
457     }
458 
459     /**
460      * Generate the payload serializers with document serializer logic for the serializer middleware body.
461      *
462      * @param context     the generation context
463      * @param memberShape the payload target member
464      * @param operand     the operand that is used to access the member value
465      */
writeMiddlewarePayloadAsDocumentSerializerDelegator( GenerationContext context, MemberShape memberShape, String operand )466     protected abstract void writeMiddlewarePayloadAsDocumentSerializerDelegator(
467             GenerationContext context,
468             MemberShape memberShape,
469             String operand
470     );
471 
472     /**
473      * Generate the document deserializer logic for the deserializer middleware body.
474      *
475      * @param context   the generation context
476      * @param operation the operation
477      * @param generator middleware generator definition
478      */
writeMiddlewareDocumentDeserializerDelegator( GenerationContext context, OperationShape operation, GoStackStepMiddlewareGenerator generator )479     protected abstract void writeMiddlewareDocumentDeserializerDelegator(
480             GenerationContext context,
481             OperationShape operation,
482             GoStackStepMiddlewareGenerator generator
483     );
484 
isRestBinding(HttpBinding.Location location)485     private boolean isRestBinding(HttpBinding.Location location) {
486         return location == HttpBinding.Location.HEADER
487                 || location == HttpBinding.Location.PREFIX_HEADERS
488                 || location == HttpBinding.Location.LABEL
489                 || location == HttpBinding.Location.QUERY
490                 || location == HttpBinding.Location.RESPONSE_CODE;
491     }
492 
493     // returns whether an operation shape has Rest Request Bindings
isOperationWithRestRequestBindings(Model model, OperationShape operationShape)494     private boolean isOperationWithRestRequestBindings(Model model, OperationShape operationShape) {
495         Collection<HttpBinding> bindings = model.getKnowledge(HttpBindingIndex.class)
496                 .getRequestBindings(operationShape).values();
497 
498         for (HttpBinding binding : bindings) {
499             if (isRestBinding(binding.getLocation())) {
500                 return true;
501             }
502         }
503 
504         return false;
505     }
506 
507     /**
508      * Returns whether a shape has rest response bindings.
509      * The shape can be an operation shape, error shape or an output shape.
510      *
511      * @param model the model
512      * @param shape the shape with possible presence of rest response bindings
513      * @return boolean indicating presence of rest response bindings in the shape
514      */
isShapeWithRestResponseBindings(Model model, Shape shape)515     protected boolean isShapeWithRestResponseBindings(Model model, Shape shape) {
516         Collection<HttpBinding> bindings = model.getKnowledge(HttpBindingIndex.class)
517                 .getResponseBindings(shape).values();
518 
519         for (HttpBinding binding : bindings) {
520             if (isRestBinding(binding.getLocation())) {
521                 return true;
522             }
523         }
524         return false;
525     }
526 
generateOperationHttpBindingSerializer(GenerationContext context, OperationShape operation)527     private void generateOperationHttpBindingSerializer(GenerationContext context, OperationShape operation) {
528         SymbolProvider symbolProvider = context.getSymbolProvider();
529         Model model = context.getModel();
530         GoWriter writer = context.getWriter();
531 
532         Shape inputShape = model.expectShape(operation.getInput()
533                 .orElseThrow(() -> new CodegenException("missing input shape for operation: " + operation.getId())));
534 
535         HttpBindingIndex bindingIndex = model.getKnowledge(HttpBindingIndex.class);
536         List<HttpBinding> bindings = bindingIndex.getRequestBindings(operation).values().stream()
537                 .filter(httpBinding -> isRestBinding(httpBinding.getLocation()))
538                 .sorted(Comparator.comparing(HttpBinding::getMember))
539                 .collect(Collectors.toList());
540 
541         Symbol httpBindingEncoder = getHttpBindingEncoderSymbol();
542         Symbol inputSymbol = symbolProvider.toSymbol(inputShape);
543         String functionName = ProtocolGenerator.getOperationHttpBindingsSerFunctionName(inputShape, getProtocolName());
544 
545         writer.addUseImports(SmithyGoDependency.FMT);
546         writer.openBlock("func $L(v $P, encoder $P) error {", "}", functionName, inputSymbol, httpBindingEncoder,
547                 () -> {
548                     writer.openBlock("if v == nil {", "}", () -> {
549                         writer.write("return fmt.Errorf(\"unsupported serialization of nil %T\", v)");
550                     });
551 
552                     writer.write("");
553 
554                     for (HttpBinding binding : bindings.stream()
555                             .filter(NoSerializeTrait.excludeNoSerializeHttpBindingMembers())
556                             .collect(Collectors.toList())) {
557 
558                         writeHttpBindingMember(context, binding);
559                         writer.write("");
560                     }
561                     writer.write("return nil");
562                 });
563         writer.write("");
564     }
565 
getHttpBindingEncoderSymbol()566     private Symbol getHttpBindingEncoderSymbol() {
567         return SymbolUtils.createPointableSymbolBuilder("Encoder", SmithyGoDependency.SMITHY_HTTP_BINDING).build();
568     }
569 
generateHttpBindingTimestampSerializer( Model model, GoWriter writer, MemberShape memberShape, HttpBinding.Location location, String operand, BiConsumer<GoWriter, String> locationEncoder )570     private void generateHttpBindingTimestampSerializer(
571             Model model,
572             GoWriter writer,
573             MemberShape memberShape,
574             HttpBinding.Location location,
575             String operand,
576             BiConsumer<GoWriter, String> locationEncoder
577     ) {
578         writer.addUseImports(SmithyGoDependency.SMITHY_TIME);
579 
580         TimestampFormatTrait.Format format = model.getKnowledge(HttpBindingIndex.class).determineTimestampFormat(
581                 memberShape, location, getDocumentTimestampFormat());
582 
583         switch (format) {
584             case DATE_TIME:
585                 locationEncoder.accept(writer, "String(smithytime.FormatDateTime(" + operand + "))");
586                 break;
587             case HTTP_DATE:
588                 locationEncoder.accept(writer, "String(smithytime.FormatHTTPDate(" + operand + "))");
589                 break;
590             case EPOCH_SECONDS:
591                 locationEncoder.accept(writer, "Double(smithytime.FormatEpochSeconds(" + operand + "))");
592                 break;
593             default:
594                 throw new CodegenException("Unknown timestamp format");
595         }
596     }
597 
isHttpDateTimestamp(Model model, HttpBinding.Location location, MemberShape memberShape)598     private boolean isHttpDateTimestamp(Model model, HttpBinding.Location location, MemberShape memberShape) {
599         Shape targetShape = model.expectShape(memberShape.getTarget().toShapeId());
600         if (targetShape.getType() != ShapeType.TIMESTAMP) {
601             return false;
602         }
603 
604         TimestampFormatTrait.Format format = HttpBindingIndex.of(model).determineTimestampFormat(
605                 memberShape, location, getDocumentTimestampFormat());
606 
607         return format == Format.HTTP_DATE;
608     }
609 
writeHttpBindingSetter( GenerationContext context, GoWriter writer, MemberShape memberShape, HttpBinding.Location location, String operand, BiConsumer<GoWriter, String> locationEncoder )610     private void writeHttpBindingSetter(
611             GenerationContext context,
612             GoWriter writer,
613             MemberShape memberShape,
614             HttpBinding.Location location,
615             String operand,
616             BiConsumer<GoWriter, String> locationEncoder
617     ) {
618         Model model = context.getModel();
619         Shape targetShape = model.expectShape(memberShape.getTarget());
620 
621         // We only need to dereference if we pass the shape around as reference in Go.
622         // Note we make two exceptions here: big.Int and big.Float should still be passed as reference to the helper
623         // method as they can be arbitrarily large.
624         operand = CodegenUtils.getAsValueIfDereferencable(GoPointableIndex.of(context.getModel()), memberShape,
625                 operand);
626 
627         switch (targetShape.getType()) {
628             case BOOLEAN:
629                 locationEncoder.accept(writer, "Boolean(" + operand + ")");
630                 break;
631             case STRING:
632                 operand = targetShape.hasTrait(EnumTrait.class) ? "string(" + operand + ")" : operand;
633                 locationEncoder.accept(writer, "String(" + operand + ")");
634                 break;
635             case TIMESTAMP:
636                 generateHttpBindingTimestampSerializer(model, writer, memberShape, location, operand, locationEncoder);
637                 break;
638             case BYTE:
639                 locationEncoder.accept(writer, "Byte(" + operand + ")");
640                 break;
641             case SHORT:
642                 locationEncoder.accept(writer, "Short(" + operand + ")");
643                 break;
644             case INTEGER:
645                 locationEncoder.accept(writer, "Integer(" + operand + ")");
646                 break;
647             case LONG:
648                 locationEncoder.accept(writer, "Long(" + operand + ")");
649                 break;
650             case FLOAT:
651                 locationEncoder.accept(writer, "Float(" + operand + ")");
652                 break;
653             case DOUBLE:
654                 locationEncoder.accept(writer, "Double(" + operand + ")");
655                 break;
656             case BIG_INTEGER:
657                 locationEncoder.accept(writer, "BigInteger(" + operand + ")");
658                 break;
659             case BIG_DECIMAL:
660                 locationEncoder.accept(writer, "BigDecimal(" + operand + ")");
661                 break;
662             default:
663                 throw new CodegenException("unexpected shape type " + targetShape.getType());
664         }
665     }
666 
writeHttpBindingMember( GenerationContext context, HttpBinding binding )667     private void writeHttpBindingMember(
668             GenerationContext context,
669             HttpBinding binding
670     ) {
671         GoWriter writer = context.getWriter();
672         Model model = context.getModel();
673         MemberShape memberShape = binding.getMember();
674         Shape targetShape = model.expectShape(memberShape.getTarget());
675         HttpBinding.Location location = binding.getLocation();
676 
677         // return an error if member shape targets location label, but is unset.
678         if (location.equals(HttpBinding.Location.LABEL)) {
679             // labels must always be set to be serialized on URI, and non empty strings,
680             GoValueAccessUtils.writeIfZeroValueMember(context.getModel(), context.getSymbolProvider(), writer,
681                     memberShape, "v", false, true, operand -> {
682                         writer.addUseImports(SmithyGoDependency.SMITHY);
683                         writer.write("return &smithy.SerializationError { "
684                                         + "Err: fmt.Errorf(\"input member $L must not be empty\")}",
685                                 memberShape.getMemberName());
686                     });
687         }
688 
689         boolean allowZeroStrings = location != HttpBinding.Location.HEADER;
690 
691         GoValueAccessUtils.writeIfNonZeroValueMember(context.getModel(), context.getSymbolProvider(), writer,
692                 memberShape, "v", allowZeroStrings, memberShape.isRequired(), (operand) -> {
693                     final String locationName = binding.getLocationName().isEmpty()
694                             ? memberShape.getMemberName() : binding.getLocationName();
695                     switch (location) {
696                         case HEADER:
697                             writer.write("locationName := $S", getCanonicalHeader(locationName));
698                             writeHeaderBinding(context, memberShape, operand, location, "locationName", "encoder");
699                             break;
700                         case PREFIX_HEADERS:
701                             MemberShape valueMemberShape = model.expectShape(targetShape.getId(),
702                                     MapShape.class).getValue();
703                             Shape valueMemberTarget = model.expectShape(valueMemberShape.getTarget());
704 
705                             if (targetShape.getType() != ShapeType.MAP) {
706                                 throw new CodegenException("Unexpected prefix headers target shape "
707                                         + valueMemberTarget.getType() + ", " + valueMemberShape.getId());
708                             }
709 
710                             writer.write("hv := encoder.Headers($S)", getCanonicalHeader(locationName));
711                             writer.addUseImports(SmithyGoDependency.NET_HTTP);
712                             writer.openBlock("for mapKey, mapVal := range $L {", "}", operand, () -> {
713                                 GoValueAccessUtils.writeIfNonZeroValue(context.getModel(), writer, valueMemberShape,
714                                         "mapVal", false, false, () -> {
715                                             writeHeaderBinding(context, valueMemberShape, "mapVal", location,
716                                                     "http.CanonicalHeaderKey(mapKey)", "hv");
717                                         });
718                             });
719                             break;
720                         case LABEL:
721                             writeHttpBindingSetter(context, writer, memberShape, location, operand, (w, s) -> {
722                                 w.openBlock("if err := encoder.SetURI($S).$L; err != nil {", "}", locationName, s,
723                                         () -> {
724                                             w.write("return err");
725                                         });
726                             });
727                             break;
728                         case QUERY:
729                             if (targetShape instanceof CollectionShape) {
730                                 MemberShape collectionMember = CodegenUtils.expectCollectionShape(targetShape)
731                                         .getMember();
732                                 writer.openBlock("for i := range $L {", "}", operand, () -> {
733                                     GoValueAccessUtils.writeIfZeroValue(context.getModel(), writer, collectionMember,
734                                             operand + "[i]", () -> {
735                                                 writer.write("continue");
736                                             });
737                                     writeHttpBindingSetter(context, writer, collectionMember, location, operand + "[i]",
738                                             (w, s) -> {
739                                                 w.writeInline("encoder.AddQuery($S).$L", locationName, s);
740                                             });
741                                 });
742                             } else {
743                                 writeHttpBindingSetter(context, writer, memberShape, location, operand,
744                                         (w, s) -> w.writeInline(
745                                                 "encoder.SetQuery($S).$L", locationName, s));
746                             }
747                             break;
748                         default:
749                             throw new CodegenException("unexpected http binding found");
750                     }
751                 });
752     }
753 
writeHeaderBinding( GenerationContext context, MemberShape memberShape, String operand, HttpBinding.Location location, String locationName, String dest )754     private void writeHeaderBinding(
755             GenerationContext context,
756             MemberShape memberShape,
757             String operand,
758             HttpBinding.Location location,
759             String locationName,
760             String dest
761     ) {
762         GoWriter writer = context.getWriter();
763         Model model = context.getModel();
764         Shape targetShape = model.expectShape(memberShape.getTarget());
765 
766         if (!(targetShape instanceof CollectionShape)) {
767             String op = conditionallyBase64Encode(writer, targetShape, operand);
768             writeHttpBindingSetter(context, writer, memberShape, location, op, (w, s) -> {
769                 w.writeInline("$L.SetHeader($L).$L", dest, locationName, s);
770             });
771             return;
772         }
773 
774         MemberShape collectionMemberShape = CodegenUtils.expectCollectionShape(targetShape).getMember();
775         writer.openBlock("for i := range $L {", "}", operand, () -> {
776             // Only set non-empty non-nil header values
777             String indexedOperand = operand + "[i]";
778             GoValueAccessUtils.writeIfNonZeroValue(context.getModel(), writer, collectionMemberShape, indexedOperand,
779                     false, false, () -> {
780                         String op = conditionallyBase64Encode(writer, targetShape, indexedOperand);
781                         writeHttpBindingSetter(context, writer, collectionMemberShape, location, op, (w, s) -> {
782                             w.writeInline("$L.AddHeader($L).$L", dest, locationName, s);
783                         });
784                     });
785         });
786     }
787 
conditionallyBase64Encode(GoWriter writer, Shape targetShape, String operand)788     private String conditionallyBase64Encode(GoWriter writer, Shape targetShape, String operand) {
789         // MediaType strings written to headers must be base64 encoded
790         if (targetShape.isStringShape() && targetShape.hasTrait(MediaTypeTrait.class)) {
791             writer.addUseImports(SmithyGoDependency.SMITHY_PTR);
792             writer.addUseImports(SmithyGoDependency.BASE64);
793             writer.write("encoded := ptr.String(base64.StdEncoding.EncodeToString([]byte(*$L)))", operand);
794             return "encoded";
795         }
796         return operand;
797     }
798 
799     /**
800      * Generates serialization functions for shapes in the passed set. These functions
801      * should return a value that can then be serialized by the implementation of
802      * {@code serializeInputDocument}.
803      *
804      * @param context The generation context.
805      * @param shapes  The shapes to generate serialization for.
806      */
generateDocumentBodyShapeSerializers(GenerationContext context, Set<Shape> shapes)807     protected abstract void generateDocumentBodyShapeSerializers(GenerationContext context, Set<Shape> shapes);
808 
809     @Override
generateResponseDeserializers(GenerationContext context)810     public void generateResponseDeserializers(GenerationContext context) {
811         for (OperationShape operation : getHttpBindingOperations(context)) {
812             generateOperationDeserializerMiddleware(context, operation);
813             generateHttpBindingDeserializer(context, operation);
814 
815             if (!CodegenUtils.isStubSyntheticClone(ProtocolUtils.expectOutput(context.getModel(), operation))) {
816                 generateOperationDocumentDeserializer(context, operation);
817                 addOperationDocumentShapeBindersForDeserializer(context, operation);
818             }
819         }
820 
821         for (StructureShape error : deserializingErrorShapes) {
822             generateHttpBindingDeserializer(context, error);
823         }
824     }
825 
826     // Generates Http Binding shape deserializer function.
generateHttpBindingDeserializer(GenerationContext context, Shape shape)827     private void generateHttpBindingDeserializer(GenerationContext context, Shape shape) {
828         SymbolProvider symbolProvider = context.getSymbolProvider();
829         Model model = context.getModel();
830         GoWriter writer = context.getWriter();
831 
832         HttpBindingIndex bindingIndex = model.getKnowledge(HttpBindingIndex.class);
833         List<HttpBinding> bindings = bindingIndex.getResponseBindings(shape).values().stream()
834                 .filter(binding -> isRestBinding(binding.getLocation()))
835                 .sorted(Comparator.comparing(HttpBinding::getMember))
836                 .collect(Collectors.toList());
837 
838         // Don't generate anything if there are no bindings.
839         if (bindings.size() == 0) {
840             return;
841         }
842 
843         Shape targetShape = shape;
844         if (shape.isOperationShape()) {
845             targetShape = ProtocolUtils.expectOutput(model, shape.asOperationShape().get());
846         }
847 
848         Symbol targetSymbol = symbolProvider.toSymbol(targetShape);
849         Symbol smithyHttpResponsePointableSymbol = SymbolUtils.createPointableSymbolBuilder(
850                 "Response", SmithyGoDependency.SMITHY_HTTP_TRANSPORT).build();
851 
852         writer.addUseImports(SmithyGoDependency.FMT);
853 
854         String functionName = ProtocolGenerator.getOperationHttpBindingsDeserFunctionName(targetShape,
855                 getProtocolName());
856         writer.openBlock("func $L(v $P, response $P) error {", "}", functionName, targetSymbol,
857                 smithyHttpResponsePointableSymbol, () -> {
858                     writer.openBlock("if v == nil {", "}", () -> {
859                         writer.write("return fmt.Errorf(\"unsupported deserialization for nil %T\", v)");
860                     });
861                     writer.write("");
862 
863                     for (HttpBinding binding : bindings) {
864                         writeRestDeserializerMember(context, writer, binding);
865                         writer.write("");
866                     }
867                     writer.write("return nil");
868                 });
869     }
870 
generateHttpHeaderValue( GenerationContext context, GoWriter writer, MemberShape memberShape, HttpBinding binding, String operand )871     private String generateHttpHeaderValue(
872             GenerationContext context,
873             GoWriter writer,
874             MemberShape memberShape,
875             HttpBinding binding,
876             String operand
877     ) {
878         Shape targetShape = context.getModel().expectShape(memberShape.getTarget());
879 
880         if (targetShape.getType() != ShapeType.LIST && targetShape.getType() != ShapeType.SET) {
881             writer.addUseImports(SmithyGoDependency.STRINGS);
882             writer.write("$L = strings.TrimSpace($L)", operand, operand);
883         }
884 
885         String value = "";
886         switch (targetShape.getType()) {
887             case STRING:
888                 if (targetShape.hasTrait(EnumTrait.class)) {
889                     value = String.format("types.%s(%s)", targetShape.getId().getName(), operand);
890                     return value;
891                 }
892                 // MediaType strings must be base-64 encoded when sent in headers.
893                 if (targetShape.hasTrait(MediaTypeTrait.class)) {
894                     writer.addUseImports(SmithyGoDependency.BASE64);
895                     writer.write("b, err := base64.StdEncoding.DecodeString($L)", operand);
896                     writer.write("if err != nil { return err }");
897                     return "string(b)";
898                 }
899                 return operand;
900             case BOOLEAN:
901                 writer.addUseImports(SmithyGoDependency.STRCONV);
902                 writer.write("vv, err := strconv.ParseBool($L)", operand);
903                 writer.write("if err != nil { return err }");
904                 return "vv";
905             case TIMESTAMP:
906                 writer.addUseImports(SmithyGoDependency.SMITHY_TIME);
907                 HttpBindingIndex bindingIndex = context.getModel().getKnowledge(HttpBindingIndex.class);
908                 TimestampFormatTrait.Format format = bindingIndex.determineTimestampFormat(
909                         memberShape,
910                         binding.getLocation(),
911                         Format.HTTP_DATE
912                 );
913                 switch (format) {
914                     case EPOCH_SECONDS:
915                         writer.addUseImports(SmithyGoDependency.STRCONV);
916                         writer.write("f, err := strconv.ParseFloat($L, 64)", operand);
917                         writer.write("if err != nil { return err }");
918                         writer.write("t := smithytime.ParseEpochSeconds(f)");
919                         break;
920                     case HTTP_DATE:
921                         writer.write("t, err := smithytime.ParseHTTPDate($L)", operand);
922                         writer.write("if err != nil { return err }");
923                         break;
924                     case DATE_TIME:
925                         writer.write("t, err := smithytime.ParseDateTime($L)", operand);
926                         writer.write("if err != nil { return err }");
927                         break;
928                     default:
929                         throw new CodegenException("Unexpected timestamp format " + format);
930                 }
931                 return "t";
932             case BYTE:
933                 writer.addUseImports(SmithyGoDependency.STRCONV);
934                 writer.write("vv, err := strconv.ParseInt($L, 0, 8)", operand);
935                 writer.write("if err != nil { return err }");
936                 return "int8(vv)";
937             case SHORT:
938                 writer.addUseImports(SmithyGoDependency.STRCONV);
939                 writer.write("vv, err := strconv.ParseInt($L, 0, 16)", operand);
940                 writer.write("if err != nil { return err }");
941                 return "int16(vv)";
942             case INTEGER:
943                 writer.addUseImports(SmithyGoDependency.STRCONV);
944                 writer.write("vv, err := strconv.ParseInt($L, 0, 32)", operand);
945                 writer.write("if err != nil { return err }");
946                 return "int32(vv)";
947             case LONG:
948                 writer.addUseImports(SmithyGoDependency.STRCONV);
949                 writer.write("vv, err := strconv.ParseInt($L, 0, 64)", operand);
950                 writer.write("if err != nil { return err }");
951                 return "vv";
952             case FLOAT:
953                 writer.addUseImports(SmithyGoDependency.STRCONV);
954                 writer.write("vv, err := strconv.ParseFloat($L, 32)", operand);
955                 writer.write("if err != nil { return err }");
956                 return "float32(vv)";
957             case DOUBLE:
958                 writer.addUseImports(SmithyGoDependency.STRCONV);
959                 writer.write("vv, err := strconv.ParseFloat($L, 64)", operand);
960                 writer.write("if err != nil { return err }");
961                 return "vv";
962             case BIG_INTEGER:
963                 writer.addUseImports(SmithyGoDependency.BIG);
964                 writer.write("i := big.NewInt(0)");
965                 writer.write("bi, ok := i.SetString($L,0)", operand);
966                 writer.openBlock("if !ok {", "}", () -> {
967                     writer.write(
968                             "return fmt.Error($S)",
969                             "Incorrect conversion from string to BigInteger type"
970                     );
971                 });
972                 return "*bi";
973             case BIG_DECIMAL:
974                 writer.addUseImports(SmithyGoDependency.BIG);
975                 writer.write("f := big.NewFloat(0)");
976                 writer.write("bd, ok := f.SetString($L,0)", operand);
977                 writer.openBlock("if !ok {", "}", () -> {
978                     writer.write(
979                             "return fmt.Error($S)",
980                             "Incorrect conversion from string to BigDecimal type"
981                     );
982                 });
983                 return "*bd";
984             case BLOB:
985                 writer.addUseImports(SmithyGoDependency.BASE64);
986                 writer.write("b, err := base64.StdEncoding.DecodeString($L)", operand);
987                 writer.write("if err != nil { return err }");
988                 return "b";
989             case SET:
990             case LIST:
991                 // handle list/Set as target shape
992                 MemberShape targetValueListMemberShape = CodegenUtils.expectCollectionShape(targetShape).getMember();
993                 return getHttpHeaderCollectionDeserializer(context, writer, targetValueListMemberShape,
994                         binding,
995                         operand);
996             default:
997                 throw new CodegenException("unexpected shape type " + targetShape.getType());
998         }
999     }
1000 
getHttpHeaderCollectionDeserializer( GenerationContext context, GoWriter writer, MemberShape memberShape, HttpBinding binding, String operand )1001     private String getHttpHeaderCollectionDeserializer(
1002             GenerationContext context,
1003             GoWriter writer,
1004             MemberShape memberShape,
1005             HttpBinding binding,
1006             String operand
1007     ) {
1008         writer.write("var list []$P", context.getSymbolProvider().toSymbol(memberShape));
1009 
1010         String operandValue = operand + "Val";
1011         writer.openBlock("for _, $L := range $L {", "}", operandValue, operand, () -> {
1012             String value = generateHttpHeaderValue(context, writer, memberShape, binding, operandValue);
1013             writer.write("list = append(list, $L)",
1014                     CodegenUtils.getAsPointerIfPointable(context.getModel(), writer,
1015                             GoPointableIndex.of(context.getModel()), memberShape, value));
1016         });
1017         return "list";
1018     }
1019 
writeRestDeserializerMember( GenerationContext context, GoWriter writer, HttpBinding binding )1020     private void writeRestDeserializerMember(
1021             GenerationContext context,
1022             GoWriter writer,
1023             HttpBinding binding
1024     ) {
1025         MemberShape memberShape = binding.getMember();
1026         Shape targetShape = context.getModel().expectShape(memberShape.getTarget());
1027         String memberName = context.getSymbolProvider().toMemberName(memberShape);
1028 
1029         switch (binding.getLocation()) {
1030             case HEADER:
1031                 writeHeaderDeserializerFunction(context, writer, memberName, memberShape, binding);
1032                 break;
1033             case PREFIX_HEADERS:
1034                 if (!targetShape.isMapShape()) {
1035                     throw new CodegenException("unexpected prefix-header shape type found in Http bindings");
1036                 }
1037                 writePrefixHeaderDeserializerFunction(context, writer, memberName, memberShape, binding);
1038                 break;
1039             case RESPONSE_CODE:
1040                 writer.addUseImports(SmithyGoDependency.SMITHY_PTR);
1041                 writer.write("v.$L = $L", memberName,
1042                         CodegenUtils.getAsPointerIfPointable(context.getModel(), writer,
1043                                 GoPointableIndex.of(context.getModel()), memberShape, "int32(response.StatusCode)"));
1044                 break;
1045             default:
1046                 throw new CodegenException("unexpected http binding found");
1047         }
1048     }
1049 
writeHeaderDeserializerFunction( GenerationContext context, GoWriter writer, String memberName, MemberShape memberShape, HttpBinding binding )1050     private void writeHeaderDeserializerFunction(
1051             GenerationContext context,
1052             GoWriter writer,
1053             String memberName,
1054             MemberShape memberShape,
1055             HttpBinding binding
1056     ) {
1057         writer.openBlock("if headerValues := response.Header.Values($S); len(headerValues) != 0 {", "}",
1058                 binding.getLocationName(), () -> {
1059                     Shape targetShape = context.getModel().expectShape(memberShape.getTarget());
1060 
1061                     String operand = "headerValues";
1062                     operand = writeHeaderValueAccessor(context, writer, targetShape, binding, operand);
1063 
1064                     String value = generateHttpHeaderValue(context, writer, memberShape, binding,
1065                             operand);
1066                     writer.write("v.$L = $L", memberName,
1067                             CodegenUtils.getAsPointerIfPointable(context.getModel(), writer,
1068                                     GoPointableIndex.of(context.getModel()), memberShape, value));
1069                 });
1070     }
1071 
writePrefixHeaderDeserializerFunction( GenerationContext context, GoWriter writer, String memberName, MemberShape memberShape, HttpBinding binding )1072     private void writePrefixHeaderDeserializerFunction(
1073             GenerationContext context,
1074             GoWriter writer,
1075             String memberName,
1076             MemberShape memberShape,
1077             HttpBinding binding
1078     ) {
1079         String prefix = binding.getLocationName();
1080         Shape targetShape = context.getModel().expectShape(memberShape.getTarget());
1081 
1082         MemberShape valueMemberShape = targetShape.asMapShape()
1083                 .orElseThrow(() -> new CodegenException("prefix headers must target map shape"))
1084                 .getValue();
1085 
1086         writer.openBlock("for headerKey, headerValues := range response.Header {", "}", () -> {
1087             writer.addUseImports(SmithyGoDependency.STRINGS);
1088             Symbol targetSymbol = context.getSymbolProvider().toSymbol(targetShape);
1089 
1090             writer.openBlock(
1091                     "if lenPrefix := len($S); "
1092                             + "len(headerKey) >= lenPrefix && strings.EqualFold(headerKey[:lenPrefix], $S) {",
1093                     "}", prefix, prefix, () -> {
1094                         writer.openBlock("if v.$L == nil {", "}", memberName, () -> {
1095                             writer.write("v.$L = $P{}", memberName, targetSymbol);
1096                         });
1097 
1098                         String operand = "headerValues";
1099                         operand = writeHeaderValueAccessor(context, writer, targetShape, binding, operand);
1100 
1101                         String value = generateHttpHeaderValue(context, writer, valueMemberShape,
1102                                 binding, operand);
1103                         writer.write("v.$L[strings.ToLower(headerKey[lenPrefix:])] = $L", memberName,
1104                                 CodegenUtils.getAsPointerIfPointable(context.getModel(), writer,
1105                                         GoPointableIndex.of(context.getModel()), valueMemberShape, value));
1106                     });
1107         });
1108     }
1109 
1110     /**
1111      * Returns the header value accessor operand, and also if the target shape is a list/set will write the splitting
1112      * of the header values by comma(,) utility helper.
1113      *
1114      * @param context     generation context
1115      * @param writer      writer
1116      * @param targetShape target shape
1117      * @param binding     http binding location
1118      * @param operand     operand of the header values.
1119      * @return returns operand for accessing the header values
1120      */
writeHeaderValueAccessor( GenerationContext context, GoWriter writer, Shape targetShape, HttpBinding binding, String operand )1121     private String writeHeaderValueAccessor(
1122             GenerationContext context,
1123             GoWriter writer,
1124             Shape targetShape,
1125             HttpBinding binding,
1126             String operand
1127     ) {
1128         switch (targetShape.getType()) {
1129             case LIST:
1130             case SET:
1131                 writerHeaderListValuesSplit(context, writer, CodegenUtils.expectCollectionShape(targetShape), binding,
1132                         operand);
1133                 break;
1134             default:
1135                 // Always use first element in header, ignores if there are multiple headers with this key.
1136                 operand += "[0]";
1137                 break;
1138         }
1139 
1140         return operand;
1141     }
1142 
1143     /**
1144      * Writes the utility to split split comma separate header values into a single list for consistent iteration. Also
1145      * has special case handling for HttpDate timestamp format when serialized as a header list. Assigns the split
1146      * header values back to the same operand name.
1147      *
1148      * @param context generation context
1149      * @param writer  writer
1150      * @param shape   target collection shape
1151      * @param binding http binding location
1152      * @param operand operand of the header values.
1153      */
writerHeaderListValuesSplit( GenerationContext context, GoWriter writer, CollectionShape shape, HttpBinding binding, String operand )1154     private void writerHeaderListValuesSplit(
1155             GenerationContext context,
1156             GoWriter writer,
1157             CollectionShape shape,
1158             HttpBinding binding,
1159             String operand
1160     ) {
1161         writer.openBlock("{", "}", () -> {
1162             writer.write("var err error");
1163             writer.addUseImports(SmithyGoDependency.SMITHY_HTTP_TRANSPORT);
1164             if (isHttpDateTimestamp(context.getModel(), binding.getLocation(), shape.getMember())) {
1165                 writer.write("$L, err = smithyhttp.SplitHTTPDateTimestampHeaderListValues($L)", operand, operand);
1166             } else {
1167                 writer.write("$L, err = smithyhttp.SplitHeaderListValues($L)", operand, operand);
1168             }
1169             writer.openBlock("if err != nil {", "}", () -> {
1170                 writer.write("return err");
1171             });
1172         });
1173     }
1174 
1175     @Override
generateSharedDeserializerComponents(GenerationContext context)1176     public void generateSharedDeserializerComponents(GenerationContext context) {
1177         deserializingErrorShapes.forEach(error -> generateErrorDeserializer(context, error));
1178         deserializeDocumentBindingShapes.addAll(ProtocolUtils.resolveRequiredDocumentShapeSerde(
1179                 context.getModel(), deserializeDocumentBindingShapes));
1180         generateDocumentBodyShapeDeserializers(context, deserializeDocumentBindingShapes);
1181     }
1182 
1183     /**
1184      * Adds the top-level shapes from the operation that bind to the body document that require deserializer functions.
1185      *
1186      * @param context   the generator context
1187      * @param operation the operation to add document binders from
1188      */
addOperationDocumentShapeBindersForDeserializer(GenerationContext context, OperationShape operation)1189     private void addOperationDocumentShapeBindersForDeserializer(GenerationContext context, OperationShape operation) {
1190         Model model = context.getModel();
1191         HttpBindingIndex httpBindingIndex = model.getKnowledge(HttpBindingIndex.class);
1192         addDocumentDeserializerBindingShapes(model, httpBindingIndex, operation);
1193 
1194         for (ShapeId errorShapeId : operation.getErrors()) {
1195             addDocumentDeserializerBindingShapes(model, httpBindingIndex, errorShapeId);
1196         }
1197     }
1198 
addDocumentDeserializerBindingShapes(Model model, HttpBindingIndex index, ToShapeId shape)1199     private void addDocumentDeserializerBindingShapes(Model model, HttpBindingIndex index, ToShapeId shape) {
1200         // Walk and add members shapes to the list that will require deserializer functions
1201         for (HttpBinding binding : index.getResponseBindings(shape).values()) {
1202             Shape targetShape = model.expectShape(binding.getMember().getTarget());
1203             if (requiresDocumentSerdeFunction(targetShape)
1204                     && (binding.getLocation() == HttpBinding.Location.DOCUMENT
1205                     || binding.getLocation() == HttpBinding.Location.PAYLOAD)) {
1206                 deserializeDocumentBindingShapes.add(targetShape);
1207             }
1208         }
1209     }
1210 
1211     /**
1212      * Generates the operation document deserializer function.
1213      *
1214      * @param context   the generation context
1215      * @param operation the operation shape being generated
1216      */
generateOperationDocumentDeserializer(GenerationContext context, OperationShape operation)1217     protected abstract void generateOperationDocumentDeserializer(GenerationContext context, OperationShape operation);
1218 
1219     /**
1220      * Generates deserialization functions for shapes in the passed set. These functions
1221      * should return a value that can then be deserialized by the implementation of
1222      * {@code deserializeOutputDocument}.
1223      *
1224      * @param context The generation context.
1225      * @param shapes  The shapes to generate deserialization for.
1226      */
generateDocumentBodyShapeDeserializers(GenerationContext context, Set<Shape> shapes)1227     protected abstract void generateDocumentBodyShapeDeserializers(GenerationContext context, Set<Shape> shapes);
1228 
generateErrorDeserializer(GenerationContext context, StructureShape shape)1229     private void generateErrorDeserializer(GenerationContext context, StructureShape shape) {
1230         GoWriter writer = context.getWriter();
1231         String functionName = ProtocolGenerator.getErrorDeserFunctionName(shape, context.getProtocolName());
1232         Symbol responseType = getApplicationProtocol().getResponseType();
1233 
1234         writer.addUseImports(SmithyGoDependency.BYTES);
1235         writer.openBlock("func $L(response $P, errorBody *bytes.Reader) error {", "}",
1236                 functionName, responseType, () -> deserializeError(context, shape));
1237         writer.write("");
1238     }
1239 
1240     /**
1241      * Writes a function body that deserializes the given error.
1242      *
1243      * <p>Two parameters will be available in scope:
1244      * <ul>
1245      *   <li>{@code response: smithyhttp.HTTPResponse}: the HTTP response received.</li>
1246      *   <li>{@code errorBody: bytes.BytesReader}: the HTTP response body.</li>
1247      * </ul>
1248      *
1249      * @param context The generation context.
1250      * @param shape   The error shape.
1251      */
deserializeError(GenerationContext context, StructureShape shape)1252     protected abstract void deserializeError(GenerationContext context, StructureShape shape);
1253 
1254     /**
1255      * Converts the first letter and any letter following a hyphen to upper case. The remaining letters are lower cased.
1256      *
1257      * @param key the header
1258      * @return the canonical header
1259      */
getCanonicalHeader(String key)1260     private String getCanonicalHeader(String key) {
1261         char[] chars = key.toCharArray();
1262         boolean upper = true;
1263         for (int i = 0; i < chars.length; i++) {
1264             char c = chars[i];
1265             c = upper ? Character.toUpperCase(c) : Character.toLowerCase(c);
1266             chars[i] = c;
1267             upper = c == '-';
1268         }
1269         return new String(chars);
1270     }
1271 }
1272