1package socks5
2
3import (
4	"fmt"
5	"io"
6)
7
8const (
9	NoAuth          = uint8(0)
10	noAcceptable    = uint8(255)
11	UserPassAuth    = uint8(2)
12	userAuthVersion = uint8(1)
13	authSuccess     = uint8(0)
14	authFailure     = uint8(1)
15)
16
17var (
18	UserAuthFailed  = fmt.Errorf("User authentication failed")
19	NoSupportedAuth = fmt.Errorf("No supported authentication mechanism")
20)
21
22// A Request encapsulates authentication state provided
23// during negotiation
24type AuthContext struct {
25	// Provided auth method
26	Method uint8
27	// Payload provided during negotiation.
28	// Keys depend on the used auth method.
29	// For UserPassauth contains Username
30	Payload map[string]string
31}
32
33type Authenticator interface {
34	Authenticate(reader io.Reader, writer io.Writer) (*AuthContext, error)
35	GetCode() uint8
36}
37
38// NoAuthAuthenticator is used to handle the "No Authentication" mode
39type NoAuthAuthenticator struct{}
40
41func (a NoAuthAuthenticator) GetCode() uint8 {
42	return NoAuth
43}
44
45func (a NoAuthAuthenticator) Authenticate(reader io.Reader, writer io.Writer) (*AuthContext, error) {
46	_, err := writer.Write([]byte{socks5Version, NoAuth})
47	return &AuthContext{NoAuth, nil}, err
48}
49
50// UserPassAuthenticator is used to handle username/password based
51// authentication
52type UserPassAuthenticator struct {
53	Credentials CredentialStore
54}
55
56func (a UserPassAuthenticator) GetCode() uint8 {
57	return UserPassAuth
58}
59
60func (a UserPassAuthenticator) Authenticate(reader io.Reader, writer io.Writer) (*AuthContext, error) {
61	// Tell the client to use user/pass auth
62	if _, err := writer.Write([]byte{socks5Version, UserPassAuth}); err != nil {
63		return nil, err
64	}
65
66	// Get the version and username length
67	header := []byte{0, 0}
68	if _, err := io.ReadAtLeast(reader, header, 2); err != nil {
69		return nil, err
70	}
71
72	// Ensure we are compatible
73	if header[0] != userAuthVersion {
74		return nil, fmt.Errorf("Unsupported auth version: %v", header[0])
75	}
76
77	// Get the user name
78	userLen := int(header[1])
79	user := make([]byte, userLen)
80	if _, err := io.ReadAtLeast(reader, user, userLen); err != nil {
81		return nil, err
82	}
83
84	// Get the password length
85	if _, err := reader.Read(header[:1]); err != nil {
86		return nil, err
87	}
88
89	// Get the password
90	passLen := int(header[0])
91	pass := make([]byte, passLen)
92	if _, err := io.ReadAtLeast(reader, pass, passLen); err != nil {
93		return nil, err
94	}
95
96	// Verify the password
97	if a.Credentials.Valid(string(user), string(pass)) {
98		if _, err := writer.Write([]byte{userAuthVersion, authSuccess}); err != nil {
99			return nil, err
100		}
101	} else {
102		if _, err := writer.Write([]byte{userAuthVersion, authFailure}); err != nil {
103			return nil, err
104		}
105		return nil, UserAuthFailed
106	}
107
108	// Done
109	return &AuthContext{UserPassAuth, map[string]string{"Username": string(user)}}, nil
110}
111
112// authenticate is used to handle connection authentication
113func (s *Server) authenticate(conn io.Writer, bufConn io.Reader) (*AuthContext, error) {
114	// Get the methods
115	methods, err := readMethods(bufConn)
116	if err != nil {
117		return nil, fmt.Errorf("Failed to get auth methods: %v", err)
118	}
119
120	// Select a usable method
121	for _, method := range methods {
122		cator, found := s.authMethods[method]
123		if found {
124			return cator.Authenticate(bufConn, conn)
125		}
126	}
127
128	// No usable method found
129	return nil, noAcceptableAuth(conn)
130}
131
132// noAcceptableAuth is used to handle when we have no eligible
133// authentication mechanism
134func noAcceptableAuth(conn io.Writer) error {
135	conn.Write([]byte{socks5Version, noAcceptable})
136	return NoSupportedAuth
137}
138
139// readMethods is used to read the number of methods
140// and proceeding auth methods
141func readMethods(r io.Reader) ([]byte, error) {
142	header := []byte{0}
143	if _, err := r.Read(header); err != nil {
144		return nil, err
145	}
146
147	numMethods := int(header[0])
148	methods := make([]byte, numMethods)
149	_, err := io.ReadAtLeast(r, methods, numMethods)
150	return methods, err
151}
152