1 package software.amazon.smithy.aws.go.codegen.customization; 2 3 import java.util.List; 4 import software.amazon.smithy.aws.go.codegen.XmlProtocolUtils; 5 import software.amazon.smithy.aws.traits.ServiceTrait; 6 import software.amazon.smithy.codegen.core.CodegenException; 7 import software.amazon.smithy.codegen.core.Symbol; 8 import software.amazon.smithy.codegen.core.SymbolProvider; 9 import software.amazon.smithy.go.codegen.GoDelegator; 10 import software.amazon.smithy.go.codegen.GoSettings; 11 import software.amazon.smithy.go.codegen.GoStackStepMiddlewareGenerator; 12 import software.amazon.smithy.go.codegen.GoWriter; 13 import software.amazon.smithy.go.codegen.SmithyGoDependency; 14 import software.amazon.smithy.go.codegen.SymbolUtils; 15 import software.amazon.smithy.go.codegen.integration.GoIntegration; 16 import software.amazon.smithy.go.codegen.integration.MiddlewareRegistrar; 17 import software.amazon.smithy.go.codegen.integration.ProtocolGenerator; 18 import software.amazon.smithy.go.codegen.integration.ProtocolUtils; 19 import software.amazon.smithy.go.codegen.integration.RuntimeClientPlugin; 20 import software.amazon.smithy.model.Model; 21 import software.amazon.smithy.model.shapes.OperationShape; 22 import software.amazon.smithy.model.shapes.ServiceShape; 23 import software.amazon.smithy.model.shapes.Shape; 24 import software.amazon.smithy.model.shapes.ShapeId; 25 import software.amazon.smithy.utils.ListUtils; 26 27 /** 28 * This integration generates a custom deserializer for GetBucketLocation response. 29 * Amazon S3 service does not wrap the GetBucketLocation response with Operation 30 * name xml tags, and thus custom deserialization is required. 31 * <p> 32 * Related to aws/aws-sdk-go-v2#908 33 */ 34 public class S3GetBucketLocation implements GoIntegration { 35 36 private final String protocolName = "awsRestxml"; 37 private final String swapDeserializerFuncName = "swapDeserializerHelper"; 38 private final String getBucketLocationOpID = "GetBucketLocation"; 39 40 /** 41 * Return true if service is Amazon S3. 42 * 43 * @param model is the generation model. 44 * @param service is the service shape being audited. 45 */ isS3Service(Model model, ServiceShape service)46 private static boolean isS3Service(Model model, ServiceShape service) { 47 String serviceId = service.expectTrait(ServiceTrait.class).getSdkId(); 48 return serviceId.equalsIgnoreCase("S3"); 49 } 50 51 /** 52 * returns name of the deserializer middleware written wrt this customization. 53 * 54 * @param operation the operation for which custom deserializer is generated. 55 */ getDeserializeMiddlewareName(OperationShape operation)56 private String getDeserializeMiddlewareName(OperationShape operation) { 57 return ProtocolGenerator.getDeserializeMiddlewareName(operation.getId(), protocolName) + "_custom"; 58 } 59 60 @Override getClientPlugins()61 public List<RuntimeClientPlugin> getClientPlugins() { 62 return ListUtils.of( 63 RuntimeClientPlugin.builder() 64 .operationPredicate((model, service, operation) -> { 65 return isS3Service(model, service) && operation.getId().getName() 66 .equals(getBucketLocationOpID); 67 }) 68 .registerMiddleware(MiddlewareRegistrar.builder() 69 .resolvedFunction( 70 SymbolUtils.createValueSymbolBuilder(swapDeserializerFuncName).build()) 71 .build()) 72 .build() 73 ); 74 } 75 76 @Override writeAdditionalFiles( GoSettings settings, Model model, SymbolProvider symbolProvider, GoDelegator goDelegator )77 public void writeAdditionalFiles( 78 GoSettings settings, 79 Model model, 80 SymbolProvider symbolProvider, 81 GoDelegator goDelegator 82 ) { 83 ShapeId serviceId = settings.getService(); 84 ServiceShape service = model.expectShape(serviceId, ServiceShape.class); 85 if (!isS3Service(model, service)) { 86 return; 87 } 88 89 for (ShapeId operationId : service.getAllOperations()) { 90 if (!(operationId.getName().equals(getBucketLocationOpID))) { 91 continue; 92 } 93 94 OperationShape operation = model.expectShape(operationId, OperationShape.class); 95 goDelegator.useShapeWriter(operation, writer -> { 96 writeCustomDeserializer(writer, model, symbolProvider, operation); 97 writeDeserializerSwapFunction(writer, operation); 98 }); 99 } 100 101 } 102 103 /** 104 * writes helper function to swap deserialization middleware with the generated 105 * custom deserializer middleware. 106 * 107 * @param writer is the go writer used 108 * @param operation is the operation for which swap function is written. 109 */ writeDeserializerSwapFunction( GoWriter writer, OperationShape operation )110 private void writeDeserializerSwapFunction( 111 GoWriter writer, 112 OperationShape operation 113 ) { 114 writer.writeDocs("Helper to swap in a custom deserializer"); 115 writer.openBlock("func $L(stack *middleware.Stack) error{", "}", 116 swapDeserializerFuncName, () -> { 117 writer.write("_, err := stack.Deserialize.Swap($S, &$L{})", 118 ProtocolUtils.OPERATION_DESERIALIZER_MIDDLEWARE_ID.getString(), 119 getDeserializeMiddlewareName(operation) 120 ); 121 writer.write("if err != nil { return err }"); 122 writer.write("return nil"); 123 }); 124 } 125 126 /** 127 * writes a custom deserializer middleware for the provided operation. 128 * 129 * @param goWriter is the go writer used. 130 * @param model is the generation model. 131 * @param symbolProvider is the symbol provider. 132 * @param operation is the operation shape for which custom deserializer is written. 133 */ writeCustomDeserializer( GoWriter goWriter, Model model, SymbolProvider symbolProvider, OperationShape operation )134 private void writeCustomDeserializer( 135 GoWriter goWriter, 136 Model model, 137 SymbolProvider symbolProvider, 138 OperationShape operation 139 ) { 140 141 GoStackStepMiddlewareGenerator middleware = GoStackStepMiddlewareGenerator.createDeserializeStepMiddleware( 142 getDeserializeMiddlewareName(operation), ProtocolUtils.OPERATION_DESERIALIZER_MIDDLEWARE_ID); 143 144 String errorFunctionName = ProtocolGenerator.getOperationErrorDeserFunctionName( 145 operation, protocolName); 146 147 middleware.writeMiddleware(goWriter, (generator, writer) -> { 148 writer.addUseImports(SmithyGoDependency.FMT); 149 150 writer.write("out, metadata, err = next.$L(ctx, in)", generator.getHandleMethodName()); 151 writer.write("if err != nil { return out, metadata, err }"); 152 writer.write(""); 153 154 writer.addUseImports(SmithyGoDependency.SMITHY_HTTP_TRANSPORT); 155 writer.write("response, ok := out.RawResponse.(*smithyhttp.Response)"); 156 writer.openBlock("if !ok {", "}", () -> { 157 writer.addUseImports(SmithyGoDependency.SMITHY); 158 writer.write(String.format("return out, metadata, &smithy.DeserializationError{Err: %s}", 159 "fmt.Errorf(\"unknown transport type %T\", out.RawResponse)")); 160 }); 161 writer.write(""); 162 163 writer.openBlock("if response.StatusCode < 200 || response.StatusCode >= 300 {", "}", () -> { 164 writer.write("return out, metadata, $L(response, &metadata)", errorFunctionName); 165 }); 166 167 Shape outputShape = model.expectShape(operation.getOutput() 168 .orElseThrow(() -> new CodegenException("expect output shape for operation: " + operation.getId())) 169 ); 170 171 Symbol outputSymbol = symbolProvider.toSymbol(outputShape); 172 173 // initialize out.Result as output structure shape 174 writer.write("output := &$T{}", outputSymbol); 175 writer.write("out.Result = output"); 176 writer.write(""); 177 178 writer.addUseImports(SmithyGoDependency.XML); 179 writer.addUseImports(SmithyGoDependency.SMITHY_XML); 180 writer.addUseImports(SmithyGoDependency.IO); 181 writer.addUseImports(SmithyGoDependency.SMITHY_IO); 182 183 writer.write("var buff [1024]byte"); 184 writer.write("ringBuffer := smithyio.NewRingBuffer(buff[:])"); 185 writer.write("body := io.TeeReader(response.Body, ringBuffer)"); 186 writer.write("rootDecoder := xml.NewDecoder(body)"); 187 188 // define a decoder with empty start element since we s3 does not wrap Location Constraint 189 // xml tag with operation specific xml tag. 190 writer.write("decoder := smithyxml.WrapNodeDecoder(rootDecoder, xml.StartElement{})"); 191 192 String deserFuncName = ProtocolGenerator.getDocumentDeserializerFunctionName(outputShape, protocolName); 193 writer.addUseImports(SmithyGoDependency.IO); 194 195 // delegate to already generated inner body deserializer function. 196 writer.write("err = $L(&output, decoder)", deserFuncName); 197 198 // EOF error is valid in this case, as we provide a NOP start element at start. 199 // Note that we correctly handle unexpected EOF. 200 writer.addUseImports(SmithyGoDependency.IO); 201 writer.write("if err == io.EOF { err = nil }"); 202 203 XmlProtocolUtils.handleDecodeError(writer, "out, metadata,"); 204 205 writer.write(""); 206 writer.write("return out, metadata, err"); 207 }); 208 } 209 } 210