1// Copyright 2013 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package ssh
6
7import (
8	"crypto/rand"
9	"errors"
10	"fmt"
11	"io"
12	"log"
13	"net"
14	"sync"
15)
16
17// debugHandshake, if set, prints messages sent and received.  Key
18// exchange messages are printed as if DH were used, so the debug
19// messages are wrong when using ECDH.
20const debugHandshake = false
21
22// chanSize sets the amount of buffering SSH connections. This is
23// primarily for testing: setting chanSize=0 uncovers deadlocks more
24// quickly.
25const chanSize = 16
26
27// keyingTransport is a packet based transport that supports key
28// changes. It need not be thread-safe. It should pass through
29// msgNewKeys in both directions.
30type keyingTransport interface {
31	packetConn
32
33	// prepareKeyChange sets up a key change. The key change for a
34	// direction will be effected if a msgNewKeys message is sent
35	// or received.
36	prepareKeyChange(*algorithms, *kexResult) error
37}
38
39// handshakeTransport implements rekeying on top of a keyingTransport
40// and offers a thread-safe writePacket() interface.
41type handshakeTransport struct {
42	conn   keyingTransport
43	config *Config
44
45	serverVersion []byte
46	clientVersion []byte
47
48	// hostKeys is non-empty if we are the server. In that case,
49	// it contains all host keys that can be used to sign the
50	// connection.
51	hostKeys []Signer
52
53	// hostKeyAlgorithms is non-empty if we are the client. In that case,
54	// we accept these key types from the server as host key.
55	hostKeyAlgorithms []string
56
57	// On read error, incoming is closed, and readError is set.
58	incoming  chan []byte
59	readError error
60
61	mu             sync.Mutex
62	writeError     error
63	sentInitPacket []byte
64	sentInitMsg    *kexInitMsg
65	pendingPackets [][]byte // Used when a key exchange is in progress.
66
67	// If the read loop wants to schedule a kex, it pings this
68	// channel, and the write loop will send out a kex
69	// message.
70	requestKex chan struct{}
71
72	// If the other side requests or confirms a kex, its kexInit
73	// packet is sent here for the write loop to find it.
74	startKex chan *pendingKex
75
76	// data for host key checking
77	hostKeyCallback HostKeyCallback
78	dialAddress     string
79	remoteAddr      net.Addr
80
81	// bannerCallback is non-empty if we are the client and it has been set in
82	// ClientConfig. In that case it is called during the user authentication
83	// dance to handle a custom server's message.
84	bannerCallback BannerCallback
85
86	// Algorithms agreed in the last key exchange.
87	algorithms *algorithms
88
89	readPacketsLeft uint32
90	readBytesLeft   int64
91
92	writePacketsLeft uint32
93	writeBytesLeft   int64
94
95	// The session ID or nil if first kex did not complete yet.
96	sessionID []byte
97}
98
99type pendingKex struct {
100	otherInit []byte
101	done      chan error
102}
103
104func newHandshakeTransport(conn keyingTransport, config *Config, clientVersion, serverVersion []byte) *handshakeTransport {
105	t := &handshakeTransport{
106		conn:          conn,
107		serverVersion: serverVersion,
108		clientVersion: clientVersion,
109		incoming:      make(chan []byte, chanSize),
110		requestKex:    make(chan struct{}, 1),
111		startKex:      make(chan *pendingKex, 1),
112
113		config: config,
114	}
115	t.resetReadThresholds()
116	t.resetWriteThresholds()
117
118	// We always start with a mandatory key exchange.
119	t.requestKex <- struct{}{}
120	return t
121}
122
123func newClientTransport(conn keyingTransport, clientVersion, serverVersion []byte, config *ClientConfig, dialAddr string, addr net.Addr) *handshakeTransport {
124	t := newHandshakeTransport(conn, &config.Config, clientVersion, serverVersion)
125	t.dialAddress = dialAddr
126	t.remoteAddr = addr
127	t.hostKeyCallback = config.HostKeyCallback
128	t.bannerCallback = config.BannerCallback
129	if config.HostKeyAlgorithms != nil {
130		t.hostKeyAlgorithms = config.HostKeyAlgorithms
131	} else {
132		t.hostKeyAlgorithms = supportedHostKeyAlgos
133	}
134	go t.readLoop()
135	go t.kexLoop()
136	return t
137}
138
139func newServerTransport(conn keyingTransport, clientVersion, serverVersion []byte, config *ServerConfig) *handshakeTransport {
140	t := newHandshakeTransport(conn, &config.Config, clientVersion, serverVersion)
141	t.hostKeys = config.hostKeys
142	go t.readLoop()
143	go t.kexLoop()
144	return t
145}
146
147func (t *handshakeTransport) getSessionID() []byte {
148	return t.sessionID
149}
150
151// waitSession waits for the session to be established. This should be
152// the first thing to call after instantiating handshakeTransport.
153func (t *handshakeTransport) waitSession() error {
154	p, err := t.readPacket()
155	if err != nil {
156		return err
157	}
158	if p[0] != msgNewKeys {
159		return fmt.Errorf("ssh: first packet should be msgNewKeys")
160	}
161
162	return nil
163}
164
165func (t *handshakeTransport) id() string {
166	if len(t.hostKeys) > 0 {
167		return "server"
168	}
169	return "client"
170}
171
172func (t *handshakeTransport) printPacket(p []byte, write bool) {
173	action := "got"
174	if write {
175		action = "sent"
176	}
177
178	if p[0] == msgChannelData || p[0] == msgChannelExtendedData {
179		log.Printf("%s %s data (packet %d bytes)", t.id(), action, len(p))
180	} else {
181		msg, err := decode(p)
182		log.Printf("%s %s %T %v (%v)", t.id(), action, msg, msg, err)
183	}
184}
185
186func (t *handshakeTransport) readPacket() ([]byte, error) {
187	p, ok := <-t.incoming
188	if !ok {
189		return nil, t.readError
190	}
191	return p, nil
192}
193
194func (t *handshakeTransport) readLoop() {
195	first := true
196	for {
197		p, err := t.readOnePacket(first)
198		first = false
199		if err != nil {
200			t.readError = err
201			close(t.incoming)
202			break
203		}
204		if p[0] == msgIgnore || p[0] == msgDebug {
205			continue
206		}
207		t.incoming <- p
208	}
209
210	// Stop writers too.
211	t.recordWriteError(t.readError)
212
213	// Unblock the writer should it wait for this.
214	close(t.startKex)
215
216	// Don't close t.requestKex; it's also written to from writePacket.
217}
218
219func (t *handshakeTransport) pushPacket(p []byte) error {
220	if debugHandshake {
221		t.printPacket(p, true)
222	}
223	return t.conn.writePacket(p)
224}
225
226func (t *handshakeTransport) getWriteError() error {
227	t.mu.Lock()
228	defer t.mu.Unlock()
229	return t.writeError
230}
231
232func (t *handshakeTransport) recordWriteError(err error) {
233	t.mu.Lock()
234	defer t.mu.Unlock()
235	if t.writeError == nil && err != nil {
236		t.writeError = err
237	}
238}
239
240func (t *handshakeTransport) requestKeyExchange() {
241	select {
242	case t.requestKex <- struct{}{}:
243	default:
244		// something already requested a kex, so do nothing.
245	}
246}
247
248func (t *handshakeTransport) resetWriteThresholds() {
249	t.writePacketsLeft = packetRekeyThreshold
250	if t.config.RekeyThreshold > 0 {
251		t.writeBytesLeft = int64(t.config.RekeyThreshold)
252	} else if t.algorithms != nil {
253		t.writeBytesLeft = t.algorithms.w.rekeyBytes()
254	} else {
255		t.writeBytesLeft = 1 << 30
256	}
257}
258
259func (t *handshakeTransport) kexLoop() {
260
261write:
262	for t.getWriteError() == nil {
263		var request *pendingKex
264		var sent bool
265
266		for request == nil || !sent {
267			var ok bool
268			select {
269			case request, ok = <-t.startKex:
270				if !ok {
271					break write
272				}
273			case <-t.requestKex:
274				break
275			}
276
277			if !sent {
278				if err := t.sendKexInit(); err != nil {
279					t.recordWriteError(err)
280					break
281				}
282				sent = true
283			}
284		}
285
286		if err := t.getWriteError(); err != nil {
287			if request != nil {
288				request.done <- err
289			}
290			break
291		}
292
293		// We're not servicing t.requestKex, but that is OK:
294		// we never block on sending to t.requestKex.
295
296		// We're not servicing t.startKex, but the remote end
297		// has just sent us a kexInitMsg, so it can't send
298		// another key change request, until we close the done
299		// channel on the pendingKex request.
300
301		err := t.enterKeyExchange(request.otherInit)
302
303		t.mu.Lock()
304		t.writeError = err
305		t.sentInitPacket = nil
306		t.sentInitMsg = nil
307
308		t.resetWriteThresholds()
309
310		// we have completed the key exchange. Since the
311		// reader is still blocked, it is safe to clear out
312		// the requestKex channel. This avoids the situation
313		// where: 1) we consumed our own request for the
314		// initial kex, and 2) the kex from the remote side
315		// caused another send on the requestKex channel,
316	clear:
317		for {
318			select {
319			case <-t.requestKex:
320				//
321			default:
322				break clear
323			}
324		}
325
326		request.done <- t.writeError
327
328		// kex finished. Push packets that we received while
329		// the kex was in progress. Don't look at t.startKex
330		// and don't increment writtenSinceKex: if we trigger
331		// another kex while we are still busy with the last
332		// one, things will become very confusing.
333		for _, p := range t.pendingPackets {
334			t.writeError = t.pushPacket(p)
335			if t.writeError != nil {
336				break
337			}
338		}
339		t.pendingPackets = t.pendingPackets[:0]
340		t.mu.Unlock()
341	}
342
343	// drain startKex channel. We don't service t.requestKex
344	// because nobody does blocking sends there.
345	go func() {
346		for init := range t.startKex {
347			init.done <- t.writeError
348		}
349	}()
350
351	// Unblock reader.
352	t.conn.Close()
353}
354
355// The protocol uses uint32 for packet counters, so we can't let them
356// reach 1<<32.  We will actually read and write more packets than
357// this, though: the other side may send more packets, and after we
358// hit this limit on writing we will send a few more packets for the
359// key exchange itself.
360const packetRekeyThreshold = (1 << 31)
361
362func (t *handshakeTransport) resetReadThresholds() {
363	t.readPacketsLeft = packetRekeyThreshold
364	if t.config.RekeyThreshold > 0 {
365		t.readBytesLeft = int64(t.config.RekeyThreshold)
366	} else if t.algorithms != nil {
367		t.readBytesLeft = t.algorithms.r.rekeyBytes()
368	} else {
369		t.readBytesLeft = 1 << 30
370	}
371}
372
373func (t *handshakeTransport) readOnePacket(first bool) ([]byte, error) {
374	p, err := t.conn.readPacket()
375	if err != nil {
376		return nil, err
377	}
378
379	if t.readPacketsLeft > 0 {
380		t.readPacketsLeft--
381	} else {
382		t.requestKeyExchange()
383	}
384
385	if t.readBytesLeft > 0 {
386		t.readBytesLeft -= int64(len(p))
387	} else {
388		t.requestKeyExchange()
389	}
390
391	if debugHandshake {
392		t.printPacket(p, false)
393	}
394
395	if first && p[0] != msgKexInit {
396		return nil, fmt.Errorf("ssh: first packet should be msgKexInit")
397	}
398
399	if p[0] != msgKexInit {
400		return p, nil
401	}
402
403	firstKex := t.sessionID == nil
404
405	kex := pendingKex{
406		done:      make(chan error, 1),
407		otherInit: p,
408	}
409	t.startKex <- &kex
410	err = <-kex.done
411
412	if debugHandshake {
413		log.Printf("%s exited key exchange (first %v), err %v", t.id(), firstKex, err)
414	}
415
416	if err != nil {
417		return nil, err
418	}
419
420	t.resetReadThresholds()
421
422	// By default, a key exchange is hidden from higher layers by
423	// translating it into msgIgnore.
424	successPacket := []byte{msgIgnore}
425	if firstKex {
426		// sendKexInit() for the first kex waits for
427		// msgNewKeys so the authentication process is
428		// guaranteed to happen over an encrypted transport.
429		successPacket = []byte{msgNewKeys}
430	}
431
432	return successPacket, nil
433}
434
435// sendKexInit sends a key change message.
436func (t *handshakeTransport) sendKexInit() error {
437	t.mu.Lock()
438	defer t.mu.Unlock()
439	if t.sentInitMsg != nil {
440		// kexInits may be sent either in response to the other side,
441		// or because our side wants to initiate a key change, so we
442		// may have already sent a kexInit. In that case, don't send a
443		// second kexInit.
444		return nil
445	}
446
447	msg := &kexInitMsg{
448		KexAlgos:                t.config.KeyExchanges,
449		CiphersClientServer:     t.config.Ciphers,
450		CiphersServerClient:     t.config.Ciphers,
451		MACsClientServer:        t.config.MACs,
452		MACsServerClient:        t.config.MACs,
453		CompressionClientServer: supportedCompressions,
454		CompressionServerClient: supportedCompressions,
455	}
456	io.ReadFull(rand.Reader, msg.Cookie[:])
457
458	if len(t.hostKeys) > 0 {
459		for _, k := range t.hostKeys {
460			msg.ServerHostKeyAlgos = append(
461				msg.ServerHostKeyAlgos, k.PublicKey().Type())
462		}
463	} else {
464		msg.ServerHostKeyAlgos = t.hostKeyAlgorithms
465	}
466	packet := Marshal(msg)
467
468	// writePacket destroys the contents, so save a copy.
469	packetCopy := make([]byte, len(packet))
470	copy(packetCopy, packet)
471
472	if err := t.pushPacket(packetCopy); err != nil {
473		return err
474	}
475
476	t.sentInitMsg = msg
477	t.sentInitPacket = packet
478
479	return nil
480}
481
482func (t *handshakeTransport) writePacket(p []byte) error {
483	switch p[0] {
484	case msgKexInit:
485		return errors.New("ssh: only handshakeTransport can send kexInit")
486	case msgNewKeys:
487		return errors.New("ssh: only handshakeTransport can send newKeys")
488	}
489
490	t.mu.Lock()
491	defer t.mu.Unlock()
492	if t.writeError != nil {
493		return t.writeError
494	}
495
496	if t.sentInitMsg != nil {
497		// Copy the packet so the writer can reuse the buffer.
498		cp := make([]byte, len(p))
499		copy(cp, p)
500		t.pendingPackets = append(t.pendingPackets, cp)
501		return nil
502	}
503
504	if t.writeBytesLeft > 0 {
505		t.writeBytesLeft -= int64(len(p))
506	} else {
507		t.requestKeyExchange()
508	}
509
510	if t.writePacketsLeft > 0 {
511		t.writePacketsLeft--
512	} else {
513		t.requestKeyExchange()
514	}
515
516	if err := t.pushPacket(p); err != nil {
517		t.writeError = err
518	}
519
520	return nil
521}
522
523func (t *handshakeTransport) Close() error {
524	return t.conn.Close()
525}
526
527func (t *handshakeTransport) enterKeyExchange(otherInitPacket []byte) error {
528	if debugHandshake {
529		log.Printf("%s entered key exchange", t.id())
530	}
531
532	otherInit := &kexInitMsg{}
533	if err := Unmarshal(otherInitPacket, otherInit); err != nil {
534		return err
535	}
536
537	magics := handshakeMagics{
538		clientVersion: t.clientVersion,
539		serverVersion: t.serverVersion,
540		clientKexInit: otherInitPacket,
541		serverKexInit: t.sentInitPacket,
542	}
543
544	clientInit := otherInit
545	serverInit := t.sentInitMsg
546	if len(t.hostKeys) == 0 {
547		clientInit, serverInit = serverInit, clientInit
548
549		magics.clientKexInit = t.sentInitPacket
550		magics.serverKexInit = otherInitPacket
551	}
552
553	var err error
554	t.algorithms, err = findAgreedAlgorithms(clientInit, serverInit)
555	if err != nil {
556		return err
557	}
558
559	// We don't send FirstKexFollows, but we handle receiving it.
560	//
561	// RFC 4253 section 7 defines the kex and the agreement method for
562	// first_kex_packet_follows. It states that the guessed packet
563	// should be ignored if the "kex algorithm and/or the host
564	// key algorithm is guessed wrong (server and client have
565	// different preferred algorithm), or if any of the other
566	// algorithms cannot be agreed upon". The other algorithms have
567	// already been checked above so the kex algorithm and host key
568	// algorithm are checked here.
569	if otherInit.FirstKexFollows && (clientInit.KexAlgos[0] != serverInit.KexAlgos[0] || clientInit.ServerHostKeyAlgos[0] != serverInit.ServerHostKeyAlgos[0]) {
570		// other side sent a kex message for the wrong algorithm,
571		// which we have to ignore.
572		if _, err := t.conn.readPacket(); err != nil {
573			return err
574		}
575	}
576
577	kex, ok := kexAlgoMap[t.algorithms.kex]
578	if !ok {
579		return fmt.Errorf("ssh: unexpected key exchange algorithm %v", t.algorithms.kex)
580	}
581
582	var result *kexResult
583	if len(t.hostKeys) > 0 {
584		result, err = t.server(kex, t.algorithms, &magics)
585	} else {
586		result, err = t.client(kex, t.algorithms, &magics)
587	}
588
589	if err != nil {
590		return err
591	}
592
593	if t.sessionID == nil {
594		t.sessionID = result.H
595	}
596	result.SessionID = t.sessionID
597
598	if err := t.conn.prepareKeyChange(t.algorithms, result); err != nil {
599		return err
600	}
601	if err = t.conn.writePacket([]byte{msgNewKeys}); err != nil {
602		return err
603	}
604	if packet, err := t.conn.readPacket(); err != nil {
605		return err
606	} else if packet[0] != msgNewKeys {
607		return unexpectedMessageError(msgNewKeys, packet[0])
608	}
609
610	return nil
611}
612
613func (t *handshakeTransport) server(kex kexAlgorithm, algs *algorithms, magics *handshakeMagics) (*kexResult, error) {
614	var hostKey Signer
615	for _, k := range t.hostKeys {
616		if algs.hostKey == k.PublicKey().Type() {
617			hostKey = k
618		}
619	}
620
621	r, err := kex.Server(t.conn, t.config.Rand, magics, hostKey)
622	return r, err
623}
624
625func (t *handshakeTransport) client(kex kexAlgorithm, algs *algorithms, magics *handshakeMagics) (*kexResult, error) {
626	result, err := kex.Client(t.conn, t.config.Rand, magics)
627	if err != nil {
628		return nil, err
629	}
630
631	hostKey, err := ParsePublicKey(result.HostKey)
632	if err != nil {
633		return nil, err
634	}
635
636	if err := verifyHostKeySignature(hostKey, result); err != nil {
637		return nil, err
638	}
639
640	err = t.hostKeyCallback(t.dialAddress, t.remoteAddr, hostKey)
641	if err != nil {
642		return nil, err
643	}
644
645	return result, nil
646}
647