1package socks5 2 3import ( 4 "bytes" 5 "encoding/binary" 6 "io" 7 "log" 8 "net" 9 "os" 10 "strings" 11 "testing" 12) 13 14type MockConn struct { 15 buf bytes.Buffer 16} 17 18func (m *MockConn) Write(b []byte) (int, error) { 19 return m.buf.Write(b) 20} 21 22func (m *MockConn) RemoteAddr() net.Addr { 23 return &net.TCPAddr{IP: []byte{127, 0, 0, 1}, Port: 65432} 24} 25 26func TestRequest_Connect(t *testing.T) { 27 // Create a local listener 28 l, err := net.Listen("tcp", "127.0.0.1:0") 29 if err != nil { 30 t.Fatalf("err: %v", err) 31 } 32 go func() { 33 conn, err := l.Accept() 34 if err != nil { 35 t.Fatalf("err: %v", err) 36 } 37 defer conn.Close() 38 39 buf := make([]byte, 4) 40 if _, err := io.ReadAtLeast(conn, buf, 4); err != nil { 41 t.Fatalf("err: %v", err) 42 } 43 44 if !bytes.Equal(buf, []byte("ping")) { 45 t.Fatalf("bad: %v", buf) 46 } 47 conn.Write([]byte("pong")) 48 }() 49 lAddr := l.Addr().(*net.TCPAddr) 50 51 // Make server 52 s := &Server{config: &Config{ 53 Rules: PermitAll(), 54 Resolver: DNSResolver{}, 55 Logger: log.New(os.Stdout, "", log.LstdFlags), 56 }} 57 58 // Create the connect request 59 buf := bytes.NewBuffer(nil) 60 buf.Write([]byte{5, 1, 0, 1, 127, 0, 0, 1}) 61 62 port := []byte{0, 0} 63 binary.BigEndian.PutUint16(port, uint16(lAddr.Port)) 64 buf.Write(port) 65 66 // Send a ping 67 buf.Write([]byte("ping")) 68 69 // Handle the request 70 resp := &MockConn{} 71 req, err := NewRequest(buf) 72 if err != nil { 73 t.Fatalf("err: %v", err) 74 } 75 76 if err := s.handleRequest(req, resp); err != nil { 77 t.Fatalf("err: %v", err) 78 } 79 80 // Verify response 81 out := resp.buf.Bytes() 82 expected := []byte{ 83 5, 84 0, 85 0, 86 1, 87 127, 0, 0, 1, 88 0, 0, 89 'p', 'o', 'n', 'g', 90 } 91 92 // Ignore the port for both 93 out[8] = 0 94 out[9] = 0 95 96 if !bytes.Equal(out, expected) { 97 t.Fatalf("bad: %v %v", out, expected) 98 } 99} 100 101func TestRequest_Connect_RuleFail(t *testing.T) { 102 // Create a local listener 103 l, err := net.Listen("tcp", "127.0.0.1:0") 104 if err != nil { 105 t.Fatalf("err: %v", err) 106 } 107 go func() { 108 conn, err := l.Accept() 109 if err != nil { 110 t.Fatalf("err: %v", err) 111 } 112 defer conn.Close() 113 114 buf := make([]byte, 4) 115 if _, err := io.ReadAtLeast(conn, buf, 4); err != nil { 116 t.Fatalf("err: %v", err) 117 } 118 119 if !bytes.Equal(buf, []byte("ping")) { 120 t.Fatalf("bad: %v", buf) 121 } 122 conn.Write([]byte("pong")) 123 }() 124 lAddr := l.Addr().(*net.TCPAddr) 125 126 // Make server 127 s := &Server{config: &Config{ 128 Rules: PermitNone(), 129 Resolver: DNSResolver{}, 130 Logger: log.New(os.Stdout, "", log.LstdFlags), 131 }} 132 133 // Create the connect request 134 buf := bytes.NewBuffer(nil) 135 buf.Write([]byte{5, 1, 0, 1, 127, 0, 0, 1}) 136 137 port := []byte{0, 0} 138 binary.BigEndian.PutUint16(port, uint16(lAddr.Port)) 139 buf.Write(port) 140 141 // Send a ping 142 buf.Write([]byte("ping")) 143 144 // Handle the request 145 resp := &MockConn{} 146 req, err := NewRequest(buf) 147 if err != nil { 148 t.Fatalf("err: %v", err) 149 } 150 151 if err := s.handleRequest(req, resp); !strings.Contains(err.Error(), "blocked by rules") { 152 t.Fatalf("err: %v", err) 153 } 154 155 // Verify response 156 out := resp.buf.Bytes() 157 expected := []byte{ 158 5, 159 2, 160 0, 161 1, 162 0, 0, 0, 0, 163 0, 0, 164 } 165 166 if !bytes.Equal(out, expected) { 167 t.Fatalf("bad: %v %v", out, expected) 168 } 169} 170