1// Copyright 2011 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package ssh
6
7import (
8	"errors"
9	"fmt"
10	"io"
11	"math/rand"
12	"net"
13	"strconv"
14	"strings"
15	"sync"
16	"time"
17)
18
19// Listen requests the remote peer open a listening socket on
20// addr. Incoming connections will be available by calling Accept on
21// the returned net.Listener. The listener must be serviced, or the
22// SSH connection may hang.
23// N must be "tcp", "tcp4", "tcp6", or "unix".
24func (c *Client) Listen(n, addr string) (net.Listener, error) {
25	switch n {
26	case "tcp", "tcp4", "tcp6":
27		laddr, err := net.ResolveTCPAddr(n, addr)
28		if err != nil {
29			return nil, err
30		}
31		return c.ListenTCP(laddr)
32	case "unix":
33		return c.ListenUnix(addr)
34	default:
35		return nil, fmt.Errorf("ssh: unsupported protocol: %s", n)
36	}
37}
38
39// Automatic port allocation is broken with OpenSSH before 6.0. See
40// also https://bugzilla.mindrot.org/show_bug.cgi?id=2017.  In
41// particular, OpenSSH 5.9 sends a channelOpenMsg with port number 0,
42// rather than the actual port number. This means you can never open
43// two different listeners with auto allocated ports. We work around
44// this by trying explicit ports until we succeed.
45
46const openSSHPrefix = "OpenSSH_"
47
48var portRandomizer = rand.New(rand.NewSource(time.Now().UnixNano()))
49
50// isBrokenOpenSSHVersion returns true if the given version string
51// specifies a version of OpenSSH that is known to have a bug in port
52// forwarding.
53func isBrokenOpenSSHVersion(versionStr string) bool {
54	i := strings.Index(versionStr, openSSHPrefix)
55	if i < 0 {
56		return false
57	}
58	i += len(openSSHPrefix)
59	j := i
60	for ; j < len(versionStr); j++ {
61		if versionStr[j] < '0' || versionStr[j] > '9' {
62			break
63		}
64	}
65	version, _ := strconv.Atoi(versionStr[i:j])
66	return version < 6
67}
68
69// autoPortListenWorkaround simulates automatic port allocation by
70// trying random ports repeatedly.
71func (c *Client) autoPortListenWorkaround(laddr *net.TCPAddr) (net.Listener, error) {
72	var sshListener net.Listener
73	var err error
74	const tries = 10
75	for i := 0; i < tries; i++ {
76		addr := *laddr
77		addr.Port = 1024 + portRandomizer.Intn(60000)
78		sshListener, err = c.ListenTCP(&addr)
79		if err == nil {
80			laddr.Port = addr.Port
81			return sshListener, err
82		}
83	}
84	return nil, fmt.Errorf("ssh: listen on random port failed after %d tries: %v", tries, err)
85}
86
87// RFC 4254 7.1
88type channelForwardMsg struct {
89	addr  string
90	rport uint32
91}
92
93// handleForwards starts goroutines handling forwarded connections.
94// It's called on first use by (*Client).ListenTCP to not launch
95// goroutines until needed.
96func (c *Client) handleForwards() {
97	go c.forwards.handleChannels(c.HandleChannelOpen("forwarded-tcpip"))
98	go c.forwards.handleChannels(c.HandleChannelOpen("forwarded-streamlocal@openssh.com"))
99}
100
101// ListenTCP requests the remote peer open a listening socket
102// on laddr. Incoming connections will be available by calling
103// Accept on the returned net.Listener.
104func (c *Client) ListenTCP(laddr *net.TCPAddr) (net.Listener, error) {
105	c.handleForwardsOnce.Do(c.handleForwards)
106	if laddr.Port == 0 && isBrokenOpenSSHVersion(string(c.ServerVersion())) {
107		return c.autoPortListenWorkaround(laddr)
108	}
109
110	m := channelForwardMsg{
111		laddr.IP.String(),
112		uint32(laddr.Port),
113	}
114	// send message
115	ok, resp, err := c.SendRequest("tcpip-forward", true, Marshal(&m))
116	if err != nil {
117		return nil, err
118	}
119	if !ok {
120		return nil, errors.New("ssh: tcpip-forward request denied by peer")
121	}
122
123	// If the original port was 0, then the remote side will
124	// supply a real port number in the response.
125	if laddr.Port == 0 {
126		var p struct {
127			Port uint32
128		}
129		if err := Unmarshal(resp, &p); err != nil {
130			return nil, err
131		}
132		laddr.Port = int(p.Port)
133	}
134
135	// Register this forward, using the port number we obtained.
136	ch := c.forwards.add(laddr)
137
138	return &tcpListener{laddr, c, ch}, nil
139}
140
141// forwardList stores a mapping between remote
142// forward requests and the tcpListeners.
143type forwardList struct {
144	sync.Mutex
145	entries []forwardEntry
146}
147
148// forwardEntry represents an established mapping of a laddr on a
149// remote ssh server to a channel connected to a tcpListener.
150type forwardEntry struct {
151	laddr net.Addr
152	c     chan forward
153}
154
155// forward represents an incoming forwarded tcpip connection. The
156// arguments to add/remove/lookup should be address as specified in
157// the original forward-request.
158type forward struct {
159	newCh NewChannel // the ssh client channel underlying this forward
160	raddr net.Addr   // the raddr of the incoming connection
161}
162
163func (l *forwardList) add(addr net.Addr) chan forward {
164	l.Lock()
165	defer l.Unlock()
166	f := forwardEntry{
167		laddr: addr,
168		c:     make(chan forward, 1),
169	}
170	l.entries = append(l.entries, f)
171	return f.c
172}
173
174// See RFC 4254, section 7.2
175type forwardedTCPPayload struct {
176	Addr       string
177	Port       uint32
178	OriginAddr string
179	OriginPort uint32
180}
181
182// parseTCPAddr parses the originating address from the remote into a *net.TCPAddr.
183func parseTCPAddr(addr string, port uint32) (*net.TCPAddr, error) {
184	if port == 0 || port > 65535 {
185		return nil, fmt.Errorf("ssh: port number out of range: %d", port)
186	}
187	ip := net.ParseIP(string(addr))
188	if ip == nil {
189		return nil, fmt.Errorf("ssh: cannot parse IP address %q", addr)
190	}
191	return &net.TCPAddr{IP: ip, Port: int(port)}, nil
192}
193
194func (l *forwardList) handleChannels(in <-chan NewChannel) {
195	for ch := range in {
196		var (
197			laddr net.Addr
198			raddr net.Addr
199			err   error
200		)
201		switch channelType := ch.ChannelType(); channelType {
202		case "forwarded-tcpip":
203			var payload forwardedTCPPayload
204			if err = Unmarshal(ch.ExtraData(), &payload); err != nil {
205				ch.Reject(ConnectionFailed, "could not parse forwarded-tcpip payload: "+err.Error())
206				continue
207			}
208
209			// RFC 4254 section 7.2 specifies that incoming
210			// addresses should list the address, in string
211			// format. It is implied that this should be an IP
212			// address, as it would be impossible to connect to it
213			// otherwise.
214			laddr, err = parseTCPAddr(payload.Addr, payload.Port)
215			if err != nil {
216				ch.Reject(ConnectionFailed, err.Error())
217				continue
218			}
219			raddr, err = parseTCPAddr(payload.OriginAddr, payload.OriginPort)
220			if err != nil {
221				ch.Reject(ConnectionFailed, err.Error())
222				continue
223			}
224
225		case "forwarded-streamlocal@openssh.com":
226			var payload forwardedStreamLocalPayload
227			if err = Unmarshal(ch.ExtraData(), &payload); err != nil {
228				ch.Reject(ConnectionFailed, "could not parse forwarded-streamlocal@openssh.com payload: "+err.Error())
229				continue
230			}
231			laddr = &net.UnixAddr{
232				Name: payload.SocketPath,
233				Net:  "unix",
234			}
235			raddr = &net.UnixAddr{
236				Name: "@",
237				Net:  "unix",
238			}
239		default:
240			panic(fmt.Errorf("ssh: unknown channel type %s", channelType))
241		}
242		if ok := l.forward(laddr, raddr, ch); !ok {
243			// Section 7.2, implementations MUST reject spurious incoming
244			// connections.
245			ch.Reject(Prohibited, "no forward for address")
246			continue
247		}
248
249	}
250}
251
252// remove removes the forward entry, and the channel feeding its
253// listener.
254func (l *forwardList) remove(addr net.Addr) {
255	l.Lock()
256	defer l.Unlock()
257	for i, f := range l.entries {
258		if addr.Network() == f.laddr.Network() && addr.String() == f.laddr.String() {
259			l.entries = append(l.entries[:i], l.entries[i+1:]...)
260			close(f.c)
261			return
262		}
263	}
264}
265
266// closeAll closes and clears all forwards.
267func (l *forwardList) closeAll() {
268	l.Lock()
269	defer l.Unlock()
270	for _, f := range l.entries {
271		close(f.c)
272	}
273	l.entries = nil
274}
275
276func (l *forwardList) forward(laddr, raddr net.Addr, ch NewChannel) bool {
277	l.Lock()
278	defer l.Unlock()
279	for _, f := range l.entries {
280		if laddr.Network() == f.laddr.Network() && laddr.String() == f.laddr.String() {
281			f.c <- forward{newCh: ch, raddr: raddr}
282			return true
283		}
284	}
285	return false
286}
287
288type tcpListener struct {
289	laddr *net.TCPAddr
290
291	conn *Client
292	in   <-chan forward
293}
294
295// Accept waits for and returns the next connection to the listener.
296func (l *tcpListener) Accept() (net.Conn, error) {
297	s, ok := <-l.in
298	if !ok {
299		return nil, io.EOF
300	}
301	ch, incoming, err := s.newCh.Accept()
302	if err != nil {
303		return nil, err
304	}
305	go DiscardRequests(incoming)
306
307	return &chanConn{
308		Channel: ch,
309		laddr:   l.laddr,
310		raddr:   s.raddr,
311	}, nil
312}
313
314// Close closes the listener.
315func (l *tcpListener) Close() error {
316	m := channelForwardMsg{
317		l.laddr.IP.String(),
318		uint32(l.laddr.Port),
319	}
320
321	// this also closes the listener.
322	l.conn.forwards.remove(l.laddr)
323	ok, _, err := l.conn.SendRequest("cancel-tcpip-forward", true, Marshal(&m))
324	if err == nil && !ok {
325		err = errors.New("ssh: cancel-tcpip-forward failed")
326	}
327	return err
328}
329
330// Addr returns the listener's network address.
331func (l *tcpListener) Addr() net.Addr {
332	return l.laddr
333}
334
335// Dial initiates a connection to the addr from the remote host.
336// The resulting connection has a zero LocalAddr() and RemoteAddr().
337func (c *Client) Dial(n, addr string) (net.Conn, error) {
338	var ch Channel
339	switch n {
340	case "tcp", "tcp4", "tcp6":
341		// Parse the address into host and numeric port.
342		host, portString, err := net.SplitHostPort(addr)
343		if err != nil {
344			return nil, err
345		}
346		port, err := strconv.ParseUint(portString, 10, 16)
347		if err != nil {
348			return nil, err
349		}
350		ch, err = c.dial(net.IPv4zero.String(), 0, host, int(port))
351		if err != nil {
352			return nil, err
353		}
354		// Use a zero address for local and remote address.
355		zeroAddr := &net.TCPAddr{
356			IP:   net.IPv4zero,
357			Port: 0,
358		}
359		return &chanConn{
360			Channel: ch,
361			laddr:   zeroAddr,
362			raddr:   zeroAddr,
363		}, nil
364	case "unix":
365		var err error
366		ch, err = c.dialStreamLocal(addr)
367		if err != nil {
368			return nil, err
369		}
370		return &chanConn{
371			Channel: ch,
372			laddr: &net.UnixAddr{
373				Name: "@",
374				Net:  "unix",
375			},
376			raddr: &net.UnixAddr{
377				Name: addr,
378				Net:  "unix",
379			},
380		}, nil
381	default:
382		return nil, fmt.Errorf("ssh: unsupported protocol: %s", n)
383	}
384}
385
386// DialTCP connects to the remote address raddr on the network net,
387// which must be "tcp", "tcp4", or "tcp6".  If laddr is not nil, it is used
388// as the local address for the connection.
389func (c *Client) DialTCP(n string, laddr, raddr *net.TCPAddr) (net.Conn, error) {
390	if laddr == nil {
391		laddr = &net.TCPAddr{
392			IP:   net.IPv4zero,
393			Port: 0,
394		}
395	}
396	ch, err := c.dial(laddr.IP.String(), laddr.Port, raddr.IP.String(), raddr.Port)
397	if err != nil {
398		return nil, err
399	}
400	return &chanConn{
401		Channel: ch,
402		laddr:   laddr,
403		raddr:   raddr,
404	}, nil
405}
406
407// RFC 4254 7.2
408type channelOpenDirectMsg struct {
409	raddr string
410	rport uint32
411	laddr string
412	lport uint32
413}
414
415func (c *Client) dial(laddr string, lport int, raddr string, rport int) (Channel, error) {
416	msg := channelOpenDirectMsg{
417		raddr: raddr,
418		rport: uint32(rport),
419		laddr: laddr,
420		lport: uint32(lport),
421	}
422	ch, in, err := c.OpenChannel("direct-tcpip", Marshal(&msg))
423	if err != nil {
424		return nil, err
425	}
426	go DiscardRequests(in)
427	return ch, err
428}
429
430type tcpChan struct {
431	Channel // the backing channel
432}
433
434// chanConn fulfills the net.Conn interface without
435// the tcpChan having to hold laddr or raddr directly.
436type chanConn struct {
437	Channel
438	laddr, raddr net.Addr
439}
440
441// LocalAddr returns the local network address.
442func (t *chanConn) LocalAddr() net.Addr {
443	return t.laddr
444}
445
446// RemoteAddr returns the remote network address.
447func (t *chanConn) RemoteAddr() net.Addr {
448	return t.raddr
449}
450
451// SetDeadline sets the read and write deadlines associated
452// with the connection.
453func (t *chanConn) SetDeadline(deadline time.Time) error {
454	if err := t.SetReadDeadline(deadline); err != nil {
455		return err
456	}
457	return t.SetWriteDeadline(deadline)
458}
459
460// SetReadDeadline sets the read deadline.
461// A zero value for t means Read will not time out.
462// After the deadline, the error from Read will implement net.Error
463// with Timeout() == true.
464func (t *chanConn) SetReadDeadline(deadline time.Time) error {
465	// for compatibility with previous version,
466	// the error message contains "tcpChan"
467	return errors.New("ssh: tcpChan: deadline not supported")
468}
469
470// SetWriteDeadline exists to satisfy the net.Conn interface
471// but is not implemented by this type.  It always returns an error.
472func (t *chanConn) SetWriteDeadline(deadline time.Time) error {
473	return errors.New("ssh: tcpChan: deadline not supported")
474}
475