1 /*
2  * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License").
5  * You may not use this file except in compliance with the License.
6  * A copy of the License is located at
7  *
8  *  http://aws.amazon.com/apache2.0
9  *
10  * or in the "license" file accompanying this file. This file is distributed
11  * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
12  * express or implied. See the License for the specific language governing
13  * permissions and limitations under the License.
14  */
15 
16 package software.amazon.smithy.go.codegen.integration;
17 
18 import java.util.Set;
19 import java.util.TreeSet;
20 import java.util.function.Consumer;
21 import software.amazon.smithy.codegen.core.CodegenException;
22 import software.amazon.smithy.go.codegen.GoWriter;
23 import software.amazon.smithy.go.codegen.MiddlewareIdentifier;
24 import software.amazon.smithy.go.codegen.integration.ProtocolGenerator.GenerationContext;
25 import software.amazon.smithy.go.codegen.knowledge.GoPointableIndex;
26 import software.amazon.smithy.model.Model;
27 import software.amazon.smithy.model.knowledge.OperationIndex;
28 import software.amazon.smithy.model.neighbor.RelationshipType;
29 import software.amazon.smithy.model.neighbor.Walker;
30 import software.amazon.smithy.model.shapes.MemberShape;
31 import software.amazon.smithy.model.shapes.OperationShape;
32 import software.amazon.smithy.model.shapes.Shape;
33 import software.amazon.smithy.model.shapes.ShapeId;
34 import software.amazon.smithy.model.shapes.ShapeType;
35 import software.amazon.smithy.model.shapes.StructureShape;
36 import software.amazon.smithy.utils.SetUtils;
37 
38 /**
39  * Utility functions for protocol generation.
40  */
41 public final class ProtocolUtils {
42     public static final MiddlewareIdentifier OPERATION_SERIALIZER_MIDDLEWARE_ID = MiddlewareIdentifier
43             .string("OperationSerializer");
44     public static final MiddlewareIdentifier OPERATION_DESERIALIZER_MIDDLEWARE_ID = MiddlewareIdentifier
45             .string("OperationDeserializer");
46 
47     private static final Set<ShapeType> REQUIRES_SERDE = SetUtils.of(
48             ShapeType.MAP, ShapeType.LIST, ShapeType.SET, ShapeType.DOCUMENT, ShapeType.STRUCTURE, ShapeType.UNION);
49     private static final Set<RelationshipType> MEMBER_RELATIONSHIPS = SetUtils.of(
50             RelationshipType.STRUCTURE_MEMBER, RelationshipType.UNION_MEMBER, RelationshipType.LIST_MEMBER,
51             RelationshipType.SET_MEMBER, RelationshipType.MAP_VALUE, RelationshipType.MEMBER_TARGET
52     );
53 
ProtocolUtils()54     private ProtocolUtils() {
55     }
56 
57     /**
58      * Resolves the entire set of shapes that will require serde given an initial set of shapes.
59      *
60      * @param model  the model
61      * @param shapes the shapes to walk and resolve additional required serializers, deserializers for
62      * @return the complete set of shapes requiring serializers, deserializers
63      */
resolveRequiredDocumentShapeSerde(Model model, Set<Shape> shapes)64     public static Set<Shape> resolveRequiredDocumentShapeSerde(Model model, Set<Shape> shapes) {
65         Set<ShapeId> processed = new TreeSet<>();
66         Set<Shape> resolvedShapes = new TreeSet<>();
67         Walker walker = new Walker(model);
68 
69         shapes.forEach(shape -> {
70             processed.add(shape.getId());
71             resolvedShapes.add(shape);
72             walker.iterateShapes(shape, relationship -> MEMBER_RELATIONSHIPS.contains(
73                     relationship.getRelationshipType()))
74                     .forEachRemaining(walkedShape -> {
75                         // MemberShape type itself is not what we are interested in
76                         if (walkedShape.getType() == ShapeType.MEMBER) {
77                             return;
78                         }
79                         if (processed.contains(walkedShape.getId())) {
80                             return;
81                         }
82                         if (requiresDocumentSerdeFunction(shape)) {
83                             resolvedShapes.add(walkedShape);
84                             processed.add(walkedShape.getId());
85                         }
86                     });
87         });
88 
89         return resolvedShapes;
90     }
91 
92     /**
93      * Determines whether a document serde function is required for the given shape.
94      * <p>
95      * The following shape types will require a serde function: maps, lists, sets, documents, structures, and unions.
96      *
97      * @param shape the shape
98      * @return true if the shape requires a serde function
99      */
requiresDocumentSerdeFunction(Shape shape)100     public static boolean requiresDocumentSerdeFunction(Shape shape) {
101         return REQUIRES_SERDE.contains(shape.getType());
102     }
103 
104     /**
105      * Gets the operation input as a structure shape or throws an exception.
106      *
107      * @param model     The model that contains the operation.
108      * @param operation The operation to get the input from.
109      * @return The operation's input as a structure shape.
110      */
expectInput(Model model, OperationShape operation)111     public static StructureShape expectInput(Model model, OperationShape operation) {
112         return model.getKnowledge(OperationIndex.class).getInput(operation)
113                 .orElseThrow(() -> new CodegenException(
114                         "Expected input shape for operation " + operation.getId().toString()));
115     }
116 
117     /**
118      * Gets the operation output as a structure shape or throws an exception.
119      *
120      * @param model     The model that contains the operation.
121      * @param operation The operation to get the output from.
122      * @return The operation's output as a structure shape.
123      */
expectOutput(Model model, OperationShape operation)124     public static StructureShape expectOutput(Model model, OperationShape operation) {
125         return model.getKnowledge(OperationIndex.class).getOutput(operation)
126                 .orElseThrow(() -> new CodegenException(
127                         "Expected output shape for operation " + operation.getId().toString()));
128     }
129 
130     /**
131      * Wraps the protocol's delegation function changing the destination variable to a double pointer if the
132      * shape type is not pointable.
133      *
134      * @param context         generation context
135      * @param writer          go writer
136      * @param member          shape to determine if pointable
137      * @param origDestOperand original variable name
138      * @param lambda          runnable
139      */
writeDeserDelegateFunction( GenerationContext context, GoWriter writer, MemberShape member, String origDestOperand, Consumer<String> lambda )140     public static void writeDeserDelegateFunction(
141             GenerationContext context,
142             GoWriter writer,
143             MemberShape member,
144             String origDestOperand,
145             Consumer<String> lambda
146     ) {
147         Shape targetShape = context.getModel().expectShape(member.getTarget());
148         Shape container = context.getModel().expectShape(member.getContainer());
149 
150         boolean withAddr = !GoPointableIndex.of(context.getModel()).isPointable(member)
151                 && GoPointableIndex.of(context.getModel()).isPointable(targetShape);
152         boolean isMap = container.getType() == ShapeType.MAP;
153 
154         String destOperand = origDestOperand;
155         if (isMap) {
156             writer.write("mapVar := $L", origDestOperand);
157             destOperand = "mapVar";
158         }
159 
160         if (withAddr) {
161             writer.write("destAddr := &$L", destOperand);
162             destOperand = "destAddr";
163         }
164 
165         lambda.accept(destOperand);
166 
167         if (isMap || withAddr) {
168             if (withAddr) {
169                 destOperand = "*" + destOperand;
170             }
171 
172             writer.write("$L = $L", origDestOperand, destOperand);
173         }
174     }
175 
176     /**
177      * Writes helper variables for the delegation function to ensure that map values are safely delegated down
178      * each level.
179      *
180      * @param context         generation context
181      * @param writer          go writer
182      * @param member          shape to determine if pointable
183      * @param origDestOperand original variable name
184      * @param lambda          runnable
185      */
writeSerDelegateFunction( GenerationContext context, GoWriter writer, MemberShape member, String origDestOperand, Consumer<String> lambda )186     public static void writeSerDelegateFunction(
187             GenerationContext context,
188             GoWriter writer,
189             MemberShape member,
190             String origDestOperand,
191             Consumer<String> lambda
192     ) {
193         Shape targetShape = context.getModel().expectShape(member.getTarget());
194         Shape container = context.getModel().expectShape(member.getContainer());
195 
196         boolean withAddr = !GoPointableIndex.of(context.getModel()).isPointable(member)
197                 && GoPointableIndex.of(context.getModel()).isPointable(targetShape);
198         boolean isMap = container.getType() == ShapeType.MAP;
199 
200         String destOperand = origDestOperand;
201         if (isMap && withAddr) {
202             writer.write("mapVar := $L", origDestOperand);
203             destOperand = "mapVar";
204         }
205 
206         String acceptVar = destOperand;
207         if (withAddr) {
208             acceptVar = "&" + destOperand;
209         }
210 
211         lambda.accept(acceptVar);
212     }
213 }
214