1// mgo - MongoDB driver for Go
2//
3// Copyright (c) 2014 - Gustavo Niemeyer <gustavo@niemeyer.net>
4//
5// All rights reserved.
6//
7// Redistribution and use in source and binary forms, with or without
8// modification, are permitted provided that the following conditions are met:
9//
10// 1. Redistributions of source code must retain the above copyright notice, this
11//    list of conditions and the following disclaimer.
12// 2. Redistributions in binary form must reproduce the above copyright notice,
13//    this list of conditions and the following disclaimer in the documentation
14//    and/or other materials provided with the distribution.
15//
16// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
17// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
18// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
20// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
21// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
22// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
23// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26
27// Pacakage scram implements a SCRAM-{SHA-1,etc} client per RFC5802.
28//
29// http://tools.ietf.org/html/rfc5802
30//
31package scram
32
33import (
34	"bytes"
35	"crypto/hmac"
36	"crypto/rand"
37	"encoding/base64"
38	"fmt"
39	"hash"
40	"strconv"
41	"strings"
42)
43
44// Client implements a SCRAM-* client (SCRAM-SHA-1, SCRAM-SHA-256, etc).
45//
46// A Client may be used within a SASL conversation with logic resembling:
47//
48//    var in []byte
49//    var client = scram.NewClient(sha1.New, user, pass)
50//    for client.Step(in) {
51//            out := client.Out()
52//            // send out to server
53//            in := serverOut
54//    }
55//    if client.Err() != nil {
56//            // auth failed
57//    }
58//
59type Client struct {
60	newHash func() hash.Hash
61
62	user string
63	pass string
64	step int
65	out  bytes.Buffer
66	err  error
67
68	clientNonce []byte
69	serverNonce []byte
70	saltedPass  []byte
71	authMsg     bytes.Buffer
72}
73
74// NewClient returns a new SCRAM-* client with the provided hash algorithm.
75//
76// For SCRAM-SHA-1, for example, use:
77//
78//    client := scram.NewClient(sha1.New, user, pass)
79//
80func NewClient(newHash func() hash.Hash, user, pass string) *Client {
81	c := &Client{
82		newHash: newHash,
83		user:    user,
84		pass:    pass,
85	}
86	c.out.Grow(256)
87	c.authMsg.Grow(256)
88	return c
89}
90
91// Out returns the data to be sent to the server in the current step.
92func (c *Client) Out() []byte {
93	if c.out.Len() == 0 {
94		return nil
95	}
96	return c.out.Bytes()
97}
98
99// Err returns the error that ocurred, or nil if there were no errors.
100func (c *Client) Err() error {
101	return c.err
102}
103
104// SetNonce sets the client nonce to the provided value.
105// If not set, the nonce is generated automatically out of crypto/rand on the first step.
106func (c *Client) SetNonce(nonce []byte) {
107	c.clientNonce = nonce
108}
109
110var escaper = strings.NewReplacer("=", "=3D", ",", "=2C")
111
112// Step processes the incoming data from the server and makes the
113// next round of data for the server available via Client.Out.
114// Step returns false if there are no errors and more data is
115// still expected.
116func (c *Client) Step(in []byte) bool {
117	c.out.Reset()
118	if c.step > 2 || c.err != nil {
119		return false
120	}
121	c.step++
122	switch c.step {
123	case 1:
124		c.err = c.step1(in)
125	case 2:
126		c.err = c.step2(in)
127	case 3:
128		c.err = c.step3(in)
129	}
130	return c.step > 2 || c.err != nil
131}
132
133func (c *Client) step1(in []byte) error {
134	if len(c.clientNonce) == 0 {
135		const nonceLen = 6
136		buf := make([]byte, nonceLen + b64.EncodedLen(nonceLen))
137		if _, err := rand.Read(buf[:nonceLen]); err != nil {
138			return fmt.Errorf("cannot read random SCRAM-SHA-1 nonce from operating system: %v", err)
139		}
140		c.clientNonce = buf[nonceLen:]
141		b64.Encode(c.clientNonce, buf[:nonceLen])
142	}
143	c.authMsg.WriteString("n=")
144	escaper.WriteString(&c.authMsg, c.user)
145	c.authMsg.WriteString(",r=")
146	c.authMsg.Write(c.clientNonce)
147
148	c.out.WriteString("n,,")
149	c.out.Write(c.authMsg.Bytes())
150	return nil
151}
152
153var b64 = base64.StdEncoding
154
155func (c *Client) step2(in []byte) error {
156	c.authMsg.WriteByte(',')
157	c.authMsg.Write(in)
158
159	fields := bytes.Split(in, []byte(","))
160	if len(fields) != 3 {
161		return fmt.Errorf("expected 3 fields in first SCRAM-SHA-1 server message, got %d: %q", len(fields), in)
162	}
163	if !bytes.HasPrefix(fields[0], []byte("r=")) || len(fields[0]) < 2 {
164		return fmt.Errorf("server sent an invalid SCRAM-SHA-1 nonce: %q", fields[0])
165	}
166	if !bytes.HasPrefix(fields[1], []byte("s=")) || len(fields[1]) < 6 {
167		return fmt.Errorf("server sent an invalid SCRAM-SHA-1 salt: %q", fields[1])
168	}
169	if !bytes.HasPrefix(fields[2], []byte("i=")) || len(fields[2]) < 6 {
170		return fmt.Errorf("server sent an invalid SCRAM-SHA-1 iteration count: %q", fields[2])
171	}
172
173	c.serverNonce = fields[0][2:]
174	if !bytes.HasPrefix(c.serverNonce, c.clientNonce) {
175		return fmt.Errorf("server SCRAM-SHA-1 nonce is not prefixed by client nonce: got %q, want %q+\"...\"", c.serverNonce, c.clientNonce)
176	}
177
178	salt := make([]byte, b64.DecodedLen(len(fields[1][2:])))
179	n, err := b64.Decode(salt, fields[1][2:])
180	if err != nil {
181		return fmt.Errorf("cannot decode SCRAM-SHA-1 salt sent by server: %q", fields[1])
182	}
183	salt = salt[:n]
184	iterCount, err := strconv.Atoi(string(fields[2][2:]))
185	if err != nil {
186		return fmt.Errorf("server sent an invalid SCRAM-SHA-1 iteration count: %q", fields[2])
187	}
188	c.saltPassword(salt, iterCount)
189
190	c.authMsg.WriteString(",c=biws,r=")
191	c.authMsg.Write(c.serverNonce)
192
193	c.out.WriteString("c=biws,r=")
194	c.out.Write(c.serverNonce)
195	c.out.WriteString(",p=")
196	c.out.Write(c.clientProof())
197	return nil
198}
199
200func (c *Client) step3(in []byte) error {
201	var isv, ise bool
202	var fields = bytes.Split(in, []byte(","))
203	if len(fields) == 1 {
204		isv = bytes.HasPrefix(fields[0], []byte("v="))
205		ise = bytes.HasPrefix(fields[0], []byte("e="))
206	}
207	if ise {
208		return fmt.Errorf("SCRAM-SHA-1 authentication error: %s", fields[0][2:])
209	} else if !isv {
210		return fmt.Errorf("unsupported SCRAM-SHA-1 final message from server: %q", in)
211	}
212	if !bytes.Equal(c.serverSignature(), fields[0][2:]) {
213		return fmt.Errorf("cannot authenticate SCRAM-SHA-1 server signature: %q", fields[0][2:])
214	}
215	return nil
216}
217
218func (c *Client) saltPassword(salt []byte, iterCount int) {
219	mac := hmac.New(c.newHash, []byte(c.pass))
220	mac.Write(salt)
221	mac.Write([]byte{0, 0, 0, 1})
222	ui := mac.Sum(nil)
223	hi := make([]byte, len(ui))
224	copy(hi, ui)
225	for i := 1; i < iterCount; i++ {
226		mac.Reset()
227		mac.Write(ui)
228		mac.Sum(ui[:0])
229		for j, b := range ui {
230			hi[j] ^= b
231		}
232	}
233	c.saltedPass = hi
234}
235
236func (c *Client) clientProof() []byte {
237	mac := hmac.New(c.newHash, c.saltedPass)
238	mac.Write([]byte("Client Key"))
239	clientKey := mac.Sum(nil)
240	hash := c.newHash()
241	hash.Write(clientKey)
242	storedKey := hash.Sum(nil)
243	mac = hmac.New(c.newHash, storedKey)
244	mac.Write(c.authMsg.Bytes())
245	clientProof := mac.Sum(nil)
246	for i, b := range clientKey {
247		clientProof[i] ^= b
248	}
249	clientProof64 := make([]byte, b64.EncodedLen(len(clientProof)))
250	b64.Encode(clientProof64, clientProof)
251	return clientProof64
252}
253
254func (c *Client) serverSignature() []byte {
255	mac := hmac.New(c.newHash, c.saltedPass)
256	mac.Write([]byte("Server Key"))
257	serverKey := mac.Sum(nil)
258
259	mac = hmac.New(c.newHash, serverKey)
260	mac.Write(c.authMsg.Bytes())
261	serverSignature := mac.Sum(nil)
262
263	encoded := make([]byte, b64.EncodedLen(len(serverSignature)))
264	b64.Encode(encoded, serverSignature)
265	return encoded
266}
267