1// Copyright 2018 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// Package sockstest provides utilities for SOCKS testing.
6package sockstest
7
8import (
9	"errors"
10	"io"
11	"net"
12
13	"golang.org/x/net/internal/socks"
14	"golang.org/x/net/nettest"
15)
16
17// An AuthRequest represents an authentication request.
18type AuthRequest struct {
19	Version int
20	Methods []socks.AuthMethod
21}
22
23// ParseAuthRequest parses an authentication request.
24func ParseAuthRequest(b []byte) (*AuthRequest, error) {
25	if len(b) < 2 {
26		return nil, errors.New("short auth request")
27	}
28	if b[0] != socks.Version5 {
29		return nil, errors.New("unexpected protocol version")
30	}
31	if len(b)-2 < int(b[1]) {
32		return nil, errors.New("short auth request")
33	}
34	req := &AuthRequest{Version: int(b[0])}
35	if b[1] > 0 {
36		req.Methods = make([]socks.AuthMethod, b[1])
37		for i, m := range b[2 : 2+b[1]] {
38			req.Methods[i] = socks.AuthMethod(m)
39		}
40	}
41	return req, nil
42}
43
44// MarshalAuthReply returns an authentication reply in wire format.
45func MarshalAuthReply(ver int, m socks.AuthMethod) ([]byte, error) {
46	return []byte{byte(ver), byte(m)}, nil
47}
48
49// A CmdRequest repesents a command request.
50type CmdRequest struct {
51	Version int
52	Cmd     socks.Command
53	Addr    socks.Addr
54}
55
56// ParseCmdRequest parses a command request.
57func ParseCmdRequest(b []byte) (*CmdRequest, error) {
58	if len(b) < 7 {
59		return nil, errors.New("short cmd request")
60	}
61	if b[0] != socks.Version5 {
62		return nil, errors.New("unexpected protocol version")
63	}
64	if socks.Command(b[1]) != socks.CmdConnect {
65		return nil, errors.New("unexpected command")
66	}
67	if b[2] != 0 {
68		return nil, errors.New("non-zero reserved field")
69	}
70	req := &CmdRequest{Version: int(b[0]), Cmd: socks.Command(b[1])}
71	l := 2
72	off := 4
73	switch b[3] {
74	case socks.AddrTypeIPv4:
75		l += net.IPv4len
76		req.Addr.IP = make(net.IP, net.IPv4len)
77	case socks.AddrTypeIPv6:
78		l += net.IPv6len
79		req.Addr.IP = make(net.IP, net.IPv6len)
80	case socks.AddrTypeFQDN:
81		l += int(b[4])
82		off = 5
83	default:
84		return nil, errors.New("unknown address type")
85	}
86	if len(b[off:]) < l {
87		return nil, errors.New("short cmd request")
88	}
89	if req.Addr.IP != nil {
90		copy(req.Addr.IP, b[off:])
91	} else {
92		req.Addr.Name = string(b[off : off+l-2])
93	}
94	req.Addr.Port = int(b[off+l-2])<<8 | int(b[off+l-1])
95	return req, nil
96}
97
98// MarshalCmdReply returns a command reply in wire format.
99func MarshalCmdReply(ver int, reply socks.Reply, a *socks.Addr) ([]byte, error) {
100	b := make([]byte, 4)
101	b[0] = byte(ver)
102	b[1] = byte(reply)
103	if a.Name != "" {
104		if len(a.Name) > 255 {
105			return nil, errors.New("fqdn too long")
106		}
107		b[3] = socks.AddrTypeFQDN
108		b = append(b, byte(len(a.Name)))
109		b = append(b, a.Name...)
110	} else if ip4 := a.IP.To4(); ip4 != nil {
111		b[3] = socks.AddrTypeIPv4
112		b = append(b, ip4...)
113	} else if ip6 := a.IP.To16(); ip6 != nil {
114		b[3] = socks.AddrTypeIPv6
115		b = append(b, ip6...)
116	} else {
117		return nil, errors.New("unknown address type")
118	}
119	b = append(b, byte(a.Port>>8), byte(a.Port))
120	return b, nil
121}
122
123// A Server repesents a server for handshake testing.
124type Server struct {
125	ln net.Listener
126}
127
128// Addr rerurns a server address.
129func (s *Server) Addr() net.Addr {
130	return s.ln.Addr()
131}
132
133// TargetAddr returns a fake final destination address.
134//
135// The returned address is only valid for testing with Server.
136func (s *Server) TargetAddr() net.Addr {
137	a := s.ln.Addr()
138	switch a := a.(type) {
139	case *net.TCPAddr:
140		if a.IP.To4() != nil {
141			return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 5963}
142		}
143		if a.IP.To16() != nil && a.IP.To4() == nil {
144			return &net.TCPAddr{IP: net.IPv6loopback, Port: 5963}
145		}
146	}
147	return nil
148}
149
150// Close closes the server.
151func (s *Server) Close() error {
152	return s.ln.Close()
153}
154
155func (s *Server) serve(authFunc, cmdFunc func(io.ReadWriter, []byte) error) {
156	c, err := s.ln.Accept()
157	if err != nil {
158		return
159	}
160	defer c.Close()
161	go s.serve(authFunc, cmdFunc)
162	b := make([]byte, 512)
163	n, err := c.Read(b)
164	if err != nil {
165		return
166	}
167	if err := authFunc(c, b[:n]); err != nil {
168		return
169	}
170	n, err = c.Read(b)
171	if err != nil {
172		return
173	}
174	if err := cmdFunc(c, b[:n]); err != nil {
175		return
176	}
177}
178
179// NewServer returns a new server.
180//
181// The provided authFunc and cmdFunc must parse requests and return
182// appropriate replies to clients.
183func NewServer(authFunc, cmdFunc func(io.ReadWriter, []byte) error) (*Server, error) {
184	var err error
185	s := new(Server)
186	s.ln, err = nettest.NewLocalListener("tcp")
187	if err != nil {
188		return nil, err
189	}
190	go s.serve(authFunc, cmdFunc)
191	return s, nil
192}
193
194// NoAuthRequired handles a no-authentication-required signaling.
195func NoAuthRequired(rw io.ReadWriter, b []byte) error {
196	req, err := ParseAuthRequest(b)
197	if err != nil {
198		return err
199	}
200	b, err = MarshalAuthReply(req.Version, socks.AuthMethodNotRequired)
201	if err != nil {
202		return err
203	}
204	n, err := rw.Write(b)
205	if err != nil {
206		return err
207	}
208	if n != len(b) {
209		return errors.New("short write")
210	}
211	return nil
212}
213
214// NoProxyRequired handles a command signaling without constructing a
215// proxy connection to the final destination.
216func NoProxyRequired(rw io.ReadWriter, b []byte) error {
217	req, err := ParseCmdRequest(b)
218	if err != nil {
219		return err
220	}
221	req.Addr.Port += 1
222	if req.Addr.Name != "" {
223		req.Addr.Name = "boundaddr.doesnotexist"
224	} else if req.Addr.IP.To4() != nil {
225		req.Addr.IP = net.IPv4(127, 0, 0, 1)
226	} else {
227		req.Addr.IP = net.IPv6loopback
228	}
229	b, err = MarshalCmdReply(socks.Version5, socks.StatusSucceeded, &req.Addr)
230	if err != nil {
231		return err
232	}
233	n, err := rw.Write(b)
234	if err != nil {
235		return err
236	}
237	if n != len(b) {
238		return errors.New("short write")
239	}
240	return nil
241}
242