1/*
2Copyright 2017 Google LLC
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8    http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15*/
16
17package spanner
18
19import (
20	"container/heap"
21	"container/list"
22	"context"
23	"fmt"
24	"log"
25	"math/rand"
26	"strings"
27	"sync"
28	"time"
29
30	"cloud.google.com/go/internal/trace"
31	sppb "google.golang.org/genproto/googleapis/spanner/v1"
32	"google.golang.org/grpc/codes"
33	"google.golang.org/grpc/metadata"
34)
35
36// sessionHandle is an interface for transactions to access Cloud Spanner
37// sessions safely. It is generated by sessionPool.take().
38type sessionHandle struct {
39	// mu guarantees that the inner session object is returned / destroyed only
40	// once.
41	mu sync.Mutex
42	// session is a pointer to a session object. Transactions never need to
43	// access it directly.
44	session *session
45}
46
47// recycle gives the inner session object back to its home session pool. It is
48// safe to call recycle multiple times but only the first one would take effect.
49func (sh *sessionHandle) recycle() {
50	sh.mu.Lock()
51	defer sh.mu.Unlock()
52	if sh.session == nil {
53		// sessionHandle has already been recycled.
54		return
55	}
56	sh.session.recycle()
57	sh.session = nil
58}
59
60// getID gets the Cloud Spanner session ID from the internal session object.
61// getID returns empty string if the sessionHandle is nil or the inner session
62// object has been released by recycle / destroy.
63func (sh *sessionHandle) getID() string {
64	sh.mu.Lock()
65	defer sh.mu.Unlock()
66	if sh.session == nil {
67		// sessionHandle has already been recycled/destroyed.
68		return ""
69	}
70	return sh.session.getID()
71}
72
73// getClient gets the Cloud Spanner RPC client associated with the session ID
74// in sessionHandle.
75func (sh *sessionHandle) getClient() sppb.SpannerClient {
76	sh.mu.Lock()
77	defer sh.mu.Unlock()
78	if sh.session == nil {
79		return nil
80	}
81	return sh.session.client
82}
83
84// getMetadata returns the metadata associated with the session in sessionHandle.
85func (sh *sessionHandle) getMetadata() metadata.MD {
86	sh.mu.Lock()
87	defer sh.mu.Unlock()
88	if sh.session == nil {
89		return nil
90	}
91	return sh.session.md
92}
93
94// getTransactionID returns the transaction id in the session if available.
95func (sh *sessionHandle) getTransactionID() transactionID {
96	sh.mu.Lock()
97	defer sh.mu.Unlock()
98	if sh.session == nil {
99		return nil
100	}
101	return sh.session.tx
102}
103
104// destroy destroys the inner session object. It is safe to call destroy
105// multiple times and only the first call would attempt to
106// destroy the inner session object.
107func (sh *sessionHandle) destroy() {
108	sh.mu.Lock()
109	s := sh.session
110	sh.session = nil
111	sh.mu.Unlock()
112	if s == nil {
113		// sessionHandle has already been destroyed..
114		return
115	}
116	s.destroy(false)
117}
118
119// session wraps a Cloud Spanner session ID through which transactions are
120// created and executed.
121type session struct {
122	// client is the RPC channel to Cloud Spanner. It is set only once during
123	// session's creation.
124	client sppb.SpannerClient
125	// id is the unique id of the session in Cloud Spanner. It is set only once
126	// during session's creation.
127	id string
128	// pool is the session's home session pool where it was created. It is set
129	// only once during session's creation.
130	pool *sessionPool
131	// createTime is the timestamp of the session's creation. It is set only
132	// once during session's creation.
133	createTime time.Time
134
135	// mu protects the following fields from concurrent access: both
136	// healthcheck workers and transactions can modify them.
137	mu sync.Mutex
138	// valid marks the validity of a session.
139	valid bool
140	// hcIndex is the index of the session inside the global healthcheck queue.
141	// If hcIndex < 0, session has been unregistered from the queue.
142	hcIndex int
143	// idleList is the linkedlist node which links the session to its home
144	// session pool's idle list. If idleList == nil, the
145	// session is not in idle list.
146	idleList *list.Element
147	// nextCheck is the timestamp of next scheduled healthcheck of the session.
148	// It is maintained by the global health checker.
149	nextCheck time.Time
150	// checkingHelath is true if currently this session is being processed by
151	// health checker. Must be modified under health checker lock.
152	checkingHealth bool
153	// md is the Metadata to be sent with each request.
154	md metadata.MD
155	// tx contains the transaction id if the session has been prepared for
156	// write.
157	tx transactionID
158}
159
160// isValid returns true if the session is still valid for use.
161func (s *session) isValid() bool {
162	s.mu.Lock()
163	defer s.mu.Unlock()
164	return s.valid
165}
166
167// isWritePrepared returns true if the session is prepared for write.
168func (s *session) isWritePrepared() bool {
169	s.mu.Lock()
170	defer s.mu.Unlock()
171	return s.tx != nil
172}
173
174// String implements fmt.Stringer for session.
175func (s *session) String() string {
176	s.mu.Lock()
177	defer s.mu.Unlock()
178	return fmt.Sprintf("<id=%v, hcIdx=%v, idleList=%p, valid=%v, create=%v, nextcheck=%v>",
179		s.id, s.hcIndex, s.idleList, s.valid, s.createTime, s.nextCheck)
180}
181
182// ping verifies if the session is still alive in Cloud Spanner.
183func (s *session) ping() error {
184	ctx, cancel := context.WithTimeout(context.Background(), time.Second)
185	defer cancel()
186	return runRetryable(ctx, func(ctx context.Context) error {
187		// s.getID is safe even when s is invalid.
188		_, err := s.client.GetSession(contextWithOutgoingMetadata(ctx, s.pool.md), &sppb.GetSessionRequest{Name: s.getID()})
189		return err
190	})
191}
192
193// setHcIndex atomically sets the session's index in the healthcheck queue and
194// returns the old index.
195func (s *session) setHcIndex(i int) int {
196	s.mu.Lock()
197	defer s.mu.Unlock()
198	oi := s.hcIndex
199	s.hcIndex = i
200	return oi
201}
202
203// setIdleList atomically sets the session's idle list link and returns the old
204// link.
205func (s *session) setIdleList(le *list.Element) *list.Element {
206	s.mu.Lock()
207	defer s.mu.Unlock()
208	old := s.idleList
209	s.idleList = le
210	return old
211}
212
213// invalidate marks a session as invalid and returns the old validity.
214func (s *session) invalidate() bool {
215	s.mu.Lock()
216	defer s.mu.Unlock()
217	ov := s.valid
218	s.valid = false
219	return ov
220}
221
222// setNextCheck sets the timestamp for next healthcheck on the session.
223func (s *session) setNextCheck(t time.Time) {
224	s.mu.Lock()
225	defer s.mu.Unlock()
226	s.nextCheck = t
227}
228
229// setTransactionID sets the transaction id in the session
230func (s *session) setTransactionID(tx transactionID) {
231	s.mu.Lock()
232	defer s.mu.Unlock()
233	s.tx = tx
234}
235
236// getID returns the session ID which uniquely identifies the session in Cloud
237// Spanner.
238func (s *session) getID() string {
239	s.mu.Lock()
240	defer s.mu.Unlock()
241	return s.id
242}
243
244// getHcIndex returns the session's index into the global healthcheck priority
245// queue.
246func (s *session) getHcIndex() int {
247	s.mu.Lock()
248	defer s.mu.Unlock()
249	return s.hcIndex
250}
251
252// getIdleList returns the session's link in its home session pool's idle list.
253func (s *session) getIdleList() *list.Element {
254	s.mu.Lock()
255	defer s.mu.Unlock()
256	return s.idleList
257}
258
259// getNextCheck returns the timestamp for next healthcheck on the session.
260func (s *session) getNextCheck() time.Time {
261	s.mu.Lock()
262	defer s.mu.Unlock()
263	return s.nextCheck
264}
265
266// recycle turns the session back to its home session pool.
267func (s *session) recycle() {
268	s.setTransactionID(nil)
269	if !s.pool.recycle(s) {
270		// s is rejected by its home session pool because it expired and the
271		// session pool currently has enough open sessions.
272		s.destroy(false)
273	}
274}
275
276// destroy removes the session from its home session pool, healthcheck queue
277// and Cloud Spanner service.
278func (s *session) destroy(isExpire bool) bool {
279	// Remove s from session pool.
280	if !s.pool.remove(s, isExpire) {
281		return false
282	}
283	// Unregister s from healthcheck queue.
284	s.pool.hc.unregister(s)
285	// Remove s from Cloud Spanner service.
286	ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
287	defer cancel()
288	s.delete(ctx)
289	return true
290}
291
292func (s *session) delete(ctx context.Context) {
293	// Ignore the error returned by runRetryable because even if we fail to
294	// explicitly destroy the session, it will be eventually garbage collected
295	// by Cloud Spanner.
296	err := runRetryable(ctx, func(ctx context.Context) error {
297		_, e := s.client.DeleteSession(ctx, &sppb.DeleteSessionRequest{Name: s.getID()})
298		return e
299	})
300	if err != nil {
301		log.Printf("Failed to delete session %v. Error: %v", s.getID(), err)
302	}
303}
304
305// prepareForWrite prepares the session for write if it is not already in that
306// state.
307func (s *session) prepareForWrite(ctx context.Context) error {
308	if s.isWritePrepared() {
309		return nil
310	}
311	tx, err := beginTransaction(ctx, s.getID(), s.client)
312	if err != nil {
313		return err
314	}
315	s.setTransactionID(tx)
316	return nil
317}
318
319// SessionPoolConfig stores configurations of a session pool.
320type SessionPoolConfig struct {
321	// getRPCClient is the caller supplied method for getting a gRPC client to
322	// Cloud Spanner, this makes session pool able to use client pooling.
323	getRPCClient func() (sppb.SpannerClient, error)
324
325	// MaxOpened is the maximum number of opened sessions allowed by the session
326	// pool. If the client tries to open a session and there are already
327	// MaxOpened sessions, it will block until one becomes available or the
328	// context passed to the client method is canceled or times out.
329	//
330	// Defaults to NumChannels * 100.
331	MaxOpened uint64
332
333	// MinOpened is the minimum number of opened sessions that the session pool
334	// tries to maintain. Session pool won't continue to expire sessions if
335	// number of opened connections drops below MinOpened. However, if a session
336	// is found to be broken, it will still be evicted from the session pool,
337	// therefore it is posssible that the number of opened sessions drops below
338	// MinOpened.
339	//
340	// Defaults to 0.
341	MinOpened uint64
342
343	// MaxIdle is the maximum number of idle sessions, pool is allowed to keep.
344	//
345	// Defaults to 0.
346	MaxIdle uint64
347
348	// MaxBurst is the maximum number of concurrent session creation requests.
349	//
350	// Defaults to 10.
351	MaxBurst uint64
352
353	// WriteSessions is the fraction of sessions we try to keep prepared for
354	// write.
355	//
356	// Defaults to 0.
357	WriteSessions float64
358
359	// HealthCheckWorkers is number of workers used by health checker for this
360	// pool.
361	//
362	// Defaults to 10.
363	HealthCheckWorkers int
364
365	// HealthCheckInterval is how often the health checker pings a session.
366	//
367	// Defaults to 5m.
368	HealthCheckInterval time.Duration
369
370	// healthCheckSampleInterval is how often the health checker samples live
371	// session (for use in maintaining session pool size).
372	//
373	// Defaults to 1m.
374	healthCheckSampleInterval time.Duration
375
376	// sessionLabels for the sessions created in the session pool.
377	sessionLabels map[string]string
378}
379
380// errNoRPCGetter returns error for SessionPoolConfig missing getRPCClient method.
381func errNoRPCGetter() error {
382	return spannerErrorf(codes.InvalidArgument, "require SessionPoolConfig.getRPCClient != nil, got nil")
383}
384
385// errMinOpenedGTMapOpened returns error for SessionPoolConfig.MaxOpened < SessionPoolConfig.MinOpened when SessionPoolConfig.MaxOpened is set.
386func errMinOpenedGTMaxOpened(maxOpened, minOpened uint64) error {
387	return spannerErrorf(codes.InvalidArgument,
388		"require SessionPoolConfig.MaxOpened >= SessionPoolConfig.MinOpened, got %v and %v", maxOpened, minOpened)
389}
390
391// validate verifies that the SessionPoolConfig is good for use.
392func (spc *SessionPoolConfig) validate() error {
393	if spc.getRPCClient == nil {
394		return errNoRPCGetter()
395	}
396	if spc.MinOpened > spc.MaxOpened && spc.MaxOpened > 0 {
397		return errMinOpenedGTMaxOpened(spc.MaxOpened, spc.MinOpened)
398	}
399	return nil
400}
401
402// sessionPool creates and caches Cloud Spanner sessions.
403type sessionPool struct {
404	// mu protects sessionPool from concurrent access.
405	mu sync.Mutex
406	// valid marks the validity of the session pool.
407	valid bool
408	// db is the database name that all sessions in the pool are associated with.
409	db string
410	// idleList caches idle session IDs. Session IDs in this list can be
411	// allocated for use.
412	idleList list.List
413	// idleWriteList caches idle sessions which have been prepared for write.
414	idleWriteList list.List
415	// mayGetSession is for broadcasting that session retrival/creation may
416	// proceed.
417	mayGetSession chan struct{}
418	// numOpened is the total number of open sessions from the session pool.
419	numOpened uint64
420	// createReqs is the number of ongoing session creation requests.
421	createReqs uint64
422	// prepareReqs is the number of ongoing session preparation request.
423	prepareReqs uint64
424	// configuration of the session pool.
425	SessionPoolConfig
426	// Metadata to be sent with each request
427	md metadata.MD
428	// hc is the health checker
429	hc *healthChecker
430}
431
432// newSessionPool creates a new session pool.
433func newSessionPool(db string, config SessionPoolConfig, md metadata.MD) (*sessionPool, error) {
434	if err := config.validate(); err != nil {
435		return nil, err
436	}
437	pool := &sessionPool{
438		db:                db,
439		valid:             true,
440		mayGetSession:     make(chan struct{}),
441		SessionPoolConfig: config,
442		md:                md,
443	}
444	if config.HealthCheckWorkers == 0 {
445		// With 10 workers and assuming average latency of 5ms for
446		// BeginTransaction, we will be able to prepare 2000 tx/sec in advance.
447		// If the rate of takeWriteSession is more than that, it will degrade to
448		// doing BeginTransaction inline.
449		//
450		// TODO: consider resizing the worker pool dynamically according to the load.
451		config.HealthCheckWorkers = 10
452	}
453	if config.HealthCheckInterval == 0 {
454		config.HealthCheckInterval = 5 * time.Minute
455	}
456	if config.healthCheckSampleInterval == 0 {
457		config.healthCheckSampleInterval = time.Minute
458	}
459	// On GCE VM, within the same region an healthcheck ping takes on average
460	// 10ms to finish, given a 5 minutes interval and 10 healthcheck workers, a
461	// healthChecker can effectively mantain
462	// 100 checks_per_worker/sec * 10 workers * 300 seconds = 300K sessions.
463	pool.hc = newHealthChecker(config.HealthCheckInterval, config.HealthCheckWorkers, config.healthCheckSampleInterval, pool)
464	close(pool.hc.ready)
465	return pool, nil
466}
467
468// isValid checks if the session pool is still valid.
469func (p *sessionPool) isValid() bool {
470	if p == nil {
471		return false
472	}
473	p.mu.Lock()
474	defer p.mu.Unlock()
475	return p.valid
476}
477
478// close marks the session pool as closed.
479func (p *sessionPool) close() {
480	if p == nil {
481		return
482	}
483	p.mu.Lock()
484	if !p.valid {
485		p.mu.Unlock()
486		return
487	}
488	p.valid = false
489	p.mu.Unlock()
490	p.hc.close()
491	// destroy all the sessions
492	p.hc.mu.Lock()
493	allSessions := make([]*session, len(p.hc.queue.sessions))
494	copy(allSessions, p.hc.queue.sessions)
495	p.hc.mu.Unlock()
496	for _, s := range allSessions {
497		s.destroy(false)
498	}
499}
500
501// errInvalidSessionPool returns error for using an invalid session pool.
502func errInvalidSessionPool() error {
503	return spannerErrorf(codes.InvalidArgument, "invalid session pool")
504}
505
506// errGetSessionTimeout returns error for context timeout during
507// sessionPool.take().
508func errGetSessionTimeout() error {
509	return spannerErrorf(codes.Canceled, "timeout / context canceled during getting session")
510}
511
512// shouldPrepareWrite returns true if we should prepare more sessions for write.
513func (p *sessionPool) shouldPrepareWrite() bool {
514	return float64(p.numOpened)*p.WriteSessions > float64(p.idleWriteList.Len()+int(p.prepareReqs))
515}
516
517func (p *sessionPool) createSession(ctx context.Context) (*session, error) {
518	trace.TracePrintf(ctx, nil, "Creating a new session")
519	doneCreate := func(done bool) {
520		p.mu.Lock()
521		if !done {
522			// Session creation failed, give budget back.
523			p.numOpened--
524			recordStat(ctx, OpenSessionCount, int64(p.numOpened))
525		}
526		p.createReqs--
527		// Notify other waiters blocking on session creation.
528		close(p.mayGetSession)
529		p.mayGetSession = make(chan struct{})
530		p.mu.Unlock()
531	}
532	sc, err := p.getRPCClient()
533	if err != nil {
534		doneCreate(false)
535		return nil, err
536	}
537	s, err := createSession(ctx, sc, p.db, p.sessionLabels, p.md)
538	if err != nil {
539		doneCreate(false)
540		// Should return error directly because of the previous retries on
541		// CreateSession RPC.
542		return nil, err
543	}
544	s.pool = p
545	p.hc.register(s)
546	doneCreate(true)
547	return s, nil
548}
549
550func createSession(ctx context.Context, sc sppb.SpannerClient, db string, labels map[string]string, md metadata.MD) (*session, error) {
551	var s *session
552	err := runRetryable(ctx, func(ctx context.Context) error {
553		sid, e := sc.CreateSession(ctx, &sppb.CreateSessionRequest{
554			Database: db,
555			Session:  &sppb.Session{Labels: labels},
556		})
557		if e != nil {
558			return e
559		}
560		// If no error, construct the new session.
561		s = &session{valid: true, client: sc, id: sid.Name, createTime: time.Now(), md: md}
562		return nil
563	})
564	if err != nil {
565		return nil, err
566	}
567	return s, nil
568}
569
570func (p *sessionPool) isHealthy(s *session) bool {
571	if s.getNextCheck().Add(2 * p.hc.getInterval()).Before(time.Now()) {
572		// TODO: figure out if we need to schedule a new healthcheck worker here.
573		if err := s.ping(); shouldDropSession(err) {
574			// The session is already bad, continue to fetch/create a new one.
575			s.destroy(false)
576			return false
577		}
578		p.hc.scheduledHC(s)
579	}
580	return true
581}
582
583// take returns a cached session if there are available ones; if there isn't
584// any, it tries to allocate a new one. Session returned by take should be used
585// for read operations.
586func (p *sessionPool) take(ctx context.Context) (*sessionHandle, error) {
587	trace.TracePrintf(ctx, nil, "Acquiring a read-only session")
588	ctx = contextWithOutgoingMetadata(ctx, p.md)
589	for {
590		var (
591			s   *session
592			err error
593		)
594
595		p.mu.Lock()
596		if !p.valid {
597			p.mu.Unlock()
598			return nil, errInvalidSessionPool()
599		}
600		if p.idleList.Len() > 0 {
601			// Idle sessions are available, get one from the top of the idle
602			// list.
603			s = p.idleList.Remove(p.idleList.Front()).(*session)
604			trace.TracePrintf(ctx, map[string]interface{}{"sessionID": s.getID()},
605				"Acquired read-only session")
606		} else if p.idleWriteList.Len() > 0 {
607			s = p.idleWriteList.Remove(p.idleWriteList.Front()).(*session)
608			trace.TracePrintf(ctx, map[string]interface{}{"sessionID": s.getID()},
609				"Acquired read-write session")
610		}
611		if s != nil {
612			s.setIdleList(nil)
613			p.mu.Unlock()
614			// From here, session is no longer in idle list, so healthcheck
615			// workers won't destroy it. If healthcheck workers failed to
616			// schedule healthcheck for the session timely, do the check here.
617			// Because session check is still much cheaper than session
618			// creation, they should be reused as much as possible.
619			if !p.isHealthy(s) {
620				continue
621			}
622			return &sessionHandle{session: s}, nil
623		}
624
625		// Idle list is empty, block if session pool has reached max session
626		// creation concurrency or max number of open sessions.
627		if (p.MaxOpened > 0 && p.numOpened >= p.MaxOpened) || (p.MaxBurst > 0 && p.createReqs >= p.MaxBurst) {
628			mayGetSession := p.mayGetSession
629			p.mu.Unlock()
630			trace.TracePrintf(ctx, nil, "Waiting for read-only session to become available")
631			select {
632			case <-ctx.Done():
633				trace.TracePrintf(ctx, nil, "Context done waiting for session")
634				return nil, errGetSessionTimeout()
635			case <-mayGetSession:
636			}
637			continue
638		}
639
640		// Take budget before the actual session creation.
641		p.numOpened++
642		recordStat(ctx, OpenSessionCount, int64(p.numOpened))
643		p.createReqs++
644		p.mu.Unlock()
645		if s, err = p.createSession(ctx); err != nil {
646			trace.TracePrintf(ctx, nil, "Error creating session: %v", err)
647			return nil, toSpannerError(err)
648		}
649		trace.TracePrintf(ctx, map[string]interface{}{"sessionID": s.getID()},
650			"Created session")
651		return &sessionHandle{session: s}, nil
652	}
653}
654
655// takeWriteSession returns a write prepared cached session if there are
656// available ones; if there isn't any, it tries to allocate a new one. Session
657// returned should be used for read write transactions.
658func (p *sessionPool) takeWriteSession(ctx context.Context) (*sessionHandle, error) {
659	trace.TracePrintf(ctx, nil, "Acquiring a read-write session")
660	ctx = contextWithOutgoingMetadata(ctx, p.md)
661	for {
662		var (
663			s   *session
664			err error
665		)
666
667		p.mu.Lock()
668		if !p.valid {
669			p.mu.Unlock()
670			return nil, errInvalidSessionPool()
671		}
672		if p.idleWriteList.Len() > 0 {
673			// Idle sessions are available, get one from the top of the idle
674			// list.
675			s = p.idleWriteList.Remove(p.idleWriteList.Front()).(*session)
676			trace.TracePrintf(ctx, map[string]interface{}{"sessionID": s.getID()}, "Acquired read-write session")
677		} else if p.idleList.Len() > 0 {
678			s = p.idleList.Remove(p.idleList.Front()).(*session)
679			trace.TracePrintf(ctx, map[string]interface{}{"sessionID": s.getID()}, "Acquired read-only session")
680		}
681		if s != nil {
682			s.setIdleList(nil)
683			p.mu.Unlock()
684			// From here, session is no longer in idle list, so healthcheck
685			// workers won't destroy it. If healthcheck workers failed to
686			// schedule healthcheck for the session timely, do the check here.
687			// Because session check is still much cheaper than session
688			// creation, they should be reused as much as possible.
689			if !p.isHealthy(s) {
690				continue
691			}
692		} else {
693			// Idle list is empty, block if session pool has reached max session
694			// creation concurrency or max number of open sessions.
695			if (p.MaxOpened > 0 && p.numOpened >= p.MaxOpened) || (p.MaxBurst > 0 && p.createReqs >= p.MaxBurst) {
696				mayGetSession := p.mayGetSession
697				p.mu.Unlock()
698				trace.TracePrintf(ctx, nil, "Waiting for read-write session to become available")
699				select {
700				case <-ctx.Done():
701					trace.TracePrintf(ctx, nil, "Context done waiting for session")
702					return nil, errGetSessionTimeout()
703				case <-mayGetSession:
704				}
705				continue
706			}
707
708			// Take budget before the actual session creation.
709			p.numOpened++
710			recordStat(ctx, OpenSessionCount, int64(p.numOpened))
711			p.createReqs++
712			p.mu.Unlock()
713			if s, err = p.createSession(ctx); err != nil {
714				trace.TracePrintf(ctx, nil, "Error creating session: %v", err)
715				return nil, toSpannerError(err)
716			}
717			trace.TracePrintf(ctx, map[string]interface{}{"sessionID": s.getID()},
718				"Created session")
719		}
720		if !s.isWritePrepared() {
721			if err = s.prepareForWrite(ctx); err != nil {
722				s.recycle()
723				trace.TracePrintf(ctx, map[string]interface{}{"sessionID": s.getID()},
724					"Error preparing session for write")
725				return nil, toSpannerError(err)
726			}
727		}
728		return &sessionHandle{session: s}, nil
729	}
730}
731
732// recycle puts session s back to the session pool's idle list, it returns true
733// if the session pool successfully recycles session s.
734func (p *sessionPool) recycle(s *session) bool {
735	p.mu.Lock()
736	defer p.mu.Unlock()
737	if !s.isValid() || !p.valid {
738		// Reject the session if session is invalid or pool itself is invalid.
739		return false
740	}
741	// Put session at the back of the list to round robin for load balancing
742	// across channels.
743	if s.isWritePrepared() {
744		s.setIdleList(p.idleWriteList.PushBack(s))
745	} else {
746		s.setIdleList(p.idleList.PushBack(s))
747	}
748	// Broadcast that a session has been returned to idle list.
749	close(p.mayGetSession)
750	p.mayGetSession = make(chan struct{})
751	return true
752}
753
754// remove atomically removes session s from the session pool and invalidates s.
755// If isExpire == true, the removal is triggered by session expiration and in
756// such cases, only idle sessions can be removed.
757func (p *sessionPool) remove(s *session, isExpire bool) bool {
758	p.mu.Lock()
759	defer p.mu.Unlock()
760	if isExpire && (p.numOpened <= p.MinOpened || s.getIdleList() == nil) {
761		// Don't expire session if the session is not in idle list (in use), or
762		// if number of open sessions is going below p.MinOpened.
763		return false
764	}
765	ol := s.setIdleList(nil)
766	// If the session is in the idlelist, remove it.
767	if ol != nil {
768		// Remove from whichever list it is in.
769		p.idleList.Remove(ol)
770		p.idleWriteList.Remove(ol)
771	}
772	if s.invalidate() {
773		// Decrease the number of opened sessions.
774		p.numOpened--
775		recordStat(context.Background(), OpenSessionCount, int64(p.numOpened))
776		// Broadcast that a session has been destroyed.
777		close(p.mayGetSession)
778		p.mayGetSession = make(chan struct{})
779		return true
780	}
781	return false
782}
783
784// hcHeap implements heap.Interface. It is used to create the priority queue for
785// session healthchecks.
786type hcHeap struct {
787	sessions []*session
788}
789
790// Len impelemnts heap.Interface.Len.
791func (h hcHeap) Len() int {
792	return len(h.sessions)
793}
794
795// Less implements heap.Interface.Less.
796func (h hcHeap) Less(i, j int) bool {
797	return h.sessions[i].getNextCheck().Before(h.sessions[j].getNextCheck())
798}
799
800// Swap implements heap.Interface.Swap.
801func (h hcHeap) Swap(i, j int) {
802	h.sessions[i], h.sessions[j] = h.sessions[j], h.sessions[i]
803	h.sessions[i].setHcIndex(i)
804	h.sessions[j].setHcIndex(j)
805}
806
807// Push implements heap.Interface.Push.
808func (h *hcHeap) Push(s interface{}) {
809	ns := s.(*session)
810	ns.setHcIndex(len(h.sessions))
811	h.sessions = append(h.sessions, ns)
812}
813
814// Pop implements heap.Interface.Pop.
815func (h *hcHeap) Pop() interface{} {
816	old := h.sessions
817	n := len(old)
818	s := old[n-1]
819	h.sessions = old[:n-1]
820	s.setHcIndex(-1)
821	return s
822}
823
824// healthChecker performs periodical healthchecks on registered sessions.
825type healthChecker struct {
826	// mu protects concurrent access to healthChecker.
827	mu sync.Mutex
828	// queue is the priority queue for session healthchecks. Sessions with lower
829	// nextCheck rank higher in the queue.
830	queue hcHeap
831	// interval is the average interval between two healthchecks on a session.
832	interval time.Duration
833	// workers is the number of concurrent healthcheck workers.
834	workers int
835	// waitWorkers waits for all healthcheck workers to exit
836	waitWorkers sync.WaitGroup
837	// pool is the underlying session pool.
838	pool *sessionPool
839	// sampleInterval is the interval of sampling by the maintainer.
840	sampleInterval time.Duration
841	// ready is used to signal that maintainer can start running.
842	ready chan struct{}
843	// done is used to signal that health checker should be closed.
844	done chan struct{}
845	// once is used for closing channel done only once.
846	once             sync.Once
847	maintainerCancel func()
848}
849
850// newHealthChecker initializes new instance of healthChecker.
851func newHealthChecker(interval time.Duration, workers int, sampleInterval time.Duration, pool *sessionPool) *healthChecker {
852	if workers <= 0 {
853		workers = 1
854	}
855	hc := &healthChecker{
856		interval:         interval,
857		workers:          workers,
858		pool:             pool,
859		sampleInterval:   sampleInterval,
860		ready:            make(chan struct{}),
861		done:             make(chan struct{}),
862		maintainerCancel: func() {},
863	}
864	hc.waitWorkers.Add(1)
865	go hc.maintainer()
866	for i := 1; i <= hc.workers; i++ {
867		hc.waitWorkers.Add(1)
868		go hc.worker(i)
869	}
870	return hc
871}
872
873// close closes the healthChecker and waits for all healthcheck workers to exit.
874func (hc *healthChecker) close() {
875	hc.mu.Lock()
876	hc.maintainerCancel()
877	hc.mu.Unlock()
878	hc.once.Do(func() { close(hc.done) })
879	hc.waitWorkers.Wait()
880}
881
882// isClosing checks if a healthChecker is already closing.
883func (hc *healthChecker) isClosing() bool {
884	select {
885	case <-hc.done:
886		return true
887	default:
888		return false
889	}
890}
891
892// getInterval gets the healthcheck interval.
893func (hc *healthChecker) getInterval() time.Duration {
894	hc.mu.Lock()
895	defer hc.mu.Unlock()
896	return hc.interval
897}
898
899// scheduledHCLocked schedules next healthcheck on session s with the assumption
900// that hc.mu is being held.
901func (hc *healthChecker) scheduledHCLocked(s *session) {
902	// The next healthcheck will be scheduled after
903	// [interval*0.5, interval*1.5) ns.
904	nsFromNow := rand.Int63n(int64(hc.interval)) + int64(hc.interval)/2
905	s.setNextCheck(time.Now().Add(time.Duration(nsFromNow)))
906	if hi := s.getHcIndex(); hi != -1 {
907		// Session is still being tracked by healthcheck workers.
908		heap.Fix(&hc.queue, hi)
909	}
910}
911
912// scheduledHC schedules next healthcheck on session s. It is safe to be called
913// concurrently.
914func (hc *healthChecker) scheduledHC(s *session) {
915	hc.mu.Lock()
916	defer hc.mu.Unlock()
917	hc.scheduledHCLocked(s)
918}
919
920// register registers a session with healthChecker for periodical healthcheck.
921func (hc *healthChecker) register(s *session) {
922	hc.mu.Lock()
923	defer hc.mu.Unlock()
924	hc.scheduledHCLocked(s)
925	heap.Push(&hc.queue, s)
926}
927
928// unregister unregisters a session from healthcheck queue.
929func (hc *healthChecker) unregister(s *session) {
930	hc.mu.Lock()
931	defer hc.mu.Unlock()
932	oi := s.setHcIndex(-1)
933	if oi >= 0 {
934		heap.Remove(&hc.queue, oi)
935	}
936}
937
938// markDone marks that health check for session has been performed.
939func (hc *healthChecker) markDone(s *session) {
940	hc.mu.Lock()
941	defer hc.mu.Unlock()
942	s.checkingHealth = false
943}
944
945// healthCheck checks the health of the session and pings it if needed.
946func (hc *healthChecker) healthCheck(s *session) {
947	defer hc.markDone(s)
948	if !s.pool.isValid() {
949		// Session pool is closed, perform a garbage collection.
950		s.destroy(false)
951		return
952	}
953	if err := s.ping(); shouldDropSession(err) {
954		// Ping failed, destroy the session.
955		s.destroy(false)
956	}
957}
958
959// worker performs the healthcheck on sessions in healthChecker's priority
960// queue.
961func (hc *healthChecker) worker(i int) {
962	// Returns a session which we should ping to keep it alive.
963	getNextForPing := func() *session {
964		hc.pool.mu.Lock()
965		defer hc.pool.mu.Unlock()
966		hc.mu.Lock()
967		defer hc.mu.Unlock()
968		if hc.queue.Len() <= 0 {
969			// Queue is empty.
970			return nil
971		}
972		s := hc.queue.sessions[0]
973		if s.getNextCheck().After(time.Now()) && hc.pool.valid {
974			// All sessions have been checked recently.
975			return nil
976		}
977		hc.scheduledHCLocked(s)
978		if !s.checkingHealth {
979			s.checkingHealth = true
980			return s
981		}
982		return nil
983	}
984
985	// Returns a session which we should prepare for write.
986	getNextForTx := func() *session {
987		hc.pool.mu.Lock()
988		defer hc.pool.mu.Unlock()
989		if hc.pool.shouldPrepareWrite() {
990			if hc.pool.idleList.Len() > 0 && hc.pool.valid {
991				hc.mu.Lock()
992				defer hc.mu.Unlock()
993				if hc.pool.idleList.Front().Value.(*session).checkingHealth {
994					return nil
995				}
996				session := hc.pool.idleList.Remove(hc.pool.idleList.Front()).(*session)
997				session.checkingHealth = true
998				hc.pool.prepareReqs++
999				return session
1000			}
1001		}
1002		return nil
1003	}
1004
1005	for {
1006		if hc.isClosing() {
1007			// Exit when the pool has been closed and all sessions have been
1008			// destroyed or when health checker has been closed.
1009			hc.waitWorkers.Done()
1010			return
1011		}
1012		ws := getNextForTx()
1013		if ws != nil {
1014			ctx, cancel := context.WithTimeout(context.Background(), time.Second)
1015			err := ws.prepareForWrite(contextWithOutgoingMetadata(ctx, hc.pool.md))
1016			cancel()
1017			if err != nil {
1018				// Skip handling prepare error, session can be prepared in next
1019				// cycle.
1020				log.Printf("Failed to prepare session, error: %v", toSpannerError(err))
1021			}
1022			hc.pool.recycle(ws)
1023			hc.pool.mu.Lock()
1024			hc.pool.prepareReqs--
1025			hc.pool.mu.Unlock()
1026			hc.markDone(ws)
1027		}
1028		rs := getNextForPing()
1029		if rs == nil {
1030			if ws == nil {
1031				// No work to be done so sleep to avoid burning CPU.
1032				pause := int64(100 * time.Millisecond)
1033				if pause > int64(hc.interval) {
1034					pause = int64(hc.interval)
1035				}
1036				select {
1037				case <-time.After(time.Duration(rand.Int63n(pause) + pause/2)):
1038				case <-hc.done:
1039				}
1040
1041			}
1042			continue
1043		}
1044		hc.healthCheck(rs)
1045	}
1046}
1047
1048// maintainer maintains the maxSessionsInUse by a window of
1049// kWindowSize * sampleInterval. Based on this information, health checker will
1050// try to maintain the number of sessions by hc.
1051func (hc *healthChecker) maintainer() {
1052	// Wait so that pool is ready.
1053	<-hc.ready
1054
1055	windowSize := uint64(10)
1056
1057	for iteration := uint64(0); ; iteration++ {
1058		if hc.isClosing() {
1059			hc.waitWorkers.Done()
1060			return
1061		}
1062
1063		// maxSessionsInUse is the maximum number of sessions in use
1064		// concurrently over a period of time.
1065		var maxSessionsInUse uint64
1066
1067		// Updates metrics.
1068		hc.pool.mu.Lock()
1069		currSessionsInUse := hc.pool.numOpened - uint64(hc.pool.idleList.Len()) - uint64(hc.pool.idleWriteList.Len())
1070		currSessionsOpened := hc.pool.numOpened
1071		hc.pool.mu.Unlock()
1072
1073		hc.mu.Lock()
1074		if iteration%windowSize == 0 || maxSessionsInUse < currSessionsInUse {
1075			maxSessionsInUse = currSessionsInUse
1076		}
1077		sessionsToKeep := maxUint64(hc.pool.MinOpened,
1078			minUint64(currSessionsOpened, hc.pool.MaxIdle+maxSessionsInUse))
1079		ctx, cancel := context.WithTimeout(context.Background(), hc.sampleInterval)
1080		hc.maintainerCancel = cancel
1081		hc.mu.Unlock()
1082
1083		// Replenish or Shrink pool if needed.
1084		//
1085		// Note: we don't need to worry about pending create session requests,
1086		// we only need to sample the current sessions in use. The routines will
1087		// not try to create extra / delete creating sessions.
1088		if sessionsToKeep > currSessionsOpened {
1089			hc.replenishPool(ctx, sessionsToKeep)
1090		} else {
1091			hc.shrinkPool(ctx, sessionsToKeep)
1092		}
1093
1094		select {
1095		case <-ctx.Done():
1096		case <-hc.done:
1097			cancel()
1098		}
1099	}
1100}
1101
1102// replenishPool is run if numOpened is less than sessionsToKeep, timeouts on
1103// sampleInterval.
1104func (hc *healthChecker) replenishPool(ctx context.Context, sessionsToKeep uint64) {
1105	for {
1106		if ctx.Err() != nil {
1107			return
1108		}
1109
1110		p := hc.pool
1111		p.mu.Lock()
1112		// Take budget before the actual session creation.
1113		if sessionsToKeep <= p.numOpened {
1114			p.mu.Unlock()
1115			break
1116		}
1117		p.numOpened++
1118		recordStat(ctx, OpenSessionCount, int64(p.numOpened))
1119		p.createReqs++
1120		shouldPrepareWrite := p.shouldPrepareWrite()
1121		p.mu.Unlock()
1122		var (
1123			s   *session
1124			err error
1125		)
1126		if s, err = p.createSession(ctx); err != nil {
1127			log.Printf("Failed to create session, error: %v", toSpannerError(err))
1128			continue
1129		}
1130		if shouldPrepareWrite {
1131			if err = s.prepareForWrite(ctx); err != nil {
1132				p.recycle(s)
1133				log.Printf("Failed to prepare session, error: %v", toSpannerError(err))
1134				continue
1135			}
1136		}
1137		p.recycle(s)
1138	}
1139}
1140
1141// shrinkPool, scales down the session pool.
1142func (hc *healthChecker) shrinkPool(ctx context.Context, sessionsToKeep uint64) {
1143	for {
1144		if ctx.Err() != nil {
1145			return
1146		}
1147
1148		p := hc.pool
1149		p.mu.Lock()
1150
1151		if sessionsToKeep >= p.numOpened {
1152			p.mu.Unlock()
1153			break
1154		}
1155
1156		var s *session
1157		if p.idleList.Len() > 0 {
1158			s = p.idleList.Front().Value.(*session)
1159		} else if p.idleWriteList.Len() > 0 {
1160			s = p.idleWriteList.Front().Value.(*session)
1161		}
1162		p.mu.Unlock()
1163		if s != nil {
1164			// destroy session as expire.
1165			s.destroy(true)
1166		} else {
1167			break
1168		}
1169	}
1170}
1171
1172// shouldDropSession returns true if a particular error leads to the removal of
1173// a session
1174func shouldDropSession(err error) bool {
1175	if err == nil {
1176		return false
1177	}
1178	// If a Cloud Spanner can no longer locate the session (for example, if
1179	// session is garbage collected), then caller should not try to return the
1180	// session back into the session pool.
1181	//
1182	// TODO: once gRPC can return auxiliary error information, stop parsing the error message.
1183	if ErrCode(err) == codes.NotFound && strings.Contains(ErrDesc(err), "Session not found") {
1184		return true
1185	}
1186	return false
1187}
1188
1189// maxUint64 returns the maximum of two uint64.
1190func maxUint64(a, b uint64) uint64 {
1191	if a > b {
1192		return a
1193	}
1194	return b
1195}
1196
1197// minUint64 returns the minimum of two uint64.
1198func minUint64(a, b uint64) uint64 {
1199	if a > b {
1200		return b
1201	}
1202	return a
1203}
1204