1// Copyright 2017 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package knownhosts
6
7import (
8	"bytes"
9	"fmt"
10	"net"
11	"reflect"
12	"testing"
13
14	"golang.org/x/crypto/ssh"
15)
16
17const edKeyStr = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIGBAarftlLeoyf+v+nVchEZII/vna2PCV8FaX4vsF5BX"
18const alternateEdKeyStr = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIIXffBYeYL+WVzVru8npl5JHt2cjlr4ornFTWzoij9sx"
19const ecKeyStr = "ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBNLCu01+wpXe3xB5olXCN4SqU2rQu0qjSRKJO4Bg+JRCPU+ENcgdA5srTU8xYDz/GEa4dzK5ldPw4J/gZgSXCMs="
20
21var ecKey, alternateEdKey, edKey ssh.PublicKey
22var testAddr = &net.TCPAddr{
23	IP:   net.IP{198, 41, 30, 196},
24	Port: 22,
25}
26
27var testAddr6 = &net.TCPAddr{
28	IP: net.IP{198, 41, 30, 196,
29		1, 2, 3, 4,
30		1, 2, 3, 4,
31		1, 2, 3, 4,
32	},
33	Port: 22,
34}
35
36func init() {
37	var err error
38	ecKey, _, _, _, err = ssh.ParseAuthorizedKey([]byte(ecKeyStr))
39	if err != nil {
40		panic(err)
41	}
42	edKey, _, _, _, err = ssh.ParseAuthorizedKey([]byte(edKeyStr))
43	if err != nil {
44		panic(err)
45	}
46	alternateEdKey, _, _, _, err = ssh.ParseAuthorizedKey([]byte(alternateEdKeyStr))
47	if err != nil {
48		panic(err)
49	}
50}
51
52func testDB(t *testing.T, s string) *hostKeyDB {
53	db := newHostKeyDB()
54	if err := db.Read(bytes.NewBufferString(s), "testdb"); err != nil {
55		t.Fatalf("Read: %v", err)
56	}
57
58	return db
59}
60
61func TestRevoked(t *testing.T) {
62	db := testDB(t, "\n\n@revoked * "+edKeyStr+"\n")
63	want := &RevokedError{
64		Revoked: KnownKey{
65			Key:      edKey,
66			Filename: "testdb",
67			Line:     3,
68		},
69	}
70	if err := db.check("", &net.TCPAddr{
71		Port: 42,
72	}, edKey); err == nil {
73		t.Fatal("no error for revoked key")
74	} else if !reflect.DeepEqual(want, err) {
75		t.Fatalf("got %#v, want %#v", want, err)
76	}
77}
78
79func TestHostAuthority(t *testing.T) {
80	for _, m := range []struct {
81		authorityFor string
82		address      string
83
84		good bool
85	}{
86		{authorityFor: "localhost", address: "localhost:22", good: true},
87		{authorityFor: "localhost", address: "localhost", good: false},
88		{authorityFor: "localhost", address: "localhost:1234", good: false},
89		{authorityFor: "[localhost]:1234", address: "localhost:1234", good: true},
90		{authorityFor: "[localhost]:1234", address: "localhost:22", good: false},
91		{authorityFor: "[localhost]:1234", address: "localhost", good: false},
92	} {
93		db := testDB(t, `@cert-authority `+m.authorityFor+` `+edKeyStr)
94		if ok := db.IsHostAuthority(db.lines[0].knownKey.Key, m.address); ok != m.good {
95			t.Errorf("IsHostAuthority: authority %s, address %s, wanted good = %v, got good = %v",
96				m.authorityFor, m.address, m.good, ok)
97		}
98	}
99}
100
101func TestBracket(t *testing.T) {
102	db := testDB(t, `[git.eclipse.org]:29418,[198.41.30.196]:29418 `+edKeyStr)
103
104	if err := db.check("git.eclipse.org:29418", &net.TCPAddr{
105		IP:   net.IP{198, 41, 30, 196},
106		Port: 29418,
107	}, edKey); err != nil {
108		t.Errorf("got error %v, want none", err)
109	}
110
111	if err := db.check("git.eclipse.org:29419", &net.TCPAddr{
112		Port: 42,
113	}, edKey); err == nil {
114		t.Fatalf("no error for unknown address")
115	} else if ke, ok := err.(*KeyError); !ok {
116		t.Fatalf("got type %T, want *KeyError", err)
117	} else if len(ke.Want) > 0 {
118		t.Fatalf("got Want %v, want []", ke.Want)
119	}
120}
121
122func TestNewKeyType(t *testing.T) {
123	str := fmt.Sprintf("%s %s", testAddr, edKeyStr)
124	db := testDB(t, str)
125	if err := db.check("", testAddr, ecKey); err == nil {
126		t.Fatalf("no error for unknown address")
127	} else if ke, ok := err.(*KeyError); !ok {
128		t.Fatalf("got type %T, want *KeyError", err)
129	} else if len(ke.Want) == 0 {
130		t.Fatalf("got empty KeyError.Want")
131	}
132}
133
134func TestSameKeyType(t *testing.T) {
135	str := fmt.Sprintf("%s %s", testAddr, edKeyStr)
136	db := testDB(t, str)
137	if err := db.check("", testAddr, alternateEdKey); err == nil {
138		t.Fatalf("no error for unknown address")
139	} else if ke, ok := err.(*KeyError); !ok {
140		t.Fatalf("got type %T, want *KeyError", err)
141	} else if len(ke.Want) == 0 {
142		t.Fatalf("got empty KeyError.Want")
143	} else if got, want := ke.Want[0].Key.Marshal(), edKey.Marshal(); !bytes.Equal(got, want) {
144		t.Fatalf("got key %q, want %q", got, want)
145	}
146}
147
148func TestIPAddress(t *testing.T) {
149	str := fmt.Sprintf("%s %s", testAddr, edKeyStr)
150	db := testDB(t, str)
151	if err := db.check("", testAddr, edKey); err != nil {
152		t.Errorf("got error %q, want none", err)
153	}
154}
155
156func TestIPv6Address(t *testing.T) {
157	str := fmt.Sprintf("%s %s", testAddr6, edKeyStr)
158	db := testDB(t, str)
159
160	if err := db.check("", testAddr6, edKey); err != nil {
161		t.Errorf("got error %q, want none", err)
162	}
163}
164
165func TestBasic(t *testing.T) {
166	str := fmt.Sprintf("#comment\n\nserver.org,%s %s\notherhost %s", testAddr, edKeyStr, ecKeyStr)
167	db := testDB(t, str)
168	if err := db.check("server.org:22", testAddr, edKey); err != nil {
169		t.Errorf("got error %v, want none", err)
170	}
171
172	want := KnownKey{
173		Key:      edKey,
174		Filename: "testdb",
175		Line:     3,
176	}
177	if err := db.check("server.org:22", testAddr, ecKey); err == nil {
178		t.Errorf("succeeded, want KeyError")
179	} else if ke, ok := err.(*KeyError); !ok {
180		t.Errorf("got %T, want *KeyError", err)
181	} else if len(ke.Want) != 1 {
182		t.Errorf("got %v, want 1 entry", ke)
183	} else if !reflect.DeepEqual(ke.Want[0], want) {
184		t.Errorf("got %v, want %v", ke.Want[0], want)
185	}
186}
187
188func TestHostNamePrecedence(t *testing.T) {
189	var evilAddr = &net.TCPAddr{
190		IP:   net.IP{66, 66, 66, 66},
191		Port: 22,
192	}
193
194	str := fmt.Sprintf("server.org,%s %s\nevil.org,%s %s", testAddr, edKeyStr, evilAddr, ecKeyStr)
195	db := testDB(t, str)
196
197	if err := db.check("server.org:22", evilAddr, ecKey); err == nil {
198		t.Errorf("check succeeded")
199	} else if _, ok := err.(*KeyError); !ok {
200		t.Errorf("got %T, want *KeyError", err)
201	}
202}
203
204func TestDBOrderingPrecedenceKeyType(t *testing.T) {
205	str := fmt.Sprintf("server.org,%s %s\nserver.org,%s %s", testAddr, edKeyStr, testAddr, alternateEdKeyStr)
206	db := testDB(t, str)
207
208	if err := db.check("server.org:22", testAddr, alternateEdKey); err == nil {
209		t.Errorf("check succeeded")
210	} else if _, ok := err.(*KeyError); !ok {
211		t.Errorf("got %T, want *KeyError", err)
212	}
213}
214
215func TestNegate(t *testing.T) {
216	str := fmt.Sprintf("%s,!server.org %s", testAddr, edKeyStr)
217	db := testDB(t, str)
218	if err := db.check("server.org:22", testAddr, ecKey); err == nil {
219		t.Errorf("succeeded")
220	} else if ke, ok := err.(*KeyError); !ok {
221		t.Errorf("got error type %T, want *KeyError", err)
222	} else if len(ke.Want) != 0 {
223		t.Errorf("got expected keys %d (first of type %s), want []", len(ke.Want), ke.Want[0].Key.Type())
224	}
225}
226
227func TestWildcard(t *testing.T) {
228	str := fmt.Sprintf("server*.domain %s", edKeyStr)
229	db := testDB(t, str)
230
231	want := &KeyError{
232		Want: []KnownKey{{
233			Filename: "testdb",
234			Line:     1,
235			Key:      edKey,
236		}},
237	}
238
239	got := db.check("server.domain:22", &net.TCPAddr{}, ecKey)
240	if !reflect.DeepEqual(got, want) {
241		t.Errorf("got %s, want %s", got, want)
242	}
243}
244
245func TestLine(t *testing.T) {
246	for in, want := range map[string]string{
247		"server.org":                             "server.org " + edKeyStr,
248		"server.org:22":                          "server.org " + edKeyStr,
249		"server.org:23":                          "[server.org]:23 " + edKeyStr,
250		"[c629:1ec4:102:304:102:304:102:304]:22": "[c629:1ec4:102:304:102:304:102:304] " + edKeyStr,
251		"[c629:1ec4:102:304:102:304:102:304]:23": "[c629:1ec4:102:304:102:304:102:304]:23 " + edKeyStr,
252	} {
253		if got := Line([]string{in}, edKey); got != want {
254			t.Errorf("Line(%q) = %q, want %q", in, got, want)
255		}
256	}
257}
258
259func TestWildcardMatch(t *testing.T) {
260	for _, c := range []struct {
261		pat, str string
262		want     bool
263	}{
264		{"a?b", "abb", true},
265		{"ab", "abc", false},
266		{"abc", "ab", false},
267		{"a*b", "axxxb", true},
268		{"a*b", "axbxb", true},
269		{"a*b", "axbxbc", false},
270		{"a*?", "axbxc", true},
271		{"a*b*", "axxbxxxxxx", true},
272		{"a*b*c", "axxbxxxxxxc", true},
273		{"a*b*?", "axxbxxxxxxc", true},
274		{"a*b*z", "axxbxxbxxxz", true},
275		{"a*b*z", "axxbxxzxxxz", true},
276		{"a*b*z", "axxbxxzxxx", false},
277	} {
278		got := wildcardMatch([]byte(c.pat), []byte(c.str))
279		if got != c.want {
280			t.Errorf("wildcardMatch(%q, %q) = %v, want %v", c.pat, c.str, got, c.want)
281		}
282
283	}
284}
285
286// TODO(hanwen): test coverage for certificates.
287
288const testHostname = "hostname"
289
290// generated with keygen -H -f
291const encodedTestHostnameHash = "|1|IHXZvQMvTcZTUU29+2vXFgx8Frs=|UGccIWfRVDwilMBnA3WJoRAC75Y="
292
293func TestHostHash(t *testing.T) {
294	testHostHash(t, testHostname, encodedTestHostnameHash)
295}
296
297func TestHashList(t *testing.T) {
298	encoded := HashHostname(testHostname)
299	testHostHash(t, testHostname, encoded)
300}
301
302func testHostHash(t *testing.T, hostname, encoded string) {
303	typ, salt, hash, err := decodeHash(encoded)
304	if err != nil {
305		t.Fatalf("decodeHash: %v", err)
306	}
307
308	if got := encodeHash(typ, salt, hash); got != encoded {
309		t.Errorf("got encoding %s want %s", got, encoded)
310	}
311
312	if typ != sha1HashType {
313		t.Fatalf("got hash type %q, want %q", typ, sha1HashType)
314	}
315
316	got := hashHost(hostname, salt)
317	if !bytes.Equal(got, hash) {
318		t.Errorf("got hash %x want %x", got, hash)
319	}
320}
321
322func TestNormalize(t *testing.T) {
323	for in, want := range map[string]string{
324		"127.0.0.1:22":             "127.0.0.1",
325		"[127.0.0.1]:22":           "127.0.0.1",
326		"[127.0.0.1]:23":           "[127.0.0.1]:23",
327		"127.0.0.1:23":             "[127.0.0.1]:23",
328		"[a.b.c]:22":               "a.b.c",
329		"[abcd:abcd:abcd:abcd]":    "[abcd:abcd:abcd:abcd]",
330		"[abcd:abcd:abcd:abcd]:22": "[abcd:abcd:abcd:abcd]",
331		"[abcd:abcd:abcd:abcd]:23": "[abcd:abcd:abcd:abcd]:23",
332	} {
333		got := Normalize(in)
334		if got != want {
335			t.Errorf("Normalize(%q) = %q, want %q", in, got, want)
336		}
337	}
338}
339
340func TestHashedHostkeyCheck(t *testing.T) {
341	str := fmt.Sprintf("%s %s", HashHostname(testHostname), edKeyStr)
342	db := testDB(t, str)
343	if err := db.check(testHostname+":22", testAddr, edKey); err != nil {
344		t.Errorf("check(%s): %v", testHostname, err)
345	}
346	want := &KeyError{
347		Want: []KnownKey{{
348			Filename: "testdb",
349			Line:     1,
350			Key:      edKey,
351		}},
352	}
353	if got := db.check(testHostname+":22", testAddr, alternateEdKey); !reflect.DeepEqual(got, want) {
354		t.Errorf("got error %v, want %v", got, want)
355	}
356}
357