1// Copyright 2020 The Matrix.org Foundation C.I.C. 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14 15package shared 16 17import ( 18 "context" 19 "database/sql" 20 "encoding/json" 21 "fmt" 22 23 eduAPI "github.com/matrix-org/dendrite/eduserver/api" 24 userapi "github.com/matrix-org/dendrite/userapi/api" 25 26 "github.com/matrix-org/dendrite/internal/eventutil" 27 "github.com/matrix-org/dendrite/internal/sqlutil" 28 "github.com/matrix-org/dendrite/roomserver/api" 29 "github.com/matrix-org/dendrite/syncapi/storage/tables" 30 "github.com/matrix-org/dendrite/syncapi/types" 31 "github.com/matrix-org/gomatrixserverlib" 32 "github.com/sirupsen/logrus" 33) 34 35// Database is a temporary struct until we have made syncserver.go the same for both pq/sqlite 36// For now this contains the shared functions 37type Database struct { 38 DB *sql.DB 39 Writer sqlutil.Writer 40 Invites tables.Invites 41 Peeks tables.Peeks 42 AccountData tables.AccountData 43 OutputEvents tables.Events 44 Topology tables.Topology 45 CurrentRoomState tables.CurrentRoomState 46 BackwardExtremities tables.BackwardsExtremities 47 SendToDevice tables.SendToDevice 48 Filter tables.Filter 49 Receipts tables.Receipts 50 Memberships tables.Memberships 51} 52 53func (d *Database) readOnlySnapshot(ctx context.Context) (*sql.Tx, error) { 54 return d.DB.BeginTx(ctx, &sql.TxOptions{ 55 // Set the isolation level so that we see a snapshot of the database. 56 // In PostgreSQL repeatable read transactions will see a snapshot taken 57 // at the first query, and since the transaction is read-only it can't 58 // run into any serialisation errors. 59 // https://www.postgresql.org/docs/9.5/static/transaction-iso.html#XACT-REPEATABLE-READ 60 Isolation: sql.LevelRepeatableRead, 61 ReadOnly: true, 62 }) 63} 64 65func (d *Database) MaxStreamPositionForPDUs(ctx context.Context) (types.StreamPosition, error) { 66 id, err := d.OutputEvents.SelectMaxEventID(ctx, nil) 67 if err != nil { 68 return 0, fmt.Errorf("d.OutputEvents.SelectMaxEventID: %w", err) 69 } 70 return types.StreamPosition(id), nil 71} 72 73func (d *Database) MaxStreamPositionForReceipts(ctx context.Context) (types.StreamPosition, error) { 74 id, err := d.Receipts.SelectMaxReceiptID(ctx, nil) 75 if err != nil { 76 return 0, fmt.Errorf("d.Receipts.SelectMaxReceiptID: %w", err) 77 } 78 return types.StreamPosition(id), nil 79} 80 81func (d *Database) MaxStreamPositionForInvites(ctx context.Context) (types.StreamPosition, error) { 82 id, err := d.Invites.SelectMaxInviteID(ctx, nil) 83 if err != nil { 84 return 0, fmt.Errorf("d.Invites.SelectMaxInviteID: %w", err) 85 } 86 return types.StreamPosition(id), nil 87} 88 89func (d *Database) MaxStreamPositionForSendToDeviceMessages(ctx context.Context) (types.StreamPosition, error) { 90 id, err := d.SendToDevice.SelectMaxSendToDeviceMessageID(ctx, nil) 91 if err != nil { 92 return 0, fmt.Errorf("d.SendToDevice.SelectMaxSendToDeviceMessageID: %w", err) 93 } 94 return types.StreamPosition(id), nil 95} 96 97func (d *Database) MaxStreamPositionForAccountData(ctx context.Context) (types.StreamPosition, error) { 98 id, err := d.AccountData.SelectMaxAccountDataID(ctx, nil) 99 if err != nil { 100 return 0, fmt.Errorf("d.Invites.SelectMaxAccountDataID: %w", err) 101 } 102 return types.StreamPosition(id), nil 103} 104 105func (d *Database) CurrentState(ctx context.Context, roomID string, stateFilterPart *gomatrixserverlib.StateFilter, excludeEventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) { 106 return d.CurrentRoomState.SelectCurrentState(ctx, nil, roomID, stateFilterPart, excludeEventIDs) 107} 108 109func (d *Database) RoomIDsWithMembership(ctx context.Context, userID string, membership string) ([]string, error) { 110 return d.CurrentRoomState.SelectRoomIDsWithMembership(ctx, nil, userID, membership) 111} 112 113func (d *Database) RecentEvents(ctx context.Context, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error) { 114 return d.OutputEvents.SelectRecentEvents(ctx, nil, roomID, r, eventFilter, chronologicalOrder, onlySyncEvents) 115} 116 117func (d *Database) PositionInTopology(ctx context.Context, eventID string) (pos types.StreamPosition, spos types.StreamPosition, err error) { 118 return d.Topology.SelectPositionInTopology(ctx, nil, eventID) 119} 120 121func (d *Database) InviteEventsInRange(ctx context.Context, targetUserID string, r types.Range) (map[string]*gomatrixserverlib.HeaderedEvent, map[string]*gomatrixserverlib.HeaderedEvent, error) { 122 return d.Invites.SelectInviteEventsInRange(ctx, nil, targetUserID, r) 123} 124 125func (d *Database) PeeksInRange(ctx context.Context, userID, deviceID string, r types.Range) (peeks []types.Peek, err error) { 126 return d.Peeks.SelectPeeksInRange(ctx, nil, userID, deviceID, r) 127} 128 129func (d *Database) RoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []eduAPI.OutputReceiptEvent, error) { 130 return d.Receipts.SelectRoomReceiptsAfter(ctx, roomIDs, streamPos) 131} 132 133// Events lookups a list of event by their event ID. 134// Returns a list of events matching the requested IDs found in the database. 135// If an event is not found in the database then it will be omitted from the list. 136// Returns an error if there was a problem talking with the database. 137// Does not include any transaction IDs in the returned events. 138func (d *Database) Events(ctx context.Context, eventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) { 139 streamEvents, err := d.OutputEvents.SelectEvents(ctx, nil, eventIDs) 140 if err != nil { 141 return nil, err 142 } 143 144 // We don't include a device here as we only include transaction IDs in 145 // incremental syncs. 146 return d.StreamEventsToEvents(nil, streamEvents), nil 147} 148 149// GetEventsInStreamingRange retrieves all of the events on a given ordering using the 150// given extremities and limit. 151func (d *Database) GetEventsInStreamingRange( 152 ctx context.Context, 153 from, to *types.StreamingToken, 154 roomID string, eventFilter *gomatrixserverlib.RoomEventFilter, 155 backwardOrdering bool, 156) (events []types.StreamEvent, err error) { 157 r := types.Range{ 158 From: from.PDUPosition, 159 To: to.PDUPosition, 160 Backwards: backwardOrdering, 161 } 162 if backwardOrdering { 163 // When using backward ordering, we want the most recent events first. 164 if events, _, err = d.OutputEvents.SelectRecentEvents( 165 ctx, nil, roomID, r, eventFilter, false, false, 166 ); err != nil { 167 return 168 } 169 } else { 170 // When using forward ordering, we want the least recent events first. 171 if events, err = d.OutputEvents.SelectEarlyEvents( 172 ctx, nil, roomID, r, eventFilter, 173 ); err != nil { 174 return 175 } 176 } 177 return events, err 178} 179 180func (d *Database) AllJoinedUsersInRooms(ctx context.Context) (map[string][]string, error) { 181 return d.CurrentRoomState.SelectJoinedUsers(ctx) 182} 183 184func (d *Database) AllPeekingDevicesInRooms(ctx context.Context) (map[string][]types.PeekingDevice, error) { 185 return d.Peeks.SelectPeekingDevices(ctx) 186} 187 188func (d *Database) GetStateEvent( 189 ctx context.Context, roomID, evType, stateKey string, 190) (*gomatrixserverlib.HeaderedEvent, error) { 191 return d.CurrentRoomState.SelectStateEvent(ctx, roomID, evType, stateKey) 192} 193 194func (d *Database) GetStateEventsForRoom( 195 ctx context.Context, roomID string, stateFilter *gomatrixserverlib.StateFilter, 196) (stateEvents []*gomatrixserverlib.HeaderedEvent, err error) { 197 stateEvents, err = d.CurrentRoomState.SelectCurrentState(ctx, nil, roomID, stateFilter, nil) 198 return 199} 200 201// AddInviteEvent stores a new invite event for a user. 202// If the invite was successfully stored this returns the stream ID it was stored at. 203// Returns an error if there was a problem communicating with the database. 204func (d *Database) AddInviteEvent( 205 ctx context.Context, inviteEvent *gomatrixserverlib.HeaderedEvent, 206) (sp types.StreamPosition, err error) { 207 _ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { 208 sp, err = d.Invites.InsertInviteEvent(ctx, txn, inviteEvent) 209 return err 210 }) 211 return 212} 213 214// RetireInviteEvent removes an old invite event from the database. 215// Returns an error if there was a problem communicating with the database. 216func (d *Database) RetireInviteEvent( 217 ctx context.Context, inviteEventID string, 218) (sp types.StreamPosition, err error) { 219 _ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { 220 sp, err = d.Invites.DeleteInviteEvent(ctx, txn, inviteEventID) 221 return err 222 }) 223 return 224} 225 226// AddPeek tracks the fact that a user has started peeking. 227// If the peek was successfully stored this returns the stream ID it was stored at. 228// Returns an error if there was a problem communicating with the database. 229func (d *Database) AddPeek( 230 ctx context.Context, roomID, userID, deviceID string, 231) (sp types.StreamPosition, err error) { 232 err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { 233 sp, err = d.Peeks.InsertPeek(ctx, txn, roomID, userID, deviceID) 234 return err 235 }) 236 return 237} 238 239// DeletePeeks tracks the fact that a user has stopped peeking from the specified 240// device. If the peeks was successfully deleted this returns the stream ID it was 241// stored at. Returns an error if there was a problem communicating with the database. 242func (d *Database) DeletePeek( 243 ctx context.Context, roomID, userID, deviceID string, 244) (sp types.StreamPosition, err error) { 245 err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { 246 sp, err = d.Peeks.DeletePeek(ctx, txn, roomID, userID, deviceID) 247 return err 248 }) 249 if err == sql.ErrNoRows { 250 sp = 0 251 err = nil 252 } 253 return 254} 255 256// DeletePeeks tracks the fact that a user has stopped peeking from all devices 257// If the peeks was successfully deleted this returns the stream ID it was stored at. 258// Returns an error if there was a problem communicating with the database. 259func (d *Database) DeletePeeks( 260 ctx context.Context, roomID, userID string, 261) (sp types.StreamPosition, err error) { 262 err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { 263 sp, err = d.Peeks.DeletePeeks(ctx, txn, roomID, userID) 264 return err 265 }) 266 if err == sql.ErrNoRows { 267 sp = 0 268 err = nil 269 } 270 return 271} 272 273// GetAccountDataInRange returns all account data for a given user inserted or 274// updated between two given positions 275// Returns a map following the format data[roomID] = []dataTypes 276// If no data is retrieved, returns an empty map 277// If there was an issue with the retrieval, returns an error 278func (d *Database) GetAccountDataInRange( 279 ctx context.Context, userID string, r types.Range, 280 accountDataFilterPart *gomatrixserverlib.EventFilter, 281) (map[string][]string, error) { 282 return d.AccountData.SelectAccountDataInRange(ctx, userID, r, accountDataFilterPart) 283} 284 285// UpsertAccountData keeps track of new or updated account data, by saving the type 286// of the new/updated data, and the user ID and room ID the data is related to (empty) 287// room ID means the data isn't specific to any room) 288// If no data with the given type, user ID and room ID exists in the database, 289// creates a new row, else update the existing one 290// Returns an error if there was an issue with the upsert 291func (d *Database) UpsertAccountData( 292 ctx context.Context, userID, roomID, dataType string, 293) (sp types.StreamPosition, err error) { 294 err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { 295 sp, err = d.AccountData.InsertAccountData(ctx, txn, userID, roomID, dataType) 296 return err 297 }) 298 return 299} 300 301func (d *Database) StreamEventsToEvents(device *userapi.Device, in []types.StreamEvent) []*gomatrixserverlib.HeaderedEvent { 302 out := make([]*gomatrixserverlib.HeaderedEvent, len(in)) 303 for i := 0; i < len(in); i++ { 304 out[i] = in[i].HeaderedEvent 305 if device != nil && in[i].TransactionID != nil { 306 if device.UserID == in[i].Sender() && device.SessionID == in[i].TransactionID.SessionID { 307 err := out[i].SetUnsignedField( 308 "transaction_id", in[i].TransactionID.TransactionID, 309 ) 310 if err != nil { 311 logrus.WithFields(logrus.Fields{ 312 "event_id": out[i].EventID(), 313 }).WithError(err).Warnf("Failed to add transaction ID to event") 314 } 315 } 316 } 317 } 318 return out 319} 320 321// handleBackwardExtremities adds this event as a backwards extremity if and only if we do not have all of 322// the events listed in the event's 'prev_events'. This function also updates the backwards extremities table 323// to account for the fact that the given event is no longer a backwards extremity, but may be marked as such. 324// This function should always be called within a sqlutil.Writer for safety in SQLite. 325func (d *Database) handleBackwardExtremities(ctx context.Context, txn *sql.Tx, ev *gomatrixserverlib.HeaderedEvent) error { 326 if err := d.BackwardExtremities.DeleteBackwardExtremity(ctx, txn, ev.RoomID(), ev.EventID()); err != nil { 327 return err 328 } 329 330 // Check if we have all of the event's previous events. If an event is 331 // missing, add it to the room's backward extremities. 332 prevEvents, err := d.OutputEvents.SelectEvents(ctx, txn, ev.PrevEventIDs()) 333 if err != nil { 334 return err 335 } 336 var found bool 337 for _, eID := range ev.PrevEventIDs() { 338 found = false 339 for _, prevEv := range prevEvents { 340 if eID == prevEv.EventID() { 341 found = true 342 } 343 } 344 345 // If the event is missing, consider it a backward extremity. 346 if !found { 347 if err = d.BackwardExtremities.InsertsBackwardExtremity(ctx, txn, ev.RoomID(), ev.EventID(), eID); err != nil { 348 return err 349 } 350 } 351 } 352 353 return nil 354} 355 356func (d *Database) PurgeRoomState( 357 ctx context.Context, roomID string, 358) error { 359 return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { 360 // If the event is a create event then we'll delete all of the existing 361 // data for the room. The only reason that a create event would be replayed 362 // to us in this way is if we're about to receive the entire room state. 363 if err := d.CurrentRoomState.DeleteRoomStateForRoom(ctx, txn, roomID); err != nil { 364 return fmt.Errorf("d.CurrentRoomState.DeleteRoomStateForRoom: %w", err) 365 } 366 return nil 367 }) 368} 369 370func (d *Database) WriteEvent( 371 ctx context.Context, 372 ev *gomatrixserverlib.HeaderedEvent, 373 addStateEvents []*gomatrixserverlib.HeaderedEvent, 374 addStateEventIDs, removeStateEventIDs []string, 375 transactionID *api.TransactionID, excludeFromSync bool, 376) (pduPosition types.StreamPosition, returnErr error) { 377 returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { 378 var err error 379 pos, err := d.OutputEvents.InsertEvent( 380 ctx, txn, ev, addStateEventIDs, removeStateEventIDs, transactionID, excludeFromSync, 381 ) 382 if err != nil { 383 return fmt.Errorf("d.OutputEvents.InsertEvent: %w", err) 384 } 385 pduPosition = pos 386 var topoPosition types.StreamPosition 387 if topoPosition, err = d.Topology.InsertEventInTopology(ctx, txn, ev, pos); err != nil { 388 return fmt.Errorf("d.Topology.InsertEventInTopology: %w", err) 389 } 390 391 if err = d.handleBackwardExtremities(ctx, txn, ev); err != nil { 392 return fmt.Errorf("d.handleBackwardExtremities: %w", err) 393 } 394 395 if len(addStateEvents) == 0 && len(removeStateEventIDs) == 0 { 396 // Nothing to do, the event may have just been a message event. 397 return nil 398 } 399 400 return d.updateRoomState(ctx, txn, removeStateEventIDs, addStateEvents, pduPosition, topoPosition) 401 }) 402 403 return pduPosition, returnErr 404} 405 406// This function should always be called within a sqlutil.Writer for safety in SQLite. 407func (d *Database) updateRoomState( 408 ctx context.Context, txn *sql.Tx, 409 removedEventIDs []string, 410 addedEvents []*gomatrixserverlib.HeaderedEvent, 411 pduPosition types.StreamPosition, 412 topoPosition types.StreamPosition, 413) error { 414 // remove first, then add, as we do not ever delete state, but do replace state which is a remove followed by an add. 415 for _, eventID := range removedEventIDs { 416 if err := d.CurrentRoomState.DeleteRoomStateByEventID(ctx, txn, eventID); err != nil { 417 return fmt.Errorf("d.CurrentRoomState.DeleteRoomStateByEventID: %w", err) 418 } 419 } 420 421 for _, event := range addedEvents { 422 if event.StateKey() == nil { 423 // ignore non state events 424 continue 425 } 426 var membership *string 427 if event.Type() == "m.room.member" { 428 value, err := event.Membership() 429 if err != nil { 430 return fmt.Errorf("event.Membership: %w", err) 431 } 432 membership = &value 433 if err = d.Memberships.UpsertMembership(ctx, txn, event, pduPosition, topoPosition); err != nil { 434 return fmt.Errorf("d.Memberships.UpsertMembership: %w", err) 435 } 436 } 437 438 if err := d.CurrentRoomState.UpsertRoomState(ctx, txn, event, membership, pduPosition); err != nil { 439 return fmt.Errorf("d.CurrentRoomState.UpsertRoomState: %w", err) 440 } 441 } 442 443 return nil 444} 445 446func (d *Database) GetEventsInTopologicalRange( 447 ctx context.Context, 448 from, to *types.TopologyToken, 449 roomID string, limit int, 450 backwardOrdering bool, 451) (events []types.StreamEvent, err error) { 452 var minDepth, maxDepth, maxStreamPosForMaxDepth types.StreamPosition 453 if backwardOrdering { 454 // Backward ordering means the 'from' token has a higher depth than the 'to' token 455 minDepth = to.Depth 456 maxDepth = from.Depth 457 // for cases where we have say 5 events with the same depth, the TopologyToken needs to 458 // know which of the 5 the client has seen. This is done by using the PDU position. 459 // Events with the same maxDepth but less than this PDU position will be returned. 460 maxStreamPosForMaxDepth = from.PDUPosition 461 } else { 462 // Forward ordering means the 'from' token has a lower depth than the 'to' token. 463 minDepth = from.Depth 464 maxDepth = to.Depth 465 } 466 467 // Select the event IDs from the defined range. 468 var eIDs []string 469 eIDs, err = d.Topology.SelectEventIDsInRange( 470 ctx, nil, roomID, minDepth, maxDepth, maxStreamPosForMaxDepth, limit, !backwardOrdering, 471 ) 472 if err != nil { 473 return 474 } 475 476 // Retrieve the events' contents using their IDs. 477 events, err = d.OutputEvents.SelectEvents(ctx, nil, eIDs) 478 return 479} 480 481func (d *Database) BackwardExtremitiesForRoom( 482 ctx context.Context, roomID string, 483) (backwardExtremities map[string][]string, err error) { 484 return d.BackwardExtremities.SelectBackwardExtremitiesForRoom(ctx, roomID) 485} 486 487func (d *Database) MaxTopologicalPosition( 488 ctx context.Context, roomID string, 489) (types.TopologyToken, error) { 490 depth, streamPos, err := d.Topology.SelectMaxPositionInTopology(ctx, nil, roomID) 491 if err != nil { 492 return types.TopologyToken{}, err 493 } 494 return types.TopologyToken{Depth: depth, PDUPosition: streamPos}, nil 495} 496 497func (d *Database) EventPositionInTopology( 498 ctx context.Context, eventID string, 499) (types.TopologyToken, error) { 500 depth, stream, err := d.Topology.SelectPositionInTopology(ctx, nil, eventID) 501 if err != nil { 502 return types.TopologyToken{}, err 503 } 504 return types.TopologyToken{Depth: depth, PDUPosition: stream}, nil 505} 506 507func (d *Database) GetFilter( 508 ctx context.Context, localpart string, filterID string, 509) (*gomatrixserverlib.Filter, error) { 510 return d.Filter.SelectFilter(ctx, localpart, filterID) 511} 512 513func (d *Database) PutFilter( 514 ctx context.Context, localpart string, filter *gomatrixserverlib.Filter, 515) (string, error) { 516 var filterID string 517 var err error 518 err = d.Writer.Do(nil, nil, func(txn *sql.Tx) error { 519 filterID, err = d.Filter.InsertFilter(ctx, filter, localpart) 520 return err 521 }) 522 return filterID, err 523} 524 525func (d *Database) RedactEvent(ctx context.Context, redactedEventID string, redactedBecause *gomatrixserverlib.HeaderedEvent) error { 526 redactedEvents, err := d.Events(ctx, []string{redactedEventID}) 527 if err != nil { 528 return err 529 } 530 if len(redactedEvents) == 0 { 531 logrus.WithField("event_id", redactedEventID).WithField("redaction_event", redactedBecause.EventID()).Warnf("missing redacted event for redaction") 532 return nil 533 } 534 eventToRedact := redactedEvents[0].Unwrap() 535 redactionEvent := redactedBecause.Unwrap() 536 ev, err := eventutil.RedactEvent(redactionEvent, eventToRedact) 537 if err != nil { 538 return err 539 } 540 541 newEvent := ev.Headered(redactedBecause.RoomVersion) 542 err = d.Writer.Do(nil, nil, func(txn *sql.Tx) error { 543 return d.OutputEvents.UpdateEventJSON(ctx, newEvent) 544 }) 545 return err 546} 547 548// Retrieve the backward topology position, i.e. the position of the 549// oldest event in the room's topology. 550func (d *Database) GetBackwardTopologyPos( 551 ctx context.Context, 552 events []types.StreamEvent, 553) (types.TopologyToken, error) { 554 zeroToken := types.TopologyToken{} 555 if len(events) == 0 { 556 return zeroToken, nil 557 } 558 pos, spos, err := d.Topology.SelectPositionInTopology(ctx, nil, events[0].EventID()) 559 if err != nil { 560 return zeroToken, err 561 } 562 tok := types.TopologyToken{Depth: pos, PDUPosition: spos} 563 tok.Decrement() 564 return tok, nil 565} 566 567// fetchStateEvents converts the set of event IDs into a set of events. It will fetch any which are missing from the database. 568// Returns a map of room ID to list of events. 569func (d *Database) fetchStateEvents( 570 ctx context.Context, txn *sql.Tx, 571 roomIDToEventIDSet map[string]map[string]bool, 572 eventIDToEvent map[string]types.StreamEvent, 573) (map[string][]types.StreamEvent, error) { 574 stateBetween := make(map[string][]types.StreamEvent) 575 missingEvents := make(map[string][]string) 576 for roomID, ids := range roomIDToEventIDSet { 577 events := stateBetween[roomID] 578 for id, need := range ids { 579 if !need { 580 continue // deleted state 581 } 582 e, ok := eventIDToEvent[id] 583 if ok { 584 events = append(events, e) 585 } else { 586 m := missingEvents[roomID] 587 m = append(m, id) 588 missingEvents[roomID] = m 589 } 590 } 591 stateBetween[roomID] = events 592 } 593 594 if len(missingEvents) > 0 { 595 // This happens when add_state_ids has an event ID which is not in the provided range. 596 // We need to explicitly fetch them. 597 allMissingEventIDs := []string{} 598 for _, missingEvIDs := range missingEvents { 599 allMissingEventIDs = append(allMissingEventIDs, missingEvIDs...) 600 } 601 evs, err := d.fetchMissingStateEvents(ctx, txn, allMissingEventIDs) 602 if err != nil { 603 return nil, err 604 } 605 // we know we got them all otherwise an error would've been returned, so just loop the events 606 for _, ev := range evs { 607 roomID := ev.RoomID() 608 stateBetween[roomID] = append(stateBetween[roomID], ev) 609 } 610 } 611 return stateBetween, nil 612} 613 614func (d *Database) fetchMissingStateEvents( 615 ctx context.Context, txn *sql.Tx, eventIDs []string, 616) ([]types.StreamEvent, error) { 617 // Fetch from the events table first so we pick up the stream ID for the 618 // event. 619 events, err := d.OutputEvents.SelectEvents(ctx, txn, eventIDs) 620 if err != nil { 621 return nil, err 622 } 623 624 have := map[string]bool{} 625 for _, event := range events { 626 have[event.EventID()] = true 627 } 628 var missing []string 629 for _, eventID := range eventIDs { 630 if !have[eventID] { 631 missing = append(missing, eventID) 632 } 633 } 634 if len(missing) == 0 { 635 return events, nil 636 } 637 638 // If they are missing from the events table then they should be state 639 // events that we received from outside the main event stream. 640 // These should be in the room state table. 641 stateEvents, err := d.CurrentRoomState.SelectEventsWithEventIDs(ctx, txn, missing) 642 643 if err != nil { 644 return nil, err 645 } 646 if len(stateEvents) != len(missing) { 647 logrus.WithContext(ctx).Warnf("Failed to map all event IDs to events (got %d, wanted %d)", len(stateEvents), len(missing)) 648 649 // TODO: Why is this happening? It's probably the roomserver. Uncomment 650 // this error again when we work out what it is and fix it, otherwise we 651 // just end up returning lots of 500s to the client and that breaks 652 // pretty much everything, rather than just sending what we have. 653 //return nil, fmt.Errorf("failed to map all event IDs to events: (got %d, wanted %d)", len(stateEvents), len(missing)) 654 } 655 events = append(events, stateEvents...) 656 return events, nil 657} 658 659// getStateDeltas returns the state deltas between fromPos and toPos, 660// exclusive of oldPos, inclusive of newPos, for the rooms in which 661// the user has new membership events. 662// A list of joined room IDs is also returned in case the caller needs it. 663func (d *Database) GetStateDeltas( 664 ctx context.Context, device *userapi.Device, 665 r types.Range, userID string, 666 stateFilter *gomatrixserverlib.StateFilter, 667) ([]types.StateDelta, []string, error) { 668 // Implement membership change algorithm: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L821 669 // - Get membership list changes for this user in this sync response 670 // - For each room which has membership list changes: 671 // * Check if the room is 'newly joined' (insufficient to just check for a join event because we allow dupe joins TODO). 672 // If it is, then we need to send the full room state down (and 'limited' is always true). 673 // * Check if user is still CURRENTLY invited to the room. If so, add room to 'invited' block. 674 // * Check if the user is CURRENTLY (TODO) left/banned. If so, add room to 'archived' block. 675 // - Get all CURRENTLY joined rooms, and add them to 'joined' block. 676 txn, err := d.readOnlySnapshot(ctx) 677 if err != nil { 678 return nil, nil, fmt.Errorf("d.readOnlySnapshot: %w", err) 679 } 680 var succeeded bool 681 defer sqlutil.EndTransactionWithCheck(txn, &succeeded, &err) 682 683 var deltas []types.StateDelta 684 685 // get all the state events ever (i.e. for all available rooms) between these two positions 686 stateNeeded, eventMap, err := d.OutputEvents.SelectStateInRange(ctx, txn, r, stateFilter) 687 if err != nil { 688 return nil, nil, err 689 } 690 state, err := d.fetchStateEvents(ctx, txn, stateNeeded, eventMap) 691 if err != nil { 692 return nil, nil, err 693 } 694 695 // find out which rooms this user is peeking, if any. 696 // We do this before joins so any peeks get overwritten 697 peeks, err := d.Peeks.SelectPeeksInRange(ctx, txn, userID, device.ID, r) 698 if err != nil { 699 return nil, nil, err 700 } 701 702 // add peek blocks 703 for _, peek := range peeks { 704 if peek.New { 705 // send full room state down instead of a delta 706 var s []types.StreamEvent 707 s, err = d.currentStateStreamEventsForRoom(ctx, txn, peek.RoomID, stateFilter) 708 if err != nil { 709 return nil, nil, err 710 } 711 state[peek.RoomID] = s 712 } 713 if !peek.Deleted { 714 deltas = append(deltas, types.StateDelta{ 715 Membership: gomatrixserverlib.Peek, 716 StateEvents: d.StreamEventsToEvents(device, state[peek.RoomID]), 717 RoomID: peek.RoomID, 718 }) 719 } 720 } 721 722 // handle newly joined rooms and non-joined rooms 723 for roomID, stateStreamEvents := range state { 724 for _, ev := range stateStreamEvents { 725 // TODO: Currently this will incorrectly add rooms which were ALREADY joined but they sent another no-op join event. 726 // We should be checking if the user was already joined at fromPos and not proceed if so. As a result of this, 727 // dupe join events will result in the entire room state coming down to the client again. This is added in 728 // the 'state' part of the response though, so is transparent modulo bandwidth concerns as it is not added to 729 // the timeline. 730 if membership := getMembershipFromEvent(ev.Event, userID); membership != "" { 731 if membership == gomatrixserverlib.Join { 732 // send full room state down instead of a delta 733 var s []types.StreamEvent 734 s, err = d.currentStateStreamEventsForRoom(ctx, txn, roomID, stateFilter) 735 if err != nil { 736 return nil, nil, err 737 } 738 state[roomID] = s 739 continue // we'll add this room in when we do joined rooms 740 } 741 742 deltas = append(deltas, types.StateDelta{ 743 Membership: membership, 744 MembershipPos: ev.StreamPosition, 745 StateEvents: d.StreamEventsToEvents(device, stateStreamEvents), 746 RoomID: roomID, 747 }) 748 break 749 } 750 } 751 } 752 753 // Add in currently joined rooms 754 joinedRoomIDs, err := d.CurrentRoomState.SelectRoomIDsWithMembership(ctx, txn, userID, gomatrixserverlib.Join) 755 if err != nil { 756 return nil, nil, err 757 } 758 for _, joinedRoomID := range joinedRoomIDs { 759 deltas = append(deltas, types.StateDelta{ 760 Membership: gomatrixserverlib.Join, 761 StateEvents: d.StreamEventsToEvents(device, state[joinedRoomID]), 762 RoomID: joinedRoomID, 763 }) 764 } 765 766 succeeded = true 767 return deltas, joinedRoomIDs, nil 768} 769 770// getStateDeltasForFullStateSync is a variant of getStateDeltas used for /sync 771// requests with full_state=true. 772// Fetches full state for all joined rooms and uses selectStateInRange to get 773// updates for other rooms. 774func (d *Database) GetStateDeltasForFullStateSync( 775 ctx context.Context, device *userapi.Device, 776 r types.Range, userID string, 777 stateFilter *gomatrixserverlib.StateFilter, 778) ([]types.StateDelta, []string, error) { 779 txn, err := d.readOnlySnapshot(ctx) 780 if err != nil { 781 return nil, nil, fmt.Errorf("d.readOnlySnapshot: %w", err) 782 } 783 var succeeded bool 784 defer sqlutil.EndTransactionWithCheck(txn, &succeeded, &err) 785 786 // Use a reasonable initial capacity 787 deltas := make(map[string]types.StateDelta) 788 789 peeks, err := d.Peeks.SelectPeeksInRange(ctx, txn, userID, device.ID, r) 790 if err != nil { 791 return nil, nil, err 792 } 793 794 // Add full states for all peeking rooms 795 for _, peek := range peeks { 796 if !peek.Deleted { 797 s, stateErr := d.currentStateStreamEventsForRoom(ctx, txn, peek.RoomID, stateFilter) 798 if stateErr != nil { 799 return nil, nil, stateErr 800 } 801 deltas[peek.RoomID] = types.StateDelta{ 802 Membership: gomatrixserverlib.Peek, 803 StateEvents: d.StreamEventsToEvents(device, s), 804 RoomID: peek.RoomID, 805 } 806 } 807 } 808 809 // Get all the state events ever between these two positions 810 stateNeeded, eventMap, err := d.OutputEvents.SelectStateInRange(ctx, txn, r, stateFilter) 811 if err != nil { 812 return nil, nil, err 813 } 814 state, err := d.fetchStateEvents(ctx, txn, stateNeeded, eventMap) 815 if err != nil { 816 return nil, nil, err 817 } 818 819 for roomID, stateStreamEvents := range state { 820 for _, ev := range stateStreamEvents { 821 if membership := getMembershipFromEvent(ev.Event, userID); membership != "" { 822 if membership != gomatrixserverlib.Join { // We've already added full state for all joined rooms above. 823 deltas[roomID] = types.StateDelta{ 824 Membership: membership, 825 MembershipPos: ev.StreamPosition, 826 StateEvents: d.StreamEventsToEvents(device, stateStreamEvents), 827 RoomID: roomID, 828 } 829 } 830 831 break 832 } 833 } 834 } 835 836 joinedRoomIDs, err := d.CurrentRoomState.SelectRoomIDsWithMembership(ctx, txn, userID, gomatrixserverlib.Join) 837 if err != nil { 838 return nil, nil, err 839 } 840 841 // Add full states for all joined rooms 842 for _, joinedRoomID := range joinedRoomIDs { 843 s, stateErr := d.currentStateStreamEventsForRoom(ctx, txn, joinedRoomID, stateFilter) 844 if stateErr != nil { 845 return nil, nil, stateErr 846 } 847 deltas[joinedRoomID] = types.StateDelta{ 848 Membership: gomatrixserverlib.Join, 849 StateEvents: d.StreamEventsToEvents(device, s), 850 RoomID: joinedRoomID, 851 } 852 } 853 854 // Create a response array. 855 result := make([]types.StateDelta, len(deltas)) 856 i := 0 857 for _, delta := range deltas { 858 result[i] = delta 859 i++ 860 } 861 862 succeeded = true 863 return result, joinedRoomIDs, nil 864} 865 866func (d *Database) currentStateStreamEventsForRoom( 867 ctx context.Context, txn *sql.Tx, roomID string, 868 stateFilter *gomatrixserverlib.StateFilter, 869) ([]types.StreamEvent, error) { 870 allState, err := d.CurrentRoomState.SelectCurrentState(ctx, txn, roomID, stateFilter, nil) 871 if err != nil { 872 return nil, err 873 } 874 s := make([]types.StreamEvent, len(allState)) 875 for i := 0; i < len(s); i++ { 876 s[i] = types.StreamEvent{HeaderedEvent: allState[i], StreamPosition: 0} 877 } 878 return s, nil 879} 880 881func (d *Database) StoreNewSendForDeviceMessage( 882 ctx context.Context, userID, deviceID string, event gomatrixserverlib.SendToDeviceEvent, 883) (newPos types.StreamPosition, err error) { 884 j, err := json.Marshal(event) 885 if err != nil { 886 return 0, err 887 } 888 // Delegate the database write task to the SendToDeviceWriter. It'll guarantee 889 // that we don't lock the table for writes in more than one place. 890 err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { 891 newPos, err = d.SendToDevice.InsertSendToDeviceMessage( 892 ctx, txn, userID, deviceID, string(j), 893 ) 894 return err 895 }) 896 if err != nil { 897 return 0, err 898 } 899 return newPos, nil 900} 901 902func (d *Database) SendToDeviceUpdatesForSync( 903 ctx context.Context, 904 userID, deviceID string, 905 from, to types.StreamPosition, 906) (types.StreamPosition, []types.SendToDeviceEvent, error) { 907 // First of all, get our send-to-device updates for this user. 908 lastPos, events, err := d.SendToDevice.SelectSendToDeviceMessages(ctx, nil, userID, deviceID, from, to) 909 if err != nil { 910 return from, nil, fmt.Errorf("d.SendToDevice.SelectSendToDeviceMessages: %w", err) 911 } 912 // If there's nothing to do then stop here. 913 if len(events) == 0 { 914 return to, nil, nil 915 } 916 return lastPos, events, nil 917} 918 919func (d *Database) CleanSendToDeviceUpdates( 920 ctx context.Context, 921 userID, deviceID string, before types.StreamPosition, 922) (err error) { 923 if err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { 924 return d.SendToDevice.DeleteSendToDeviceMessages(ctx, txn, userID, deviceID, before) 925 }); err != nil { 926 logrus.WithError(err).Errorf("Failed to clean up old send-to-device messages for user %q device %q", userID, deviceID) 927 return err 928 } 929 return nil 930} 931 932// getMembershipFromEvent returns the value of content.membership iff the event is a state event 933// with type 'm.room.member' and state_key of userID. Otherwise, an empty string is returned. 934func getMembershipFromEvent(ev *gomatrixserverlib.Event, userID string) string { 935 if ev.Type() != "m.room.member" || !ev.StateKeyEquals(userID) { 936 return "" 937 } 938 membership, err := ev.Membership() 939 if err != nil { 940 return "" 941 } 942 return membership 943} 944 945// StoreReceipt stores user receipts 946func (d *Database) StoreReceipt(ctx context.Context, roomId, receiptType, userId, eventId string, timestamp gomatrixserverlib.Timestamp) (pos types.StreamPosition, err error) { 947 err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { 948 pos, err = d.Receipts.UpsertReceipt(ctx, txn, roomId, receiptType, userId, eventId, timestamp) 949 return err 950 }) 951 return 952} 953 954func (d *Database) GetRoomReceipts(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]eduAPI.OutputReceiptEvent, error) { 955 _, receipts, err := d.Receipts.SelectRoomReceiptsAfter(ctx, roomIDs, streamPos) 956 return receipts, err 957} 958