1// Copyright (C) MongoDB, Inc. 2017-present.
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may
4// not use this file except in compliance with the License. You may obtain
5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6
7package mongo
8
9import (
10	"context"
11	"errors"
12	"fmt"
13	"reflect"
14	"strconv"
15	"time"
16
17	"go.mongodb.org/mongo-driver/bson"
18	"go.mongodb.org/mongo-driver/bson/bsoncodec"
19	"go.mongodb.org/mongo-driver/bson/primitive"
20	"go.mongodb.org/mongo-driver/mongo/description"
21	"go.mongodb.org/mongo-driver/mongo/options"
22	"go.mongodb.org/mongo-driver/mongo/readconcern"
23	"go.mongodb.org/mongo-driver/mongo/readpref"
24	"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
25	"go.mongodb.org/mongo-driver/x/mongo/driver"
26	"go.mongodb.org/mongo-driver/x/mongo/driver/operation"
27	"go.mongodb.org/mongo-driver/x/mongo/driver/session"
28)
29
30var (
31	// ErrMissingResumeToken indicates that a change stream notification from the server did not contain a resume token.
32	ErrMissingResumeToken = errors.New("cannot provide resume functionality when the resume token is missing")
33	// ErrNilCursor indicates that the underlying cursor for the change stream is nil.
34	ErrNilCursor = errors.New("cursor is nil")
35
36	minResumableLabelWireVersion int32 = 9 // Wire version at which the server includes the resumable error label
37	networkErrorLabel                  = "NetworkError"
38	resumableErrorLabel                = "ResumableChangeStreamError"
39	errorCursorNotFound          int32 = 43 // CursorNotFound error code
40
41	// Whitelist of error codes that are considered resumable.
42	resumableChangeStreamErrors = map[int32]struct{}{
43		6:     {}, // HostUnreachable
44		7:     {}, // HostNotFound
45		89:    {}, // NetworkTimeout
46		91:    {}, // ShutdownInProgress
47		189:   {}, // PrimarySteppedDown
48		262:   {}, // ExceededTimeLimit
49		9001:  {}, // SocketException
50		10107: {}, // NotMaster
51		11600: {}, // InterruptedAtShutdown
52		11602: {}, // InterruptedDueToReplStateChange
53		13435: {}, // NotMasterNoSlaveOk
54		13436: {}, // NotMasterOrSecondary
55		63:    {}, // StaleShardVersion
56		150:   {}, // StaleEpoch
57		13388: {}, // StaleConfig
58		234:   {}, // RetryChangeStream
59		133:   {}, // FailedToSatisfyReadPreference
60	}
61)
62
63// ChangeStream is used to iterate over a stream of events. Each event can be decoded into a Go type via the Decode
64// method or accessed as raw BSON via the Current field. For more information about change streams, see
65// https://docs.mongodb.com/manual/changeStreams/.
66type ChangeStream struct {
67	// Current is the BSON bytes of the current event. This property is only valid until the next call to Next or
68	// TryNext. If continued access is required, a copy must be made.
69	Current bson.Raw
70
71	aggregate     *operation.Aggregate
72	pipelineSlice []bsoncore.Document
73	cursor        changeStreamCursor
74	cursorOptions driver.CursorOptions
75	batch         []bsoncore.Document
76	resumeToken   bson.Raw
77	err           error
78	sess          *session.Client
79	client        *Client
80	registry      *bsoncodec.Registry
81	streamType    StreamType
82	options       *options.ChangeStreamOptions
83	selector      description.ServerSelector
84	operationTime *primitive.Timestamp
85	wireVersion   *description.VersionRange
86}
87
88type changeStreamConfig struct {
89	readConcern    *readconcern.ReadConcern
90	readPreference *readpref.ReadPref
91	client         *Client
92	registry       *bsoncodec.Registry
93	streamType     StreamType
94	collectionName string
95	databaseName   string
96	crypt          *driver.Crypt
97}
98
99func newChangeStream(ctx context.Context, config changeStreamConfig, pipeline interface{},
100	opts ...*options.ChangeStreamOptions) (*ChangeStream, error) {
101	if ctx == nil {
102		ctx = context.Background()
103	}
104
105	cs := &ChangeStream{
106		client:        config.client,
107		registry:      config.registry,
108		streamType:    config.streamType,
109		options:       options.MergeChangeStreamOptions(opts...),
110		selector:      description.ReadPrefSelector(config.readPreference),
111		cursorOptions: config.client.createBaseCursorOptions(),
112	}
113
114	cs.sess = sessionFromContext(ctx)
115	if cs.sess == nil && cs.client.sessionPool != nil {
116		cs.sess, cs.err = session.NewClientSession(cs.client.sessionPool, cs.client.id, session.Implicit)
117		if cs.err != nil {
118			return nil, cs.Err()
119		}
120	}
121	if cs.err = cs.client.validSession(cs.sess); cs.err != nil {
122		closeImplicitSession(cs.sess)
123		return nil, cs.Err()
124	}
125
126	cs.aggregate = operation.NewAggregate(nil).
127		ReadPreference(config.readPreference).ReadConcern(config.readConcern).
128		Deployment(cs.client.deployment).ClusterClock(cs.client.clock).
129		CommandMonitor(cs.client.monitor).Session(cs.sess).ServerSelector(cs.selector).Retry(driver.RetryNone).
130		Crypt(config.crypt)
131
132	if cs.options.Collation != nil {
133		cs.aggregate.Collation(bsoncore.Document(cs.options.Collation.ToDocument()))
134	}
135	if cs.options.BatchSize != nil {
136		cs.aggregate.BatchSize(*cs.options.BatchSize)
137		cs.cursorOptions.BatchSize = *cs.options.BatchSize
138	}
139	if cs.options.MaxAwaitTime != nil {
140		cs.cursorOptions.MaxTimeMS = int64(time.Duration(*cs.options.MaxAwaitTime) / time.Millisecond)
141	}
142
143	switch cs.streamType {
144	case ClientStream:
145		cs.aggregate.Database("admin")
146	case DatabaseStream:
147		cs.aggregate.Database(config.databaseName)
148	case CollectionStream:
149		cs.aggregate.Collection(config.collectionName).Database(config.databaseName)
150	default:
151		closeImplicitSession(cs.sess)
152		return nil, fmt.Errorf("must supply a valid StreamType in config, instead of %v", cs.streamType)
153	}
154
155	// When starting a change stream, cache startAfter as the first resume token if it is set. If not, cache
156	// resumeAfter. If neither is set, do not cache a resume token.
157	resumeToken := cs.options.StartAfter
158	if resumeToken == nil {
159		resumeToken = cs.options.ResumeAfter
160	}
161	var marshaledToken bson.Raw
162	if resumeToken != nil {
163		if marshaledToken, cs.err = bson.Marshal(resumeToken); cs.err != nil {
164			closeImplicitSession(cs.sess)
165			return nil, cs.Err()
166		}
167	}
168	cs.resumeToken = marshaledToken
169
170	if cs.err = cs.buildPipelineSlice(pipeline); cs.err != nil {
171		closeImplicitSession(cs.sess)
172		return nil, cs.Err()
173	}
174	var pipelineArr bsoncore.Document
175	pipelineArr, cs.err = cs.pipelineToBSON()
176	cs.aggregate.Pipeline(pipelineArr)
177
178	if cs.err = cs.executeOperation(ctx, false); cs.err != nil {
179		closeImplicitSession(cs.sess)
180		return nil, cs.Err()
181	}
182
183	return cs, cs.Err()
184}
185
186func (cs *ChangeStream) createOperationDeployment(server driver.Server, connection driver.Connection) driver.Deployment {
187	return &changeStreamDeployment{
188		topologyKind: cs.client.deployment.Kind(),
189		server:       server,
190		conn:         connection,
191	}
192}
193
194func (cs *ChangeStream) executeOperation(ctx context.Context, resuming bool) error {
195	var server driver.Server
196	var conn driver.Connection
197	var err error
198
199	if server, cs.err = cs.client.deployment.SelectServer(ctx, cs.selector); cs.err != nil {
200		return cs.Err()
201	}
202	if conn, cs.err = server.Connection(ctx); cs.err != nil {
203		return cs.Err()
204	}
205	defer conn.Close()
206	cs.wireVersion = conn.Description().WireVersion
207
208	cs.aggregate.Deployment(cs.createOperationDeployment(server, conn))
209
210	if resuming {
211		cs.replaceOptions(ctx, cs.wireVersion)
212
213		csOptDoc := cs.createPipelineOptionsDoc()
214		pipIdx, pipDoc := bsoncore.AppendDocumentStart(nil)
215		pipDoc = bsoncore.AppendDocumentElement(pipDoc, "$changeStream", csOptDoc)
216		if pipDoc, cs.err = bsoncore.AppendDocumentEnd(pipDoc, pipIdx); cs.err != nil {
217			return cs.Err()
218		}
219		cs.pipelineSlice[0] = pipDoc
220
221		var plArr bsoncore.Document
222		if plArr, cs.err = cs.pipelineToBSON(); cs.err != nil {
223			return cs.Err()
224		}
225		cs.aggregate.Pipeline(plArr)
226	}
227
228	if original := cs.aggregate.Execute(ctx); original != nil {
229		retryableRead := cs.client.retryReads && cs.wireVersion != nil && cs.wireVersion.Max >= 6
230		if !retryableRead {
231			cs.err = replaceErrors(original)
232			return cs.err
233		}
234
235		cs.err = original
236		switch tt := original.(type) {
237		case driver.Error:
238			if !tt.RetryableRead() {
239				break
240			}
241
242			server, err = cs.client.deployment.SelectServer(ctx, cs.selector)
243			if err != nil {
244				break
245			}
246
247			conn.Close()
248			conn, err = server.Connection(ctx)
249			if err != nil {
250				break
251			}
252			defer conn.Close()
253			cs.wireVersion = conn.Description().WireVersion
254
255			if cs.wireVersion == nil || cs.wireVersion.Max < 6 {
256				break
257			}
258
259			cs.aggregate.Deployment(cs.createOperationDeployment(server, conn))
260			cs.err = cs.aggregate.Execute(ctx)
261		}
262
263		if cs.err != nil {
264			cs.err = replaceErrors(cs.err)
265			return cs.Err()
266		}
267
268	}
269	cs.err = nil
270
271	cr := cs.aggregate.ResultCursorResponse()
272	cr.Server = server
273
274	cs.cursor, cs.err = driver.NewBatchCursor(cr, cs.sess, cs.client.clock, cs.cursorOptions)
275	if cs.err = replaceErrors(cs.err); cs.err != nil {
276		return cs.Err()
277	}
278
279	cs.updatePbrtFromCommand()
280	if cs.options.StartAtOperationTime == nil && cs.options.ResumeAfter == nil &&
281		cs.options.StartAfter == nil && cs.wireVersion.Max >= 7 &&
282		cs.emptyBatch() && cs.resumeToken == nil {
283		cs.operationTime = cs.sess.OperationTime
284	}
285
286	return cs.Err()
287}
288
289// Updates the post batch resume token after a successful aggregate or getMore operation.
290func (cs *ChangeStream) updatePbrtFromCommand() {
291	// Only cache the pbrt if an empty batch was returned and a pbrt was included
292	if pbrt := cs.cursor.PostBatchResumeToken(); cs.emptyBatch() && pbrt != nil {
293		cs.resumeToken = bson.Raw(pbrt)
294	}
295}
296
297func (cs *ChangeStream) storeResumeToken() error {
298	// If cs.Current is the last document in the batch and a pbrt is included, cache the pbrt
299	// Otherwise, cache the _id of the document
300	var tokenDoc bson.Raw
301	if len(cs.batch) == 0 {
302		if pbrt := cs.cursor.PostBatchResumeToken(); pbrt != nil {
303			tokenDoc = bson.Raw(pbrt)
304		}
305	}
306
307	if tokenDoc == nil {
308		var ok bool
309		tokenDoc, ok = cs.Current.Lookup("_id").DocumentOK()
310		if !ok {
311			_ = cs.Close(context.Background())
312			return ErrMissingResumeToken
313		}
314	}
315
316	cs.resumeToken = tokenDoc
317	return nil
318}
319
320func (cs *ChangeStream) buildPipelineSlice(pipeline interface{}) error {
321	val := reflect.ValueOf(pipeline)
322	if !val.IsValid() || !(val.Kind() == reflect.Slice) {
323		cs.err = errors.New("can only transform slices and arrays into aggregation pipelines, but got invalid")
324		return cs.err
325	}
326
327	cs.pipelineSlice = make([]bsoncore.Document, 0, val.Len()+1)
328
329	csIdx, csDoc := bsoncore.AppendDocumentStart(nil)
330	csDocTemp := cs.createPipelineOptionsDoc()
331	if cs.err != nil {
332		return cs.err
333	}
334	csDoc = bsoncore.AppendDocumentElement(csDoc, "$changeStream", csDocTemp)
335	csDoc, cs.err = bsoncore.AppendDocumentEnd(csDoc, csIdx)
336	if cs.err != nil {
337		return cs.err
338	}
339	cs.pipelineSlice = append(cs.pipelineSlice, csDoc)
340
341	for i := 0; i < val.Len(); i++ {
342		var elem []byte
343		elem, cs.err = transformBsoncoreDocument(cs.registry, val.Index(i).Interface(), true, fmt.Sprintf("pipeline stage :%v", i))
344		if cs.err != nil {
345			return cs.err
346		}
347
348		cs.pipelineSlice = append(cs.pipelineSlice, elem)
349	}
350
351	return cs.err
352}
353
354func (cs *ChangeStream) createPipelineOptionsDoc() bsoncore.Document {
355	plDocIdx, plDoc := bsoncore.AppendDocumentStart(nil)
356
357	if cs.streamType == ClientStream {
358		plDoc = bsoncore.AppendBooleanElement(plDoc, "allChangesForCluster", true)
359	}
360
361	if cs.options.FullDocument != nil {
362		plDoc = bsoncore.AppendStringElement(plDoc, "fullDocument", string(*cs.options.FullDocument))
363	}
364
365	if cs.options.ResumeAfter != nil {
366		var raDoc bsoncore.Document
367		raDoc, cs.err = transformBsoncoreDocument(cs.registry, cs.options.ResumeAfter, true, "resumeAfter")
368		if cs.err != nil {
369			return nil
370		}
371
372		plDoc = bsoncore.AppendDocumentElement(plDoc, "resumeAfter", raDoc)
373	}
374
375	if cs.options.StartAfter != nil {
376		var saDoc bsoncore.Document
377		saDoc, cs.err = transformBsoncoreDocument(cs.registry, cs.options.StartAfter, true, "startAfter")
378		if cs.err != nil {
379			return nil
380		}
381
382		plDoc = bsoncore.AppendDocumentElement(plDoc, "startAfter", saDoc)
383	}
384
385	if cs.options.StartAtOperationTime != nil {
386		plDoc = bsoncore.AppendTimestampElement(plDoc, "startAtOperationTime", cs.options.StartAtOperationTime.T, cs.options.StartAtOperationTime.I)
387	}
388
389	if plDoc, cs.err = bsoncore.AppendDocumentEnd(plDoc, plDocIdx); cs.err != nil {
390		return nil
391	}
392
393	return plDoc
394}
395
396func (cs *ChangeStream) pipelineToBSON() (bsoncore.Document, error) {
397	pipelineDocIdx, pipelineArr := bsoncore.AppendArrayStart(nil)
398	for i, doc := range cs.pipelineSlice {
399		pipelineArr = bsoncore.AppendDocumentElement(pipelineArr, strconv.Itoa(i), doc)
400	}
401	if pipelineArr, cs.err = bsoncore.AppendArrayEnd(pipelineArr, pipelineDocIdx); cs.err != nil {
402		return nil, cs.err
403	}
404	return pipelineArr, cs.err
405}
406
407func (cs *ChangeStream) replaceOptions(ctx context.Context, wireVersion *description.VersionRange) {
408	// Cached resume token: use the resume token as the resumeAfter option and set no other resume options
409	if cs.resumeToken != nil {
410		cs.options.SetResumeAfter(cs.resumeToken)
411		cs.options.SetStartAfter(nil)
412		cs.options.SetStartAtOperationTime(nil)
413		return
414	}
415
416	// No cached resume token but cached operation time: use the operation time as the startAtOperationTime option and
417	// set no other resume options
418	if (cs.sess.OperationTime != nil || cs.options.StartAtOperationTime != nil) && wireVersion.Max >= 7 {
419		opTime := cs.options.StartAtOperationTime
420		if cs.operationTime != nil {
421			opTime = cs.sess.OperationTime
422		}
423
424		cs.options.SetStartAtOperationTime(opTime)
425		cs.options.SetResumeAfter(nil)
426		cs.options.SetStartAfter(nil)
427		return
428	}
429
430	// No cached resume token or operation time: set none of the resume options
431	cs.options.SetResumeAfter(nil)
432	cs.options.SetStartAfter(nil)
433	cs.options.SetStartAtOperationTime(nil)
434}
435
436// ID returns the ID for this change stream, or 0 if the cursor has been closed or exhausted.
437func (cs *ChangeStream) ID() int64 {
438	if cs.cursor == nil {
439		return 0
440	}
441	return cs.cursor.ID()
442}
443
444// Decode will unmarshal the current event document into val and return any errors from the unmarshalling process
445// without any modification. If val is nil or is a typed nil, an error will be returned.
446func (cs *ChangeStream) Decode(val interface{}) error {
447	if cs.cursor == nil {
448		return ErrNilCursor
449	}
450
451	return bson.UnmarshalWithRegistry(cs.registry, cs.Current, val)
452}
453
454// Err returns the last error seen by the change stream, or nil if no errors has occurred.
455func (cs *ChangeStream) Err() error {
456	if cs.err != nil {
457		return replaceErrors(cs.err)
458	}
459	if cs.cursor == nil {
460		return nil
461	}
462
463	return replaceErrors(cs.cursor.Err())
464}
465
466// Close closes this change stream and the underlying cursor. Next and TryNext must not be called after Close has been
467// called. Close is idempotent. After the first call, any subsequent calls will not change the state.
468func (cs *ChangeStream) Close(ctx context.Context) error {
469	if ctx == nil {
470		ctx = context.Background()
471	}
472
473	defer closeImplicitSession(cs.sess)
474
475	if cs.cursor == nil {
476		return nil // cursor is already closed
477	}
478
479	cs.err = replaceErrors(cs.cursor.Close(ctx))
480	cs.cursor = nil
481	return cs.Err()
482}
483
484// ResumeToken returns the last cached resume token for this change stream, or nil if a resume token has not been
485// stored.
486func (cs *ChangeStream) ResumeToken() bson.Raw {
487	return cs.resumeToken
488}
489
490// Next gets the next event for this change stream. It returns true if there were no errors and the next event document
491// is available.
492//
493// Next blocks until an event is available, an error occurs, or ctx expires. If ctx expires, the error
494// will be set to ctx.Err(). In an error case, Next will return false.
495//
496// If Next returns false, subsequent calls will also return false.
497func (cs *ChangeStream) Next(ctx context.Context) bool {
498	return cs.next(ctx, false)
499}
500
501// TryNext attempts to get the next event for this change stream. It returns true if there were no errors and the next
502// event document is available.
503//
504// TryNext returns false if the change stream is closed by the server, an error occurs when getting changes from the
505// server, the next change is not yet available, or ctx expires. If ctx expires, the error will be set to ctx.Err().
506//
507// If TryNext returns false and an error occurred or the change stream was closed
508// (i.e. cs.Err() != nil || cs.ID() == 0), subsequent attempts will also return false. Otherwise, it is safe to call
509// TryNext again until a change is available.
510//
511// This method requires driver version >= 1.2.0.
512func (cs *ChangeStream) TryNext(ctx context.Context) bool {
513	return cs.next(ctx, true)
514}
515
516func (cs *ChangeStream) next(ctx context.Context, nonBlocking bool) bool {
517	// return false right away if the change stream has already errored or if cursor is closed.
518	if cs.err != nil {
519		return false
520	}
521
522	if ctx == nil {
523		ctx = context.Background()
524	}
525
526	if len(cs.batch) == 0 {
527		cs.loopNext(ctx, nonBlocking)
528		if cs.err != nil {
529			cs.err = replaceErrors(cs.err)
530			return false
531		}
532		if len(cs.batch) == 0 {
533			return false
534		}
535	}
536
537	// successfully got non-empty batch
538	cs.Current = bson.Raw(cs.batch[0])
539	cs.batch = cs.batch[1:]
540	if cs.err = cs.storeResumeToken(); cs.err != nil {
541		return false
542	}
543	return true
544}
545
546func (cs *ChangeStream) loopNext(ctx context.Context, nonBlocking bool) {
547	for {
548		if cs.cursor == nil {
549			return
550		}
551
552		if cs.cursor.Next(ctx) {
553			// non-empty batch returned
554			cs.batch, cs.err = cs.cursor.Batch().Documents()
555			return
556		}
557
558		cs.err = replaceErrors(cs.cursor.Err())
559		if cs.err == nil {
560			// Check if cursor is alive
561			if cs.ID() == 0 {
562				return
563			}
564
565			// If a getMore was done but the batch was empty, the batch cursor will return false with no error.
566			// Update the tracked resume token to catch the post batch resume token from the server response.
567			cs.updatePbrtFromCommand()
568			if nonBlocking {
569				// stop after a successful getMore, even though the batch was empty
570				return
571			}
572			continue // loop getMore until a non-empty batch is returned or an error occurs
573		}
574
575		if !cs.isResumableError() {
576			return
577		}
578
579		// ignore error from cursor close because if the cursor is deleted or errors we tried to close it and will remake and try to get next batch
580		_ = cs.cursor.Close(ctx)
581		if cs.err = cs.executeOperation(ctx, true); cs.err != nil {
582			return
583		}
584	}
585}
586
587func (cs *ChangeStream) isResumableError() bool {
588	commandErr, ok := cs.err.(CommandError)
589	if !ok || commandErr.HasErrorLabel(networkErrorLabel) {
590		// All non-server errors or network errors are resumable.
591		return true
592	}
593
594	if commandErr.Code == errorCursorNotFound {
595		return true
596	}
597
598	// For wire versions 9 and above, a server error is resumable if it has the ResumableChangeStreamError label.
599	if cs.wireVersion != nil && cs.wireVersion.Includes(minResumableLabelWireVersion) {
600		return commandErr.HasErrorLabel(resumableErrorLabel)
601	}
602
603	// For wire versions below 9, a server error is resumable if its code is on the whitelist.
604	_, resumable := resumableChangeStreamErrors[commandErr.Code]
605	return resumable
606}
607
608// Returns true if the underlying cursor's batch is empty
609func (cs *ChangeStream) emptyBatch() bool {
610	return cs.cursor.Batch().Empty()
611}
612
613// StreamType represents the cluster type against which a ChangeStream was created.
614type StreamType uint8
615
616// These constants represent valid change stream types. A change stream can be initialized over a collection, all
617// collections in a database, or over a cluster.
618const (
619	CollectionStream StreamType = iota
620	DatabaseStream
621	ClientStream
622)
623