1/* Copyright 2017 Vector Creations Ltd
2 *
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 *     http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16package gomatrixserverlib
17
18import (
19	"bytes"
20	"crypto/sha1"
21	"fmt"
22	"sort"
23)
24
25// ResolveStateConflicts takes a list of state events with conflicting state keys
26// and works out which event should be used for each state event.
27func ResolveStateConflicts(conflicted []*Event, authEvents []*Event) []*Event {
28	var r stateResolver
29	r.resolvedThirdPartyInvites = map[string]*Event{}
30	r.resolvedMembers = map[string]*Event{}
31	// Group the conflicted events by type and state key.
32	r.addConflicted(conflicted)
33	// Add the unconflicted auth events needed for auth checks.
34	for i := range authEvents {
35		r.addAuthEvent(authEvents[i])
36	}
37	// Resolve the conflicted auth events.
38	r.resolveAndAddAuthBlocks([][]*Event{r.creates})
39	r.resolveAndAddAuthBlocks([][]*Event{r.powerLevels})
40	r.resolveAndAddAuthBlocks([][]*Event{r.joinRules})
41	r.resolveAndAddAuthBlocks(r.thirdPartyInvites)
42	r.resolveAndAddAuthBlocks(r.members)
43	// Resolve any other conflicted state events.
44	for _, block := range r.others {
45		if event := r.resolveNormalBlock(block); event != nil {
46			r.result = append(r.result, event)
47		}
48	}
49	return r.result
50}
51
52// A stateResolver tracks the internal state of the state resolution algorithm
53// It has 3 sections:
54//
55//  * Lists of lists of events to resolve grouped by event type and state key.
56//  * The resolved auth events grouped by type and state key.
57//  * A List of resolved events.
58//
59// It implements the AuthEvents interface and can be used for running auth checks.
60type stateResolver struct {
61	// Lists of lists of events to resolve grouped by event type and state key:
62	//   * creates, powerLevels, joinRules have empty state keys.
63	//   * members and thirdPartyInvites are grouped by state key.
64	//   * the others are grouped by the pair of type and state key.
65	creates           []*Event
66	powerLevels       []*Event
67	joinRules         []*Event
68	thirdPartyInvites [][]*Event
69	members           [][]*Event
70	others            [][]*Event
71	// The resolved auth events grouped by type and state key.
72	resolvedCreate            *Event
73	resolvedPowerLevels       *Event
74	resolvedJoinRules         *Event
75	resolvedThirdPartyInvites map[string]*Event
76	resolvedMembers           map[string]*Event
77	// The list of resolved events.
78	// This will contain one entry for each conflicted event type and state key.
79	result []*Event
80}
81
82func (r *stateResolver) Create() (*Event, error) {
83	return r.resolvedCreate, nil
84}
85
86func (r *stateResolver) PowerLevels() (*Event, error) {
87	return r.resolvedPowerLevels, nil
88}
89
90func (r *stateResolver) JoinRules() (*Event, error) {
91	return r.resolvedJoinRules, nil
92}
93
94func (r *stateResolver) ThirdPartyInvite(key string) (*Event, error) {
95	return r.resolvedThirdPartyInvites[key], nil
96}
97
98func (r *stateResolver) Member(key string) (*Event, error) {
99	return r.resolvedMembers[key], nil
100}
101
102func (r *stateResolver) addConflicted(events []*Event) { // nolint: gocyclo
103	type conflictKey struct {
104		eventType string
105		stateKey  string
106	}
107	offsets := map[conflictKey]int{}
108	// Split up the conflicted events into blocks with the same type and state key.
109	// Separate the auth events into specifically named lists because they have
110	// special rules for state resolution.
111	for _, event := range events {
112		key := conflictKey{event.Type(), *event.StateKey()}
113		// Work out which block to add the event to.
114		// By default we add the event to a block in the others list.
115		blockList := &r.others
116		switch key.eventType {
117		case MRoomCreate:
118			if key.stateKey == "" {
119				r.creates = append(r.creates, event)
120				continue
121			}
122		case MRoomPowerLevels:
123			if key.stateKey == "" {
124				r.powerLevels = append(r.powerLevels, event)
125				continue
126			}
127		case MRoomJoinRules:
128			if key.stateKey == "" {
129				r.joinRules = append(r.joinRules, event)
130				continue
131			}
132		case MRoomMember:
133			blockList = &r.members
134		case MRoomThirdPartyInvite:
135			blockList = &r.thirdPartyInvites
136		}
137		// We need to find an entry for the state key in a block list.
138		offset, ok := offsets[key]
139		if !ok {
140			// This is the first time we've seen that state key so we add a
141			// new block to the block list.
142			offset = len(*blockList)
143			*blockList = append(*blockList, nil)
144			offsets[key] = offset
145		}
146		// Get the address of the block in the block list.
147		block := &(*blockList)[offset]
148		// Add the event to the block.
149		*block = append(*block, event)
150	}
151}
152
153// Add an event to the resolved auth events.
154func (r *stateResolver) addAuthEvent(event *Event) {
155	switch event.Type() {
156	case MRoomCreate:
157		if event.StateKeyEquals("") {
158			r.resolvedCreate = event
159		}
160	case MRoomPowerLevels:
161		if event.StateKeyEquals("") {
162			r.resolvedPowerLevels = event
163		}
164	case MRoomJoinRules:
165		if event.StateKeyEquals("") {
166			r.resolvedJoinRules = event
167		}
168	case MRoomMember:
169		r.resolvedMembers[*event.StateKey()] = event
170	case MRoomThirdPartyInvite:
171		r.resolvedThirdPartyInvites[*event.StateKey()] = event
172	}
173}
174
175// Remove the auth event with the given type and state key.
176func (r *stateResolver) removeAuthEvent(eventType, stateKey string) {
177	switch eventType {
178	case MRoomCreate:
179		if stateKey == "" {
180			r.resolvedCreate = nil
181		}
182	case MRoomPowerLevels:
183		if stateKey == "" {
184			r.resolvedPowerLevels = nil
185		}
186	case MRoomJoinRules:
187		if stateKey == "" {
188			r.resolvedJoinRules = nil
189		}
190	case MRoomMember:
191		r.resolvedMembers[stateKey] = nil
192	case MRoomThirdPartyInvite:
193		r.resolvedThirdPartyInvites[stateKey] = nil
194	}
195}
196
197// resolveAndAddAuthBlocks resolves each block of conflicting auth state events in a list of blocks
198// where all the blocks have the same event type.
199// Once every block has been resolved the resulting events are added to the events used for auth checks.
200// This is called once per auth event type and state key pair.
201func (r *stateResolver) resolveAndAddAuthBlocks(blocks [][]*Event) {
202	start := len(r.result)
203	for _, block := range blocks {
204		if len(block) == 0 {
205			continue
206		}
207		if event := r.resolveAuthBlock(block); event != nil {
208			r.result = append(r.result, event)
209		}
210	}
211	// Only add the events to the auth events once all of the events with that type have been resolved.
212	// (SPEC: This is done to avoid the result of state resolution depending on the iteration order)
213	for i := start; i < len(r.result); i++ {
214		r.addAuthEvent(r.result[i])
215	}
216}
217
218// resolveAuthBlock resolves a block of auth events with the same state key to a single event.
219func (r *stateResolver) resolveAuthBlock(events []*Event) *Event {
220	// Sort the events by depth and sha1 of event ID
221	block := sortConflictedEventsByDepthAndSHA1(events)
222
223	// Pick the "oldest" event, that is the one with the lowest depth, as the first candidate.
224	// If none of the newer events pass auth checks against this event then we pick the "oldest" event.
225	// (SPEC: This ensures that we always pick a state event for this type and state key.
226	//  Note that if all the events fail auth checks we will still pick the "oldest" event.)
227	result := block[0].event
228	// Temporarily add the candidate event to the auth events.
229	r.addAuthEvent(result)
230	for i := 1; i < len(block); i++ {
231		event := block[i].event
232		// Check if the next event passes authentication checks against the current candidate.
233		// (SPEC: This ensures that "ban" events cannot be replaced by "join" events through a conflict)
234		if Allowed(event, r) == nil {
235			// If the event passes authentication checks pick it as the current candidate.
236			// (SPEC: This prefers newer events so that we don't flip a valid state back to a previous version)
237			result = event
238			r.addAuthEvent(result)
239		} else {
240			// If the authentication check fails then we stop iterating the list and return the current candidate.
241			break
242		}
243	}
244	// Discard the event from the auth events.
245	// We'll add it back later when all events of the same type have been resolved.
246	// (SPEC: This is done to avoid the result of state resolution depending on the iteration order)
247	r.removeAuthEvent(result.Type(), *result.StateKey())
248	return result
249}
250
251// resolveNormalBlock resolves a block of normal state events with the same state key to a single event.
252func (r *stateResolver) resolveNormalBlock(events []*Event) *Event {
253	// Sort the events by depth and sha1 of event ID
254	block := sortConflictedEventsByDepthAndSHA1(events)
255	// Start at the "newest" event, that is the one with the highest depth, and go
256	// backward through the list until we find one that passes authentication checks.
257	// (SPEC: This prefers newer events so that we don't flip a valid state back to a previous version)
258	for i := len(block) - 1; i > 0; i-- {
259		event := block[i].event
260		if Allowed(event, r) == nil {
261			return event
262		}
263	}
264	// If all the auth checks for newer events fail then we pick the oldest event.
265	// (SPEC: This ensures that we always pick a state event for this type and state key.
266	//  Note that if all the events fail auth checks we will still pick the "oldest" event.)
267	return block[0].event
268}
269
270// sortConflictedEventsByDepthAndSHA1 sorts by ascending depth and descending sha1 of event ID.
271func sortConflictedEventsByDepthAndSHA1(events []*Event) []conflictedEvent {
272	block := make([]conflictedEvent, len(events))
273	for i := range events {
274		event := events[i]
275		block[i] = conflictedEvent{
276			depth:       event.Depth(),
277			eventIDSHA1: sha1.Sum([]byte(event.EventID())),
278			event:       event,
279		}
280	}
281	sort.Sort(conflictedEventSorter(block))
282	return block
283}
284
285// A conflictedEvent is used to sort the events in a block by ascending depth and descending sha1 of event ID.
286// (SPEC: We use the SHA1 of the event ID as an arbitrary tie breaker between events with the same depth)
287type conflictedEvent struct {
288	depth       int64
289	eventIDSHA1 [sha1.Size]byte
290	event       *Event
291}
292
293// A conflictedEventSorter is used to sort the events using sort.Sort.
294type conflictedEventSorter []conflictedEvent
295
296func (s conflictedEventSorter) Len() int {
297	return len(s)
298}
299
300func (s conflictedEventSorter) Less(i, j int) bool {
301	if s[i].depth == s[j].depth {
302		return bytes.Compare(s[i].eventIDSHA1[:], s[j].eventIDSHA1[:]) > 0
303	}
304	return s[i].depth < s[j].depth
305}
306
307func (s conflictedEventSorter) Swap(i, j int) {
308	s[i], s[j] = s[j], s[i]
309}
310
311// ResolveConflicts performs state resolution on the input events, returning the
312// resolved state. It will automatically decide which state resolution algorithm
313// to use, depending on the room version. `events` should be all the state events
314// to resolve. `authEvents` should be the entire set of auth_events for these `events`.
315// Returns an error if the state resolution algorithm cannot be determined.
316func ResolveConflicts(
317	version RoomVersion,
318	events []*Event,
319	authEvents []*Event,
320) ([]*Event, error) {
321	type stateKeyTuple struct {
322		Type     string
323		StateKey string
324	}
325
326	// Prepare our data structures.
327	eventIDMap := map[string]struct{}{}
328	eventMap := make(map[stateKeyTuple][]*Event)
329	var conflicted, notConflicted, resolved []*Event
330
331	// Run through all of the events that we were given and sort them
332	// into a map, sorted by (event_type, state_key) tuple. This means
333	// that we can easily spot events that are "conflicted", e.g.
334	// there are duplicate values for the same tuple key.
335	for _, event := range events {
336		if _, ok := eventIDMap[event.EventID()]; ok {
337			continue
338		}
339		eventIDMap[event.EventID()] = struct{}{}
340		if event.StateKey() == nil {
341			// Ignore events that are not state events.
342			continue
343		}
344		// Append the events if there is already a conflicted list for
345		// this tuple key, create it if not.
346		tuple := stateKeyTuple{event.Type(), *event.StateKey()}
347		eventMap[tuple] = append(eventMap[tuple], event)
348	}
349
350	// Split out the events in the map into conflicted and unconflicted
351	// buckets. The conflicted events will be ran through state res,
352	// whereas unconfliced events will always going to appear in the
353	// final resolved state.
354	for _, list := range eventMap {
355		if len(list) > 1 {
356			conflicted = append(conflicted, list...)
357		} else {
358			notConflicted = append(notConflicted, list...)
359		}
360	}
361
362	// Work out which state resolution algorithm we want to run for
363	// the room version.
364	stateResAlgo, err := version.StateResAlgorithm()
365	if err != nil {
366		return nil, err
367	}
368	switch stateResAlgo {
369	case StateResV1:
370		// Currently state res v1 doesn't handle unconflicted events
371		// for us, like state res v2 does, so we will need to add the
372		// unconflicted events into the state ourselves.
373		// TODO: Fix state res v1 so this is handled for the caller.
374		resolved = ResolveStateConflicts(conflicted, authEvents)
375		resolved = append(resolved, notConflicted...)
376	case StateResV2:
377		// TODO: auth difference here?
378		resolved = ResolveStateConflictsV2(conflicted, notConflicted, authEvents, authEvents)
379	default:
380		return nil, fmt.Errorf("unsupported state resolution algorithm %v", stateResAlgo)
381	}
382
383	// Return the final resolved state events, including both the
384	// resolved set of conflicted events, and the unconflicted events.
385	return resolved, nil
386}
387