1 /* 2 * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"). 5 * You may not use this file except in compliance with the License. 6 * A copy of the License is located at 7 * 8 * http://aws.amazon.com/apache2.0 9 * 10 * or in the "license" file accompanying this file. This file is distributed 11 * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 12 * express or implied. See the License for the specific language governing 13 * permissions and limitations under the License. 14 */ 15 16 package software.amazon.smithy.go.codegen.integration; 17 18 import static software.amazon.smithy.go.codegen.integration.HttpProtocolGeneratorUtils.isShapeWithResponseBindings; 19 import static software.amazon.smithy.go.codegen.integration.ProtocolUtils.requiresDocumentSerdeFunction; 20 21 import java.util.Collection; 22 import java.util.Comparator; 23 import java.util.List; 24 import java.util.Optional; 25 import java.util.Set; 26 import java.util.TreeSet; 27 import java.util.function.BiConsumer; 28 import java.util.logging.Logger; 29 import java.util.stream.Collectors; 30 import software.amazon.smithy.codegen.core.CodegenException; 31 import software.amazon.smithy.codegen.core.Symbol; 32 import software.amazon.smithy.codegen.core.SymbolProvider; 33 import software.amazon.smithy.go.codegen.ApplicationProtocol; 34 import software.amazon.smithy.go.codegen.CodegenUtils; 35 import software.amazon.smithy.go.codegen.GoStackStepMiddlewareGenerator; 36 import software.amazon.smithy.go.codegen.GoValueAccessUtils; 37 import software.amazon.smithy.go.codegen.GoWriter; 38 import software.amazon.smithy.go.codegen.SmithyGoDependency; 39 import software.amazon.smithy.go.codegen.SymbolUtils; 40 import software.amazon.smithy.go.codegen.knowledge.GoPointableIndex; 41 import software.amazon.smithy.go.codegen.trait.NoSerializeTrait; 42 import software.amazon.smithy.model.Model; 43 import software.amazon.smithy.model.knowledge.HttpBinding; 44 import software.amazon.smithy.model.knowledge.HttpBindingIndex; 45 import software.amazon.smithy.model.knowledge.TopDownIndex; 46 import software.amazon.smithy.model.shapes.CollectionShape; 47 import software.amazon.smithy.model.shapes.MapShape; 48 import software.amazon.smithy.model.shapes.MemberShape; 49 import software.amazon.smithy.model.shapes.OperationShape; 50 import software.amazon.smithy.model.shapes.Shape; 51 import software.amazon.smithy.model.shapes.ShapeId; 52 import software.amazon.smithy.model.shapes.ShapeType; 53 import software.amazon.smithy.model.shapes.StructureShape; 54 import software.amazon.smithy.model.shapes.ToShapeId; 55 import software.amazon.smithy.model.traits.EnumTrait; 56 import software.amazon.smithy.model.traits.HttpTrait; 57 import software.amazon.smithy.model.traits.MediaTypeTrait; 58 import software.amazon.smithy.model.traits.StreamingTrait; 59 import software.amazon.smithy.model.traits.TimestampFormatTrait; 60 import software.amazon.smithy.model.traits.TimestampFormatTrait.Format; 61 import software.amazon.smithy.utils.OptionalUtils; 62 63 64 /** 65 * Abstract implementation useful for all protocols that use HTTP bindings. 66 */ 67 public abstract class HttpBindingProtocolGenerator implements ProtocolGenerator { 68 private static final Logger LOGGER = Logger.getLogger(HttpBindingProtocolGenerator.class.getName()); 69 70 private final boolean isErrorCodeInBody; 71 private final Set<Shape> serializeDocumentBindingShapes = new TreeSet<>(); 72 private final Set<Shape> deserializeDocumentBindingShapes = new TreeSet<>(); 73 private final Set<StructureShape> deserializingErrorShapes = new TreeSet<>(); 74 75 /** 76 * Creates a Http binding protocol generator. 77 * 78 * @param isErrorCodeInBody A boolean that indicates if the error code for the implementing protocol is located in 79 * the error response body, meaning this generator will parse the body before attempting to 80 * load an error code. 81 */ HttpBindingProtocolGenerator(boolean isErrorCodeInBody)82 public HttpBindingProtocolGenerator(boolean isErrorCodeInBody) { 83 this.isErrorCodeInBody = isErrorCodeInBody; 84 } 85 86 @Override getApplicationProtocol()87 public ApplicationProtocol getApplicationProtocol() { 88 return ApplicationProtocol.createDefaultHttpApplicationProtocol(); 89 } 90 91 @Override generateSharedSerializerComponents(GenerationContext context)92 public void generateSharedSerializerComponents(GenerationContext context) { 93 serializeDocumentBindingShapes.addAll(ProtocolUtils.resolveRequiredDocumentShapeSerde( 94 context.getModel(), serializeDocumentBindingShapes)); 95 generateDocumentBodyShapeSerializers(context, serializeDocumentBindingShapes); 96 } 97 98 /** 99 * Get the operations with HTTP Bindings. 100 * 101 * @param context the generation context 102 * @return the list of operation shapes 103 */ getHttpBindingOperations(GenerationContext context)104 public Set<OperationShape> getHttpBindingOperations(GenerationContext context) { 105 TopDownIndex topDownIndex = context.getModel().getKnowledge(TopDownIndex.class); 106 107 Set<OperationShape> containedOperations = new TreeSet<>(); 108 for (OperationShape operation : topDownIndex.getContainedOperations(context.getService())) { 109 OptionalUtils.ifPresentOrElse( 110 operation.getTrait(HttpTrait.class), 111 httpTrait -> containedOperations.add(operation), 112 () -> LOGGER.warning(String.format( 113 "Unable to fetch %s protocol request bindings for %s because it does not have an " 114 + "http binding trait", getProtocol(), operation.getId())) 115 ); 116 } 117 return containedOperations; 118 } 119 120 @Override generateRequestSerializers(GenerationContext context)121 public void generateRequestSerializers(GenerationContext context) { 122 for (OperationShape operation : getHttpBindingOperations(context)) { 123 generateOperationSerializer(context, operation); 124 } 125 } 126 127 /** 128 * Gets the default serde format for timestamps. 129 * 130 * @return Returns the default format. 131 */ getDocumentTimestampFormat()132 protected abstract Format getDocumentTimestampFormat(); 133 134 /** 135 * Gets the default content-type when a document is synthesized in the body. 136 * 137 * @return Returns the default content-type. 138 */ getDocumentContentType()139 protected abstract String getDocumentContentType(); 140 generateOperationSerializer(GenerationContext context, OperationShape operation)141 private void generateOperationSerializer(GenerationContext context, OperationShape operation) { 142 generateOperationSerializerMiddleware(context, operation); 143 generateOperationHttpBindingSerializer(context, operation); 144 145 if (!CodegenUtils.isStubSyntheticClone(ProtocolUtils.expectInput(context.getModel(), operation))) { 146 generateOperationDocumentSerializer(context, operation); 147 addOperationDocumentShapeBindersForSerializer(context, operation); 148 } 149 } 150 151 /** 152 * Generates the operation document serializer function. 153 * 154 * @param context the generation context 155 * @param operation the operation shape being generated 156 */ generateOperationDocumentSerializer(GenerationContext context, OperationShape operation)157 protected abstract void generateOperationDocumentSerializer(GenerationContext context, OperationShape operation); 158 159 /** 160 * Adds the top-level shapes from the operation that bind to the body document that require serializer functions. 161 * 162 * @param context the generator context 163 * @param operation the operation to add document binders from 164 */ addOperationDocumentShapeBindersForSerializer(GenerationContext context, OperationShape operation)165 private void addOperationDocumentShapeBindersForSerializer(GenerationContext context, OperationShape operation) { 166 Model model = context.getModel(); 167 168 // Walk and add members shapes to the list that will require serializer functions 169 Collection<HttpBinding> bindings = model.getKnowledge(HttpBindingIndex.class) 170 .getRequestBindings(operation).values(); 171 172 for (HttpBinding binding : bindings) { 173 Shape targetShape = model.expectShape(binding.getMember().getTarget()); 174 // Check if the input shape has a members that target the document or payload and require serializers 175 if (requiresDocumentSerdeFunction(targetShape) 176 && (binding.getLocation() == HttpBinding.Location.DOCUMENT 177 || binding.getLocation() == HttpBinding.Location.PAYLOAD)) { 178 serializeDocumentBindingShapes.add(targetShape); 179 } 180 } 181 } 182 generateOperationSerializerMiddleware(GenerationContext context, OperationShape operation)183 private void generateOperationSerializerMiddleware(GenerationContext context, OperationShape operation) { 184 GoStackStepMiddlewareGenerator middleware = GoStackStepMiddlewareGenerator.createSerializeStepMiddleware( 185 ProtocolGenerator.getSerializeMiddlewareName(operation.getId(), getProtocolName()), 186 ProtocolUtils.OPERATION_SERIALIZER_MIDDLEWARE_ID); 187 188 SymbolProvider symbolProvider = context.getSymbolProvider(); 189 Model model = context.getModel(); 190 Shape inputShape = model.expectShape(operation.getInput() 191 .orElseThrow(() -> new CodegenException("expect input shape for operation: " + operation.getId()))); 192 Symbol inputSymbol = symbolProvider.toSymbol(inputShape); 193 ApplicationProtocol applicationProtocol = getApplicationProtocol(); 194 Symbol requestType = applicationProtocol.getRequestType(); 195 HttpTrait httpTrait = operation.expectTrait(HttpTrait.class); 196 197 middleware.writeMiddleware(context.getWriter(), (generator, writer) -> { 198 writer.addUseImports(SmithyGoDependency.FMT); 199 writer.addUseImports(SmithyGoDependency.SMITHY); 200 writer.addUseImports(SmithyGoDependency.SMITHY_HTTP_BINDING); 201 202 // cast input request to smithy transport type, check for failures 203 writer.write("request, ok := in.Request.($P)", requestType); 204 writer.openBlock("if !ok {", "}", () -> { 205 writer.write("return out, metadata, " 206 + "&smithy.SerializationError{Err: fmt.Errorf(\"unknown transport type %T\", in.Request)}"); 207 }); 208 writer.write(""); 209 210 // cast input parameters type to the input type of the operation 211 writer.write("input, ok := in.Parameters.($P)", inputSymbol); 212 writer.write("_ = input"); 213 writer.openBlock("if !ok {", "}", () -> { 214 writer.write("return out, metadata, " 215 + "&smithy.SerializationError{Err: fmt.Errorf(\"unknown input parameters type %T\"," 216 + " in.Parameters)}"); 217 }); 218 219 writer.write(""); 220 writer.write("opPath, opQuery := httpbinding.SplitURI($S)", httpTrait.getUri()); 221 writer.write("request.URL.Path = smithyhttp.JoinPath(request.URL.Path, opPath)"); 222 writer.write("request.URL.RawQuery = smithyhttp.JoinRawQuery(request.URL.RawQuery, opQuery)"); 223 writer.write("request.Method = $S", httpTrait.getMethod()); 224 writer.write("restEncoder, err := httpbinding.NewEncoder(request.URL.Path, request.URL.RawQuery, " 225 + "request.Header)"); 226 writer.openBlock("if err != nil {", "}", () -> { 227 writer.write("return out, metadata, &smithy.SerializationError{Err: err}"); 228 }); 229 writer.write(""); 230 231 // we only generate an operations http bindings function if there are bindings 232 if (isOperationWithRestRequestBindings(model, operation)) { 233 String serFunctionName = ProtocolGenerator.getOperationHttpBindingsSerFunctionName(inputShape, 234 getProtocolName()); 235 writer.openBlock("if err := $L(input, restEncoder); err != nil {", "}", serFunctionName, () -> { 236 writer.write("return out, metadata, &smithy.SerializationError{Err: err}"); 237 }); 238 writer.write(""); 239 } 240 241 // Don't consider serializing the body if the input shape is a stubbed synthetic clone, without an 242 // archetype. 243 if (!CodegenUtils.isStubSyntheticClone(ProtocolUtils.expectInput(model, operation))) { 244 // document bindings vs payload bindings 245 HttpBindingIndex httpBindingIndex = model.getKnowledge(HttpBindingIndex.class); 246 boolean hasDocumentBindings = !httpBindingIndex 247 .getRequestBindings(operation, HttpBinding.Location.DOCUMENT) 248 .isEmpty(); 249 Optional<HttpBinding> payloadBinding = httpBindingIndex.getRequestBindings(operation, 250 HttpBinding.Location.PAYLOAD).stream().findFirst(); 251 252 if (hasDocumentBindings) { 253 // delegate the setup and usage of the document serializer function for the protocol 254 writeMiddlewareDocumentSerializerDelegator(context, operation, generator); 255 256 } else if (payloadBinding.isPresent()) { 257 // delegate the setup and usage of the payload serializer function for the protocol 258 MemberShape memberShape = payloadBinding.get().getMember(); 259 writeMiddlewarePayloadSerializerDelegator(context, memberShape); 260 } 261 writer.write(""); 262 } 263 264 // Serialize HTTP request with payload, if set. 265 writer.openBlock("if request.Request, err = restEncoder.Encode(request.Request); err != nil {", "}", () -> { 266 writer.write("return out, metadata, &smithy.SerializationError{Err: err}"); 267 }); 268 writer.write("in.Request = request"); 269 writer.write(""); 270 271 writer.write("return next.$L(ctx, in)", generator.getHandleMethodName()); 272 }); 273 } 274 275 // Generates operation deserializer middleware that delegates to appropriate deserializers for the error, 276 // output shapes for the operation. generateOperationDeserializerMiddleware(GenerationContext context, OperationShape operation)277 private void generateOperationDeserializerMiddleware(GenerationContext context, OperationShape operation) { 278 GoStackStepMiddlewareGenerator middleware = GoStackStepMiddlewareGenerator.createDeserializeStepMiddleware( 279 ProtocolGenerator.getDeserializeMiddlewareName(operation.getId(), getProtocolName()), 280 ProtocolUtils.OPERATION_DESERIALIZER_MIDDLEWARE_ID); 281 282 SymbolProvider symbolProvider = context.getSymbolProvider(); 283 Model model = context.getModel(); 284 285 ApplicationProtocol applicationProtocol = getApplicationProtocol(); 286 Symbol responseType = applicationProtocol.getResponseType(); 287 GoWriter goWriter = context.getWriter(); 288 289 String errorFunctionName = ProtocolGenerator.getOperationErrorDeserFunctionName( 290 operation, context.getProtocolName()); 291 292 middleware.writeMiddleware(goWriter, (generator, writer) -> { 293 writer.addUseImports(SmithyGoDependency.FMT); 294 295 writer.write("out, metadata, err = next.$L(ctx, in)", generator.getHandleMethodName()); 296 writer.write("if err != nil { return out, metadata, err }"); 297 writer.write(""); 298 299 writer.write("response, ok := out.RawResponse.($P)", responseType); 300 writer.openBlock("if !ok {", "}", () -> { 301 writer.addUseImports(SmithyGoDependency.SMITHY); 302 writer.write(String.format("return out, metadata, &smithy.DeserializationError{Err: %s}", 303 "fmt.Errorf(\"unknown transport type %T\", out.RawResponse)")); 304 }); 305 writer.write(""); 306 307 writer.openBlock("if response.StatusCode < 200 || response.StatusCode >= 300 {", "}", () -> { 308 writer.write("return out, metadata, $L(response, &metadata)", errorFunctionName); 309 }); 310 311 Shape outputShape = model.expectShape(operation.getOutput() 312 .orElseThrow(() -> new CodegenException("expect output shape for operation: " + operation.getId())) 313 ); 314 315 Symbol outputSymbol = symbolProvider.toSymbol(outputShape); 316 317 // initialize out.Result as output structure shape 318 writer.write("output := &$T{}", outputSymbol); 319 writer.write("out.Result = output"); 320 writer.write(""); 321 322 // Output shape HTTP binding middleware generation 323 if (isShapeWithRestResponseBindings(model, operation)) { 324 String deserFuncName = ProtocolGenerator.getOperationHttpBindingsDeserFunctionName( 325 outputShape, getProtocolName()); 326 327 writer.write("err= $L(output, response)", deserFuncName); 328 writer.openBlock("if err != nil {", "}", () -> { 329 writer.addUseImports(SmithyGoDependency.SMITHY); 330 writer.write(String.format("return out, metadata, &smithy.DeserializationError{Err: %s}", 331 "fmt.Errorf(\"failed to decode response with invalid Http bindings, %w\", err)")); 332 }); 333 writer.write(""); 334 } 335 336 // Discard without deserializing the response if the input shape is a stubbed synthetic clone 337 // without an archetype. 338 if (CodegenUtils.isStubSyntheticClone(ProtocolUtils.expectOutput(model, operation))) { 339 writer.addUseImports(SmithyGoDependency.IOUTIL); 340 writer.openBlock("if _, err = io.Copy(ioutil.Discard, response.Body); err != nil {", "}", () -> { 341 writer.openBlock("return out, metadata, &smithy.DeserializationError{", "}", () -> { 342 writer.write("Err: fmt.Errorf(\"failed to discard response body, %w\", err),"); 343 }); 344 }); 345 } else if (isShapeWithResponseBindings(model, operation, HttpBinding.Location.DOCUMENT) 346 || isShapeWithResponseBindings(model, operation, HttpBinding.Location.PAYLOAD)) { 347 // Output Shape Document Binding middleware generation 348 writeMiddlewareDocumentDeserializerDelegator(context, operation, generator); 349 } 350 writer.write(""); 351 352 writer.write("return out, metadata, err"); 353 }); 354 goWriter.write(""); 355 356 Set<StructureShape> errorShapes = HttpProtocolGeneratorUtils.generateErrorDispatcher( 357 context, operation, responseType, this::writeErrorMessageCodeDeserializer); 358 deserializingErrorShapes.addAll(errorShapes); 359 deserializeDocumentBindingShapes.addAll(errorShapes); 360 } 361 362 /** 363 * Writes a code snippet that gets the error code and error message. 364 * 365 * <p>Four parameters will be available in scope: 366 * <ul> 367 * <li>{@code response: smithyhttp.HTTPResponse}: the HTTP response received.</li> 368 * <li>{@code errorBody: bytes.BytesReader}: the HTTP response body.</li> 369 * <li>{@code errorMessage: string}: the error message initialized to a default value.</li> 370 * <li>{@code errorCode: string}: the error code initialized to a default value.</li> 371 * </ul> 372 * 373 * @param context the generation context. 374 */ writeErrorMessageCodeDeserializer(GenerationContext context)375 protected abstract void writeErrorMessageCodeDeserializer(GenerationContext context); 376 377 /** 378 * Generate the document serializer logic for the serializer middleware body. 379 * 380 * @param context the generation context 381 * @param operation the operation 382 * @param generator middleware generator definition 383 */ writeMiddlewareDocumentSerializerDelegator( GenerationContext context, OperationShape operation, GoStackStepMiddlewareGenerator generator )384 protected abstract void writeMiddlewareDocumentSerializerDelegator( 385 GenerationContext context, 386 OperationShape operation, 387 GoStackStepMiddlewareGenerator generator 388 ); 389 390 /** 391 * Generate the payload serializer logic for the serializer middleware body. 392 * 393 * @param context the generation context 394 * @param memberShape the payload target member 395 */ writeMiddlewarePayloadSerializerDelegator( GenerationContext context, MemberShape memberShape )396 protected void writeMiddlewarePayloadSerializerDelegator( 397 GenerationContext context, 398 MemberShape memberShape 399 ) { 400 GoWriter writer = context.getWriter(); 401 Model model = context.getModel(); 402 Shape payloadShape = model.expectShape(memberShape.getTarget()); 403 404 GoValueAccessUtils.writeIfNonZeroValueMember(context.getModel(), context.getSymbolProvider(), writer, 405 memberShape, "input", (s) -> { 406 writer.openBlock("if !restEncoder.HasHeader(\"Content-Type\") {", "}", () -> { 407 writer.write("restEncoder.SetHeader(\"Content-Type\").String($S)", 408 getPayloadShapeMediaType(payloadShape)); 409 }); 410 writer.write(""); 411 412 if (payloadShape.hasTrait(StreamingTrait.class)) { 413 writer.write("payload := $L", s); 414 415 } else if (payloadShape.isBlobShape()) { 416 writer.addUseImports(SmithyGoDependency.BYTES); 417 writer.write("payload := bytes.NewReader($L)", s); 418 419 } else if (payloadShape.isStringShape()) { 420 writer.addUseImports(SmithyGoDependency.STRINGS); 421 writer.write("payload := strings.NewReader(*$L)", s); 422 423 } else { 424 writeMiddlewarePayloadAsDocumentSerializerDelegator(context, memberShape, s); 425 } 426 427 writer.openBlock("if request, err = request.SetStream(payload); err != nil {", "}", 428 () -> { 429 writer.write("return out, metadata, &smithy.SerializationError{Err: err}"); 430 }); 431 }); 432 } 433 434 /** 435 * Returns the MediaType for the payload shape derived from the MediaTypeTrait, shape type, or 436 * document content type. 437 * 438 * @param payloadShape shape bound to the payload. 439 * @return string for media type. 440 */ getPayloadShapeMediaType(Shape payloadShape)441 private String getPayloadShapeMediaType(Shape payloadShape) { 442 Optional<MediaTypeTrait> mediaTypeTrait = payloadShape.getTrait(MediaTypeTrait.class); 443 444 if (mediaTypeTrait.isPresent()) { 445 return mediaTypeTrait.get().getValue(); 446 } 447 448 if (payloadShape.isBlobShape()) { 449 return "application/octet-stream"; 450 } 451 452 if (payloadShape.isStringShape()) { 453 return "text/plain"; 454 } 455 456 return getDocumentContentType(); 457 } 458 459 /** 460 * Generate the payload serializers with document serializer logic for the serializer middleware body. 461 * 462 * @param context the generation context 463 * @param memberShape the payload target member 464 * @param operand the operand that is used to access the member value 465 */ writeMiddlewarePayloadAsDocumentSerializerDelegator( GenerationContext context, MemberShape memberShape, String operand )466 protected abstract void writeMiddlewarePayloadAsDocumentSerializerDelegator( 467 GenerationContext context, 468 MemberShape memberShape, 469 String operand 470 ); 471 472 /** 473 * Generate the document deserializer logic for the deserializer middleware body. 474 * 475 * @param context the generation context 476 * @param operation the operation 477 * @param generator middleware generator definition 478 */ writeMiddlewareDocumentDeserializerDelegator( GenerationContext context, OperationShape operation, GoStackStepMiddlewareGenerator generator )479 protected abstract void writeMiddlewareDocumentDeserializerDelegator( 480 GenerationContext context, 481 OperationShape operation, 482 GoStackStepMiddlewareGenerator generator 483 ); 484 isRestBinding(HttpBinding.Location location)485 private boolean isRestBinding(HttpBinding.Location location) { 486 return location == HttpBinding.Location.HEADER 487 || location == HttpBinding.Location.PREFIX_HEADERS 488 || location == HttpBinding.Location.LABEL 489 || location == HttpBinding.Location.QUERY 490 || location == HttpBinding.Location.RESPONSE_CODE; 491 } 492 493 // returns whether an operation shape has Rest Request Bindings isOperationWithRestRequestBindings(Model model, OperationShape operationShape)494 private boolean isOperationWithRestRequestBindings(Model model, OperationShape operationShape) { 495 Collection<HttpBinding> bindings = model.getKnowledge(HttpBindingIndex.class) 496 .getRequestBindings(operationShape).values(); 497 498 for (HttpBinding binding : bindings) { 499 if (isRestBinding(binding.getLocation())) { 500 return true; 501 } 502 } 503 504 return false; 505 } 506 507 /** 508 * Returns whether a shape has rest response bindings. 509 * The shape can be an operation shape, error shape or an output shape. 510 * 511 * @param model the model 512 * @param shape the shape with possible presence of rest response bindings 513 * @return boolean indicating presence of rest response bindings in the shape 514 */ isShapeWithRestResponseBindings(Model model, Shape shape)515 protected boolean isShapeWithRestResponseBindings(Model model, Shape shape) { 516 Collection<HttpBinding> bindings = model.getKnowledge(HttpBindingIndex.class) 517 .getResponseBindings(shape).values(); 518 519 for (HttpBinding binding : bindings) { 520 if (isRestBinding(binding.getLocation())) { 521 return true; 522 } 523 } 524 return false; 525 } 526 generateOperationHttpBindingSerializer(GenerationContext context, OperationShape operation)527 private void generateOperationHttpBindingSerializer(GenerationContext context, OperationShape operation) { 528 SymbolProvider symbolProvider = context.getSymbolProvider(); 529 Model model = context.getModel(); 530 GoWriter writer = context.getWriter(); 531 532 Shape inputShape = model.expectShape(operation.getInput() 533 .orElseThrow(() -> new CodegenException("missing input shape for operation: " + operation.getId()))); 534 535 HttpBindingIndex bindingIndex = model.getKnowledge(HttpBindingIndex.class); 536 List<HttpBinding> bindings = bindingIndex.getRequestBindings(operation).values().stream() 537 .filter(httpBinding -> isRestBinding(httpBinding.getLocation())) 538 .sorted(Comparator.comparing(HttpBinding::getMember)) 539 .collect(Collectors.toList()); 540 541 Symbol httpBindingEncoder = getHttpBindingEncoderSymbol(); 542 Symbol inputSymbol = symbolProvider.toSymbol(inputShape); 543 String functionName = ProtocolGenerator.getOperationHttpBindingsSerFunctionName(inputShape, getProtocolName()); 544 545 writer.addUseImports(SmithyGoDependency.FMT); 546 writer.openBlock("func $L(v $P, encoder $P) error {", "}", functionName, inputSymbol, httpBindingEncoder, 547 () -> { 548 writer.openBlock("if v == nil {", "}", () -> { 549 writer.write("return fmt.Errorf(\"unsupported serialization of nil %T\", v)"); 550 }); 551 552 writer.write(""); 553 554 for (HttpBinding binding : bindings.stream() 555 .filter(NoSerializeTrait.excludeNoSerializeHttpBindingMembers()) 556 .collect(Collectors.toList())) { 557 558 writeHttpBindingMember(context, binding); 559 writer.write(""); 560 } 561 writer.write("return nil"); 562 }); 563 writer.write(""); 564 } 565 getHttpBindingEncoderSymbol()566 private Symbol getHttpBindingEncoderSymbol() { 567 return SymbolUtils.createPointableSymbolBuilder("Encoder", SmithyGoDependency.SMITHY_HTTP_BINDING).build(); 568 } 569 generateHttpBindingTimestampSerializer( Model model, GoWriter writer, MemberShape memberShape, HttpBinding.Location location, String operand, BiConsumer<GoWriter, String> locationEncoder )570 private void generateHttpBindingTimestampSerializer( 571 Model model, 572 GoWriter writer, 573 MemberShape memberShape, 574 HttpBinding.Location location, 575 String operand, 576 BiConsumer<GoWriter, String> locationEncoder 577 ) { 578 writer.addUseImports(SmithyGoDependency.SMITHY_TIME); 579 580 TimestampFormatTrait.Format format = model.getKnowledge(HttpBindingIndex.class).determineTimestampFormat( 581 memberShape, location, getDocumentTimestampFormat()); 582 583 switch (format) { 584 case DATE_TIME: 585 locationEncoder.accept(writer, "String(smithytime.FormatDateTime(" + operand + "))"); 586 break; 587 case HTTP_DATE: 588 locationEncoder.accept(writer, "String(smithytime.FormatHTTPDate(" + operand + "))"); 589 break; 590 case EPOCH_SECONDS: 591 locationEncoder.accept(writer, "Double(smithytime.FormatEpochSeconds(" + operand + "))"); 592 break; 593 default: 594 throw new CodegenException("Unknown timestamp format"); 595 } 596 } 597 isHttpDateTimestamp(Model model, HttpBinding.Location location, MemberShape memberShape)598 private boolean isHttpDateTimestamp(Model model, HttpBinding.Location location, MemberShape memberShape) { 599 Shape targetShape = model.expectShape(memberShape.getTarget().toShapeId()); 600 if (targetShape.getType() != ShapeType.TIMESTAMP) { 601 return false; 602 } 603 604 TimestampFormatTrait.Format format = HttpBindingIndex.of(model).determineTimestampFormat( 605 memberShape, location, getDocumentTimestampFormat()); 606 607 return format == Format.HTTP_DATE; 608 } 609 writeHttpBindingSetter( GenerationContext context, GoWriter writer, MemberShape memberShape, HttpBinding.Location location, String operand, BiConsumer<GoWriter, String> locationEncoder )610 private void writeHttpBindingSetter( 611 GenerationContext context, 612 GoWriter writer, 613 MemberShape memberShape, 614 HttpBinding.Location location, 615 String operand, 616 BiConsumer<GoWriter, String> locationEncoder 617 ) { 618 Model model = context.getModel(); 619 Shape targetShape = model.expectShape(memberShape.getTarget()); 620 621 // We only need to dereference if we pass the shape around as reference in Go. 622 // Note we make two exceptions here: big.Int and big.Float should still be passed as reference to the helper 623 // method as they can be arbitrarily large. 624 operand = CodegenUtils.getAsValueIfDereferencable(GoPointableIndex.of(context.getModel()), memberShape, 625 operand); 626 627 switch (targetShape.getType()) { 628 case BOOLEAN: 629 locationEncoder.accept(writer, "Boolean(" + operand + ")"); 630 break; 631 case STRING: 632 operand = targetShape.hasTrait(EnumTrait.class) ? "string(" + operand + ")" : operand; 633 locationEncoder.accept(writer, "String(" + operand + ")"); 634 break; 635 case TIMESTAMP: 636 generateHttpBindingTimestampSerializer(model, writer, memberShape, location, operand, locationEncoder); 637 break; 638 case BYTE: 639 locationEncoder.accept(writer, "Byte(" + operand + ")"); 640 break; 641 case SHORT: 642 locationEncoder.accept(writer, "Short(" + operand + ")"); 643 break; 644 case INTEGER: 645 locationEncoder.accept(writer, "Integer(" + operand + ")"); 646 break; 647 case LONG: 648 locationEncoder.accept(writer, "Long(" + operand + ")"); 649 break; 650 case FLOAT: 651 locationEncoder.accept(writer, "Float(" + operand + ")"); 652 break; 653 case DOUBLE: 654 locationEncoder.accept(writer, "Double(" + operand + ")"); 655 break; 656 case BIG_INTEGER: 657 locationEncoder.accept(writer, "BigInteger(" + operand + ")"); 658 break; 659 case BIG_DECIMAL: 660 locationEncoder.accept(writer, "BigDecimal(" + operand + ")"); 661 break; 662 default: 663 throw new CodegenException("unexpected shape type " + targetShape.getType()); 664 } 665 } 666 writeHttpBindingMember( GenerationContext context, HttpBinding binding )667 private void writeHttpBindingMember( 668 GenerationContext context, 669 HttpBinding binding 670 ) { 671 GoWriter writer = context.getWriter(); 672 Model model = context.getModel(); 673 MemberShape memberShape = binding.getMember(); 674 Shape targetShape = model.expectShape(memberShape.getTarget()); 675 HttpBinding.Location location = binding.getLocation(); 676 677 // return an error if member shape targets location label, but is unset. 678 if (location.equals(HttpBinding.Location.LABEL)) { 679 // labels must always be set to be serialized on URI, and non empty strings, 680 GoValueAccessUtils.writeIfZeroValueMember(context.getModel(), context.getSymbolProvider(), writer, 681 memberShape, "v", false, true, operand -> { 682 writer.addUseImports(SmithyGoDependency.SMITHY); 683 writer.write("return &smithy.SerializationError { " 684 + "Err: fmt.Errorf(\"input member $L must not be empty\")}", 685 memberShape.getMemberName()); 686 }); 687 } 688 689 boolean allowZeroStrings = location != HttpBinding.Location.HEADER; 690 691 GoValueAccessUtils.writeIfNonZeroValueMember(context.getModel(), context.getSymbolProvider(), writer, 692 memberShape, "v", allowZeroStrings, memberShape.isRequired(), (operand) -> { 693 final String locationName = binding.getLocationName().isEmpty() 694 ? memberShape.getMemberName() : binding.getLocationName(); 695 switch (location) { 696 case HEADER: 697 writer.write("locationName := $S", getCanonicalHeader(locationName)); 698 writeHeaderBinding(context, memberShape, operand, location, "locationName", "encoder"); 699 break; 700 case PREFIX_HEADERS: 701 MemberShape valueMemberShape = model.expectShape(targetShape.getId(), 702 MapShape.class).getValue(); 703 Shape valueMemberTarget = model.expectShape(valueMemberShape.getTarget()); 704 705 if (targetShape.getType() != ShapeType.MAP) { 706 throw new CodegenException("Unexpected prefix headers target shape " 707 + valueMemberTarget.getType() + ", " + valueMemberShape.getId()); 708 } 709 710 writer.write("hv := encoder.Headers($S)", getCanonicalHeader(locationName)); 711 writer.addUseImports(SmithyGoDependency.NET_HTTP); 712 writer.openBlock("for mapKey, mapVal := range $L {", "}", operand, () -> { 713 GoValueAccessUtils.writeIfNonZeroValue(context.getModel(), writer, valueMemberShape, 714 "mapVal", false, false, () -> { 715 writeHeaderBinding(context, valueMemberShape, "mapVal", location, 716 "http.CanonicalHeaderKey(mapKey)", "hv"); 717 }); 718 }); 719 break; 720 case LABEL: 721 writeHttpBindingSetter(context, writer, memberShape, location, operand, (w, s) -> { 722 w.openBlock("if err := encoder.SetURI($S).$L; err != nil {", "}", locationName, s, 723 () -> { 724 w.write("return err"); 725 }); 726 }); 727 break; 728 case QUERY: 729 if (targetShape instanceof CollectionShape) { 730 MemberShape collectionMember = CodegenUtils.expectCollectionShape(targetShape) 731 .getMember(); 732 writer.openBlock("for i := range $L {", "}", operand, () -> { 733 GoValueAccessUtils.writeIfZeroValue(context.getModel(), writer, collectionMember, 734 operand + "[i]", () -> { 735 writer.write("continue"); 736 }); 737 writeHttpBindingSetter(context, writer, collectionMember, location, operand + "[i]", 738 (w, s) -> { 739 w.writeInline("encoder.AddQuery($S).$L", locationName, s); 740 }); 741 }); 742 } else { 743 writeHttpBindingSetter(context, writer, memberShape, location, operand, 744 (w, s) -> w.writeInline( 745 "encoder.SetQuery($S).$L", locationName, s)); 746 } 747 break; 748 default: 749 throw new CodegenException("unexpected http binding found"); 750 } 751 }); 752 } 753 writeHeaderBinding( GenerationContext context, MemberShape memberShape, String operand, HttpBinding.Location location, String locationName, String dest )754 private void writeHeaderBinding( 755 GenerationContext context, 756 MemberShape memberShape, 757 String operand, 758 HttpBinding.Location location, 759 String locationName, 760 String dest 761 ) { 762 GoWriter writer = context.getWriter(); 763 Model model = context.getModel(); 764 Shape targetShape = model.expectShape(memberShape.getTarget()); 765 766 if (!(targetShape instanceof CollectionShape)) { 767 String op = conditionallyBase64Encode(writer, targetShape, operand); 768 writeHttpBindingSetter(context, writer, memberShape, location, op, (w, s) -> { 769 w.writeInline("$L.SetHeader($L).$L", dest, locationName, s); 770 }); 771 return; 772 } 773 774 MemberShape collectionMemberShape = CodegenUtils.expectCollectionShape(targetShape).getMember(); 775 writer.openBlock("for i := range $L {", "}", operand, () -> { 776 // Only set non-empty non-nil header values 777 String indexedOperand = operand + "[i]"; 778 GoValueAccessUtils.writeIfNonZeroValue(context.getModel(), writer, collectionMemberShape, indexedOperand, 779 false, false, () -> { 780 String op = conditionallyBase64Encode(writer, targetShape, indexedOperand); 781 writeHttpBindingSetter(context, writer, collectionMemberShape, location, op, (w, s) -> { 782 w.writeInline("$L.AddHeader($L).$L", dest, locationName, s); 783 }); 784 }); 785 }); 786 } 787 conditionallyBase64Encode(GoWriter writer, Shape targetShape, String operand)788 private String conditionallyBase64Encode(GoWriter writer, Shape targetShape, String operand) { 789 // MediaType strings written to headers must be base64 encoded 790 if (targetShape.isStringShape() && targetShape.hasTrait(MediaTypeTrait.class)) { 791 writer.addUseImports(SmithyGoDependency.SMITHY_PTR); 792 writer.addUseImports(SmithyGoDependency.BASE64); 793 writer.write("encoded := ptr.String(base64.StdEncoding.EncodeToString([]byte(*$L)))", operand); 794 return "encoded"; 795 } 796 return operand; 797 } 798 799 /** 800 * Generates serialization functions for shapes in the passed set. These functions 801 * should return a value that can then be serialized by the implementation of 802 * {@code serializeInputDocument}. 803 * 804 * @param context The generation context. 805 * @param shapes The shapes to generate serialization for. 806 */ generateDocumentBodyShapeSerializers(GenerationContext context, Set<Shape> shapes)807 protected abstract void generateDocumentBodyShapeSerializers(GenerationContext context, Set<Shape> shapes); 808 809 @Override generateResponseDeserializers(GenerationContext context)810 public void generateResponseDeserializers(GenerationContext context) { 811 for (OperationShape operation : getHttpBindingOperations(context)) { 812 generateOperationDeserializerMiddleware(context, operation); 813 generateHttpBindingDeserializer(context, operation); 814 815 if (!CodegenUtils.isStubSyntheticClone(ProtocolUtils.expectOutput(context.getModel(), operation))) { 816 generateOperationDocumentDeserializer(context, operation); 817 addOperationDocumentShapeBindersForDeserializer(context, operation); 818 } 819 } 820 821 for (StructureShape error : deserializingErrorShapes) { 822 generateHttpBindingDeserializer(context, error); 823 } 824 } 825 826 // Generates Http Binding shape deserializer function. generateHttpBindingDeserializer(GenerationContext context, Shape shape)827 private void generateHttpBindingDeserializer(GenerationContext context, Shape shape) { 828 SymbolProvider symbolProvider = context.getSymbolProvider(); 829 Model model = context.getModel(); 830 GoWriter writer = context.getWriter(); 831 832 HttpBindingIndex bindingIndex = model.getKnowledge(HttpBindingIndex.class); 833 List<HttpBinding> bindings = bindingIndex.getResponseBindings(shape).values().stream() 834 .filter(binding -> isRestBinding(binding.getLocation())) 835 .sorted(Comparator.comparing(HttpBinding::getMember)) 836 .collect(Collectors.toList()); 837 838 // Don't generate anything if there are no bindings. 839 if (bindings.size() == 0) { 840 return; 841 } 842 843 Shape targetShape = shape; 844 if (shape.isOperationShape()) { 845 targetShape = ProtocolUtils.expectOutput(model, shape.asOperationShape().get()); 846 } 847 848 Symbol targetSymbol = symbolProvider.toSymbol(targetShape); 849 Symbol smithyHttpResponsePointableSymbol = SymbolUtils.createPointableSymbolBuilder( 850 "Response", SmithyGoDependency.SMITHY_HTTP_TRANSPORT).build(); 851 852 writer.addUseImports(SmithyGoDependency.FMT); 853 854 String functionName = ProtocolGenerator.getOperationHttpBindingsDeserFunctionName(targetShape, 855 getProtocolName()); 856 writer.openBlock("func $L(v $P, response $P) error {", "}", functionName, targetSymbol, 857 smithyHttpResponsePointableSymbol, () -> { 858 writer.openBlock("if v == nil {", "}", () -> { 859 writer.write("return fmt.Errorf(\"unsupported deserialization for nil %T\", v)"); 860 }); 861 writer.write(""); 862 863 for (HttpBinding binding : bindings) { 864 writeRestDeserializerMember(context, writer, binding); 865 writer.write(""); 866 } 867 writer.write("return nil"); 868 }); 869 } 870 generateHttpHeaderValue( GenerationContext context, GoWriter writer, MemberShape memberShape, HttpBinding binding, String operand )871 private String generateHttpHeaderValue( 872 GenerationContext context, 873 GoWriter writer, 874 MemberShape memberShape, 875 HttpBinding binding, 876 String operand 877 ) { 878 Shape targetShape = context.getModel().expectShape(memberShape.getTarget()); 879 880 if (targetShape.getType() != ShapeType.LIST && targetShape.getType() != ShapeType.SET) { 881 writer.addUseImports(SmithyGoDependency.STRINGS); 882 writer.write("$L = strings.TrimSpace($L)", operand, operand); 883 } 884 885 String value = ""; 886 switch (targetShape.getType()) { 887 case STRING: 888 if (targetShape.hasTrait(EnumTrait.class)) { 889 value = String.format("types.%s(%s)", targetShape.getId().getName(), operand); 890 return value; 891 } 892 // MediaType strings must be base-64 encoded when sent in headers. 893 if (targetShape.hasTrait(MediaTypeTrait.class)) { 894 writer.addUseImports(SmithyGoDependency.BASE64); 895 writer.write("b, err := base64.StdEncoding.DecodeString($L)", operand); 896 writer.write("if err != nil { return err }"); 897 return "string(b)"; 898 } 899 return operand; 900 case BOOLEAN: 901 writer.addUseImports(SmithyGoDependency.STRCONV); 902 writer.write("vv, err := strconv.ParseBool($L)", operand); 903 writer.write("if err != nil { return err }"); 904 return "vv"; 905 case TIMESTAMP: 906 writer.addUseImports(SmithyGoDependency.SMITHY_TIME); 907 HttpBindingIndex bindingIndex = context.getModel().getKnowledge(HttpBindingIndex.class); 908 TimestampFormatTrait.Format format = bindingIndex.determineTimestampFormat( 909 memberShape, 910 binding.getLocation(), 911 Format.HTTP_DATE 912 ); 913 switch (format) { 914 case EPOCH_SECONDS: 915 writer.addUseImports(SmithyGoDependency.STRCONV); 916 writer.write("f, err := strconv.ParseFloat($L, 64)", operand); 917 writer.write("if err != nil { return err }"); 918 writer.write("t := smithytime.ParseEpochSeconds(f)"); 919 break; 920 case HTTP_DATE: 921 writer.write("t, err := smithytime.ParseHTTPDate($L)", operand); 922 writer.write("if err != nil { return err }"); 923 break; 924 case DATE_TIME: 925 writer.write("t, err := smithytime.ParseDateTime($L)", operand); 926 writer.write("if err != nil { return err }"); 927 break; 928 default: 929 throw new CodegenException("Unexpected timestamp format " + format); 930 } 931 return "t"; 932 case BYTE: 933 writer.addUseImports(SmithyGoDependency.STRCONV); 934 writer.write("vv, err := strconv.ParseInt($L, 0, 8)", operand); 935 writer.write("if err != nil { return err }"); 936 return "int8(vv)"; 937 case SHORT: 938 writer.addUseImports(SmithyGoDependency.STRCONV); 939 writer.write("vv, err := strconv.ParseInt($L, 0, 16)", operand); 940 writer.write("if err != nil { return err }"); 941 return "int16(vv)"; 942 case INTEGER: 943 writer.addUseImports(SmithyGoDependency.STRCONV); 944 writer.write("vv, err := strconv.ParseInt($L, 0, 32)", operand); 945 writer.write("if err != nil { return err }"); 946 return "int32(vv)"; 947 case LONG: 948 writer.addUseImports(SmithyGoDependency.STRCONV); 949 writer.write("vv, err := strconv.ParseInt($L, 0, 64)", operand); 950 writer.write("if err != nil { return err }"); 951 return "vv"; 952 case FLOAT: 953 writer.addUseImports(SmithyGoDependency.STRCONV); 954 writer.write("vv, err := strconv.ParseFloat($L, 32)", operand); 955 writer.write("if err != nil { return err }"); 956 return "float32(vv)"; 957 case DOUBLE: 958 writer.addUseImports(SmithyGoDependency.STRCONV); 959 writer.write("vv, err := strconv.ParseFloat($L, 64)", operand); 960 writer.write("if err != nil { return err }"); 961 return "vv"; 962 case BIG_INTEGER: 963 writer.addUseImports(SmithyGoDependency.BIG); 964 writer.write("i := big.NewInt(0)"); 965 writer.write("bi, ok := i.SetString($L,0)", operand); 966 writer.openBlock("if !ok {", "}", () -> { 967 writer.write( 968 "return fmt.Error($S)", 969 "Incorrect conversion from string to BigInteger type" 970 ); 971 }); 972 return "*bi"; 973 case BIG_DECIMAL: 974 writer.addUseImports(SmithyGoDependency.BIG); 975 writer.write("f := big.NewFloat(0)"); 976 writer.write("bd, ok := f.SetString($L,0)", operand); 977 writer.openBlock("if !ok {", "}", () -> { 978 writer.write( 979 "return fmt.Error($S)", 980 "Incorrect conversion from string to BigDecimal type" 981 ); 982 }); 983 return "*bd"; 984 case BLOB: 985 writer.addUseImports(SmithyGoDependency.BASE64); 986 writer.write("b, err := base64.StdEncoding.DecodeString($L)", operand); 987 writer.write("if err != nil { return err }"); 988 return "b"; 989 case SET: 990 case LIST: 991 // handle list/Set as target shape 992 MemberShape targetValueListMemberShape = CodegenUtils.expectCollectionShape(targetShape).getMember(); 993 return getHttpHeaderCollectionDeserializer(context, writer, targetValueListMemberShape, 994 binding, 995 operand); 996 default: 997 throw new CodegenException("unexpected shape type " + targetShape.getType()); 998 } 999 } 1000 getHttpHeaderCollectionDeserializer( GenerationContext context, GoWriter writer, MemberShape memberShape, HttpBinding binding, String operand )1001 private String getHttpHeaderCollectionDeserializer( 1002 GenerationContext context, 1003 GoWriter writer, 1004 MemberShape memberShape, 1005 HttpBinding binding, 1006 String operand 1007 ) { 1008 writer.write("var list []$P", context.getSymbolProvider().toSymbol(memberShape)); 1009 1010 String operandValue = operand + "Val"; 1011 writer.openBlock("for _, $L := range $L {", "}", operandValue, operand, () -> { 1012 String value = generateHttpHeaderValue(context, writer, memberShape, binding, operandValue); 1013 writer.write("list = append(list, $L)", 1014 CodegenUtils.getAsPointerIfPointable(context.getModel(), writer, 1015 GoPointableIndex.of(context.getModel()), memberShape, value)); 1016 }); 1017 return "list"; 1018 } 1019 writeRestDeserializerMember( GenerationContext context, GoWriter writer, HttpBinding binding )1020 private void writeRestDeserializerMember( 1021 GenerationContext context, 1022 GoWriter writer, 1023 HttpBinding binding 1024 ) { 1025 MemberShape memberShape = binding.getMember(); 1026 Shape targetShape = context.getModel().expectShape(memberShape.getTarget()); 1027 String memberName = context.getSymbolProvider().toMemberName(memberShape); 1028 1029 switch (binding.getLocation()) { 1030 case HEADER: 1031 writeHeaderDeserializerFunction(context, writer, memberName, memberShape, binding); 1032 break; 1033 case PREFIX_HEADERS: 1034 if (!targetShape.isMapShape()) { 1035 throw new CodegenException("unexpected prefix-header shape type found in Http bindings"); 1036 } 1037 writePrefixHeaderDeserializerFunction(context, writer, memberName, memberShape, binding); 1038 break; 1039 case RESPONSE_CODE: 1040 writer.addUseImports(SmithyGoDependency.SMITHY_PTR); 1041 writer.write("v.$L = $L", memberName, 1042 CodegenUtils.getAsPointerIfPointable(context.getModel(), writer, 1043 GoPointableIndex.of(context.getModel()), memberShape, "int32(response.StatusCode)")); 1044 break; 1045 default: 1046 throw new CodegenException("unexpected http binding found"); 1047 } 1048 } 1049 writeHeaderDeserializerFunction( GenerationContext context, GoWriter writer, String memberName, MemberShape memberShape, HttpBinding binding )1050 private void writeHeaderDeserializerFunction( 1051 GenerationContext context, 1052 GoWriter writer, 1053 String memberName, 1054 MemberShape memberShape, 1055 HttpBinding binding 1056 ) { 1057 writer.openBlock("if headerValues := response.Header.Values($S); len(headerValues) != 0 {", "}", 1058 binding.getLocationName(), () -> { 1059 Shape targetShape = context.getModel().expectShape(memberShape.getTarget()); 1060 1061 String operand = "headerValues"; 1062 operand = writeHeaderValueAccessor(context, writer, targetShape, binding, operand); 1063 1064 String value = generateHttpHeaderValue(context, writer, memberShape, binding, 1065 operand); 1066 writer.write("v.$L = $L", memberName, 1067 CodegenUtils.getAsPointerIfPointable(context.getModel(), writer, 1068 GoPointableIndex.of(context.getModel()), memberShape, value)); 1069 }); 1070 } 1071 writePrefixHeaderDeserializerFunction( GenerationContext context, GoWriter writer, String memberName, MemberShape memberShape, HttpBinding binding )1072 private void writePrefixHeaderDeserializerFunction( 1073 GenerationContext context, 1074 GoWriter writer, 1075 String memberName, 1076 MemberShape memberShape, 1077 HttpBinding binding 1078 ) { 1079 String prefix = binding.getLocationName(); 1080 Shape targetShape = context.getModel().expectShape(memberShape.getTarget()); 1081 1082 MemberShape valueMemberShape = targetShape.asMapShape() 1083 .orElseThrow(() -> new CodegenException("prefix headers must target map shape")) 1084 .getValue(); 1085 1086 writer.openBlock("for headerKey, headerValues := range response.Header {", "}", () -> { 1087 writer.addUseImports(SmithyGoDependency.STRINGS); 1088 Symbol targetSymbol = context.getSymbolProvider().toSymbol(targetShape); 1089 1090 writer.openBlock( 1091 "if lenPrefix := len($S); " 1092 + "len(headerKey) >= lenPrefix && strings.EqualFold(headerKey[:lenPrefix], $S) {", 1093 "}", prefix, prefix, () -> { 1094 writer.openBlock("if v.$L == nil {", "}", memberName, () -> { 1095 writer.write("v.$L = $P{}", memberName, targetSymbol); 1096 }); 1097 1098 String operand = "headerValues"; 1099 operand = writeHeaderValueAccessor(context, writer, targetShape, binding, operand); 1100 1101 String value = generateHttpHeaderValue(context, writer, valueMemberShape, 1102 binding, operand); 1103 writer.write("v.$L[strings.ToLower(headerKey[lenPrefix:])] = $L", memberName, 1104 CodegenUtils.getAsPointerIfPointable(context.getModel(), writer, 1105 GoPointableIndex.of(context.getModel()), valueMemberShape, value)); 1106 }); 1107 }); 1108 } 1109 1110 /** 1111 * Returns the header value accessor operand, and also if the target shape is a list/set will write the splitting 1112 * of the header values by comma(,) utility helper. 1113 * 1114 * @param context generation context 1115 * @param writer writer 1116 * @param targetShape target shape 1117 * @param binding http binding location 1118 * @param operand operand of the header values. 1119 * @return returns operand for accessing the header values 1120 */ writeHeaderValueAccessor( GenerationContext context, GoWriter writer, Shape targetShape, HttpBinding binding, String operand )1121 private String writeHeaderValueAccessor( 1122 GenerationContext context, 1123 GoWriter writer, 1124 Shape targetShape, 1125 HttpBinding binding, 1126 String operand 1127 ) { 1128 switch (targetShape.getType()) { 1129 case LIST: 1130 case SET: 1131 writerHeaderListValuesSplit(context, writer, CodegenUtils.expectCollectionShape(targetShape), binding, 1132 operand); 1133 break; 1134 default: 1135 // Always use first element in header, ignores if there are multiple headers with this key. 1136 operand += "[0]"; 1137 break; 1138 } 1139 1140 return operand; 1141 } 1142 1143 /** 1144 * Writes the utility to split split comma separate header values into a single list for consistent iteration. Also 1145 * has special case handling for HttpDate timestamp format when serialized as a header list. Assigns the split 1146 * header values back to the same operand name. 1147 * 1148 * @param context generation context 1149 * @param writer writer 1150 * @param shape target collection shape 1151 * @param binding http binding location 1152 * @param operand operand of the header values. 1153 */ writerHeaderListValuesSplit( GenerationContext context, GoWriter writer, CollectionShape shape, HttpBinding binding, String operand )1154 private void writerHeaderListValuesSplit( 1155 GenerationContext context, 1156 GoWriter writer, 1157 CollectionShape shape, 1158 HttpBinding binding, 1159 String operand 1160 ) { 1161 writer.openBlock("{", "}", () -> { 1162 writer.write("var err error"); 1163 writer.addUseImports(SmithyGoDependency.SMITHY_HTTP_TRANSPORT); 1164 if (isHttpDateTimestamp(context.getModel(), binding.getLocation(), shape.getMember())) { 1165 writer.write("$L, err = smithyhttp.SplitHTTPDateTimestampHeaderListValues($L)", operand, operand); 1166 } else { 1167 writer.write("$L, err = smithyhttp.SplitHeaderListValues($L)", operand, operand); 1168 } 1169 writer.openBlock("if err != nil {", "}", () -> { 1170 writer.write("return err"); 1171 }); 1172 }); 1173 } 1174 1175 @Override generateSharedDeserializerComponents(GenerationContext context)1176 public void generateSharedDeserializerComponents(GenerationContext context) { 1177 deserializingErrorShapes.forEach(error -> generateErrorDeserializer(context, error)); 1178 deserializeDocumentBindingShapes.addAll(ProtocolUtils.resolveRequiredDocumentShapeSerde( 1179 context.getModel(), deserializeDocumentBindingShapes)); 1180 generateDocumentBodyShapeDeserializers(context, deserializeDocumentBindingShapes); 1181 } 1182 1183 /** 1184 * Adds the top-level shapes from the operation that bind to the body document that require deserializer functions. 1185 * 1186 * @param context the generator context 1187 * @param operation the operation to add document binders from 1188 */ addOperationDocumentShapeBindersForDeserializer(GenerationContext context, OperationShape operation)1189 private void addOperationDocumentShapeBindersForDeserializer(GenerationContext context, OperationShape operation) { 1190 Model model = context.getModel(); 1191 HttpBindingIndex httpBindingIndex = model.getKnowledge(HttpBindingIndex.class); 1192 addDocumentDeserializerBindingShapes(model, httpBindingIndex, operation); 1193 1194 for (ShapeId errorShapeId : operation.getErrors()) { 1195 addDocumentDeserializerBindingShapes(model, httpBindingIndex, errorShapeId); 1196 } 1197 } 1198 addDocumentDeserializerBindingShapes(Model model, HttpBindingIndex index, ToShapeId shape)1199 private void addDocumentDeserializerBindingShapes(Model model, HttpBindingIndex index, ToShapeId shape) { 1200 // Walk and add members shapes to the list that will require deserializer functions 1201 for (HttpBinding binding : index.getResponseBindings(shape).values()) { 1202 Shape targetShape = model.expectShape(binding.getMember().getTarget()); 1203 if (requiresDocumentSerdeFunction(targetShape) 1204 && (binding.getLocation() == HttpBinding.Location.DOCUMENT 1205 || binding.getLocation() == HttpBinding.Location.PAYLOAD)) { 1206 deserializeDocumentBindingShapes.add(targetShape); 1207 } 1208 } 1209 } 1210 1211 /** 1212 * Generates the operation document deserializer function. 1213 * 1214 * @param context the generation context 1215 * @param operation the operation shape being generated 1216 */ generateOperationDocumentDeserializer(GenerationContext context, OperationShape operation)1217 protected abstract void generateOperationDocumentDeserializer(GenerationContext context, OperationShape operation); 1218 1219 /** 1220 * Generates deserialization functions for shapes in the passed set. These functions 1221 * should return a value that can then be deserialized by the implementation of 1222 * {@code deserializeOutputDocument}. 1223 * 1224 * @param context The generation context. 1225 * @param shapes The shapes to generate deserialization for. 1226 */ generateDocumentBodyShapeDeserializers(GenerationContext context, Set<Shape> shapes)1227 protected abstract void generateDocumentBodyShapeDeserializers(GenerationContext context, Set<Shape> shapes); 1228 generateErrorDeserializer(GenerationContext context, StructureShape shape)1229 private void generateErrorDeserializer(GenerationContext context, StructureShape shape) { 1230 GoWriter writer = context.getWriter(); 1231 String functionName = ProtocolGenerator.getErrorDeserFunctionName(shape, context.getProtocolName()); 1232 Symbol responseType = getApplicationProtocol().getResponseType(); 1233 1234 writer.addUseImports(SmithyGoDependency.BYTES); 1235 writer.openBlock("func $L(response $P, errorBody *bytes.Reader) error {", "}", 1236 functionName, responseType, () -> deserializeError(context, shape)); 1237 writer.write(""); 1238 } 1239 1240 /** 1241 * Writes a function body that deserializes the given error. 1242 * 1243 * <p>Two parameters will be available in scope: 1244 * <ul> 1245 * <li>{@code response: smithyhttp.HTTPResponse}: the HTTP response received.</li> 1246 * <li>{@code errorBody: bytes.BytesReader}: the HTTP response body.</li> 1247 * </ul> 1248 * 1249 * @param context The generation context. 1250 * @param shape The error shape. 1251 */ deserializeError(GenerationContext context, StructureShape shape)1252 protected abstract void deserializeError(GenerationContext context, StructureShape shape); 1253 1254 /** 1255 * Converts the first letter and any letter following a hyphen to upper case. The remaining letters are lower cased. 1256 * 1257 * @param key the header 1258 * @return the canonical header 1259 */ getCanonicalHeader(String key)1260 private String getCanonicalHeader(String key) { 1261 char[] chars = key.toCharArray(); 1262 boolean upper = true; 1263 for (int i = 0; i < chars.length; i++) { 1264 char c = chars[i]; 1265 c = upper ? Character.toUpperCase(c) : Character.toLowerCase(c); 1266 chars[i] = c; 1267 upper = c == '-'; 1268 } 1269 return new String(chars); 1270 } 1271 } 1272