1package socks5 2 3import ( 4 "fmt" 5 "io" 6 "net" 7 "strconv" 8 "strings" 9 10 "github.com/ooni/psiphon/oopsi/golang.org/x/net/context" 11) 12 13const ( 14 ConnectCommand = uint8(1) 15 BindCommand = uint8(2) 16 AssociateCommand = uint8(3) 17 ipv4Address = uint8(1) 18 fqdnAddress = uint8(3) 19 ipv6Address = uint8(4) 20) 21 22const ( 23 successReply uint8 = iota 24 serverFailure 25 ruleFailure 26 networkUnreachable 27 hostUnreachable 28 connectionRefused 29 ttlExpired 30 commandNotSupported 31 addrTypeNotSupported 32) 33 34var ( 35 unrecognizedAddrType = fmt.Errorf("Unrecognized address type") 36) 37 38// AddressRewriter is used to rewrite a destination transparently 39type AddressRewriter interface { 40 Rewrite(ctx context.Context, request *Request) (context.Context, *AddrSpec) 41} 42 43// AddrSpec is used to return the target AddrSpec 44// which may be specified as IPv4, IPv6, or a FQDN 45type AddrSpec struct { 46 FQDN string 47 IP net.IP 48 Port int 49} 50 51func (a *AddrSpec) String() string { 52 if a.FQDN != "" { 53 return fmt.Sprintf("%s (%s):%d", a.FQDN, a.IP, a.Port) 54 } 55 return fmt.Sprintf("%s:%d", a.IP, a.Port) 56} 57 58// Address returns a string suitable to dial; prefer returning IP-based 59// address, fallback to FQDN 60func (a AddrSpec) Address() string { 61 if 0 != len(a.IP) { 62 return net.JoinHostPort(a.IP.String(), strconv.Itoa(a.Port)) 63 } 64 return net.JoinHostPort(a.FQDN, strconv.Itoa(a.Port)) 65} 66 67// A Request represents request received by a server 68type Request struct { 69 // Protocol version 70 Version uint8 71 // Requested command 72 Command uint8 73 // AuthContext provided during negotiation 74 AuthContext *AuthContext 75 // AddrSpec of the the network that sent the request 76 RemoteAddr *AddrSpec 77 // AddrSpec of the desired destination 78 DestAddr *AddrSpec 79 // AddrSpec of the actual destination (might be affected by rewrite) 80 realDestAddr *AddrSpec 81 bufConn io.Reader 82} 83 84type conn interface { 85 Write([]byte) (int, error) 86 RemoteAddr() net.Addr 87} 88 89// NewRequest creates a new Request from the tcp connection 90func NewRequest(bufConn io.Reader) (*Request, error) { 91 // Read the version byte 92 header := []byte{0, 0, 0} 93 if _, err := io.ReadAtLeast(bufConn, header, 3); err != nil { 94 return nil, fmt.Errorf("Failed to get command version: %v", err) 95 } 96 97 // Ensure we are compatible 98 if header[0] != socks5Version { 99 return nil, fmt.Errorf("Unsupported command version: %v", header[0]) 100 } 101 102 // Read in the destination address 103 dest, err := readAddrSpec(bufConn) 104 if err != nil { 105 return nil, err 106 } 107 108 request := &Request{ 109 Version: socks5Version, 110 Command: header[1], 111 DestAddr: dest, 112 bufConn: bufConn, 113 } 114 115 return request, nil 116} 117 118// handleRequest is used for request processing after authentication 119func (s *Server) handleRequest(req *Request, conn conn) error { 120 ctx := context.Background() 121 122 // Resolve the address if we have a FQDN 123 dest := req.DestAddr 124 if dest.FQDN != "" { 125 ctx_, addr, err := s.config.Resolver.Resolve(ctx, dest.FQDN) 126 if err != nil { 127 if err := sendReply(conn, hostUnreachable, nil); err != nil { 128 return fmt.Errorf("Failed to send reply: %v", err) 129 } 130 return fmt.Errorf("Failed to resolve destination '%v': %v", dest.FQDN, err) 131 } 132 ctx = ctx_ 133 dest.IP = addr 134 } 135 136 // Apply any address rewrites 137 req.realDestAddr = req.DestAddr 138 if s.config.Rewriter != nil { 139 ctx, req.realDestAddr = s.config.Rewriter.Rewrite(ctx, req) 140 } 141 142 // Switch on the command 143 switch req.Command { 144 case ConnectCommand: 145 return s.handleConnect(ctx, conn, req) 146 case BindCommand: 147 return s.handleBind(ctx, conn, req) 148 case AssociateCommand: 149 return s.handleAssociate(ctx, conn, req) 150 default: 151 if err := sendReply(conn, commandNotSupported, nil); err != nil { 152 return fmt.Errorf("Failed to send reply: %v", err) 153 } 154 return fmt.Errorf("Unsupported command: %v", req.Command) 155 } 156} 157 158// handleConnect is used to handle a connect command 159func (s *Server) handleConnect(ctx context.Context, conn conn, req *Request) error { 160 // Check if this is allowed 161 if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok { 162 if err := sendReply(conn, ruleFailure, nil); err != nil { 163 return fmt.Errorf("Failed to send reply: %v", err) 164 } 165 return fmt.Errorf("Connect to %v blocked by rules", req.DestAddr) 166 } else { 167 ctx = ctx_ 168 } 169 170 // Attempt to connect 171 dial := s.config.Dial 172 if dial == nil { 173 dial = func(ctx context.Context, net_, addr string) (net.Conn, error) { 174 return net.Dial(net_, addr) 175 } 176 } 177 target, err := dial(ctx, "tcp", req.realDestAddr.Address()) 178 if err != nil { 179 msg := err.Error() 180 resp := hostUnreachable 181 if strings.Contains(msg, "refused") { 182 resp = connectionRefused 183 } else if strings.Contains(msg, "network is unreachable") { 184 resp = networkUnreachable 185 } 186 if err := sendReply(conn, resp, nil); err != nil { 187 return fmt.Errorf("Failed to send reply: %v", err) 188 } 189 return fmt.Errorf("Connect to %v failed: %v", req.DestAddr, err) 190 } 191 defer target.Close() 192 193 // Send success 194 local := target.LocalAddr().(*net.TCPAddr) 195 bind := AddrSpec{IP: local.IP, Port: local.Port} 196 if err := sendReply(conn, successReply, &bind); err != nil { 197 return fmt.Errorf("Failed to send reply: %v", err) 198 } 199 200 // Start proxying 201 errCh := make(chan error, 2) 202 go proxy(target, req.bufConn, errCh) 203 go proxy(conn, target, errCh) 204 205 // Wait 206 for i := 0; i < 2; i++ { 207 e := <-errCh 208 if e != nil { 209 // return from this function closes target (and conn). 210 return e 211 } 212 } 213 return nil 214} 215 216// handleBind is used to handle a connect command 217func (s *Server) handleBind(ctx context.Context, conn conn, req *Request) error { 218 // Check if this is allowed 219 if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok { 220 if err := sendReply(conn, ruleFailure, nil); err != nil { 221 return fmt.Errorf("Failed to send reply: %v", err) 222 } 223 return fmt.Errorf("Bind to %v blocked by rules", req.DestAddr) 224 } else { 225 ctx = ctx_ 226 } 227 228 // TODO: Support bind 229 if err := sendReply(conn, commandNotSupported, nil); err != nil { 230 return fmt.Errorf("Failed to send reply: %v", err) 231 } 232 return nil 233} 234 235// handleAssociate is used to handle a connect command 236func (s *Server) handleAssociate(ctx context.Context, conn conn, req *Request) error { 237 // Check if this is allowed 238 if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok { 239 if err := sendReply(conn, ruleFailure, nil); err != nil { 240 return fmt.Errorf("Failed to send reply: %v", err) 241 } 242 return fmt.Errorf("Associate to %v blocked by rules", req.DestAddr) 243 } else { 244 ctx = ctx_ 245 } 246 247 // TODO: Support associate 248 if err := sendReply(conn, commandNotSupported, nil); err != nil { 249 return fmt.Errorf("Failed to send reply: %v", err) 250 } 251 return nil 252} 253 254// readAddrSpec is used to read AddrSpec. 255// Expects an address type byte, follwed by the address and port 256func readAddrSpec(r io.Reader) (*AddrSpec, error) { 257 d := &AddrSpec{} 258 259 // Get the address type 260 addrType := []byte{0} 261 if _, err := r.Read(addrType); err != nil { 262 return nil, err 263 } 264 265 // Handle on a per type basis 266 switch addrType[0] { 267 case ipv4Address: 268 addr := make([]byte, 4) 269 if _, err := io.ReadAtLeast(r, addr, len(addr)); err != nil { 270 return nil, err 271 } 272 d.IP = net.IP(addr) 273 274 case ipv6Address: 275 addr := make([]byte, 16) 276 if _, err := io.ReadAtLeast(r, addr, len(addr)); err != nil { 277 return nil, err 278 } 279 d.IP = net.IP(addr) 280 281 case fqdnAddress: 282 if _, err := r.Read(addrType); err != nil { 283 return nil, err 284 } 285 addrLen := int(addrType[0]) 286 fqdn := make([]byte, addrLen) 287 if _, err := io.ReadAtLeast(r, fqdn, addrLen); err != nil { 288 return nil, err 289 } 290 d.FQDN = string(fqdn) 291 292 default: 293 return nil, unrecognizedAddrType 294 } 295 296 // Read the port 297 port := []byte{0, 0} 298 if _, err := io.ReadAtLeast(r, port, 2); err != nil { 299 return nil, err 300 } 301 d.Port = (int(port[0]) << 8) | int(port[1]) 302 303 return d, nil 304} 305 306// sendReply is used to send a reply message 307func sendReply(w io.Writer, resp uint8, addr *AddrSpec) error { 308 // Format the address 309 var addrType uint8 310 var addrBody []byte 311 var addrPort uint16 312 switch { 313 case addr == nil: 314 addrType = ipv4Address 315 addrBody = []byte{0, 0, 0, 0} 316 addrPort = 0 317 318 case addr.FQDN != "": 319 addrType = fqdnAddress 320 addrBody = append([]byte{byte(len(addr.FQDN))}, addr.FQDN...) 321 addrPort = uint16(addr.Port) 322 323 case addr.IP.To4() != nil: 324 addrType = ipv4Address 325 addrBody = []byte(addr.IP.To4()) 326 addrPort = uint16(addr.Port) 327 328 case addr.IP.To16() != nil: 329 addrType = ipv6Address 330 addrBody = []byte(addr.IP.To16()) 331 addrPort = uint16(addr.Port) 332 333 default: 334 return fmt.Errorf("Failed to format address: %v", addr) 335 } 336 337 // Format the message 338 msg := make([]byte, 6+len(addrBody)) 339 msg[0] = socks5Version 340 msg[1] = resp 341 msg[2] = 0 // Reserved 342 msg[3] = addrType 343 copy(msg[4:], addrBody) 344 msg[4+len(addrBody)] = byte(addrPort >> 8) 345 msg[4+len(addrBody)+1] = byte(addrPort & 0xff) 346 347 // Send the message 348 _, err := w.Write(msg) 349 return err 350} 351 352type closeWriter interface { 353 CloseWrite() error 354} 355 356// proxy is used to suffle data from src to destination, and sends errors 357// down a dedicated channel 358func proxy(dst io.Writer, src io.Reader, errCh chan error) { 359 _, err := io.Copy(dst, src) 360 if tcpConn, ok := dst.(closeWriter); ok { 361 tcpConn.CloseWrite() 362 } 363 errCh <- err 364} 365