1 package software.amazon.smithy.aws.go.codegen.customization; 2 3 import java.util.List; 4 import software.amazon.smithy.aws.go.codegen.XmlProtocolUtils; 5 import software.amazon.smithy.codegen.core.CodegenException; 6 import software.amazon.smithy.codegen.core.Symbol; 7 import software.amazon.smithy.codegen.core.SymbolProvider; 8 import software.amazon.smithy.go.codegen.GoDelegator; 9 import software.amazon.smithy.go.codegen.GoSettings; 10 import software.amazon.smithy.go.codegen.GoStackStepMiddlewareGenerator; 11 import software.amazon.smithy.go.codegen.GoWriter; 12 import software.amazon.smithy.go.codegen.SmithyGoDependency; 13 import software.amazon.smithy.go.codegen.SymbolUtils; 14 import software.amazon.smithy.go.codegen.integration.GoIntegration; 15 import software.amazon.smithy.go.codegen.integration.MiddlewareRegistrar; 16 import software.amazon.smithy.go.codegen.integration.ProtocolGenerator; 17 import software.amazon.smithy.go.codegen.integration.ProtocolUtils; 18 import software.amazon.smithy.go.codegen.integration.RuntimeClientPlugin; 19 import software.amazon.smithy.model.Model; 20 import software.amazon.smithy.model.shapes.OperationShape; 21 import software.amazon.smithy.model.shapes.ServiceShape; 22 import software.amazon.smithy.model.shapes.Shape; 23 import software.amazon.smithy.model.shapes.ShapeId; 24 import software.amazon.smithy.utils.ListUtils; 25 26 /** 27 * This integration generates a custom deserializer for GetBucketLocation response. 28 * Amazon S3 service does not wrap the GetBucketLocation response with Operation 29 * name xml tags, and thus custom deserialization is required. 30 * <p> 31 * Related to aws/aws-sdk-go-v2#908 32 */ 33 public class S3GetBucketLocation implements GoIntegration { 34 35 private final String protocolName = "awsRestxml"; 36 private final String swapDeserializerFuncName = "swapDeserializerHelper"; 37 private final String getBucketLocationOpID = "GetBucketLocation"; 38 39 /** 40 * Return true if service is Amazon S3. 41 * 42 * @param model is the generation model. 43 * @param service is the service shape being audited. 44 */ isS3Service(Model model, ServiceShape service)45 private static boolean isS3Service(Model model, ServiceShape service) { 46 return S3ModelUtils.isServiceS3(model, service); 47 } 48 49 /** 50 * returns name of the deserializer middleware written wrt this customization. 51 * 52 * @param operation the operation for which custom deserializer is generated. 53 */ getDeserializeMiddlewareName(OperationShape operation)54 private String getDeserializeMiddlewareName(OperationShape operation) { 55 return ProtocolGenerator.getDeserializeMiddlewareName(operation.getId(), protocolName) + "_custom"; 56 } 57 58 @Override getClientPlugins()59 public List<RuntimeClientPlugin> getClientPlugins() { 60 return ListUtils.of( 61 RuntimeClientPlugin.builder() 62 .operationPredicate((model, service, operation) -> { 63 return isS3Service(model, service) && operation.getId().getName() 64 .equals(getBucketLocationOpID); 65 }) 66 .registerMiddleware(MiddlewareRegistrar.builder() 67 .resolvedFunction( 68 SymbolUtils.createValueSymbolBuilder(swapDeserializerFuncName).build()) 69 .build()) 70 .build() 71 ); 72 } 73 74 @Override writeAdditionalFiles( GoSettings settings, Model model, SymbolProvider symbolProvider, GoDelegator goDelegator )75 public void writeAdditionalFiles( 76 GoSettings settings, 77 Model model, 78 SymbolProvider symbolProvider, 79 GoDelegator goDelegator 80 ) { 81 ShapeId serviceId = settings.getService(); 82 ServiceShape service = model.expectShape(serviceId, ServiceShape.class); 83 if (!isS3Service(model, service)) { 84 return; 85 } 86 87 for (ShapeId operationId : service.getAllOperations()) { 88 if (!(operationId.getName().equals(getBucketLocationOpID))) { 89 continue; 90 } 91 92 OperationShape operation = model.expectShape(operationId, OperationShape.class); 93 goDelegator.useShapeWriter(operation, writer -> { 94 writeCustomDeserializer(writer, model, symbolProvider, operation); 95 writeDeserializerSwapFunction(writer, operation); 96 }); 97 } 98 99 } 100 101 /** 102 * writes helper function to swap deserialization middleware with the generated 103 * custom deserializer middleware. 104 * 105 * @param writer is the go writer used 106 * @param operation is the operation for which swap function is written. 107 */ writeDeserializerSwapFunction( GoWriter writer, OperationShape operation )108 private void writeDeserializerSwapFunction( 109 GoWriter writer, 110 OperationShape operation 111 ) { 112 writer.writeDocs("Helper to swap in a custom deserializer"); 113 writer.openBlock("func $L(stack *middleware.Stack) error{", "}", 114 swapDeserializerFuncName, () -> { 115 writer.write("_, err := stack.Deserialize.Swap($S, &$L{})", 116 ProtocolUtils.OPERATION_DESERIALIZER_MIDDLEWARE_ID.getString(), 117 getDeserializeMiddlewareName(operation) 118 ); 119 writer.write("if err != nil { return err }"); 120 writer.write("return nil"); 121 }); 122 } 123 124 /** 125 * writes a custom deserializer middleware for the provided operation. 126 * 127 * @param goWriter is the go writer used. 128 * @param model is the generation model. 129 * @param symbolProvider is the symbol provider. 130 * @param operation is the operation shape for which custom deserializer is written. 131 */ writeCustomDeserializer( GoWriter goWriter, Model model, SymbolProvider symbolProvider, OperationShape operation )132 private void writeCustomDeserializer( 133 GoWriter goWriter, 134 Model model, 135 SymbolProvider symbolProvider, 136 OperationShape operation 137 ) { 138 139 GoStackStepMiddlewareGenerator middleware = GoStackStepMiddlewareGenerator.createDeserializeStepMiddleware( 140 getDeserializeMiddlewareName(operation), ProtocolUtils.OPERATION_DESERIALIZER_MIDDLEWARE_ID); 141 142 String errorFunctionName = ProtocolGenerator.getOperationErrorDeserFunctionName( 143 operation, protocolName); 144 145 middleware.writeMiddleware(goWriter, (generator, writer) -> { 146 writer.addUseImports(SmithyGoDependency.FMT); 147 148 writer.write("out, metadata, err = next.$L(ctx, in)", generator.getHandleMethodName()); 149 writer.write("if err != nil { return out, metadata, err }"); 150 writer.write(""); 151 152 writer.addUseImports(SmithyGoDependency.SMITHY_HTTP_TRANSPORT); 153 writer.write("response, ok := out.RawResponse.(*smithyhttp.Response)"); 154 writer.openBlock("if !ok {", "}", () -> { 155 writer.addUseImports(SmithyGoDependency.SMITHY); 156 writer.write(String.format("return out, metadata, &smithy.DeserializationError{Err: %s}", 157 "fmt.Errorf(\"unknown transport type %T\", out.RawResponse)")); 158 }); 159 writer.write(""); 160 161 writer.openBlock("if response.StatusCode < 200 || response.StatusCode >= 300 {", "}", () -> { 162 writer.write("return out, metadata, $L(response, &metadata)", errorFunctionName); 163 }); 164 165 Shape outputShape = model.expectShape(operation.getOutput() 166 .orElseThrow(() -> new CodegenException("expect output shape for operation: " + operation.getId())) 167 ); 168 169 Symbol outputSymbol = symbolProvider.toSymbol(outputShape); 170 171 // initialize out.Result as output structure shape 172 writer.write("output := &$T{}", outputSymbol); 173 writer.write("out.Result = output"); 174 writer.write(""); 175 176 writer.addUseImports(SmithyGoDependency.XML); 177 writer.addUseImports(SmithyGoDependency.SMITHY_XML); 178 writer.addUseImports(SmithyGoDependency.IO); 179 writer.addUseImports(SmithyGoDependency.SMITHY_IO); 180 181 writer.write("var buff [1024]byte"); 182 writer.write("ringBuffer := smithyio.NewRingBuffer(buff[:])"); 183 writer.write("body := io.TeeReader(response.Body, ringBuffer)"); 184 writer.write("rootDecoder := xml.NewDecoder(body)"); 185 186 // define a decoder with empty start element since we s3 does not wrap Location Constraint 187 // xml tag with operation specific xml tag. 188 writer.write("decoder := smithyxml.WrapNodeDecoder(rootDecoder, xml.StartElement{})"); 189 190 String deserFuncName = ProtocolGenerator.getDocumentDeserializerFunctionName(outputShape, protocolName); 191 writer.addUseImports(SmithyGoDependency.IO); 192 193 // delegate to already generated inner body deserializer function. 194 writer.write("err = $L(&output, decoder)", deserFuncName); 195 196 // EOF error is valid in this case, as we provide a NOP start element at start. 197 // Note that we correctly handle unexpected EOF. 198 writer.addUseImports(SmithyGoDependency.IO); 199 writer.write("if err == io.EOF { err = nil }"); 200 201 XmlProtocolUtils.handleDecodeError(writer, "out, metadata,"); 202 203 writer.write(""); 204 writer.write("return out, metadata, err"); 205 }); 206 } 207 } 208