1// Copyright 2017 Vector Creations Ltd
2// Copyright 2018 New Vector Ltd
3// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
4//
5// Licensed under the Apache License, Version 2.0 (the "License");
6// you may not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//
9//     http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
16
17package state
18
19import (
20	"context"
21	"fmt"
22	"sort"
23	"time"
24
25	"github.com/matrix-org/dendrite/roomserver/storage"
26	"github.com/matrix-org/util"
27	"github.com/prometheus/client_golang/prometheus"
28
29	"github.com/matrix-org/dendrite/roomserver/types"
30	"github.com/matrix-org/gomatrixserverlib"
31)
32
33type StateResolution struct {
34	db       storage.Database
35	roomInfo types.RoomInfo
36	events   map[types.EventNID]*gomatrixserverlib.Event
37}
38
39func NewStateResolution(db storage.Database, roomInfo types.RoomInfo) StateResolution {
40	return StateResolution{
41		db:       db,
42		roomInfo: roomInfo,
43		events:   make(map[types.EventNID]*gomatrixserverlib.Event),
44	}
45}
46
47// LoadStateAtSnapshot loads the full state of a room at a particular snapshot.
48// This is typically the state before an event or the current state of a room.
49// Returns a sorted list of state entries or an error if there was a problem talking to the database.
50func (v *StateResolution) LoadStateAtSnapshot(
51	ctx context.Context, stateNID types.StateSnapshotNID,
52) ([]types.StateEntry, error) {
53	stateBlockNIDLists, err := v.db.StateBlockNIDs(ctx, []types.StateSnapshotNID{stateNID})
54	if err != nil {
55		return nil, err
56	}
57	// We've asked for exactly one snapshot from the db so we should have exactly one entry in the result.
58	stateBlockNIDList := stateBlockNIDLists[0]
59
60	stateEntryLists, err := v.db.StateEntries(ctx, stateBlockNIDList.StateBlockNIDs)
61	if err != nil {
62		return nil, err
63	}
64	stateEntriesMap := stateEntryListMap(stateEntryLists)
65
66	// Combine all the state entries for this snapshot.
67	// The order of state block NIDs in the list tells us the order to combine them in.
68	var fullState []types.StateEntry
69	for _, stateBlockNID := range stateBlockNIDList.StateBlockNIDs {
70		entries, ok := stateEntriesMap.lookup(stateBlockNID)
71		if !ok {
72			// This should only get hit if the database is corrupt.
73			// It should be impossible for an event to reference a NID that doesn't exist
74			panic(fmt.Errorf("corrupt DB: Missing state block numeric ID %d", stateBlockNID))
75		}
76		fullState = append(fullState, entries...)
77	}
78
79	// Stable sort so that the most recent entry for each state key stays
80	// remains later in the list than the older entries for the same state key.
81	sort.Stable(stateEntryByStateKeySorter(fullState))
82	// Unique returns the last entry and hence the most recent entry for each state key.
83	fullState = fullState[:util.Unique(stateEntryByStateKeySorter(fullState))]
84	return fullState, nil
85}
86
87// LoadStateAtEvent loads the full state of a room before a particular event.
88func (v *StateResolution) LoadStateAtEvent(
89	ctx context.Context, eventID string,
90) ([]types.StateEntry, error) {
91	snapshotNID, err := v.db.SnapshotNIDFromEventID(ctx, eventID)
92	if err != nil {
93		return nil, fmt.Errorf("LoadStateAtEvent.SnapshotNIDFromEventID failed for event %s : %s", eventID, err)
94	}
95	if snapshotNID == 0 {
96		return nil, fmt.Errorf("LoadStateAtEvent.SnapshotNIDFromEventID(%s) returned 0 NID, was this event stored?", eventID)
97	}
98
99	stateEntries, err := v.LoadStateAtSnapshot(ctx, snapshotNID)
100	if err != nil {
101		return nil, err
102	}
103
104	return stateEntries, nil
105}
106
107// LoadCombinedStateAfterEvents loads a snapshot of the state after each of the events
108// and combines those snapshots together into a single list. At this point it is
109// possible to run into duplicate (type, state key) tuples.
110func (v *StateResolution) LoadCombinedStateAfterEvents(
111	ctx context.Context, prevStates []types.StateAtEvent,
112) ([]types.StateEntry, error) {
113	stateNIDs := make([]types.StateSnapshotNID, len(prevStates))
114	for i, state := range prevStates {
115		stateNIDs[i] = state.BeforeStateSnapshotNID
116	}
117	// Fetch the state snapshots for the state before the each prev event from the database.
118	// Deduplicate the IDs before passing them to the database.
119	// There could be duplicates because the events could be state events where
120	// the snapshot of the room state before them was the same.
121	stateBlockNIDLists, err := v.db.StateBlockNIDs(ctx, UniqueStateSnapshotNIDs(stateNIDs))
122	if err != nil {
123		return nil, fmt.Errorf("v.db.StateBlockNIDs: %w", err)
124	}
125
126	var stateBlockNIDs []types.StateBlockNID
127	for _, list := range stateBlockNIDLists {
128		stateBlockNIDs = append(stateBlockNIDs, list.StateBlockNIDs...)
129	}
130	// Fetch the state entries that will be combined to create the snapshots.
131	// Deduplicate the IDs before passing them to the database.
132	// There could be duplicates because a block of state entries could be reused by
133	// multiple snapshots.
134	stateEntryLists, err := v.db.StateEntries(ctx, uniqueStateBlockNIDs(stateBlockNIDs))
135	if err != nil {
136		return nil, fmt.Errorf("v.db.StateEntries: %w", err)
137	}
138	stateBlockNIDsMap := stateBlockNIDListMap(stateBlockNIDLists)
139	stateEntriesMap := stateEntryListMap(stateEntryLists)
140
141	// Combine the entries from all the snapshots of state after each prev event into a single list.
142	var combined []types.StateEntry
143	for _, prevState := range prevStates {
144		// Grab the list of state data NIDs for this snapshot.
145		stateBlockNIDs, ok := stateBlockNIDsMap.lookup(prevState.BeforeStateSnapshotNID)
146		if !ok {
147			// This should only get hit if the database is corrupt.
148			// It should be impossible for an event to reference a NID that doesn't exist
149			panic(fmt.Errorf("corrupt DB: Missing state snapshot numeric ID %d", prevState.BeforeStateSnapshotNID))
150		}
151
152		// Combine all the state entries for this snapshot.
153		// The order of state block NIDs in the list tells us the order to combine them in.
154		var fullState []types.StateEntry
155		for _, stateBlockNID := range stateBlockNIDs {
156			entries, ok := stateEntriesMap.lookup(stateBlockNID)
157			if !ok {
158				// This should only get hit if the database is corrupt.
159				// It should be impossible for an event to reference a NID that doesn't exist
160				panic(fmt.Errorf("corrupt DB: Missing state block numeric ID %d", stateBlockNID))
161			}
162			fullState = append(fullState, entries...)
163		}
164		if prevState.IsStateEvent() && !prevState.IsRejected {
165			// If the prev event was a state event then add an entry for the event itself
166			// so that we get the state after the event rather than the state before.
167			fullState = append(fullState, prevState.StateEntry)
168		}
169
170		// Stable sort so that the most recent entry for each state key stays
171		// remains later in the list than the older entries for the same state key.
172		sort.Stable(stateEntryByStateKeySorter(fullState))
173		// Unique returns the last entry and hence the most recent entry for each state key.
174		fullState = fullState[:util.Unique(stateEntryByStateKeySorter(fullState))]
175		// Add the full state for this StateSnapshotNID.
176		combined = append(combined, fullState...)
177	}
178	return combined, nil
179}
180
181// DifferenceBetweeenStateSnapshots works out which state entries have been added and removed between two snapshots.
182func (v *StateResolution) DifferenceBetweeenStateSnapshots(
183	ctx context.Context, oldStateNID, newStateNID types.StateSnapshotNID,
184) (removed, added []types.StateEntry, err error) {
185	if oldStateNID == newStateNID {
186		// If the snapshot NIDs are the same then nothing has changed
187		return nil, nil, nil
188	}
189
190	var oldEntries []types.StateEntry
191	var newEntries []types.StateEntry
192	if oldStateNID != 0 {
193		oldEntries, err = v.LoadStateAtSnapshot(ctx, oldStateNID)
194		if err != nil {
195			return nil, nil, err
196		}
197	}
198	if newStateNID != 0 {
199		newEntries, err = v.LoadStateAtSnapshot(ctx, newStateNID)
200		if err != nil {
201			return nil, nil, err
202		}
203	}
204
205	var oldI int
206	var newI int
207	for {
208		switch {
209		case oldI == len(oldEntries):
210			// We've reached the end of the old entries.
211			// The rest of the new list must have been newly added.
212			added = append(added, newEntries[newI:]...)
213			return
214		case newI == len(newEntries):
215			// We've reached the end of the new entries.
216			// The rest of the old list must be have been removed.
217			removed = append(removed, oldEntries[oldI:]...)
218			return
219		case oldEntries[oldI] == newEntries[newI]:
220			// The entry is in both lists so skip over it.
221			oldI++
222			newI++
223		case oldEntries[oldI].LessThan(newEntries[newI]):
224			// The lists are sorted so the old entry being less than the new entry means that it only appears in the old list.
225			removed = append(removed, oldEntries[oldI])
226			oldI++
227		default:
228			// Reaching the default case implies that the new entry is less than the old entry.
229			// Since the lists are sorted this means that it only appears in the new list.
230			added = append(added, newEntries[newI])
231			newI++
232		}
233	}
234}
235
236// LoadStateAtSnapshotForStringTuples loads the state for a list of event type and state key pairs at a snapshot.
237// This is used when we only want to load a subset of the room state at a snapshot.
238// If there is no entry for a given event type and state key pair then it will be discarded.
239// This is typically the state before an event or the current state of a room.
240// Returns a sorted list of state entries or an error if there was a problem talking to the database.
241func (v *StateResolution) LoadStateAtSnapshotForStringTuples(
242	ctx context.Context,
243	stateNID types.StateSnapshotNID,
244	stateKeyTuples []gomatrixserverlib.StateKeyTuple,
245) ([]types.StateEntry, error) {
246	numericTuples, err := v.stringTuplesToNumericTuples(ctx, stateKeyTuples)
247	if err != nil {
248		return nil, err
249	}
250	return v.loadStateAtSnapshotForNumericTuples(ctx, stateNID, numericTuples)
251}
252
253// stringTuplesToNumericTuples converts the string state key tuples into numeric IDs
254// If there isn't a numeric ID for either the event type or the event state key then the tuple is discarded.
255// Returns an error if there was a problem talking to the database.
256func (v *StateResolution) stringTuplesToNumericTuples(
257	ctx context.Context,
258	stringTuples []gomatrixserverlib.StateKeyTuple,
259) ([]types.StateKeyTuple, error) {
260	eventTypes := make([]string, len(stringTuples))
261	stateKeys := make([]string, len(stringTuples))
262	for i := range stringTuples {
263		eventTypes[i] = stringTuples[i].EventType
264		stateKeys[i] = stringTuples[i].StateKey
265	}
266	eventTypes = util.UniqueStrings(eventTypes)
267	eventTypeMap, err := v.db.EventTypeNIDs(ctx, eventTypes)
268	if err != nil {
269		return nil, err
270	}
271	stateKeys = util.UniqueStrings(stateKeys)
272	stateKeyMap, err := v.db.EventStateKeyNIDs(ctx, stateKeys)
273	if err != nil {
274		return nil, err
275	}
276
277	var result []types.StateKeyTuple
278	for _, stringTuple := range stringTuples {
279		var numericTuple types.StateKeyTuple
280		var ok1, ok2 bool
281		numericTuple.EventTypeNID, ok1 = eventTypeMap[stringTuple.EventType]
282		numericTuple.EventStateKeyNID, ok2 = stateKeyMap[stringTuple.StateKey]
283		// Discard the tuple if there wasn't a numeric ID for either the event type or the state key.
284		if ok1 && ok2 {
285			result = append(result, numericTuple)
286		}
287	}
288
289	return result, nil
290}
291
292// loadStateAtSnapshotForNumericTuples loads the state for a list of event type and state key pairs at a snapshot.
293// This is used when we only want to load a subset of the room state at a snapshot.
294// If there is no entry for a given event type and state key pair then it will be discarded.
295// This is typically the state before an event or the current state of a room.
296// Returns a sorted list of state entries or an error if there was a problem talking to the database.
297func (v *StateResolution) loadStateAtSnapshotForNumericTuples(
298	ctx context.Context,
299	stateNID types.StateSnapshotNID,
300	stateKeyTuples []types.StateKeyTuple,
301) ([]types.StateEntry, error) {
302	stateBlockNIDLists, err := v.db.StateBlockNIDs(ctx, []types.StateSnapshotNID{stateNID})
303	if err != nil {
304		return nil, err
305	}
306	// We've asked for exactly one snapshot from the db so we should have exactly one entry in the result.
307	stateBlockNIDList := stateBlockNIDLists[0]
308
309	stateEntryLists, err := v.db.StateEntriesForTuples(
310		ctx, stateBlockNIDList.StateBlockNIDs, stateKeyTuples,
311	)
312	if err != nil {
313		return nil, err
314	}
315	stateEntriesMap := stateEntryListMap(stateEntryLists)
316
317	// Combine all the state entries for this snapshot.
318	// The order of state block NIDs in the list tells us the order to combine them in.
319	var fullState []types.StateEntry
320	for _, stateBlockNID := range stateBlockNIDList.StateBlockNIDs {
321		entries, ok := stateEntriesMap.lookup(stateBlockNID)
322		if !ok {
323			// If the block is missing from the map it means that none of its entries matched a requested tuple.
324			// This can happen if the block doesn't contain an update for one of the requested tuples.
325			// If none of the requested tuples are in the block then it can be safely skipped.
326			continue
327		}
328		fullState = append(fullState, entries...)
329	}
330
331	// Stable sort so that the most recent entry for each state key stays
332	// remains later in the list than the older entries for the same state key.
333	sort.Stable(stateEntryByStateKeySorter(fullState))
334	// Unique returns the last entry and hence the most recent entry for each state key.
335	fullState = fullState[:util.Unique(stateEntryByStateKeySorter(fullState))]
336	return fullState, nil
337}
338
339// LoadStateAfterEventsForStringTuples loads the state for a list of event type
340// and state key pairs after list of events.
341// This is used when we only want to load a subset of the room state after a list of events.
342// If there is no entry for a given event type and state key pair then it will be discarded.
343// This is typically the state before an event.
344// Returns a sorted list of state entries or an error if there was a problem talking to the database.
345func (v *StateResolution) LoadStateAfterEventsForStringTuples(
346	ctx context.Context,
347	prevStates []types.StateAtEvent,
348	stateKeyTuples []gomatrixserverlib.StateKeyTuple,
349) ([]types.StateEntry, error) {
350	numericTuples, err := v.stringTuplesToNumericTuples(ctx, stateKeyTuples)
351	if err != nil {
352		return nil, err
353	}
354	return v.loadStateAfterEventsForNumericTuples(ctx, prevStates, numericTuples)
355}
356
357func (v *StateResolution) loadStateAfterEventsForNumericTuples(
358	ctx context.Context,
359	prevStates []types.StateAtEvent,
360	stateKeyTuples []types.StateKeyTuple,
361) ([]types.StateEntry, error) {
362	if len(prevStates) == 1 {
363		// Fast path for a single event.
364		prevState := prevStates[0]
365		result, err := v.loadStateAtSnapshotForNumericTuples(
366			ctx, prevState.BeforeStateSnapshotNID, stateKeyTuples,
367		)
368		if err != nil {
369			return nil, err
370		}
371		if prevState.IsStateEvent() {
372			// The result is current the state before the requested event.
373			// We want the state after the requested event.
374			// If the requested event was a state event then we need to
375			// update that key in the result.
376			// If the requested event wasn't a state event then the state after
377			// it is the same as the state before it.
378			set := false
379			for i := range result {
380				if result[i].StateKeyTuple == prevState.StateKeyTuple {
381					result[i] = prevState.StateEntry
382					set = true
383				}
384			}
385			if !set { // no previous state exists for this event: add new state
386				result = append(result, prevState.StateEntry)
387			}
388		}
389		return result, nil
390	}
391
392	// Slow path for more that one event.
393	// Load the entire state so that we can do conflict resolution if we need to.
394	// TODO: The are some optimistations we could do here:
395	//    1) We only need to do conflict resolution if there is a conflict in the
396	//       requested tuples so we might try loading just those tuples and then
397	//       checking for conflicts.
398	//    2) When there is a conflict we still only need to load the state
399	//       needed to do conflict resolution which would save us having to load
400	//       the full state.
401
402	// TODO: Add metrics for this as it could take a long time for big rooms
403	// with large conflicts.
404	fullState, _, _, err := v.calculateStateAfterManyEvents(ctx, v.roomInfo.RoomVersion, prevStates)
405	if err != nil {
406		return nil, err
407	}
408
409	// Sort the full state so we can use it as a map.
410	sort.Sort(stateEntrySorter(fullState))
411
412	// Filter the full state down to the required tuples.
413	var result []types.StateEntry
414	for _, tuple := range stateKeyTuples {
415		eventNID, ok := stateEntryMap(fullState).lookup(tuple)
416		if ok {
417			result = append(result, types.StateEntry{
418				StateKeyTuple: tuple,
419				EventNID:      eventNID,
420			})
421		}
422	}
423	sort.Sort(stateEntrySorter(result))
424	return result, nil
425}
426
427var calculateStateDurations = prometheus.NewHistogramVec(
428	prometheus.HistogramOpts{
429		Namespace: "dendrite",
430		Subsystem: "roomserver",
431		Name:      "calculate_state_duration_milliseconds",
432		Help:      "How long it takes to calculate the state after a list of events",
433		Buckets: []float64{ // milliseconds
434			5, 10, 25, 50, 75, 100, 200, 300, 400, 500,
435			1000, 2000, 3000, 4000, 5000, 6000,
436			7000, 8000, 9000, 10000, 15000, 20000, 30000,
437		},
438	},
439	// Takes two labels:
440	//   algorithm:
441	//      The algorithm used to calculate the state or the step it failed on if it failed.
442	//      Labels starting with "_" are used to indicate when the algorithm fails halfway.
443	//  outcome:
444	//      Whether the state was successfully calculated.
445	//
446	// The possible values for algorithm are:
447	//    empty_state -> The list of events was empty so the state is empty.
448	//    no_change -> The state hasn't changed.
449	//    single_delta -> There was a single event added to the state in a way that can be encoded as a single delta
450	//    full_state_no_conflicts -> We created a new copy of the full room state, but didn't enounter any conflicts
451	//                               while doing so.
452	//    full_state_with_conflicts -> We created a new copy of the full room state and had to resolve conflicts to do so.
453	//    _load_state_block_nids -> Failed loading the state block nids for a single previous state.
454	//    _load_combined_state -> Failed to load the combined state.
455	//    _resolve_conflicts -> Failed to resolve conflicts.
456	[]string{"algorithm", "outcome"},
457)
458
459var calculateStatePrevEventLength = prometheus.NewSummaryVec(
460	prometheus.SummaryOpts{
461		Namespace: "dendrite",
462		Subsystem: "roomserver",
463		Name:      "calculate_state_prev_event_length",
464		Help:      "The length of the list of events to calculate the state after",
465	},
466	[]string{"algorithm", "outcome"},
467)
468
469var calculateStateFullStateLength = prometheus.NewSummaryVec(
470	prometheus.SummaryOpts{
471		Namespace: "dendrite",
472		Subsystem: "roomserver",
473		Name:      "calculate_state_full_state_length",
474		Help:      "The length of the full room state.",
475	},
476	[]string{"algorithm", "outcome"},
477)
478
479var calculateStateConflictLength = prometheus.NewSummaryVec(
480	prometheus.SummaryOpts{
481		Namespace: "dendrite",
482		Subsystem: "roomserver",
483		Name:      "calculate_state_conflict_state_length",
484		Help:      "The length of the conflicted room state.",
485	},
486	[]string{"algorithm", "outcome"},
487)
488
489type calculateStateMetrics struct {
490	algorithm       string
491	startTime       time.Time
492	prevEventLength int
493	fullStateLength int
494	conflictLength  int
495}
496
497func (c *calculateStateMetrics) stop(stateNID types.StateSnapshotNID, err error) (types.StateSnapshotNID, error) {
498	var outcome string
499	if err == nil {
500		outcome = "success"
501	} else {
502		outcome = "failure"
503	}
504	calculateStateDurations.WithLabelValues(c.algorithm, outcome).Observe(
505		float64(time.Since(c.startTime).Milliseconds()),
506	)
507	calculateStatePrevEventLength.WithLabelValues(c.algorithm, outcome).Observe(
508		float64(c.prevEventLength),
509	)
510	calculateStateFullStateLength.WithLabelValues(c.algorithm, outcome).Observe(
511		float64(c.fullStateLength),
512	)
513	calculateStateConflictLength.WithLabelValues(c.algorithm, outcome).Observe(
514		float64(c.conflictLength),
515	)
516	return stateNID, err
517}
518
519func init() {
520	prometheus.MustRegister(
521		calculateStateDurations, calculateStatePrevEventLength,
522		calculateStateFullStateLength, calculateStateConflictLength,
523	)
524}
525
526// CalculateAndStoreStateBeforeEvent calculates a snapshot of the state of a room before an event.
527// Stores the snapshot of the state in the database.
528// Returns a numeric ID for the snapshot of the state before the event.
529func (v *StateResolution) CalculateAndStoreStateBeforeEvent(
530	ctx context.Context,
531	event *gomatrixserverlib.Event,
532	isRejected bool,
533) (types.StateSnapshotNID, error) {
534	// Load the state at the prev events.
535	prevStates, err := v.db.StateAtEventIDs(ctx, event.PrevEventIDs())
536	if err != nil {
537		return 0, err
538	}
539
540	// The state before this event will be the state after the events that came before it.
541	return v.CalculateAndStoreStateAfterEvents(ctx, prevStates)
542}
543
544// CalculateAndStoreStateAfterEvents finds the room state after the given events.
545// Stores the resulting state in the database and returns a numeric ID for that snapshot.
546func (v *StateResolution) CalculateAndStoreStateAfterEvents(
547	ctx context.Context,
548	prevStates []types.StateAtEvent,
549) (types.StateSnapshotNID, error) {
550	metrics := calculateStateMetrics{startTime: time.Now(), prevEventLength: len(prevStates)}
551
552	if len(prevStates) == 0 {
553		// 2) There weren't any prev_events for this event so the state is
554		// empty.
555		metrics.algorithm = "empty_state"
556		stateNID, err := v.db.AddState(ctx, v.roomInfo.RoomNID, nil, nil)
557		if err != nil {
558			err = fmt.Errorf("v.db.AddState: %w", err)
559		}
560		return metrics.stop(stateNID, err)
561	}
562
563	if len(prevStates) == 1 {
564		prevState := prevStates[0]
565		if prevState.EventStateKeyNID == 0 || prevState.IsRejected {
566			// 3) None of the previous events were state events and they all
567			// have the same state, so this event has exactly the same state
568			// as the previous events.
569			// This should be the internal case.
570			metrics.algorithm = "no_change"
571			return metrics.stop(prevState.BeforeStateSnapshotNID, nil)
572		}
573		// The previous event was a state event so we need to store a copy
574		// of the previous state updated with that event.
575		stateBlockNIDLists, err := v.db.StateBlockNIDs(
576			ctx, []types.StateSnapshotNID{prevState.BeforeStateSnapshotNID},
577		)
578		if err != nil {
579			metrics.algorithm = "_load_state_blocks"
580			return metrics.stop(0, fmt.Errorf("v.db.StateBlockNIDs: %w", err))
581		}
582		stateBlockNIDs := stateBlockNIDLists[0].StateBlockNIDs
583		if len(stateBlockNIDs) < maxStateBlockNIDs {
584			// 4) The number of state data blocks is small enough that we can just
585			// add the state event as a block of size one to the end of the blocks.
586			metrics.algorithm = "single_delta"
587			stateNID, err := v.db.AddState(
588				ctx, v.roomInfo.RoomNID, stateBlockNIDs, []types.StateEntry{prevState.StateEntry},
589			)
590			if err != nil {
591				err = fmt.Errorf("v.db.AddState: %w", err)
592			}
593			return metrics.stop(stateNID, err)
594		}
595		// If there are too many deltas then we need to calculate the full state
596		// So fall through to calculateAndStoreStateAfterManyEvents
597	}
598
599	stateNID, err := v.calculateAndStoreStateAfterManyEvents(ctx, v.roomInfo.RoomNID, prevStates, metrics)
600	if err != nil {
601		return 0, fmt.Errorf("v.calculateAndStoreStateAfterManyEvents: %w", err)
602	}
603	return stateNID, nil
604}
605
606// maxStateBlockNIDs is the maximum number of state data blocks to use to encode a snapshot of room state.
607// Increasing this number means that we can encode more of the state changes as simple deltas which means that
608// we need fewer entries in the state data table. However making this number bigger will increase the size of
609// the rows in the state table itself and will require more index lookups when retrieving a snapshot.
610// TODO: Tune this to get the right balance between size and lookup performance.
611const maxStateBlockNIDs = 64
612
613// calculateAndStoreStateAfterManyEvents finds the room state after the given events.
614// This handles the slow path of calculateAndStoreStateAfterEvents for when there is more than one event.
615// Stores the resulting state and returns a numeric ID for the snapshot.
616func (v *StateResolution) calculateAndStoreStateAfterManyEvents(
617	ctx context.Context,
618	roomNID types.RoomNID,
619	prevStates []types.StateAtEvent,
620	metrics calculateStateMetrics,
621) (types.StateSnapshotNID, error) {
622	state, algorithm, conflictLength, err :=
623		v.calculateStateAfterManyEvents(ctx, v.roomInfo.RoomVersion, prevStates)
624	metrics.algorithm = algorithm
625	if err != nil {
626		return metrics.stop(0, fmt.Errorf("v.calculateStateAfterManyEvents: %w", err))
627	}
628
629	// TODO: Check if we can encode the new state as a delta against the
630	// previous state.
631	metrics.conflictLength = conflictLength
632	metrics.fullStateLength = len(state)
633	return metrics.stop(v.db.AddState(ctx, roomNID, nil, state))
634}
635
636func (v *StateResolution) calculateStateAfterManyEvents(
637	ctx context.Context, roomVersion gomatrixserverlib.RoomVersion,
638	prevStates []types.StateAtEvent,
639) (state []types.StateEntry, algorithm string, conflictLength int, err error) {
640	var combined []types.StateEntry
641	// Conflict resolution.
642	// First stage: load the state after each of the prev events.
643	combined, err = v.LoadCombinedStateAfterEvents(ctx, prevStates)
644	if err != nil {
645		err = fmt.Errorf("v.LoadCombinedStateAfterEvents: %w", err)
646		algorithm = "_load_combined_state"
647		return
648	}
649
650	// Collect all the entries with the same type and key together.
651	// We don't care about the order here because the conflict resolution
652	// algorithm doesn't depend on the order of the prev events.
653	// Remove duplicate entires.
654	combined = combined[:util.SortAndUnique(stateEntrySorter(combined))]
655
656	// Find the conflicts
657	conflicts := findDuplicateStateKeys(combined)
658
659	if len(conflicts) > 0 {
660		conflictLength = len(conflicts)
661
662		// 5) There are conflicting state events, for each conflict workout
663		// what the appropriate state event is.
664
665		// Work out which entries aren't conflicted.
666		var notConflicted []types.StateEntry
667		for _, entry := range combined {
668			if _, ok := stateEntryMap(conflicts).lookup(entry.StateKeyTuple); !ok {
669				notConflicted = append(notConflicted, entry)
670			}
671		}
672
673		var resolved []types.StateEntry
674		resolved, err = v.resolveConflicts(ctx, roomVersion, notConflicted, conflicts)
675		if err != nil {
676			err = fmt.Errorf("v.resolveConflits: %w", err)
677			algorithm = "_resolve_conflicts"
678			return
679		}
680		algorithm = "full_state_with_conflicts"
681		state = resolved[:util.SortAndUnique(stateEntrySorter(resolved))]
682	} else {
683		algorithm = "full_state_no_conflicts"
684		// 6) There weren't any conflicts
685		state = combined
686	}
687	return
688}
689
690func (v *StateResolution) resolveConflicts(
691	ctx context.Context, version gomatrixserverlib.RoomVersion,
692	notConflicted, conflicted []types.StateEntry,
693) ([]types.StateEntry, error) {
694	stateResAlgo, err := version.StateResAlgorithm()
695	if err != nil {
696		return nil, err
697	}
698	switch stateResAlgo {
699	case gomatrixserverlib.StateResV1:
700		return v.resolveConflictsV1(ctx, notConflicted, conflicted)
701	case gomatrixserverlib.StateResV2:
702		return v.resolveConflictsV2(ctx, notConflicted, conflicted)
703	}
704	return nil, fmt.Errorf("unsupported state resolution algorithm %v", stateResAlgo)
705}
706
707// resolveConflicts resolves a list of conflicted state entries. It takes two lists.
708// The first is a list of all state entries that are not conflicted.
709// The second is a list of all state entries that are conflicted
710// A state entry is conflicted when there is more than one numeric event ID for the same state key tuple.
711// Returns a list that combines the entries without conflicts with the result of state resolution for the entries with conflicts.
712// The returned list is sorted by state key tuple.
713// Returns an error if there was a problem talking to the database.
714func (v *StateResolution) resolveConflictsV1(
715	ctx context.Context,
716	notConflicted, conflicted []types.StateEntry,
717) ([]types.StateEntry, error) {
718
719	// Load the conflicted events
720	conflictedEvents, eventIDMap, err := v.loadStateEvents(ctx, conflicted)
721	if err != nil {
722		return nil, err
723	}
724
725	// Work out which auth events we need to load.
726	needed := gomatrixserverlib.StateNeededForAuth(conflictedEvents)
727
728	// Find the numeric IDs for the necessary state keys.
729	var neededStateKeys []string
730	neededStateKeys = append(neededStateKeys, needed.Member...)
731	neededStateKeys = append(neededStateKeys, needed.ThirdPartyInvite...)
732	stateKeyNIDMap, err := v.db.EventStateKeyNIDs(ctx, neededStateKeys)
733	if err != nil {
734		return nil, err
735	}
736
737	// Load the necessary auth events.
738	tuplesNeeded := v.stateKeyTuplesNeeded(stateKeyNIDMap, needed)
739	var authEntries []types.StateEntry
740	for _, tuple := range tuplesNeeded {
741		if eventNID, ok := stateEntryMap(notConflicted).lookup(tuple); ok {
742			authEntries = append(authEntries, types.StateEntry{
743				StateKeyTuple: tuple,
744				EventNID:      eventNID,
745			})
746		}
747	}
748	authEvents, _, err := v.loadStateEvents(ctx, authEntries)
749	if err != nil {
750		return nil, err
751	}
752
753	// Resolve the conflicts.
754	resolvedEvents := gomatrixserverlib.ResolveStateConflicts(conflictedEvents, authEvents)
755
756	// Map from the full events back to numeric state entries.
757	for _, resolvedEvent := range resolvedEvents {
758		entry, ok := eventIDMap[resolvedEvent.EventID()]
759		if !ok {
760			panic(fmt.Errorf("missing state entry for event ID %q", resolvedEvent.EventID()))
761		}
762		notConflicted = append(notConflicted, entry)
763	}
764
765	// Sort the result so it can be searched.
766	sort.Sort(stateEntrySorter(notConflicted))
767	return notConflicted, nil
768}
769
770// resolveConflicts resolves a list of conflicted state entries. It takes two lists.
771// The first is a list of all state entries that are not conflicted.
772// The second is a list of all state entries that are conflicted
773// A state entry is conflicted when there is more than one numeric event ID for the same state key tuple.
774// Returns a list that combines the entries without conflicts with the result of state resolution for the entries with conflicts.
775// The returned list is sorted by state key tuple.
776// Returns an error if there was a problem talking to the database.
777func (v *StateResolution) resolveConflictsV2(
778	ctx context.Context,
779	notConflicted, conflicted []types.StateEntry,
780) ([]types.StateEntry, error) {
781	estimate := len(conflicted) + len(notConflicted)
782	eventIDMap := make(map[string]types.StateEntry, estimate)
783
784	// Load the conflicted events
785	conflictedEvents, conflictedEventMap, err := v.loadStateEvents(ctx, conflicted)
786	if err != nil {
787		return nil, err
788	}
789	for k, v := range conflictedEventMap {
790		eventIDMap[k] = v
791	}
792
793	// Load the non-conflicted events
794	nonConflictedEvents, nonConflictedEventMap, err := v.loadStateEvents(ctx, notConflicted)
795	if err != nil {
796		return nil, err
797	}
798	for k, v := range nonConflictedEventMap {
799		eventIDMap[k] = v
800	}
801
802	// For each conflicted event, we will add a new set of auth events. Auth
803	// events may be duplicated across these sets but that's OK.
804	authSets := make(map[string][]*gomatrixserverlib.Event, len(conflicted))
805	authEvents := make([]*gomatrixserverlib.Event, 0, estimate*3)
806	authDifference := make([]*gomatrixserverlib.Event, 0, estimate)
807
808	// For each conflicted event, let's try and get the needed auth events.
809	neededStateKeys := make([]string, 16)
810	authEntries := make([]types.StateEntry, 16)
811	for _, conflictedEvent := range conflictedEvents {
812		// Work out which auth events we need to load.
813		key := conflictedEvent.EventID()
814		needed := gomatrixserverlib.StateNeededForAuth([]*gomatrixserverlib.Event{conflictedEvent})
815
816		// Find the numeric IDs for the necessary state keys.
817		neededStateKeys = neededStateKeys[:0]
818		neededStateKeys = append(neededStateKeys, needed.Member...)
819		neededStateKeys = append(neededStateKeys, needed.ThirdPartyInvite...)
820		stateKeyNIDMap, err := v.db.EventStateKeyNIDs(ctx, neededStateKeys)
821		if err != nil {
822			return nil, err
823		}
824
825		// Load the necessary auth events.
826		tuplesNeeded := v.stateKeyTuplesNeeded(stateKeyNIDMap, needed)
827		authEntries = authEntries[:0]
828		for _, tuple := range tuplesNeeded {
829			if eventNID, ok := stateEntryMap(notConflicted).lookup(tuple); ok {
830				authEntries = append(authEntries, types.StateEntry{
831					StateKeyTuple: tuple,
832					EventNID:      eventNID,
833				})
834			}
835		}
836
837		// Store the newly found auth events in the auth set for this event.
838		authSets[key], _, err = v.loadStateEvents(ctx, authEntries)
839		if err != nil {
840			return nil, err
841		}
842		authEvents = append(authEvents, authSets[key]...)
843	}
844
845	// This function helps us to work out whether an event exists in one of the
846	// auth sets.
847	isInAuthList := func(k string, event *gomatrixserverlib.Event) bool {
848		for _, e := range authSets[k] {
849			if e.EventID() == event.EventID() {
850				return true
851			}
852		}
853		return false
854	}
855
856	// This function works out if an event exists in all of the auth sets.
857	isInAllAuthLists := func(event *gomatrixserverlib.Event) bool {
858		found := true
859		for k := range authSets {
860			found = found && isInAuthList(k, event)
861		}
862		return found
863	}
864
865	// Look through all of the auth events that we've been given and work out if
866	// there are any events which don't appear in all of the auth sets. If they
867	// don't then we add them to the auth difference.
868	for _, event := range authEvents {
869		if !isInAllAuthLists(event) {
870			authDifference = append(authDifference, event)
871		}
872	}
873
874	// Resolve the conflicts.
875	resolvedEvents := gomatrixserverlib.ResolveStateConflictsV2(
876		conflictedEvents,
877		nonConflictedEvents,
878		authEvents,
879		authDifference,
880	)
881
882	// Map from the full events back to numeric state entries.
883	for _, resolvedEvent := range resolvedEvents {
884		entry, ok := eventIDMap[resolvedEvent.EventID()]
885		if !ok {
886			panic(fmt.Errorf("missing state entry for event ID %q", resolvedEvent.EventID()))
887		}
888		notConflicted = append(notConflicted, entry)
889	}
890
891	// Sort the result so it can be searched.
892	sort.Sort(stateEntrySorter(notConflicted))
893	return notConflicted, nil
894}
895
896// stateKeyTuplesNeeded works out which numeric state key tuples we need to authenticate some events.
897func (v *StateResolution) stateKeyTuplesNeeded(stateKeyNIDMap map[string]types.EventStateKeyNID, stateNeeded gomatrixserverlib.StateNeeded) []types.StateKeyTuple {
898	var keyTuples []types.StateKeyTuple
899	if stateNeeded.Create {
900		keyTuples = append(keyTuples, types.StateKeyTuple{
901			EventTypeNID:     types.MRoomCreateNID,
902			EventStateKeyNID: types.EmptyStateKeyNID,
903		})
904	}
905	if stateNeeded.PowerLevels {
906		keyTuples = append(keyTuples, types.StateKeyTuple{
907			EventTypeNID:     types.MRoomPowerLevelsNID,
908			EventStateKeyNID: types.EmptyStateKeyNID,
909		})
910	}
911	if stateNeeded.JoinRules {
912		keyTuples = append(keyTuples, types.StateKeyTuple{
913			EventTypeNID:     types.MRoomJoinRulesNID,
914			EventStateKeyNID: types.EmptyStateKeyNID,
915		})
916	}
917	for _, member := range stateNeeded.Member {
918		stateKeyNID, ok := stateKeyNIDMap[member]
919		if ok {
920			keyTuples = append(keyTuples, types.StateKeyTuple{
921				EventTypeNID:     types.MRoomMemberNID,
922				EventStateKeyNID: stateKeyNID,
923			})
924		}
925	}
926	for _, token := range stateNeeded.ThirdPartyInvite {
927		stateKeyNID, ok := stateKeyNIDMap[token]
928		if ok {
929			keyTuples = append(keyTuples, types.StateKeyTuple{
930				EventTypeNID:     types.MRoomThirdPartyInviteNID,
931				EventStateKeyNID: stateKeyNID,
932			})
933		}
934	}
935	return keyTuples
936}
937
938// loadStateEvents loads the matrix events for a list of state entries.
939// Returns a list of state events in no particular order and a map from string event ID back to state entry.
940// The map can be used to recover which numeric state entry a given event is for.
941// Returns an error if there was a problem talking to the database.
942func (v *StateResolution) loadStateEvents(
943	ctx context.Context, entries []types.StateEntry,
944) ([]*gomatrixserverlib.Event, map[string]types.StateEntry, error) {
945	result := make([]*gomatrixserverlib.Event, 0, len(entries))
946	eventEntries := make([]types.StateEntry, 0, len(entries))
947	eventNIDs := make([]types.EventNID, 0, len(entries))
948	for _, entry := range entries {
949		if e, ok := v.events[entry.EventNID]; ok {
950			result = append(result, e)
951		} else {
952			eventEntries = append(eventEntries, entry)
953			eventNIDs = append(eventNIDs, entry.EventNID)
954		}
955	}
956	events, err := v.db.Events(ctx, eventNIDs)
957	if err != nil {
958		return nil, nil, err
959	}
960	eventIDMap := map[string]types.StateEntry{}
961	for _, entry := range eventEntries {
962		event, ok := eventMap(events).lookup(entry.EventNID)
963		if !ok {
964			panic(fmt.Errorf("corrupt DB: Missing event numeric ID %d", entry.EventNID))
965		}
966		result = append(result, event.Event)
967		eventIDMap[event.Event.EventID()] = entry
968		v.events[entry.EventNID] = event.Event
969	}
970	return result, eventIDMap, nil
971}
972
973// findDuplicateStateKeys finds the state entries where the state key tuple appears more than once in a sorted list.
974// Returns a sorted list of those state entries.
975func findDuplicateStateKeys(a []types.StateEntry) []types.StateEntry {
976	var result []types.StateEntry
977	// j is the starting index of a block of entries with the same state key tuple.
978	j := 0
979	for i := 1; i < len(a); i++ {
980		// Check if the state key tuple matches the start of the block
981		if a[j].StateKeyTuple != a[i].StateKeyTuple {
982			// If the state key tuple is different then we've reached the end of a block of duplicates.
983			// Check if the size of the block is bigger than one.
984			// If the size is one then there was only a single entry with that state key tuple so we don't add it to the result
985			if j+1 != i {
986				// Add the block to the result.
987				result = append(result, a[j:i]...)
988			}
989			// Start a new block for the next state key tuple.
990			j = i
991		}
992	}
993	// Check if the last block with the same state key tuple had more than one event in it.
994	if j+1 != len(a) {
995		result = append(result, a[j:]...)
996	}
997	return result
998}
999
1000type stateEntrySorter []types.StateEntry
1001
1002func (s stateEntrySorter) Len() int           { return len(s) }
1003func (s stateEntrySorter) Less(i, j int) bool { return s[i].LessThan(s[j]) }
1004func (s stateEntrySorter) Swap(i, j int)      { s[i], s[j] = s[j], s[i] }
1005
1006type stateBlockNIDListMap []types.StateBlockNIDList
1007
1008func (m stateBlockNIDListMap) lookup(stateNID types.StateSnapshotNID) (stateBlockNIDs []types.StateBlockNID, ok bool) {
1009	list := []types.StateBlockNIDList(m)
1010	i := sort.Search(len(list), func(i int) bool {
1011		return list[i].StateSnapshotNID >= stateNID
1012	})
1013	if i < len(list) && list[i].StateSnapshotNID == stateNID {
1014		ok = true
1015		stateBlockNIDs = list[i].StateBlockNIDs
1016	}
1017	return
1018}
1019
1020type stateEntryListMap []types.StateEntryList
1021
1022func (m stateEntryListMap) lookup(stateBlockNID types.StateBlockNID) (stateEntries []types.StateEntry, ok bool) {
1023	list := []types.StateEntryList(m)
1024	i := sort.Search(len(list), func(i int) bool {
1025		return list[i].StateBlockNID >= stateBlockNID
1026	})
1027	if i < len(list) && list[i].StateBlockNID == stateBlockNID {
1028		ok = true
1029		stateEntries = list[i].StateEntries
1030	}
1031	return
1032}
1033
1034type stateEntryByStateKeySorter []types.StateEntry
1035
1036func (s stateEntryByStateKeySorter) Len() int { return len(s) }
1037func (s stateEntryByStateKeySorter) Less(i, j int) bool {
1038	return s[i].StateKeyTuple.LessThan(s[j].StateKeyTuple)
1039}
1040func (s stateEntryByStateKeySorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
1041
1042type stateNIDSorter []types.StateSnapshotNID
1043
1044func (s stateNIDSorter) Len() int           { return len(s) }
1045func (s stateNIDSorter) Less(i, j int) bool { return s[i] < s[j] }
1046func (s stateNIDSorter) Swap(i, j int)      { s[i], s[j] = s[j], s[i] }
1047
1048func UniqueStateSnapshotNIDs(nids []types.StateSnapshotNID) []types.StateSnapshotNID {
1049	return nids[:util.SortAndUnique(stateNIDSorter(nids))]
1050}
1051
1052type stateBlockNIDSorter []types.StateBlockNID
1053
1054func (s stateBlockNIDSorter) Len() int           { return len(s) }
1055func (s stateBlockNIDSorter) Less(i, j int) bool { return s[i] < s[j] }
1056func (s stateBlockNIDSorter) Swap(i, j int)      { s[i], s[j] = s[j], s[i] }
1057
1058func uniqueStateBlockNIDs(nids []types.StateBlockNID) []types.StateBlockNID {
1059	return nids[:util.SortAndUnique(stateBlockNIDSorter(nids))]
1060}
1061
1062// Map from event type, state key tuple to numeric event ID.
1063// Implemented using binary search on a sorted array.
1064type stateEntryMap []types.StateEntry
1065
1066// lookup an entry in the event map.
1067func (m stateEntryMap) lookup(stateKey types.StateKeyTuple) (eventNID types.EventNID, ok bool) {
1068	// Since the list is sorted we can implement this using binary search.
1069	// This is faster than using a hash map.
1070	// We don't have to worry about pathological cases because the keys are fixed
1071	// size and are controlled by us.
1072	list := []types.StateEntry(m)
1073	i := sort.Search(len(list), func(i int) bool {
1074		return !list[i].StateKeyTuple.LessThan(stateKey)
1075	})
1076	if i < len(list) && list[i].StateKeyTuple == stateKey {
1077		ok = true
1078		eventNID = list[i].EventNID
1079	}
1080	return
1081}
1082
1083// Map from numeric event ID to event.
1084// Implemented using binary search on a sorted array.
1085type eventMap []types.Event
1086
1087// lookup an entry in the event map.
1088func (m eventMap) lookup(eventNID types.EventNID) (event *types.Event, ok bool) {
1089	// Since the list is sorted we can implement this using binary search.
1090	// This is faster than using a hash map.
1091	// We don't have to worry about pathological cases because the keys are fixed
1092	// size are controlled by us.
1093	list := []types.Event(m)
1094	i := sort.Search(len(list), func(i int) bool {
1095		return list[i].EventNID >= eventNID
1096	})
1097	if i < len(list) && list[i].EventNID == eventNID {
1098		ok = true
1099		event = &list[i]
1100	}
1101	return
1102}
1103