1 package software.amazon.smithy.aws.go.codegen.customization;
2 
3 import java.util.List;
4 import software.amazon.smithy.aws.go.codegen.XmlProtocolUtils;
5 import software.amazon.smithy.aws.traits.ServiceTrait;
6 import software.amazon.smithy.codegen.core.CodegenException;
7 import software.amazon.smithy.codegen.core.Symbol;
8 import software.amazon.smithy.codegen.core.SymbolProvider;
9 import software.amazon.smithy.go.codegen.GoDelegator;
10 import software.amazon.smithy.go.codegen.GoSettings;
11 import software.amazon.smithy.go.codegen.GoStackStepMiddlewareGenerator;
12 import software.amazon.smithy.go.codegen.GoWriter;
13 import software.amazon.smithy.go.codegen.SmithyGoDependency;
14 import software.amazon.smithy.go.codegen.SymbolUtils;
15 import software.amazon.smithy.go.codegen.integration.GoIntegration;
16 import software.amazon.smithy.go.codegen.integration.MiddlewareRegistrar;
17 import software.amazon.smithy.go.codegen.integration.ProtocolGenerator;
18 import software.amazon.smithy.go.codegen.integration.ProtocolUtils;
19 import software.amazon.smithy.go.codegen.integration.RuntimeClientPlugin;
20 import software.amazon.smithy.model.Model;
21 import software.amazon.smithy.model.shapes.OperationShape;
22 import software.amazon.smithy.model.shapes.ServiceShape;
23 import software.amazon.smithy.model.shapes.Shape;
24 import software.amazon.smithy.model.shapes.ShapeId;
25 import software.amazon.smithy.utils.ListUtils;
26 
27 /**
28  * This integration generates a custom deserializer for GetBucketLocation response.
29  * Amazon S3 service does not wrap the GetBucketLocation response with Operation
30  * name xml tags, and thus custom deserialization is required.
31  * <p>
32  * Related to aws/aws-sdk-go-v2#908
33  */
34 public class S3GetBucketLocation implements GoIntegration {
35 
36     private final String protocolName = "awsRestxml";
37     private final String swapDeserializerFuncName = "swapDeserializerHelper";
38     private final String getBucketLocationOpID = "GetBucketLocation";
39 
40     /**
41      * Return true if service is Amazon S3.
42      *
43      * @param model   is the generation model.
44      * @param service is the service shape being audited.
45      */
isS3Service(Model model, ServiceShape service)46     private static boolean isS3Service(Model model, ServiceShape service) {
47         String serviceId = service.expectTrait(ServiceTrait.class).getSdkId();
48         return serviceId.equalsIgnoreCase("S3");
49     }
50 
51     /**
52      * returns name of the deserializer middleware written wrt this customization.
53      *
54      * @param operation the operation for which custom deserializer is generated.
55      */
getDeserializeMiddlewareName(OperationShape operation)56     private String getDeserializeMiddlewareName(OperationShape operation) {
57         return ProtocolGenerator.getDeserializeMiddlewareName(operation.getId(), protocolName) + "_custom";
58     }
59 
60     @Override
getClientPlugins()61     public List<RuntimeClientPlugin> getClientPlugins() {
62         return ListUtils.of(
63                 RuntimeClientPlugin.builder()
64                         .operationPredicate((model, service, operation) -> {
65                             return isS3Service(model, service) && operation.getId().getName()
66                                     .equals(getBucketLocationOpID);
67                         })
68                         .registerMiddleware(MiddlewareRegistrar.builder()
69                                 .resolvedFunction(
70                                         SymbolUtils.createValueSymbolBuilder(swapDeserializerFuncName).build())
71                                 .build())
72                         .build()
73         );
74     }
75 
76     @Override
writeAdditionalFiles( GoSettings settings, Model model, SymbolProvider symbolProvider, GoDelegator goDelegator )77     public void writeAdditionalFiles(
78             GoSettings settings,
79             Model model,
80             SymbolProvider symbolProvider,
81             GoDelegator goDelegator
82     ) {
83         ShapeId serviceId = settings.getService();
84         ServiceShape service = model.expectShape(serviceId, ServiceShape.class);
85         if (!isS3Service(model, service)) {
86             return;
87         }
88 
89         for (ShapeId operationId : service.getAllOperations()) {
90             if (!(operationId.getName().equals(getBucketLocationOpID))) {
91                 continue;
92             }
93 
94             OperationShape operation = model.expectShape(operationId, OperationShape.class);
95             goDelegator.useShapeWriter(operation, writer -> {
96                 writeCustomDeserializer(writer, model, symbolProvider, operation);
97                 writeDeserializerSwapFunction(writer, operation);
98             });
99         }
100 
101     }
102 
103     /**
104      * writes helper function to swap deserialization middleware with the generated
105      * custom deserializer middleware.
106      *
107      * @param writer    is the go writer used
108      * @param operation is the operation for which swap function is written.
109      */
writeDeserializerSwapFunction( GoWriter writer, OperationShape operation )110     private void writeDeserializerSwapFunction(
111             GoWriter writer,
112             OperationShape operation
113     ) {
114         writer.writeDocs("Helper to swap in a custom deserializer");
115         writer.openBlock("func $L(stack *middleware.Stack) error{", "}",
116                 swapDeserializerFuncName, () -> {
117                     writer.write("_, err := stack.Deserialize.Swap($S, &$L{})",
118                             ProtocolUtils.OPERATION_DESERIALIZER_MIDDLEWARE_ID.getString(),
119                             getDeserializeMiddlewareName(operation)
120                     );
121                     writer.write("if err != nil { return err }");
122                     writer.write("return nil");
123                 });
124     }
125 
126     /**
127      * writes a custom deserializer middleware for the provided operation.
128      *
129      * @param goWriter       is the go writer used.
130      * @param model          is the generation model.
131      * @param symbolProvider is the symbol provider.
132      * @param operation      is the operation shape for which custom deserializer is written.
133      */
writeCustomDeserializer( GoWriter goWriter, Model model, SymbolProvider symbolProvider, OperationShape operation )134     private void writeCustomDeserializer(
135             GoWriter goWriter,
136             Model model,
137             SymbolProvider symbolProvider,
138             OperationShape operation
139     ) {
140 
141         GoStackStepMiddlewareGenerator middleware = GoStackStepMiddlewareGenerator.createDeserializeStepMiddleware(
142                 getDeserializeMiddlewareName(operation), ProtocolUtils.OPERATION_DESERIALIZER_MIDDLEWARE_ID);
143 
144         String errorFunctionName = ProtocolGenerator.getOperationErrorDeserFunctionName(
145                 operation, protocolName);
146 
147         middleware.writeMiddleware(goWriter, (generator, writer) -> {
148             writer.addUseImports(SmithyGoDependency.FMT);
149 
150             writer.write("out, metadata, err = next.$L(ctx, in)", generator.getHandleMethodName());
151             writer.write("if err != nil { return out, metadata, err }");
152             writer.write("");
153 
154             writer.addUseImports(SmithyGoDependency.SMITHY_HTTP_TRANSPORT);
155             writer.write("response, ok := out.RawResponse.(*smithyhttp.Response)");
156             writer.openBlock("if !ok {", "}", () -> {
157                 writer.addUseImports(SmithyGoDependency.SMITHY);
158                 writer.write(String.format("return out, metadata, &smithy.DeserializationError{Err: %s}",
159                         "fmt.Errorf(\"unknown transport type %T\", out.RawResponse)"));
160             });
161             writer.write("");
162 
163             writer.openBlock("if response.StatusCode < 200 || response.StatusCode >= 300 {", "}", () -> {
164                 writer.write("return out, metadata, $L(response, &metadata)", errorFunctionName);
165             });
166 
167             Shape outputShape = model.expectShape(operation.getOutput()
168                     .orElseThrow(() -> new CodegenException("expect output shape for operation: " + operation.getId()))
169             );
170 
171             Symbol outputSymbol = symbolProvider.toSymbol(outputShape);
172 
173             // initialize out.Result as output structure shape
174             writer.write("output := &$T{}", outputSymbol);
175             writer.write("out.Result = output");
176             writer.write("");
177 
178             writer.addUseImports(SmithyGoDependency.XML);
179             writer.addUseImports(SmithyGoDependency.SMITHY_XML);
180             writer.addUseImports(SmithyGoDependency.IO);
181             writer.addUseImports(SmithyGoDependency.SMITHY_IO);
182 
183             writer.write("var buff [1024]byte");
184             writer.write("ringBuffer := smithyio.NewRingBuffer(buff[:])");
185             writer.write("body := io.TeeReader(response.Body, ringBuffer)");
186             writer.write("rootDecoder := xml.NewDecoder(body)");
187 
188             // define a decoder with empty start element since we s3 does not wrap Location Constraint
189             // xml tag with operation specific xml tag.
190             writer.write("decoder := smithyxml.WrapNodeDecoder(rootDecoder, xml.StartElement{})");
191 
192             String deserFuncName = ProtocolGenerator.getDocumentDeserializerFunctionName(outputShape, protocolName);
193             writer.addUseImports(SmithyGoDependency.IO);
194 
195             // delegate to already generated inner body deserializer function.
196             writer.write("err = $L(&output, decoder)", deserFuncName);
197 
198             // EOF error is valid in this case, as we provide a NOP start element at start.
199             // Note that we correctly handle unexpected EOF.
200             writer.addUseImports(SmithyGoDependency.IO);
201             writer.write("if err == io.EOF { err = nil }");
202 
203             XmlProtocolUtils.handleDecodeError(writer, "out, metadata,");
204 
205             writer.write("");
206             writer.write("return out, metadata, err");
207         });
208     }
209 }
210