1// Copyright 2020 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package shared
16
17import (
18	"context"
19	"database/sql"
20	"encoding/json"
21	"fmt"
22
23	eduAPI "github.com/matrix-org/dendrite/eduserver/api"
24	userapi "github.com/matrix-org/dendrite/userapi/api"
25
26	"github.com/matrix-org/dendrite/internal/eventutil"
27	"github.com/matrix-org/dendrite/internal/sqlutil"
28	"github.com/matrix-org/dendrite/roomserver/api"
29	"github.com/matrix-org/dendrite/syncapi/storage/tables"
30	"github.com/matrix-org/dendrite/syncapi/types"
31	"github.com/matrix-org/gomatrixserverlib"
32	"github.com/sirupsen/logrus"
33)
34
35// Database is a temporary struct until we have made syncserver.go the same for both pq/sqlite
36// For now this contains the shared functions
37type Database struct {
38	DB                  *sql.DB
39	Writer              sqlutil.Writer
40	Invites             tables.Invites
41	Peeks               tables.Peeks
42	AccountData         tables.AccountData
43	OutputEvents        tables.Events
44	Topology            tables.Topology
45	CurrentRoomState    tables.CurrentRoomState
46	BackwardExtremities tables.BackwardsExtremities
47	SendToDevice        tables.SendToDevice
48	Filter              tables.Filter
49	Receipts            tables.Receipts
50	Memberships         tables.Memberships
51}
52
53func (d *Database) readOnlySnapshot(ctx context.Context) (*sql.Tx, error) {
54	return d.DB.BeginTx(ctx, &sql.TxOptions{
55		// Set the isolation level so that we see a snapshot of the database.
56		// In PostgreSQL repeatable read transactions will see a snapshot taken
57		// at the first query, and since the transaction is read-only it can't
58		// run into any serialisation errors.
59		// https://www.postgresql.org/docs/9.5/static/transaction-iso.html#XACT-REPEATABLE-READ
60		Isolation: sql.LevelRepeatableRead,
61		ReadOnly:  true,
62	})
63}
64
65func (d *Database) MaxStreamPositionForPDUs(ctx context.Context) (types.StreamPosition, error) {
66	id, err := d.OutputEvents.SelectMaxEventID(ctx, nil)
67	if err != nil {
68		return 0, fmt.Errorf("d.OutputEvents.SelectMaxEventID: %w", err)
69	}
70	return types.StreamPosition(id), nil
71}
72
73func (d *Database) MaxStreamPositionForReceipts(ctx context.Context) (types.StreamPosition, error) {
74	id, err := d.Receipts.SelectMaxReceiptID(ctx, nil)
75	if err != nil {
76		return 0, fmt.Errorf("d.Receipts.SelectMaxReceiptID: %w", err)
77	}
78	return types.StreamPosition(id), nil
79}
80
81func (d *Database) MaxStreamPositionForInvites(ctx context.Context) (types.StreamPosition, error) {
82	id, err := d.Invites.SelectMaxInviteID(ctx, nil)
83	if err != nil {
84		return 0, fmt.Errorf("d.Invites.SelectMaxInviteID: %w", err)
85	}
86	return types.StreamPosition(id), nil
87}
88
89func (d *Database) MaxStreamPositionForSendToDeviceMessages(ctx context.Context) (types.StreamPosition, error) {
90	id, err := d.SendToDevice.SelectMaxSendToDeviceMessageID(ctx, nil)
91	if err != nil {
92		return 0, fmt.Errorf("d.SendToDevice.SelectMaxSendToDeviceMessageID: %w", err)
93	}
94	return types.StreamPosition(id), nil
95}
96
97func (d *Database) MaxStreamPositionForAccountData(ctx context.Context) (types.StreamPosition, error) {
98	id, err := d.AccountData.SelectMaxAccountDataID(ctx, nil)
99	if err != nil {
100		return 0, fmt.Errorf("d.Invites.SelectMaxAccountDataID: %w", err)
101	}
102	return types.StreamPosition(id), nil
103}
104
105func (d *Database) CurrentState(ctx context.Context, roomID string, stateFilterPart *gomatrixserverlib.StateFilter, excludeEventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) {
106	return d.CurrentRoomState.SelectCurrentState(ctx, nil, roomID, stateFilterPart, excludeEventIDs)
107}
108
109func (d *Database) RoomIDsWithMembership(ctx context.Context, userID string, membership string) ([]string, error) {
110	return d.CurrentRoomState.SelectRoomIDsWithMembership(ctx, nil, userID, membership)
111}
112
113func (d *Database) RecentEvents(ctx context.Context, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error) {
114	return d.OutputEvents.SelectRecentEvents(ctx, nil, roomID, r, eventFilter, chronologicalOrder, onlySyncEvents)
115}
116
117func (d *Database) PositionInTopology(ctx context.Context, eventID string) (pos types.StreamPosition, spos types.StreamPosition, err error) {
118	return d.Topology.SelectPositionInTopology(ctx, nil, eventID)
119}
120
121func (d *Database) InviteEventsInRange(ctx context.Context, targetUserID string, r types.Range) (map[string]*gomatrixserverlib.HeaderedEvent, map[string]*gomatrixserverlib.HeaderedEvent, error) {
122	return d.Invites.SelectInviteEventsInRange(ctx, nil, targetUserID, r)
123}
124
125func (d *Database) PeeksInRange(ctx context.Context, userID, deviceID string, r types.Range) (peeks []types.Peek, err error) {
126	return d.Peeks.SelectPeeksInRange(ctx, nil, userID, deviceID, r)
127}
128
129func (d *Database) RoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []eduAPI.OutputReceiptEvent, error) {
130	return d.Receipts.SelectRoomReceiptsAfter(ctx, roomIDs, streamPos)
131}
132
133// Events lookups a list of event by their event ID.
134// Returns a list of events matching the requested IDs found in the database.
135// If an event is not found in the database then it will be omitted from the list.
136// Returns an error if there was a problem talking with the database.
137// Does not include any transaction IDs in the returned events.
138func (d *Database) Events(ctx context.Context, eventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) {
139	streamEvents, err := d.OutputEvents.SelectEvents(ctx, nil, eventIDs)
140	if err != nil {
141		return nil, err
142	}
143
144	// We don't include a device here as we only include transaction IDs in
145	// incremental syncs.
146	return d.StreamEventsToEvents(nil, streamEvents), nil
147}
148
149// GetEventsInStreamingRange retrieves all of the events on a given ordering using the
150// given extremities and limit.
151func (d *Database) GetEventsInStreamingRange(
152	ctx context.Context,
153	from, to *types.StreamingToken,
154	roomID string, eventFilter *gomatrixserverlib.RoomEventFilter,
155	backwardOrdering bool,
156) (events []types.StreamEvent, err error) {
157	r := types.Range{
158		From:      from.PDUPosition,
159		To:        to.PDUPosition,
160		Backwards: backwardOrdering,
161	}
162	if backwardOrdering {
163		// When using backward ordering, we want the most recent events first.
164		if events, _, err = d.OutputEvents.SelectRecentEvents(
165			ctx, nil, roomID, r, eventFilter, false, false,
166		); err != nil {
167			return
168		}
169	} else {
170		// When using forward ordering, we want the least recent events first.
171		if events, err = d.OutputEvents.SelectEarlyEvents(
172			ctx, nil, roomID, r, eventFilter,
173		); err != nil {
174			return
175		}
176	}
177	return events, err
178}
179
180func (d *Database) AllJoinedUsersInRooms(ctx context.Context) (map[string][]string, error) {
181	return d.CurrentRoomState.SelectJoinedUsers(ctx)
182}
183
184func (d *Database) AllPeekingDevicesInRooms(ctx context.Context) (map[string][]types.PeekingDevice, error) {
185	return d.Peeks.SelectPeekingDevices(ctx)
186}
187
188func (d *Database) GetStateEvent(
189	ctx context.Context, roomID, evType, stateKey string,
190) (*gomatrixserverlib.HeaderedEvent, error) {
191	return d.CurrentRoomState.SelectStateEvent(ctx, roomID, evType, stateKey)
192}
193
194func (d *Database) GetStateEventsForRoom(
195	ctx context.Context, roomID string, stateFilter *gomatrixserverlib.StateFilter,
196) (stateEvents []*gomatrixserverlib.HeaderedEvent, err error) {
197	stateEvents, err = d.CurrentRoomState.SelectCurrentState(ctx, nil, roomID, stateFilter, nil)
198	return
199}
200
201// AddInviteEvent stores a new invite event for a user.
202// If the invite was successfully stored this returns the stream ID it was stored at.
203// Returns an error if there was a problem communicating with the database.
204func (d *Database) AddInviteEvent(
205	ctx context.Context, inviteEvent *gomatrixserverlib.HeaderedEvent,
206) (sp types.StreamPosition, err error) {
207	_ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
208		sp, err = d.Invites.InsertInviteEvent(ctx, txn, inviteEvent)
209		return err
210	})
211	return
212}
213
214// RetireInviteEvent removes an old invite event from the database.
215// Returns an error if there was a problem communicating with the database.
216func (d *Database) RetireInviteEvent(
217	ctx context.Context, inviteEventID string,
218) (sp types.StreamPosition, err error) {
219	_ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
220		sp, err = d.Invites.DeleteInviteEvent(ctx, txn, inviteEventID)
221		return err
222	})
223	return
224}
225
226// AddPeek tracks the fact that a user has started peeking.
227// If the peek was successfully stored this returns the stream ID it was stored at.
228// Returns an error if there was a problem communicating with the database.
229func (d *Database) AddPeek(
230	ctx context.Context, roomID, userID, deviceID string,
231) (sp types.StreamPosition, err error) {
232	err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
233		sp, err = d.Peeks.InsertPeek(ctx, txn, roomID, userID, deviceID)
234		return err
235	})
236	return
237}
238
239// DeletePeeks tracks the fact that a user has stopped peeking from the specified
240// device. If the peeks was successfully deleted this returns the stream ID it was
241// stored at. Returns an error if there was a problem communicating with the database.
242func (d *Database) DeletePeek(
243	ctx context.Context, roomID, userID, deviceID string,
244) (sp types.StreamPosition, err error) {
245	err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
246		sp, err = d.Peeks.DeletePeek(ctx, txn, roomID, userID, deviceID)
247		return err
248	})
249	if err == sql.ErrNoRows {
250		sp = 0
251		err = nil
252	}
253	return
254}
255
256// DeletePeeks tracks the fact that a user has stopped peeking from all devices
257// If the peeks was successfully deleted this returns the stream ID it was stored at.
258// Returns an error if there was a problem communicating with the database.
259func (d *Database) DeletePeeks(
260	ctx context.Context, roomID, userID string,
261) (sp types.StreamPosition, err error) {
262	err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
263		sp, err = d.Peeks.DeletePeeks(ctx, txn, roomID, userID)
264		return err
265	})
266	if err == sql.ErrNoRows {
267		sp = 0
268		err = nil
269	}
270	return
271}
272
273// GetAccountDataInRange returns all account data for a given user inserted or
274// updated between two given positions
275// Returns a map following the format data[roomID] = []dataTypes
276// If no data is retrieved, returns an empty map
277// If there was an issue with the retrieval, returns an error
278func (d *Database) GetAccountDataInRange(
279	ctx context.Context, userID string, r types.Range,
280	accountDataFilterPart *gomatrixserverlib.EventFilter,
281) (map[string][]string, error) {
282	return d.AccountData.SelectAccountDataInRange(ctx, userID, r, accountDataFilterPart)
283}
284
285// UpsertAccountData keeps track of new or updated account data, by saving the type
286// of the new/updated data, and the user ID and room ID the data is related to (empty)
287// room ID means the data isn't specific to any room)
288// If no data with the given type, user ID and room ID exists in the database,
289// creates a new row, else update the existing one
290// Returns an error if there was an issue with the upsert
291func (d *Database) UpsertAccountData(
292	ctx context.Context, userID, roomID, dataType string,
293) (sp types.StreamPosition, err error) {
294	err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
295		sp, err = d.AccountData.InsertAccountData(ctx, txn, userID, roomID, dataType)
296		return err
297	})
298	return
299}
300
301func (d *Database) StreamEventsToEvents(device *userapi.Device, in []types.StreamEvent) []*gomatrixserverlib.HeaderedEvent {
302	out := make([]*gomatrixserverlib.HeaderedEvent, len(in))
303	for i := 0; i < len(in); i++ {
304		out[i] = in[i].HeaderedEvent
305		if device != nil && in[i].TransactionID != nil {
306			if device.UserID == in[i].Sender() && device.SessionID == in[i].TransactionID.SessionID {
307				err := out[i].SetUnsignedField(
308					"transaction_id", in[i].TransactionID.TransactionID,
309				)
310				if err != nil {
311					logrus.WithFields(logrus.Fields{
312						"event_id": out[i].EventID(),
313					}).WithError(err).Warnf("Failed to add transaction ID to event")
314				}
315			}
316		}
317	}
318	return out
319}
320
321// handleBackwardExtremities adds this event as a backwards extremity if and only if we do not have all of
322// the events listed in the event's 'prev_events'. This function also updates the backwards extremities table
323// to account for the fact that the given event is no longer a backwards extremity, but may be marked as such.
324// This function should always be called within a sqlutil.Writer for safety in SQLite.
325func (d *Database) handleBackwardExtremities(ctx context.Context, txn *sql.Tx, ev *gomatrixserverlib.HeaderedEvent) error {
326	if err := d.BackwardExtremities.DeleteBackwardExtremity(ctx, txn, ev.RoomID(), ev.EventID()); err != nil {
327		return err
328	}
329
330	// Check if we have all of the event's previous events. If an event is
331	// missing, add it to the room's backward extremities.
332	prevEvents, err := d.OutputEvents.SelectEvents(ctx, txn, ev.PrevEventIDs())
333	if err != nil {
334		return err
335	}
336	var found bool
337	for _, eID := range ev.PrevEventIDs() {
338		found = false
339		for _, prevEv := range prevEvents {
340			if eID == prevEv.EventID() {
341				found = true
342			}
343		}
344
345		// If the event is missing, consider it a backward extremity.
346		if !found {
347			if err = d.BackwardExtremities.InsertsBackwardExtremity(ctx, txn, ev.RoomID(), ev.EventID(), eID); err != nil {
348				return err
349			}
350		}
351	}
352
353	return nil
354}
355
356func (d *Database) PurgeRoomState(
357	ctx context.Context, roomID string,
358) error {
359	return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
360		// If the event is a create event then we'll delete all of the existing
361		// data for the room. The only reason that a create event would be replayed
362		// to us in this way is if we're about to receive the entire room state.
363		if err := d.CurrentRoomState.DeleteRoomStateForRoom(ctx, txn, roomID); err != nil {
364			return fmt.Errorf("d.CurrentRoomState.DeleteRoomStateForRoom: %w", err)
365		}
366		return nil
367	})
368}
369
370func (d *Database) WriteEvent(
371	ctx context.Context,
372	ev *gomatrixserverlib.HeaderedEvent,
373	addStateEvents []*gomatrixserverlib.HeaderedEvent,
374	addStateEventIDs, removeStateEventIDs []string,
375	transactionID *api.TransactionID, excludeFromSync bool,
376) (pduPosition types.StreamPosition, returnErr error) {
377	returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
378		var err error
379		pos, err := d.OutputEvents.InsertEvent(
380			ctx, txn, ev, addStateEventIDs, removeStateEventIDs, transactionID, excludeFromSync,
381		)
382		if err != nil {
383			return fmt.Errorf("d.OutputEvents.InsertEvent: %w", err)
384		}
385		pduPosition = pos
386		var topoPosition types.StreamPosition
387		if topoPosition, err = d.Topology.InsertEventInTopology(ctx, txn, ev, pos); err != nil {
388			return fmt.Errorf("d.Topology.InsertEventInTopology: %w", err)
389		}
390
391		if err = d.handleBackwardExtremities(ctx, txn, ev); err != nil {
392			return fmt.Errorf("d.handleBackwardExtremities: %w", err)
393		}
394
395		if len(addStateEvents) == 0 && len(removeStateEventIDs) == 0 {
396			// Nothing to do, the event may have just been a message event.
397			return nil
398		}
399
400		return d.updateRoomState(ctx, txn, removeStateEventIDs, addStateEvents, pduPosition, topoPosition)
401	})
402
403	return pduPosition, returnErr
404}
405
406// This function should always be called within a sqlutil.Writer for safety in SQLite.
407func (d *Database) updateRoomState(
408	ctx context.Context, txn *sql.Tx,
409	removedEventIDs []string,
410	addedEvents []*gomatrixserverlib.HeaderedEvent,
411	pduPosition types.StreamPosition,
412	topoPosition types.StreamPosition,
413) error {
414	// remove first, then add, as we do not ever delete state, but do replace state which is a remove followed by an add.
415	for _, eventID := range removedEventIDs {
416		if err := d.CurrentRoomState.DeleteRoomStateByEventID(ctx, txn, eventID); err != nil {
417			return fmt.Errorf("d.CurrentRoomState.DeleteRoomStateByEventID: %w", err)
418		}
419	}
420
421	for _, event := range addedEvents {
422		if event.StateKey() == nil {
423			// ignore non state events
424			continue
425		}
426		var membership *string
427		if event.Type() == "m.room.member" {
428			value, err := event.Membership()
429			if err != nil {
430				return fmt.Errorf("event.Membership: %w", err)
431			}
432			membership = &value
433			if err = d.Memberships.UpsertMembership(ctx, txn, event, pduPosition, topoPosition); err != nil {
434				return fmt.Errorf("d.Memberships.UpsertMembership: %w", err)
435			}
436		}
437
438		if err := d.CurrentRoomState.UpsertRoomState(ctx, txn, event, membership, pduPosition); err != nil {
439			return fmt.Errorf("d.CurrentRoomState.UpsertRoomState: %w", err)
440		}
441	}
442
443	return nil
444}
445
446func (d *Database) GetEventsInTopologicalRange(
447	ctx context.Context,
448	from, to *types.TopologyToken,
449	roomID string, limit int,
450	backwardOrdering bool,
451) (events []types.StreamEvent, err error) {
452	var minDepth, maxDepth, maxStreamPosForMaxDepth types.StreamPosition
453	if backwardOrdering {
454		// Backward ordering means the 'from' token has a higher depth than the 'to' token
455		minDepth = to.Depth
456		maxDepth = from.Depth
457		// for cases where we have say 5 events with the same depth, the TopologyToken needs to
458		// know which of the 5 the client has seen. This is done by using the PDU position.
459		// Events with the same maxDepth but less than this PDU position will be returned.
460		maxStreamPosForMaxDepth = from.PDUPosition
461	} else {
462		// Forward ordering means the 'from' token has a lower depth than the 'to' token.
463		minDepth = from.Depth
464		maxDepth = to.Depth
465	}
466
467	// Select the event IDs from the defined range.
468	var eIDs []string
469	eIDs, err = d.Topology.SelectEventIDsInRange(
470		ctx, nil, roomID, minDepth, maxDepth, maxStreamPosForMaxDepth, limit, !backwardOrdering,
471	)
472	if err != nil {
473		return
474	}
475
476	// Retrieve the events' contents using their IDs.
477	events, err = d.OutputEvents.SelectEvents(ctx, nil, eIDs)
478	return
479}
480
481func (d *Database) BackwardExtremitiesForRoom(
482	ctx context.Context, roomID string,
483) (backwardExtremities map[string][]string, err error) {
484	return d.BackwardExtremities.SelectBackwardExtremitiesForRoom(ctx, roomID)
485}
486
487func (d *Database) MaxTopologicalPosition(
488	ctx context.Context, roomID string,
489) (types.TopologyToken, error) {
490	depth, streamPos, err := d.Topology.SelectMaxPositionInTopology(ctx, nil, roomID)
491	if err != nil {
492		return types.TopologyToken{}, err
493	}
494	return types.TopologyToken{Depth: depth, PDUPosition: streamPos}, nil
495}
496
497func (d *Database) EventPositionInTopology(
498	ctx context.Context, eventID string,
499) (types.TopologyToken, error) {
500	depth, stream, err := d.Topology.SelectPositionInTopology(ctx, nil, eventID)
501	if err != nil {
502		return types.TopologyToken{}, err
503	}
504	return types.TopologyToken{Depth: depth, PDUPosition: stream}, nil
505}
506
507func (d *Database) GetFilter(
508	ctx context.Context, localpart string, filterID string,
509) (*gomatrixserverlib.Filter, error) {
510	return d.Filter.SelectFilter(ctx, localpart, filterID)
511}
512
513func (d *Database) PutFilter(
514	ctx context.Context, localpart string, filter *gomatrixserverlib.Filter,
515) (string, error) {
516	var filterID string
517	var err error
518	err = d.Writer.Do(nil, nil, func(txn *sql.Tx) error {
519		filterID, err = d.Filter.InsertFilter(ctx, filter, localpart)
520		return err
521	})
522	return filterID, err
523}
524
525func (d *Database) RedactEvent(ctx context.Context, redactedEventID string, redactedBecause *gomatrixserverlib.HeaderedEvent) error {
526	redactedEvents, err := d.Events(ctx, []string{redactedEventID})
527	if err != nil {
528		return err
529	}
530	if len(redactedEvents) == 0 {
531		logrus.WithField("event_id", redactedEventID).WithField("redaction_event", redactedBecause.EventID()).Warnf("missing redacted event for redaction")
532		return nil
533	}
534	eventToRedact := redactedEvents[0].Unwrap()
535	redactionEvent := redactedBecause.Unwrap()
536	ev, err := eventutil.RedactEvent(redactionEvent, eventToRedact)
537	if err != nil {
538		return err
539	}
540
541	newEvent := ev.Headered(redactedBecause.RoomVersion)
542	err = d.Writer.Do(nil, nil, func(txn *sql.Tx) error {
543		return d.OutputEvents.UpdateEventJSON(ctx, newEvent)
544	})
545	return err
546}
547
548// Retrieve the backward topology position, i.e. the position of the
549// oldest event in the room's topology.
550func (d *Database) GetBackwardTopologyPos(
551	ctx context.Context,
552	events []types.StreamEvent,
553) (types.TopologyToken, error) {
554	zeroToken := types.TopologyToken{}
555	if len(events) == 0 {
556		return zeroToken, nil
557	}
558	pos, spos, err := d.Topology.SelectPositionInTopology(ctx, nil, events[0].EventID())
559	if err != nil {
560		return zeroToken, err
561	}
562	tok := types.TopologyToken{Depth: pos, PDUPosition: spos}
563	tok.Decrement()
564	return tok, nil
565}
566
567// fetchStateEvents converts the set of event IDs into a set of events. It will fetch any which are missing from the database.
568// Returns a map of room ID to list of events.
569func (d *Database) fetchStateEvents(
570	ctx context.Context, txn *sql.Tx,
571	roomIDToEventIDSet map[string]map[string]bool,
572	eventIDToEvent map[string]types.StreamEvent,
573) (map[string][]types.StreamEvent, error) {
574	stateBetween := make(map[string][]types.StreamEvent)
575	missingEvents := make(map[string][]string)
576	for roomID, ids := range roomIDToEventIDSet {
577		events := stateBetween[roomID]
578		for id, need := range ids {
579			if !need {
580				continue // deleted state
581			}
582			e, ok := eventIDToEvent[id]
583			if ok {
584				events = append(events, e)
585			} else {
586				m := missingEvents[roomID]
587				m = append(m, id)
588				missingEvents[roomID] = m
589			}
590		}
591		stateBetween[roomID] = events
592	}
593
594	if len(missingEvents) > 0 {
595		// This happens when add_state_ids has an event ID which is not in the provided range.
596		// We need to explicitly fetch them.
597		allMissingEventIDs := []string{}
598		for _, missingEvIDs := range missingEvents {
599			allMissingEventIDs = append(allMissingEventIDs, missingEvIDs...)
600		}
601		evs, err := d.fetchMissingStateEvents(ctx, txn, allMissingEventIDs)
602		if err != nil {
603			return nil, err
604		}
605		// we know we got them all otherwise an error would've been returned, so just loop the events
606		for _, ev := range evs {
607			roomID := ev.RoomID()
608			stateBetween[roomID] = append(stateBetween[roomID], ev)
609		}
610	}
611	return stateBetween, nil
612}
613
614func (d *Database) fetchMissingStateEvents(
615	ctx context.Context, txn *sql.Tx, eventIDs []string,
616) ([]types.StreamEvent, error) {
617	// Fetch from the events table first so we pick up the stream ID for the
618	// event.
619	events, err := d.OutputEvents.SelectEvents(ctx, txn, eventIDs)
620	if err != nil {
621		return nil, err
622	}
623
624	have := map[string]bool{}
625	for _, event := range events {
626		have[event.EventID()] = true
627	}
628	var missing []string
629	for _, eventID := range eventIDs {
630		if !have[eventID] {
631			missing = append(missing, eventID)
632		}
633	}
634	if len(missing) == 0 {
635		return events, nil
636	}
637
638	// If they are missing from the events table then they should be state
639	// events that we received from outside the main event stream.
640	// These should be in the room state table.
641	stateEvents, err := d.CurrentRoomState.SelectEventsWithEventIDs(ctx, txn, missing)
642
643	if err != nil {
644		return nil, err
645	}
646	if len(stateEvents) != len(missing) {
647		logrus.WithContext(ctx).Warnf("Failed to map all event IDs to events (got %d, wanted %d)", len(stateEvents), len(missing))
648
649		// TODO: Why is this happening? It's probably the roomserver. Uncomment
650		// this error again when we work out what it is and fix it, otherwise we
651		// just end up returning lots of 500s to the client and that breaks
652		// pretty much everything, rather than just sending what we have.
653		//return nil, fmt.Errorf("failed to map all event IDs to events: (got %d, wanted %d)", len(stateEvents), len(missing))
654	}
655	events = append(events, stateEvents...)
656	return events, nil
657}
658
659// getStateDeltas returns the state deltas between fromPos and toPos,
660// exclusive of oldPos, inclusive of newPos, for the rooms in which
661// the user has new membership events.
662// A list of joined room IDs is also returned in case the caller needs it.
663func (d *Database) GetStateDeltas(
664	ctx context.Context, device *userapi.Device,
665	r types.Range, userID string,
666	stateFilter *gomatrixserverlib.StateFilter,
667) ([]types.StateDelta, []string, error) {
668	// Implement membership change algorithm: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L821
669	// - Get membership list changes for this user in this sync response
670	// - For each room which has membership list changes:
671	//     * Check if the room is 'newly joined' (insufficient to just check for a join event because we allow dupe joins TODO).
672	//       If it is, then we need to send the full room state down (and 'limited' is always true).
673	//     * Check if user is still CURRENTLY invited to the room. If so, add room to 'invited' block.
674	//     * Check if the user is CURRENTLY (TODO) left/banned. If so, add room to 'archived' block.
675	// - Get all CURRENTLY joined rooms, and add them to 'joined' block.
676	txn, err := d.readOnlySnapshot(ctx)
677	if err != nil {
678		return nil, nil, fmt.Errorf("d.readOnlySnapshot: %w", err)
679	}
680	var succeeded bool
681	defer sqlutil.EndTransactionWithCheck(txn, &succeeded, &err)
682
683	var deltas []types.StateDelta
684
685	// get all the state events ever (i.e. for all available rooms) between these two positions
686	stateNeeded, eventMap, err := d.OutputEvents.SelectStateInRange(ctx, txn, r, stateFilter)
687	if err != nil {
688		return nil, nil, err
689	}
690	state, err := d.fetchStateEvents(ctx, txn, stateNeeded, eventMap)
691	if err != nil {
692		return nil, nil, err
693	}
694
695	// find out which rooms this user is peeking, if any.
696	// We do this before joins so any peeks get overwritten
697	peeks, err := d.Peeks.SelectPeeksInRange(ctx, txn, userID, device.ID, r)
698	if err != nil {
699		return nil, nil, err
700	}
701
702	// add peek blocks
703	for _, peek := range peeks {
704		if peek.New {
705			// send full room state down instead of a delta
706			var s []types.StreamEvent
707			s, err = d.currentStateStreamEventsForRoom(ctx, txn, peek.RoomID, stateFilter)
708			if err != nil {
709				return nil, nil, err
710			}
711			state[peek.RoomID] = s
712		}
713		if !peek.Deleted {
714			deltas = append(deltas, types.StateDelta{
715				Membership:  gomatrixserverlib.Peek,
716				StateEvents: d.StreamEventsToEvents(device, state[peek.RoomID]),
717				RoomID:      peek.RoomID,
718			})
719		}
720	}
721
722	// handle newly joined rooms and non-joined rooms
723	for roomID, stateStreamEvents := range state {
724		for _, ev := range stateStreamEvents {
725			// TODO: Currently this will incorrectly add rooms which were ALREADY joined but they sent another no-op join event.
726			//       We should be checking if the user was already joined at fromPos and not proceed if so. As a result of this,
727			//       dupe join events will result in the entire room state coming down to the client again. This is added in
728			//       the 'state' part of the response though, so is transparent modulo bandwidth concerns as it is not added to
729			//       the timeline.
730			if membership := getMembershipFromEvent(ev.Event, userID); membership != "" {
731				if membership == gomatrixserverlib.Join {
732					// send full room state down instead of a delta
733					var s []types.StreamEvent
734					s, err = d.currentStateStreamEventsForRoom(ctx, txn, roomID, stateFilter)
735					if err != nil {
736						return nil, nil, err
737					}
738					state[roomID] = s
739					continue // we'll add this room in when we do joined rooms
740				}
741
742				deltas = append(deltas, types.StateDelta{
743					Membership:    membership,
744					MembershipPos: ev.StreamPosition,
745					StateEvents:   d.StreamEventsToEvents(device, stateStreamEvents),
746					RoomID:        roomID,
747				})
748				break
749			}
750		}
751	}
752
753	// Add in currently joined rooms
754	joinedRoomIDs, err := d.CurrentRoomState.SelectRoomIDsWithMembership(ctx, txn, userID, gomatrixserverlib.Join)
755	if err != nil {
756		return nil, nil, err
757	}
758	for _, joinedRoomID := range joinedRoomIDs {
759		deltas = append(deltas, types.StateDelta{
760			Membership:  gomatrixserverlib.Join,
761			StateEvents: d.StreamEventsToEvents(device, state[joinedRoomID]),
762			RoomID:      joinedRoomID,
763		})
764	}
765
766	succeeded = true
767	return deltas, joinedRoomIDs, nil
768}
769
770// getStateDeltasForFullStateSync is a variant of getStateDeltas used for /sync
771// requests with full_state=true.
772// Fetches full state for all joined rooms and uses selectStateInRange to get
773// updates for other rooms.
774func (d *Database) GetStateDeltasForFullStateSync(
775	ctx context.Context, device *userapi.Device,
776	r types.Range, userID string,
777	stateFilter *gomatrixserverlib.StateFilter,
778) ([]types.StateDelta, []string, error) {
779	txn, err := d.readOnlySnapshot(ctx)
780	if err != nil {
781		return nil, nil, fmt.Errorf("d.readOnlySnapshot: %w", err)
782	}
783	var succeeded bool
784	defer sqlutil.EndTransactionWithCheck(txn, &succeeded, &err)
785
786	// Use a reasonable initial capacity
787	deltas := make(map[string]types.StateDelta)
788
789	peeks, err := d.Peeks.SelectPeeksInRange(ctx, txn, userID, device.ID, r)
790	if err != nil {
791		return nil, nil, err
792	}
793
794	// Add full states for all peeking rooms
795	for _, peek := range peeks {
796		if !peek.Deleted {
797			s, stateErr := d.currentStateStreamEventsForRoom(ctx, txn, peek.RoomID, stateFilter)
798			if stateErr != nil {
799				return nil, nil, stateErr
800			}
801			deltas[peek.RoomID] = types.StateDelta{
802				Membership:  gomatrixserverlib.Peek,
803				StateEvents: d.StreamEventsToEvents(device, s),
804				RoomID:      peek.RoomID,
805			}
806		}
807	}
808
809	// Get all the state events ever between these two positions
810	stateNeeded, eventMap, err := d.OutputEvents.SelectStateInRange(ctx, txn, r, stateFilter)
811	if err != nil {
812		return nil, nil, err
813	}
814	state, err := d.fetchStateEvents(ctx, txn, stateNeeded, eventMap)
815	if err != nil {
816		return nil, nil, err
817	}
818
819	for roomID, stateStreamEvents := range state {
820		for _, ev := range stateStreamEvents {
821			if membership := getMembershipFromEvent(ev.Event, userID); membership != "" {
822				if membership != gomatrixserverlib.Join { // We've already added full state for all joined rooms above.
823					deltas[roomID] = types.StateDelta{
824						Membership:    membership,
825						MembershipPos: ev.StreamPosition,
826						StateEvents:   d.StreamEventsToEvents(device, stateStreamEvents),
827						RoomID:        roomID,
828					}
829				}
830
831				break
832			}
833		}
834	}
835
836	joinedRoomIDs, err := d.CurrentRoomState.SelectRoomIDsWithMembership(ctx, txn, userID, gomatrixserverlib.Join)
837	if err != nil {
838		return nil, nil, err
839	}
840
841	// Add full states for all joined rooms
842	for _, joinedRoomID := range joinedRoomIDs {
843		s, stateErr := d.currentStateStreamEventsForRoom(ctx, txn, joinedRoomID, stateFilter)
844		if stateErr != nil {
845			return nil, nil, stateErr
846		}
847		deltas[joinedRoomID] = types.StateDelta{
848			Membership:  gomatrixserverlib.Join,
849			StateEvents: d.StreamEventsToEvents(device, s),
850			RoomID:      joinedRoomID,
851		}
852	}
853
854	// Create a response array.
855	result := make([]types.StateDelta, len(deltas))
856	i := 0
857	for _, delta := range deltas {
858		result[i] = delta
859		i++
860	}
861
862	succeeded = true
863	return result, joinedRoomIDs, nil
864}
865
866func (d *Database) currentStateStreamEventsForRoom(
867	ctx context.Context, txn *sql.Tx, roomID string,
868	stateFilter *gomatrixserverlib.StateFilter,
869) ([]types.StreamEvent, error) {
870	allState, err := d.CurrentRoomState.SelectCurrentState(ctx, txn, roomID, stateFilter, nil)
871	if err != nil {
872		return nil, err
873	}
874	s := make([]types.StreamEvent, len(allState))
875	for i := 0; i < len(s); i++ {
876		s[i] = types.StreamEvent{HeaderedEvent: allState[i], StreamPosition: 0}
877	}
878	return s, nil
879}
880
881func (d *Database) StoreNewSendForDeviceMessage(
882	ctx context.Context, userID, deviceID string, event gomatrixserverlib.SendToDeviceEvent,
883) (newPos types.StreamPosition, err error) {
884	j, err := json.Marshal(event)
885	if err != nil {
886		return 0, err
887	}
888	// Delegate the database write task to the SendToDeviceWriter. It'll guarantee
889	// that we don't lock the table for writes in more than one place.
890	err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
891		newPos, err = d.SendToDevice.InsertSendToDeviceMessage(
892			ctx, txn, userID, deviceID, string(j),
893		)
894		return err
895	})
896	if err != nil {
897		return 0, err
898	}
899	return newPos, nil
900}
901
902func (d *Database) SendToDeviceUpdatesForSync(
903	ctx context.Context,
904	userID, deviceID string,
905	from, to types.StreamPosition,
906) (types.StreamPosition, []types.SendToDeviceEvent, error) {
907	// First of all, get our send-to-device updates for this user.
908	lastPos, events, err := d.SendToDevice.SelectSendToDeviceMessages(ctx, nil, userID, deviceID, from, to)
909	if err != nil {
910		return from, nil, fmt.Errorf("d.SendToDevice.SelectSendToDeviceMessages: %w", err)
911	}
912	// If there's nothing to do then stop here.
913	if len(events) == 0 {
914		return to, nil, nil
915	}
916	return lastPos, events, nil
917}
918
919func (d *Database) CleanSendToDeviceUpdates(
920	ctx context.Context,
921	userID, deviceID string, before types.StreamPosition,
922) (err error) {
923	if err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
924		return d.SendToDevice.DeleteSendToDeviceMessages(ctx, txn, userID, deviceID, before)
925	}); err != nil {
926		logrus.WithError(err).Errorf("Failed to clean up old send-to-device messages for user %q device %q", userID, deviceID)
927		return err
928	}
929	return nil
930}
931
932// getMembershipFromEvent returns the value of content.membership iff the event is a state event
933// with type 'm.room.member' and state_key of userID. Otherwise, an empty string is returned.
934func getMembershipFromEvent(ev *gomatrixserverlib.Event, userID string) string {
935	if ev.Type() != "m.room.member" || !ev.StateKeyEquals(userID) {
936		return ""
937	}
938	membership, err := ev.Membership()
939	if err != nil {
940		return ""
941	}
942	return membership
943}
944
945// StoreReceipt stores user receipts
946func (d *Database) StoreReceipt(ctx context.Context, roomId, receiptType, userId, eventId string, timestamp gomatrixserverlib.Timestamp) (pos types.StreamPosition, err error) {
947	err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
948		pos, err = d.Receipts.UpsertReceipt(ctx, txn, roomId, receiptType, userId, eventId, timestamp)
949		return err
950	})
951	return
952}
953
954func (d *Database) GetRoomReceipts(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]eduAPI.OutputReceiptEvent, error) {
955	_, receipts, err := d.Receipts.SelectRoomReceiptsAfter(ctx, roomIDs, streamPos)
956	return receipts, err
957}
958