1// Copyright 2018 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package firestore
16
17import (
18	"context"
19	"errors"
20	"fmt"
21	"io"
22	"log"
23	"sort"
24	"time"
25
26	"cloud.google.com/go/internal/btree"
27	"github.com/golang/protobuf/ptypes"
28	gax "github.com/googleapis/gax-go/v2"
29	pb "google.golang.org/genproto/googleapis/firestore/v1"
30	"google.golang.org/grpc/codes"
31	"google.golang.org/grpc/status"
32)
33
34// LogWatchStreams controls whether watch stream status changes are logged.
35// This feature is EXPERIMENTAL and may disappear at any time.
36var LogWatchStreams = false
37
38// DocumentChangeKind describes the kind of change to a document between
39// query snapshots.
40type DocumentChangeKind int
41
42const (
43	// DocumentAdded indicates that the document was added for the first time.
44	DocumentAdded DocumentChangeKind = iota
45	// DocumentRemoved indicates that the document was removed.
46	DocumentRemoved
47	// DocumentModified indicates that the document was modified.
48	DocumentModified
49)
50
51// A DocumentChange describes the change to a document from one query snapshot to the next.
52type DocumentChange struct {
53	Kind DocumentChangeKind
54	Doc  *DocumentSnapshot
55	// The zero-based index of the document in the sequence of query results prior to this change,
56	// or -1 if the document was not present.
57	OldIndex int
58	// The zero-based index of the document in the sequence of query results after this change,
59	// or -1 if the document is no longer present.
60	NewIndex int
61}
62
63// Implementation of realtime updates (a.k.a. watch).
64// This code is closely based on the Node.js implementation,
65// https://github.com/googleapis/nodejs-firestore/blob/master/src/watch.js.
66
67// The sole target ID for all streams from this client.
68// Variable for testing.
69var watchTargetID int32 = 'g' + 'o'
70
71var defaultBackoff = gax.Backoff{
72	// Values from https://github.com/googleapis/nodejs-firestore/blob/master/src/backoff.js.
73	Initial:    1 * time.Second,
74	Max:        60 * time.Second,
75	Multiplier: 1.5,
76}
77
78// not goroutine-safe
79type watchStream struct {
80	ctx         context.Context
81	c           *Client
82	lc          pb.Firestore_ListenClient                 // the gRPC stream
83	target      *pb.Target                                // document or query being watched
84	backoff     gax.Backoff                               // for stream retries
85	err         error                                     // sticky permanent error
86	readTime    time.Time                                 // time of most recent snapshot
87	current     bool                                      // saw CURRENT, but not RESET; precondition for a snapshot
88	hasReturned bool                                      // have we returned a snapshot yet?
89	compare     func(a, b *DocumentSnapshot) (int, error) // compare documents according to query
90
91	// An ordered tree where DocumentSnapshots are the keys.
92	docTree *btree.BTree
93	// Map of document name to DocumentSnapshot for the last returned snapshot.
94	docMap map[string]*DocumentSnapshot
95	// Map of document name to DocumentSnapshot for accumulated changes for the current snapshot.
96	// A nil value means the document was removed.
97	changeMap map[string]*DocumentSnapshot
98}
99
100func newWatchStreamForDocument(ctx context.Context, dr *DocumentRef) *watchStream {
101	// A single document is always equal to itself.
102	compare := func(_, _ *DocumentSnapshot) (int, error) { return 0, nil }
103	return newWatchStream(ctx, dr.Parent.c, compare, &pb.Target{
104		TargetType: &pb.Target_Documents{
105			Documents: &pb.Target_DocumentsTarget{Documents: []string{dr.Path}},
106		},
107		TargetId: watchTargetID,
108	})
109}
110
111func newWatchStreamForQuery(ctx context.Context, q Query) (*watchStream, error) {
112	qp, err := q.toProto()
113	if err != nil {
114		return nil, err
115	}
116	target := &pb.Target{
117		TargetType: &pb.Target_Query{
118			Query: &pb.Target_QueryTarget{
119				Parent:    q.parentPath,
120				QueryType: &pb.Target_QueryTarget_StructuredQuery{qp},
121			},
122		},
123		TargetId: watchTargetID,
124	}
125	return newWatchStream(ctx, q.c, q.compareFunc(), target), nil
126}
127
128const btreeDegree = 4
129
130func newWatchStream(ctx context.Context, c *Client, compare func(_, _ *DocumentSnapshot) (int, error), target *pb.Target) *watchStream {
131	w := &watchStream{
132		ctx:       ctx,
133		c:         c,
134		compare:   compare,
135		target:    target,
136		backoff:   defaultBackoff,
137		docMap:    map[string]*DocumentSnapshot{},
138		changeMap: map[string]*DocumentSnapshot{},
139	}
140	w.docTree = btree.New(btreeDegree, func(a, b interface{}) bool {
141		return w.less(a.(*DocumentSnapshot), b.(*DocumentSnapshot))
142	})
143	return w
144}
145
146func (s *watchStream) less(a, b *DocumentSnapshot) bool {
147	c, err := s.compare(a, b)
148	if err != nil {
149		s.err = err
150		return false
151	}
152	return c < 0
153}
154
155// Once nextSnapshot returns an error, it will always return the same error.
156func (s *watchStream) nextSnapshot() (*btree.BTree, []DocumentChange, time.Time, error) {
157	if s.err != nil {
158		return nil, nil, time.Time{}, s.err
159	}
160	var changes []DocumentChange
161	for {
162		// Process messages until we are in a consistent state.
163		for !s.handleNextMessage() {
164		}
165		if s.err != nil {
166			_ = s.close() // ignore error
167			return nil, nil, time.Time{}, s.err
168		}
169		var newDocTree *btree.BTree
170		newDocTree, changes = s.computeSnapshot(s.docTree, s.docMap, s.changeMap, s.readTime)
171		if s.err != nil {
172			return nil, nil, time.Time{}, s.err
173		}
174		// Only return a snapshot if something has changed, or this is the first snapshot.
175		if !s.hasReturned || newDocTree != s.docTree {
176			s.docTree = newDocTree
177			break
178		}
179	}
180	s.changeMap = map[string]*DocumentSnapshot{}
181	s.hasReturned = true
182	return s.docTree, changes, s.readTime, nil
183}
184
185// Read a message from the stream and handle it. Return true when
186// we're in a consistent state, or there is a permanent error.
187func (s *watchStream) handleNextMessage() bool {
188	res, err := s.recv()
189	if err != nil {
190		s.err = err
191		// Errors returned by recv are permanent.
192		return true
193	}
194	switch r := res.ResponseType.(type) {
195	case *pb.ListenResponse_TargetChange:
196		return s.handleTargetChange(r.TargetChange)
197
198	case *pb.ListenResponse_DocumentChange:
199		name := r.DocumentChange.Document.Name
200		s.logf("DocumentChange %q", name)
201		if hasWatchTargetID(r.DocumentChange.TargetIds) { // document changed
202			ref, err := pathToDoc(name, s.c)
203			if err == nil {
204				s.changeMap[name], err = newDocumentSnapshot(ref, r.DocumentChange.Document, s.c, nil)
205			}
206			if err != nil {
207				s.err = err
208				return true
209			}
210		} else if hasWatchTargetID(r.DocumentChange.RemovedTargetIds) { // document removed
211			s.changeMap[name] = nil
212		}
213
214	case *pb.ListenResponse_DocumentDelete:
215		s.logf("Delete %q", r.DocumentDelete.Document)
216		s.changeMap[r.DocumentDelete.Document] = nil
217
218	case *pb.ListenResponse_DocumentRemove:
219		s.logf("Remove %q", r.DocumentRemove.Document)
220		s.changeMap[r.DocumentRemove.Document] = nil
221
222	case *pb.ListenResponse_Filter:
223		s.logf("Filter %d", r.Filter.Count)
224		if int(r.Filter.Count) != s.currentSize() {
225			s.resetDocs() // Remove all the current results.
226			// The filter didn't match; close the stream so it will be re-opened on the next
227			// call to nextSnapshot.
228			_ = s.close() // ignore error
229			s.lc = nil
230		}
231
232	default:
233		s.err = fmt.Errorf("unknown response type %T", r)
234		return true
235	}
236	return false
237}
238
239// Return true if in a consistent state, or there is a permanent error.
240func (s *watchStream) handleTargetChange(tc *pb.TargetChange) bool {
241	switch tc.TargetChangeType {
242	case pb.TargetChange_NO_CHANGE:
243		s.logf("TargetNoChange %d %v", len(tc.TargetIds), tc.ReadTime)
244		if len(tc.TargetIds) == 0 && tc.ReadTime != nil && s.current {
245			// Everything is up-to-date, so we are ready to return a snapshot.
246			rt, err := ptypes.Timestamp(tc.ReadTime)
247			if err != nil {
248				s.err = err
249				return true
250			}
251			s.readTime = rt
252			s.target.ResumeType = &pb.Target_ResumeToken{tc.ResumeToken}
253			return true
254		}
255
256	case pb.TargetChange_ADD:
257		s.logf("TargetAdd")
258		if tc.TargetIds[0] != watchTargetID {
259			s.err = errors.New("unexpected target ID sent by server")
260			return true
261		}
262
263	case pb.TargetChange_REMOVE:
264		s.logf("TargetRemove")
265		// We should never see a remove.
266		if tc.Cause != nil {
267			s.err = status.Error(codes.Code(tc.Cause.Code), tc.Cause.Message)
268		} else {
269			s.err = status.Error(codes.Internal, "firestore: client saw REMOVE")
270		}
271		return true
272
273	// The targets reflect all changes committed before the targets were added
274	// to the stream.
275	case pb.TargetChange_CURRENT:
276		s.logf("TargetCurrent")
277		s.current = true
278
279	// The targets have been reset, and a new initial state for the targets will be
280	// returned in subsequent changes. Whatever changes have happened so far no
281	// longer matter.
282	case pb.TargetChange_RESET:
283		s.logf("TargetReset")
284		s.resetDocs()
285
286	default:
287		s.err = fmt.Errorf("firestore: unknown TargetChange type %s", tc.TargetChangeType)
288		return true
289	}
290	// If we see a resume token and our watch ID is affected, we assume the stream
291	// is now healthy, so we reset our backoff time to the minimum.
292	if tc.ResumeToken != nil && (len(tc.TargetIds) == 0 || hasWatchTargetID(tc.TargetIds)) {
293		s.backoff = defaultBackoff
294	}
295	return false // not in a consistent state, keep receiving
296}
297
298func (s *watchStream) resetDocs() {
299	s.target.ResumeType = nil // clear resume token
300	s.current = false
301	s.changeMap = map[string]*DocumentSnapshot{}
302	// Mark each document as deleted. If documents are not deleted, they
303	// will be send again by the server.
304	it := s.docTree.BeforeIndex(0)
305	for it.Next() {
306		s.changeMap[it.Key.(*DocumentSnapshot).Ref.Path] = nil
307	}
308}
309
310func (s *watchStream) currentSize() int {
311	_, adds, deletes := extractChanges(s.docMap, s.changeMap)
312	return len(s.docMap) + len(adds) - len(deletes)
313}
314
315// Return the changes that have occurred since the last snapshot.
316func extractChanges(docMap, changeMap map[string]*DocumentSnapshot) (updates, adds []*DocumentSnapshot, deletes []string) {
317	for name, doc := range changeMap {
318		switch {
319		case doc == nil:
320			if _, ok := docMap[name]; ok {
321				deletes = append(deletes, name)
322			}
323		case docMap[name] != nil:
324			updates = append(updates, doc)
325		default:
326			adds = append(adds, doc)
327		}
328	}
329	return updates, adds, deletes
330}
331
332// For development only.
333// TODO(jba): remove.
334func assert(b bool) {
335	if !b {
336		panic("assertion failed")
337	}
338}
339
340// Applies the mutations in changeMap to both the document tree and the
341// document lookup map. Modifies docMap in place and returns a new docTree.
342// If there were no changes, returns docTree unmodified.
343func (s *watchStream) computeSnapshot(docTree *btree.BTree, docMap, changeMap map[string]*DocumentSnapshot, readTime time.Time) (*btree.BTree, []DocumentChange) {
344	var changes []DocumentChange
345	updatedTree := docTree
346	assert(docTree.Len() == len(docMap))
347	updates, adds, deletes := extractChanges(docMap, changeMap)
348	if len(adds) > 0 || len(deletes) > 0 {
349		updatedTree = docTree.Clone()
350	}
351	// Process the sorted changes in the order that is expected by our clients
352	// (removals, additions, and then modifications). We also need to sort the
353	// individual changes to assure that oldIndex/newIndex keep incrementing.
354	deldocs := make([]*DocumentSnapshot, len(deletes))
355	for i, d := range deletes {
356		deldocs[i] = docMap[d]
357	}
358	sort.Sort(byLess{deldocs, s.less})
359	for _, oldDoc := range deldocs {
360		assert(oldDoc != nil)
361		delete(docMap, oldDoc.Ref.Path)
362		_, oldi := updatedTree.GetWithIndex(oldDoc)
363		// TODO(jba): have btree.Delete return old index
364		_, found := updatedTree.Delete(oldDoc)
365		assert(found)
366		changes = append(changes, DocumentChange{
367			Kind:     DocumentRemoved,
368			Doc:      oldDoc,
369			OldIndex: oldi,
370			NewIndex: -1,
371		})
372	}
373	sort.Sort(byLess{adds, s.less})
374	for _, newDoc := range adds {
375		name := newDoc.Ref.Path
376		assert(docMap[name] == nil)
377		newDoc.ReadTime = readTime
378		docMap[name] = newDoc
379		updatedTree.Set(newDoc, nil)
380		// TODO(jba): change btree so Set returns index as second value.
381		_, newi := updatedTree.GetWithIndex(newDoc)
382		changes = append(changes, DocumentChange{
383			Kind:     DocumentAdded,
384			Doc:      newDoc,
385			OldIndex: -1,
386			NewIndex: newi,
387		})
388	}
389	sort.Sort(byLess{updates, s.less})
390	for _, newDoc := range updates {
391		name := newDoc.Ref.Path
392		oldDoc := docMap[name]
393		assert(oldDoc != nil)
394		if newDoc.UpdateTime.Equal(oldDoc.UpdateTime) {
395			continue
396		}
397		if updatedTree == docTree {
398			updatedTree = docTree.Clone()
399		}
400		newDoc.ReadTime = readTime
401		docMap[name] = newDoc
402		_, oldi := updatedTree.GetWithIndex(oldDoc)
403		updatedTree.Delete(oldDoc)
404		updatedTree.Set(newDoc, nil)
405		_, newi := updatedTree.GetWithIndex(newDoc)
406		changes = append(changes, DocumentChange{
407			Kind:     DocumentModified,
408			Doc:      newDoc,
409			OldIndex: oldi,
410			NewIndex: newi,
411		})
412	}
413	assert(updatedTree.Len() == len(docMap))
414	return updatedTree, changes
415}
416
417type byLess struct {
418	s    []*DocumentSnapshot
419	less func(a, b *DocumentSnapshot) bool
420}
421
422func (b byLess) Len() int           { return len(b.s) }
423func (b byLess) Swap(i, j int)      { b.s[i], b.s[j] = b.s[j], b.s[i] }
424func (b byLess) Less(i, j int) bool { return b.less(b.s[i], b.s[j]) }
425
426func hasWatchTargetID(ids []int32) bool {
427	for _, id := range ids {
428		if id == watchTargetID {
429			return true
430		}
431	}
432	return false
433}
434
435func (s *watchStream) logf(format string, args ...interface{}) {
436	if LogWatchStreams {
437		log.Printf(format, args...)
438	}
439}
440
441// Close the stream. From this point on, calls to nextSnapshot will return
442// io.EOF, or the error from CloseSend.
443func (s *watchStream) stop() {
444	err := s.close()
445	if s.err != nil { // don't change existing error
446		return
447	}
448	if err != nil {
449		s.err = err
450	}
451	s.err = io.EOF // normal shutdown
452}
453
454func (s *watchStream) close() error {
455	if s.lc == nil {
456		return nil
457	}
458	return s.lc.CloseSend()
459}
460
461// recv receives the next message from the stream. It also handles opening the stream
462// initially, and reopening it on non-permanent errors.
463// recv doesn't have to be goroutine-safe.
464func (s *watchStream) recv() (*pb.ListenResponse, error) {
465	var err error
466	for {
467		if s.lc == nil {
468			s.lc, err = s.open()
469			if err != nil {
470				// Do not retry if open fails.
471				return nil, err
472			}
473		}
474		res, err := s.lc.Recv()
475		if err == nil || isPermanentWatchError(err) {
476			return res, err
477		}
478		// Non-permanent error. Sleep and retry.
479		s.changeMap = map[string]*DocumentSnapshot{} // clear changeMap
480		dur := s.backoff.Pause()
481		// If we're out of quota, wait a long time before retrying.
482		if status.Code(err) == codes.ResourceExhausted {
483			dur = s.backoff.Max
484		}
485		if err := sleep(s.ctx, dur); err != nil {
486			return nil, err
487		}
488		s.lc = nil
489	}
490}
491
492func (s *watchStream) open() (pb.Firestore_ListenClient, error) {
493	dbPath := s.c.path()
494	lc, err := s.c.c.Listen(withResourceHeader(s.ctx, dbPath))
495	if err == nil {
496		err = lc.Send(&pb.ListenRequest{
497			Database:     dbPath,
498			TargetChange: &pb.ListenRequest_AddTarget{AddTarget: s.target},
499		})
500	}
501	if err != nil {
502		return nil, err
503	}
504	return lc, nil
505}
506
507func isPermanentWatchError(err error) bool {
508	if err == io.EOF {
509		// Retry on normal end-of-stream.
510		return false
511	}
512	switch status.Code(err) {
513	case codes.Unknown, codes.DeadlineExceeded, codes.ResourceExhausted,
514		codes.Internal, codes.Unavailable, codes.Unauthenticated:
515		return false
516	default:
517		return true
518	}
519}
520