1 /*
2  * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License").
5  * You may not use this file except in compliance with the License.
6  * A copy of the License is located at
7  *
8  *  http://aws.amazon.com/apache2.0
9  *
10  * or in the "license" file accompanying this file. This file is distributed
11  * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
12  * express or implied. See the License for the specific language governing
13  * permissions and limitations under the License.
14  */
15 
16 package software.amazon.smithy.aws.go.codegen;
17 
18 import java.util.Set;
19 import java.util.TreeSet;
20 import java.util.function.Consumer;
21 import software.amazon.smithy.codegen.core.Symbol;
22 import software.amazon.smithy.go.codegen.GoWriter;
23 import software.amazon.smithy.go.codegen.SmithyGoDependency;
24 import software.amazon.smithy.go.codegen.SymbolUtils;
25 import software.amazon.smithy.go.codegen.integration.HttpProtocolTestGenerator;
26 import software.amazon.smithy.go.codegen.integration.HttpProtocolUnitTestGenerator;
27 import software.amazon.smithy.go.codegen.integration.HttpProtocolUnitTestGenerator.ConfigValue;
28 import software.amazon.smithy.go.codegen.integration.HttpProtocolUnitTestRequestGenerator;
29 import software.amazon.smithy.go.codegen.integration.HttpProtocolUnitTestResponseErrorGenerator;
30 import software.amazon.smithy.go.codegen.integration.HttpProtocolUnitTestResponseGenerator;
31 import software.amazon.smithy.go.codegen.integration.IdempotencyTokenMiddlewareGenerator;
32 import software.amazon.smithy.go.codegen.integration.ProtocolGenerator.GenerationContext;
33 import software.amazon.smithy.model.shapes.ShapeId;
34 import software.amazon.smithy.utils.SetUtils;
35 
36 /**
37  * Utility methods for generating AWS protocols.
38  */
39 final class AwsProtocolUtils {
AwsProtocolUtils()40     private AwsProtocolUtils() {
41     }
42 
43     /**
44      * Generates HTTP protocol tests with all required AWS-specific configuration set.
45      *
46      * @param context The generation context.
47      */
generateHttpProtocolTests(GenerationContext context)48     static void generateHttpProtocolTests(GenerationContext context) {
49         Set<HttpProtocolUnitTestGenerator.ConfigValue> configValues = new TreeSet<>(SetUtils.of(
50                 HttpProtocolUnitTestGenerator.ConfigValue.builder()
51                         .name(AddAwsConfigFields.REGION_CONFIG_NAME)
52                         .value(writer -> writer.write("$S,", "us-west-2"))
53                         .build(),
54                 HttpProtocolUnitTestGenerator.ConfigValue.builder()
55                         .name(AddAwsConfigFields.ENDPOINT_RESOLVER_CONFIG_NAME)
56                         .value(writer -> {
57                             writer.addUseImports(AwsGoDependency.AWS_CORE);
58                             writer.openBlock("$L(func(region string, options $L) (e aws.Endpoint, err error) {", "}),",
59                                     EndpointGenerator.RESOLVER_FUNC_NAME, EndpointGenerator.RESOLVER_OPTIONS, () -> {
60                                         writer.write("e.URL = url");
61                                         writer.write("e.SigningRegion = \"us-west-2\"");
62                                         writer.write("return e, err");
63                                     });
64                         })
65                         .build(),
66                 HttpProtocolUnitTestGenerator.ConfigValue.builder()
67                         .name("APIOptions")
68                         .value(writer -> {
69                             Symbol stackSymbol = SymbolUtils.createPointableSymbolBuilder("Stack",
70                                     SmithyGoDependency.SMITHY_MIDDLEWARE).build();
71                             writer.openBlock("[]func($P) error{", "},", stackSymbol, () -> {
72                                 writer.openBlock("func(s $P) error {", "},", stackSymbol, () -> {
73                                     writer.write("s.Finalize.Clear()");
74                                     writer.write("return nil");
75                                 });
76                             });
77                         })
78                         .build()
79         ));
80 
81         // TODO can this check be replaced with a lookup into the runtime plugins?
82         if (IdempotencyTokenMiddlewareGenerator.hasOperationsWithIdempotencyToken(context.getModel(),
83                 context.getService())) {
84             configValues.add(
85                     HttpProtocolUnitTestGenerator.ConfigValue.builder()
86                             .name(IdempotencyTokenMiddlewareGenerator.IDEMPOTENCY_CONFIG_NAME)
87                             .value(writer -> {
88                                 writer.addUseImports(SmithyGoDependency.SMITHY_RAND);
89                                 writer.addUseImports(SmithyGoDependency.SMITHY_TESTING);
90                                 writer.write("smithyrand.NewUUIDIdempotencyToken(&smithytesting.ByteLoop{}),");
91                             })
92                             .build()
93             );
94         }
95 
96         Set<ConfigValue> inputConfigValues = new TreeSet<>(configValues);
97         inputConfigValues.add(HttpProtocolUnitTestGenerator.ConfigValue.builder()
98                 .name(AddAwsConfigFields.HTTP_CLIENT_CONFIG_NAME)
99                 .value(writer -> {
100                     writer.addUseImports(AwsGoDependency.AWS_HTTP_TRANSPORT);
101                     writer.write("awshttp.NewBuildableClient(),");
102                 })
103                 .build());
104 
105         Set<HttpProtocolUnitTestGenerator.SkipTest> inputSkipTests = new TreeSet<>(SetUtils.of(
106                 // REST-JSON Documents
107                 HttpProtocolUnitTestGenerator.SkipTest.builder()
108                         .service(ShapeId.from("aws.protocoltests.restjson#RestJson"))
109                         .operation(ShapeId.from("aws.protocoltests.restjson#InlineDocument"))
110                         .build(),
111                 HttpProtocolUnitTestGenerator.SkipTest.builder()
112                         .service(ShapeId.from("aws.protocoltests.restjson#RestJson"))
113                         .operation(ShapeId.from("aws.protocoltests.restjson#InlineDocumentAsPayload"))
114                         .build(),
115 
116                 // Null lists/maps without sparse tag
117                 HttpProtocolUnitTestGenerator.SkipTest.builder()
118                         .service(ShapeId.from("aws.protocoltests.restjson#RestJson"))
119                         .operation(ShapeId.from("aws.protocoltests.restjson#JsonLists"))
120                         .addTestName("RestJsonListsSerializeNull")
121                         .build(),
122                 HttpProtocolUnitTestGenerator.SkipTest.builder()
123                         .service(ShapeId.from("aws.protocoltests.restjson#RestJson"))
124                         .operation(ShapeId.from("aws.protocoltests.restjson#JsonMaps"))
125                         .addTestName("RestJsonSerializesNullMapValues")
126                         .build(),
127                 HttpProtocolUnitTestGenerator.SkipTest.builder()
128                         .service(ShapeId.from("aws.protocoltests.json#JsonProtocol"))
129                         .operation(ShapeId.from("aws.protocoltests.json#NullOperation"))
130                         .addTestName("AwsJson11MapsSerializeNullValues")
131                         .addTestName("AwsJson11ListsSerializeNull")
132                         .build(),
133 
134                 // JSON RPC Documents
135                 HttpProtocolUnitTestGenerator.SkipTest.builder()
136                         .service(ShapeId.from("aws.protocoltests.json#JsonProtocol"))
137                         .operation(ShapeId.from("aws.protocoltests.json#PutAndGetInlineDocuments"))
138                         .build(),
139 
140                 // JSON RPC serialize empty modeled input should always serialize something
141                 HttpProtocolUnitTestGenerator.SkipTest.builder()
142                         .service(ShapeId.from("aws.protocoltests.json10#JsonRpc10"))
143                         .operation(ShapeId.from("aws.protocoltests.json10#EmptyInputAndEmptyOutput"))
144                         .addTestName("AwsJson10EmptyInputAndEmptyOutput")
145                         .build(),
146 
147                 // Rest XML namespaced attributes. This needs to be fixed, but can be punted
148                 // temporarily since this is only used in an output in a single service.
149                 // TODO: fix serializing namespaced xml attributes
150                 HttpProtocolUnitTestGenerator.SkipTest.builder()
151                         .service(ShapeId.from("aws.protocoltests.restxml.xmlns#RestXmlWithNamespace"))
152                         .operation(ShapeId.from("aws.protocoltests.restxml.xmlns#SimpleScalarProperties"))
153                         .build()
154                 ));
155 
156         Set<HttpProtocolUnitTestGenerator.SkipTest> outputSkipTests = new TreeSet<>(SetUtils.of(
157                 // REST-XML opinionated test - prefix headers as empty vs nil map
158                 HttpProtocolUnitTestGenerator.SkipTest.builder()
159                         .service(ShapeId.from("aws.protocoltests.restxml#RestXml"))
160                         .operation(ShapeId.from("aws.protocoltests.restxml#HttpPrefixHeaders"))
161                         .addTestName("HttpPrefixHeadersAreNotPresent")
162                         .build()
163         ));
164 
165         new HttpProtocolTestGenerator(context,
166                 (HttpProtocolUnitTestRequestGenerator.Builder) new HttpProtocolUnitTestRequestGenerator
167                         .Builder()
168                         .addSkipTests(inputSkipTests)
169                         .addClientConfigValues(inputConfigValues),
170                 (HttpProtocolUnitTestResponseGenerator.Builder) new HttpProtocolUnitTestResponseGenerator
171                         .Builder()
172                         .addSkipTests(outputSkipTests)
173                         .addClientConfigValues(configValues),
174                 (HttpProtocolUnitTestResponseErrorGenerator.Builder) new HttpProtocolUnitTestResponseErrorGenerator
175                         .Builder()
176                         .addClientConfigValues(configValues)
177         ).generateProtocolTests();
178     }
179 
writeJsonErrorMessageCodeDeserializer(GenerationContext context)180     public static void writeJsonErrorMessageCodeDeserializer(GenerationContext context) {
181         GoWriter writer = context.getWriter();
182         // The error code could be in the headers, even though for this protocol it should be in the body.
183         writer.write("code := response.Header.Get(\"X-Amzn-ErrorType\")");
184         writer.write("if len(code) != 0 { errorCode = restjson.SanitizeErrorCode(code) }");
185         writer.write("");
186 
187         initializeJsonDecoder(writer, "errorBody");
188         writer.addUseImports(AwsGoDependency.AWS_REST_JSON_PROTOCOL);
189         // This will check various body locations for the error code and error message
190         writer.write("code, message, err := restjson.GetErrorInfo(decoder)");
191         handleDecodeError(writer);
192 
193         writer.addUseImports(SmithyGoDependency.IO);
194         // Reset the body in case it needs to be used for anything else.
195         writer.write("errorBody.Seek(0, io.SeekStart)");
196 
197         // Only set the values if something was found so that we keep the default values.
198         writer.write("if len(code) != 0 { errorCode = restjson.SanitizeErrorCode(code) }");
199         writer.write("if len(message) != 0 { errorMessage = message }");
200         writer.write("");
201     }
202 
initializeJsonDecoder(GoWriter writer, String bodyLocation)203     public static void initializeJsonDecoder(GoWriter writer, String bodyLocation) {
204         // Use a ring buffer and tee reader to help in pinpointing any deserialization errors.
205         writer.addUseImports(SmithyGoDependency.SMITHY_IO);
206         writer.write("var buff [1024]byte");
207         writer.write("ringBuffer := smithyio.NewRingBuffer(buff[:])");
208         writer.write("");
209 
210         writer.addUseImports(SmithyGoDependency.IO);
211         writer.addUseImports(SmithyGoDependency.JSON);
212         writer.write("body := io.TeeReader($L, ringBuffer)", bodyLocation);
213         writer.write("decoder := json.NewDecoder(body)");
214         writer.write("decoder.UseNumber()");
215     }
216 
217     /**
218      * Decodes JSON into {@code shape} with type {@code interface{}} using the encoding/json decoder
219      * referenced by {@code decoder}.
220      *
221      * @param writer            GoWriter to write code to
222      * @param errorReturnExtras extra parameters to return if an error occurs
223      */
decodeJsonIntoInterface(GoWriter writer, String errorReturnExtras)224     public static void decodeJsonIntoInterface(GoWriter writer, String errorReturnExtras) {
225         writer.write("var shape interface{}");
226         writer.addUseImports(SmithyGoDependency.IO);
227         writer.openBlock("if err := decoder.Decode(&shape); err != nil && err != io.EOF {", "}", () -> {
228             wrapAsDeserializationError(writer);
229             writer.write("return $Lerr", errorReturnExtras);
230         });
231         writer.write("");
232     }
233 
234     /**
235      * Wraps the Go error {@code err} in a {@code DeserializationError} with a snapshot
236      *
237      * @param writer
238      */
wrapAsDeserializationError(GoWriter writer)239     private static void wrapAsDeserializationError(GoWriter writer) {
240         writer.write("var snapshot bytes.Buffer");
241         writer.write("io.Copy(&snapshot, ringBuffer)");
242         writer.openBlock("err = &smithy.DeserializationError {", "}", () -> {
243             writer.write("Err: fmt.Errorf(\"failed to decode response body, %w\", err),");
244             writer.write("Snapshot: snapshot.Bytes(),");
245         });
246     }
247 
handleDecodeError(GoWriter writer, String returnExtras)248     public static void handleDecodeError(GoWriter writer, String returnExtras) {
249         writer.openBlock("if err != nil {", "}", () -> {
250             writer.addUseImports(SmithyGoDependency.BYTES);
251             writer.addUseImports(SmithyGoDependency.SMITHY);
252             writer.addUseImports(SmithyGoDependency.IO);
253             wrapAsDeserializationError(writer);
254             writer.write("return $Lerr", returnExtras);
255         }).write("");
256     }
257 
handleDecodeError(GoWriter writer)258     public static void handleDecodeError(GoWriter writer) {
259         handleDecodeError(writer, "");
260     }
261 }
262