1/*
2Copyright 2019 Google LLC
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8    http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15*/
16
17package spanner
18
19import (
20	"context"
21	"fmt"
22	"log"
23	"sync"
24	"time"
25
26	"cloud.google.com/go/internal/trace"
27	vkit "cloud.google.com/go/spanner/apiv1"
28	sppb "google.golang.org/genproto/googleapis/spanner/v1"
29	"google.golang.org/grpc/codes"
30	"google.golang.org/grpc/metadata"
31)
32
33// sessionConsumer is passed to the batchCreateSessions method and will receive
34// the sessions that are created as they become available. A sessionConsumer
35// implementation must be safe for concurrent use.
36//
37// The interface is implemented by sessionPool and is used for testing the
38// sessionClient.
39type sessionConsumer interface {
40	// sessionReady is called when a session has been created and is ready for
41	// use.
42	sessionReady(s *session)
43
44	// sessionCreationFailed is called when the creation of a sub-batch of
45	// sessions failed. The numSessions argument specifies the number of
46	// sessions that could not be created as a result of this error. A
47	// consumer may receive multiple errors per batch.
48	sessionCreationFailed(err error, numSessions int32)
49}
50
51// sessionClient creates sessions for a database, either in batches or one at a
52// time. Each session will be affiliated with a gRPC channel. sessionClient
53// will ensure that the sessions that are created are evenly distributed over
54// all available channels.
55type sessionClient struct {
56	mu     sync.Mutex
57	rr     int
58	closed bool
59
60	gapicClients  []*vkit.Client
61	database      string
62	sessionLabels map[string]string
63	md            metadata.MD
64	batchTimeout  time.Duration
65	logger        *log.Logger
66}
67
68// newSessionClient creates a session client to use for a database.
69func newSessionClient(gapicClients []*vkit.Client, database string, sessionLabels map[string]string, md metadata.MD, logger *log.Logger) *sessionClient {
70	return &sessionClient{
71		gapicClients:  gapicClients,
72		database:      database,
73		sessionLabels: sessionLabels,
74		md:            md,
75		batchTimeout:  time.Minute,
76		logger:        logger,
77	}
78}
79
80func (sc *sessionClient) close() error {
81	sc.mu.Lock()
82	defer sc.mu.Unlock()
83	sc.closed = true
84	var errs []error
85	for _, gpc := range sc.gapicClients {
86		if err := gpc.Close(); err != nil {
87			errs = append(errs, err)
88		}
89	}
90	switch len(errs) {
91	case 0:
92		return nil
93	case 1:
94		return errs[0]
95	default:
96		return fmt.Errorf("closing gapic clients returned multiple errors: %v", errs)
97	}
98}
99
100// createSession creates one session for the database of the sessionClient. The
101// session is created using one synchronous RPC.
102func (sc *sessionClient) createSession(ctx context.Context) (*session, error) {
103	ctx = contextWithOutgoingMetadata(ctx, sc.md)
104	sc.mu.Lock()
105	if sc.closed {
106		return nil, spannerErrorf(codes.FailedPrecondition, "SessionClient is closed")
107	}
108	client := sc.rrNextGapicClientLocked()
109	sc.mu.Unlock()
110	sid, err := client.CreateSession(ctx, &sppb.CreateSessionRequest{
111		Database: sc.database,
112		Session:  &sppb.Session{Labels: sc.sessionLabels},
113	})
114	if err != nil {
115		return nil, toSpannerError(err)
116	}
117	return &session{valid: true, client: client, id: sid.Name, createTime: time.Now(), md: sc.md, logger: sc.logger}, nil
118}
119
120// batchCreateSessions creates a batch of sessions for the database of the
121// sessionClient and returns these to the given sessionConsumer.
122//
123// createSessionCount is the number of sessions that should be created. The
124// sessionConsumer is guaranteed to receive the requested number of sessions if
125// no error occurs. If one or more errors occur, the sessionConsumer will
126// receive any number of sessions + any number of errors, where each error will
127// include the number of sessions that could not be created as a result of the
128// error. The sum of returned sessions and errored sessions will be equal to
129// the number of requested sessions.
130func (sc *sessionClient) batchCreateSessions(createSessionCount int32, consumer sessionConsumer) error {
131	// The sessions that we create should be evenly distributed over all the
132	// channels (gapic clients) that are used by the client. Each gapic client
133	// will do a request for a fraction of the total.
134	sessionCountPerChannel := createSessionCount / int32(len(sc.gapicClients))
135	// The remainder of the calculation will be added to the number of sessions
136	// that will be created for the first channel, to ensure that we create the
137	// exact number of requested sessions.
138	remainder := createSessionCount % int32(len(sc.gapicClients))
139	sc.mu.Lock()
140	defer sc.mu.Unlock()
141	if sc.closed {
142		return spannerErrorf(codes.FailedPrecondition, "SessionClient is closed")
143	}
144	// Spread the session creation over all available gRPC channels. Spanner
145	// will maintain server side caches for a session on the gRPC channel that
146	// is used by the session. A session should therefore always use the same
147	// channel, and the sessions should be as evenly distributed as possible
148	// over the channels.
149	for i := 0; i < len(sc.gapicClients); i++ {
150		client := sc.rrNextGapicClientLocked()
151		// Determine the number of sessions that should be created for this
152		// channel. The createCount for the first channel will be increased
153		// with the remainder of the division of the total number of sessions
154		// with the number of channels. All other channels will just use the
155		// result of the division over all channels.
156		createCountForChannel := sessionCountPerChannel
157		if i == 0 {
158			// We add the remainder to the first gRPC channel we use. We could
159			// also spread the remainder over all channels, but this ensures
160			// that small batches of sessions (i.e. less than numChannels) are
161			// created in one RPC.
162			createCountForChannel += remainder
163		}
164		if createCountForChannel > 0 {
165			go sc.executeBatchCreateSessions(client, createCountForChannel, sc.sessionLabels, sc.md, consumer)
166		}
167	}
168	return nil
169}
170
171// executeBatchCreateSessions executes the gRPC call for creating a batch of
172// sessions.
173func (sc *sessionClient) executeBatchCreateSessions(client *vkit.Client, createCount int32, labels map[string]string, md metadata.MD, consumer sessionConsumer) {
174	ctx, cancel := context.WithTimeout(context.Background(), sc.batchTimeout)
175	defer cancel()
176	ctx = contextWithOutgoingMetadata(ctx, sc.md)
177
178	ctx = trace.StartSpan(ctx, "cloud.google.com/go/spanner.BatchCreateSessions")
179	defer func() { trace.EndSpan(ctx, nil) }()
180	trace.TracePrintf(ctx, nil, "Creating a batch of %d sessions", createCount)
181	remainingCreateCount := createCount
182	for {
183		sc.mu.Lock()
184		closed := sc.closed
185		sc.mu.Unlock()
186		if closed {
187			err := spannerErrorf(codes.Canceled, "Session client closed")
188			trace.TracePrintf(ctx, nil, "Session client closed while creating a batch of %d sessions: %v", createCount, err)
189			consumer.sessionCreationFailed(err, remainingCreateCount)
190			break
191		}
192		if ctx.Err() != nil {
193			trace.TracePrintf(ctx, nil, "Context error while creating a batch of %d sessions: %v", createCount, ctx.Err())
194			consumer.sessionCreationFailed(toSpannerError(ctx.Err()), remainingCreateCount)
195			break
196		}
197		response, err := client.BatchCreateSessions(ctx, &sppb.BatchCreateSessionsRequest{
198			SessionCount:    remainingCreateCount,
199			Database:        sc.database,
200			SessionTemplate: &sppb.Session{Labels: labels},
201		})
202		if err != nil {
203			trace.TracePrintf(ctx, nil, "Error creating a batch of %d sessions: %v", remainingCreateCount, err)
204			consumer.sessionCreationFailed(toSpannerError(err), remainingCreateCount)
205			break
206		}
207		actuallyCreated := int32(len(response.Session))
208		trace.TracePrintf(ctx, nil, "Received a batch of %d sessions", actuallyCreated)
209		for _, s := range response.Session {
210			consumer.sessionReady(&session{valid: true, client: client, id: s.Name, createTime: time.Now(), md: md, logger: sc.logger})
211		}
212		if actuallyCreated < remainingCreateCount {
213			// Spanner could return less sessions than requested. In that case, we
214			// should do another call using the same gRPC channel.
215			remainingCreateCount -= actuallyCreated
216		} else {
217			trace.TracePrintf(ctx, nil, "Finished creating %d sessions", createCount)
218			break
219		}
220	}
221}
222
223func (sc *sessionClient) sessionWithID(id string) *session {
224	sc.mu.Lock()
225	defer sc.mu.Unlock()
226	return &session{valid: true, client: sc.rrNextGapicClientLocked(), id: id, createTime: time.Now(), md: sc.md, logger: sc.logger}
227}
228
229// rrNextGapicClientLocked returns the next gRPC client to use for session creation. The
230// client is set on the session, and used by all subsequent gRPC calls on the
231// session. Using the same channel for all gRPC calls for a session ensures the
232// optimal usage of server side caches.
233func (sc *sessionClient) rrNextGapicClientLocked() *vkit.Client {
234	sc.rr = (sc.rr + 1) % len(sc.gapicClients)
235	return sc.gapicClients[sc.rr]
236}
237