1package gomatrixserverlib
2
3import (
4	"context"
5	"errors"
6	"fmt"
7	"strings"
8	"sync"
9	"time"
10
11	"github.com/matrix-org/util"
12	"golang.org/x/crypto/ed25519"
13)
14
15// A PublicKeyLookupRequest is a request for a public key with a particular key ID.
16type PublicKeyLookupRequest struct {
17	// The server to fetch a key for.
18	ServerName ServerName `json:"server_name"`
19	// The ID of the key to fetch.
20	KeyID KeyID `json:"key_id"`
21}
22
23// MarshalText turns the public key lookup request into a string format,
24// which allows us to use it as a JSON map key.
25func (r PublicKeyLookupRequest) MarshalText() ([]byte, error) {
26	return []byte(fmt.Sprintf("%s/%s", r.ServerName, r.KeyID)), nil
27}
28
29// UnmarshalText turns the string format back into a public key lookup
30// request, from a JSON map key.
31func (r *PublicKeyLookupRequest) UnmarshalText(text []byte) error {
32	parts := strings.SplitN(string(text), "/", 2)
33	if len(parts) < 2 {
34		return errors.New("expected at least one / separator in " + string(text))
35	}
36	r.ServerName, r.KeyID = ServerName(parts[0]), KeyID(parts[1])
37	return nil
38}
39
40// PublicKeyNotExpired is a magic value for PublicKeyLookupResult.ExpiredTS:
41// it indicates that this is an active key which has not yet expired
42const PublicKeyNotExpired = Timestamp(0)
43
44// PublicKeyNotValid is a magic value for PublicKeyLookupResult.ValidUntilTS:
45// it is used when we don't have a validity period for this key. Most likely
46// it is an old key with an expiry date.
47const PublicKeyNotValid = Timestamp(0)
48
49// A PublicKeyLookupResult is the result of looking up a server signing key.
50type PublicKeyLookupResult struct {
51	VerifyKey
52	// if this key has expired, the time it stopped being valid for event signing in milliseconds.
53	// if the key has not expired, the magic value PublicKeyNotExpired.
54	ExpiredTS Timestamp `json:"expired_ts"`
55	// When this result is valid until in milliseconds.
56	// if the key has expired, the magic value PublicKeyNotValid.
57	ValidUntilTS Timestamp `json:"valid_until_ts"`
58}
59
60// WasValidAt checks if this signing key is valid for an event signed at the
61// given timestamp.
62func (r PublicKeyLookupResult) WasValidAt(atTs Timestamp, strictValidityChecking bool) bool {
63	if r.ExpiredTS != PublicKeyNotExpired {
64		return atTs < r.ExpiredTS
65	}
66	if strictValidityChecking {
67		if r.ValidUntilTS == PublicKeyNotValid {
68			return false
69		}
70		// Servers MUST use the lesser of valid_until_ts and 7 days into the
71		// future when determining if a key is valid.
72		// https://matrix.org/docs/spec/rooms/v5#signing-key-validity-period
73		sevenDaysFuture := time.Now().Add(time.Hour * 24 * 7)
74		validUntilTS := r.ValidUntilTS.Time()
75		if validUntilTS.After(sevenDaysFuture) {
76			validUntilTS = sevenDaysFuture
77		}
78		if atTs.Time().After(validUntilTS) {
79			return false
80		}
81	}
82	return true
83}
84
85type PublicKeyNotaryLookupRequest struct {
86	ServerKeys map[ServerName]map[KeyID]PublicKeyNotaryQueryCriteria `json:"server_keys"`
87}
88
89type PublicKeyNotaryQueryCriteria struct {
90	MinimumValidUntilTS Timestamp `json:"minimum_valid_until_ts"`
91}
92
93// A KeyFetcher is a way of fetching public keys in bulk.
94type KeyFetcher interface {
95	// Lookup a batch of public keys.
96	// Takes a map from (server name, key ID) pairs to timestamp.
97	// The timestamp is when the keys need to be vaild up to.
98	// Returns a map from (server name, key ID) pairs to server key objects for
99	// that server name containing that key ID
100	// The result may have fewer (server name, key ID) pairs than were in the request.
101	// The result may have more (server name, key ID) pairs than were in the request.
102	// Returns an error if there was a problem fetching the keys.
103	FetchKeys(ctx context.Context, requests map[PublicKeyLookupRequest]Timestamp) (map[PublicKeyLookupRequest]PublicKeyLookupResult, error)
104
105	// FetcherName returns the name of this fetcher, which can then be used for
106	// logging errors etc.
107	FetcherName() string
108}
109
110// A KeyDatabase is a store for caching public keys.
111type KeyDatabase interface {
112	KeyFetcher
113	// Add a block of public keys to the database.
114	// Returns an error if there was a problem storing the keys.
115	// A database is not required to rollback storing the all keys if some of
116	// the keys aren't stored, and an in-progess store may be partially visible
117	// to a concurrent FetchKeys(). This is acceptable since the database is
118	// only used as a cache for the keys, so if a FetchKeys() races with a
119	// StoreKeys() and some of the keys are missing they will be just be refetched.
120	StoreKeys(ctx context.Context, results map[PublicKeyLookupRequest]PublicKeyLookupResult) error
121}
122
123// A KeyRing stores keys for matrix servers and provides methods for verifying JSON messages.
124type KeyRing struct {
125	KeyFetchers []KeyFetcher
126	KeyDatabase KeyDatabase
127}
128
129// A VerifyJSONRequest is a request to check for a signature on a JSON message.
130// A JSON message is valid for a server if the message has at least one valid
131// signature from that server.
132type VerifyJSONRequest struct {
133	// The name of the matrix server to check for a signature for.
134	ServerName ServerName
135	// The millisecond posix timestamp the message needs to be valid at.
136	AtTS Timestamp
137	// The JSON bytes.
138	Message []byte
139	// Should validity signature checking be enabled? (Room version >= 5)
140	StrictValidityChecking bool
141}
142
143// A VerifyJSONResult is the result of checking the signature of a JSON message.
144type VerifyJSONResult struct {
145	// Whether the message passed the signature checks.
146	// This will be nil if the message passed the checks.
147	// This will have an error if the message did not pass the checks.
148	Error error
149}
150
151// A JSONVerifier is an object which can verify the signatures of JSON messages.
152type JSONVerifier interface {
153	// VerifyJSONs performs bulk JSON signature verification for a list of VerifyJSONRequests.
154	// Returns a list of VerifyJSONResults with the same length and order as the request list.
155	// The caller should check the Result field for each entry to see if it was valid.
156	// Returns an error if there was a problem talking to the database or one of the other methods
157	// of fetching the public keys.
158	VerifyJSONs(ctx context.Context, requests []VerifyJSONRequest) ([]VerifyJSONResult, error)
159}
160
161// VerifyJSONs implements JSONVerifier.
162func (k KeyRing) VerifyJSONs(ctx context.Context, requests []VerifyJSONRequest) ([]VerifyJSONResult, error) { // nolint: gocyclo
163	logger := util.GetLogger(ctx)
164	results := make([]VerifyJSONResult, len(requests))
165	keyIDs := make([][]KeyID, len(requests))
166
167	// Store the initial number of requests that were made. We'll remove
168	// things from the requests array that we no longer need, but we later
169	// need to check that we satisfied the full number of requests.
170	numRequests := len(requests)
171
172	for i := range requests {
173		ids, err := ListKeyIDs(string(requests[i].ServerName), requests[i].Message)
174		if err != nil {
175			results[i].Error = fmt.Errorf("gomatrixserverlib: error extracting key IDs")
176			continue
177		}
178		for _, keyID := range ids {
179			if k.isAlgorithmSupported(keyID) {
180				keyIDs[i] = append(keyIDs[i], keyID)
181			}
182		}
183		if len(keyIDs[i]) == 0 {
184			results[i].Error = fmt.Errorf(
185				"gomatrixserverlib: not signed by %q with a supported algorithm", requests[i].ServerName,
186			)
187			continue
188		}
189		// Set a place holder error in the result field.
190		// This will be unset if one of the signature checks passes.
191		// This will be overwritten if one of the signature checks fails.
192		// Therefore this will only remain in place if the keys couldn't be downloaded.
193		results[i].Error = fmt.Errorf(
194			"gomatrixserverlib: could not download key for %q", requests[i].ServerName,
195		)
196	}
197
198	keyRequests := k.publicKeyRequests(requests, results, keyIDs)
199	if len(keyRequests) == 0 {
200		// There aren't any keys to fetch so we can stop here.
201		// This will happen if all the objects are missing supported signatures.
202		return results, nil
203	}
204	keysFromDatabase, err := k.KeyDatabase.FetchKeys(ctx, keyRequests)
205	if err != nil {
206		return nil, err
207	}
208
209	keysFetched := map[PublicKeyLookupRequest]PublicKeyLookupResult{}
210	now := AsTimestamp(time.Now())
211	for req, res := range keysFromDatabase {
212		if res.ExpiredTS != PublicKeyNotExpired {
213			// The key is expired - it's not going to change so just return
214			// it and don't bother requesting it again.
215			keysFetched[req] = res
216			delete(keyRequests, req)
217			continue
218		}
219		// The key isn't expired so include it in the results.
220		keysFetched[req] = res
221		// If the key is inside validity then we don't need to update it.
222		if now < res.ValidUntilTS && res.ExpiredTS == PublicKeyNotExpired {
223			delete(keyRequests, req)
224		}
225	}
226
227	if len(keysFetched) == numRequests {
228		// If our key requests are all satisfied then we can try performing
229		// a verification using our keys.
230		k.checkUsingKeys(requests, results, keyIDs, keysFetched)
231
232		// If we run into any errors when verifying using the keys that we
233		// have then we can hit federation and check for updated keys.
234		errored := false
235		for _, r := range results {
236			if r.Error != nil {
237				errored = true
238				break
239			}
240		}
241		if !errored {
242			return results, nil
243		}
244	}
245
246	for _, fetcher := range k.KeyFetchers {
247		// If we have all of the keys that we need now then we can
248		// break the loop.
249		if len(keyRequests) == 0 {
250			break
251		}
252
253		fetcherLogger := logger.WithField("fetcher", fetcher.FetcherName())
254
255		// TODO: Coalesce in-flight requests for the same keys.
256		// Otherwise we risk spamming the servers we query the keys from.
257
258		fetcherLogger.WithField("num_key_requests", len(keyRequests)).
259			Info("Requesting keys from fetcher")
260
261		fetched, err := fetcher.FetchKeys(ctx, keyRequests)
262		if err != nil {
263			fetcherLogger.WithError(err).Warn("Failed to request keys from fetcher")
264			continue
265		}
266
267		if len(fetched) == 0 {
268			fetcherLogger.Warn("Failed to retrieve any keys")
269			continue
270		}
271
272		fetcherLogger.WithField("num_keys_fetched", len(fetched)).
273			Info("Got keys from fetcher")
274
275		// Hold the new keys and remove them from the request queue.
276		for req, res := range fetched {
277			keysFetched[req] = res
278			delete(keyRequests, req)
279		}
280	}
281
282	// Now that we've fetched all of the keys we need, try to check
283	// if the requests are valid.
284	k.checkUsingKeys(requests, results, keyIDs, keysFetched)
285
286	// Add the keys to the database so that we won't need to fetch them again.
287	if err := k.KeyDatabase.StoreKeys(ctx, keysFetched); err != nil {
288		return nil, err
289	}
290
291	return results, nil
292}
293
294func (k *KeyRing) isAlgorithmSupported(keyID KeyID) bool {
295	return strings.HasPrefix(string(keyID), "ed25519:")
296}
297
298func (k *KeyRing) publicKeyRequests(
299	requests []VerifyJSONRequest, results []VerifyJSONResult, keyIDs [][]KeyID,
300) map[PublicKeyLookupRequest]Timestamp {
301	keyRequests := map[PublicKeyLookupRequest]Timestamp{}
302	for i := range requests {
303		if results[i].Error == nil {
304			// We've already verified this message, we don't need to refetch the keys for it.
305			continue
306		}
307		for _, keyID := range keyIDs[i] {
308			k := PublicKeyLookupRequest{requests[i].ServerName, keyID}
309			// Grab the maximum neeeded TS for this server and key ID.
310			// This will default to 0 if the server and keyID weren't in the map.
311			maxTS := keyRequests[k]
312			if maxTS <= requests[i].AtTS {
313				// We clobber on equality since that means that if the server and keyID
314				// weren't already in the map and since AtTS is unsigned and since the
315				// default value for maxTS is 0 we will always insert an entry for the
316				// server and keyID.
317				keyRequests[k] = requests[i].AtTS
318			}
319		}
320	}
321	return keyRequests
322}
323
324func (k *KeyRing) checkUsingKeys(
325	requests []VerifyJSONRequest, results []VerifyJSONResult, keyIDs [][]KeyID,
326	keys map[PublicKeyLookupRequest]PublicKeyLookupResult,
327) {
328	for i := range requests {
329		if results[i].Error == nil {
330			// We've already checked this message and it passed the signature checks.
331			// So we can skip to the next message.
332			continue
333		}
334		for _, keyID := range keyIDs[i] {
335			serverKey, ok := keys[PublicKeyLookupRequest{requests[i].ServerName, keyID}]
336			if !ok {
337				// No key for this key ID so we continue onto the next key ID.
338				continue
339			}
340			if !serverKey.WasValidAt(requests[i].AtTS, requests[i].StrictValidityChecking) {
341				// The key wasn't valid at the timestamp we needed it to be valid at.
342				// So skip onto the next key.
343				results[i].Error = fmt.Errorf(
344					"gomatrixserverlib: key with ID %q for %q not valid at %d",
345					keyID, requests[i].ServerName, requests[i].AtTS,
346				)
347				continue
348			}
349			if err := VerifyJSON(
350				string(requests[i].ServerName), keyID, ed25519.PublicKey(serverKey.Key), requests[i].Message,
351			); err != nil {
352				// The signature wasn't valid, record the error and try the next key ID.
353				results[i].Error = err
354				continue
355			}
356			// The signature is valid, set the result to nil.
357			results[i].Error = nil
358			break
359		}
360	}
361}
362
363type KeyClient interface {
364	GetServerKeys(ctx context.Context, matrixServer ServerName) (ServerKeys, error)
365	LookupServerKeys(ctx context.Context, matrixServer ServerName, keyRequests map[PublicKeyLookupRequest]Timestamp) ([]ServerKeys, error)
366}
367
368// A PerspectiveKeyFetcher fetches server keys from a single perspective server.
369type PerspectiveKeyFetcher struct {
370	// The name of the perspective server to fetch keys from.
371	PerspectiveServerName ServerName
372	// The ed25519 public keys the perspective server must sign responses with.
373	PerspectiveServerKeys map[KeyID]ed25519.PublicKey
374	// The federation client to use to fetch keys with.
375	Client KeyClient
376}
377
378// FetcherName implements KeyFetcher
379func (p PerspectiveKeyFetcher) FetcherName() string {
380	return fmt.Sprintf("perspective server %s", p.PerspectiveServerName)
381}
382
383// FetchKeys implements KeyFetcher
384func (p *PerspectiveKeyFetcher) FetchKeys(
385	ctx context.Context, requests map[PublicKeyLookupRequest]Timestamp,
386) (map[PublicKeyLookupRequest]PublicKeyLookupResult, error) {
387	serverKeys, err := p.Client.LookupServerKeys(ctx, p.PerspectiveServerName, requests)
388	if err != nil {
389		return nil, fmt.Errorf("gomatrixserverlib: unable to lookup server keys: %w", err)
390	}
391
392	results := map[PublicKeyLookupRequest]PublicKeyLookupResult{}
393
394	for _, keys := range serverKeys {
395		var valid bool
396		keyIDs, err := ListKeyIDs(string(p.PerspectiveServerName), keys.Raw)
397		if err != nil {
398			// The response from the perspective server was corrupted.
399			return nil, fmt.Errorf("gomatrixserverlib: unable to list key IDs: %w", err)
400		}
401		for _, keyID := range keyIDs {
402			perspectiveKey, ok := p.PerspectiveServerKeys[keyID]
403			if !ok {
404				// We don't have a key for that keyID, skip to the next keyID.
405				continue
406			}
407			if err := VerifyJSON(string(p.PerspectiveServerName), keyID, perspectiveKey, keys.Raw); err != nil {
408				// An invalid signature is very bad since it means we have a
409				// problem talking to the perspective server.
410				return nil, fmt.Errorf("gomatrixserverlib: unable to verify response: %w", err)
411			}
412			valid = true
413			break
414		}
415		if !valid {
416			// This means we don't have a known signature from the perspective server.
417			return nil, fmt.Errorf("gomatrixserverlib: not signed with a known key for the perspective server")
418		}
419
420		// Check that the keys are valid for the server they claim to be
421		checks, _ := CheckKeys(keys.ServerName, time.Unix(0, 0), keys)
422		if !checks.AllChecksOK {
423			// This is bad because it means that the perspective server was trying to feed us an invalid response.
424			return nil, fmt.Errorf("gomatrixserverlib: key response from perspective server failed checks")
425		}
426
427		// TODO (matrix-org/dendrite#345): What happens if the same key ID
428		// appears in multiple responses?
429		// We should probably take the response with the highest valid_until_ts.
430		mapServerKeysToPublicKeyLookupResult(keys, results)
431	}
432
433	return results, nil
434}
435
436// A DirectKeyFetcher fetches keys directly from a server.
437// This may be suitable for local deployments that are firewalled from the public internet where DNS can be trusted.
438type DirectKeyFetcher struct {
439	// The federation client to use to fetch keys with.
440	Client KeyClient
441}
442
443// FetcherName implements KeyFetcher
444func (d DirectKeyFetcher) FetcherName() string {
445	return "DirectKeyFetcher"
446}
447
448// FetchKeys implements KeyFetcher
449func (d *DirectKeyFetcher) FetchKeys(
450	ctx context.Context, requests map[PublicKeyLookupRequest]Timestamp,
451) (map[PublicKeyLookupRequest]PublicKeyLookupResult, error) {
452	fetcherLogger := util.GetLogger(ctx).WithField("fetcher", d.FetcherName())
453
454	byServer := map[ServerName]map[PublicKeyLookupRequest]Timestamp{}
455	for req, ts := range requests {
456		server := byServer[req.ServerName]
457		if server == nil {
458			server = map[PublicKeyLookupRequest]Timestamp{}
459			byServer[req.ServerName] = server
460		}
461		server[req] = ts
462	}
463
464	// Work out the number of workers that we want to start. If the
465	// number of outstanding requests is less than the current max
466	// then reduce it so we don't start workers unnecessarily.
467	numWorkers := 64
468	if len(byServer) < numWorkers {
469		numWorkers = len(byServer)
470	}
471
472	// Prepare somewhere to put the results. This map is protected
473	// by the below mutex.
474	results := map[PublicKeyLookupRequest]PublicKeyLookupResult{}
475	var resultsMutex sync.Mutex
476
477	// Populate the wait group with the number of workers.
478	var wait sync.WaitGroup
479	wait.Add(numWorkers)
480
481	// Populate the jobs queue.
482	pending := make(chan ServerName, len(byServer))
483	for serverName := range byServer {
484		pending <- serverName
485	}
486	close(pending)
487
488	// Define our worker.
489	worker := func(ch <-chan ServerName) {
490		defer wait.Done()
491		for server := range ch {
492			serverResults, err := d.fetchKeysForServer(ctx, server)
493			if err != nil {
494				serverResults, err = d.fetchNotaryKeysForServer(ctx, server)
495				if err != nil {
496					// TODO: Should we actually be erroring here? or should we just drop those keys from the result map?
497					fetcherLogger.WithError(err).Error("Failed to fetch key for server")
498					continue
499				}
500			}
501			resultsMutex.Lock()
502			for req, keys := range serverResults {
503				results[req] = keys
504			}
505			resultsMutex.Unlock()
506		}
507	}
508
509	// Start the workers.
510	for i := 0; i < numWorkers; i++ {
511		go worker(pending)
512	}
513
514	// Wait for the workers to finish before returning
515	// the results.
516	wait.Wait()
517	return results, nil
518}
519
520func (d *DirectKeyFetcher) fetchKeysForServer(
521	ctx context.Context, serverName ServerName,
522) (map[PublicKeyLookupRequest]PublicKeyLookupResult, error) {
523	ctx, cancel := context.WithTimeout(ctx, time.Second*15)
524	defer cancel()
525
526	keys, err := d.Client.GetServerKeys(ctx, serverName)
527	if err != nil {
528		if err != nil {
529			return nil, err
530		}
531	}
532	// Check that the keys are valid for the server.
533	checks, _ := CheckKeys(serverName, time.Unix(0, 0), keys)
534	if !checks.AllChecksOK {
535		return nil, fmt.Errorf("gomatrixserverlib: key response direct from %q failed checks", serverName)
536	}
537
538	results := map[PublicKeyLookupRequest]PublicKeyLookupResult{}
539
540	// TODO (matrix-org/dendrite#345): What happens if the same key ID
541	// appears in multiple responses? We should probably reject the response.
542	mapServerKeysToPublicKeyLookupResult(keys, results)
543
544	return results, nil
545}
546
547func (d *DirectKeyFetcher) fetchNotaryKeysForServer(
548	ctx context.Context, serverName ServerName,
549) (map[PublicKeyLookupRequest]PublicKeyLookupResult, error) {
550	ctx, cancel := context.WithTimeout(ctx, time.Second*15)
551	defer cancel()
552
553	var keys ServerKeys
554	allKeys, err := d.Client.LookupServerKeys(ctx, serverName, map[PublicKeyLookupRequest]Timestamp{
555		{serverName, ""}: AsTimestamp(time.Now()),
556	})
557	if err != nil {
558		return nil, err
559	}
560	found := false
561	for _, serverKeys := range allKeys {
562		if serverKeys.ServerName == serverName {
563			keys = serverKeys
564			found = true
565			break
566		}
567	}
568	if !found {
569		return nil, fmt.Errorf("gomatrixserverlib: notary key response contained no results for %q", serverName)
570	}
571	// Check that the keys are valid for the server.
572	checks, _ := CheckKeys(serverName, time.Unix(0, 0), keys)
573	if !checks.AllChecksOK {
574		return nil, fmt.Errorf("gomatrixserverlib: notary key response direct from %q failed checks", serverName)
575	}
576
577	results := map[PublicKeyLookupRequest]PublicKeyLookupResult{}
578
579	// TODO (matrix-org/dendrite#345): What happens if the same key ID
580	// appears in multiple responses? We should probably reject the response.
581	mapServerKeysToPublicKeyLookupResult(keys, results)
582
583	return results, nil
584}
585
586// mapServerKeysToPublicKeyLookupResult takes the (verified) result from a
587// /key/v2/query call and inserts it into a PublicKeyLookupRequest->PublicKeyLookupResult
588// map.
589func mapServerKeysToPublicKeyLookupResult(serverKeys ServerKeys, results map[PublicKeyLookupRequest]PublicKeyLookupResult) {
590	for keyID, key := range serverKeys.VerifyKeys {
591		results[PublicKeyLookupRequest{
592			ServerName: serverKeys.ServerName,
593			KeyID:      keyID,
594		}] = PublicKeyLookupResult{
595			VerifyKey:    key,
596			ValidUntilTS: serverKeys.ValidUntilTS,
597			ExpiredTS:    PublicKeyNotExpired,
598		}
599	}
600	for keyID, key := range serverKeys.OldVerifyKeys {
601		results[PublicKeyLookupRequest{
602			ServerName: serverKeys.ServerName,
603			KeyID:      keyID,
604		}] = PublicKeyLookupResult{
605			VerifyKey:    key.VerifyKey,
606			ValidUntilTS: PublicKeyNotValid,
607			ExpiredTS:    key.ExpiredTS,
608		}
609	}
610}
611