1package socks5
2
3import (
4	"bytes"
5	"encoding/binary"
6	"io"
7	"log"
8	"net"
9	"os"
10	"strings"
11	"testing"
12)
13
14type MockConn struct {
15	buf bytes.Buffer
16}
17
18func (m *MockConn) Write(b []byte) (int, error) {
19	return m.buf.Write(b)
20}
21
22func (m *MockConn) RemoteAddr() net.Addr {
23	return &net.TCPAddr{IP: []byte{127, 0, 0, 1}, Port: 65432}
24}
25
26func TestRequest_Connect(t *testing.T) {
27	// Create a local listener
28	l, err := net.Listen("tcp", "127.0.0.1:0")
29	if err != nil {
30		t.Fatalf("err: %v", err)
31	}
32	go func() {
33		conn, err := l.Accept()
34		if err != nil {
35			t.Fatalf("err: %v", err)
36		}
37		defer conn.Close()
38
39		buf := make([]byte, 4)
40		if _, err := io.ReadAtLeast(conn, buf, 4); err != nil {
41			t.Fatalf("err: %v", err)
42		}
43
44		if !bytes.Equal(buf, []byte("ping")) {
45			t.Fatalf("bad: %v", buf)
46		}
47		conn.Write([]byte("pong"))
48	}()
49	lAddr := l.Addr().(*net.TCPAddr)
50
51	// Make server
52	s := &Server{config: &Config{
53		Rules:    PermitAll(),
54		Resolver: DNSResolver{},
55		Logger:   log.New(os.Stdout, "", log.LstdFlags),
56	}}
57
58	// Create the connect request
59	buf := bytes.NewBuffer(nil)
60	buf.Write([]byte{5, 1, 0, 1, 127, 0, 0, 1})
61
62	port := []byte{0, 0}
63	binary.BigEndian.PutUint16(port, uint16(lAddr.Port))
64	buf.Write(port)
65
66	// Send a ping
67	buf.Write([]byte("ping"))
68
69	// Handle the request
70	resp := &MockConn{}
71	req, err := NewRequest(buf)
72	if err != nil {
73		t.Fatalf("err: %v", err)
74	}
75
76	if err := s.handleRequest(req, resp); err != nil {
77		t.Fatalf("err: %v", err)
78	}
79
80	// Verify response
81	out := resp.buf.Bytes()
82	expected := []byte{
83		5,
84		0,
85		0,
86		1,
87		127, 0, 0, 1,
88		0, 0,
89		'p', 'o', 'n', 'g',
90	}
91
92	// Ignore the port for both
93	out[8] = 0
94	out[9] = 0
95
96	if !bytes.Equal(out, expected) {
97		t.Fatalf("bad: %v %v", out, expected)
98	}
99}
100
101func TestRequest_Connect_RuleFail(t *testing.T) {
102	// Create a local listener
103	l, err := net.Listen("tcp", "127.0.0.1:0")
104	if err != nil {
105		t.Fatalf("err: %v", err)
106	}
107	go func() {
108		conn, err := l.Accept()
109		if err != nil {
110			t.Fatalf("err: %v", err)
111		}
112		defer conn.Close()
113
114		buf := make([]byte, 4)
115		if _, err := io.ReadAtLeast(conn, buf, 4); err != nil {
116			t.Fatalf("err: %v", err)
117		}
118
119		if !bytes.Equal(buf, []byte("ping")) {
120			t.Fatalf("bad: %v", buf)
121		}
122		conn.Write([]byte("pong"))
123	}()
124	lAddr := l.Addr().(*net.TCPAddr)
125
126	// Make server
127	s := &Server{config: &Config{
128		Rules:    PermitNone(),
129		Resolver: DNSResolver{},
130		Logger:   log.New(os.Stdout, "", log.LstdFlags),
131	}}
132
133	// Create the connect request
134	buf := bytes.NewBuffer(nil)
135	buf.Write([]byte{5, 1, 0, 1, 127, 0, 0, 1})
136
137	port := []byte{0, 0}
138	binary.BigEndian.PutUint16(port, uint16(lAddr.Port))
139	buf.Write(port)
140
141	// Send a ping
142	buf.Write([]byte("ping"))
143
144	// Handle the request
145	resp := &MockConn{}
146	req, err := NewRequest(buf)
147	if err != nil {
148		t.Fatalf("err: %v", err)
149	}
150
151	if err := s.handleRequest(req, resp); !strings.Contains(err.Error(), "blocked by rules") {
152		t.Fatalf("err: %v", err)
153	}
154
155	// Verify response
156	out := resp.buf.Bytes()
157	expected := []byte{
158		5,
159		2,
160		0,
161		1,
162		0, 0, 0, 0,
163		0, 0,
164	}
165
166	if !bytes.Equal(out, expected) {
167		t.Fatalf("bad: %v %v", out, expected)
168	}
169}
170