1/*
2Copyright 2017 Google LLC
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8    http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15*/
16
17package spanner
18
19import (
20	"context"
21	"time"
22
23	"cloud.google.com/go/internal/trace"
24	"github.com/golang/protobuf/ptypes"
25	"github.com/googleapis/gax-go/v2"
26	"google.golang.org/genproto/googleapis/rpc/errdetails"
27	"google.golang.org/grpc/codes"
28	"google.golang.org/grpc/status"
29)
30
31const (
32	retryInfoKey = "google.rpc.retryinfo-bin"
33)
34
35// DefaultRetryBackoff is used for retryers as a fallback value when the server
36// did not return any retry information.
37var DefaultRetryBackoff = gax.Backoff{
38	Initial:    20 * time.Millisecond,
39	Max:        32 * time.Second,
40	Multiplier: 1.3,
41}
42
43// spannerRetryer extends the generic gax Retryer, but also checks for any
44// retry info returned by Cloud Spanner and uses that if present.
45type spannerRetryer struct {
46	gax.Retryer
47}
48
49// onCodes returns a spannerRetryer that will retry on the specified error
50// codes.
51func onCodes(bo gax.Backoff, cc ...codes.Code) gax.Retryer {
52	return &spannerRetryer{
53		Retryer: gax.OnCodes(cc, bo),
54	}
55}
56
57// Retry returns the retry delay returned by Cloud Spanner if that is present.
58// Otherwise it returns the retry delay calculated by the generic gax Retryer.
59func (r *spannerRetryer) Retry(err error) (time.Duration, bool) {
60	delay, shouldRetry := r.Retryer.Retry(err)
61	if !shouldRetry {
62		return 0, false
63	}
64	if serverDelay, hasServerDelay := extractRetryDelay(err); hasServerDelay {
65		delay = serverDelay
66	}
67	return delay, true
68}
69
70// runWithRetryOnAbortedOrSessionNotFound executes the given function and
71// retries it if it returns an Aborted or Session not found error. The retry
72// is delayed if the error was Aborted. The delay between retries is the delay
73// returned by Cloud Spanner, or if none is returned, the calculated delay with
74// a minimum of 10ms and maximum of 32s. There is no delay before the retry if
75// the error was Session not found.
76func runWithRetryOnAbortedOrSessionNotFound(ctx context.Context, f func(context.Context) error) error {
77	retryer := onCodes(DefaultRetryBackoff, codes.Aborted)
78	funcWithRetry := func(ctx context.Context) error {
79		for {
80			err := f(ctx)
81			if err == nil {
82				return nil
83			}
84			// Get Spanner or GRPC status error.
85			// TODO(loite): Refactor to unwrap Status error instead of Spanner
86			// error when statusError implements the (errors|xerrors).Wrapper
87			// interface.
88			var retryErr error
89			var se *Error
90			if errorAs(err, &se) {
91				// It is a (wrapped) Spanner error. Use that to check whether
92				// we should retry.
93				retryErr = se
94			} else {
95				// It's not a Spanner error, check if it is a status error.
96				_, ok := status.FromError(err)
97				if !ok {
98					return err
99				}
100				retryErr = err
101			}
102			if isSessionNotFoundError(retryErr) {
103				trace.TracePrintf(ctx, nil, "Retrying after Session not found")
104				continue
105			}
106			delay, shouldRetry := retryer.Retry(retryErr)
107			if !shouldRetry {
108				return err
109			}
110			trace.TracePrintf(ctx, nil, "Backing off after ABORTED for %s, then retrying", delay)
111			if err := gax.Sleep(ctx, delay); err != nil {
112				return err
113			}
114		}
115	}
116	return funcWithRetry(ctx)
117}
118
119// extractRetryDelay extracts retry backoff from a *spanner.Error if present.
120func extractRetryDelay(err error) (time.Duration, bool) {
121	var se *Error
122	var s *status.Status
123	// Unwrap status error.
124	if errorAs(err, &se) {
125		s = status.Convert(se.Unwrap())
126	} else {
127		s = status.Convert(err)
128	}
129	if s == nil {
130		return 0, false
131	}
132	for _, detail := range s.Details() {
133		if retryInfo, ok := detail.(*errdetails.RetryInfo); ok {
134			delay, err := ptypes.Duration(retryInfo.RetryDelay)
135			if err != nil {
136				return 0, false
137			}
138			return delay, true
139		}
140	}
141	return 0, false
142}
143