1package ssh
2
3import (
4	"errors"
5	"io"
6	"net"
7)
8
9// streamLocalChannelOpenDirectMsg is a struct used for SSH_MSG_CHANNEL_OPEN message
10// with "direct-streamlocal@openssh.com" string.
11//
12// See openssh-portable/PROTOCOL, section 2.4. connection: Unix domain socket forwarding
13// https://github.com/openssh/openssh-portable/blob/master/PROTOCOL#L235
14type streamLocalChannelOpenDirectMsg struct {
15	socketPath string
16	reserved0  string
17	reserved1  uint32
18}
19
20// forwardedStreamLocalPayload is a struct used for SSH_MSG_CHANNEL_OPEN message
21// with "forwarded-streamlocal@openssh.com" string.
22type forwardedStreamLocalPayload struct {
23	SocketPath string
24	Reserved0  string
25}
26
27// streamLocalChannelForwardMsg is a struct used for SSH2_MSG_GLOBAL_REQUEST message
28// with "streamlocal-forward@openssh.com"/"cancel-streamlocal-forward@openssh.com" string.
29type streamLocalChannelForwardMsg struct {
30	socketPath string
31}
32
33// ListenUnix is similar to ListenTCP but uses a Unix domain socket.
34func (c *Client) ListenUnix(socketPath string) (net.Listener, error) {
35	c.handleForwardsOnce.Do(c.handleForwards)
36	m := streamLocalChannelForwardMsg{
37		socketPath,
38	}
39	// send message
40	ok, _, err := c.SendRequest("streamlocal-forward@openssh.com", true, Marshal(&m))
41	if err != nil {
42		return nil, err
43	}
44	if !ok {
45		return nil, errors.New("ssh: streamlocal-forward@openssh.com request denied by peer")
46	}
47	ch := c.forwards.add(&net.UnixAddr{Name: socketPath, Net: "unix"})
48
49	return &unixListener{socketPath, c, ch}, nil
50}
51
52func (c *Client) dialStreamLocal(socketPath string) (Channel, error) {
53	msg := streamLocalChannelOpenDirectMsg{
54		socketPath: socketPath,
55	}
56	ch, in, err := c.OpenChannel("direct-streamlocal@openssh.com", Marshal(&msg))
57	if err != nil {
58		return nil, err
59	}
60	go DiscardRequests(in)
61	return ch, err
62}
63
64type unixListener struct {
65	socketPath string
66
67	conn *Client
68	in   <-chan forward
69}
70
71// Accept waits for and returns the next connection to the listener.
72func (l *unixListener) Accept() (net.Conn, error) {
73	s, ok := <-l.in
74	if !ok {
75		return nil, io.EOF
76	}
77	ch, incoming, err := s.newCh.Accept()
78	if err != nil {
79		return nil, err
80	}
81	go DiscardRequests(incoming)
82
83	return &chanConn{
84		Channel: ch,
85		laddr: &net.UnixAddr{
86			Name: l.socketPath,
87			Net:  "unix",
88		},
89		raddr: &net.UnixAddr{
90			Name: "@",
91			Net:  "unix",
92		},
93	}, nil
94}
95
96// Close closes the listener.
97func (l *unixListener) Close() error {
98	// this also closes the listener.
99	l.conn.forwards.remove(&net.UnixAddr{Name: l.socketPath, Net: "unix"})
100	m := streamLocalChannelForwardMsg{
101		l.socketPath,
102	}
103	ok, _, err := l.conn.SendRequest("cancel-streamlocal-forward@openssh.com", true, Marshal(&m))
104	if err == nil && !ok {
105		err = errors.New("ssh: cancel-streamlocal-forward@openssh.com failed")
106	}
107	return err
108}
109
110// Addr returns the listener's network address.
111func (l *unixListener) Addr() net.Addr {
112	return &net.UnixAddr{
113		Name: l.socketPath,
114		Net:  "unix",
115	}
116}
117