1// Copyright 2015 Keybase, Inc. All rights reserved. Use of
2// this source code is governed by the included BSD license.
3
4package libkb
5
6import (
7	"encoding/json"
8	"fmt"
9	"reflect"
10	"sort"
11	"testing"
12	"time"
13
14	keybase1 "github.com/keybase/client/go/protocol/keybase1"
15	jsonw "github.com/keybase/go-jsonw"
16	testvectors "github.com/keybase/keybase-test-vectors/go"
17	"github.com/stretchr/testify/require"
18)
19
20// Returns a map from error name strings to sets of Go error types. If a test
21// returns any error type in the corresponding set, it's a pass. (The reason
22// the types aren't one-to-one here is that implementation differences between
23// the Go and JS sigchains make that more trouble than it's worth.)
24func getErrorTypesMap() map[string]map[reflect.Type]bool {
25	return map[string]map[reflect.Type]bool{
26		"CTIME_MISMATCH": {
27			reflect.TypeOf(CtimeMismatchError{}): true,
28		},
29		"EXPIRED_SIBKEY": {
30			reflect.TypeOf(KeyExpiredError{}): true,
31		},
32		"FINGERPRINT_MISMATCH": {
33			reflect.TypeOf(ChainLinkFingerprintMismatchError{}): true,
34		},
35		"INVALID_SIBKEY": {
36			reflect.TypeOf(KeyRevokedError{}): true,
37		},
38		"NO_KEY_WITH_THIS_HASH": {
39			reflect.TypeOf(NoKeyError{}): true,
40		},
41		"KEY_OWNERSHIP": {
42			reflect.TypeOf(KeyFamilyError{}): true,
43		},
44		"KID_MISMATCH": {
45			reflect.TypeOf(ChainLinkKIDMismatchError{}): true,
46		},
47		"NONEXISTENT_KID": {
48			reflect.TypeOf(KeyFamilyError{}): true,
49		},
50		"NOT_LATEST_SUBCHAIN": {
51			reflect.TypeOf(NotLatestSubchainError{}): true,
52		},
53		"REVERSE_SIG_VERIFY_FAILED": {
54			reflect.TypeOf(ReverseSigError{}): true,
55		},
56		"VERIFY_FAILED": {
57			reflect.TypeOf(BadSigError{}): true,
58		},
59		"WRONG_UID": {
60			reflect.TypeOf(UIDMismatchError{}): true,
61		},
62		"WRONG_USERNAME": {
63			reflect.TypeOf(BadUsernameError{}): true,
64		},
65		"WRONG_SEQNO": {
66			reflect.TypeOf(ChainLinkWrongSeqnoError{}): true,
67		},
68		"WRONG_PREV": {
69			reflect.TypeOf(ChainLinkPrevHashMismatchError{}): true,
70		},
71		"BAD_CHAIN_LINK": {
72			reflect.TypeOf(ChainLinkError{}): true,
73		},
74		"CHAIN_LINK_STUBBED_UNSUPPORTED": {
75			reflect.TypeOf(ChainLinkStubbedUnsupportedError{}): true,
76		},
77		"SIGCHAIN_V2_STUBBED_SIGNATURE_NEEDED": {
78			reflect.TypeOf(SigchainV2StubbedSignatureNeededError{}): true,
79		},
80		"SIGCHAIN_V2_STUBBED_FIRST_LINK": {
81			reflect.TypeOf(SigchainV2StubbedFirstLinkError{}): true,
82		},
83		"SIGCHAIN_V2_MISMATCHED_FIELD": {
84			reflect.TypeOf(SigchainV2MismatchedFieldError{}): true,
85		},
86		"SIGCHAIN_V2_MISMATCHED_HASH": {
87			reflect.TypeOf(SigchainV2MismatchedHashError{}): true,
88		},
89		"WRONG_PER_USER_KEY_REVERSE_SIG": {
90			reflect.TypeOf(ReverseSigError{}): true,
91		},
92	}
93}
94
95type subchainSummary struct {
96	EldestSeqno keybase1.Seqno `json:"eldest_seqno"`
97	Sibkeys     int            `json:"sibkeys"`
98	Subkeys     int            `json:"subkeys"`
99}
100
101// One of the test cases from the JSON list of all tests.
102type TestCase struct {
103	Input         string            `json:"input"`
104	Len           int               `json:"len"`
105	Sibkeys       int               `json:"sibkeys"`
106	Subkeys       int               `json:"subkeys"`
107	ErrType       string            `json:"err_type"`
108	Eldest        string            `json:"eldest"`
109	EldestSeqno   *keybase1.Seqno   `json:"eldest_seqno,omitempty"`
110	PrevSubchains []subchainSummary `json:"previous_subchains,omitempty"`
111}
112
113// The JSON list of all test cases.
114type TestList struct {
115	Tests      map[string]TestCase `json:"tests"`
116	ErrorTypes []string            `json:"error_types"`
117}
118
119// The input data for a single test. Each test has its own input JSON file.
120type TestInput struct {
121	// We omit the "chain" member here, because we need it in blob form.
122	Username  string            `json:"username"`
123	UID       string            `json:"uid"`
124	Keys      []string          `json:"keys"`
125	LabelKids map[string]string `json:"label_kids"`
126	LabelSigs map[string]string `json:"label_sigs"`
127}
128
129func TestAllChains(t *testing.T) {
130	tc := SetupTest(t, "test_all_chains", 1)
131	defer tc.Cleanup()
132
133	var testList TestList
134	err := json.Unmarshal([]byte(testvectors.ChainTests), &testList)
135	require.NoError(t, err, "failed to unmarshal the chain tests")
136	// Always do the tests in alphabetical order.
137	testNames := []string{}
138	for name := range testList.Tests {
139		testNames = append(testNames, name)
140	}
141	sort.Strings(testNames)
142	for _, name := range testNames {
143		testCase := testList.Tests[name]
144		tc.G.Log.Info("starting sigchain test case %s (%s)", name, testCase.Input)
145		doChainTest(t, tc, testCase)
146	}
147}
148
149func doChainTest(t *testing.T, tc TestContext, testCase TestCase) {
150	inputJSON, exists := testvectors.ChainTestInputs[testCase.Input]
151	if !exists {
152		t.Fatal("missing test input: " + testCase.Input)
153	}
154	// Unmarshal test input in two ways: once for the structured data and once
155	// for the chain link blobs.
156	var input TestInput
157	err := json.Unmarshal([]byte(inputJSON), &input)
158	if err != nil {
159		t.Fatal(err)
160	}
161	inputBlob, err := jsonw.Unmarshal([]byte(inputJSON))
162	if err != nil {
163		t.Fatal(err)
164	}
165	uid, err := UIDFromHex(input.UID)
166	if err != nil {
167		t.Fatal(err)
168	}
169	chainLen, err := inputBlob.AtKey("chain").Len()
170	if err != nil {
171		t.Fatal(err)
172	}
173
174	// Get the eldest key. This is assumed to be the first key in the list of
175	// bundles, unless the "eldest" field is given in the test description, in
176	// which case the eldest key is specified by name.
177	var eldestKID keybase1.KID
178	if testCase.Eldest == "" {
179		eldestKey, _, err := ParseGenericKey(input.Keys[0])
180		if err != nil {
181			t.Fatal(err)
182		}
183		eldestKID = eldestKey.GetKID()
184	} else {
185		eldestKIDStr, found := input.LabelKids[testCase.Eldest]
186		if !found {
187			t.Fatalf("No KID found for label %s", testCase.Eldest)
188		}
189		eldestKID = keybase1.KIDFromString(eldestKIDStr)
190	}
191
192	// Parse all the key bundles.
193	keyFamily, err := createKeyFamily(tc.G, input.Keys)
194	if err != nil {
195		t.Fatal(err)
196	}
197
198	// Run the actual sigchain parsing and verification. This is most of the
199	// code that's actually being tested.
200	var sigchainErr error
201	m := NewMetaContextForTest(tc)
202	ckf := ComputedKeyFamily{Contextified: NewContextified(tc.G), kf: keyFamily}
203	sigchain := SigChain{
204		username:          NewNormalizedUsername(input.Username),
205		uid:               uid,
206		loadedFromLinkOne: true,
207		Contextified:      NewContextified(tc.G),
208	}
209	for i := 0; i < chainLen; i++ {
210		linkBlob := inputBlob.AtKey("chain").AtIndex(i)
211		rawLinkBlob, err := linkBlob.Marshal()
212		if err != nil {
213			sigchainErr = err
214			break
215		}
216		link, err := ImportLinkFromServer(m, &sigchain, rawLinkBlob, uid)
217		if err != nil {
218			sigchainErr = err
219			break
220		}
221		require.Equal(t, keybase1.SeqType_PUBLIC, link.unpacked.seqType, "all user chains are public")
222		if link.unpacked.outerLinkV2 != nil {
223			require.Equal(t, link.unpacked.outerLinkV2.SeqType, link.unpacked.seqType, "inner-outer seq_type match")
224		}
225		sigchain.chainLinks = append(sigchain.chainLinks, link)
226	}
227	if sigchainErr == nil {
228		_, sigchainErr = sigchain.VerifySigsAndComputeKeys(NewMetaContextForTest(tc), eldestKID, &ckf, uid)
229	}
230
231	// Some tests expect an error. If we get one, make sure it's the right
232	// type.
233	if testCase.ErrType != "" {
234		if sigchainErr == nil {
235			t.Fatalf("Expected %s error from VerifySigsAndComputeKeys. No error returned.", testCase.ErrType)
236		}
237		foundType := reflect.TypeOf(sigchainErr)
238		expectedTypes := getErrorTypesMap()[testCase.ErrType]
239		if len(expectedTypes) == 0 {
240			msg := "No Go error types defined for expected failure %s.\n" +
241				"This could be because of new test cases in github.com/keybase/keybase-test-vectors.\n" +
242				"Go error returned: %s"
243			t.Fatalf(msg, testCase.ErrType, foundType)
244		}
245		if expectedTypes[foundType] {
246			// Success! We found the error we expected. This test is done.
247			tc.G.Log.Debug("EXPECTED error encountered: %s", sigchainErr)
248			return
249		}
250
251		// Got an error, but one of the wrong type. Tests with error names
252		// that are missing from the map (maybe because we add new test
253		// cases in the future) will also hit this branch.
254		t.Fatalf("Wrong error type encountered. Expected %v (%s), got %s: %s",
255			expectedTypes, testCase.ErrType, foundType, sigchainErr)
256
257	}
258
259	// Tests that expected an error terminated above. Tests that get here
260	// should succeed without errors.
261	if sigchainErr != nil {
262		t.Fatal(sigchainErr)
263	}
264
265	// Check the expected results: total unrevoked links, sibkeys, and subkeys.
266	unrevokedCount := 0
267
268	idtable, err := NewIdentityTable(NewMetaContextForTest(tc), eldestKID, &sigchain, nil)
269	if err != nil {
270		t.Fatal(err)
271	}
272	for _, link := range idtable.links {
273		if !link.IsDirectlyRevoked() {
274			unrevokedCount++
275		}
276	}
277
278	fatalStr := ""
279	if unrevokedCount != testCase.Len {
280		fatalStr += fmt.Sprintf("Expected %d unrevoked links, but found %d.\n", testCase.Len, unrevokedCount)
281	}
282	if testCase.Len > 0 && sigchain.currentSubchainStart == 0 {
283		fatalStr += fmt.Sprintf("Expected nonzero currentSubchainStart, but found %d.\n", sigchain.currentSubchainStart)
284	}
285	// Don't use the current time to get keys, because that will cause test
286	// failures 5 years from now :-D
287	testTime := getCurrentTimeForTest(sigchain, keyFamily)
288	numSibkeys := len(ckf.GetAllActiveSibkeysAtTime(testTime))
289	if numSibkeys != testCase.Sibkeys {
290		fatalStr += fmt.Sprintf("Expected %d sibkeys, got %d\n", testCase.Sibkeys, numSibkeys)
291	}
292	numSubkeys := len(ckf.GetAllActiveSubkeysAtTime(testTime))
293	if numSubkeys != testCase.Subkeys {
294		fatalStr += fmt.Sprintf("Expected %d subkeys, got %d\n", testCase.Subkeys, numSubkeys)
295	}
296
297	if fatalStr != "" {
298		t.Fatal(fatalStr)
299	}
300
301	if testCase.EldestSeqno != nil && sigchain.EldestSeqno() != *testCase.EldestSeqno {
302		t.Fatalf("wrong eldest seqno: wanted %d but got %d", *testCase.EldestSeqno, sigchain.EldestSeqno())
303	}
304	if testCase.PrevSubchains != nil {
305		if len(testCase.PrevSubchains) != len(sigchain.prevSubchains) {
306			t.Fatalf("wrong number of historical subchains; wanted %d but got %d", len(testCase.PrevSubchains), len(sigchain.prevSubchains))
307		}
308		for i, expected := range testCase.PrevSubchains {
309			received := sigchain.prevSubchains[i]
310			if received.EldestSeqno() != expected.EldestSeqno {
311				t.Fatalf("For historical subchain %d, wrong eldest seqno; wanted %d but got %d", i, expected.EldestSeqno, received.EldestSeqno())
312			}
313			ckf := ComputedKeyFamily{kf: keyFamily, cki: received.GetComputedKeyInfos()}
314			n := len(ckf.GetAllSibkeysUnchecked())
315			if n != expected.Sibkeys {
316				t.Fatalf("For historical subchain %d, wrong number of sibkeys; wanted %d but got %d", i, expected.Sibkeys, n)
317			}
318			m := len(ckf.GetAllSubkeysUnchecked())
319			if m != expected.Subkeys {
320				t.Fatalf("For historical subchain %d, wrong number of subkeys; wanted %d but got %d", i, expected.Sibkeys, m)
321			}
322		}
323	}
324
325	storeAndLoad(t, tc, &sigchain)
326	// Success!
327}
328
329func storeAndLoad(t *testing.T, tc TestContext, chain *SigChain) {
330	err := chain.Store(NewMetaContextForTest(tc))
331	if err != nil {
332		t.Fatal(err)
333	}
334	sgl := SigChainLoader{
335		user: &User{
336			name: chain.username.String(),
337			id:   chain.uid,
338		},
339		self: false,
340		leaf: &MerkleUserLeaf{
341			public: chain.GetCurrentTailTriple(),
342			uid:    chain.uid,
343		},
344		chainType:        PublicChain,
345		MetaContextified: NewMetaContextified(NewMetaContextForTest(tc)),
346	}
347	sgl.chain = chain
348	sgl.dirtyTail = chain.GetCurrentTailTriple()
349	err = sgl.Store()
350	if err != nil {
351		t.Fatal(err)
352	}
353	sgl.chain = nil
354	sgl.dirtyTail = nil
355	var sc2 *SigChain
356	// Reset the link cache so that we're sure our loads hits storage.
357	tc.G.cacheMu.Lock()
358	tc.G.linkCache = NewLinkCache(1000, time.Hour)
359	tc.G.cacheMu.Unlock()
360	sc2, err = sgl.Load()
361	if err != nil {
362		t.Fatal(err)
363	}
364
365	// Loading sigchains from cache doesn't benefit from knowing the current
366	// eldest KID from the Merkle tree. That means if the account just reset,
367	// for example, loading from cache will still produce the old subchain
368	// start. Avoid failing on this case by skipping the comparison when
369	// `currentSubchainStart` is 0 (invalid) in the original chain.
370	if chain.currentSubchainStart == 0 {
371		// As described above, short circuit when we know loading from cache
372		// would give us a different answer.
373		return
374	}
375	if chain.currentSubchainStart != sc2.currentSubchainStart {
376		t.Fatalf("disagreement about currentSubchainStart: %d != %d", chain.currentSubchainStart, sc2.currentSubchainStart)
377	}
378	if len(chain.chainLinks) != len(sc2.chainLinks) {
379		t.Fatalf("subchains don't have the same length: %d != %d", len(chain.chainLinks), len(sc2.chainLinks))
380	}
381	for i := 0; i < len(chain.chainLinks); i++ {
382		if chain.chainLinks[i].GetSeqno() != sc2.chainLinks[i].GetSeqno() {
383			t.Fatalf("stored and loaded chains mismatched links: %d != %d", chain.chainLinks[i].GetSeqno(), sc2.chainLinks[i].GetSeqno())
384		}
385	}
386}
387
388func createKeyFamily(g *GlobalContext, bundles []string) (*KeyFamily, error) {
389	allKeys := jsonw.NewArray(len(bundles))
390	for i, bundle := range bundles {
391		err := allKeys.SetIndex(i, jsonw.NewString(bundle))
392		if err != nil {
393			return nil, err
394		}
395	}
396	publicKeys := jsonw.NewDictionary()
397	err := publicKeys.SetKey("all_bundles", allKeys)
398	if err != nil {
399		return nil, err
400	}
401	return ParseKeyFamily(g, publicKeys)
402}
403
404func getCurrentTimeForTest(sigChain SigChain, keyFamily *KeyFamily) time.Time {
405	// Pick a test time that's the latest ctime of all links and PGP keys.
406	var t time.Time
407	for _, link := range sigChain.chainLinks {
408		linkCTime := time.Unix(link.unpacked.ctime, 0)
409		if linkCTime.After(t) {
410			t = linkCTime
411		}
412	}
413	for _, ks := range keyFamily.PGPKeySets {
414		keyCTime := ks.PermissivelyMergedKey.PrimaryKey.CreationTime
415		if keyCTime.After(t) {
416			t = keyCTime
417		}
418	}
419	return t
420}
421