1package ntlmssp
2
3import (
4	"bytes"
5	"encoding/base64"
6	"io"
7	"io/ioutil"
8	"net/http"
9	"strings"
10)
11
12// GetDomain : parse domain name from based on slashes in the input
13func GetDomain(user string) (string, string) {
14	domain := ""
15
16	if strings.Contains(user, "\\") {
17		ucomponents := strings.SplitN(user, "\\", 2)
18		domain = ucomponents[0]
19		user = ucomponents[1]
20	}
21	return user, domain
22}
23
24//Negotiator is a http.Roundtripper decorator that automatically
25//converts basic authentication to NTLM/Negotiate authentication when appropriate.
26type Negotiator struct{ http.RoundTripper }
27
28//RoundTrip sends the request to the server, handling any authentication
29//re-sends as needed.
30func (l Negotiator) RoundTrip(req *http.Request) (res *http.Response, err error) {
31	// Use default round tripper if not provided
32	rt := l.RoundTripper
33	if rt == nil {
34		rt = http.DefaultTransport
35	}
36	// If it is not basic auth, just round trip the request as usual
37	reqauth := authheader(req.Header.Get("Authorization"))
38	if !reqauth.IsBasic() {
39		return rt.RoundTrip(req)
40	}
41	// Save request body
42	body := bytes.Buffer{}
43	if req.Body != nil {
44		_, err = body.ReadFrom(req.Body)
45		if err != nil {
46			return nil, err
47		}
48
49		req.Body.Close()
50		req.Body = ioutil.NopCloser(bytes.NewReader(body.Bytes()))
51	}
52	// first try anonymous, in case the server still finds us
53	// authenticated from previous traffic
54	req.Header.Del("Authorization")
55	res, err = rt.RoundTrip(req)
56	if err != nil {
57		return nil, err
58	}
59	if res.StatusCode != http.StatusUnauthorized {
60		return res, err
61	}
62
63	resauth := authheader(res.Header.Get("Www-Authenticate"))
64	if !resauth.IsNegotiate() && !resauth.IsNTLM() {
65		// Unauthorized, Negotiate not requested, let's try with basic auth
66		req.Header.Set("Authorization", string(reqauth))
67		io.Copy(ioutil.Discard, res.Body)
68		res.Body.Close()
69		req.Body = ioutil.NopCloser(bytes.NewReader(body.Bytes()))
70
71		res, err = rt.RoundTrip(req)
72		if err != nil {
73			return nil, err
74		}
75		if res.StatusCode != http.StatusUnauthorized {
76			return res, err
77		}
78		resauth = authheader(res.Header.Get("Www-Authenticate"))
79	}
80
81	if resauth.IsNegotiate() || resauth.IsNTLM() {
82		// 401 with request:Basic and response:Negotiate
83		io.Copy(ioutil.Discard, res.Body)
84		res.Body.Close()
85
86		// recycle credentials
87		u, p, err := reqauth.GetBasicCreds()
88		if err != nil {
89			return nil, err
90		}
91
92		// get domain from username
93		domain := ""
94		u, domain = GetDomain(u)
95
96		// send negotiate
97		negotiateMessage, err := NewNegotiateMessage(domain, "")
98		if err != nil {
99			return nil, err
100		}
101		if resauth.IsNTLM() {
102			req.Header.Set("Authorization", "NTLM "+base64.StdEncoding.EncodeToString(negotiateMessage))
103		} else {
104			req.Header.Set("Authorization", "Negotiate "+base64.StdEncoding.EncodeToString(negotiateMessage))
105		}
106
107		req.Body = ioutil.NopCloser(bytes.NewReader(body.Bytes()))
108
109		res, err = rt.RoundTrip(req)
110		if err != nil {
111			return nil, err
112		}
113
114		// receive challenge?
115		resauth = authheader(res.Header.Get("Www-Authenticate"))
116		challengeMessage, err := resauth.GetData()
117		if err != nil {
118			return nil, err
119		}
120		if !(resauth.IsNegotiate() || resauth.IsNTLM()) || len(challengeMessage) == 0 {
121			// Negotiation failed, let client deal with response
122			return res, nil
123		}
124		io.Copy(ioutil.Discard, res.Body)
125		res.Body.Close()
126
127		// send authenticate
128		authenticateMessage, err := ProcessChallenge(challengeMessage, u, p)
129		if err != nil {
130			return nil, err
131		}
132		if resauth.IsNTLM() {
133			req.Header.Set("Authorization", "NTLM "+base64.StdEncoding.EncodeToString(authenticateMessage))
134		} else {
135			req.Header.Set("Authorization", "Negotiate "+base64.StdEncoding.EncodeToString(authenticateMessage))
136		}
137
138		req.Body = ioutil.NopCloser(bytes.NewReader(body.Bytes()))
139
140		return rt.RoundTrip(req)
141	}
142
143	return res, err
144}
145