1// Copyright 2018 The Go Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style 3// license that can be found in the LICENSE file. 4 5// Package sockstest provides utilities for SOCKS testing. 6package sockstest 7 8import ( 9 "errors" 10 "io" 11 "net" 12 13 "golang.org/x/net/internal/socks" 14 "golang.org/x/net/nettest" 15) 16 17// An AuthRequest represents an authentication request. 18type AuthRequest struct { 19 Version int 20 Methods []socks.AuthMethod 21} 22 23// ParseAuthRequest parses an authentication request. 24func ParseAuthRequest(b []byte) (*AuthRequest, error) { 25 if len(b) < 2 { 26 return nil, errors.New("short auth request") 27 } 28 if b[0] != socks.Version5 { 29 return nil, errors.New("unexpected protocol version") 30 } 31 if len(b)-2 < int(b[1]) { 32 return nil, errors.New("short auth request") 33 } 34 req := &AuthRequest{Version: int(b[0])} 35 if b[1] > 0 { 36 req.Methods = make([]socks.AuthMethod, b[1]) 37 for i, m := range b[2 : 2+b[1]] { 38 req.Methods[i] = socks.AuthMethod(m) 39 } 40 } 41 return req, nil 42} 43 44// MarshalAuthReply returns an authentication reply in wire format. 45func MarshalAuthReply(ver int, m socks.AuthMethod) ([]byte, error) { 46 return []byte{byte(ver), byte(m)}, nil 47} 48 49// A CmdRequest repesents a command request. 50type CmdRequest struct { 51 Version int 52 Cmd socks.Command 53 Addr socks.Addr 54} 55 56// ParseCmdRequest parses a command request. 57func ParseCmdRequest(b []byte) (*CmdRequest, error) { 58 if len(b) < 7 { 59 return nil, errors.New("short cmd request") 60 } 61 if b[0] != socks.Version5 { 62 return nil, errors.New("unexpected protocol version") 63 } 64 if socks.Command(b[1]) != socks.CmdConnect { 65 return nil, errors.New("unexpected command") 66 } 67 if b[2] != 0 { 68 return nil, errors.New("non-zero reserved field") 69 } 70 req := &CmdRequest{Version: int(b[0]), Cmd: socks.Command(b[1])} 71 l := 2 72 off := 4 73 switch b[3] { 74 case socks.AddrTypeIPv4: 75 l += net.IPv4len 76 req.Addr.IP = make(net.IP, net.IPv4len) 77 case socks.AddrTypeIPv6: 78 l += net.IPv6len 79 req.Addr.IP = make(net.IP, net.IPv6len) 80 case socks.AddrTypeFQDN: 81 l += int(b[4]) 82 off = 5 83 default: 84 return nil, errors.New("unknown address type") 85 } 86 if len(b[off:]) < l { 87 return nil, errors.New("short cmd request") 88 } 89 if req.Addr.IP != nil { 90 copy(req.Addr.IP, b[off:]) 91 } else { 92 req.Addr.Name = string(b[off : off+l-2]) 93 } 94 req.Addr.Port = int(b[off+l-2])<<8 | int(b[off+l-1]) 95 return req, nil 96} 97 98// MarshalCmdReply returns a command reply in wire format. 99func MarshalCmdReply(ver int, reply socks.Reply, a *socks.Addr) ([]byte, error) { 100 b := make([]byte, 4) 101 b[0] = byte(ver) 102 b[1] = byte(reply) 103 if a.Name != "" { 104 if len(a.Name) > 255 { 105 return nil, errors.New("fqdn too long") 106 } 107 b[3] = socks.AddrTypeFQDN 108 b = append(b, byte(len(a.Name))) 109 b = append(b, a.Name...) 110 } else if ip4 := a.IP.To4(); ip4 != nil { 111 b[3] = socks.AddrTypeIPv4 112 b = append(b, ip4...) 113 } else if ip6 := a.IP.To16(); ip6 != nil { 114 b[3] = socks.AddrTypeIPv6 115 b = append(b, ip6...) 116 } else { 117 return nil, errors.New("unknown address type") 118 } 119 b = append(b, byte(a.Port>>8), byte(a.Port)) 120 return b, nil 121} 122 123// A Server repesents a server for handshake testing. 124type Server struct { 125 ln net.Listener 126} 127 128// Addr rerurns a server address. 129func (s *Server) Addr() net.Addr { 130 return s.ln.Addr() 131} 132 133// TargetAddr returns a fake final destination address. 134// 135// The returned address is only valid for testing with Server. 136func (s *Server) TargetAddr() net.Addr { 137 a := s.ln.Addr() 138 switch a := a.(type) { 139 case *net.TCPAddr: 140 if a.IP.To4() != nil { 141 return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 5963} 142 } 143 if a.IP.To16() != nil && a.IP.To4() == nil { 144 return &net.TCPAddr{IP: net.IPv6loopback, Port: 5963} 145 } 146 } 147 return nil 148} 149 150// Close closes the server. 151func (s *Server) Close() error { 152 return s.ln.Close() 153} 154 155func (s *Server) serve(authFunc, cmdFunc func(io.ReadWriter, []byte) error) { 156 c, err := s.ln.Accept() 157 if err != nil { 158 return 159 } 160 defer c.Close() 161 go s.serve(authFunc, cmdFunc) 162 b := make([]byte, 512) 163 n, err := c.Read(b) 164 if err != nil { 165 return 166 } 167 if err := authFunc(c, b[:n]); err != nil { 168 return 169 } 170 n, err = c.Read(b) 171 if err != nil { 172 return 173 } 174 if err := cmdFunc(c, b[:n]); err != nil { 175 return 176 } 177} 178 179// NewServer returns a new server. 180// 181// The provided authFunc and cmdFunc must parse requests and return 182// appropriate replies to clients. 183func NewServer(authFunc, cmdFunc func(io.ReadWriter, []byte) error) (*Server, error) { 184 var err error 185 s := new(Server) 186 s.ln, err = nettest.NewLocalListener("tcp") 187 if err != nil { 188 return nil, err 189 } 190 go s.serve(authFunc, cmdFunc) 191 return s, nil 192} 193 194// NoAuthRequired handles a no-authentication-required signaling. 195func NoAuthRequired(rw io.ReadWriter, b []byte) error { 196 req, err := ParseAuthRequest(b) 197 if err != nil { 198 return err 199 } 200 b, err = MarshalAuthReply(req.Version, socks.AuthMethodNotRequired) 201 if err != nil { 202 return err 203 } 204 n, err := rw.Write(b) 205 if err != nil { 206 return err 207 } 208 if n != len(b) { 209 return errors.New("short write") 210 } 211 return nil 212} 213 214// NoProxyRequired handles a command signaling without constructing a 215// proxy connection to the final destination. 216func NoProxyRequired(rw io.ReadWriter, b []byte) error { 217 req, err := ParseCmdRequest(b) 218 if err != nil { 219 return err 220 } 221 req.Addr.Port += 1 222 if req.Addr.Name != "" { 223 req.Addr.Name = "boundaddr.doesnotexist" 224 } else if req.Addr.IP.To4() != nil { 225 req.Addr.IP = net.IPv4(127, 0, 0, 1) 226 } else { 227 req.Addr.IP = net.IPv6loopback 228 } 229 b, err = MarshalCmdReply(socks.Version5, socks.StatusSucceeded, &req.Addr) 230 if err != nil { 231 return err 232 } 233 n, err := rw.Write(b) 234 if err != nil { 235 return err 236 } 237 if n != len(b) { 238 return errors.New("short write") 239 } 240 return nil 241} 242