1 package software.amazon.smithy.aws.go.codegen;
2 
3 import java.util.Collection;
4 import java.util.Optional;
5 import software.amazon.smithy.aws.go.codegen.customization.AwsCustomGoDependency;
6 import software.amazon.smithy.aws.traits.ServiceTrait;
7 import software.amazon.smithy.aws.traits.protocols.RestXmlTrait;
8 import software.amazon.smithy.codegen.core.CodegenException;
9 import software.amazon.smithy.codegen.core.Symbol;
10 import software.amazon.smithy.codegen.core.SymbolProvider;
11 import software.amazon.smithy.go.codegen.GoValueAccessUtils;
12 import software.amazon.smithy.go.codegen.GoWriter;
13 import software.amazon.smithy.go.codegen.SmithyGoDependency;
14 import software.amazon.smithy.go.codegen.SymbolUtils;
15 import software.amazon.smithy.go.codegen.SyntheticClone;
16 import software.amazon.smithy.go.codegen.integration.ProtocolGenerator;
17 import software.amazon.smithy.go.codegen.knowledge.GoPointableIndex;
18 import software.amazon.smithy.model.shapes.MemberShape;
19 import software.amazon.smithy.model.shapes.ServiceShape;
20 import software.amazon.smithy.model.shapes.Shape;
21 import software.amazon.smithy.model.traits.EnumTrait;
22 import software.amazon.smithy.model.traits.TimestampFormatTrait;
23 import software.amazon.smithy.model.traits.XmlAttributeTrait;
24 import software.amazon.smithy.model.traits.XmlNameTrait;
25 import software.amazon.smithy.model.traits.XmlNamespaceTrait;
26 
27 public final class XmlProtocolUtils {
XmlProtocolUtils()28     private XmlProtocolUtils() {
29 
30     }
31 
32     /**
33      * generateXMLStartElement generates the XML start element for a shape. It is used to generate smithy xml's startElement.
34      *
35      * @param context  is the generation context.
36      * @param shape    is the Shape for which xml start element is to be generated.
37      * @param dst      is the operand name which holds the generated start element.
38      * @param inputSrc is the input variable for the shape with values to be serialized.
39      */
generateXMLStartElement( ProtocolGenerator.GenerationContext context, Shape shape, String dst, String inputSrc )40     public static void generateXMLStartElement(
41             ProtocolGenerator.GenerationContext context, Shape shape, String dst, String inputSrc
42     ) {
43         GoWriter writer = context.getWriter();
44         String attrName = dst + "Attr";
45         generateXmlNamespaceAndAttributes(context, shape, attrName, inputSrc);
46 
47         writer.openBlock("$L := smithyxml.StartElement{ ", "}", dst, () -> {
48             writer.openBlock("Name:smithyxml.Name{", "},", () -> {
49                 writer.write("Local: $S,", getSerializedXMLShapeName(context, shape));
50             });
51             writer.write("Attr : $L,", attrName);
52         });
53     }
54 
55     /**
56      * Generates XML Start element for a document shape marked as a payload.
57      *
58      * @param context  is the generation context.
59      * @param memberShape is the payload as document member shape
60      * @param dst is the operand name which holds the generated start element.
61      * @param inputSrc is the input variable for the shape with values to be serialized.
62      */
generatePayloadAsDocumentXMLStartElement( ProtocolGenerator.GenerationContext context, MemberShape memberShape, String dst, String inputSrc )63     public static void generatePayloadAsDocumentXMLStartElement(
64             ProtocolGenerator.GenerationContext context, MemberShape memberShape, String dst, String inputSrc
65     ) {
66         GoWriter writer = context.getWriter();
67         String attrName = dst + "Attr";
68         Shape targetShape = context.getModel().expectShape(memberShape.getTarget());
69 
70         generateXmlNamespaceAndAttributes(context, targetShape, attrName, inputSrc);
71 
72         writer.openBlock("$L := smithyxml.StartElement{ ", "}", dst, () -> {
73             writer.openBlock("Name:smithyxml.Name{", "},", () -> {
74                 String name = memberShape.getMemberName();
75                 if (targetShape.isStructureShape()) {
76                     if (memberShape.hasTrait(XmlNameTrait.class)) {
77                         name = getSerializedXMLMemberName(memberShape);
78                     } else {
79                         name = getSerializedXMLShapeName(context, targetShape);
80                     }
81                 }
82 
83                 writer.write("Local: $S,", name);
84 
85             });
86             writer.write("Attr : $L,", attrName);
87         });
88     }
89 
90 
91     /**
92      * Generates XML Attributes as per xmlNamespace and xmlAttribute traits.
93      *
94      * @param context is the generation context.
95      * @param shape is the shape that is decorated with XmlNamespace, XmlAttribute trait.
96      * @param dst is the operand name which holds the generated xml Attribute value.
97      * @param inputSrc is the input variable for the shape with values to be put as xml attributes.
98      */
generateXmlNamespaceAndAttributes( ProtocolGenerator.GenerationContext context, Shape shape, String dst, String inputSrc )99     private static void generateXmlNamespaceAndAttributes(
100             ProtocolGenerator.GenerationContext context, Shape shape, String dst, String inputSrc
101     ) {
102         GoWriter writer = context.getWriter();
103         writer.write("$L := []smithyxml.Attr{}", dst);
104 
105         Optional<XmlNamespaceTrait> xmlNamespaceTrait = shape.getTrait(XmlNamespaceTrait.class);
106         if (xmlNamespaceTrait.isPresent()) {
107             XmlNamespaceTrait namespace = xmlNamespaceTrait.get();
108             writer.write("$L = append($L, smithyxml.NewNamespaceAttribute($S, $S))",
109                     dst, dst,
110                     namespace.getPrefix().isPresent() ? namespace.getPrefix().get() : "", namespace.getUri()
111             );
112         }
113 
114         // Traverse member shapes to get attributes
115         if (shape.isMemberShape()) {
116             MemberShape memberShape = shape.asMemberShape().get();
117             Shape target = context.getModel().expectShape(memberShape.getTarget());
118             String memberName = context.getSymbolProvider().toMemberName(memberShape);
119             String operand = inputSrc + "." + memberName;
120             generateXmlAttributes(context, target.members(), operand, dst);
121         } else {
122             generateXmlAttributes(context, shape.members(), inputSrc, dst);
123         }
124     }
125 
generateXmlAttributes( ProtocolGenerator.GenerationContext context, Collection<MemberShape> members, String inputSrc, String dst )126     private static void generateXmlAttributes(
127             ProtocolGenerator.GenerationContext context,
128             Collection<MemberShape> members,
129             String inputSrc,
130             String dst
131     ) {
132         GoWriter writer = context.getWriter();
133         members.forEach(memberShape -> {
134             if (memberShape.hasTrait(XmlAttributeTrait.class)) {
135                 GoValueAccessUtils.writeIfNonZeroValueMember(context.getModel(), context.getSymbolProvider(),
136                         writer, memberShape, inputSrc, true, memberShape.isRequired(), (operand) -> {
137                             // xml attributes should always be string
138                             String dest = "av";
139                             formatXmlAttributeValueAsString(context, memberShape, operand, dest);
140                             writer.write("$L = append($L, smithyxml.NewAttribute($S, $L))",
141                                     dst, dst, getSerializedXMLMemberName(memberShape), dest);
142                         });
143             }
144         });
145     }
146 
147     // generates code to format xml attributes. If a shape type is timestamp, number, or boolean
148     // it will be formatted into a string.
formatXmlAttributeValueAsString( ProtocolGenerator.GenerationContext context, MemberShape member, String src, String dest )149     private static void formatXmlAttributeValueAsString(
150             ProtocolGenerator.GenerationContext context,
151             MemberShape member, String src, String dest
152     ) {
153         GoWriter writer = context.getWriter();
154         Shape target = context.getModel().expectShape(member.getTarget());
155 
156         // declare destination variable
157         writer.write("var $L string", dest);
158 
159         // Pointable value references need to be dereferenced before being used.
160         String derefSource = src;
161         if (GoPointableIndex.of(context.getModel()).isPointable(member)) {
162             derefSource = "*" + src;
163         }
164 
165         if (target.hasTrait(EnumTrait.class)) {
166             writer.write("$L = string($L)", dest, derefSource);
167             return;
168         } else if (target.isStringShape()) {
169             // create dereferenced copy of pointed to value.
170             writer.write("$L = $L", dest, derefSource);
171             return;
172         }
173 
174         if (target.isTimestampShape() || target.hasTrait(TimestampFormatTrait.class)) {
175             TimestampFormatTrait.Format format = member.getMemberTrait(context.getModel(), TimestampFormatTrait.class)
176                     .map(TimestampFormatTrait::getFormat).orElse(TimestampFormatTrait.Format.DATE_TIME);
177             writer.addUseImports(SmithyGoDependency.SMITHY_TIME);
178             switch (format) {
179                 case DATE_TIME:
180                     writer.write("$L = smithytime.FormatDateTime($L)", dest, derefSource);
181                     break;
182                 case HTTP_DATE:
183                     writer.write("$L = smithytime.FormatHTTPDate($L)", dest, derefSource);
184                     break;
185                 case EPOCH_SECONDS:
186                     writer.addUseImports(SmithyGoDependency.STRCONV);
187                     writer.write("$L = strconv.FormatFloat(smithytime.FormatEpochSeconds($L), 'f', -1, 64)",
188                             dest, derefSource);
189                     break;
190                 case UNKNOWN:
191                     throw new CodegenException("Unknown timestamp format");
192             }
193             return;
194         }
195 
196         if (target.isBooleanShape()) {
197             writer.write(SmithyGoDependency.STRCONV);
198             writer.write("$L = strconv.FormatBool($L)", dest, derefSource);
199             return;
200         }
201 
202         if (target.isByteShape() || target.isShortShape() || target.isIntegerShape() || target.isLongShape()) {
203             writer.write(SmithyGoDependency.STRCONV);
204             writer.write("$L = strconv.FormatInt(int64($L), 10)", dest, derefSource);
205             return;
206         }
207 
208         if (target.isFloatShape()) {
209             writer.write(SmithyGoDependency.STRCONV);
210             writer.write("$L = strconv.FormatFloat(float64($L),'f', -1, 32)", dest, derefSource);
211             return;
212         }
213 
214         if (target.isDoubleShape()) {
215             writer.write(SmithyGoDependency.STRCONV);
216             writer.write("$L = strconv.FormatFloat($L,'f', -1, 64)", dest, derefSource);
217             return;
218         }
219 
220         if (target.isBigIntegerShape() || target.isBigDecimalShape()) {
221             throw new CodegenException(String.format("Cannot serialize shape type %s on protocol, shape: %s.",
222                     target.getType(), target.getId()));
223         }
224 
225         throw new CodegenException(
226                 "Members serialized as XML attributes can only be of string, number, boolean or timestamp format");
227     }
228 
229     /**
230      * getSerializedXMLMemberName returns a xml member name used for serializing. If a member shape has
231      * XML name trait, xml name would be given precedence over member name.
232      *
233      * @param memberShape is the member shape for which serializer name is queried.
234      * @return name of a xml member shape used by serializers
235      */
getSerializedXMLMemberName(MemberShape memberShape)236     private static String getSerializedXMLMemberName(MemberShape memberShape) {
237         Optional<XmlNameTrait> xmlNameTrait = memberShape.getTrait(XmlNameTrait.class);
238         return xmlNameTrait.isPresent() ? xmlNameTrait.get().getValue() : memberShape.getMemberName();
239     }
240 
241     /**
242      * getSerializedXMLShapeName returns a xml shape name used for serializing. If a member shape
243      * has xml name trait, xml name would be given precedence over member name.
244      * This correctly handles renamed shapes, and returns the original shape name.
245      *
246      * @param context is the generation context for which
247      * @param shape   is the Shape for which serializer name is queried.
248      * @return name of a xml member shape used by serializers.
249      */
getSerializedXMLShapeName(ProtocolGenerator.GenerationContext context, Shape shape)250     private static String getSerializedXMLShapeName(ProtocolGenerator.GenerationContext context, Shape shape) {
251         SymbolProvider symbolProvider = context.getSymbolProvider();
252         Symbol shapeSymbol = symbolProvider.toSymbol(shape);
253         String shapeName = shapeSymbol.getName();
254 
255         // check if synthetic cloned shape
256         Optional<SyntheticClone> clone = shape.getTrait(SyntheticClone.class);
257         if (clone.isPresent()) {
258             SyntheticClone cl = clone.get();
259             if (cl.getArchetype().isPresent()) {
260                 shapeName = cl.getArchetype().get().getName();
261             }
262         }
263 
264         // check if shape is member shape
265         Optional<MemberShape> member = shape.asMemberShape();
266         if (member.isPresent()) {
267             return getSerializedXMLMemberName(member.get());
268         }
269 
270         return shape.getTrait(XmlNameTrait.class).map(XmlNameTrait::getValue).orElse(shapeName);
271     }
272 
273     /**
274      * initializeXmlDecoder generates stub code to initialize xml decoder.
275      * Returns nil in case EOF occurs while initializing xml decoder.
276      *
277      * @param writer       the go writer used to write
278      * @param bodyLocation the variable used to represent response body
279      */
initializeXmlDecoder(GoWriter writer, String bodyLocation)280     public static void initializeXmlDecoder(GoWriter writer, String bodyLocation) {
281         initializeXmlDecoder(writer, bodyLocation, "", "nil");
282     }
283 
284     /**
285      * initializeXmlDecoder generates stub code to initialize xml decoder
286      *
287      * @param writer       the go writer used to write
288      * @param bodyLocation the variable used to represent response body
289      * @param returnOnEOF  the variable to return in case an EOF error occurs while initializing xml decoder
290      */
initializeXmlDecoder(GoWriter writer, String bodyLocation, String returnOnEOF)291     public static void initializeXmlDecoder(GoWriter writer, String bodyLocation, String returnOnEOF) {
292         initializeXmlDecoder(writer, bodyLocation, "", returnOnEOF);
293     }
294 
295     /**
296      * initializeXmlDecoder generates stub code to initialize xml decoder
297      *
298      * @param writer       the go writer used to write
299      * @param bodyLocation the variable used to represent response body
300      * @param returnExtras the extra variables to be returned with the wrapped error check statement
301      * @param returnOnEOF  the variable to return in case an EOF error occurs while initializing xml decoder
302      */
initializeXmlDecoder( GoWriter writer, String bodyLocation, String returnExtras, String returnOnEOF )303     public static void initializeXmlDecoder(
304             GoWriter writer, String bodyLocation, String returnExtras, String returnOnEOF
305     ) {
306         // Use a ring buffer and tee reader to help in pinpointing any deserialization errors.
307         writer.addUseImports(SmithyGoDependency.SMITHY_IO);
308         writer.write("var buff [1024]byte");
309         writer.write("ringBuffer := smithyio.NewRingBuffer(buff[:])");
310         writer.insertTrailingNewline();
311 
312         writer.addUseImports(SmithyGoDependency.IO);
313         writer.addUseImports(SmithyGoDependency.XML);
314         writer.addUseImports(SmithyGoDependency.SMITHY_XML);
315         writer.write("body := io.TeeReader($L, ringBuffer)", bodyLocation);
316         writer.write("rootDecoder := xml.NewDecoder(body)");
317         writer.write("t, err := smithyxml.FetchRootElement(rootDecoder)");
318         writer.write("if err == io.EOF { return $L$L}", returnExtras, returnOnEOF);
319         handleDecodeError(writer, returnExtras);
320 
321         writer.insertTrailingNewline();
322         writer.write("decoder := smithyxml.WrapNodeDecoder(rootDecoder, t)");
323         writer.insertTrailingNewline();
324     }
325 
326     /**
327      * handleDecodeError handles the xml deserialization error wrapping.
328      *
329      * @param writer       the go writer used to write
330      * @param returnExtras extra variables to be returned with the wrapped error statement
331      */
handleDecodeError(GoWriter writer, String returnExtras)332     public static void handleDecodeError(GoWriter writer, String returnExtras) {
333         writer.addUseImports(SmithyGoDependency.IO);
334         writer.openBlock("if err != nil {", "}", () -> {
335             writer.addUseImports(SmithyGoDependency.BYTES);
336             writer.addUseImports(SmithyGoDependency.SMITHY);
337             writer.write("var snapshot bytes.Buffer");
338             writer.write("io.Copy(&snapshot, ringBuffer)");
339             writer.openBlock("return $L&smithy.DeserializationError {", "}", returnExtras, () -> {
340                 writer.write("Err : fmt.Errorf(\"failed to decode response body, %w\", err),");
341                 writer.write("Snapshot: snapshot.Bytes(),");
342             });
343         }).write("");
344     }
345 
346     /**
347      * Generates code to retrieve error code or error message from the error response body
348      * This method is used indirectly by generateErrorDispatcher to generate operation specific error handling functions
349      *
350      * @param context the generation context
351      * @see <a href="https://awslabs.github.io/smithy/1.0/spec/aws/aws-restxml-protocol.html#operation-error-serialization">Rest-XML operation error serialization.</a>
352      */
writeXmlErrorMessageCodeDeserializer(ProtocolGenerator.GenerationContext context)353     public static void writeXmlErrorMessageCodeDeserializer(ProtocolGenerator.GenerationContext context) {
354         GoWriter writer = context.getWriter();
355 
356         // Check if service uses isNoErrorWrapping setting
357         boolean isNoErrorWrapping = context.getService().getTrait(RestXmlTrait.class).map(
358                 RestXmlTrait::isNoErrorWrapping).orElse(false);
359 
360         ServiceShape service = context.getService();
361 
362         if (requiresS3Customization(service)) {
363             Symbol getErrorComponentFunction = SymbolUtils.createValueSymbolBuilder(
364                     "GetErrorResponseComponents",
365                     AwsCustomGoDependency.S3_SHARED_CUSTOMIZATION
366             ).build();
367 
368             Symbol errorOptions = SymbolUtils.createValueSymbolBuilder(
369                     "ErrorResponseDeserializerOptions",
370                     AwsCustomGoDependency.S3_SHARED_CUSTOMIZATION
371             ).build();
372 
373             if (isS3Service(service)) {
374                 // s3 service
375                 writer.openBlock("errorComponents, err := $T(errorBody, $T{",
376                         "})", getErrorComponentFunction, errorOptions, () -> {
377                             writer.write("UseStatusCode : true, StatusCode : response.StatusCode,");
378                         });
379             } else {
380                 // s3 control
381                 writer.openBlock("errorComponents, err := $T(errorBody, $T{",
382                         "})", getErrorComponentFunction, errorOptions, () -> {
383                             writer.write("IsWrappedWithErrorTag: true,");
384                         });
385             }
386 
387             writer.write("if err != nil { return err }");
388 
389             writer.insertTrailingNewline();
390             writer.openBlock("if hostID := errorComponents.HostID; len(hostID)!=0 {", "}", () -> {
391                 writer.write("s3shared.SetHostIDMetadata(metadata, hostID)");
392             });
393         } else {
394             writer.addUseImports(AwsGoDependency.AWS_XML);
395             writer.write("errorComponents, err := awsxml.GetErrorResponseComponents(errorBody, $L)", isNoErrorWrapping);
396             writer.write("if err != nil { return err }");
397             writer.insertTrailingNewline();
398         }
399 
400         writer.addUseImports(AwsGoDependency.AWS_MIDDLEWARE);
401         writer.openBlock("if reqID := errorComponents.RequestID; len(reqID)!=0 {", "}", () -> {
402             writer.write("awsmiddleware.SetRequestIDMetadata(metadata, reqID)");
403         });
404         writer.insertTrailingNewline();
405 
406         writer.write("if len(errorComponents.Code) != 0 { errorCode = errorComponents.Code}");
407         writer.write("if len(errorComponents.Message) != 0 { errorMessage = errorComponents.Message}");
408         writer.insertTrailingNewline();
409 
410         writer.write("errorBody.Seek(0, io.SeekStart)");
411         writer.insertTrailingNewline();
412     }
413 
414     // returns true if service is either s3 or s3 control and needs s3 customization
requiresS3Customization(ServiceShape service)415     private static boolean requiresS3Customization(ServiceShape service) {
416         String serviceId = service.expectTrait(ServiceTrait.class).getSdkId();
417         return serviceId.equalsIgnoreCase("S3") || serviceId.equalsIgnoreCase("S3 Control");
418     }
419 
isS3Service(ServiceShape service)420     private static boolean isS3Service(ServiceShape service) {
421         String serviceId = service.expectTrait(ServiceTrait.class).getSdkId();
422         return serviceId.equalsIgnoreCase("S3");
423     }
424 }
425