1// Copyright 2021 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package consumers
16
17import (
18	"context"
19	"encoding/json"
20
21	"github.com/matrix-org/dendrite/internal"
22	"github.com/matrix-org/dendrite/keyserver/api"
23	"github.com/matrix-org/dendrite/keyserver/storage"
24	"github.com/matrix-org/dendrite/setup/config"
25	"github.com/matrix-org/dendrite/setup/process"
26	"github.com/matrix-org/gomatrixserverlib"
27	"github.com/sirupsen/logrus"
28
29	"github.com/Shopify/sarama"
30)
31
32type OutputCrossSigningKeyUpdateConsumer struct {
33	eduServerConsumer *internal.ContinualConsumer
34	keyDB             storage.Database
35	keyAPI            api.KeyInternalAPI
36	serverName        string
37}
38
39func NewOutputCrossSigningKeyUpdateConsumer(
40	process *process.ProcessContext,
41	cfg *config.Dendrite,
42	kafkaConsumer sarama.Consumer,
43	keyDB storage.Database,
44	keyAPI api.KeyInternalAPI,
45) *OutputCrossSigningKeyUpdateConsumer {
46	// The keyserver both produces and consumes on the TopicOutputKeyChangeEvent
47	// topic. We will only produce events where the UserID matches our server name,
48	// and we will only consume events where the UserID does NOT match our server
49	// name (because the update came from a remote server).
50	consumer := internal.ContinualConsumer{
51		Process:        process,
52		ComponentName:  "keyserver/keyserver",
53		Topic:          cfg.Global.Kafka.TopicFor(config.TopicOutputKeyChangeEvent),
54		Consumer:       kafkaConsumer,
55		PartitionStore: keyDB,
56	}
57	s := &OutputCrossSigningKeyUpdateConsumer{
58		eduServerConsumer: &consumer,
59		keyDB:             keyDB,
60		keyAPI:            keyAPI,
61		serverName:        string(cfg.Global.ServerName),
62	}
63	consumer.ProcessMessage = s.onMessage
64
65	return s
66}
67
68func (s *OutputCrossSigningKeyUpdateConsumer) Start() error {
69	return s.eduServerConsumer.Start()
70}
71
72// onMessage is called in response to a message received on the
73// key change events topic from the key server.
74func (t *OutputCrossSigningKeyUpdateConsumer) onMessage(msg *sarama.ConsumerMessage) error {
75	var m api.DeviceMessage
76	if err := json.Unmarshal(msg.Value, &m); err != nil {
77		logrus.WithError(err).Errorf("failed to read device message from key change topic")
78		return nil
79	}
80	if m.OutputCrossSigningKeyUpdate == nil {
81		// This probably shouldn't happen but stops us from panicking if we come
82		// across an update that doesn't satisfy either types.
83		return nil
84	}
85	switch m.Type {
86	case api.TypeCrossSigningUpdate:
87		return t.onCrossSigningMessage(m)
88	default:
89		return nil
90	}
91}
92
93func (s *OutputCrossSigningKeyUpdateConsumer) onCrossSigningMessage(m api.DeviceMessage) error {
94	output := m.CrossSigningKeyUpdate
95	_, host, err := gomatrixserverlib.SplitID('@', output.UserID)
96	if err != nil {
97		logrus.WithError(err).Errorf("eduserver output log: user ID parse failure")
98		return nil
99	}
100	if host == gomatrixserverlib.ServerName(s.serverName) {
101		// Ignore any messages that contain information about our own users, as
102		// they already originated from this server.
103		return nil
104	}
105	uploadReq := &api.PerformUploadDeviceKeysRequest{
106		UserID: output.UserID,
107	}
108	if output.MasterKey != nil {
109		uploadReq.MasterKey = *output.MasterKey
110	}
111	if output.SelfSigningKey != nil {
112		uploadReq.SelfSigningKey = *output.SelfSigningKey
113	}
114	uploadRes := &api.PerformUploadDeviceKeysResponse{}
115	s.keyAPI.PerformUploadDeviceKeys(context.TODO(), uploadReq, uploadRes)
116	return uploadRes.Error
117}
118