1 package software.amazon.smithy.aws.go.codegen; 2 3 import java.util.Collections; 4 import java.util.Map; 5 import java.util.Set; 6 import java.util.TreeSet; 7 import java.util.function.Predicate; 8 import java.util.logging.Logger; 9 import software.amazon.smithy.codegen.core.Symbol; 10 import software.amazon.smithy.codegen.core.SymbolProvider; 11 import software.amazon.smithy.go.codegen.GoDependency; 12 import software.amazon.smithy.go.codegen.GoValueAccessUtils; 13 import software.amazon.smithy.go.codegen.GoWriter; 14 import software.amazon.smithy.go.codegen.SmithyGoDependency; 15 import software.amazon.smithy.go.codegen.SymbolUtils; 16 import software.amazon.smithy.go.codegen.integration.DocumentShapeSerVisitor; 17 import software.amazon.smithy.go.codegen.integration.ProtocolGenerator.GenerationContext; 18 import software.amazon.smithy.go.codegen.integration.ProtocolUtils; 19 import software.amazon.smithy.go.codegen.knowledge.GoPointableIndex; 20 import software.amazon.smithy.go.codegen.trait.NoSerializeTrait; 21 import software.amazon.smithy.model.shapes.CollectionShape; 22 import software.amazon.smithy.model.shapes.DocumentShape; 23 import software.amazon.smithy.model.shapes.MapShape; 24 import software.amazon.smithy.model.shapes.MemberShape; 25 import software.amazon.smithy.model.shapes.Shape; 26 import software.amazon.smithy.model.shapes.StructureShape; 27 import software.amazon.smithy.model.shapes.UnionShape; 28 import software.amazon.smithy.model.traits.EnumTrait; 29 import software.amazon.smithy.model.traits.TimestampFormatTrait; 30 import software.amazon.smithy.model.traits.TimestampFormatTrait.Format; 31 import software.amazon.smithy.model.traits.XmlFlattenedTrait; 32 import software.amazon.smithy.model.traits.XmlNameTrait; 33 import software.amazon.smithy.utils.FunctionalUtils; 34 35 /** 36 * Visitor to generate serialization functions for shapes in AWS Query protocol 37 * document bodies. 38 * <p> 39 * This class handles function body generation for all types expected by the 40 * {@code DocumentShapeSerVisitor}. No other shape type serialization is overwritten. 41 * <p> 42 * Timestamps are serialized to {@link Format}.DATE_TIME by default. 43 */ 44 class QueryShapeSerVisitor extends DocumentShapeSerVisitor { 45 private static final Format DEFAULT_TIMESTAMP_FORMAT = Format.DATE_TIME; 46 private static final Logger LOGGER = Logger.getLogger(QueryShapeSerVisitor.class.getName()); 47 48 private final Predicate<MemberShape> memberFilter; 49 QueryShapeSerVisitor(GenerationContext context)50 public QueryShapeSerVisitor(GenerationContext context) { 51 this(context, NoSerializeTrait.excludeNoSerializeMembers().and(FunctionalUtils.alwaysTrue())); 52 } 53 QueryShapeSerVisitor(GenerationContext context, Predicate<MemberShape> memberFilter)54 public QueryShapeSerVisitor(GenerationContext context, Predicate<MemberShape> memberFilter) { 55 super(context); 56 this.memberFilter = NoSerializeTrait.excludeNoSerializeMembers().and(memberFilter); 57 } 58 getMemberSerVisitor(MemberShape member, String source, String dest)59 private DocumentMemberSerVisitor getMemberSerVisitor(MemberShape member, String source, String dest) { 60 // Get the timestamp format to be used, defaulting to epoch seconds. 61 Format format = member.getMemberTrait(getContext().getModel(), TimestampFormatTrait.class) 62 .map(TimestampFormatTrait::getFormat) 63 .orElse(DEFAULT_TIMESTAMP_FORMAT); 64 return new DocumentMemberSerVisitor(getContext(), member, source, dest, format); 65 } 66 67 @Override getAdditionalSerArguments()68 protected Map<String, String> getAdditionalSerArguments() { 69 return Collections.singletonMap("value", "query.Value"); 70 } 71 72 @Override serializeCollection(GenerationContext context, CollectionShape shape)73 protected void serializeCollection(GenerationContext context, CollectionShape shape) { 74 GoWriter writer = context.getWriter(); 75 MemberShape member = shape.getMember(); 76 Shape target = context.getModel().expectShape(member.getTarget()); 77 78 // If the list is empty, exit early to avoid extra effort. 79 writer.write("if len(v) == 0 { return nil }"); 80 81 writer.write("array := value.Array($S)", getSerializedLocationName(member, "member")); 82 writer.write(""); 83 84 writer.openBlock("for i := range v {", "}", () -> { 85 // Null values should be omitted for query. 86 if (GoPointableIndex.of(context.getModel()).isNillable(shape.getMember())) { 87 writer.openBlock("if vv := v[i]; vv == nil {", "}", () -> { 88 writer.write("continue"); 89 }); 90 } 91 92 writer.write("av := array.Value()"); 93 target.accept(getMemberSerVisitor(shape.getMember(), "v[i]", "av")); 94 }); 95 writer.write("return nil"); 96 } 97 98 @Override serializeDocument(GenerationContext context, DocumentShape shape)99 protected void serializeDocument(GenerationContext context, DocumentShape shape) { 100 LOGGER.warning("Document type is unsupported for Query serialization."); 101 context.getWriter().write("return &smithy.SerializationError{Err: fmt.Errorf(" 102 + "\"Document type is unsupported for the query protocol.\")}"); 103 } 104 105 @Override serializeMap(GenerationContext context, MapShape shape)106 protected void serializeMap(GenerationContext context, MapShape shape) { 107 GoWriter writer = context.getWriter(); 108 109 // If the map is empty, exit early to avoid extra effort. 110 writer.write("if len(v) == 0 { return nil }"); 111 112 Shape target = context.getModel().expectShape(shape.getValue().getTarget()); 113 String keyLocationName = getSerializedLocationName(shape.getKey(), "key"); 114 String valueLocationName = getSerializedLocationName(shape.getValue(), "value"); 115 writer.write("object := value.Map($S, $S)", keyLocationName, valueLocationName); 116 writer.write(""); 117 118 // Create a sorted list of the map's keys so we can have a stable body. 119 // Ideally this would be a function we dispatch to, but the lack of generics make 120 // that impractical since you can't make a function for a map[string]any 121 writer.write("keys := make([]string, 0, len(v))"); 122 writer.write("for key := range v { keys = append(keys, key) }"); 123 writer.addUseImports(GoDependency.standardLibraryDependency("sort", "1.15")); 124 writer.write("sort.Strings(keys)"); 125 writer.write(""); 126 127 writer.addUseImports(SmithyGoDependency.FMT); 128 writer.openBlock("for _, key := range keys {", "}", () -> { 129 // Null values should be omitted for query. 130 if (GoPointableIndex.of(context.getModel()).isNillable(shape.getValue())) { 131 writer.openBlock("if vv := v[key]; vv == nil {", "}", () -> { 132 writer.write("continue"); 133 }); 134 } 135 136 writer.write("om := object.Key(key)"); 137 target.accept(getMemberSerVisitor(shape.getValue(), "v[key]", "om")); 138 }); 139 140 writer.write("return nil"); 141 } 142 143 @Override serializeStructure(GenerationContext context, StructureShape shape)144 protected void serializeStructure(GenerationContext context, StructureShape shape) { 145 GoWriter writer = context.getWriter(); 146 writer.write("object := value.Object()"); 147 writer.write("_ = object"); 148 writer.write(""); 149 150 // Use a TreeSet to sort the members. 151 Set<MemberShape> members = new TreeSet<>(shape.getAllMembers().values()); 152 for (MemberShape member : members) { 153 if (!memberFilter.test(member)) { 154 continue; 155 } 156 Shape target = context.getModel().expectShape(member.getTarget()); 157 158 GoValueAccessUtils.writeIfNonZeroValueMember(context.getModel(), context.getSymbolProvider(), writer, 159 member, "v", true, member.isRequired(), (operand) -> { 160 String locationName = getSerializedLocationName(member, member.getMemberName()); 161 if (isFlattened(context, member)) { 162 writer.write("objectKey := object.FlatKey($S)", locationName); 163 } else { 164 writer.write("objectKey := object.Key($S)", locationName); 165 } 166 target.accept(getMemberSerVisitor(member, operand, "objectKey")); 167 }); 168 writer.write(""); 169 } 170 171 writer.write("return nil"); 172 } 173 174 /** 175 * Retrieves the correct serialization location based on the member's 176 * xmlName trait or uses the default value. 177 * 178 * @param memberShape The member being serialized. 179 * @param defaultValue A default value for the location. 180 * @return The location where the member will be serialized. 181 */ getSerializedLocationName(MemberShape memberShape, String defaultValue)182 protected String getSerializedLocationName(MemberShape memberShape, String defaultValue) { 183 return memberShape.getTrait(XmlNameTrait.class) 184 .map(XmlNameTrait::getValue) 185 .orElse(defaultValue); 186 } 187 188 /** 189 * Tells whether the contents of the member should be flattened 190 * when serialized. 191 * 192 * @param context The generation context. 193 * @param memberShape The member being serialized. 194 * @return If the member's contents should be flattened when serialized. 195 */ isFlattened(GenerationContext context, MemberShape memberShape)196 protected boolean isFlattened(GenerationContext context, MemberShape memberShape) { 197 return memberShape.hasTrait(XmlFlattenedTrait.class); 198 } 199 200 @Override serializeUnion(GenerationContext context, UnionShape shape)201 protected void serializeUnion(GenerationContext context, UnionShape shape) { 202 GoWriter writer = context.getWriter(); 203 SymbolProvider symbolProvider = context.getSymbolProvider(); 204 Symbol symbol = symbolProvider.toSymbol(shape); 205 writer.addUseImports(SmithyGoDependency.FMT); 206 207 writer.write("object := value.Object()"); 208 writer.write(""); 209 210 writer.openBlock("switch uv := v.(type) {", "}", () -> { 211 // Use a TreeSet to sort the members. 212 Set<MemberShape> members = new TreeSet<>(shape.getAllMembers().values()); 213 for (MemberShape member : members) { 214 Shape target = context.getModel().expectShape(member.getTarget()); 215 Symbol memberSymbol = SymbolUtils.createValueSymbolBuilder( 216 symbolProvider.toMemberName(member), 217 symbol.getNamespace() 218 ).build(); 219 220 writer.openBlock("case *$T:", "", memberSymbol, () -> { 221 String locationName = getSerializedLocationName(member, member.getMemberName()); 222 if (isFlattened(context, member)) { 223 writer.write("objectKey := object.FlatKey($S)", locationName); 224 } else { 225 writer.write("objectKey := object.Key($S)", locationName); 226 } 227 target.accept(getMemberSerVisitor(member, "uv.Value", "objectKey")); 228 }); 229 } 230 231 // Handle unknown union values 232 writer.openBlock("default:", "", () -> { 233 writer.write("return fmt.Errorf(\"attempted to serialize unknown member type %T" 234 + " for union %T\", uv, v)"); 235 }); 236 }); 237 238 writer.write("return nil"); 239 } 240 } 241