1// Copyright 2013 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package ssh
6
7import (
8	"bytes"
9	"crypto/rand"
10	"errors"
11	"fmt"
12	"io"
13	"net"
14	"reflect"
15	"runtime"
16	"strings"
17	"sync"
18	"testing"
19)
20
21type testChecker struct {
22	calls []string
23}
24
25func (t *testChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error {
26	if dialAddr == "bad" {
27		return fmt.Errorf("dialAddr is bad")
28	}
29
30	if tcpAddr, ok := addr.(*net.TCPAddr); !ok || tcpAddr == nil {
31		return fmt.Errorf("testChecker: got %T want *net.TCPAddr", addr)
32	}
33
34	t.calls = append(t.calls, fmt.Sprintf("%s %v %s %x", dialAddr, addr, key.Type(), key.Marshal()))
35
36	return nil
37}
38
39// netPipe is analogous to net.Pipe, but it uses a real net.Conn, and
40// therefore is buffered (net.Pipe deadlocks if both sides start with
41// a write.)
42func netPipe() (net.Conn, net.Conn, error) {
43	listener, err := net.Listen("tcp", "127.0.0.1:0")
44	if err != nil {
45		listener, err = net.Listen("tcp", "[::1]:0")
46		if err != nil {
47			return nil, nil, err
48		}
49	}
50	defer listener.Close()
51	c1, err := net.Dial("tcp", listener.Addr().String())
52	if err != nil {
53		return nil, nil, err
54	}
55
56	c2, err := listener.Accept()
57	if err != nil {
58		c1.Close()
59		return nil, nil, err
60	}
61
62	return c1, c2, nil
63}
64
65// noiseTransport inserts ignore messages to check that the read loop
66// and the key exchange filters out these messages.
67type noiseTransport struct {
68	keyingTransport
69}
70
71func (t *noiseTransport) writePacket(p []byte) error {
72	ignore := []byte{msgIgnore}
73	if err := t.keyingTransport.writePacket(ignore); err != nil {
74		return err
75	}
76	debug := []byte{msgDebug, 1, 2, 3}
77	if err := t.keyingTransport.writePacket(debug); err != nil {
78		return err
79	}
80
81	return t.keyingTransport.writePacket(p)
82}
83
84func addNoiseTransport(t keyingTransport) keyingTransport {
85	return &noiseTransport{t}
86}
87
88// handshakePair creates two handshakeTransports connected with each
89// other. If the noise argument is true, both transports will try to
90// confuse the other side by sending ignore and debug messages.
91func handshakePair(clientConf *ClientConfig, addr string, noise bool) (client *handshakeTransport, server *handshakeTransport, err error) {
92	a, b, err := netPipe()
93	if err != nil {
94		return nil, nil, err
95	}
96
97	var trC, trS keyingTransport
98
99	trC = newTransport(a, rand.Reader, true)
100	trS = newTransport(b, rand.Reader, false)
101	if noise {
102		trC = addNoiseTransport(trC)
103		trS = addNoiseTransport(trS)
104	}
105	clientConf.SetDefaults()
106
107	v := []byte("version")
108	client = newClientTransport(trC, v, v, clientConf, addr, a.RemoteAddr())
109
110	serverConf := &ServerConfig{}
111	serverConf.AddHostKey(testSigners["ecdsa"])
112	serverConf.AddHostKey(testSigners["rsa"])
113	serverConf.SetDefaults()
114	server = newServerTransport(trS, v, v, serverConf)
115
116	if err := server.waitSession(); err != nil {
117		return nil, nil, fmt.Errorf("server.waitSession: %v", err)
118	}
119	if err := client.waitSession(); err != nil {
120		return nil, nil, fmt.Errorf("client.waitSession: %v", err)
121	}
122
123	return client, server, nil
124}
125
126func TestHandshakeBasic(t *testing.T) {
127	if runtime.GOOS == "plan9" {
128		t.Skip("see golang.org/issue/7237")
129	}
130
131	checker := &syncChecker{
132		waitCall: make(chan int, 10),
133		called:   make(chan int, 10),
134	}
135
136	checker.waitCall <- 1
137	trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr", false)
138	if err != nil {
139		t.Fatalf("handshakePair: %v", err)
140	}
141
142	defer trC.Close()
143	defer trS.Close()
144
145	// Let first kex complete normally.
146	<-checker.called
147
148	clientDone := make(chan int, 0)
149	gotHalf := make(chan int, 0)
150	const N = 20
151
152	go func() {
153		defer close(clientDone)
154		// Client writes a bunch of stuff, and does a key
155		// change in the middle. This should not confuse the
156		// handshake in progress. We do this twice, so we test
157		// that the packet buffer is reset correctly.
158		for i := 0; i < N; i++ {
159			p := []byte{msgRequestSuccess, byte(i)}
160			if err := trC.writePacket(p); err != nil {
161				t.Fatalf("sendPacket: %v", err)
162			}
163			if (i % 10) == 5 {
164				<-gotHalf
165				// halfway through, we request a key change.
166				trC.requestKeyExchange()
167
168				// Wait until we can be sure the key
169				// change has really started before we
170				// write more.
171				<-checker.called
172			}
173			if (i % 10) == 7 {
174				// write some packets until the kex
175				// completes, to test buffering of
176				// packets.
177				checker.waitCall <- 1
178			}
179		}
180	}()
181
182	// Server checks that client messages come in cleanly
183	i := 0
184	err = nil
185	for ; i < N; i++ {
186		var p []byte
187		p, err = trS.readPacket()
188		if err != nil {
189			break
190		}
191		if (i % 10) == 5 {
192			gotHalf <- 1
193		}
194
195		want := []byte{msgRequestSuccess, byte(i)}
196		if bytes.Compare(p, want) != 0 {
197			t.Errorf("message %d: got %v, want %v", i, p, want)
198		}
199	}
200	<-clientDone
201	if err != nil && err != io.EOF {
202		t.Fatalf("server error: %v", err)
203	}
204	if i != N {
205		t.Errorf("received %d messages, want 10.", i)
206	}
207
208	close(checker.called)
209	if _, ok := <-checker.called; ok {
210		// If all went well, we registered exactly 2 key changes: one
211		// that establishes the session, and one that we requested
212		// additionally.
213		t.Fatalf("got another host key checks after 2 handshakes")
214	}
215}
216
217func TestForceFirstKex(t *testing.T) {
218	// like handshakePair, but must access the keyingTransport.
219	checker := &testChecker{}
220	clientConf := &ClientConfig{HostKeyCallback: checker.Check}
221	a, b, err := netPipe()
222	if err != nil {
223		t.Fatalf("netPipe: %v", err)
224	}
225
226	var trC, trS keyingTransport
227
228	trC = newTransport(a, rand.Reader, true)
229
230	// This is the disallowed packet:
231	trC.writePacket(Marshal(&serviceRequestMsg{serviceUserAuth}))
232
233	// Rest of the setup.
234	trS = newTransport(b, rand.Reader, false)
235	clientConf.SetDefaults()
236
237	v := []byte("version")
238	client := newClientTransport(trC, v, v, clientConf, "addr", a.RemoteAddr())
239
240	serverConf := &ServerConfig{}
241	serverConf.AddHostKey(testSigners["ecdsa"])
242	serverConf.AddHostKey(testSigners["rsa"])
243	serverConf.SetDefaults()
244	server := newServerTransport(trS, v, v, serverConf)
245
246	defer client.Close()
247	defer server.Close()
248
249	// We setup the initial key exchange, but the remote side
250	// tries to send serviceRequestMsg in cleartext, which is
251	// disallowed.
252
253	if err := server.waitSession(); err == nil {
254		t.Errorf("server first kex init should reject unexpected packet")
255	}
256}
257
258func TestHandshakeAutoRekeyWrite(t *testing.T) {
259	checker := &syncChecker{
260		called:   make(chan int, 10),
261		waitCall: nil,
262	}
263	clientConf := &ClientConfig{HostKeyCallback: checker.Check}
264	clientConf.RekeyThreshold = 500
265	trC, trS, err := handshakePair(clientConf, "addr", false)
266	if err != nil {
267		t.Fatalf("handshakePair: %v", err)
268	}
269	defer trC.Close()
270	defer trS.Close()
271
272	input := make([]byte, 251)
273	input[0] = msgRequestSuccess
274
275	done := make(chan int, 1)
276	const numPacket = 5
277	go func() {
278		defer close(done)
279		j := 0
280		for ; j < numPacket; j++ {
281			if p, err := trS.readPacket(); err != nil {
282				break
283			} else if !bytes.Equal(input, p) {
284				t.Errorf("got packet type %d, want %d", p[0], input[0])
285			}
286		}
287
288		if j != numPacket {
289			t.Errorf("got %d, want 5 messages", j)
290		}
291	}()
292
293	<-checker.called
294
295	for i := 0; i < numPacket; i++ {
296		p := make([]byte, len(input))
297		copy(p, input)
298		if err := trC.writePacket(p); err != nil {
299			t.Errorf("writePacket: %v", err)
300		}
301		if i == 2 {
302			// Make sure the kex is in progress.
303			<-checker.called
304		}
305
306	}
307	<-done
308}
309
310type syncChecker struct {
311	waitCall chan int
312	called   chan int
313}
314
315func (c *syncChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error {
316	c.called <- 1
317	if c.waitCall != nil {
318		<-c.waitCall
319	}
320	return nil
321}
322
323func TestHandshakeAutoRekeyRead(t *testing.T) {
324	sync := &syncChecker{
325		called:   make(chan int, 2),
326		waitCall: nil,
327	}
328	clientConf := &ClientConfig{
329		HostKeyCallback: sync.Check,
330	}
331	clientConf.RekeyThreshold = 500
332
333	trC, trS, err := handshakePair(clientConf, "addr", false)
334	if err != nil {
335		t.Fatalf("handshakePair: %v", err)
336	}
337	defer trC.Close()
338	defer trS.Close()
339
340	packet := make([]byte, 501)
341	packet[0] = msgRequestSuccess
342	if err := trS.writePacket(packet); err != nil {
343		t.Fatalf("writePacket: %v", err)
344	}
345
346	// While we read out the packet, a key change will be
347	// initiated.
348	done := make(chan int, 1)
349	go func() {
350		defer close(done)
351		if _, err := trC.readPacket(); err != nil {
352			t.Fatalf("readPacket(client): %v", err)
353		}
354
355	}()
356
357	<-done
358	<-sync.called
359}
360
361// errorKeyingTransport generates errors after a given number of
362// read/write operations.
363type errorKeyingTransport struct {
364	packetConn
365	readLeft, writeLeft int
366}
367
368func (n *errorKeyingTransport) prepareKeyChange(*algorithms, *kexResult) error {
369	return nil
370}
371
372func (n *errorKeyingTransport) getSessionID() []byte {
373	return nil
374}
375
376func (n *errorKeyingTransport) writePacket(packet []byte) error {
377	if n.writeLeft == 0 {
378		n.Close()
379		return errors.New("barf")
380	}
381
382	n.writeLeft--
383	return n.packetConn.writePacket(packet)
384}
385
386func (n *errorKeyingTransport) readPacket() ([]byte, error) {
387	if n.readLeft == 0 {
388		n.Close()
389		return nil, errors.New("barf")
390	}
391
392	n.readLeft--
393	return n.packetConn.readPacket()
394}
395
396func TestHandshakeErrorHandlingRead(t *testing.T) {
397	for i := 0; i < 20; i++ {
398		testHandshakeErrorHandlingN(t, i, -1, false)
399	}
400}
401
402func TestHandshakeErrorHandlingWrite(t *testing.T) {
403	for i := 0; i < 20; i++ {
404		testHandshakeErrorHandlingN(t, -1, i, false)
405	}
406}
407
408func TestHandshakeErrorHandlingReadCoupled(t *testing.T) {
409	for i := 0; i < 20; i++ {
410		testHandshakeErrorHandlingN(t, i, -1, true)
411	}
412}
413
414func TestHandshakeErrorHandlingWriteCoupled(t *testing.T) {
415	for i := 0; i < 20; i++ {
416		testHandshakeErrorHandlingN(t, -1, i, true)
417	}
418}
419
420// testHandshakeErrorHandlingN runs handshakes, injecting errors. If
421// handshakeTransport deadlocks, the go runtime will detect it and
422// panic.
423func testHandshakeErrorHandlingN(t *testing.T, readLimit, writeLimit int, coupled bool) {
424	msg := Marshal(&serviceRequestMsg{strings.Repeat("x", int(minRekeyThreshold)/4)})
425
426	a, b := memPipe()
427	defer a.Close()
428	defer b.Close()
429
430	key := testSigners["ecdsa"]
431	serverConf := Config{RekeyThreshold: minRekeyThreshold}
432	serverConf.SetDefaults()
433	serverConn := newHandshakeTransport(&errorKeyingTransport{a, readLimit, writeLimit}, &serverConf, []byte{'a'}, []byte{'b'})
434	serverConn.hostKeys = []Signer{key}
435	go serverConn.readLoop()
436	go serverConn.kexLoop()
437
438	clientConf := Config{RekeyThreshold: 10 * minRekeyThreshold}
439	clientConf.SetDefaults()
440	clientConn := newHandshakeTransport(&errorKeyingTransport{b, -1, -1}, &clientConf, []byte{'a'}, []byte{'b'})
441	clientConn.hostKeyAlgorithms = []string{key.PublicKey().Type()}
442	clientConn.hostKeyCallback = InsecureIgnoreHostKey()
443	go clientConn.readLoop()
444	go clientConn.kexLoop()
445
446	var wg sync.WaitGroup
447
448	for _, hs := range []packetConn{serverConn, clientConn} {
449		if !coupled {
450			wg.Add(2)
451			go func(c packetConn) {
452				for i := 0; ; i++ {
453					str := fmt.Sprintf("%08x", i) + strings.Repeat("x", int(minRekeyThreshold)/4-8)
454					err := c.writePacket(Marshal(&serviceRequestMsg{str}))
455					if err != nil {
456						break
457					}
458				}
459				wg.Done()
460				c.Close()
461			}(hs)
462			go func(c packetConn) {
463				for {
464					_, err := c.readPacket()
465					if err != nil {
466						break
467					}
468				}
469				wg.Done()
470			}(hs)
471		} else {
472			wg.Add(1)
473			go func(c packetConn) {
474				for {
475					_, err := c.readPacket()
476					if err != nil {
477						break
478					}
479					if err := c.writePacket(msg); err != nil {
480						break
481					}
482
483				}
484				wg.Done()
485			}(hs)
486		}
487	}
488	wg.Wait()
489}
490
491func TestDisconnect(t *testing.T) {
492	if runtime.GOOS == "plan9" {
493		t.Skip("see golang.org/issue/7237")
494	}
495	checker := &testChecker{}
496	trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr", false)
497	if err != nil {
498		t.Fatalf("handshakePair: %v", err)
499	}
500
501	defer trC.Close()
502	defer trS.Close()
503
504	trC.writePacket([]byte{msgRequestSuccess, 0, 0})
505	errMsg := &disconnectMsg{
506		Reason:  42,
507		Message: "such is life",
508	}
509	trC.writePacket(Marshal(errMsg))
510	trC.writePacket([]byte{msgRequestSuccess, 0, 0})
511
512	packet, err := trS.readPacket()
513	if err != nil {
514		t.Fatalf("readPacket 1: %v", err)
515	}
516	if packet[0] != msgRequestSuccess {
517		t.Errorf("got packet %v, want packet type %d", packet, msgRequestSuccess)
518	}
519
520	_, err = trS.readPacket()
521	if err == nil {
522		t.Errorf("readPacket 2 succeeded")
523	} else if !reflect.DeepEqual(err, errMsg) {
524		t.Errorf("got error %#v, want %#v", err, errMsg)
525	}
526
527	_, err = trS.readPacket()
528	if err == nil {
529		t.Errorf("readPacket 3 succeeded")
530	}
531}
532
533func TestHandshakeRekeyDefault(t *testing.T) {
534	clientConf := &ClientConfig{
535		Config: Config{
536			Ciphers: []string{"aes128-ctr"},
537		},
538		HostKeyCallback: InsecureIgnoreHostKey(),
539	}
540	trC, trS, err := handshakePair(clientConf, "addr", false)
541	if err != nil {
542		t.Fatalf("handshakePair: %v", err)
543	}
544	defer trC.Close()
545	defer trS.Close()
546
547	trC.writePacket([]byte{msgRequestSuccess, 0, 0})
548	trC.Close()
549
550	rgb := (1024 + trC.readBytesLeft) >> 30
551	wgb := (1024 + trC.writeBytesLeft) >> 30
552
553	if rgb != 64 {
554		t.Errorf("got rekey after %dG read, want 64G", rgb)
555	}
556	if wgb != 64 {
557		t.Errorf("got rekey after %dG write, want 64G", wgb)
558	}
559}
560