1// Copyright 2017 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package knownhosts
6
7import (
8	"bytes"
9	"fmt"
10	"net"
11	"reflect"
12	"testing"
13
14	"golang.org/x/crypto/ssh"
15)
16
17const edKeyStr = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIGBAarftlLeoyf+v+nVchEZII/vna2PCV8FaX4vsF5BX"
18const alternateEdKeyStr = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIIXffBYeYL+WVzVru8npl5JHt2cjlr4ornFTWzoij9sx"
19const ecKeyStr = "ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBNLCu01+wpXe3xB5olXCN4SqU2rQu0qjSRKJO4Bg+JRCPU+ENcgdA5srTU8xYDz/GEa4dzK5ldPw4J/gZgSXCMs="
20
21var ecKey, alternateEdKey, edKey ssh.PublicKey
22var testAddr = &net.TCPAddr{
23	IP:   net.IP{198, 41, 30, 196},
24	Port: 22,
25}
26
27var testAddr6 = &net.TCPAddr{
28	IP: net.IP{198, 41, 30, 196,
29		1, 2, 3, 4,
30		1, 2, 3, 4,
31		1, 2, 3, 4,
32	},
33	Port: 22,
34}
35
36func init() {
37	var err error
38	ecKey, _, _, _, err = ssh.ParseAuthorizedKey([]byte(ecKeyStr))
39	if err != nil {
40		panic(err)
41	}
42	edKey, _, _, _, err = ssh.ParseAuthorizedKey([]byte(edKeyStr))
43	if err != nil {
44		panic(err)
45	}
46	alternateEdKey, _, _, _, err = ssh.ParseAuthorizedKey([]byte(alternateEdKeyStr))
47	if err != nil {
48		panic(err)
49	}
50}
51
52func testDB(t *testing.T, s string) *hostKeyDB {
53	db := newHostKeyDB()
54	if err := db.Read(bytes.NewBufferString(s), "testdb"); err != nil {
55		t.Fatalf("Read: %v", err)
56	}
57
58	return db
59}
60
61func TestRevoked(t *testing.T) {
62	db := testDB(t, "\n\n@revoked * "+edKeyStr+"\n")
63	want := &RevokedError{
64		Revoked: KnownKey{
65			Key:      edKey,
66			Filename: "testdb",
67			Line:     3,
68		},
69	}
70	if err := db.check("", &net.TCPAddr{
71		Port: 42,
72	}, edKey); err == nil {
73		t.Fatal("no error for revoked key")
74	} else if !reflect.DeepEqual(want, err) {
75		t.Fatalf("got %#v, want %#v", want, err)
76	}
77}
78
79func TestHostAuthority(t *testing.T) {
80	for _, m := range []struct {
81		authorityFor string
82		address      string
83
84		good bool
85	}{
86		{authorityFor: "localhost", address: "localhost:22", good: true},
87		{authorityFor: "localhost", address: "localhost", good: false},
88		{authorityFor: "localhost", address: "localhost:1234", good: false},
89		{authorityFor: "[localhost]:1234", address: "localhost:1234", good: true},
90		{authorityFor: "[localhost]:1234", address: "localhost:22", good: false},
91		{authorityFor: "[localhost]:1234", address: "localhost", good: false},
92	} {
93		db := testDB(t, `@cert-authority `+m.authorityFor+` `+edKeyStr)
94		if ok := db.IsHostAuthority(db.lines[0].knownKey.Key, m.address); ok != m.good {
95			t.Errorf("IsHostAuthority: authority %s, address %s, wanted good = %v, got good = %v",
96				m.authorityFor, m.address, m.good, ok)
97		}
98	}
99}
100
101func TestBracket(t *testing.T) {
102	db := testDB(t, `[git.eclipse.org]:29418,[198.41.30.196]:29418 `+edKeyStr)
103
104	if err := db.check("git.eclipse.org:29418", &net.TCPAddr{
105		IP:   net.IP{198, 41, 30, 196},
106		Port: 29418,
107	}, edKey); err != nil {
108		t.Errorf("got error %v, want none", err)
109	}
110
111	if err := db.check("git.eclipse.org:29419", &net.TCPAddr{
112		Port: 42,
113	}, edKey); err == nil {
114		t.Fatalf("no error for unknown address")
115	} else if ke, ok := err.(*KeyError); !ok {
116		t.Fatalf("got type %T, want *KeyError", err)
117	} else if len(ke.Want) > 0 {
118		t.Fatalf("got Want %v, want []", ke.Want)
119	}
120}
121
122func TestNewKeyType(t *testing.T) {
123	str := fmt.Sprintf("%s %s", testAddr, edKeyStr)
124	db := testDB(t, str)
125	if err := db.check("", testAddr, ecKey); err == nil {
126		t.Fatalf("no error for unknown address")
127	} else if ke, ok := err.(*KeyError); !ok {
128		t.Fatalf("got type %T, want *KeyError", err)
129	} else if len(ke.Want) == 0 {
130		t.Fatalf("got empty KeyError.Want")
131	}
132}
133
134func TestSameKeyType(t *testing.T) {
135	str := fmt.Sprintf("%s %s", testAddr, edKeyStr)
136	db := testDB(t, str)
137	if err := db.check("", testAddr, alternateEdKey); err == nil {
138		t.Fatalf("no error for unknown address")
139	} else if ke, ok := err.(*KeyError); !ok {
140		t.Fatalf("got type %T, want *KeyError", err)
141	} else if len(ke.Want) == 0 {
142		t.Fatalf("got empty KeyError.Want")
143	} else if got, want := ke.Want[0].Key.Marshal(), edKey.Marshal(); !bytes.Equal(got, want) {
144		t.Fatalf("got key %q, want %q", got, want)
145	}
146}
147
148func TestIPAddress(t *testing.T) {
149	str := fmt.Sprintf("%s %s", testAddr, edKeyStr)
150	db := testDB(t, str)
151	if err := db.check("", testAddr, edKey); err != nil {
152		t.Errorf("got error %q, want none", err)
153	}
154}
155
156func TestIPv6Address(t *testing.T) {
157	str := fmt.Sprintf("%s %s", testAddr6, edKeyStr)
158	db := testDB(t, str)
159
160	if err := db.check("", testAddr6, edKey); err != nil {
161		t.Errorf("got error %q, want none", err)
162	}
163}
164
165func TestBasic(t *testing.T) {
166	str := fmt.Sprintf("#comment\n\nserver.org,%s %s\notherhost %s", testAddr, edKeyStr, ecKeyStr)
167	db := testDB(t, str)
168	if err := db.check("server.org:22", testAddr, edKey); err != nil {
169		t.Errorf("got error %q, want none", err)
170	}
171
172	want := KnownKey{
173		Key:      edKey,
174		Filename: "testdb",
175		Line:     3,
176	}
177	if err := db.check("server.org:22", testAddr, ecKey); err == nil {
178		t.Errorf("succeeded, want KeyError")
179	} else if ke, ok := err.(*KeyError); !ok {
180		t.Errorf("got %T, want *KeyError", err)
181	} else if len(ke.Want) != 1 {
182		t.Errorf("got %v, want 1 entry", ke)
183	} else if !reflect.DeepEqual(ke.Want[0], want) {
184		t.Errorf("got %v, want %v", ke.Want[0], want)
185	}
186}
187
188func TestNegate(t *testing.T) {
189	str := fmt.Sprintf("%s,!server.org %s", testAddr, edKeyStr)
190	db := testDB(t, str)
191	if err := db.check("server.org:22", testAddr, ecKey); err == nil {
192		t.Errorf("succeeded")
193	} else if ke, ok := err.(*KeyError); !ok {
194		t.Errorf("got error type %T, want *KeyError", err)
195	} else if len(ke.Want) != 0 {
196		t.Errorf("got expected keys %d (first of type %s), want []", len(ke.Want), ke.Want[0].Key.Type())
197	}
198}
199
200func TestWildcard(t *testing.T) {
201	str := fmt.Sprintf("server*.domain %s", edKeyStr)
202	db := testDB(t, str)
203
204	want := &KeyError{
205		Want: []KnownKey{{
206			Filename: "testdb",
207			Line:     1,
208			Key:      edKey,
209		}},
210	}
211
212	got := db.check("server.domain:22", &net.TCPAddr{}, ecKey)
213	if !reflect.DeepEqual(got, want) {
214		t.Errorf("got %s, want %s", got, want)
215	}
216}
217
218func TestLine(t *testing.T) {
219	for in, want := range map[string]string{
220		"server.org":                             "server.org " + edKeyStr,
221		"server.org:22":                          "server.org " + edKeyStr,
222		"server.org:23":                          "[server.org]:23 " + edKeyStr,
223		"[c629:1ec4:102:304:102:304:102:304]:22": "[c629:1ec4:102:304:102:304:102:304] " + edKeyStr,
224		"[c629:1ec4:102:304:102:304:102:304]:23": "[c629:1ec4:102:304:102:304:102:304]:23 " + edKeyStr,
225	} {
226		if got := Line([]string{in}, edKey); got != want {
227			t.Errorf("Line(%q) = %q, want %q", in, got, want)
228		}
229	}
230}
231
232func TestWildcardMatch(t *testing.T) {
233	for _, c := range []struct {
234		pat, str string
235		want     bool
236	}{
237		{"a?b", "abb", true},
238		{"ab", "abc", false},
239		{"abc", "ab", false},
240		{"a*b", "axxxb", true},
241		{"a*b", "axbxb", true},
242		{"a*b", "axbxbc", false},
243		{"a*?", "axbxc", true},
244		{"a*b*", "axxbxxxxxx", true},
245		{"a*b*c", "axxbxxxxxxc", true},
246		{"a*b*?", "axxbxxxxxxc", true},
247		{"a*b*z", "axxbxxbxxxz", true},
248		{"a*b*z", "axxbxxzxxxz", true},
249		{"a*b*z", "axxbxxzxxx", false},
250	} {
251		got := wildcardMatch([]byte(c.pat), []byte(c.str))
252		if got != c.want {
253			t.Errorf("wildcardMatch(%q, %q) = %v, want %v", c.pat, c.str, got, c.want)
254		}
255
256	}
257}
258
259// TODO(hanwen): test coverage for certificates.
260
261const testHostname = "hostname"
262
263// generated with keygen -H -f
264const encodedTestHostnameHash = "|1|IHXZvQMvTcZTUU29+2vXFgx8Frs=|UGccIWfRVDwilMBnA3WJoRAC75Y="
265
266func TestHostHash(t *testing.T) {
267	testHostHash(t, testHostname, encodedTestHostnameHash)
268}
269
270func TestHashList(t *testing.T) {
271	encoded := HashHostname(testHostname)
272	testHostHash(t, testHostname, encoded)
273}
274
275func testHostHash(t *testing.T, hostname, encoded string) {
276	typ, salt, hash, err := decodeHash(encoded)
277	if err != nil {
278		t.Fatalf("decodeHash: %v", err)
279	}
280
281	if got := encodeHash(typ, salt, hash); got != encoded {
282		t.Errorf("got encoding %s want %s", got, encoded)
283	}
284
285	if typ != sha1HashType {
286		t.Fatalf("got hash type %q, want %q", typ, sha1HashType)
287	}
288
289	got := hashHost(hostname, salt)
290	if !bytes.Equal(got, hash) {
291		t.Errorf("got hash %x want %x", got, hash)
292	}
293}
294
295func TestNormalize(t *testing.T) {
296	for in, want := range map[string]string{
297		"127.0.0.1:22":             "127.0.0.1",
298		"[127.0.0.1]:22":           "127.0.0.1",
299		"[127.0.0.1]:23":           "[127.0.0.1]:23",
300		"127.0.0.1:23":             "[127.0.0.1]:23",
301		"[a.b.c]:22":               "a.b.c",
302		"[abcd:abcd:abcd:abcd]":    "[abcd:abcd:abcd:abcd]",
303		"[abcd:abcd:abcd:abcd]:22": "[abcd:abcd:abcd:abcd]",
304		"[abcd:abcd:abcd:abcd]:23": "[abcd:abcd:abcd:abcd]:23",
305	} {
306		got := Normalize(in)
307		if got != want {
308			t.Errorf("Normalize(%q) = %q, want %q", in, got, want)
309		}
310	}
311}
312
313func TestHashedHostkeyCheck(t *testing.T) {
314	str := fmt.Sprintf("%s %s", HashHostname(testHostname), edKeyStr)
315	db := testDB(t, str)
316	if err := db.check(testHostname+":22", testAddr, edKey); err != nil {
317		t.Errorf("check(%s): %v", testHostname, err)
318	}
319	want := &KeyError{
320		Want: []KnownKey{{
321			Filename: "testdb",
322			Line:     1,
323			Key:      edKey,
324		}},
325	}
326	if got := db.check(testHostname+":22", testAddr, alternateEdKey); !reflect.DeepEqual(got, want) {
327		t.Errorf("got error %v, want %v", got, want)
328	}
329}
330