1package socks5 2 3import ( 4 "fmt" 5 "io" 6) 7 8const ( 9 NoAuth = uint8(0) 10 noAcceptable = uint8(255) 11 UserPassAuth = uint8(2) 12 userAuthVersion = uint8(1) 13 authSuccess = uint8(0) 14 authFailure = uint8(1) 15) 16 17var ( 18 UserAuthFailed = fmt.Errorf("User authentication failed") 19 NoSupportedAuth = fmt.Errorf("No supported authentication mechanism") 20) 21 22// A Request encapsulates authentication state provided 23// during negotiation 24type AuthContext struct { 25 // Provided auth method 26 Method uint8 27 // Payload provided during negotiation. 28 // Keys depend on the used auth method. 29 // For UserPassauth contains Username 30 Payload map[string]string 31} 32 33type Authenticator interface { 34 Authenticate(reader io.Reader, writer io.Writer) (*AuthContext, error) 35 GetCode() uint8 36} 37 38// NoAuthAuthenticator is used to handle the "No Authentication" mode 39type NoAuthAuthenticator struct{} 40 41func (a NoAuthAuthenticator) GetCode() uint8 { 42 return NoAuth 43} 44 45func (a NoAuthAuthenticator) Authenticate(reader io.Reader, writer io.Writer) (*AuthContext, error) { 46 _, err := writer.Write([]byte{socks5Version, NoAuth}) 47 return &AuthContext{NoAuth, nil}, err 48} 49 50// UserPassAuthenticator is used to handle username/password based 51// authentication 52type UserPassAuthenticator struct { 53 Credentials CredentialStore 54} 55 56func (a UserPassAuthenticator) GetCode() uint8 { 57 return UserPassAuth 58} 59 60func (a UserPassAuthenticator) Authenticate(reader io.Reader, writer io.Writer) (*AuthContext, error) { 61 // Tell the client to use user/pass auth 62 if _, err := writer.Write([]byte{socks5Version, UserPassAuth}); err != nil { 63 return nil, err 64 } 65 66 // Get the version and username length 67 header := []byte{0, 0} 68 if _, err := io.ReadAtLeast(reader, header, 2); err != nil { 69 return nil, err 70 } 71 72 // Ensure we are compatible 73 if header[0] != userAuthVersion { 74 return nil, fmt.Errorf("Unsupported auth version: %v", header[0]) 75 } 76 77 // Get the user name 78 userLen := int(header[1]) 79 user := make([]byte, userLen) 80 if _, err := io.ReadAtLeast(reader, user, userLen); err != nil { 81 return nil, err 82 } 83 84 // Get the password length 85 if _, err := reader.Read(header[:1]); err != nil { 86 return nil, err 87 } 88 89 // Get the password 90 passLen := int(header[0]) 91 pass := make([]byte, passLen) 92 if _, err := io.ReadAtLeast(reader, pass, passLen); err != nil { 93 return nil, err 94 } 95 96 // Verify the password 97 if a.Credentials.Valid(string(user), string(pass)) { 98 if _, err := writer.Write([]byte{userAuthVersion, authSuccess}); err != nil { 99 return nil, err 100 } 101 } else { 102 if _, err := writer.Write([]byte{userAuthVersion, authFailure}); err != nil { 103 return nil, err 104 } 105 return nil, UserAuthFailed 106 } 107 108 // Done 109 return &AuthContext{UserPassAuth, map[string]string{"Username": string(user)}}, nil 110} 111 112// authenticate is used to handle connection authentication 113func (s *Server) authenticate(conn io.Writer, bufConn io.Reader) (*AuthContext, error) { 114 // Get the methods 115 methods, err := readMethods(bufConn) 116 if err != nil { 117 return nil, fmt.Errorf("Failed to get auth methods: %v", err) 118 } 119 120 // Select a usable method 121 for _, method := range methods { 122 cator, found := s.authMethods[method] 123 if found { 124 return cator.Authenticate(bufConn, conn) 125 } 126 } 127 128 // No usable method found 129 return nil, noAcceptableAuth(conn) 130} 131 132// noAcceptableAuth is used to handle when we have no eligible 133// authentication mechanism 134func noAcceptableAuth(conn io.Writer) error { 135 conn.Write([]byte{socks5Version, noAcceptable}) 136 return NoSupportedAuth 137} 138 139// readMethods is used to read the number of methods 140// and proceeding auth methods 141func readMethods(r io.Reader) ([]byte, error) { 142 header := []byte{0} 143 if _, err := r.Read(header); err != nil { 144 return nil, err 145 } 146 147 numMethods := int(header[0]) 148 methods := make([]byte, numMethods) 149 _, err := io.ReadAtLeast(r, methods, numMethods) 150 return methods, err 151} 152