1 package software.amazon.smithy.aws.go.codegen;
2 
3 import java.util.Collections;
4 import java.util.Map;
5 import java.util.Set;
6 import java.util.TreeSet;
7 import java.util.function.Predicate;
8 import java.util.logging.Logger;
9 import software.amazon.smithy.codegen.core.Symbol;
10 import software.amazon.smithy.codegen.core.SymbolProvider;
11 import software.amazon.smithy.go.codegen.GoDependency;
12 import software.amazon.smithy.go.codegen.GoValueAccessUtils;
13 import software.amazon.smithy.go.codegen.GoWriter;
14 import software.amazon.smithy.go.codegen.SmithyGoDependency;
15 import software.amazon.smithy.go.codegen.SymbolUtils;
16 import software.amazon.smithy.go.codegen.integration.DocumentShapeSerVisitor;
17 import software.amazon.smithy.go.codegen.integration.ProtocolGenerator.GenerationContext;
18 import software.amazon.smithy.go.codegen.integration.ProtocolUtils;
19 import software.amazon.smithy.go.codegen.knowledge.GoPointableIndex;
20 import software.amazon.smithy.go.codegen.trait.NoSerializeTrait;
21 import software.amazon.smithy.model.shapes.CollectionShape;
22 import software.amazon.smithy.model.shapes.DocumentShape;
23 import software.amazon.smithy.model.shapes.MapShape;
24 import software.amazon.smithy.model.shapes.MemberShape;
25 import software.amazon.smithy.model.shapes.Shape;
26 import software.amazon.smithy.model.shapes.StructureShape;
27 import software.amazon.smithy.model.shapes.UnionShape;
28 import software.amazon.smithy.model.traits.EnumTrait;
29 import software.amazon.smithy.model.traits.TimestampFormatTrait;
30 import software.amazon.smithy.model.traits.TimestampFormatTrait.Format;
31 import software.amazon.smithy.model.traits.XmlFlattenedTrait;
32 import software.amazon.smithy.model.traits.XmlNameTrait;
33 import software.amazon.smithy.utils.FunctionalUtils;
34 
35 /**
36  * Visitor to generate serialization functions for shapes in AWS Query protocol
37  * document bodies.
38  * <p>
39  * This class handles function body generation for all types expected by the
40  * {@code DocumentShapeSerVisitor}. No other shape type serialization is overwritten.
41  * <p>
42  * Timestamps are serialized to {@link Format}.DATE_TIME by default.
43  */
44 class QueryShapeSerVisitor extends DocumentShapeSerVisitor {
45     private static final Format DEFAULT_TIMESTAMP_FORMAT = Format.DATE_TIME;
46     private static final Logger LOGGER = Logger.getLogger(QueryShapeSerVisitor.class.getName());
47 
48     private final Predicate<MemberShape> memberFilter;
49 
QueryShapeSerVisitor(GenerationContext context)50     public QueryShapeSerVisitor(GenerationContext context) {
51         this(context, NoSerializeTrait.excludeNoSerializeMembers().and(FunctionalUtils.alwaysTrue()));
52     }
53 
QueryShapeSerVisitor(GenerationContext context, Predicate<MemberShape> memberFilter)54     public QueryShapeSerVisitor(GenerationContext context, Predicate<MemberShape> memberFilter) {
55         super(context);
56         this.memberFilter = NoSerializeTrait.excludeNoSerializeMembers().and(memberFilter);
57     }
58 
getMemberSerVisitor(MemberShape member, String source, String dest)59     private DocumentMemberSerVisitor getMemberSerVisitor(MemberShape member, String source, String dest) {
60         // Get the timestamp format to be used, defaulting to epoch seconds.
61         Format format = member.getMemberTrait(getContext().getModel(), TimestampFormatTrait.class)
62                 .map(TimestampFormatTrait::getFormat)
63                 .orElse(DEFAULT_TIMESTAMP_FORMAT);
64         return new DocumentMemberSerVisitor(getContext(), member, source, dest, format);
65     }
66 
67     @Override
getAdditionalSerArguments()68     protected Map<String, String> getAdditionalSerArguments() {
69         return Collections.singletonMap("value", "query.Value");
70     }
71 
72     @Override
serializeCollection(GenerationContext context, CollectionShape shape)73     protected void serializeCollection(GenerationContext context, CollectionShape shape) {
74         GoWriter writer = context.getWriter();
75         MemberShape member = shape.getMember();
76         Shape target = context.getModel().expectShape(member.getTarget());
77 
78         // If the list is empty, exit early to avoid extra effort.
79         writer.write("if len(v) == 0 { return nil }");
80 
81         writer.write("array := value.Array($S)", getSerializedLocationName(member, "member"));
82         writer.write("");
83 
84         writer.openBlock("for i := range v {", "}", () -> {
85             // Null values should be omitted for query.
86             if (GoPointableIndex.of(context.getModel()).isNillable(shape.getMember())) {
87                 writer.openBlock("if vv := v[i]; vv == nil {", "}", () -> {
88                     writer.write("continue");
89                 });
90             }
91 
92             writer.write("av := array.Value()");
93             target.accept(getMemberSerVisitor(shape.getMember(), "v[i]", "av"));
94         });
95         writer.write("return nil");
96     }
97 
98     @Override
serializeDocument(GenerationContext context, DocumentShape shape)99     protected void serializeDocument(GenerationContext context, DocumentShape shape) {
100         LOGGER.warning("Document type is unsupported for Query serialization.");
101         context.getWriter().write("return &smithy.SerializationError{Err: fmt.Errorf("
102                 + "\"Document type is unsupported for the query protocol.\")}");
103     }
104 
105     @Override
serializeMap(GenerationContext context, MapShape shape)106     protected void serializeMap(GenerationContext context, MapShape shape) {
107         GoWriter writer = context.getWriter();
108 
109         // If the map is empty, exit early to avoid extra effort.
110         writer.write("if len(v) == 0 { return nil }");
111 
112         Shape target = context.getModel().expectShape(shape.getValue().getTarget());
113         String keyLocationName = getSerializedLocationName(shape.getKey(), "key");
114         String valueLocationName = getSerializedLocationName(shape.getValue(), "value");
115         writer.write("object := value.Map($S, $S)", keyLocationName, valueLocationName);
116         writer.write("");
117 
118         // Create a sorted list of the map's keys so we can have a stable body.
119         // Ideally this would be a function we dispatch to, but the lack of generics make
120         // that impractical since you can't make a function for a map[string]any
121         writer.write("keys := make([]string, 0, len(v))");
122         writer.write("for key := range v { keys = append(keys, key) }");
123         writer.addUseImports(GoDependency.standardLibraryDependency("sort", "1.15"));
124         writer.write("sort.Strings(keys)");
125         writer.write("");
126 
127         writer.addUseImports(SmithyGoDependency.FMT);
128         writer.openBlock("for _, key := range keys {", "}", () -> {
129             // Null values should be omitted for query.
130             if (GoPointableIndex.of(context.getModel()).isNillable(shape.getValue())) {
131                 writer.openBlock("if vv := v[key]; vv == nil {", "}", () -> {
132                     writer.write("continue");
133                 });
134             }
135 
136             writer.write("om := object.Key(key)");
137             target.accept(getMemberSerVisitor(shape.getValue(), "v[key]", "om"));
138         });
139 
140         writer.write("return nil");
141     }
142 
143     @Override
serializeStructure(GenerationContext context, StructureShape shape)144     protected void serializeStructure(GenerationContext context, StructureShape shape) {
145         GoWriter writer = context.getWriter();
146         writer.write("object := value.Object()");
147         writer.write("_ = object");
148         writer.write("");
149 
150         // Use a TreeSet to sort the members.
151         Set<MemberShape> members = new TreeSet<>(shape.getAllMembers().values());
152         for (MemberShape member : members) {
153             if (!memberFilter.test(member)) {
154                 continue;
155             }
156             Shape target = context.getModel().expectShape(member.getTarget());
157 
158             GoValueAccessUtils.writeIfNonZeroValueMember(context.getModel(), context.getSymbolProvider(), writer,
159                     member, "v", true, member.isRequired(), (operand) -> {
160                         String locationName = getSerializedLocationName(member, member.getMemberName());
161                         if (isFlattened(context, member)) {
162                             writer.write("objectKey := object.FlatKey($S)", locationName);
163                         } else {
164                             writer.write("objectKey := object.Key($S)", locationName);
165                         }
166                         target.accept(getMemberSerVisitor(member, operand, "objectKey"));
167                     });
168             writer.write("");
169         }
170 
171         writer.write("return nil");
172     }
173 
174     /**
175      * Retrieves the correct serialization location based on the member's
176      * xmlName trait or uses the default value.
177      *
178      * @param memberShape  The member being serialized.
179      * @param defaultValue A default value for the location.
180      * @return The location where the member will be serialized.
181      */
getSerializedLocationName(MemberShape memberShape, String defaultValue)182     protected String getSerializedLocationName(MemberShape memberShape, String defaultValue) {
183         return memberShape.getTrait(XmlNameTrait.class)
184                 .map(XmlNameTrait::getValue)
185                 .orElse(defaultValue);
186     }
187 
188     /**
189      * Tells whether the contents of the member should be flattened
190      * when serialized.
191      *
192      * @param context     The generation context.
193      * @param memberShape The member being serialized.
194      * @return If the member's contents should be flattened when serialized.
195      */
isFlattened(GenerationContext context, MemberShape memberShape)196     protected boolean isFlattened(GenerationContext context, MemberShape memberShape) {
197         return memberShape.hasTrait(XmlFlattenedTrait.class);
198     }
199 
200     @Override
serializeUnion(GenerationContext context, UnionShape shape)201     protected void serializeUnion(GenerationContext context, UnionShape shape) {
202         GoWriter writer = context.getWriter();
203         SymbolProvider symbolProvider = context.getSymbolProvider();
204         Symbol symbol = symbolProvider.toSymbol(shape);
205         writer.addUseImports(SmithyGoDependency.FMT);
206 
207         writer.write("object := value.Object()");
208         writer.write("");
209 
210         writer.openBlock("switch uv := v.(type) {", "}", () -> {
211             // Use a TreeSet to sort the members.
212             Set<MemberShape> members = new TreeSet<>(shape.getAllMembers().values());
213             for (MemberShape member : members) {
214                 Shape target = context.getModel().expectShape(member.getTarget());
215                 Symbol memberSymbol = SymbolUtils.createValueSymbolBuilder(
216                         symbolProvider.toMemberName(member),
217                         symbol.getNamespace()
218                 ).build();
219 
220                 writer.openBlock("case *$T:", "", memberSymbol, () -> {
221                     String locationName = getSerializedLocationName(member, member.getMemberName());
222                     if (isFlattened(context, member)) {
223                         writer.write("objectKey := object.FlatKey($S)", locationName);
224                     } else {
225                         writer.write("objectKey := object.Key($S)", locationName);
226                     }
227                     target.accept(getMemberSerVisitor(member, "uv.Value", "objectKey"));
228                 });
229             }
230 
231             // Handle unknown union values
232             writer.openBlock("default:", "", () -> {
233                 writer.write("return fmt.Errorf(\"attempted to serialize unknown member type %T"
234                         + " for union %T\", uv, v)");
235             });
236         });
237 
238         writer.write("return nil");
239     }
240 }
241