1// Copyright 2013 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package ssh
6
7import (
8	"bytes"
9	"crypto/rand"
10	"errors"
11	"fmt"
12	"net"
13	"runtime"
14	"strings"
15	"sync"
16	"testing"
17)
18
19type testChecker struct {
20	calls []string
21}
22
23func (t *testChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error {
24	if dialAddr == "bad" {
25		return fmt.Errorf("dialAddr is bad")
26	}
27
28	if tcpAddr, ok := addr.(*net.TCPAddr); !ok || tcpAddr == nil {
29		return fmt.Errorf("testChecker: got %T want *net.TCPAddr", addr)
30	}
31
32	t.calls = append(t.calls, fmt.Sprintf("%s %v %s %x", dialAddr, addr, key.Type(), key.Marshal()))
33
34	return nil
35}
36
37// netPipe is analogous to net.Pipe, but it uses a real net.Conn, and
38// therefore is buffered (net.Pipe deadlocks if both sides start with
39// a write.)
40func netPipe() (net.Conn, net.Conn, error) {
41	listener, err := net.Listen("tcp", "127.0.0.1:0")
42	if err != nil {
43		return nil, nil, err
44	}
45	defer listener.Close()
46	c1, err := net.Dial("tcp", listener.Addr().String())
47	if err != nil {
48		return nil, nil, err
49	}
50
51	c2, err := listener.Accept()
52	if err != nil {
53		c1.Close()
54		return nil, nil, err
55	}
56
57	return c1, c2, nil
58}
59
60func handshakePair(clientConf *ClientConfig, addr string) (client *handshakeTransport, server *handshakeTransport, err error) {
61	a, b, err := netPipe()
62	if err != nil {
63		return nil, nil, err
64	}
65
66	trC := newTransport(a, rand.Reader, true)
67	trS := newTransport(b, rand.Reader, false)
68	clientConf.SetDefaults()
69
70	v := []byte("version")
71	client = newClientTransport(trC, v, v, clientConf, addr, a.RemoteAddr())
72
73	serverConf := &ServerConfig{}
74	serverConf.AddHostKey(testSigners["ecdsa"])
75	serverConf.AddHostKey(testSigners["rsa"])
76	serverConf.SetDefaults()
77	server = newServerTransport(trS, v, v, serverConf)
78
79	return client, server, nil
80}
81
82func TestHandshakeBasic(t *testing.T) {
83	if runtime.GOOS == "plan9" {
84		t.Skip("see golang.org/issue/7237")
85	}
86	checker := &testChecker{}
87	trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr")
88	if err != nil {
89		t.Fatalf("handshakePair: %v", err)
90	}
91
92	defer trC.Close()
93	defer trS.Close()
94
95	go func() {
96		// Client writes a bunch of stuff, and does a key
97		// change in the middle. This should not confuse the
98		// handshake in progress
99		for i := 0; i < 10; i++ {
100			p := []byte{msgRequestSuccess, byte(i)}
101			if err := trC.writePacket(p); err != nil {
102				t.Fatalf("sendPacket: %v", err)
103			}
104			if i == 5 {
105				// halfway through, we request a key change.
106				_, _, err := trC.sendKexInit()
107				if err != nil {
108					t.Fatalf("sendKexInit: %v", err)
109				}
110			}
111		}
112		trC.Close()
113	}()
114
115	// Server checks that client messages come in cleanly
116	i := 0
117	for {
118		p, err := trS.readPacket()
119		if err != nil {
120			break
121		}
122		if p[0] == msgNewKeys {
123			continue
124		}
125		want := []byte{msgRequestSuccess, byte(i)}
126		if bytes.Compare(p, want) != 0 {
127			t.Errorf("message %d: got %q, want %q", i, p, want)
128		}
129		i++
130	}
131	if i != 10 {
132		t.Errorf("received %d messages, want 10.", i)
133	}
134
135	// If all went well, we registered exactly 1 key change.
136	if len(checker.calls) != 1 {
137		t.Fatalf("got %d host key checks, want 1", len(checker.calls))
138	}
139
140	pub := testSigners["ecdsa"].PublicKey()
141	want := fmt.Sprintf("%s %v %s %x", "addr", trC.remoteAddr, pub.Type(), pub.Marshal())
142	if want != checker.calls[0] {
143		t.Errorf("got %q want %q for host key check", checker.calls[0], want)
144	}
145}
146
147func TestHandshakeError(t *testing.T) {
148	checker := &testChecker{}
149	trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "bad")
150	if err != nil {
151		t.Fatalf("handshakePair: %v", err)
152	}
153	defer trC.Close()
154	defer trS.Close()
155
156	// send a packet
157	packet := []byte{msgRequestSuccess, 42}
158	if err := trC.writePacket(packet); err != nil {
159		t.Errorf("writePacket: %v", err)
160	}
161
162	// Now request a key change.
163	_, _, err = trC.sendKexInit()
164	if err != nil {
165		t.Errorf("sendKexInit: %v", err)
166	}
167
168	// the key change will fail, and afterwards we can't write.
169	if err := trC.writePacket([]byte{msgRequestSuccess, 43}); err == nil {
170		t.Errorf("writePacket after botched rekey succeeded.")
171	}
172
173	readback, err := trS.readPacket()
174	if err != nil {
175		t.Fatalf("server closed too soon: %v", err)
176	}
177	if bytes.Compare(readback, packet) != 0 {
178		t.Errorf("got %q want %q", readback, packet)
179	}
180	readback, err = trS.readPacket()
181	if err == nil {
182		t.Errorf("got a message %q after failed key change", readback)
183	}
184}
185
186func TestHandshakeTwice(t *testing.T) {
187	checker := &testChecker{}
188	trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr")
189	if err != nil {
190		t.Fatalf("handshakePair: %v", err)
191	}
192
193	defer trC.Close()
194	defer trS.Close()
195
196	// send a packet
197	packet := make([]byte, 5)
198	packet[0] = msgRequestSuccess
199	if err := trC.writePacket(packet); err != nil {
200		t.Errorf("writePacket: %v", err)
201	}
202
203	// Now request a key change.
204	_, _, err = trC.sendKexInit()
205	if err != nil {
206		t.Errorf("sendKexInit: %v", err)
207	}
208
209	// Send another packet. Use a fresh one, since writePacket destroys.
210	packet = make([]byte, 5)
211	packet[0] = msgRequestSuccess
212	if err := trC.writePacket(packet); err != nil {
213		t.Errorf("writePacket: %v", err)
214	}
215
216	// 2nd key change.
217	_, _, err = trC.sendKexInit()
218	if err != nil {
219		t.Errorf("sendKexInit: %v", err)
220	}
221
222	packet = make([]byte, 5)
223	packet[0] = msgRequestSuccess
224	if err := trC.writePacket(packet); err != nil {
225		t.Errorf("writePacket: %v", err)
226	}
227
228	packet = make([]byte, 5)
229	packet[0] = msgRequestSuccess
230	for i := 0; i < 5; i++ {
231		msg, err := trS.readPacket()
232		if err != nil {
233			t.Fatalf("server closed too soon: %v", err)
234		}
235		if msg[0] == msgNewKeys {
236			continue
237		}
238
239		if bytes.Compare(msg, packet) != 0 {
240			t.Errorf("packet %d: got %q want %q", i, msg, packet)
241		}
242	}
243	if len(checker.calls) != 2 {
244		t.Errorf("got %d key changes, want 2", len(checker.calls))
245	}
246}
247
248func TestHandshakeAutoRekeyWrite(t *testing.T) {
249	checker := &testChecker{}
250	clientConf := &ClientConfig{HostKeyCallback: checker.Check}
251	clientConf.RekeyThreshold = 500
252	trC, trS, err := handshakePair(clientConf, "addr")
253	if err != nil {
254		t.Fatalf("handshakePair: %v", err)
255	}
256	defer trC.Close()
257	defer trS.Close()
258
259	for i := 0; i < 5; i++ {
260		packet := make([]byte, 251)
261		packet[0] = msgRequestSuccess
262		if err := trC.writePacket(packet); err != nil {
263			t.Errorf("writePacket: %v", err)
264		}
265	}
266
267	j := 0
268	for ; j < 5; j++ {
269		_, err := trS.readPacket()
270		if err != nil {
271			break
272		}
273	}
274
275	if j != 5 {
276		t.Errorf("got %d, want 5 messages", j)
277	}
278
279	if len(checker.calls) != 2 {
280		t.Errorf("got %d key changes, wanted 2", len(checker.calls))
281	}
282}
283
284type syncChecker struct {
285	called chan int
286}
287
288func (t *syncChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error {
289	t.called <- 1
290	return nil
291}
292
293func TestHandshakeAutoRekeyRead(t *testing.T) {
294	sync := &syncChecker{make(chan int, 2)}
295	clientConf := &ClientConfig{
296		HostKeyCallback: sync.Check,
297	}
298	clientConf.RekeyThreshold = 500
299
300	trC, trS, err := handshakePair(clientConf, "addr")
301	if err != nil {
302		t.Fatalf("handshakePair: %v", err)
303	}
304	defer trC.Close()
305	defer trS.Close()
306
307	packet := make([]byte, 501)
308	packet[0] = msgRequestSuccess
309	if err := trS.writePacket(packet); err != nil {
310		t.Fatalf("writePacket: %v", err)
311	}
312	// While we read out the packet, a key change will be
313	// initiated.
314	if _, err := trC.readPacket(); err != nil {
315		t.Fatalf("readPacket(client): %v", err)
316	}
317
318	<-sync.called
319}
320
321// errorKeyingTransport generates errors after a given number of
322// read/write operations.
323type errorKeyingTransport struct {
324	packetConn
325	readLeft, writeLeft int
326}
327
328func (n *errorKeyingTransport) prepareKeyChange(*algorithms, *kexResult) error {
329	return nil
330}
331func (n *errorKeyingTransport) getSessionID() []byte {
332	return nil
333}
334
335func (n *errorKeyingTransport) writePacket(packet []byte) error {
336	if n.writeLeft == 0 {
337		n.Close()
338		return errors.New("barf")
339	}
340
341	n.writeLeft--
342	return n.packetConn.writePacket(packet)
343}
344
345func (n *errorKeyingTransport) readPacket() ([]byte, error) {
346	if n.readLeft == 0 {
347		n.Close()
348		return nil, errors.New("barf")
349	}
350
351	n.readLeft--
352	return n.packetConn.readPacket()
353}
354
355func TestHandshakeErrorHandlingRead(t *testing.T) {
356	for i := 0; i < 20; i++ {
357		testHandshakeErrorHandlingN(t, i, -1)
358	}
359}
360
361func TestHandshakeErrorHandlingWrite(t *testing.T) {
362	for i := 0; i < 20; i++ {
363		testHandshakeErrorHandlingN(t, -1, i)
364	}
365}
366
367// testHandshakeErrorHandlingN runs handshakes, injecting errors. If
368// handshakeTransport deadlocks, the go runtime will detect it and
369// panic.
370func testHandshakeErrorHandlingN(t *testing.T, readLimit, writeLimit int) {
371	msg := Marshal(&serviceRequestMsg{strings.Repeat("x", int(minRekeyThreshold)/4)})
372
373	a, b := memPipe()
374	defer a.Close()
375	defer b.Close()
376
377	key := testSigners["ecdsa"]
378	serverConf := Config{RekeyThreshold: minRekeyThreshold}
379	serverConf.SetDefaults()
380	serverConn := newHandshakeTransport(&errorKeyingTransport{a, readLimit, writeLimit}, &serverConf, []byte{'a'}, []byte{'b'})
381	serverConn.hostKeys = []Signer{key}
382	go serverConn.readLoop()
383
384	clientConf := Config{RekeyThreshold: 10 * minRekeyThreshold}
385	clientConf.SetDefaults()
386	clientConn := newHandshakeTransport(&errorKeyingTransport{b, -1, -1}, &clientConf, []byte{'a'}, []byte{'b'})
387	clientConn.hostKeyAlgorithms = []string{key.PublicKey().Type()}
388	go clientConn.readLoop()
389
390	var wg sync.WaitGroup
391	wg.Add(4)
392
393	for _, hs := range []packetConn{serverConn, clientConn} {
394		go func(c packetConn) {
395			for {
396				err := c.writePacket(msg)
397				if err != nil {
398					break
399				}
400			}
401			wg.Done()
402		}(hs)
403		go func(c packetConn) {
404			for {
405				_, err := c.readPacket()
406				if err != nil {
407					break
408				}
409			}
410			wg.Done()
411		}(hs)
412	}
413
414	wg.Wait()
415}
416