1 /*
2  * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License").
5  * You may not use this file except in compliance with the License.
6  * A copy of the License is located at
7  *
8  *  http://aws.amazon.com/apache2.0
9  *
10  * or in the "license" file accompanying this file. This file is distributed
11  * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
12  * express or implied. See the License for the specific language governing
13  * permissions and limitations under the License.
14  */
15 
16 package software.amazon.smithy.go.codegen.integration;
17 
18 import java.util.Map;
19 import java.util.Optional;
20 import software.amazon.smithy.codegen.core.CodegenException;
21 import software.amazon.smithy.codegen.core.Symbol;
22 import software.amazon.smithy.codegen.core.SymbolProvider;
23 import software.amazon.smithy.go.codegen.GoDelegator;
24 import software.amazon.smithy.go.codegen.GoSettings;
25 import software.amazon.smithy.go.codegen.GoWriter;
26 import software.amazon.smithy.go.codegen.SmithyGoDependency;
27 import software.amazon.smithy.go.codegen.SymbolUtils;
28 import software.amazon.smithy.model.Model;
29 import software.amazon.smithy.model.knowledge.TopDownIndex;
30 import software.amazon.smithy.model.shapes.ListShape;
31 import software.amazon.smithy.model.shapes.MemberShape;
32 import software.amazon.smithy.model.shapes.OperationShape;
33 import software.amazon.smithy.model.shapes.ServiceShape;
34 import software.amazon.smithy.model.shapes.Shape;
35 import software.amazon.smithy.model.shapes.ShapeId;
36 import software.amazon.smithy.model.shapes.SimpleShape;
37 import software.amazon.smithy.model.shapes.StructureShape;
38 import software.amazon.smithy.utils.StringUtils;
39 import software.amazon.smithy.waiters.Acceptor;
40 import software.amazon.smithy.waiters.Matcher;
41 import software.amazon.smithy.waiters.PathComparator;
42 import software.amazon.smithy.waiters.WaitableTrait;
43 import software.amazon.smithy.waiters.Waiter;
44 
45 /**
46  * Implements support for WaitableTrait.
47  */
48 public class Waiters implements GoIntegration {
49     private static final String WAITER_INVOKER_FUNCTION_NAME = "Wait";
50 
51     @Override
writeAdditionalFiles( GoSettings settings, Model model, SymbolProvider symbolProvider, GoDelegator goDelegator )52     public void writeAdditionalFiles(
53             GoSettings settings,
54             Model model,
55             SymbolProvider symbolProvider,
56             GoDelegator goDelegator
57     ) {
58         ServiceShape serviceShape = settings.getService(model);
59         TopDownIndex topDownIndex = TopDownIndex.of(model);
60 
61         topDownIndex.getContainedOperations(serviceShape).stream()
62                 .forEach(operation -> {
63                     if (!operation.hasTrait(WaitableTrait.ID)) {
64                         return;
65                     }
66 
67                     Map<String, Waiter> waiters = operation.expectTrait(WaitableTrait.class).getWaiters();
68 
69                     goDelegator.useShapeWriter(operation, writer -> {
70                         generateOperationWaiter(model, symbolProvider, writer, operation, waiters);
71                     });
72                 });
73     }
74 
75 
76     /**
77      * Generates all waiter components used for the operation.
78      */
generateOperationWaiter( Model model, SymbolProvider symbolProvider, GoWriter writer, OperationShape operation, Map<String, Waiter> waiters )79     private void generateOperationWaiter(
80             Model model,
81             SymbolProvider symbolProvider,
82             GoWriter writer,
83             OperationShape operation,
84             Map<String, Waiter> waiters
85     ) {
86         // generate waiter function
87         waiters.forEach((name, waiter) -> {
88             // write waiter options
89             generateWaiterOptions(model, symbolProvider, writer, operation, name, waiter);
90 
91             // write waiter client
92             generateWaiterClient(model, symbolProvider, writer, operation, name, waiter);
93 
94             // write waiter specific invoker
95             generateWaiterInvoker(model, symbolProvider, writer, operation, name, waiter);
96 
97             // write waiter state mutator for each waiter
98             generateRetryable(model, symbolProvider, writer, operation, name, waiter);
99 
100         });
101     }
102 
103     /**
104      * Generates waiter options to configure a waiter client.
105      */
generateWaiterOptions( Model model, SymbolProvider symbolProvider, GoWriter writer, OperationShape operationShape, String waiterName, Waiter waiter )106     private void generateWaiterOptions(
107             Model model,
108             SymbolProvider symbolProvider,
109             GoWriter writer,
110             OperationShape operationShape,
111             String waiterName,
112             Waiter waiter
113     ) {
114         String optionsName = generateWaiterOptionsName(waiterName);
115         String waiterClientName = generateWaiterClientName(waiterName);
116 
117         StructureShape inputShape = model.expectShape(
118                 operationShape.getInput().get(), StructureShape.class
119         );
120         StructureShape outputShape = model.expectShape(
121                 operationShape.getOutput().get(), StructureShape.class
122         );
123 
124         Symbol inputSymbol = symbolProvider.toSymbol(inputShape);
125         Symbol outputSymbol = symbolProvider.toSymbol(outputShape);
126 
127         writer.write("");
128         writer.writeDocs(
129                 String.format("%s are waiter options for %s", optionsName, waiterClientName)
130         );
131 
132         writer.openBlock("type $L struct {", "}",
133                 optionsName, () -> {
134                     writer.addUseImports(SmithyGoDependency.TIME);
135 
136                     writer.write("");
137                     writer.writeDocs(
138                             "Set of options to modify how an operation is invoked. These apply to all operations "
139                                     + "invoked for this client. Use functional options on operation call to modify "
140                                     + "this list for per operation behavior."
141                     );
142                     Symbol stackSymbol = SymbolUtils.createPointableSymbolBuilder("Stack",
143                             SmithyGoDependency.SMITHY_MIDDLEWARE)
144                             .build();
145                     writer.write("APIOptions []func($P) error", stackSymbol);
146 
147                     writer.write("");
148                     writer.writeDocs(
149                             String.format("MinDelay is the minimum amount of time to delay between retries. "
150                                     + "If unset, %s will use default minimum delay of %s seconds. "
151                                     + "Note that MinDelay must resolve to a value lesser than or equal "
152                                     + "to the MaxDelay.", waiterClientName, waiter.getMinDelay())
153                     );
154                     writer.write("MinDelay time.Duration");
155 
156                     writer.write("");
157                     writer.writeDocs(
158                             String.format("MaxDelay is the maximum amount of time to delay between retries. "
159                                     + "If unset or set to zero, %s will use default max delay of %s seconds. "
160                                     + "Note that MaxDelay must resolve to value greater than or equal "
161                                     + "to the MinDelay.", waiterClientName, waiter.getMaxDelay())
162                     );
163                     writer.write("MaxDelay time.Duration");
164 
165                     writer.write("");
166                     writer.writeDocs("LogWaitAttempts is used to enable logging for waiter retry attempts");
167                     writer.write("LogWaitAttempts bool");
168 
169                     writer.write("");
170                     writer.writeDocs(
171                             "Retryable is function that can be used to override the "
172                                     + "service defined waiter-behavior based on operation output, or returned error. "
173                                     + "This function is used by the waiter to decide if a state is retryable "
174                                     + "or a terminal state.\n\nBy default service-modeled logic "
175                                     + "will populate this option. This option can thus be used to define a custom "
176                                     + "waiter state with fall-back to service-modeled waiter state mutators."
177                                     + "The function returns an error in case of a failure state. "
178                                     + "In case of retry state, this function returns a bool value of true and "
179                                     + "nil error, while in case of success it returns a bool value of false and "
180                                     + "nil error."
181                     );
182                     writer.write(
183                             "Retryable func(context.Context, $P, $P, error) "
184                                     + "(bool, error)", inputSymbol, outputSymbol);
185                 }
186         );
187         writer.write("");
188     }
189 
190 
191     /**
192      * Generates waiter client used to invoke waiter function. The waiter client is specific to a modeled waiter.
193      * Each waiter client is unique within a enclosure of a service.
194      * This function also generates a waiter client constructor that takes in a API client interface, and waiter options
195      * to configure a waiter client.
196      */
generateWaiterClient( Model model, SymbolProvider symbolProvider, GoWriter writer, OperationShape operationShape, String waiterName, Waiter waiter )197     private void generateWaiterClient(
198             Model model,
199             SymbolProvider symbolProvider,
200             GoWriter writer,
201             OperationShape operationShape,
202             String waiterName,
203             Waiter waiter
204     ) {
205         Symbol operationSymbol = symbolProvider.toSymbol(operationShape);
206         String clientName = generateWaiterClientName(waiterName);
207 
208         writer.write("");
209         writer.writeDocs(
210                 String.format("%s defines the waiters for %s", clientName, waiterName)
211         );
212         writer.openBlock("type $L struct {", "}",
213                 clientName, () -> {
214                     writer.write("");
215                     writer.write("client $L", OperationInterfaceGenerator.getApiClientInterfaceName(operationSymbol));
216 
217                     writer.write("");
218                     writer.write("options $L", generateWaiterOptionsName(waiterName));
219                 });
220 
221         writer.write("");
222 
223         String constructorName = String.format("New%s", clientName);
224 
225         Symbol waiterOptionsSymbol = SymbolUtils.createPointableSymbolBuilder(
226                 generateWaiterOptionsName(waiterName)
227         ).build();
228 
229         Symbol clientSymbol = SymbolUtils.createPointableSymbolBuilder(
230                 clientName
231         ).build();
232 
233         writer.writeDocs(
234                 String.format("%s constructs a %s.", constructorName, clientName)
235         );
236         writer.openBlock("func $L(client $L, optFns ...func($P)) $P {", "}",
237                 constructorName, OperationInterfaceGenerator.getApiClientInterfaceName(operationSymbol),
238                 waiterOptionsSymbol, clientSymbol, () -> {
239                     writer.write("options := $T{}", waiterOptionsSymbol);
240                     writer.addUseImports(SmithyGoDependency.TIME);
241 
242                     // set defaults
243                     writer.write("options.MinDelay = $L * time.Second", waiter.getMinDelay());
244                     writer.write("options.MaxDelay = $L * time.Second", waiter.getMaxDelay());
245                     writer.write("options.Retryable = $L", generateRetryableName(waiterName));
246                     writer.write("");
247 
248                     writer.openBlock("for _, fn := range optFns {",
249                             "}", () -> {
250                                 writer.write("fn(&options)");
251                             });
252 
253                     writer.openBlock("return &$T {", "}", clientSymbol, () -> {
254                         writer.write("client: client, ");
255                         writer.write("options: options, ");
256                     });
257                 });
258     }
259 
260     /**
261      * Generates waiter invoker functions to call specific operation waiters
262      * These waiter invoker functions is defined on each modeled waiter client.
263      * The invoker function takes in a context, along with operation input, and
264      * optional functional options for the waiter.
265      */
generateWaiterInvoker( Model model, SymbolProvider symbolProvider, GoWriter writer, OperationShape operationShape, String waiterName, Waiter waiter )266     private void generateWaiterInvoker(
267             Model model,
268             SymbolProvider symbolProvider,
269             GoWriter writer,
270             OperationShape operationShape,
271             String waiterName,
272             Waiter waiter
273     ) {
274         StructureShape inputShape = model.expectShape(
275                 operationShape.getInput().get(), StructureShape.class
276         );
277 
278         Symbol operationSymbol = symbolProvider.toSymbol(operationShape);
279         Symbol inputSymbol = symbolProvider.toSymbol(inputShape);
280 
281         Symbol waiterOptionsSymbol = SymbolUtils.createPointableSymbolBuilder(
282                 generateWaiterOptionsName(waiterName)
283         ).build();
284 
285         Symbol clientSymbol = SymbolUtils.createPointableSymbolBuilder(
286                 generateWaiterClientName(waiterName)
287         ).build();
288 
289         writer.write("");
290         writer.addUseImports(SmithyGoDependency.CONTEXT);
291         writer.addUseImports(SmithyGoDependency.TIME);
292         writer.writeDocs(
293                 String.format(
294                         "%s calls the waiter function for %s waiter. The maxWaitDur is the maximum wait duration "
295                                 + "the waiter will wait. The maxWaitDur is required and must be greater than zero.",
296                         WAITER_INVOKER_FUNCTION_NAME, waiterName)
297         );
298         writer.openBlock(
299                 "func (w $P) $L(ctx context.Context, params $P, maxWaitDur time.Duration, optFns ...func($P)) error {",
300                 "}",
301                 clientSymbol, WAITER_INVOKER_FUNCTION_NAME, inputSymbol, waiterOptionsSymbol,
302                 () -> {
303                     writer.openBlock("if maxWaitDur <= 0 {", "}", () -> {
304                         writer.addUseImports(SmithyGoDependency.FMT);
305                         writer.write("return fmt.Errorf(\"maximum wait time for waiter must be greater than zero\")");
306                     }).write("");
307 
308                     writer.write("options := w.options");
309 
310                     writer.openBlock("for _, fn := range optFns {",
311                             "}", () -> {
312                                 writer.write("fn(&options)");
313                             });
314                     writer.write("");
315 
316                     // validate values for MaxDelay from options
317                     writer.openBlock("if options.MaxDelay <= 0 {", "}", () -> {
318                         writer.write("options.MaxDelay = $L * time.Second", waiter.getMaxDelay());
319                     });
320                     writer.write("");
321 
322                     // validate that MinDelay is lesser than or equal to resolved MaxDelay
323                     writer.openBlock("if options.MinDelay > options.MaxDelay {", "}", () -> {
324                         writer.addUseImports(SmithyGoDependency.FMT);
325                         writer.write("return fmt.Errorf(\"minimum waiter delay %v must be lesser than or equal to "
326                                 + "maximum waiter delay of %v.\", options.MinDelay, options.MaxDelay)");
327                     }).write("");
328 
329                     writer.addUseImports(SmithyGoDependency.CONTEXT);
330                     writer.write("ctx, cancelFn := context.WithTimeout(ctx, maxWaitDur)");
331                     writer.write("defer cancelFn()");
332                     writer.write("");
333 
334                     Symbol loggerMiddleware = SymbolUtils.createValueSymbolBuilder(
335                             "Logger", SmithyGoDependency.SMITHY_WAITERS
336                     ).build();
337                     writer.write("logger := $T{}", loggerMiddleware);
338                     writer.write("remainingTime := maxWaitDur").write("");
339 
340                     writer.write("var attempt int64");
341                     writer.openBlock("for {", "}", () -> {
342                         writer.write("");
343                         writer.write("attempt++");
344 
345                         writer.write("apiOptions := options.APIOptions");
346                         writer.write("start := time.Now()").write("");
347 
348                         // add waiter logger middleware to log an attempt, if LogWaitAttempts is enabled.
349                         writer.openBlock("if options.LogWaitAttempts {", "}", () -> {
350                             writer.write("logger.Attempt = attempt");
351                             writer.write(
352                                     "apiOptions = append([]func(*middleware.Stack) error{}, options.APIOptions...)");
353                             writer.write("apiOptions = append(apiOptions, logger.AddLogger)");
354                         }).write("");
355 
356                         // make a request
357                         writer.openBlock("out, err := w.client.$T(ctx, params, func (o *Options) { ", "})",
358                                 operationSymbol, () -> {
359                                     writer.write("o.APIOptions = append(o.APIOptions, apiOptions...)");
360                                 });
361                         writer.write("");
362 
363                         // handle response and identify waiter state
364                         writer.write("retryable, err := options.Retryable(ctx, params, out, err)");
365                         writer.write("if err != nil { return err }");
366                         writer.write("if !retryable { return nil }").write("");
367 
368                         // update remaining time
369                         writer.write("remainingTime -= time.Since(start)");
370 
371                         // check if next iteration is possible
372                         writer.openBlock("if remainingTime < options.MinDelay || remainingTime <= 0 {", "}", () -> {
373                             writer.write("break");
374                         });
375                         writer.write("");
376 
377                         // handle retry delay computation, sleep.
378                         Symbol computeDelaySymbol = SymbolUtils.createValueSymbolBuilder(
379                                 "ComputeDelay", SmithyGoDependency.SMITHY_WAITERS
380                         ).build();
381                         writer.writeDocs("compute exponential backoff between waiter retries");
382                         writer.openBlock("delay, err := $T(", ")", computeDelaySymbol, () -> {
383                             writer.write("attempt, options.MinDelay, options.MaxDelay, remainingTime,");
384                         });
385 
386                         writer.addUseImports(SmithyGoDependency.FMT);
387                         writer.write(
388                                 "if err != nil { return fmt.Errorf(\"error computing waiter delay, %w\", err)}");
389                         writer.write("");
390 
391                         // update remaining time as per computed delay
392                         writer.write("remainingTime -= delay");
393 
394                         // sleep for delay
395                         Symbol sleepWithContextSymbol = SymbolUtils.createValueSymbolBuilder(
396                                 "SleepWithContext", SmithyGoDependency.SMITHY_TIME
397                         ).build();
398                         writer.writeDocs("sleep for the delay amount before invoking a request");
399                         writer.openBlock("if err := $T(ctx, delay); err != nil {", "}", sleepWithContextSymbol,
400                                 () -> {
401                                     writer.write(
402                                             "return fmt.Errorf(\"request cancelled while waiting, %w\", err)");
403                                 });
404                     });
405                     writer.write("return fmt.Errorf(\"exceeded max wait time for $L waiter\")", waiterName);
406                 });
407     }
408 
409     /**
410      * Generates a waiter state mutator function which is used by the waiter retrier Middleware to mutate
411      * waiter state as per the defined logic and returned operation response.
412      *
413      * @param model          the smithy model
414      * @param symbolProvider symbol provider
415      * @param writer         the Gowriter
416      * @param operationShape operation shape on which the waiter is modeled
417      * @param waiterName     the waiter name
418      * @param waiter         the waiter structure that contains info on modeled waiter
419      */
generateRetryable( Model model, SymbolProvider symbolProvider, GoWriter writer, OperationShape operationShape, String waiterName, Waiter waiter )420     private void generateRetryable(
421             Model model,
422             SymbolProvider symbolProvider,
423             GoWriter writer,
424             OperationShape operationShape,
425             String waiterName,
426             Waiter waiter
427     ) {
428         StructureShape inputShape = model.expectShape(
429                 operationShape.getInput().get(), StructureShape.class
430         );
431         StructureShape outputShape = model.expectShape(
432                 operationShape.getOutput().get(), StructureShape.class
433         );
434 
435         Symbol inputSymbol = symbolProvider.toSymbol(inputShape);
436         Symbol outputSymbol = symbolProvider.toSymbol(outputShape);
437 
438         writer.write("");
439         writer.openBlock("func $L(ctx context.Context, input $P, output $P, err error) (bool, error) {",
440                 "}", generateRetryableName(waiterName), inputSymbol, outputSymbol, () -> {
441                     waiter.getAcceptors().forEach(acceptor -> {
442                         writer.write("");
443                         // scope each acceptor to avoid name collisions
444                         Matcher matcher = acceptor.getMatcher();
445                         switch (matcher.getMemberName()) {
446                             case "output":
447                                 writer.addUseImports(SmithyGoDependency.GO_JMESPATH);
448                                 writer.addUseImports(SmithyGoDependency.FMT);
449 
450                                 Matcher.OutputMember outputMember = (Matcher.OutputMember) matcher;
451                                 String path = outputMember.getValue().getPath();
452                                 String expectedValue = outputMember.getValue().getExpected();
453                                 PathComparator comparator = outputMember.getValue().getComparator();
454                                 writer.openBlock("if err == nil {", "}", () -> {
455                                     writer.write("pathValue, err :=  jmespath.Search($S, output)", path);
456                                     writer.openBlock("if err != nil {", "}", () -> {
457                                         writer.write(
458                                                 "return false, "
459                                                         + "fmt.Errorf(\"error evaluating waiter state: %w\", err)");
460                                     }).write("");
461                                     writer.write("expectedValue := $S", expectedValue);
462 
463                                     if (comparator == PathComparator.BOOLEAN_EQUALS) {
464                                         writeWaiterComparator(writer, acceptor, comparator, null, "pathValue",
465                                                 "expectedValue");
466                                     } else {
467                                         String[] pathMembers = path.split("\\.");
468                                         Shape targetShape = outputShape;
469                                         for (int i = 0; i < pathMembers.length; i++) {
470                                             MemberShape member = getComparedMember(model, targetShape, pathMembers[i]);
471                                             if (member == null) {
472                                                 targetShape = null;
473                                                 break;
474                                             }
475                                             targetShape = model.expectShape(member.getTarget());
476                                         }
477 
478                                         if (targetShape == null) {
479                                             writeWaiterComparator(writer, acceptor, comparator, null, "pathValue",
480                                                     "expectedValue");
481                                         } else {
482                                             Symbol targetSymbol = symbolProvider.toSymbol(targetShape);
483                                             writeWaiterComparator(writer, acceptor, comparator, targetSymbol,
484                                                     "pathValue",
485                                                     "expectedValue");
486                                         }
487                                     }
488                                 });
489                                 break;
490 
491                             case "inputOutput":
492                                 writer.addUseImports(SmithyGoDependency.GO_JMESPATH);
493                                 writer.addUseImports(SmithyGoDependency.FMT);
494 
495                                 Matcher.InputOutputMember ioMember = (Matcher.InputOutputMember) matcher;
496                                 path = ioMember.getValue().getPath();
497                                 expectedValue = ioMember.getValue().getExpected();
498                                 comparator = ioMember.getValue().getComparator();
499                                 writer.openBlock("if err == nil {", "}", () -> {
500                                     writer.openBlock("pathValue, err :=  jmespath.Search($S, &struct{",
501                                             "})", path, () -> {
502                                                 writer.write("Input $P \n Output $P \n }{", inputSymbol,
503                                                         outputSymbol);
504                                                 writer.write("Input: input, \n Output: output, \n");
505                                             });
506                                     writer.openBlock("if err != nil {", "}", () -> {
507                                         writer.write(
508                                                 "return false, "
509                                                         + "fmt.Errorf(\"error evaluating waiter state: %w\", err)");
510                                     });
511                                     writer.write("");
512                                     writer.write("expectedValue := $S", expectedValue);
513                                     writeWaiterComparator(writer, acceptor, comparator, outputSymbol, "pathValue",
514                                             "expectedValue");
515                                 });
516                                 break;
517 
518                             case "success":
519                                 Matcher.SuccessMember successMember = (Matcher.SuccessMember) matcher;
520                                 writer.openBlock("if err == nil {", "}",
521                                         () -> {
522                                             writeMatchedAcceptorReturn(writer, acceptor);
523                                         });
524                                 break;
525 
526                             case "errorType":
527                                 Matcher.ErrorTypeMember errorTypeMember = (Matcher.ErrorTypeMember) matcher;
528                                 String errorType = errorTypeMember.getValue();
529 
530                                 writer.openBlock("if err != nil {", "}", () -> {
531 
532                                     // identify if this is a modeled error shape
533                                     Optional<ShapeId> errorShapeId = operationShape.getErrors().stream().filter(
534                                             shapeId -> {
535                                                 return shapeId.getName().equalsIgnoreCase(errorType);
536                                             }).findFirst();
537 
538                                     // if modeled error shape
539                                     if (errorShapeId.isPresent()) {
540                                         Shape errorShape = model.expectShape(errorShapeId.get());
541                                         Symbol modeledErrorSymbol = symbolProvider.toSymbol(errorShape);
542                                         writer.addUseImports(SmithyGoDependency.ERRORS);
543                                         writer.write("var errorType *$T", modeledErrorSymbol);
544                                         writer.openBlock("if errors.As(err, &errorType) {", "}", () -> {
545                                             writeMatchedAcceptorReturn(writer, acceptor);
546                                         });
547                                     } else {
548                                         // fall back to un-modeled error shape matching
549                                         writer.addUseImports(SmithyGoDependency.SMITHY);
550                                         writer.addUseImports(SmithyGoDependency.ERRORS);
551 
552                                         // assert unmodeled error to smithy's API error
553                                         writer.write("var apiErr smithy.APIError");
554                                         writer.write("ok := errors.As(err, &apiErr)");
555                                         writer.openBlock("if !ok {", "}", () -> {
556                                             writer.write("return false, "
557                                                     + "fmt.Errorf(\"expected err to be of type smithy.APIError, "
558                                                     + "got %w\", err)");
559                                         });
560                                         writer.write("");
561 
562                                         writer.openBlock("if $S == apiErr.ErrorCode() {", "}",
563                                                 errorType, () -> {
564                                                     writeMatchedAcceptorReturn(writer, acceptor);
565                                                 });
566                                     }
567                                 });
568                                 break;
569 
570                             default:
571                                 throw new CodegenException(
572                                         String.format("unknown waiter state : %v", matcher.getMemberName())
573                                 );
574                         }
575                     });
576 
577                     writer.write("");
578                     writer.write("return true, nil");
579                 });
580     }
581 
582     /**
583      * writes comparators for a given waiter. The comparators are defined within the waiter acceptor.
584      *
585      * @param writer       the Gowriter
586      * @param acceptor     the waiter acceptor that defines the comparator and acceptor states
587      * @param comparator   the comparator
588      * @param targetSymbol the shape symbol of the compared type.
589      * @param actual       the variable carrying the actual value obtained.
590      *                     This may be computed via a jmespath expression or operation response status (success/failure)
591      * @param expected     the variable carrying the expected value. This value is as per the modeled waiter.
592      */
writeWaiterComparator( GoWriter writer, Acceptor acceptor, PathComparator comparator, Symbol targetSymbol, String actual, String expected )593     private void writeWaiterComparator(
594             GoWriter writer,
595             Acceptor acceptor,
596             PathComparator comparator,
597             Symbol targetSymbol,
598             String actual,
599             String expected
600     ) {
601         if (targetSymbol == null) {
602             targetSymbol = SymbolUtils.createValueSymbolBuilder("string").build();
603         }
604 
605         String valueAccessor = "string(value)";
606         Optional<Boolean> isPointable = targetSymbol.getProperty(SymbolUtils.POINTABLE, Boolean.class);
607         if (isPointable.isPresent() && isPointable.get().booleanValue()) {
608             valueAccessor = "string(*value)";
609         }
610 
611         switch (comparator) {
612             case STRING_EQUALS:
613                 writer.write("value, ok := $L.($P)", actual, targetSymbol);
614                 writer.write("if !ok {");
615                 writer.write("return false, fmt.Errorf(\"waiter comparator expected $P value, got %T\", $L)}",
616                         targetSymbol, actual);
617                 writer.write("");
618 
619                 writer.openBlock("if $L == $L {", "}", valueAccessor, expected, () -> {
620                     writeMatchedAcceptorReturn(writer, acceptor);
621                 });
622                 break;
623 
624             case BOOLEAN_EQUALS:
625                 writer.addUseImports(SmithyGoDependency.STRCONV);
626                 writer.write("bv, err := strconv.ParseBool($L)", expected);
627                 writer.write(
628                         "if err != nil { return false, "
629                                 + "fmt.Errorf(\"error parsing boolean from string %w\", err)}");
630 
631                 writer.write("value, ok := $L.(bool)", actual);
632                 writer.openBlock(" if !ok {", "}", () -> {
633                     writer.write("return false, "
634                             + "fmt.Errorf(\"waiter comparator expected bool value got %T\", $L)", actual);
635                 });
636                 writer.write("");
637 
638                 writer.openBlock("if value == bv {", "}", () -> {
639                     writeMatchedAcceptorReturn(writer, acceptor);
640                 });
641                 break;
642 
643             case ALL_STRING_EQUALS:
644                 writer.write("var match = true");
645                 writer.write("listOfValues, ok := $L.([]interface{})", actual);
646                 writer.openBlock(" if !ok {", "}", () -> {
647                     writer.write("return false, "
648                             + "fmt.Errorf(\"waiter comparator expected list got %T\", $L)", actual);
649                 });
650                 writer.write("");
651 
652                 writer.write("if len(listOfValues) == 0 { match = false }");
653 
654                 String allStringValueAccessor = valueAccessor;
655                 Symbol allStringTargetSymbol = targetSymbol;
656                 writer.openBlock("for _, v := range listOfValues {", "}", () -> {
657                     writer.write("value, ok := v.($P)", allStringTargetSymbol);
658                     writer.write("if !ok {");
659                     writer.write("return false, fmt.Errorf(\"waiter comparator expected $P value, got %T\", $L)}",
660                             allStringTargetSymbol, actual);
661                     writer.write("");
662                     writer.write("if $L != $L { match = false }", allStringValueAccessor, expected);
663                 });
664                 writer.write("");
665 
666                 writer.openBlock("if match {", "}", () -> {
667                     writeMatchedAcceptorReturn(writer, acceptor);
668                 });
669                 break;
670 
671             case ANY_STRING_EQUALS:
672                 writer.write("listOfValues, ok := $L.([]interface{})", actual);
673                 writer.openBlock(" if !ok {", "}", () -> {
674                     writer.write("return false, "
675                             + "fmt.Errorf(\"waiter comparator expected list got %T\", $L)", actual);
676                 });
677                 writer.write("");
678 
679                 String anyStringValueAccessor = valueAccessor;
680                 Symbol anyStringTargetSymbol = targetSymbol;
681                 writer.openBlock("for _, v := range listOfValues {", "}", () -> {
682                     writer.write("value, ok := v.($P)", anyStringTargetSymbol);
683                     writer.write("if !ok {");
684                     writer.write("return false, fmt.Errorf(\"waiter comparator expected $P value, got %T\", $L)}",
685                             anyStringTargetSymbol, actual);
686                     writer.write("");
687                     writer.openBlock("if $L == $L {", "}", anyStringValueAccessor, expected, () -> {
688                         writeMatchedAcceptorReturn(writer, acceptor);
689                     });
690                 });
691                 break;
692 
693             default:
694                 throw new CodegenException(
695                         String.format("Found unknown waiter path comparator, %s", comparator.toString()));
696         }
697     }
698 
699 
700     /**
701      * Writes return statement for state where a waiter's acceptor state is a match.
702      *
703      * @param writer   the Go writer
704      * @param acceptor the waiter acceptor who's state is used to write an appropriate return statement.
705      */
writeMatchedAcceptorReturn(GoWriter writer, Acceptor acceptor)706     private void writeMatchedAcceptorReturn(GoWriter writer, Acceptor acceptor) {
707         switch (acceptor.getState()) {
708             case SUCCESS:
709                 writer.write("return false, nil");
710                 break;
711 
712             case FAILURE:
713                 writer.addUseImports(SmithyGoDependency.FMT);
714                 writer.write("return false, fmt.Errorf(\"waiter state transitioned to Failure\")");
715                 break;
716 
717             case RETRY:
718                 writer.write("return true, nil");
719                 break;
720 
721             default:
722                 throw new CodegenException("unknown acceptor state defined for the waiter");
723         }
724     }
725 
generateWaiterOptionsName( String waiterName )726     private String generateWaiterOptionsName(
727             String waiterName
728     ) {
729         waiterName = StringUtils.capitalize(waiterName);
730         return String.format("%sWaiterOptions", waiterName);
731     }
732 
generateWaiterClientName( String waiterName )733     private String generateWaiterClientName(
734             String waiterName
735     ) {
736         waiterName = StringUtils.capitalize(waiterName);
737         return String.format("%sWaiter", waiterName);
738     }
739 
generateRetryableName( String waiterName )740     private String generateRetryableName(
741             String waiterName
742     ) {
743         waiterName = StringUtils.uncapitalize(waiterName);
744         return String.format("%sStateRetryable", waiterName);
745     }
746 
747 
748     /**
749      * Returns the MemberShape wrt to the provided Shape and name.
750      * For eg, If shape `A` has MemberShape `B`, and the name provided is `B` as string.
751      * We return the MemberShape `B`.
752      *
753      * @param model the generation model.
754      * @param shape the shape that is walked to retreive the shape matching provided name.
755      * @param name  name is a single scope path string, and should only match to one or less shapes.
756      * @return MemberShape matching the name.
757      */
getComparedMember(Model model, Shape shape, String name)758     private MemberShape getComparedMember(Model model, Shape shape, String name) {
759 
760         name = name.replaceAll("\\[\\]", "");
761 
762         // if shape is a simple shape, just return shape as member shape
763         if (shape instanceof SimpleShape) {
764             return shape.asMemberShape().get();
765         }
766 
767         switch (shape.getType()) {
768             case STRUCTURE:
769                 StructureShape st = shape.asStructureShape().get();
770                 for (MemberShape memberShape : st.getAllMembers().values()) {
771                     if (name.equalsIgnoreCase(memberShape.getMemberName())) {
772                         return memberShape;
773                     }
774                 }
775                 break;
776 
777             case LIST:
778                 ListShape listShape = shape.asListShape().get();
779                 MemberShape listMember = listShape.getMember();
780                 Shape listTarget = model.expectShape(listMember.getTarget());
781                 return getComparedMember(model, listTarget, name);
782 
783             default:
784                 // TODO: add support for * usage with jmespath expression.
785                 return null;
786         }
787 
788         // TODO: add support for * usage with jmespath expression.
789         // return null if no shape type matched (this would happen in case of * usage with jmespath expression).
790         return null;
791     }
792 }
793