1// Copyright 2018 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// Package socks provides a SOCKS version 5 client implementation.
6//
7// SOCKS protocol version 5 is defined in RFC 1928.
8// Username/Password authentication for SOCKS version 5 is defined in
9// RFC 1929.
10package socks
11
12import (
13	"context"
14	"errors"
15	"io"
16	"net"
17	"strconv"
18)
19
20// A Command represents a SOCKS command.
21type Command int
22
23func (cmd Command) String() string {
24	switch cmd {
25	case CmdConnect:
26		return "socks connect"
27	case cmdBind:
28		return "socks bind"
29	default:
30		return "socks " + strconv.Itoa(int(cmd))
31	}
32}
33
34// An AuthMethod represents a SOCKS authentication method.
35type AuthMethod int
36
37// A Reply represents a SOCKS command reply code.
38type Reply int
39
40func (code Reply) String() string {
41	switch code {
42	case StatusSucceeded:
43		return "succeeded"
44	case 0x01:
45		return "general SOCKS server failure"
46	case 0x02:
47		return "connection not allowed by ruleset"
48	case 0x03:
49		return "network unreachable"
50	case 0x04:
51		return "host unreachable"
52	case 0x05:
53		return "connection refused"
54	case 0x06:
55		return "TTL expired"
56	case 0x07:
57		return "command not supported"
58	case 0x08:
59		return "address type not supported"
60	default:
61		return "unknown code: " + strconv.Itoa(int(code))
62	}
63}
64
65// Wire protocol constants.
66const (
67	Version5 = 0x05
68
69	AddrTypeIPv4 = 0x01
70	AddrTypeFQDN = 0x03
71	AddrTypeIPv6 = 0x04
72
73	CmdConnect Command = 0x01 // establishes an active-open forward proxy connection
74	cmdBind    Command = 0x02 // establishes a passive-open forward proxy connection
75
76	AuthMethodNotRequired         AuthMethod = 0x00 // no authentication required
77	AuthMethodUsernamePassword    AuthMethod = 0x02 // use username/password
78	AuthMethodNoAcceptableMethods AuthMethod = 0xff // no acceptable authentication methods
79
80	StatusSucceeded Reply = 0x00
81)
82
83// An Addr represents a SOCKS-specific address.
84// Either Name or IP is used exclusively.
85type Addr struct {
86	Name string // fully-qualified domain name
87	IP   net.IP
88	Port int
89}
90
91func (a *Addr) Network() string { return "socks" }
92
93func (a *Addr) String() string {
94	if a == nil {
95		return "<nil>"
96	}
97	port := strconv.Itoa(a.Port)
98	if a.IP == nil {
99		return net.JoinHostPort(a.Name, port)
100	}
101	return net.JoinHostPort(a.IP.String(), port)
102}
103
104// A Conn represents a forward proxy connection.
105type Conn struct {
106	net.Conn
107
108	boundAddr net.Addr
109}
110
111// BoundAddr returns the address assigned by the proxy server for
112// connecting to the command target address from the proxy server.
113func (c *Conn) BoundAddr() net.Addr {
114	if c == nil {
115		return nil
116	}
117	return c.boundAddr
118}
119
120// A Dialer holds SOCKS-specific options.
121type Dialer struct {
122	cmd          Command // either CmdConnect or cmdBind
123	proxyNetwork string  // network between a proxy server and a client
124	proxyAddress string  // proxy server address
125
126	// ProxyDial specifies the optional dial function for
127	// establishing the transport connection.
128	ProxyDial func(context.Context, string, string) (net.Conn, error)
129
130	// AuthMethods specifies the list of request authentication
131	// methods.
132	// If empty, SOCKS client requests only AuthMethodNotRequired.
133	AuthMethods []AuthMethod
134
135	// Authenticate specifies the optional authentication
136	// function. It must be non-nil when AuthMethods is not empty.
137	// It must return an error when the authentication is failed.
138	Authenticate func(context.Context, io.ReadWriter, AuthMethod) error
139}
140
141// DialContext connects to the provided address on the provided
142// network.
143//
144// The returned error value may be a net.OpError. When the Op field of
145// net.OpError contains "socks", the Source field contains a proxy
146// server address and the Addr field contains a command target
147// address.
148//
149// See func Dial of the net package of standard library for a
150// description of the network and address parameters.
151func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
152	if err := d.validateTarget(network, address); err != nil {
153		proxy, dst, _ := d.pathAddrs(address)
154		return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err}
155	}
156	if ctx == nil {
157		proxy, dst, _ := d.pathAddrs(address)
158		return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: errors.New("nil context")}
159	}
160	var err error
161	var c net.Conn
162	if d.ProxyDial != nil {
163		c, err = d.ProxyDial(ctx, d.proxyNetwork, d.proxyAddress)
164	} else {
165		var dd net.Dialer
166		c, err = dd.DialContext(ctx, d.proxyNetwork, d.proxyAddress)
167	}
168	if err != nil {
169		proxy, dst, _ := d.pathAddrs(address)
170		return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err}
171	}
172	a, err := d.connect(ctx, c, address)
173	if err != nil {
174		c.Close()
175		proxy, dst, _ := d.pathAddrs(address)
176		return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err}
177	}
178	return &Conn{Conn: c, boundAddr: a}, nil
179}
180
181// DialWithConn initiates a connection from SOCKS server to the target
182// network and address using the connection c that is already
183// connected to the SOCKS server.
184//
185// It returns the connection's local address assigned by the SOCKS
186// server.
187func (d *Dialer) DialWithConn(ctx context.Context, c net.Conn, network, address string) (net.Addr, error) {
188	if err := d.validateTarget(network, address); err != nil {
189		proxy, dst, _ := d.pathAddrs(address)
190		return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err}
191	}
192	if ctx == nil {
193		proxy, dst, _ := d.pathAddrs(address)
194		return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: errors.New("nil context")}
195	}
196	a, err := d.connect(ctx, c, address)
197	if err != nil {
198		proxy, dst, _ := d.pathAddrs(address)
199		return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err}
200	}
201	return a, nil
202}
203
204// Dial connects to the provided address on the provided network.
205//
206// Unlike DialContext, it returns a raw transport connection instead
207// of a forward proxy connection.
208//
209// Deprecated: Use DialContext or DialWithConn instead.
210func (d *Dialer) Dial(network, address string) (net.Conn, error) {
211	if err := d.validateTarget(network, address); err != nil {
212		proxy, dst, _ := d.pathAddrs(address)
213		return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err}
214	}
215	var err error
216	var c net.Conn
217	if d.ProxyDial != nil {
218		c, err = d.ProxyDial(context.Background(), d.proxyNetwork, d.proxyAddress)
219	} else {
220		c, err = net.Dial(d.proxyNetwork, d.proxyAddress)
221	}
222	if err != nil {
223		proxy, dst, _ := d.pathAddrs(address)
224		return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err}
225	}
226	if _, err := d.DialWithConn(context.Background(), c, network, address); err != nil {
227		c.Close()
228		return nil, err
229	}
230	return c, nil
231}
232
233func (d *Dialer) validateTarget(network, address string) error {
234	switch network {
235	case "tcp", "tcp6", "tcp4":
236	default:
237		return errors.New("network not implemented")
238	}
239	switch d.cmd {
240	case CmdConnect, cmdBind:
241	default:
242		return errors.New("command not implemented")
243	}
244	return nil
245}
246
247func (d *Dialer) pathAddrs(address string) (proxy, dst net.Addr, err error) {
248	for i, s := range []string{d.proxyAddress, address} {
249		host, port, err := splitHostPort(s)
250		if err != nil {
251			return nil, nil, err
252		}
253		a := &Addr{Port: port}
254		a.IP = net.ParseIP(host)
255		if a.IP == nil {
256			a.Name = host
257		}
258		if i == 0 {
259			proxy = a
260		} else {
261			dst = a
262		}
263	}
264	return
265}
266
267// NewDialer returns a new Dialer that dials through the provided
268// proxy server's network and address.
269func NewDialer(network, address string) *Dialer {
270	return &Dialer{proxyNetwork: network, proxyAddress: address, cmd: CmdConnect}
271}
272
273const (
274	authUsernamePasswordVersion = 0x01
275	authStatusSucceeded         = 0x00
276)
277
278// UsernamePassword are the credentials for the username/password
279// authentication method.
280type UsernamePassword struct {
281	Username string
282	Password string
283}
284
285// Authenticate authenticates a pair of username and password with the
286// proxy server.
287func (up *UsernamePassword) Authenticate(ctx context.Context, rw io.ReadWriter, auth AuthMethod) error {
288	switch auth {
289	case AuthMethodNotRequired:
290		return nil
291	case AuthMethodUsernamePassword:
292		if len(up.Username) == 0 || len(up.Username) > 255 || len(up.Password) == 0 || len(up.Password) > 255 {
293			return errors.New("invalid username/password")
294		}
295		b := []byte{authUsernamePasswordVersion}
296		b = append(b, byte(len(up.Username)))
297		b = append(b, up.Username...)
298		b = append(b, byte(len(up.Password)))
299		b = append(b, up.Password...)
300		// TODO(mikio): handle IO deadlines and cancelation if
301		// necessary
302		if _, err := rw.Write(b); err != nil {
303			return err
304		}
305		if _, err := io.ReadFull(rw, b[:2]); err != nil {
306			return err
307		}
308		if b[0] != authUsernamePasswordVersion {
309			return errors.New("invalid username/password version")
310		}
311		if b[1] != authStatusSucceeded {
312			return errors.New("username/password authentication failed")
313		}
314		return nil
315	}
316	return errors.New("unsupported authentication method " + strconv.Itoa(int(auth)))
317}
318