1/* 2Copyright 2019 Google LLC 3 4Licensed under the Apache License, Version 2.0 (the "License"); 5you may not use this file except in compliance with the License. 6You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10Unless required by applicable law or agreed to in writing, software 11distributed under the License is distributed on an "AS IS" BASIS, 12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13See the License for the specific language governing permissions and 14limitations under the License. 15*/ 16 17package spanner 18 19import ( 20 "context" 21 "fmt" 22 "log" 23 "reflect" 24 "sync" 25 "time" 26 27 "cloud.google.com/go/internal/trace" 28 "cloud.google.com/go/internal/version" 29 vkit "cloud.google.com/go/spanner/apiv1" 30 "github.com/googleapis/gax-go/v2" 31 "google.golang.org/api/option" 32 gtransport "google.golang.org/api/transport/grpc" 33 sppb "google.golang.org/genproto/googleapis/spanner/v1" 34 "google.golang.org/grpc/codes" 35 "google.golang.org/grpc/metadata" 36) 37 38var cidGen = newClientIDGenerator() 39 40type clientIDGenerator struct { 41 mu sync.Mutex 42 ids map[string]int 43} 44 45func newClientIDGenerator() *clientIDGenerator { 46 return &clientIDGenerator{ids: make(map[string]int)} 47} 48 49func (cg *clientIDGenerator) nextID(database string) string { 50 cg.mu.Lock() 51 defer cg.mu.Unlock() 52 var id int 53 if val, ok := cg.ids[database]; ok { 54 id = val + 1 55 } else { 56 id = 1 57 } 58 cg.ids[database] = id 59 return fmt.Sprintf("client-%d", id) 60} 61 62// sessionConsumer is passed to the batchCreateSessions method and will receive 63// the sessions that are created as they become available. A sessionConsumer 64// implementation must be safe for concurrent use. 65// 66// The interface is implemented by sessionPool and is used for testing the 67// sessionClient. 68type sessionConsumer interface { 69 // sessionReady is called when a session has been created and is ready for 70 // use. 71 sessionReady(s *session) 72 73 // sessionCreationFailed is called when the creation of a sub-batch of 74 // sessions failed. The numSessions argument specifies the number of 75 // sessions that could not be created as a result of this error. A 76 // consumer may receive multiple errors per batch. 77 sessionCreationFailed(err error, numSessions int32) 78} 79 80// sessionClient creates sessions for a database, either in batches or one at a 81// time. Each session will be affiliated with a gRPC channel. sessionClient 82// will ensure that the sessions that are created are evenly distributed over 83// all available channels. 84type sessionClient struct { 85 mu sync.Mutex 86 closed bool 87 88 connPool gtransport.ConnPool 89 database string 90 id string 91 sessionLabels map[string]string 92 md metadata.MD 93 batchTimeout time.Duration 94 logger *log.Logger 95 callOptions *vkit.CallOptions 96} 97 98// newSessionClient creates a session client to use for a database. 99func newSessionClient(connPool gtransport.ConnPool, database string, sessionLabels map[string]string, md metadata.MD, logger *log.Logger, callOptions *vkit.CallOptions) *sessionClient { 100 return &sessionClient{ 101 connPool: connPool, 102 database: database, 103 id: cidGen.nextID(database), 104 sessionLabels: sessionLabels, 105 md: md, 106 batchTimeout: time.Minute, 107 logger: logger, 108 callOptions: callOptions, 109 } 110} 111 112func (sc *sessionClient) close() error { 113 sc.mu.Lock() 114 defer sc.mu.Unlock() 115 sc.closed = true 116 return sc.connPool.Close() 117} 118 119// createSession creates one session for the database of the sessionClient. The 120// session is created using one synchronous RPC. 121func (sc *sessionClient) createSession(ctx context.Context) (*session, error) { 122 sc.mu.Lock() 123 if sc.closed { 124 sc.mu.Unlock() 125 return nil, spannerErrorf(codes.FailedPrecondition, "SessionClient is closed") 126 } 127 sc.mu.Unlock() 128 client, err := sc.nextClient() 129 if err != nil { 130 return nil, err 131 } 132 ctx = contextWithOutgoingMetadata(ctx, sc.md) 133 sid, err := client.CreateSession(ctx, &sppb.CreateSessionRequest{ 134 Database: sc.database, 135 Session: &sppb.Session{Labels: sc.sessionLabels}, 136 }) 137 if err != nil { 138 return nil, ToSpannerError(err) 139 } 140 return &session{valid: true, client: client, id: sid.Name, createTime: time.Now(), md: sc.md, logger: sc.logger}, nil 141} 142 143// batchCreateSessions creates a batch of sessions for the database of the 144// sessionClient and returns these to the given sessionConsumer. 145// 146// createSessionCount is the number of sessions that should be created. The 147// sessionConsumer is guaranteed to receive the requested number of sessions if 148// no error occurs. If one or more errors occur, the sessionConsumer will 149// receive any number of sessions + any number of errors, where each error will 150// include the number of sessions that could not be created as a result of the 151// error. The sum of returned sessions and errored sessions will be equal to 152// the number of requested sessions. 153// If distributeOverChannels is true, the sessions will be equally distributed 154// over all the channels that are in use by the client. 155func (sc *sessionClient) batchCreateSessions(createSessionCount int32, distributeOverChannels bool, consumer sessionConsumer) error { 156 var sessionCountPerChannel int32 157 var remainder int32 158 if distributeOverChannels { 159 // The sessions that we create should be evenly distributed over all the 160 // channels (gapic clients) that are used by the client. Each gapic client 161 // will do a request for a fraction of the total. 162 sessionCountPerChannel = createSessionCount / int32(sc.connPool.Num()) 163 // The remainder of the calculation will be added to the number of sessions 164 // that will be created for the first channel, to ensure that we create the 165 // exact number of requested sessions. 166 remainder = createSessionCount % int32(sc.connPool.Num()) 167 } else { 168 sessionCountPerChannel = createSessionCount 169 } 170 sc.mu.Lock() 171 defer sc.mu.Unlock() 172 if sc.closed { 173 return spannerErrorf(codes.FailedPrecondition, "SessionClient is closed") 174 } 175 // Spread the session creation over all available gRPC channels. Spanner 176 // will maintain server side caches for a session on the gRPC channel that 177 // is used by the session. A session should therefore always use the same 178 // channel, and the sessions should be as evenly distributed as possible 179 // over the channels. 180 var numBeingCreated int32 181 for i := 0; i < sc.connPool.Num() && numBeingCreated < createSessionCount; i++ { 182 client, err := sc.nextClient() 183 if err != nil { 184 return err 185 } 186 // Determine the number of sessions that should be created for this 187 // channel. The createCount for the first channel will be increased 188 // with the remainder of the division of the total number of sessions 189 // with the number of channels. All other channels will just use the 190 // result of the division over all channels. 191 createCountForChannel := sessionCountPerChannel 192 if i == 0 { 193 // We add the remainder to the first gRPC channel we use. We could 194 // also spread the remainder over all channels, but this ensures 195 // that small batches of sessions (i.e. less than numChannels) are 196 // created in one RPC. 197 createCountForChannel += remainder 198 } 199 if createCountForChannel > 0 { 200 go sc.executeBatchCreateSessions(client, createCountForChannel, sc.sessionLabels, sc.md, consumer) 201 numBeingCreated += createCountForChannel 202 } 203 } 204 return nil 205} 206 207// executeBatchCreateSessions executes the gRPC call for creating a batch of 208// sessions. 209func (sc *sessionClient) executeBatchCreateSessions(client *vkit.Client, createCount int32, labels map[string]string, md metadata.MD, consumer sessionConsumer) { 210 ctx, cancel := context.WithTimeout(context.Background(), sc.batchTimeout) 211 defer cancel() 212 ctx = contextWithOutgoingMetadata(ctx, sc.md) 213 214 ctx = trace.StartSpan(ctx, "cloud.google.com/go/spanner.BatchCreateSessions") 215 defer func() { trace.EndSpan(ctx, nil) }() 216 trace.TracePrintf(ctx, nil, "Creating a batch of %d sessions", createCount) 217 remainingCreateCount := createCount 218 for { 219 sc.mu.Lock() 220 closed := sc.closed 221 sc.mu.Unlock() 222 if closed { 223 err := spannerErrorf(codes.Canceled, "Session client closed") 224 trace.TracePrintf(ctx, nil, "Session client closed while creating a batch of %d sessions: %v", createCount, err) 225 consumer.sessionCreationFailed(err, remainingCreateCount) 226 break 227 } 228 if ctx.Err() != nil { 229 trace.TracePrintf(ctx, nil, "Context error while creating a batch of %d sessions: %v", createCount, ctx.Err()) 230 consumer.sessionCreationFailed(ToSpannerError(ctx.Err()), remainingCreateCount) 231 break 232 } 233 response, err := client.BatchCreateSessions(ctx, &sppb.BatchCreateSessionsRequest{ 234 SessionCount: remainingCreateCount, 235 Database: sc.database, 236 SessionTemplate: &sppb.Session{Labels: labels}, 237 }) 238 if err != nil { 239 trace.TracePrintf(ctx, nil, "Error creating a batch of %d sessions: %v", remainingCreateCount, err) 240 consumer.sessionCreationFailed(ToSpannerError(err), remainingCreateCount) 241 break 242 } 243 actuallyCreated := int32(len(response.Session)) 244 trace.TracePrintf(ctx, nil, "Received a batch of %d sessions", actuallyCreated) 245 for _, s := range response.Session { 246 consumer.sessionReady(&session{valid: true, client: client, id: s.Name, createTime: time.Now(), md: md, logger: sc.logger}) 247 } 248 if actuallyCreated < remainingCreateCount { 249 // Spanner could return less sessions than requested. In that case, we 250 // should do another call using the same gRPC channel. 251 remainingCreateCount -= actuallyCreated 252 } else { 253 trace.TracePrintf(ctx, nil, "Finished creating %d sessions", createCount) 254 break 255 } 256 } 257} 258 259func (sc *sessionClient) sessionWithID(id string) (*session, error) { 260 sc.mu.Lock() 261 defer sc.mu.Unlock() 262 client, err := sc.nextClient() 263 if err != nil { 264 return nil, err 265 } 266 return &session{valid: true, client: client, id: id, createTime: time.Now(), md: sc.md, logger: sc.logger}, nil 267} 268 269// nextClient returns the next gRPC client to use for session creation. The 270// client is set on the session, and used by all subsequent gRPC calls on the 271// session. Using the same channel for all gRPC calls for a session ensures the 272// optimal usage of server side caches. 273func (sc *sessionClient) nextClient() (*vkit.Client, error) { 274 // This call should never return an error as we are passing in an existing 275 // connection, so we can safely ignore it. 276 client, err := vkit.NewClient(context.Background(), option.WithGRPCConn(sc.connPool.Conn())) 277 if err != nil { 278 return nil, err 279 } 280 client.SetGoogleClientInfo("gccl", version.Repo) 281 if sc.callOptions != nil { 282 client.CallOptions = mergeCallOptions(client.CallOptions, sc.callOptions) 283 } 284 return client, nil 285} 286 287// mergeCallOptions merges two CallOptions into one and the first argument has 288// a lower order of precedence than the second one. 289func mergeCallOptions(a *vkit.CallOptions, b *vkit.CallOptions) *vkit.CallOptions { 290 res := &vkit.CallOptions{} 291 resVal := reflect.ValueOf(res).Elem() 292 aVal := reflect.ValueOf(a).Elem() 293 bVal := reflect.ValueOf(b).Elem() 294 295 t := aVal.Type() 296 297 for i := 0; i < aVal.NumField(); i++ { 298 fieldName := t.Field(i).Name 299 300 aFieldVal := aVal.Field(i).Interface().([]gax.CallOption) 301 bFieldVal := bVal.Field(i).Interface().([]gax.CallOption) 302 303 merged := append(aFieldVal, bFieldVal...) 304 resVal.FieldByName(fieldName).Set(reflect.ValueOf(merged)) 305 } 306 return res 307} 308