1// Copyright 2017 Vector Creations Ltd
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package helpers
16
17import (
18	"context"
19	"fmt"
20	"sort"
21
22	"github.com/matrix-org/dendrite/roomserver/state"
23	"github.com/matrix-org/dendrite/roomserver/storage"
24	"github.com/matrix-org/dendrite/roomserver/types"
25	"github.com/matrix-org/gomatrixserverlib"
26)
27
28// CheckForSoftFail returns true if the event should be soft-failed
29// and false otherwise. The return error value should be checked before
30// the soft-fail bool.
31func CheckForSoftFail(
32	ctx context.Context,
33	db storage.Database,
34	event *gomatrixserverlib.HeaderedEvent,
35	stateEventIDs []string,
36) (bool, error) {
37	rewritesState := len(stateEventIDs) > 1
38
39	var authStateEntries []types.StateEntry
40	var err error
41	if rewritesState {
42		authStateEntries, err = db.StateEntriesForEventIDs(ctx, stateEventIDs)
43		if err != nil {
44			return true, fmt.Errorf("StateEntriesForEventIDs failed: %w", err)
45		}
46	} else {
47		// Work out if the room exists.
48		var roomInfo *types.RoomInfo
49		roomInfo, err = db.RoomInfo(ctx, event.RoomID())
50		if err != nil {
51			return false, fmt.Errorf("db.RoomNID: %w", err)
52		}
53		if roomInfo == nil || roomInfo.IsStub {
54			return false, nil
55		}
56
57		// Then get the state entries for the current state snapshot.
58		// We'll use this to check if the event is allowed right now.
59		roomState := state.NewStateResolution(db, *roomInfo)
60		authStateEntries, err = roomState.LoadStateAtSnapshot(ctx, roomInfo.StateSnapshotNID)
61		if err != nil {
62			return true, fmt.Errorf("roomState.LoadStateAtSnapshot: %w", err)
63		}
64	}
65
66	// As a special case, it's possible that the room will have no
67	// state because we haven't received a m.room.create event yet.
68	// If we're now processing the first create event then never
69	// soft-fail it.
70	if len(authStateEntries) == 0 && event.Type() == gomatrixserverlib.MRoomCreate {
71		return false, nil
72	}
73
74	// Work out which of the state events we actually need.
75	stateNeeded := gomatrixserverlib.StateNeededForAuth([]*gomatrixserverlib.Event{event.Unwrap()})
76
77	// Load the actual auth events from the database.
78	authEvents, err := loadAuthEvents(ctx, db, stateNeeded, authStateEntries)
79	if err != nil {
80		return true, fmt.Errorf("loadAuthEvents: %w", err)
81	}
82
83	// Check if the event is allowed.
84	if err = gomatrixserverlib.Allowed(event.Event, &authEvents); err != nil {
85		// return true, nil
86		return true, err
87	}
88	return false, nil
89}
90
91// CheckAuthEvents checks that the event passes authentication checks
92// Returns the numeric IDs for the auth events.
93func CheckAuthEvents(
94	ctx context.Context,
95	db storage.Database,
96	event *gomatrixserverlib.HeaderedEvent,
97	authEventIDs []string,
98) ([]types.EventNID, error) {
99	// Grab the numeric IDs for the supplied auth state events from the database.
100	authStateEntries, err := db.StateEntriesForEventIDs(ctx, authEventIDs)
101	if err != nil {
102		return nil, fmt.Errorf("db.StateEntriesForEventIDs: %w", err)
103	}
104	authStateEntries = types.DeduplicateStateEntries(authStateEntries)
105
106	// Work out which of the state events we actually need.
107	stateNeeded := gomatrixserverlib.StateNeededForAuth([]*gomatrixserverlib.Event{event.Unwrap()})
108
109	// Load the actual auth events from the database.
110	authEvents, err := loadAuthEvents(ctx, db, stateNeeded, authStateEntries)
111	if err != nil {
112		return nil, fmt.Errorf("loadAuthEvents: %w", err)
113	}
114
115	// Check if the event is allowed.
116	if err = gomatrixserverlib.Allowed(event.Event, &authEvents); err != nil {
117		return nil, err
118	}
119
120	// Return the numeric IDs for the auth events.
121	result := make([]types.EventNID, len(authStateEntries))
122	for i := range authStateEntries {
123		result[i] = authStateEntries[i].EventNID
124	}
125	return result, nil
126}
127
128type authEvents struct {
129	stateKeyNIDMap map[string]types.EventStateKeyNID
130	state          stateEntryMap
131	events         EventMap
132}
133
134// Create implements gomatrixserverlib.AuthEventProvider
135func (ae *authEvents) Create() (*gomatrixserverlib.Event, error) {
136	return ae.lookupEventWithEmptyStateKey(types.MRoomCreateNID), nil
137}
138
139// PowerLevels implements gomatrixserverlib.AuthEventProvider
140func (ae *authEvents) PowerLevels() (*gomatrixserverlib.Event, error) {
141	return ae.lookupEventWithEmptyStateKey(types.MRoomPowerLevelsNID), nil
142}
143
144// JoinRules implements gomatrixserverlib.AuthEventProvider
145func (ae *authEvents) JoinRules() (*gomatrixserverlib.Event, error) {
146	return ae.lookupEventWithEmptyStateKey(types.MRoomJoinRulesNID), nil
147}
148
149// Memmber implements gomatrixserverlib.AuthEventProvider
150func (ae *authEvents) Member(stateKey string) (*gomatrixserverlib.Event, error) {
151	return ae.lookupEvent(types.MRoomMemberNID, stateKey), nil
152}
153
154// ThirdPartyInvite implements gomatrixserverlib.AuthEventProvider
155func (ae *authEvents) ThirdPartyInvite(stateKey string) (*gomatrixserverlib.Event, error) {
156	return ae.lookupEvent(types.MRoomThirdPartyInviteNID, stateKey), nil
157}
158
159func (ae *authEvents) lookupEventWithEmptyStateKey(typeNID types.EventTypeNID) *gomatrixserverlib.Event {
160	eventNID, ok := ae.state.lookup(types.StateKeyTuple{
161		EventTypeNID:     typeNID,
162		EventStateKeyNID: types.EmptyStateKeyNID,
163	})
164	if !ok {
165		return nil
166	}
167	event, ok := ae.events.Lookup(eventNID)
168	if !ok {
169		return nil
170	}
171	return event.Event
172}
173
174func (ae *authEvents) lookupEvent(typeNID types.EventTypeNID, stateKey string) *gomatrixserverlib.Event {
175	stateKeyNID, ok := ae.stateKeyNIDMap[stateKey]
176	if !ok {
177		return nil
178	}
179	eventNID, ok := ae.state.lookup(types.StateKeyTuple{
180		EventTypeNID:     typeNID,
181		EventStateKeyNID: stateKeyNID,
182	})
183	if !ok {
184		return nil
185	}
186	event, ok := ae.events.Lookup(eventNID)
187	if !ok {
188		return nil
189	}
190	return event.Event
191}
192
193// loadAuthEvents loads the events needed for authentication from the supplied room state.
194func loadAuthEvents(
195	ctx context.Context,
196	db storage.Database,
197	needed gomatrixserverlib.StateNeeded,
198	state []types.StateEntry,
199) (result authEvents, err error) {
200	// Look up the numeric IDs for the state keys needed for auth.
201	var neededStateKeys []string
202	neededStateKeys = append(neededStateKeys, needed.Member...)
203	neededStateKeys = append(neededStateKeys, needed.ThirdPartyInvite...)
204	if result.stateKeyNIDMap, err = db.EventStateKeyNIDs(ctx, neededStateKeys); err != nil {
205		return
206	}
207
208	// Load the events we need.
209	result.state = state
210	var eventNIDs []types.EventNID
211	keyTuplesNeeded := stateKeyTuplesNeeded(result.stateKeyNIDMap, needed)
212	for _, keyTuple := range keyTuplesNeeded {
213		eventNID, ok := result.state.lookup(keyTuple)
214		if ok {
215			eventNIDs = append(eventNIDs, eventNID)
216		}
217	}
218	if result.events, err = db.Events(ctx, eventNIDs); err != nil {
219		return
220	}
221	return
222}
223
224// stateKeyTuplesNeeded works out which numeric state key tuples we need to authenticate some events.
225func stateKeyTuplesNeeded(
226	stateKeyNIDMap map[string]types.EventStateKeyNID,
227	stateNeeded gomatrixserverlib.StateNeeded,
228) []types.StateKeyTuple {
229	var keyTuples []types.StateKeyTuple
230	if stateNeeded.Create {
231		keyTuples = append(keyTuples, types.StateKeyTuple{
232			EventTypeNID:     types.MRoomCreateNID,
233			EventStateKeyNID: types.EmptyStateKeyNID,
234		})
235	}
236	if stateNeeded.PowerLevels {
237		keyTuples = append(keyTuples, types.StateKeyTuple{
238			EventTypeNID:     types.MRoomPowerLevelsNID,
239			EventStateKeyNID: types.EmptyStateKeyNID,
240		})
241	}
242	if stateNeeded.JoinRules {
243		keyTuples = append(keyTuples, types.StateKeyTuple{
244			EventTypeNID:     types.MRoomJoinRulesNID,
245			EventStateKeyNID: types.EmptyStateKeyNID,
246		})
247	}
248	for _, member := range stateNeeded.Member {
249		stateKeyNID, ok := stateKeyNIDMap[member]
250		if ok {
251			keyTuples = append(keyTuples, types.StateKeyTuple{
252				EventTypeNID:     types.MRoomMemberNID,
253				EventStateKeyNID: stateKeyNID,
254			})
255		}
256	}
257	for _, token := range stateNeeded.ThirdPartyInvite {
258		stateKeyNID, ok := stateKeyNIDMap[token]
259		if ok {
260			keyTuples = append(keyTuples, types.StateKeyTuple{
261				EventTypeNID:     types.MRoomThirdPartyInviteNID,
262				EventStateKeyNID: stateKeyNID,
263			})
264		}
265	}
266	return keyTuples
267}
268
269// Map from event type, state key tuple to numeric event ID.
270// Implemented using binary search on a sorted array.
271type stateEntryMap []types.StateEntry
272
273// lookup an entry in the event map.
274func (m stateEntryMap) lookup(stateKey types.StateKeyTuple) (eventNID types.EventNID, ok bool) {
275	// Since the list is sorted we can implement this using binary search.
276	// This is faster than using a hash map.
277	// We don't have to worry about pathological cases because the keys are fixed
278	// size and are controlled by us.
279	list := []types.StateEntry(m)
280	i := sort.Search(len(list), func(i int) bool {
281		return !list[i].StateKeyTuple.LessThan(stateKey)
282	})
283	if i < len(list) && list[i].StateKeyTuple == stateKey {
284		ok = true
285		eventNID = list[i].EventNID
286	}
287	return
288}
289
290// Map from numeric event ID to event.
291// Implemented using binary search on a sorted array.
292type EventMap []types.Event
293
294// lookup an entry in the event map.
295func (m EventMap) Lookup(eventNID types.EventNID) (event *types.Event, ok bool) {
296	// Since the list is sorted we can implement this using binary search.
297	// This is faster than using a hash map.
298	// We don't have to worry about pathological cases because the keys are fixed
299	// size are controlled by us.
300	list := []types.Event(m)
301	i := sort.Search(len(list), func(i int) bool {
302		return list[i].EventNID >= eventNID
303	})
304	if i < len(list) && list[i].EventNID == eventNID {
305		ok = true
306		event = &list[i]
307	}
308	return
309}
310