1/*
2Copyright 2019 Google LLC
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8    http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15*/
16
17package spanner
18
19import (
20	"context"
21	"fmt"
22	"log"
23	"reflect"
24	"sync"
25	"time"
26
27	"cloud.google.com/go/internal/trace"
28	"cloud.google.com/go/internal/version"
29	vkit "cloud.google.com/go/spanner/apiv1"
30	"github.com/googleapis/gax-go/v2"
31	"google.golang.org/api/option"
32	gtransport "google.golang.org/api/transport/grpc"
33	sppb "google.golang.org/genproto/googleapis/spanner/v1"
34	"google.golang.org/grpc/codes"
35	"google.golang.org/grpc/metadata"
36)
37
38var cidGen = newClientIDGenerator()
39
40type clientIDGenerator struct {
41	mu  sync.Mutex
42	ids map[string]int
43}
44
45func newClientIDGenerator() *clientIDGenerator {
46	return &clientIDGenerator{ids: make(map[string]int)}
47}
48
49func (cg *clientIDGenerator) nextID(database string) string {
50	cg.mu.Lock()
51	defer cg.mu.Unlock()
52	var id int
53	if val, ok := cg.ids[database]; ok {
54		id = val + 1
55	} else {
56		id = 1
57	}
58	cg.ids[database] = id
59	return fmt.Sprintf("client-%d", id)
60}
61
62// sessionConsumer is passed to the batchCreateSessions method and will receive
63// the sessions that are created as they become available. A sessionConsumer
64// implementation must be safe for concurrent use.
65//
66// The interface is implemented by sessionPool and is used for testing the
67// sessionClient.
68type sessionConsumer interface {
69	// sessionReady is called when a session has been created and is ready for
70	// use.
71	sessionReady(s *session)
72
73	// sessionCreationFailed is called when the creation of a sub-batch of
74	// sessions failed. The numSessions argument specifies the number of
75	// sessions that could not be created as a result of this error. A
76	// consumer may receive multiple errors per batch.
77	sessionCreationFailed(err error, numSessions int32)
78}
79
80// sessionClient creates sessions for a database, either in batches or one at a
81// time. Each session will be affiliated with a gRPC channel. sessionClient
82// will ensure that the sessions that are created are evenly distributed over
83// all available channels.
84type sessionClient struct {
85	mu     sync.Mutex
86	closed bool
87
88	connPool      gtransport.ConnPool
89	database      string
90	id            string
91	sessionLabels map[string]string
92	md            metadata.MD
93	batchTimeout  time.Duration
94	logger        *log.Logger
95	callOptions   *vkit.CallOptions
96}
97
98// newSessionClient creates a session client to use for a database.
99func newSessionClient(connPool gtransport.ConnPool, database string, sessionLabels map[string]string, md metadata.MD, logger *log.Logger, callOptions *vkit.CallOptions) *sessionClient {
100	return &sessionClient{
101		connPool:      connPool,
102		database:      database,
103		id:            cidGen.nextID(database),
104		sessionLabels: sessionLabels,
105		md:            md,
106		batchTimeout:  time.Minute,
107		logger:        logger,
108		callOptions:   callOptions,
109	}
110}
111
112func (sc *sessionClient) close() error {
113	sc.mu.Lock()
114	defer sc.mu.Unlock()
115	sc.closed = true
116	return sc.connPool.Close()
117}
118
119// createSession creates one session for the database of the sessionClient. The
120// session is created using one synchronous RPC.
121func (sc *sessionClient) createSession(ctx context.Context) (*session, error) {
122	sc.mu.Lock()
123	if sc.closed {
124		sc.mu.Unlock()
125		return nil, spannerErrorf(codes.FailedPrecondition, "SessionClient is closed")
126	}
127	sc.mu.Unlock()
128	client, err := sc.nextClient()
129	if err != nil {
130		return nil, err
131	}
132	ctx = contextWithOutgoingMetadata(ctx, sc.md)
133	sid, err := client.CreateSession(ctx, &sppb.CreateSessionRequest{
134		Database: sc.database,
135		Session:  &sppb.Session{Labels: sc.sessionLabels},
136	})
137	if err != nil {
138		return nil, ToSpannerError(err)
139	}
140	return &session{valid: true, client: client, id: sid.Name, createTime: time.Now(), md: sc.md, logger: sc.logger}, nil
141}
142
143// batchCreateSessions creates a batch of sessions for the database of the
144// sessionClient and returns these to the given sessionConsumer.
145//
146// createSessionCount is the number of sessions that should be created. The
147// sessionConsumer is guaranteed to receive the requested number of sessions if
148// no error occurs. If one or more errors occur, the sessionConsumer will
149// receive any number of sessions + any number of errors, where each error will
150// include the number of sessions that could not be created as a result of the
151// error. The sum of returned sessions and errored sessions will be equal to
152// the number of requested sessions.
153// If distributeOverChannels is true, the sessions will be equally distributed
154// over all the channels that are in use by the client.
155func (sc *sessionClient) batchCreateSessions(createSessionCount int32, distributeOverChannels bool, consumer sessionConsumer) error {
156	var sessionCountPerChannel int32
157	var remainder int32
158	if distributeOverChannels {
159		// The sessions that we create should be evenly distributed over all the
160		// channels (gapic clients) that are used by the client. Each gapic client
161		// will do a request for a fraction of the total.
162		sessionCountPerChannel = createSessionCount / int32(sc.connPool.Num())
163		// The remainder of the calculation will be added to the number of sessions
164		// that will be created for the first channel, to ensure that we create the
165		// exact number of requested sessions.
166		remainder = createSessionCount % int32(sc.connPool.Num())
167	} else {
168		sessionCountPerChannel = createSessionCount
169	}
170	sc.mu.Lock()
171	defer sc.mu.Unlock()
172	if sc.closed {
173		return spannerErrorf(codes.FailedPrecondition, "SessionClient is closed")
174	}
175	// Spread the session creation over all available gRPC channels. Spanner
176	// will maintain server side caches for a session on the gRPC channel that
177	// is used by the session. A session should therefore always use the same
178	// channel, and the sessions should be as evenly distributed as possible
179	// over the channels.
180	var numBeingCreated int32
181	for i := 0; i < sc.connPool.Num() && numBeingCreated < createSessionCount; i++ {
182		client, err := sc.nextClient()
183		if err != nil {
184			return err
185		}
186		// Determine the number of sessions that should be created for this
187		// channel. The createCount for the first channel will be increased
188		// with the remainder of the division of the total number of sessions
189		// with the number of channels. All other channels will just use the
190		// result of the division over all channels.
191		createCountForChannel := sessionCountPerChannel
192		if i == 0 {
193			// We add the remainder to the first gRPC channel we use. We could
194			// also spread the remainder over all channels, but this ensures
195			// that small batches of sessions (i.e. less than numChannels) are
196			// created in one RPC.
197			createCountForChannel += remainder
198		}
199		if createCountForChannel > 0 {
200			go sc.executeBatchCreateSessions(client, createCountForChannel, sc.sessionLabels, sc.md, consumer)
201			numBeingCreated += createCountForChannel
202		}
203	}
204	return nil
205}
206
207// executeBatchCreateSessions executes the gRPC call for creating a batch of
208// sessions.
209func (sc *sessionClient) executeBatchCreateSessions(client *vkit.Client, createCount int32, labels map[string]string, md metadata.MD, consumer sessionConsumer) {
210	ctx, cancel := context.WithTimeout(context.Background(), sc.batchTimeout)
211	defer cancel()
212	ctx = contextWithOutgoingMetadata(ctx, sc.md)
213
214	ctx = trace.StartSpan(ctx, "cloud.google.com/go/spanner.BatchCreateSessions")
215	defer func() { trace.EndSpan(ctx, nil) }()
216	trace.TracePrintf(ctx, nil, "Creating a batch of %d sessions", createCount)
217	remainingCreateCount := createCount
218	for {
219		sc.mu.Lock()
220		closed := sc.closed
221		sc.mu.Unlock()
222		if closed {
223			err := spannerErrorf(codes.Canceled, "Session client closed")
224			trace.TracePrintf(ctx, nil, "Session client closed while creating a batch of %d sessions: %v", createCount, err)
225			consumer.sessionCreationFailed(err, remainingCreateCount)
226			break
227		}
228		if ctx.Err() != nil {
229			trace.TracePrintf(ctx, nil, "Context error while creating a batch of %d sessions: %v", createCount, ctx.Err())
230			consumer.sessionCreationFailed(ToSpannerError(ctx.Err()), remainingCreateCount)
231			break
232		}
233		response, err := client.BatchCreateSessions(ctx, &sppb.BatchCreateSessionsRequest{
234			SessionCount:    remainingCreateCount,
235			Database:        sc.database,
236			SessionTemplate: &sppb.Session{Labels: labels},
237		})
238		if err != nil {
239			trace.TracePrintf(ctx, nil, "Error creating a batch of %d sessions: %v", remainingCreateCount, err)
240			consumer.sessionCreationFailed(ToSpannerError(err), remainingCreateCount)
241			break
242		}
243		actuallyCreated := int32(len(response.Session))
244		trace.TracePrintf(ctx, nil, "Received a batch of %d sessions", actuallyCreated)
245		for _, s := range response.Session {
246			consumer.sessionReady(&session{valid: true, client: client, id: s.Name, createTime: time.Now(), md: md, logger: sc.logger})
247		}
248		if actuallyCreated < remainingCreateCount {
249			// Spanner could return less sessions than requested. In that case, we
250			// should do another call using the same gRPC channel.
251			remainingCreateCount -= actuallyCreated
252		} else {
253			trace.TracePrintf(ctx, nil, "Finished creating %d sessions", createCount)
254			break
255		}
256	}
257}
258
259func (sc *sessionClient) sessionWithID(id string) (*session, error) {
260	sc.mu.Lock()
261	defer sc.mu.Unlock()
262	client, err := sc.nextClient()
263	if err != nil {
264		return nil, err
265	}
266	return &session{valid: true, client: client, id: id, createTime: time.Now(), md: sc.md, logger: sc.logger}, nil
267}
268
269// nextClient returns the next gRPC client to use for session creation. The
270// client is set on the session, and used by all subsequent gRPC calls on the
271// session. Using the same channel for all gRPC calls for a session ensures the
272// optimal usage of server side caches.
273func (sc *sessionClient) nextClient() (*vkit.Client, error) {
274	// This call should never return an error as we are passing in an existing
275	// connection, so we can safely ignore it.
276	client, err := vkit.NewClient(context.Background(), option.WithGRPCConn(sc.connPool.Conn()))
277	if err != nil {
278		return nil, err
279	}
280	client.SetGoogleClientInfo("gccl", version.Repo)
281	if sc.callOptions != nil {
282		client.CallOptions = mergeCallOptions(client.CallOptions, sc.callOptions)
283	}
284	return client, nil
285}
286
287// mergeCallOptions merges two CallOptions into one and the first argument has
288// a lower order of precedence than the second one.
289func mergeCallOptions(a *vkit.CallOptions, b *vkit.CallOptions) *vkit.CallOptions {
290	res := &vkit.CallOptions{}
291	resVal := reflect.ValueOf(res).Elem()
292	aVal := reflect.ValueOf(a).Elem()
293	bVal := reflect.ValueOf(b).Elem()
294
295	t := aVal.Type()
296
297	for i := 0; i < aVal.NumField(); i++ {
298		fieldName := t.Field(i).Name
299
300		aFieldVal := aVal.Field(i).Interface().([]gax.CallOption)
301		bFieldVal := bVal.Field(i).Interface().([]gax.CallOption)
302
303		merged := append(aFieldVal, bFieldVal...)
304		resVal.FieldByName(fieldName).Set(reflect.ValueOf(merged))
305	}
306	return res
307}
308