1// Copyright 2017 Vector Creations Ltd 2// Copyright 2018 New Vector Ltd 3// Copyright 2019-2020 The Matrix.org Foundation C.I.C. 4// 5// Licensed under the Apache License, Version 2.0 (the "License"); 6// you may not use this file except in compliance with the License. 7// You may obtain a copy of the License at 8// 9// http://www.apache.org/licenses/LICENSE-2.0 10// 11// Unless required by applicable law or agreed to in writing, software 12// distributed under the License is distributed on an "AS IS" BASIS, 13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14// See the License for the specific language governing permissions and 15// limitations under the License. 16 17package state 18 19import ( 20 "context" 21 "fmt" 22 "sort" 23 "time" 24 25 "github.com/matrix-org/dendrite/roomserver/storage" 26 "github.com/matrix-org/util" 27 "github.com/prometheus/client_golang/prometheus" 28 29 "github.com/matrix-org/dendrite/roomserver/types" 30 "github.com/matrix-org/gomatrixserverlib" 31) 32 33type StateResolution struct { 34 db storage.Database 35 roomInfo types.RoomInfo 36 events map[types.EventNID]*gomatrixserverlib.Event 37} 38 39func NewStateResolution(db storage.Database, roomInfo types.RoomInfo) StateResolution { 40 return StateResolution{ 41 db: db, 42 roomInfo: roomInfo, 43 events: make(map[types.EventNID]*gomatrixserverlib.Event), 44 } 45} 46 47// LoadStateAtSnapshot loads the full state of a room at a particular snapshot. 48// This is typically the state before an event or the current state of a room. 49// Returns a sorted list of state entries or an error if there was a problem talking to the database. 50func (v *StateResolution) LoadStateAtSnapshot( 51 ctx context.Context, stateNID types.StateSnapshotNID, 52) ([]types.StateEntry, error) { 53 stateBlockNIDLists, err := v.db.StateBlockNIDs(ctx, []types.StateSnapshotNID{stateNID}) 54 if err != nil { 55 return nil, err 56 } 57 // We've asked for exactly one snapshot from the db so we should have exactly one entry in the result. 58 stateBlockNIDList := stateBlockNIDLists[0] 59 60 stateEntryLists, err := v.db.StateEntries(ctx, stateBlockNIDList.StateBlockNIDs) 61 if err != nil { 62 return nil, err 63 } 64 stateEntriesMap := stateEntryListMap(stateEntryLists) 65 66 // Combine all the state entries for this snapshot. 67 // The order of state block NIDs in the list tells us the order to combine them in. 68 var fullState []types.StateEntry 69 for _, stateBlockNID := range stateBlockNIDList.StateBlockNIDs { 70 entries, ok := stateEntriesMap.lookup(stateBlockNID) 71 if !ok { 72 // This should only get hit if the database is corrupt. 73 // It should be impossible for an event to reference a NID that doesn't exist 74 panic(fmt.Errorf("corrupt DB: Missing state block numeric ID %d", stateBlockNID)) 75 } 76 fullState = append(fullState, entries...) 77 } 78 79 // Stable sort so that the most recent entry for each state key stays 80 // remains later in the list than the older entries for the same state key. 81 sort.Stable(stateEntryByStateKeySorter(fullState)) 82 // Unique returns the last entry and hence the most recent entry for each state key. 83 fullState = fullState[:util.Unique(stateEntryByStateKeySorter(fullState))] 84 return fullState, nil 85} 86 87// LoadStateAtEvent loads the full state of a room before a particular event. 88func (v *StateResolution) LoadStateAtEvent( 89 ctx context.Context, eventID string, 90) ([]types.StateEntry, error) { 91 snapshotNID, err := v.db.SnapshotNIDFromEventID(ctx, eventID) 92 if err != nil { 93 return nil, fmt.Errorf("LoadStateAtEvent.SnapshotNIDFromEventID failed for event %s : %s", eventID, err) 94 } 95 if snapshotNID == 0 { 96 return nil, fmt.Errorf("LoadStateAtEvent.SnapshotNIDFromEventID(%s) returned 0 NID, was this event stored?", eventID) 97 } 98 99 stateEntries, err := v.LoadStateAtSnapshot(ctx, snapshotNID) 100 if err != nil { 101 return nil, err 102 } 103 104 return stateEntries, nil 105} 106 107// LoadCombinedStateAfterEvents loads a snapshot of the state after each of the events 108// and combines those snapshots together into a single list. At this point it is 109// possible to run into duplicate (type, state key) tuples. 110func (v *StateResolution) LoadCombinedStateAfterEvents( 111 ctx context.Context, prevStates []types.StateAtEvent, 112) ([]types.StateEntry, error) { 113 stateNIDs := make([]types.StateSnapshotNID, len(prevStates)) 114 for i, state := range prevStates { 115 stateNIDs[i] = state.BeforeStateSnapshotNID 116 } 117 // Fetch the state snapshots for the state before the each prev event from the database. 118 // Deduplicate the IDs before passing them to the database. 119 // There could be duplicates because the events could be state events where 120 // the snapshot of the room state before them was the same. 121 stateBlockNIDLists, err := v.db.StateBlockNIDs(ctx, UniqueStateSnapshotNIDs(stateNIDs)) 122 if err != nil { 123 return nil, fmt.Errorf("v.db.StateBlockNIDs: %w", err) 124 } 125 126 var stateBlockNIDs []types.StateBlockNID 127 for _, list := range stateBlockNIDLists { 128 stateBlockNIDs = append(stateBlockNIDs, list.StateBlockNIDs...) 129 } 130 // Fetch the state entries that will be combined to create the snapshots. 131 // Deduplicate the IDs before passing them to the database. 132 // There could be duplicates because a block of state entries could be reused by 133 // multiple snapshots. 134 stateEntryLists, err := v.db.StateEntries(ctx, uniqueStateBlockNIDs(stateBlockNIDs)) 135 if err != nil { 136 return nil, fmt.Errorf("v.db.StateEntries: %w", err) 137 } 138 stateBlockNIDsMap := stateBlockNIDListMap(stateBlockNIDLists) 139 stateEntriesMap := stateEntryListMap(stateEntryLists) 140 141 // Combine the entries from all the snapshots of state after each prev event into a single list. 142 var combined []types.StateEntry 143 for _, prevState := range prevStates { 144 // Grab the list of state data NIDs for this snapshot. 145 stateBlockNIDs, ok := stateBlockNIDsMap.lookup(prevState.BeforeStateSnapshotNID) 146 if !ok { 147 // This should only get hit if the database is corrupt. 148 // It should be impossible for an event to reference a NID that doesn't exist 149 panic(fmt.Errorf("corrupt DB: Missing state snapshot numeric ID %d", prevState.BeforeStateSnapshotNID)) 150 } 151 152 // Combine all the state entries for this snapshot. 153 // The order of state block NIDs in the list tells us the order to combine them in. 154 var fullState []types.StateEntry 155 for _, stateBlockNID := range stateBlockNIDs { 156 entries, ok := stateEntriesMap.lookup(stateBlockNID) 157 if !ok { 158 // This should only get hit if the database is corrupt. 159 // It should be impossible for an event to reference a NID that doesn't exist 160 panic(fmt.Errorf("corrupt DB: Missing state block numeric ID %d", stateBlockNID)) 161 } 162 fullState = append(fullState, entries...) 163 } 164 if prevState.IsStateEvent() && !prevState.IsRejected { 165 // If the prev event was a state event then add an entry for the event itself 166 // so that we get the state after the event rather than the state before. 167 fullState = append(fullState, prevState.StateEntry) 168 } 169 170 // Stable sort so that the most recent entry for each state key stays 171 // remains later in the list than the older entries for the same state key. 172 sort.Stable(stateEntryByStateKeySorter(fullState)) 173 // Unique returns the last entry and hence the most recent entry for each state key. 174 fullState = fullState[:util.Unique(stateEntryByStateKeySorter(fullState))] 175 // Add the full state for this StateSnapshotNID. 176 combined = append(combined, fullState...) 177 } 178 return combined, nil 179} 180 181// DifferenceBetweeenStateSnapshots works out which state entries have been added and removed between two snapshots. 182func (v *StateResolution) DifferenceBetweeenStateSnapshots( 183 ctx context.Context, oldStateNID, newStateNID types.StateSnapshotNID, 184) (removed, added []types.StateEntry, err error) { 185 if oldStateNID == newStateNID { 186 // If the snapshot NIDs are the same then nothing has changed 187 return nil, nil, nil 188 } 189 190 var oldEntries []types.StateEntry 191 var newEntries []types.StateEntry 192 if oldStateNID != 0 { 193 oldEntries, err = v.LoadStateAtSnapshot(ctx, oldStateNID) 194 if err != nil { 195 return nil, nil, err 196 } 197 } 198 if newStateNID != 0 { 199 newEntries, err = v.LoadStateAtSnapshot(ctx, newStateNID) 200 if err != nil { 201 return nil, nil, err 202 } 203 } 204 205 var oldI int 206 var newI int 207 for { 208 switch { 209 case oldI == len(oldEntries): 210 // We've reached the end of the old entries. 211 // The rest of the new list must have been newly added. 212 added = append(added, newEntries[newI:]...) 213 return 214 case newI == len(newEntries): 215 // We've reached the end of the new entries. 216 // The rest of the old list must be have been removed. 217 removed = append(removed, oldEntries[oldI:]...) 218 return 219 case oldEntries[oldI] == newEntries[newI]: 220 // The entry is in both lists so skip over it. 221 oldI++ 222 newI++ 223 case oldEntries[oldI].LessThan(newEntries[newI]): 224 // The lists are sorted so the old entry being less than the new entry means that it only appears in the old list. 225 removed = append(removed, oldEntries[oldI]) 226 oldI++ 227 default: 228 // Reaching the default case implies that the new entry is less than the old entry. 229 // Since the lists are sorted this means that it only appears in the new list. 230 added = append(added, newEntries[newI]) 231 newI++ 232 } 233 } 234} 235 236// LoadStateAtSnapshotForStringTuples loads the state for a list of event type and state key pairs at a snapshot. 237// This is used when we only want to load a subset of the room state at a snapshot. 238// If there is no entry for a given event type and state key pair then it will be discarded. 239// This is typically the state before an event or the current state of a room. 240// Returns a sorted list of state entries or an error if there was a problem talking to the database. 241func (v *StateResolution) LoadStateAtSnapshotForStringTuples( 242 ctx context.Context, 243 stateNID types.StateSnapshotNID, 244 stateKeyTuples []gomatrixserverlib.StateKeyTuple, 245) ([]types.StateEntry, error) { 246 numericTuples, err := v.stringTuplesToNumericTuples(ctx, stateKeyTuples) 247 if err != nil { 248 return nil, err 249 } 250 return v.loadStateAtSnapshotForNumericTuples(ctx, stateNID, numericTuples) 251} 252 253// stringTuplesToNumericTuples converts the string state key tuples into numeric IDs 254// If there isn't a numeric ID for either the event type or the event state key then the tuple is discarded. 255// Returns an error if there was a problem talking to the database. 256func (v *StateResolution) stringTuplesToNumericTuples( 257 ctx context.Context, 258 stringTuples []gomatrixserverlib.StateKeyTuple, 259) ([]types.StateKeyTuple, error) { 260 eventTypes := make([]string, len(stringTuples)) 261 stateKeys := make([]string, len(stringTuples)) 262 for i := range stringTuples { 263 eventTypes[i] = stringTuples[i].EventType 264 stateKeys[i] = stringTuples[i].StateKey 265 } 266 eventTypes = util.UniqueStrings(eventTypes) 267 eventTypeMap, err := v.db.EventTypeNIDs(ctx, eventTypes) 268 if err != nil { 269 return nil, err 270 } 271 stateKeys = util.UniqueStrings(stateKeys) 272 stateKeyMap, err := v.db.EventStateKeyNIDs(ctx, stateKeys) 273 if err != nil { 274 return nil, err 275 } 276 277 var result []types.StateKeyTuple 278 for _, stringTuple := range stringTuples { 279 var numericTuple types.StateKeyTuple 280 var ok1, ok2 bool 281 numericTuple.EventTypeNID, ok1 = eventTypeMap[stringTuple.EventType] 282 numericTuple.EventStateKeyNID, ok2 = stateKeyMap[stringTuple.StateKey] 283 // Discard the tuple if there wasn't a numeric ID for either the event type or the state key. 284 if ok1 && ok2 { 285 result = append(result, numericTuple) 286 } 287 } 288 289 return result, nil 290} 291 292// loadStateAtSnapshotForNumericTuples loads the state for a list of event type and state key pairs at a snapshot. 293// This is used when we only want to load a subset of the room state at a snapshot. 294// If there is no entry for a given event type and state key pair then it will be discarded. 295// This is typically the state before an event or the current state of a room. 296// Returns a sorted list of state entries or an error if there was a problem talking to the database. 297func (v *StateResolution) loadStateAtSnapshotForNumericTuples( 298 ctx context.Context, 299 stateNID types.StateSnapshotNID, 300 stateKeyTuples []types.StateKeyTuple, 301) ([]types.StateEntry, error) { 302 stateBlockNIDLists, err := v.db.StateBlockNIDs(ctx, []types.StateSnapshotNID{stateNID}) 303 if err != nil { 304 return nil, err 305 } 306 // We've asked for exactly one snapshot from the db so we should have exactly one entry in the result. 307 stateBlockNIDList := stateBlockNIDLists[0] 308 309 stateEntryLists, err := v.db.StateEntriesForTuples( 310 ctx, stateBlockNIDList.StateBlockNIDs, stateKeyTuples, 311 ) 312 if err != nil { 313 return nil, err 314 } 315 stateEntriesMap := stateEntryListMap(stateEntryLists) 316 317 // Combine all the state entries for this snapshot. 318 // The order of state block NIDs in the list tells us the order to combine them in. 319 var fullState []types.StateEntry 320 for _, stateBlockNID := range stateBlockNIDList.StateBlockNIDs { 321 entries, ok := stateEntriesMap.lookup(stateBlockNID) 322 if !ok { 323 // If the block is missing from the map it means that none of its entries matched a requested tuple. 324 // This can happen if the block doesn't contain an update for one of the requested tuples. 325 // If none of the requested tuples are in the block then it can be safely skipped. 326 continue 327 } 328 fullState = append(fullState, entries...) 329 } 330 331 // Stable sort so that the most recent entry for each state key stays 332 // remains later in the list than the older entries for the same state key. 333 sort.Stable(stateEntryByStateKeySorter(fullState)) 334 // Unique returns the last entry and hence the most recent entry for each state key. 335 fullState = fullState[:util.Unique(stateEntryByStateKeySorter(fullState))] 336 return fullState, nil 337} 338 339// LoadStateAfterEventsForStringTuples loads the state for a list of event type 340// and state key pairs after list of events. 341// This is used when we only want to load a subset of the room state after a list of events. 342// If there is no entry for a given event type and state key pair then it will be discarded. 343// This is typically the state before an event. 344// Returns a sorted list of state entries or an error if there was a problem talking to the database. 345func (v *StateResolution) LoadStateAfterEventsForStringTuples( 346 ctx context.Context, 347 prevStates []types.StateAtEvent, 348 stateKeyTuples []gomatrixserverlib.StateKeyTuple, 349) ([]types.StateEntry, error) { 350 numericTuples, err := v.stringTuplesToNumericTuples(ctx, stateKeyTuples) 351 if err != nil { 352 return nil, err 353 } 354 return v.loadStateAfterEventsForNumericTuples(ctx, prevStates, numericTuples) 355} 356 357func (v *StateResolution) loadStateAfterEventsForNumericTuples( 358 ctx context.Context, 359 prevStates []types.StateAtEvent, 360 stateKeyTuples []types.StateKeyTuple, 361) ([]types.StateEntry, error) { 362 if len(prevStates) == 1 { 363 // Fast path for a single event. 364 prevState := prevStates[0] 365 result, err := v.loadStateAtSnapshotForNumericTuples( 366 ctx, prevState.BeforeStateSnapshotNID, stateKeyTuples, 367 ) 368 if err != nil { 369 return nil, err 370 } 371 if prevState.IsStateEvent() { 372 // The result is current the state before the requested event. 373 // We want the state after the requested event. 374 // If the requested event was a state event then we need to 375 // update that key in the result. 376 // If the requested event wasn't a state event then the state after 377 // it is the same as the state before it. 378 set := false 379 for i := range result { 380 if result[i].StateKeyTuple == prevState.StateKeyTuple { 381 result[i] = prevState.StateEntry 382 set = true 383 } 384 } 385 if !set { // no previous state exists for this event: add new state 386 result = append(result, prevState.StateEntry) 387 } 388 } 389 return result, nil 390 } 391 392 // Slow path for more that one event. 393 // Load the entire state so that we can do conflict resolution if we need to. 394 // TODO: The are some optimistations we could do here: 395 // 1) We only need to do conflict resolution if there is a conflict in the 396 // requested tuples so we might try loading just those tuples and then 397 // checking for conflicts. 398 // 2) When there is a conflict we still only need to load the state 399 // needed to do conflict resolution which would save us having to load 400 // the full state. 401 402 // TODO: Add metrics for this as it could take a long time for big rooms 403 // with large conflicts. 404 fullState, _, _, err := v.calculateStateAfterManyEvents(ctx, v.roomInfo.RoomVersion, prevStates) 405 if err != nil { 406 return nil, err 407 } 408 409 // Sort the full state so we can use it as a map. 410 sort.Sort(stateEntrySorter(fullState)) 411 412 // Filter the full state down to the required tuples. 413 var result []types.StateEntry 414 for _, tuple := range stateKeyTuples { 415 eventNID, ok := stateEntryMap(fullState).lookup(tuple) 416 if ok { 417 result = append(result, types.StateEntry{ 418 StateKeyTuple: tuple, 419 EventNID: eventNID, 420 }) 421 } 422 } 423 sort.Sort(stateEntrySorter(result)) 424 return result, nil 425} 426 427var calculateStateDurations = prometheus.NewHistogramVec( 428 prometheus.HistogramOpts{ 429 Namespace: "dendrite", 430 Subsystem: "roomserver", 431 Name: "calculate_state_duration_milliseconds", 432 Help: "How long it takes to calculate the state after a list of events", 433 Buckets: []float64{ // milliseconds 434 5, 10, 25, 50, 75, 100, 200, 300, 400, 500, 435 1000, 2000, 3000, 4000, 5000, 6000, 436 7000, 8000, 9000, 10000, 15000, 20000, 30000, 437 }, 438 }, 439 // Takes two labels: 440 // algorithm: 441 // The algorithm used to calculate the state or the step it failed on if it failed. 442 // Labels starting with "_" are used to indicate when the algorithm fails halfway. 443 // outcome: 444 // Whether the state was successfully calculated. 445 // 446 // The possible values for algorithm are: 447 // empty_state -> The list of events was empty so the state is empty. 448 // no_change -> The state hasn't changed. 449 // single_delta -> There was a single event added to the state in a way that can be encoded as a single delta 450 // full_state_no_conflicts -> We created a new copy of the full room state, but didn't enounter any conflicts 451 // while doing so. 452 // full_state_with_conflicts -> We created a new copy of the full room state and had to resolve conflicts to do so. 453 // _load_state_block_nids -> Failed loading the state block nids for a single previous state. 454 // _load_combined_state -> Failed to load the combined state. 455 // _resolve_conflicts -> Failed to resolve conflicts. 456 []string{"algorithm", "outcome"}, 457) 458 459var calculateStatePrevEventLength = prometheus.NewSummaryVec( 460 prometheus.SummaryOpts{ 461 Namespace: "dendrite", 462 Subsystem: "roomserver", 463 Name: "calculate_state_prev_event_length", 464 Help: "The length of the list of events to calculate the state after", 465 }, 466 []string{"algorithm", "outcome"}, 467) 468 469var calculateStateFullStateLength = prometheus.NewSummaryVec( 470 prometheus.SummaryOpts{ 471 Namespace: "dendrite", 472 Subsystem: "roomserver", 473 Name: "calculate_state_full_state_length", 474 Help: "The length of the full room state.", 475 }, 476 []string{"algorithm", "outcome"}, 477) 478 479var calculateStateConflictLength = prometheus.NewSummaryVec( 480 prometheus.SummaryOpts{ 481 Namespace: "dendrite", 482 Subsystem: "roomserver", 483 Name: "calculate_state_conflict_state_length", 484 Help: "The length of the conflicted room state.", 485 }, 486 []string{"algorithm", "outcome"}, 487) 488 489type calculateStateMetrics struct { 490 algorithm string 491 startTime time.Time 492 prevEventLength int 493 fullStateLength int 494 conflictLength int 495} 496 497func (c *calculateStateMetrics) stop(stateNID types.StateSnapshotNID, err error) (types.StateSnapshotNID, error) { 498 var outcome string 499 if err == nil { 500 outcome = "success" 501 } else { 502 outcome = "failure" 503 } 504 calculateStateDurations.WithLabelValues(c.algorithm, outcome).Observe( 505 float64(time.Since(c.startTime).Milliseconds()), 506 ) 507 calculateStatePrevEventLength.WithLabelValues(c.algorithm, outcome).Observe( 508 float64(c.prevEventLength), 509 ) 510 calculateStateFullStateLength.WithLabelValues(c.algorithm, outcome).Observe( 511 float64(c.fullStateLength), 512 ) 513 calculateStateConflictLength.WithLabelValues(c.algorithm, outcome).Observe( 514 float64(c.conflictLength), 515 ) 516 return stateNID, err 517} 518 519func init() { 520 prometheus.MustRegister( 521 calculateStateDurations, calculateStatePrevEventLength, 522 calculateStateFullStateLength, calculateStateConflictLength, 523 ) 524} 525 526// CalculateAndStoreStateBeforeEvent calculates a snapshot of the state of a room before an event. 527// Stores the snapshot of the state in the database. 528// Returns a numeric ID for the snapshot of the state before the event. 529func (v *StateResolution) CalculateAndStoreStateBeforeEvent( 530 ctx context.Context, 531 event *gomatrixserverlib.Event, 532 isRejected bool, 533) (types.StateSnapshotNID, error) { 534 // Load the state at the prev events. 535 prevStates, err := v.db.StateAtEventIDs(ctx, event.PrevEventIDs()) 536 if err != nil { 537 return 0, err 538 } 539 540 // The state before this event will be the state after the events that came before it. 541 return v.CalculateAndStoreStateAfterEvents(ctx, prevStates) 542} 543 544// CalculateAndStoreStateAfterEvents finds the room state after the given events. 545// Stores the resulting state in the database and returns a numeric ID for that snapshot. 546func (v *StateResolution) CalculateAndStoreStateAfterEvents( 547 ctx context.Context, 548 prevStates []types.StateAtEvent, 549) (types.StateSnapshotNID, error) { 550 metrics := calculateStateMetrics{startTime: time.Now(), prevEventLength: len(prevStates)} 551 552 if len(prevStates) == 0 { 553 // 2) There weren't any prev_events for this event so the state is 554 // empty. 555 metrics.algorithm = "empty_state" 556 stateNID, err := v.db.AddState(ctx, v.roomInfo.RoomNID, nil, nil) 557 if err != nil { 558 err = fmt.Errorf("v.db.AddState: %w", err) 559 } 560 return metrics.stop(stateNID, err) 561 } 562 563 if len(prevStates) == 1 { 564 prevState := prevStates[0] 565 if prevState.EventStateKeyNID == 0 || prevState.IsRejected { 566 // 3) None of the previous events were state events and they all 567 // have the same state, so this event has exactly the same state 568 // as the previous events. 569 // This should be the internal case. 570 metrics.algorithm = "no_change" 571 return metrics.stop(prevState.BeforeStateSnapshotNID, nil) 572 } 573 // The previous event was a state event so we need to store a copy 574 // of the previous state updated with that event. 575 stateBlockNIDLists, err := v.db.StateBlockNIDs( 576 ctx, []types.StateSnapshotNID{prevState.BeforeStateSnapshotNID}, 577 ) 578 if err != nil { 579 metrics.algorithm = "_load_state_blocks" 580 return metrics.stop(0, fmt.Errorf("v.db.StateBlockNIDs: %w", err)) 581 } 582 stateBlockNIDs := stateBlockNIDLists[0].StateBlockNIDs 583 if len(stateBlockNIDs) < maxStateBlockNIDs { 584 // 4) The number of state data blocks is small enough that we can just 585 // add the state event as a block of size one to the end of the blocks. 586 metrics.algorithm = "single_delta" 587 stateNID, err := v.db.AddState( 588 ctx, v.roomInfo.RoomNID, stateBlockNIDs, []types.StateEntry{prevState.StateEntry}, 589 ) 590 if err != nil { 591 err = fmt.Errorf("v.db.AddState: %w", err) 592 } 593 return metrics.stop(stateNID, err) 594 } 595 // If there are too many deltas then we need to calculate the full state 596 // So fall through to calculateAndStoreStateAfterManyEvents 597 } 598 599 stateNID, err := v.calculateAndStoreStateAfterManyEvents(ctx, v.roomInfo.RoomNID, prevStates, metrics) 600 if err != nil { 601 return 0, fmt.Errorf("v.calculateAndStoreStateAfterManyEvents: %w", err) 602 } 603 return stateNID, nil 604} 605 606// maxStateBlockNIDs is the maximum number of state data blocks to use to encode a snapshot of room state. 607// Increasing this number means that we can encode more of the state changes as simple deltas which means that 608// we need fewer entries in the state data table. However making this number bigger will increase the size of 609// the rows in the state table itself and will require more index lookups when retrieving a snapshot. 610// TODO: Tune this to get the right balance between size and lookup performance. 611const maxStateBlockNIDs = 64 612 613// calculateAndStoreStateAfterManyEvents finds the room state after the given events. 614// This handles the slow path of calculateAndStoreStateAfterEvents for when there is more than one event. 615// Stores the resulting state and returns a numeric ID for the snapshot. 616func (v *StateResolution) calculateAndStoreStateAfterManyEvents( 617 ctx context.Context, 618 roomNID types.RoomNID, 619 prevStates []types.StateAtEvent, 620 metrics calculateStateMetrics, 621) (types.StateSnapshotNID, error) { 622 state, algorithm, conflictLength, err := 623 v.calculateStateAfterManyEvents(ctx, v.roomInfo.RoomVersion, prevStates) 624 metrics.algorithm = algorithm 625 if err != nil { 626 return metrics.stop(0, fmt.Errorf("v.calculateStateAfterManyEvents: %w", err)) 627 } 628 629 // TODO: Check if we can encode the new state as a delta against the 630 // previous state. 631 metrics.conflictLength = conflictLength 632 metrics.fullStateLength = len(state) 633 return metrics.stop(v.db.AddState(ctx, roomNID, nil, state)) 634} 635 636func (v *StateResolution) calculateStateAfterManyEvents( 637 ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, 638 prevStates []types.StateAtEvent, 639) (state []types.StateEntry, algorithm string, conflictLength int, err error) { 640 var combined []types.StateEntry 641 // Conflict resolution. 642 // First stage: load the state after each of the prev events. 643 combined, err = v.LoadCombinedStateAfterEvents(ctx, prevStates) 644 if err != nil { 645 err = fmt.Errorf("v.LoadCombinedStateAfterEvents: %w", err) 646 algorithm = "_load_combined_state" 647 return 648 } 649 650 // Collect all the entries with the same type and key together. 651 // We don't care about the order here because the conflict resolution 652 // algorithm doesn't depend on the order of the prev events. 653 // Remove duplicate entires. 654 combined = combined[:util.SortAndUnique(stateEntrySorter(combined))] 655 656 // Find the conflicts 657 conflicts := findDuplicateStateKeys(combined) 658 659 if len(conflicts) > 0 { 660 conflictLength = len(conflicts) 661 662 // 5) There are conflicting state events, for each conflict workout 663 // what the appropriate state event is. 664 665 // Work out which entries aren't conflicted. 666 var notConflicted []types.StateEntry 667 for _, entry := range combined { 668 if _, ok := stateEntryMap(conflicts).lookup(entry.StateKeyTuple); !ok { 669 notConflicted = append(notConflicted, entry) 670 } 671 } 672 673 var resolved []types.StateEntry 674 resolved, err = v.resolveConflicts(ctx, roomVersion, notConflicted, conflicts) 675 if err != nil { 676 err = fmt.Errorf("v.resolveConflits: %w", err) 677 algorithm = "_resolve_conflicts" 678 return 679 } 680 algorithm = "full_state_with_conflicts" 681 state = resolved[:util.SortAndUnique(stateEntrySorter(resolved))] 682 } else { 683 algorithm = "full_state_no_conflicts" 684 // 6) There weren't any conflicts 685 state = combined 686 } 687 return 688} 689 690func (v *StateResolution) resolveConflicts( 691 ctx context.Context, version gomatrixserverlib.RoomVersion, 692 notConflicted, conflicted []types.StateEntry, 693) ([]types.StateEntry, error) { 694 stateResAlgo, err := version.StateResAlgorithm() 695 if err != nil { 696 return nil, err 697 } 698 switch stateResAlgo { 699 case gomatrixserverlib.StateResV1: 700 return v.resolveConflictsV1(ctx, notConflicted, conflicted) 701 case gomatrixserverlib.StateResV2: 702 return v.resolveConflictsV2(ctx, notConflicted, conflicted) 703 } 704 return nil, fmt.Errorf("unsupported state resolution algorithm %v", stateResAlgo) 705} 706 707// resolveConflicts resolves a list of conflicted state entries. It takes two lists. 708// The first is a list of all state entries that are not conflicted. 709// The second is a list of all state entries that are conflicted 710// A state entry is conflicted when there is more than one numeric event ID for the same state key tuple. 711// Returns a list that combines the entries without conflicts with the result of state resolution for the entries with conflicts. 712// The returned list is sorted by state key tuple. 713// Returns an error if there was a problem talking to the database. 714func (v *StateResolution) resolveConflictsV1( 715 ctx context.Context, 716 notConflicted, conflicted []types.StateEntry, 717) ([]types.StateEntry, error) { 718 719 // Load the conflicted events 720 conflictedEvents, eventIDMap, err := v.loadStateEvents(ctx, conflicted) 721 if err != nil { 722 return nil, err 723 } 724 725 // Work out which auth events we need to load. 726 needed := gomatrixserverlib.StateNeededForAuth(conflictedEvents) 727 728 // Find the numeric IDs for the necessary state keys. 729 var neededStateKeys []string 730 neededStateKeys = append(neededStateKeys, needed.Member...) 731 neededStateKeys = append(neededStateKeys, needed.ThirdPartyInvite...) 732 stateKeyNIDMap, err := v.db.EventStateKeyNIDs(ctx, neededStateKeys) 733 if err != nil { 734 return nil, err 735 } 736 737 // Load the necessary auth events. 738 tuplesNeeded := v.stateKeyTuplesNeeded(stateKeyNIDMap, needed) 739 var authEntries []types.StateEntry 740 for _, tuple := range tuplesNeeded { 741 if eventNID, ok := stateEntryMap(notConflicted).lookup(tuple); ok { 742 authEntries = append(authEntries, types.StateEntry{ 743 StateKeyTuple: tuple, 744 EventNID: eventNID, 745 }) 746 } 747 } 748 authEvents, _, err := v.loadStateEvents(ctx, authEntries) 749 if err != nil { 750 return nil, err 751 } 752 753 // Resolve the conflicts. 754 resolvedEvents := gomatrixserverlib.ResolveStateConflicts(conflictedEvents, authEvents) 755 756 // Map from the full events back to numeric state entries. 757 for _, resolvedEvent := range resolvedEvents { 758 entry, ok := eventIDMap[resolvedEvent.EventID()] 759 if !ok { 760 panic(fmt.Errorf("missing state entry for event ID %q", resolvedEvent.EventID())) 761 } 762 notConflicted = append(notConflicted, entry) 763 } 764 765 // Sort the result so it can be searched. 766 sort.Sort(stateEntrySorter(notConflicted)) 767 return notConflicted, nil 768} 769 770// resolveConflicts resolves a list of conflicted state entries. It takes two lists. 771// The first is a list of all state entries that are not conflicted. 772// The second is a list of all state entries that are conflicted 773// A state entry is conflicted when there is more than one numeric event ID for the same state key tuple. 774// Returns a list that combines the entries without conflicts with the result of state resolution for the entries with conflicts. 775// The returned list is sorted by state key tuple. 776// Returns an error if there was a problem talking to the database. 777func (v *StateResolution) resolveConflictsV2( 778 ctx context.Context, 779 notConflicted, conflicted []types.StateEntry, 780) ([]types.StateEntry, error) { 781 estimate := len(conflicted) + len(notConflicted) 782 eventIDMap := make(map[string]types.StateEntry, estimate) 783 784 // Load the conflicted events 785 conflictedEvents, conflictedEventMap, err := v.loadStateEvents(ctx, conflicted) 786 if err != nil { 787 return nil, err 788 } 789 for k, v := range conflictedEventMap { 790 eventIDMap[k] = v 791 } 792 793 // Load the non-conflicted events 794 nonConflictedEvents, nonConflictedEventMap, err := v.loadStateEvents(ctx, notConflicted) 795 if err != nil { 796 return nil, err 797 } 798 for k, v := range nonConflictedEventMap { 799 eventIDMap[k] = v 800 } 801 802 // For each conflicted event, we will add a new set of auth events. Auth 803 // events may be duplicated across these sets but that's OK. 804 authSets := make(map[string][]*gomatrixserverlib.Event, len(conflicted)) 805 authEvents := make([]*gomatrixserverlib.Event, 0, estimate*3) 806 authDifference := make([]*gomatrixserverlib.Event, 0, estimate) 807 808 // For each conflicted event, let's try and get the needed auth events. 809 neededStateKeys := make([]string, 16) 810 authEntries := make([]types.StateEntry, 16) 811 for _, conflictedEvent := range conflictedEvents { 812 // Work out which auth events we need to load. 813 key := conflictedEvent.EventID() 814 needed := gomatrixserverlib.StateNeededForAuth([]*gomatrixserverlib.Event{conflictedEvent}) 815 816 // Find the numeric IDs for the necessary state keys. 817 neededStateKeys = neededStateKeys[:0] 818 neededStateKeys = append(neededStateKeys, needed.Member...) 819 neededStateKeys = append(neededStateKeys, needed.ThirdPartyInvite...) 820 stateKeyNIDMap, err := v.db.EventStateKeyNIDs(ctx, neededStateKeys) 821 if err != nil { 822 return nil, err 823 } 824 825 // Load the necessary auth events. 826 tuplesNeeded := v.stateKeyTuplesNeeded(stateKeyNIDMap, needed) 827 authEntries = authEntries[:0] 828 for _, tuple := range tuplesNeeded { 829 if eventNID, ok := stateEntryMap(notConflicted).lookup(tuple); ok { 830 authEntries = append(authEntries, types.StateEntry{ 831 StateKeyTuple: tuple, 832 EventNID: eventNID, 833 }) 834 } 835 } 836 837 // Store the newly found auth events in the auth set for this event. 838 authSets[key], _, err = v.loadStateEvents(ctx, authEntries) 839 if err != nil { 840 return nil, err 841 } 842 authEvents = append(authEvents, authSets[key]...) 843 } 844 845 // This function helps us to work out whether an event exists in one of the 846 // auth sets. 847 isInAuthList := func(k string, event *gomatrixserverlib.Event) bool { 848 for _, e := range authSets[k] { 849 if e.EventID() == event.EventID() { 850 return true 851 } 852 } 853 return false 854 } 855 856 // This function works out if an event exists in all of the auth sets. 857 isInAllAuthLists := func(event *gomatrixserverlib.Event) bool { 858 found := true 859 for k := range authSets { 860 found = found && isInAuthList(k, event) 861 } 862 return found 863 } 864 865 // Look through all of the auth events that we've been given and work out if 866 // there are any events which don't appear in all of the auth sets. If they 867 // don't then we add them to the auth difference. 868 for _, event := range authEvents { 869 if !isInAllAuthLists(event) { 870 authDifference = append(authDifference, event) 871 } 872 } 873 874 // Resolve the conflicts. 875 resolvedEvents := gomatrixserverlib.ResolveStateConflictsV2( 876 conflictedEvents, 877 nonConflictedEvents, 878 authEvents, 879 authDifference, 880 ) 881 882 // Map from the full events back to numeric state entries. 883 for _, resolvedEvent := range resolvedEvents { 884 entry, ok := eventIDMap[resolvedEvent.EventID()] 885 if !ok { 886 panic(fmt.Errorf("missing state entry for event ID %q", resolvedEvent.EventID())) 887 } 888 notConflicted = append(notConflicted, entry) 889 } 890 891 // Sort the result so it can be searched. 892 sort.Sort(stateEntrySorter(notConflicted)) 893 return notConflicted, nil 894} 895 896// stateKeyTuplesNeeded works out which numeric state key tuples we need to authenticate some events. 897func (v *StateResolution) stateKeyTuplesNeeded(stateKeyNIDMap map[string]types.EventStateKeyNID, stateNeeded gomatrixserverlib.StateNeeded) []types.StateKeyTuple { 898 var keyTuples []types.StateKeyTuple 899 if stateNeeded.Create { 900 keyTuples = append(keyTuples, types.StateKeyTuple{ 901 EventTypeNID: types.MRoomCreateNID, 902 EventStateKeyNID: types.EmptyStateKeyNID, 903 }) 904 } 905 if stateNeeded.PowerLevels { 906 keyTuples = append(keyTuples, types.StateKeyTuple{ 907 EventTypeNID: types.MRoomPowerLevelsNID, 908 EventStateKeyNID: types.EmptyStateKeyNID, 909 }) 910 } 911 if stateNeeded.JoinRules { 912 keyTuples = append(keyTuples, types.StateKeyTuple{ 913 EventTypeNID: types.MRoomJoinRulesNID, 914 EventStateKeyNID: types.EmptyStateKeyNID, 915 }) 916 } 917 for _, member := range stateNeeded.Member { 918 stateKeyNID, ok := stateKeyNIDMap[member] 919 if ok { 920 keyTuples = append(keyTuples, types.StateKeyTuple{ 921 EventTypeNID: types.MRoomMemberNID, 922 EventStateKeyNID: stateKeyNID, 923 }) 924 } 925 } 926 for _, token := range stateNeeded.ThirdPartyInvite { 927 stateKeyNID, ok := stateKeyNIDMap[token] 928 if ok { 929 keyTuples = append(keyTuples, types.StateKeyTuple{ 930 EventTypeNID: types.MRoomThirdPartyInviteNID, 931 EventStateKeyNID: stateKeyNID, 932 }) 933 } 934 } 935 return keyTuples 936} 937 938// loadStateEvents loads the matrix events for a list of state entries. 939// Returns a list of state events in no particular order and a map from string event ID back to state entry. 940// The map can be used to recover which numeric state entry a given event is for. 941// Returns an error if there was a problem talking to the database. 942func (v *StateResolution) loadStateEvents( 943 ctx context.Context, entries []types.StateEntry, 944) ([]*gomatrixserverlib.Event, map[string]types.StateEntry, error) { 945 result := make([]*gomatrixserverlib.Event, 0, len(entries)) 946 eventEntries := make([]types.StateEntry, 0, len(entries)) 947 eventNIDs := make([]types.EventNID, 0, len(entries)) 948 for _, entry := range entries { 949 if e, ok := v.events[entry.EventNID]; ok { 950 result = append(result, e) 951 } else { 952 eventEntries = append(eventEntries, entry) 953 eventNIDs = append(eventNIDs, entry.EventNID) 954 } 955 } 956 events, err := v.db.Events(ctx, eventNIDs) 957 if err != nil { 958 return nil, nil, err 959 } 960 eventIDMap := map[string]types.StateEntry{} 961 for _, entry := range eventEntries { 962 event, ok := eventMap(events).lookup(entry.EventNID) 963 if !ok { 964 panic(fmt.Errorf("corrupt DB: Missing event numeric ID %d", entry.EventNID)) 965 } 966 result = append(result, event.Event) 967 eventIDMap[event.Event.EventID()] = entry 968 v.events[entry.EventNID] = event.Event 969 } 970 return result, eventIDMap, nil 971} 972 973// findDuplicateStateKeys finds the state entries where the state key tuple appears more than once in a sorted list. 974// Returns a sorted list of those state entries. 975func findDuplicateStateKeys(a []types.StateEntry) []types.StateEntry { 976 var result []types.StateEntry 977 // j is the starting index of a block of entries with the same state key tuple. 978 j := 0 979 for i := 1; i < len(a); i++ { 980 // Check if the state key tuple matches the start of the block 981 if a[j].StateKeyTuple != a[i].StateKeyTuple { 982 // If the state key tuple is different then we've reached the end of a block of duplicates. 983 // Check if the size of the block is bigger than one. 984 // If the size is one then there was only a single entry with that state key tuple so we don't add it to the result 985 if j+1 != i { 986 // Add the block to the result. 987 result = append(result, a[j:i]...) 988 } 989 // Start a new block for the next state key tuple. 990 j = i 991 } 992 } 993 // Check if the last block with the same state key tuple had more than one event in it. 994 if j+1 != len(a) { 995 result = append(result, a[j:]...) 996 } 997 return result 998} 999 1000type stateEntrySorter []types.StateEntry 1001 1002func (s stateEntrySorter) Len() int { return len(s) } 1003func (s stateEntrySorter) Less(i, j int) bool { return s[i].LessThan(s[j]) } 1004func (s stateEntrySorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] } 1005 1006type stateBlockNIDListMap []types.StateBlockNIDList 1007 1008func (m stateBlockNIDListMap) lookup(stateNID types.StateSnapshotNID) (stateBlockNIDs []types.StateBlockNID, ok bool) { 1009 list := []types.StateBlockNIDList(m) 1010 i := sort.Search(len(list), func(i int) bool { 1011 return list[i].StateSnapshotNID >= stateNID 1012 }) 1013 if i < len(list) && list[i].StateSnapshotNID == stateNID { 1014 ok = true 1015 stateBlockNIDs = list[i].StateBlockNIDs 1016 } 1017 return 1018} 1019 1020type stateEntryListMap []types.StateEntryList 1021 1022func (m stateEntryListMap) lookup(stateBlockNID types.StateBlockNID) (stateEntries []types.StateEntry, ok bool) { 1023 list := []types.StateEntryList(m) 1024 i := sort.Search(len(list), func(i int) bool { 1025 return list[i].StateBlockNID >= stateBlockNID 1026 }) 1027 if i < len(list) && list[i].StateBlockNID == stateBlockNID { 1028 ok = true 1029 stateEntries = list[i].StateEntries 1030 } 1031 return 1032} 1033 1034type stateEntryByStateKeySorter []types.StateEntry 1035 1036func (s stateEntryByStateKeySorter) Len() int { return len(s) } 1037func (s stateEntryByStateKeySorter) Less(i, j int) bool { 1038 return s[i].StateKeyTuple.LessThan(s[j].StateKeyTuple) 1039} 1040func (s stateEntryByStateKeySorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] } 1041 1042type stateNIDSorter []types.StateSnapshotNID 1043 1044func (s stateNIDSorter) Len() int { return len(s) } 1045func (s stateNIDSorter) Less(i, j int) bool { return s[i] < s[j] } 1046func (s stateNIDSorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] } 1047 1048func UniqueStateSnapshotNIDs(nids []types.StateSnapshotNID) []types.StateSnapshotNID { 1049 return nids[:util.SortAndUnique(stateNIDSorter(nids))] 1050} 1051 1052type stateBlockNIDSorter []types.StateBlockNID 1053 1054func (s stateBlockNIDSorter) Len() int { return len(s) } 1055func (s stateBlockNIDSorter) Less(i, j int) bool { return s[i] < s[j] } 1056func (s stateBlockNIDSorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] } 1057 1058func uniqueStateBlockNIDs(nids []types.StateBlockNID) []types.StateBlockNID { 1059 return nids[:util.SortAndUnique(stateBlockNIDSorter(nids))] 1060} 1061 1062// Map from event type, state key tuple to numeric event ID. 1063// Implemented using binary search on a sorted array. 1064type stateEntryMap []types.StateEntry 1065 1066// lookup an entry in the event map. 1067func (m stateEntryMap) lookup(stateKey types.StateKeyTuple) (eventNID types.EventNID, ok bool) { 1068 // Since the list is sorted we can implement this using binary search. 1069 // This is faster than using a hash map. 1070 // We don't have to worry about pathological cases because the keys are fixed 1071 // size and are controlled by us. 1072 list := []types.StateEntry(m) 1073 i := sort.Search(len(list), func(i int) bool { 1074 return !list[i].StateKeyTuple.LessThan(stateKey) 1075 }) 1076 if i < len(list) && list[i].StateKeyTuple == stateKey { 1077 ok = true 1078 eventNID = list[i].EventNID 1079 } 1080 return 1081} 1082 1083// Map from numeric event ID to event. 1084// Implemented using binary search on a sorted array. 1085type eventMap []types.Event 1086 1087// lookup an entry in the event map. 1088func (m eventMap) lookup(eventNID types.EventNID) (event *types.Event, ok bool) { 1089 // Since the list is sorted we can implement this using binary search. 1090 // This is faster than using a hash map. 1091 // We don't have to worry about pathological cases because the keys are fixed 1092 // size are controlled by us. 1093 list := []types.Event(m) 1094 i := sort.Search(len(list), func(i int) bool { 1095 return list[i].EventNID >= eventNID 1096 }) 1097 if i < len(list) && list[i].EventNID == eventNID { 1098 ok = true 1099 event = &list[i] 1100 } 1101 return 1102} 1103