1/*
2Copyright 2019 Google LLC
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8    http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15*/
16
17package spanner
18
19import (
20	"context"
21	"fmt"
22	"log"
23	"sync"
24	"time"
25
26	"cloud.google.com/go/internal/trace"
27	vkit "cloud.google.com/go/spanner/apiv1"
28	"google.golang.org/api/option"
29	gtransport "google.golang.org/api/transport/grpc"
30	sppb "google.golang.org/genproto/googleapis/spanner/v1"
31	"google.golang.org/grpc/codes"
32	"google.golang.org/grpc/metadata"
33)
34
35var cidGen = newClientIDGenerator()
36
37type clientIDGenerator struct {
38	mu  sync.Mutex
39	ids map[string]int
40}
41
42func newClientIDGenerator() *clientIDGenerator {
43	return &clientIDGenerator{ids: make(map[string]int)}
44}
45
46func (cg *clientIDGenerator) nextID(database string) string {
47	cg.mu.Lock()
48	defer cg.mu.Unlock()
49	var id int
50	if val, ok := cg.ids[database]; ok {
51		id = val + 1
52	} else {
53		id = 1
54	}
55	cg.ids[database] = id
56	return fmt.Sprintf("client-%d", id)
57}
58
59// sessionConsumer is passed to the batchCreateSessions method and will receive
60// the sessions that are created as they become available. A sessionConsumer
61// implementation must be safe for concurrent use.
62//
63// The interface is implemented by sessionPool and is used for testing the
64// sessionClient.
65type sessionConsumer interface {
66	// sessionReady is called when a session has been created and is ready for
67	// use.
68	sessionReady(s *session)
69
70	// sessionCreationFailed is called when the creation of a sub-batch of
71	// sessions failed. The numSessions argument specifies the number of
72	// sessions that could not be created as a result of this error. A
73	// consumer may receive multiple errors per batch.
74	sessionCreationFailed(err error, numSessions int32)
75}
76
77// sessionClient creates sessions for a database, either in batches or one at a
78// time. Each session will be affiliated with a gRPC channel. sessionClient
79// will ensure that the sessions that are created are evenly distributed over
80// all available channels.
81type sessionClient struct {
82	mu     sync.Mutex
83	closed bool
84
85	connPool      gtransport.ConnPool
86	database      string
87	id            string
88	sessionLabels map[string]string
89	md            metadata.MD
90	batchTimeout  time.Duration
91	logger        *log.Logger
92}
93
94// newSessionClient creates a session client to use for a database.
95func newSessionClient(connPool gtransport.ConnPool, database string, sessionLabels map[string]string, md metadata.MD, logger *log.Logger) *sessionClient {
96	return &sessionClient{
97		connPool:      connPool,
98		database:      database,
99		id:            cidGen.nextID(database),
100		sessionLabels: sessionLabels,
101		md:            md,
102		batchTimeout:  time.Minute,
103		logger:        logger,
104	}
105}
106
107func (sc *sessionClient) close() error {
108	sc.mu.Lock()
109	defer sc.mu.Unlock()
110	sc.closed = true
111	return sc.connPool.Close()
112}
113
114// createSession creates one session for the database of the sessionClient. The
115// session is created using one synchronous RPC.
116func (sc *sessionClient) createSession(ctx context.Context) (*session, error) {
117	sc.mu.Lock()
118	if sc.closed {
119		sc.mu.Unlock()
120		return nil, spannerErrorf(codes.FailedPrecondition, "SessionClient is closed")
121	}
122	sc.mu.Unlock()
123	client, err := sc.nextClient()
124	if err != nil {
125		return nil, err
126	}
127	ctx = contextWithOutgoingMetadata(ctx, sc.md)
128	sid, err := client.CreateSession(ctx, &sppb.CreateSessionRequest{
129		Database: sc.database,
130		Session:  &sppb.Session{Labels: sc.sessionLabels},
131	})
132	if err != nil {
133		return nil, toSpannerError(err)
134	}
135	return &session{valid: true, client: client, id: sid.Name, createTime: time.Now(), md: sc.md, logger: sc.logger}, nil
136}
137
138// batchCreateSessions creates a batch of sessions for the database of the
139// sessionClient and returns these to the given sessionConsumer.
140//
141// createSessionCount is the number of sessions that should be created. The
142// sessionConsumer is guaranteed to receive the requested number of sessions if
143// no error occurs. If one or more errors occur, the sessionConsumer will
144// receive any number of sessions + any number of errors, where each error will
145// include the number of sessions that could not be created as a result of the
146// error. The sum of returned sessions and errored sessions will be equal to
147// the number of requested sessions.
148func (sc *sessionClient) batchCreateSessions(createSessionCount int32, consumer sessionConsumer) error {
149	// The sessions that we create should be evenly distributed over all the
150	// channels (gapic clients) that are used by the client. Each gapic client
151	// will do a request for a fraction of the total.
152	sessionCountPerChannel := createSessionCount / int32(sc.connPool.Num())
153	// The remainder of the calculation will be added to the number of sessions
154	// that will be created for the first channel, to ensure that we create the
155	// exact number of requested sessions.
156	remainder := createSessionCount % int32(sc.connPool.Num())
157	sc.mu.Lock()
158	defer sc.mu.Unlock()
159	if sc.closed {
160		return spannerErrorf(codes.FailedPrecondition, "SessionClient is closed")
161	}
162	// Spread the session creation over all available gRPC channels. Spanner
163	// will maintain server side caches for a session on the gRPC channel that
164	// is used by the session. A session should therefore always use the same
165	// channel, and the sessions should be as evenly distributed as possible
166	// over the channels.
167	for i := 0; i < sc.connPool.Num(); i++ {
168		client, err := sc.nextClient()
169		if err != nil {
170			return err
171		}
172		// Determine the number of sessions that should be created for this
173		// channel. The createCount for the first channel will be increased
174		// with the remainder of the division of the total number of sessions
175		// with the number of channels. All other channels will just use the
176		// result of the division over all channels.
177		createCountForChannel := sessionCountPerChannel
178		if i == 0 {
179			// We add the remainder to the first gRPC channel we use. We could
180			// also spread the remainder over all channels, but this ensures
181			// that small batches of sessions (i.e. less than numChannels) are
182			// created in one RPC.
183			createCountForChannel += remainder
184		}
185		if createCountForChannel > 0 {
186			go sc.executeBatchCreateSessions(client, createCountForChannel, sc.sessionLabels, sc.md, consumer)
187		}
188	}
189	return nil
190}
191
192// executeBatchCreateSessions executes the gRPC call for creating a batch of
193// sessions.
194func (sc *sessionClient) executeBatchCreateSessions(client *vkit.Client, createCount int32, labels map[string]string, md metadata.MD, consumer sessionConsumer) {
195	ctx, cancel := context.WithTimeout(context.Background(), sc.batchTimeout)
196	defer cancel()
197	ctx = contextWithOutgoingMetadata(ctx, sc.md)
198
199	ctx = trace.StartSpan(ctx, "cloud.google.com/go/spanner.BatchCreateSessions")
200	defer func() { trace.EndSpan(ctx, nil) }()
201	trace.TracePrintf(ctx, nil, "Creating a batch of %d sessions", createCount)
202	remainingCreateCount := createCount
203	for {
204		sc.mu.Lock()
205		closed := sc.closed
206		sc.mu.Unlock()
207		if closed {
208			err := spannerErrorf(codes.Canceled, "Session client closed")
209			trace.TracePrintf(ctx, nil, "Session client closed while creating a batch of %d sessions: %v", createCount, err)
210			consumer.sessionCreationFailed(err, remainingCreateCount)
211			break
212		}
213		if ctx.Err() != nil {
214			trace.TracePrintf(ctx, nil, "Context error while creating a batch of %d sessions: %v", createCount, ctx.Err())
215			consumer.sessionCreationFailed(toSpannerError(ctx.Err()), remainingCreateCount)
216			break
217		}
218		response, err := client.BatchCreateSessions(ctx, &sppb.BatchCreateSessionsRequest{
219			SessionCount:    remainingCreateCount,
220			Database:        sc.database,
221			SessionTemplate: &sppb.Session{Labels: labels},
222		})
223		if err != nil {
224			trace.TracePrintf(ctx, nil, "Error creating a batch of %d sessions: %v", remainingCreateCount, err)
225			consumer.sessionCreationFailed(toSpannerError(err), remainingCreateCount)
226			break
227		}
228		actuallyCreated := int32(len(response.Session))
229		trace.TracePrintf(ctx, nil, "Received a batch of %d sessions", actuallyCreated)
230		for _, s := range response.Session {
231			consumer.sessionReady(&session{valid: true, client: client, id: s.Name, createTime: time.Now(), md: md, logger: sc.logger})
232		}
233		if actuallyCreated < remainingCreateCount {
234			// Spanner could return less sessions than requested. In that case, we
235			// should do another call using the same gRPC channel.
236			remainingCreateCount -= actuallyCreated
237		} else {
238			trace.TracePrintf(ctx, nil, "Finished creating %d sessions", createCount)
239			break
240		}
241	}
242}
243
244func (sc *sessionClient) sessionWithID(id string) (*session, error) {
245	sc.mu.Lock()
246	defer sc.mu.Unlock()
247	client, err := sc.nextClient()
248	if err != nil {
249		return nil, err
250	}
251	return &session{valid: true, client: client, id: id, createTime: time.Now(), md: sc.md, logger: sc.logger}, nil
252}
253
254// nextClient returns the next gRPC client to use for session creation. The
255// client is set on the session, and used by all subsequent gRPC calls on the
256// session. Using the same channel for all gRPC calls for a session ensures the
257// optimal usage of server side caches.
258func (sc *sessionClient) nextClient() (*vkit.Client, error) {
259	// This call should never return an error as we are passing in an existing
260	// connection, so we can safely ignore it.
261	client, err := vkit.NewClient(context.Background(), option.WithGRPCConn(sc.connPool.Conn()))
262	if err != nil {
263		return nil, err
264	}
265	return client, nil
266}
267