1 package software.amazon.smithy.aws.go.codegen.customization;
2 
3 import java.util.List;
4 import software.amazon.smithy.aws.go.codegen.XmlProtocolUtils;
5 import software.amazon.smithy.codegen.core.CodegenException;
6 import software.amazon.smithy.codegen.core.Symbol;
7 import software.amazon.smithy.codegen.core.SymbolProvider;
8 import software.amazon.smithy.go.codegen.GoDelegator;
9 import software.amazon.smithy.go.codegen.GoSettings;
10 import software.amazon.smithy.go.codegen.GoStackStepMiddlewareGenerator;
11 import software.amazon.smithy.go.codegen.GoWriter;
12 import software.amazon.smithy.go.codegen.SmithyGoDependency;
13 import software.amazon.smithy.go.codegen.SymbolUtils;
14 import software.amazon.smithy.go.codegen.integration.GoIntegration;
15 import software.amazon.smithy.go.codegen.integration.MiddlewareRegistrar;
16 import software.amazon.smithy.go.codegen.integration.ProtocolGenerator;
17 import software.amazon.smithy.go.codegen.integration.ProtocolUtils;
18 import software.amazon.smithy.go.codegen.integration.RuntimeClientPlugin;
19 import software.amazon.smithy.model.Model;
20 import software.amazon.smithy.model.shapes.OperationShape;
21 import software.amazon.smithy.model.shapes.ServiceShape;
22 import software.amazon.smithy.model.shapes.Shape;
23 import software.amazon.smithy.model.shapes.ShapeId;
24 import software.amazon.smithy.utils.ListUtils;
25 
26 /**
27  * This integration generates a custom deserializer for GetBucketLocation response.
28  * Amazon S3 service does not wrap the GetBucketLocation response with Operation
29  * name xml tags, and thus custom deserialization is required.
30  * <p>
31  * Related to aws/aws-sdk-go-v2#908
32  */
33 public class S3GetBucketLocation implements GoIntegration {
34 
35     private final String protocolName = "awsRestxml";
36     private final String swapDeserializerFuncName = "swapDeserializerHelper";
37     private final String getBucketLocationOpID = "GetBucketLocation";
38 
39     /**
40      * Return true if service is Amazon S3.
41      *
42      * @param model   is the generation model.
43      * @param service is the service shape being audited.
44      */
isS3Service(Model model, ServiceShape service)45     private static boolean isS3Service(Model model, ServiceShape service) {
46         return S3ModelUtils.isServiceS3(model, service);
47     }
48 
49     /**
50      * returns name of the deserializer middleware written wrt this customization.
51      *
52      * @param operation the operation for which custom deserializer is generated.
53      */
getDeserializeMiddlewareName(OperationShape operation)54     private String getDeserializeMiddlewareName(OperationShape operation) {
55         return ProtocolGenerator.getDeserializeMiddlewareName(operation.getId(), protocolName) + "_custom";
56     }
57 
58     @Override
getClientPlugins()59     public List<RuntimeClientPlugin> getClientPlugins() {
60         return ListUtils.of(
61                 RuntimeClientPlugin.builder()
62                         .operationPredicate((model, service, operation) -> {
63                             return isS3Service(model, service) && operation.getId().getName()
64                                     .equals(getBucketLocationOpID);
65                         })
66                         .registerMiddleware(MiddlewareRegistrar.builder()
67                                 .resolvedFunction(
68                                         SymbolUtils.createValueSymbolBuilder(swapDeserializerFuncName).build())
69                                 .build())
70                         .build()
71         );
72     }
73 
74     @Override
writeAdditionalFiles( GoSettings settings, Model model, SymbolProvider symbolProvider, GoDelegator goDelegator )75     public void writeAdditionalFiles(
76             GoSettings settings,
77             Model model,
78             SymbolProvider symbolProvider,
79             GoDelegator goDelegator
80     ) {
81         ShapeId serviceId = settings.getService();
82         ServiceShape service = model.expectShape(serviceId, ServiceShape.class);
83         if (!isS3Service(model, service)) {
84             return;
85         }
86 
87         for (ShapeId operationId : service.getAllOperations()) {
88             if (!(operationId.getName().equals(getBucketLocationOpID))) {
89                 continue;
90             }
91 
92             OperationShape operation = model.expectShape(operationId, OperationShape.class);
93             goDelegator.useShapeWriter(operation, writer -> {
94                 writeCustomDeserializer(writer, model, symbolProvider, operation);
95                 writeDeserializerSwapFunction(writer, operation);
96             });
97         }
98 
99     }
100 
101     /**
102      * writes helper function to swap deserialization middleware with the generated
103      * custom deserializer middleware.
104      *
105      * @param writer    is the go writer used
106      * @param operation is the operation for which swap function is written.
107      */
writeDeserializerSwapFunction( GoWriter writer, OperationShape operation )108     private void writeDeserializerSwapFunction(
109             GoWriter writer,
110             OperationShape operation
111     ) {
112         writer.writeDocs("Helper to swap in a custom deserializer");
113         writer.openBlock("func $L(stack *middleware.Stack) error{", "}",
114                 swapDeserializerFuncName, () -> {
115                     writer.write("_, err := stack.Deserialize.Swap($S, &$L{})",
116                             ProtocolUtils.OPERATION_DESERIALIZER_MIDDLEWARE_ID.getString(),
117                             getDeserializeMiddlewareName(operation)
118                     );
119                     writer.write("if err != nil { return err }");
120                     writer.write("return nil");
121                 });
122     }
123 
124     /**
125      * writes a custom deserializer middleware for the provided operation.
126      *
127      * @param goWriter       is the go writer used.
128      * @param model          is the generation model.
129      * @param symbolProvider is the symbol provider.
130      * @param operation      is the operation shape for which custom deserializer is written.
131      */
writeCustomDeserializer( GoWriter goWriter, Model model, SymbolProvider symbolProvider, OperationShape operation )132     private void writeCustomDeserializer(
133             GoWriter goWriter,
134             Model model,
135             SymbolProvider symbolProvider,
136             OperationShape operation
137     ) {
138 
139         GoStackStepMiddlewareGenerator middleware = GoStackStepMiddlewareGenerator.createDeserializeStepMiddleware(
140                 getDeserializeMiddlewareName(operation), ProtocolUtils.OPERATION_DESERIALIZER_MIDDLEWARE_ID);
141 
142         String errorFunctionName = ProtocolGenerator.getOperationErrorDeserFunctionName(
143                 operation, protocolName);
144 
145         middleware.writeMiddleware(goWriter, (generator, writer) -> {
146             writer.addUseImports(SmithyGoDependency.FMT);
147 
148             writer.write("out, metadata, err = next.$L(ctx, in)", generator.getHandleMethodName());
149             writer.write("if err != nil { return out, metadata, err }");
150             writer.write("");
151 
152             writer.addUseImports(SmithyGoDependency.SMITHY_HTTP_TRANSPORT);
153             writer.write("response, ok := out.RawResponse.(*smithyhttp.Response)");
154             writer.openBlock("if !ok {", "}", () -> {
155                 writer.addUseImports(SmithyGoDependency.SMITHY);
156                 writer.write(String.format("return out, metadata, &smithy.DeserializationError{Err: %s}",
157                         "fmt.Errorf(\"unknown transport type %T\", out.RawResponse)"));
158             });
159             writer.write("");
160 
161             writer.openBlock("if response.StatusCode < 200 || response.StatusCode >= 300 {", "}", () -> {
162                 writer.write("return out, metadata, $L(response, &metadata)", errorFunctionName);
163             });
164 
165             Shape outputShape = model.expectShape(operation.getOutput()
166                     .orElseThrow(() -> new CodegenException("expect output shape for operation: " + operation.getId()))
167             );
168 
169             Symbol outputSymbol = symbolProvider.toSymbol(outputShape);
170 
171             // initialize out.Result as output structure shape
172             writer.write("output := &$T{}", outputSymbol);
173             writer.write("out.Result = output");
174             writer.write("");
175 
176             writer.addUseImports(SmithyGoDependency.XML);
177             writer.addUseImports(SmithyGoDependency.SMITHY_XML);
178             writer.addUseImports(SmithyGoDependency.IO);
179             writer.addUseImports(SmithyGoDependency.SMITHY_IO);
180 
181             writer.write("var buff [1024]byte");
182             writer.write("ringBuffer := smithyio.NewRingBuffer(buff[:])");
183             writer.write("body := io.TeeReader(response.Body, ringBuffer)");
184             writer.write("rootDecoder := xml.NewDecoder(body)");
185 
186             // define a decoder with empty start element since we s3 does not wrap Location Constraint
187             // xml tag with operation specific xml tag.
188             writer.write("decoder := smithyxml.WrapNodeDecoder(rootDecoder, xml.StartElement{})");
189 
190             String deserFuncName = ProtocolGenerator.getDocumentDeserializerFunctionName(outputShape, protocolName);
191             writer.addUseImports(SmithyGoDependency.IO);
192 
193             // delegate to already generated inner body deserializer function.
194             writer.write("err = $L(&output, decoder)", deserFuncName);
195 
196             // EOF error is valid in this case, as we provide a NOP start element at start.
197             // Note that we correctly handle unexpected EOF.
198             writer.addUseImports(SmithyGoDependency.IO);
199             writer.write("if err == io.EOF { err = nil }");
200 
201             XmlProtocolUtils.handleDecodeError(writer, "out, metadata,");
202 
203             writer.write("");
204             writer.write("return out, metadata, err");
205         });
206     }
207 }
208