1// mgo - MongoDB driver for Go
2//
3// Copyright (c) 2010-2012 - Gustavo Niemeyer <gustavo@niemeyer.net>
4//
5// All rights reserved.
6//
7// Redistribution and use in source and binary forms, with or without
8// modification, are permitted provided that the following conditions are met:
9//
10// 1. Redistributions of source code must retain the above copyright notice, this
11//    list of conditions and the following disclaimer.
12// 2. Redistributions in binary form must reproduce the above copyright notice,
13//    this list of conditions and the following disclaimer in the documentation
14//    and/or other materials provided with the distribution.
15//
16// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
17// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
18// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
20// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
21// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
22// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
23// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26
27package mgo
28
29import (
30	"crypto/md5"
31	"crypto/sha1"
32	"encoding/hex"
33	"errors"
34	"fmt"
35	"sync"
36
37	"gopkg.in/mgo.v2/bson"
38	"gopkg.in/mgo.v2/internal/scram"
39)
40
41type authCmd struct {
42	Authenticate int
43
44	Nonce string
45	User  string
46	Key   string
47}
48
49type startSaslCmd struct {
50	StartSASL int `bson:"startSasl"`
51}
52
53type authResult struct {
54	ErrMsg string
55	Ok     bool
56}
57
58type getNonceCmd struct {
59	GetNonce int
60}
61
62type getNonceResult struct {
63	Nonce string
64	Err   string "$err"
65	Code  int
66}
67
68type logoutCmd struct {
69	Logout int
70}
71
72type saslCmd struct {
73	Start          int    `bson:"saslStart,omitempty"`
74	Continue       int    `bson:"saslContinue,omitempty"`
75	ConversationId int    `bson:"conversationId,omitempty"`
76	Mechanism      string `bson:"mechanism,omitempty"`
77	Payload        []byte
78}
79
80type saslResult struct {
81	Ok    bool `bson:"ok"`
82	NotOk bool `bson:"code"` // Server <= 2.3.2 returns ok=1 & code>0 on errors (WTF?)
83	Done  bool
84
85	ConversationId int `bson:"conversationId"`
86	Payload        []byte
87	ErrMsg         string
88}
89
90type saslStepper interface {
91	Step(serverData []byte) (clientData []byte, done bool, err error)
92	Close()
93}
94
95func (socket *mongoSocket) getNonce() (nonce string, err error) {
96	socket.Lock()
97	for socket.cachedNonce == "" && socket.dead == nil {
98		debugf("Socket %p to %s: waiting for nonce", socket, socket.addr)
99		socket.gotNonce.Wait()
100	}
101	if socket.cachedNonce == "mongos" {
102		socket.Unlock()
103		return "", errors.New("Can't authenticate with mongos; see http://j.mp/mongos-auth")
104	}
105	debugf("Socket %p to %s: got nonce", socket, socket.addr)
106	nonce, err = socket.cachedNonce, socket.dead
107	socket.cachedNonce = ""
108	socket.Unlock()
109	if err != nil {
110		nonce = ""
111	}
112	return
113}
114
115func (socket *mongoSocket) resetNonce() {
116	debugf("Socket %p to %s: requesting a new nonce", socket, socket.addr)
117	op := &queryOp{}
118	op.query = &getNonceCmd{GetNonce: 1}
119	op.collection = "admin.$cmd"
120	op.limit = -1
121	op.replyFunc = func(err error, reply *replyOp, docNum int, docData []byte) {
122		if err != nil {
123			socket.kill(errors.New("getNonce: "+err.Error()), true)
124			return
125		}
126		result := &getNonceResult{}
127		err = bson.Unmarshal(docData, &result)
128		if err != nil {
129			socket.kill(errors.New("Failed to unmarshal nonce: "+err.Error()), true)
130			return
131		}
132		debugf("Socket %p to %s: nonce unmarshalled: %#v", socket, socket.addr, result)
133		if result.Code == 13390 {
134			// mongos doesn't yet support auth (see http://j.mp/mongos-auth)
135			result.Nonce = "mongos"
136		} else if result.Nonce == "" {
137			var msg string
138			if result.Err != "" {
139				msg = fmt.Sprintf("Got an empty nonce: %s (%d)", result.Err, result.Code)
140			} else {
141				msg = "Got an empty nonce"
142			}
143			socket.kill(errors.New(msg), true)
144			return
145		}
146		socket.Lock()
147		if socket.cachedNonce != "" {
148			socket.Unlock()
149			panic("resetNonce: nonce already cached")
150		}
151		socket.cachedNonce = result.Nonce
152		socket.gotNonce.Signal()
153		socket.Unlock()
154	}
155	err := socket.Query(op)
156	if err != nil {
157		socket.kill(errors.New("resetNonce: "+err.Error()), true)
158	}
159}
160
161func (socket *mongoSocket) Login(cred Credential) error {
162	socket.Lock()
163	if cred.Mechanism == "" && socket.serverInfo.MaxWireVersion >= 3 {
164		cred.Mechanism = "SCRAM-SHA-1"
165	}
166	for _, sockCred := range socket.creds {
167		if sockCred == cred {
168			debugf("Socket %p to %s: login: db=%q user=%q (already logged in)", socket, socket.addr, cred.Source, cred.Username)
169			socket.Unlock()
170			return nil
171		}
172	}
173	if socket.dropLogout(cred) {
174		debugf("Socket %p to %s: login: db=%q user=%q (cached)", socket, socket.addr, cred.Source, cred.Username)
175		socket.creds = append(socket.creds, cred)
176		socket.Unlock()
177		return nil
178	}
179	socket.Unlock()
180
181	debugf("Socket %p to %s: login: db=%q user=%q", socket, socket.addr, cred.Source, cred.Username)
182
183	var err error
184	switch cred.Mechanism {
185	case "", "MONGODB-CR", "MONGO-CR": // Name changed to MONGODB-CR in SERVER-8501.
186		err = socket.loginClassic(cred)
187	case "PLAIN":
188		err = socket.loginPlain(cred)
189	case "MONGODB-X509":
190		err = socket.loginX509(cred)
191	default:
192		// Try SASL for everything else, if it is available.
193		err = socket.loginSASL(cred)
194	}
195
196	if err != nil {
197		debugf("Socket %p to %s: login error: %s", socket, socket.addr, err)
198	} else {
199		debugf("Socket %p to %s: login successful", socket, socket.addr)
200	}
201	return err
202}
203
204func (socket *mongoSocket) loginClassic(cred Credential) error {
205	// Note that this only works properly because this function is
206	// synchronous, which means the nonce won't get reset while we're
207	// using it and any other login requests will block waiting for a
208	// new nonce provided in the defer call below.
209	nonce, err := socket.getNonce()
210	if err != nil {
211		return err
212	}
213	defer socket.resetNonce()
214
215	psum := md5.New()
216	psum.Write([]byte(cred.Username + ":mongo:" + cred.Password))
217
218	ksum := md5.New()
219	ksum.Write([]byte(nonce + cred.Username))
220	ksum.Write([]byte(hex.EncodeToString(psum.Sum(nil))))
221
222	key := hex.EncodeToString(ksum.Sum(nil))
223
224	cmd := authCmd{Authenticate: 1, User: cred.Username, Nonce: nonce, Key: key}
225	res := authResult{}
226	return socket.loginRun(cred.Source, &cmd, &res, func() error {
227		if !res.Ok {
228			return errors.New(res.ErrMsg)
229		}
230		socket.Lock()
231		socket.dropAuth(cred.Source)
232		socket.creds = append(socket.creds, cred)
233		socket.Unlock()
234		return nil
235	})
236}
237
238type authX509Cmd struct {
239	Authenticate int
240	User         string
241	Mechanism    string
242}
243
244func (socket *mongoSocket) loginX509(cred Credential) error {
245	cmd := authX509Cmd{Authenticate: 1, User: cred.Username, Mechanism: "MONGODB-X509"}
246	res := authResult{}
247	return socket.loginRun(cred.Source, &cmd, &res, func() error {
248		if !res.Ok {
249			return errors.New(res.ErrMsg)
250		}
251		socket.Lock()
252		socket.dropAuth(cred.Source)
253		socket.creds = append(socket.creds, cred)
254		socket.Unlock()
255		return nil
256	})
257}
258
259func (socket *mongoSocket) loginPlain(cred Credential) error {
260	cmd := saslCmd{Start: 1, Mechanism: "PLAIN", Payload: []byte("\x00" + cred.Username + "\x00" + cred.Password)}
261	res := authResult{}
262	return socket.loginRun(cred.Source, &cmd, &res, func() error {
263		if !res.Ok {
264			return errors.New(res.ErrMsg)
265		}
266		socket.Lock()
267		socket.dropAuth(cred.Source)
268		socket.creds = append(socket.creds, cred)
269		socket.Unlock()
270		return nil
271	})
272}
273
274func (socket *mongoSocket) loginSASL(cred Credential) error {
275	var sasl saslStepper
276	var err error
277	if cred.Mechanism == "SCRAM-SHA-1" {
278		// SCRAM is handled without external libraries.
279		sasl = saslNewScram(cred)
280	} else if len(cred.ServiceHost) > 0 {
281		sasl, err = saslNew(cred, cred.ServiceHost)
282	} else {
283		sasl, err = saslNew(cred, socket.Server().Addr)
284	}
285	if err != nil {
286		return err
287	}
288	defer sasl.Close()
289
290	// The goal of this logic is to carry a locked socket until the
291	// local SASL step confirms the auth is valid; the socket needs to be
292	// locked so that concurrent action doesn't leave the socket in an
293	// auth state that doesn't reflect the operations that took place.
294	// As a simple case, imagine inverting login=>logout to logout=>login.
295	//
296	// The logic below works because the lock func isn't called concurrently.
297	locked := false
298	lock := func(b bool) {
299		if locked != b {
300			locked = b
301			if b {
302				socket.Lock()
303			} else {
304				socket.Unlock()
305			}
306		}
307	}
308
309	lock(true)
310	defer lock(false)
311
312	start := 1
313	cmd := saslCmd{}
314	res := saslResult{}
315	for {
316		payload, done, err := sasl.Step(res.Payload)
317		if err != nil {
318			return err
319		}
320		if done && res.Done {
321			socket.dropAuth(cred.Source)
322			socket.creds = append(socket.creds, cred)
323			break
324		}
325		lock(false)
326
327		cmd = saslCmd{
328			Start:          start,
329			Continue:       1 - start,
330			ConversationId: res.ConversationId,
331			Mechanism:      cred.Mechanism,
332			Payload:        payload,
333		}
334		start = 0
335		err = socket.loginRun(cred.Source, &cmd, &res, func() error {
336			// See the comment on lock for why this is necessary.
337			lock(true)
338			if !res.Ok || res.NotOk {
339				return fmt.Errorf("server returned error on SASL authentication step: %s", res.ErrMsg)
340			}
341			return nil
342		})
343		if err != nil {
344			return err
345		}
346		if done && res.Done {
347			socket.dropAuth(cred.Source)
348			socket.creds = append(socket.creds, cred)
349			break
350		}
351	}
352
353	return nil
354}
355
356func saslNewScram(cred Credential) *saslScram {
357	credsum := md5.New()
358	credsum.Write([]byte(cred.Username + ":mongo:" + cred.Password))
359	client := scram.NewClient(sha1.New, cred.Username, hex.EncodeToString(credsum.Sum(nil)))
360	return &saslScram{cred: cred, client: client}
361}
362
363type saslScram struct {
364	cred   Credential
365	client *scram.Client
366}
367
368func (s *saslScram) Close() {}
369
370func (s *saslScram) Step(serverData []byte) (clientData []byte, done bool, err error) {
371	more := s.client.Step(serverData)
372	return s.client.Out(), !more, s.client.Err()
373}
374
375func (socket *mongoSocket) loginRun(db string, query, result interface{}, f func() error) error {
376	var mutex sync.Mutex
377	var replyErr error
378	mutex.Lock()
379
380	op := queryOp{}
381	op.query = query
382	op.collection = db + ".$cmd"
383	op.limit = -1
384	op.replyFunc = func(err error, reply *replyOp, docNum int, docData []byte) {
385		defer mutex.Unlock()
386
387		if err != nil {
388			replyErr = err
389			return
390		}
391
392		err = bson.Unmarshal(docData, result)
393		if err != nil {
394			replyErr = err
395		} else {
396			// Must handle this within the read loop for the socket, so
397			// that concurrent login requests are properly ordered.
398			replyErr = f()
399		}
400	}
401
402	err := socket.Query(&op)
403	if err != nil {
404		return err
405	}
406	mutex.Lock() // Wait.
407	return replyErr
408}
409
410func (socket *mongoSocket) Logout(db string) {
411	socket.Lock()
412	cred, found := socket.dropAuth(db)
413	if found {
414		debugf("Socket %p to %s: logout: db=%q (flagged)", socket, socket.addr, db)
415		socket.logout = append(socket.logout, cred)
416	}
417	socket.Unlock()
418}
419
420func (socket *mongoSocket) LogoutAll() {
421	socket.Lock()
422	if l := len(socket.creds); l > 0 {
423		debugf("Socket %p to %s: logout all (flagged %d)", socket, socket.addr, l)
424		socket.logout = append(socket.logout, socket.creds...)
425		socket.creds = socket.creds[0:0]
426	}
427	socket.Unlock()
428}
429
430func (socket *mongoSocket) flushLogout() (ops []interface{}) {
431	socket.Lock()
432	if l := len(socket.logout); l > 0 {
433		debugf("Socket %p to %s: logout all (flushing %d)", socket, socket.addr, l)
434		for i := 0; i != l; i++ {
435			op := queryOp{}
436			op.query = &logoutCmd{1}
437			op.collection = socket.logout[i].Source + ".$cmd"
438			op.limit = -1
439			ops = append(ops, &op)
440		}
441		socket.logout = socket.logout[0:0]
442	}
443	socket.Unlock()
444	return
445}
446
447func (socket *mongoSocket) dropAuth(db string) (cred Credential, found bool) {
448	for i, sockCred := range socket.creds {
449		if sockCred.Source == db {
450			copy(socket.creds[i:], socket.creds[i+1:])
451			socket.creds = socket.creds[:len(socket.creds)-1]
452			return sockCred, true
453		}
454	}
455	return cred, false
456}
457
458func (socket *mongoSocket) dropLogout(cred Credential) (found bool) {
459	for i, sockCred := range socket.logout {
460		if sockCred == cred {
461			copy(socket.logout[i:], socket.logout[i+1:])
462			socket.logout = socket.logout[:len(socket.logout)-1]
463			return true
464		}
465	}
466	return false
467}
468