1package internal
2
3import (
4	"context"
5	"crypto/ed25519"
6	"fmt"
7	"time"
8
9	"github.com/matrix-org/dendrite/setup/config"
10	"github.com/matrix-org/dendrite/signingkeyserver/api"
11	"github.com/matrix-org/gomatrixserverlib"
12	"github.com/sirupsen/logrus"
13)
14
15type ServerKeyAPI struct {
16	api.SigningKeyServerAPI
17
18	ServerName        gomatrixserverlib.ServerName
19	ServerPublicKey   ed25519.PublicKey
20	ServerKeyID       gomatrixserverlib.KeyID
21	ServerKeyValidity time.Duration
22	OldServerKeys     []config.OldVerifyKeys
23
24	OurKeyRing gomatrixserverlib.KeyRing
25	FedClient  gomatrixserverlib.KeyClient
26}
27
28func (s *ServerKeyAPI) KeyRing() *gomatrixserverlib.KeyRing {
29	// Return a keyring that forces requests to be proxied through the
30	// below functions. That way we can enforce things like validity
31	// and keeping the cache up-to-date.
32	return &gomatrixserverlib.KeyRing{
33		KeyDatabase: s,
34		KeyFetchers: []gomatrixserverlib.KeyFetcher{},
35	}
36}
37
38func (s *ServerKeyAPI) StoreKeys(
39	_ context.Context,
40	results map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult,
41) error {
42	// Run in a background context - we don't want to stop this work just
43	// because the caller gives up waiting.
44	ctx := context.Background()
45
46	// Store any keys that we were given in our database.
47	return s.OurKeyRing.KeyDatabase.StoreKeys(ctx, results)
48}
49
50func (s *ServerKeyAPI) FetchKeys(
51	_ context.Context,
52	requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp,
53) (map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, error) {
54	// Run in a background context - we don't want to stop this work just
55	// because the caller gives up waiting.
56	ctx := context.Background()
57	now := gomatrixserverlib.AsTimestamp(time.Now())
58	results := map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult{}
59	origRequests := map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp{}
60	for k, v := range requests {
61		origRequests[k] = v
62	}
63
64	// First, check if any of these key checks are for our own keys. If
65	// they are then we will satisfy them directly.
66	s.handleLocalKeys(ctx, requests, results)
67
68	// Then consult our local database and see if we have the requested
69	// keys. These might come from a cache, depending on the database
70	// implementation used.
71	if err := s.handleDatabaseKeys(ctx, now, requests, results); err != nil {
72		return nil, err
73	}
74
75	// For any key requests that we still have outstanding, next try to
76	// fetch them directly. We'll go through each of the key fetchers to
77	// ask for the remaining keys
78	for _, fetcher := range s.OurKeyRing.KeyFetchers {
79		// If there are no more keys to look up then stop.
80		if len(requests) == 0 {
81			break
82		}
83
84		// Ask the fetcher to look up our keys.
85		if err := s.handleFetcherKeys(ctx, now, fetcher, requests, results); err != nil {
86			logrus.WithError(err).WithFields(logrus.Fields{
87				"fetcher_name": fetcher.FetcherName(),
88			}).Errorf("Failed to retrieve %d key(s)", len(requests))
89			continue
90		}
91	}
92
93	// Check that we've actually satisfied all of the key requests that we
94	// were given. We should report an error if we didn't.
95	for req := range origRequests {
96		if _, ok := results[req]; !ok {
97			// The results don't contain anything for this specific request, so
98			// we've failed to satisfy it from local keys, database keys or from
99			// all of the fetchers. Report an error.
100			logrus.Warnf("Failed to retrieve key %q for server %q", req.KeyID, req.ServerName)
101		}
102	}
103
104	// Return the keys.
105	return results, nil
106}
107
108func (s *ServerKeyAPI) FetcherName() string {
109	return fmt.Sprintf("ServerKeyAPI (wrapping %q)", s.OurKeyRing.KeyDatabase.FetcherName())
110}
111
112// handleLocalKeys handles cases where the key request contains
113// a request for our own server keys, either current or old.
114func (s *ServerKeyAPI) handleLocalKeys(
115	_ context.Context,
116	requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp,
117	results map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult,
118) {
119	for req := range requests {
120		if req.ServerName != s.ServerName {
121			continue
122		}
123		if req.KeyID == s.ServerKeyID {
124			// We found a key request that is supposed to be for our own
125			// keys. Remove it from the request list so we don't hit the
126			// database or the fetchers for it.
127			delete(requests, req)
128
129			// Insert our own key into the response.
130			results[req] = gomatrixserverlib.PublicKeyLookupResult{
131				VerifyKey: gomatrixserverlib.VerifyKey{
132					Key: gomatrixserverlib.Base64Bytes(s.ServerPublicKey),
133				},
134				ExpiredTS:    gomatrixserverlib.PublicKeyNotExpired,
135				ValidUntilTS: gomatrixserverlib.AsTimestamp(time.Now().Add(s.ServerKeyValidity)),
136			}
137		} else {
138			// The key request doesn't match our current key. Let's see
139			// if it matches any of our old verify keys.
140			for _, oldVerifyKey := range s.OldServerKeys {
141				if req.KeyID == oldVerifyKey.KeyID {
142					// We found a key request that is supposed to be an expired
143					// key.
144					delete(requests, req)
145
146					// Insert our own key into the response.
147					results[req] = gomatrixserverlib.PublicKeyLookupResult{
148						VerifyKey: gomatrixserverlib.VerifyKey{
149							Key: gomatrixserverlib.Base64Bytes(oldVerifyKey.PrivateKey.Public().(ed25519.PublicKey)),
150						},
151						ExpiredTS:    oldVerifyKey.ExpiredAt,
152						ValidUntilTS: gomatrixserverlib.PublicKeyNotValid,
153					}
154
155					// No need to look at the other keys.
156					break
157				}
158			}
159		}
160	}
161}
162
163// handleDatabaseKeys handles cases where the key requests can be
164// satisfied from our local database/cache.
165func (s *ServerKeyAPI) handleDatabaseKeys(
166	ctx context.Context,
167	now gomatrixserverlib.Timestamp,
168	requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp,
169	results map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult,
170) error {
171	// Ask the database/cache for the keys.
172	dbResults, err := s.OurKeyRing.KeyDatabase.FetchKeys(ctx, requests)
173	if err != nil {
174		return err
175	}
176
177	// We successfully got some keys. Add them to the results.
178	for req, res := range dbResults {
179		// The key we've retrieved from the database/cache might
180		// have passed its validity period, but right now, it's
181		// the best thing we've got, and it might be sufficient to
182		// verify a past event.
183		results[req] = res
184
185		// If the key is valid right now then we can also remove it
186		// from the request list as we don't need to fetch it again
187		// in that case. If the key isn't valid right now, then by
188		// leaving it in the 'requests' map, we'll try to update the
189		// key using the fetchers in handleFetcherKeys.
190		if res.WasValidAt(now, true) {
191			delete(requests, req)
192		}
193	}
194	return nil
195}
196
197// handleFetcherKeys handles cases where a fetcher can satisfy
198// the remaining requests.
199func (s *ServerKeyAPI) handleFetcherKeys(
200	ctx context.Context,
201	_ gomatrixserverlib.Timestamp,
202	fetcher gomatrixserverlib.KeyFetcher,
203	requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp,
204	results map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult,
205) error {
206	logrus.WithFields(logrus.Fields{
207		"fetcher_name": fetcher.FetcherName(),
208	}).Infof("Fetching %d key(s)", len(requests))
209
210	// Create a context that limits our requests to 30 seconds.
211	fetcherCtx, fetcherCancel := context.WithTimeout(ctx, time.Second*30)
212	defer fetcherCancel()
213
214	// Try to fetch the keys.
215	fetcherResults, err := fetcher.FetchKeys(fetcherCtx, requests)
216	if err != nil {
217		return fmt.Errorf("fetcher.FetchKeys: %w", err)
218	}
219
220	// Build a map of the results that we want to commit to the
221	// database. We do this in a separate map because otherwise we
222	// might end up trying to rewrite database entries.
223	storeResults := map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult{}
224
225	// Now let's look at the results that we got from this fetcher.
226	for req, res := range fetcherResults {
227		if prev, ok := results[req]; ok {
228			// We've already got a previous entry for this request
229			// so let's see if the newly retrieved one contains a more
230			// up-to-date validity period.
231			if res.ValidUntilTS > prev.ValidUntilTS {
232				// This key is newer than the one we had so let's store
233				// it in the database.
234				storeResults[req] = res
235			}
236		} else {
237			// We didn't already have a previous entry for this request
238			// so store it in the database anyway for now.
239			storeResults[req] = res
240		}
241
242		// Update the results map with this new result. If nothing
243		// else, we can try verifying against this key.
244		results[req] = res
245
246		// Remove it from the request list so we won't re-fetch it.
247		delete(requests, req)
248	}
249
250	// Store the keys from our store map.
251	if err = s.OurKeyRing.KeyDatabase.StoreKeys(context.Background(), storeResults); err != nil {
252		logrus.WithError(err).WithFields(logrus.Fields{
253			"fetcher_name":  fetcher.FetcherName(),
254			"database_name": s.OurKeyRing.KeyDatabase.FetcherName(),
255		}).Errorf("Failed to store keys in the database")
256		return fmt.Errorf("server key API failed to store retrieved keys: %w", err)
257	}
258
259	if len(storeResults) > 0 {
260		logrus.WithFields(logrus.Fields{
261			"fetcher_name": fetcher.FetcherName(),
262		}).Infof("Updated %d of %d key(s) in database (%d keys remaining)", len(storeResults), len(results), len(requests))
263	}
264
265	return nil
266}
267