1// mgo - MongoDB driver for Go
2//
3// Copyright (c) 2010-2012 - Gustavo Niemeyer <gustavo@niemeyer.net>
4//
5// All rights reserved.
6//
7// Redistribution and use in source and binary forms, with or without
8// modification, are permitted provided that the following conditions are met:
9//
10// 1. Redistributions of source code must retain the above copyright notice, this
11//    list of conditions and the following disclaimer.
12// 2. Redistributions in binary form must reproduce the above copyright notice,
13//    this list of conditions and the following disclaimer in the documentation
14//    and/or other materials provided with the distribution.
15//
16// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
17// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
18// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
20// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
21// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
22// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
23// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26
27package mgo
28
29import (
30	"errors"
31	"fmt"
32	"net"
33	"sync"
34	"time"
35
36	"gopkg.in/mgo.v2/bson"
37)
38
39type replyFunc func(err error, reply *replyOp, docNum int, docData []byte)
40
41type mongoSocket struct {
42	sync.Mutex
43	server        *mongoServer // nil when cached
44	conn          net.Conn
45	timeout       time.Duration
46	addr          string // For debugging only.
47	nextRequestId uint32
48	replyFuncs    map[uint32]replyFunc
49	references    int
50	creds         []Credential
51	logout        []Credential
52	cachedNonce   string
53	gotNonce      sync.Cond
54	dead          error
55	serverInfo    *mongoServerInfo
56}
57
58type queryOpFlags uint32
59
60const (
61	_ queryOpFlags = 1 << iota
62	flagTailable
63	flagSlaveOk
64	flagLogReplay
65	flagNoCursorTimeout
66	flagAwaitData
67)
68
69type queryOp struct {
70	collection string
71	query      interface{}
72	skip       int32
73	limit      int32
74	selector   interface{}
75	flags      queryOpFlags
76	replyFunc  replyFunc
77
78	mode       Mode
79	options    queryWrapper
80	hasOptions bool
81	serverTags []bson.D
82}
83
84type queryWrapper struct {
85	Query          interface{} "$query"
86	OrderBy        interface{} "$orderby,omitempty"
87	Hint           interface{} "$hint,omitempty"
88	Explain        bool        "$explain,omitempty"
89	Snapshot       bool        "$snapshot,omitempty"
90	ReadPreference bson.D      "$readPreference,omitempty"
91	MaxScan        int         "$maxScan,omitempty"
92	MaxTimeMS      int         "$maxTimeMS,omitempty"
93	Comment        string      "$comment,omitempty"
94}
95
96func (op *queryOp) finalQuery(socket *mongoSocket) interface{} {
97	if op.flags&flagSlaveOk != 0 && socket.ServerInfo().Mongos {
98		var modeName string
99		switch op.mode {
100		case Strong:
101			modeName = "primary"
102		case Monotonic, Eventual:
103			modeName = "secondaryPreferred"
104		case PrimaryPreferred:
105			modeName = "primaryPreferred"
106		case Secondary:
107			modeName = "secondary"
108		case SecondaryPreferred:
109			modeName = "secondaryPreferred"
110		case Nearest:
111			modeName = "nearest"
112		default:
113			panic(fmt.Sprintf("unsupported read mode: %d", op.mode))
114		}
115		op.hasOptions = true
116		op.options.ReadPreference = make(bson.D, 0, 2)
117		op.options.ReadPreference = append(op.options.ReadPreference, bson.DocElem{"mode", modeName})
118		if len(op.serverTags) > 0 {
119			op.options.ReadPreference = append(op.options.ReadPreference, bson.DocElem{"tags", op.serverTags})
120		}
121	}
122	if op.hasOptions {
123		if op.query == nil {
124			var empty bson.D
125			op.options.Query = empty
126		} else {
127			op.options.Query = op.query
128		}
129		debugf("final query is %#v\n", &op.options)
130		return &op.options
131	}
132	return op.query
133}
134
135type getMoreOp struct {
136	collection string
137	limit      int32
138	cursorId   int64
139	replyFunc  replyFunc
140}
141
142type replyOp struct {
143	flags     uint32
144	cursorId  int64
145	firstDoc  int32
146	replyDocs int32
147}
148
149type insertOp struct {
150	collection string        // "database.collection"
151	documents  []interface{} // One or more documents to insert
152	flags      uint32
153}
154
155type updateOp struct {
156	Collection string      `bson:"-"` // "database.collection"
157	Selector   interface{} `bson:"q"`
158	Update     interface{} `bson:"u"`
159	Flags      uint32      `bson:"-"`
160	Multi      bool        `bson:"multi,omitempty"`
161	Upsert     bool        `bson:"upsert,omitempty"`
162}
163
164type deleteOp struct {
165	Collection string      `bson:"-"` // "database.collection"
166	Selector   interface{} `bson:"q"`
167	Flags      uint32      `bson:"-"`
168	Limit      int         `bson:"limit"`
169}
170
171type killCursorsOp struct {
172	cursorIds []int64
173}
174
175type requestInfo struct {
176	bufferPos int
177	replyFunc replyFunc
178}
179
180func newSocket(server *mongoServer, conn net.Conn, timeout time.Duration) *mongoSocket {
181	socket := &mongoSocket{
182		conn:       conn,
183		addr:       server.Addr,
184		server:     server,
185		replyFuncs: make(map[uint32]replyFunc),
186	}
187	socket.gotNonce.L = &socket.Mutex
188	if err := socket.InitialAcquire(server.Info(), timeout); err != nil {
189		panic("newSocket: InitialAcquire returned error: " + err.Error())
190	}
191	stats.socketsAlive(+1)
192	debugf("Socket %p to %s: initialized", socket, socket.addr)
193	socket.resetNonce()
194	go socket.readLoop()
195	return socket
196}
197
198// Server returns the server that the socket is associated with.
199// It returns nil while the socket is cached in its respective server.
200func (socket *mongoSocket) Server() *mongoServer {
201	socket.Lock()
202	server := socket.server
203	socket.Unlock()
204	return server
205}
206
207// ServerInfo returns details for the server at the time the socket
208// was initially acquired.
209func (socket *mongoSocket) ServerInfo() *mongoServerInfo {
210	socket.Lock()
211	serverInfo := socket.serverInfo
212	socket.Unlock()
213	return serverInfo
214}
215
216// InitialAcquire obtains the first reference to the socket, either
217// right after the connection is made or once a recycled socket is
218// being put back in use.
219func (socket *mongoSocket) InitialAcquire(serverInfo *mongoServerInfo, timeout time.Duration) error {
220	socket.Lock()
221	if socket.references > 0 {
222		panic("Socket acquired out of cache with references")
223	}
224	if socket.dead != nil {
225		dead := socket.dead
226		socket.Unlock()
227		return dead
228	}
229	socket.references++
230	socket.serverInfo = serverInfo
231	socket.timeout = timeout
232	stats.socketsInUse(+1)
233	stats.socketRefs(+1)
234	socket.Unlock()
235	return nil
236}
237
238// Acquire obtains an additional reference to the socket.
239// The socket will only be recycled when it's released as many
240// times as it's been acquired.
241func (socket *mongoSocket) Acquire() (info *mongoServerInfo) {
242	socket.Lock()
243	if socket.references == 0 {
244		panic("Socket got non-initial acquire with references == 0")
245	}
246	// We'll track references to dead sockets as well.
247	// Caller is still supposed to release the socket.
248	socket.references++
249	stats.socketRefs(+1)
250	serverInfo := socket.serverInfo
251	socket.Unlock()
252	return serverInfo
253}
254
255// Release decrements a socket reference. The socket will be
256// recycled once its released as many times as it's been acquired.
257func (socket *mongoSocket) Release() {
258	socket.Lock()
259	if socket.references == 0 {
260		panic("socket.Release() with references == 0")
261	}
262	socket.references--
263	stats.socketRefs(-1)
264	if socket.references == 0 {
265		stats.socketsInUse(-1)
266		server := socket.server
267		socket.Unlock()
268		socket.LogoutAll()
269		// If the socket is dead server is nil.
270		if server != nil {
271			server.RecycleSocket(socket)
272		}
273	} else {
274		socket.Unlock()
275	}
276}
277
278// SetTimeout changes the timeout used on socket operations.
279func (socket *mongoSocket) SetTimeout(d time.Duration) {
280	socket.Lock()
281	socket.timeout = d
282	socket.Unlock()
283}
284
285type deadlineType int
286
287const (
288	readDeadline  deadlineType = 1
289	writeDeadline deadlineType = 2
290)
291
292func (socket *mongoSocket) updateDeadline(which deadlineType) {
293	var when time.Time
294	if socket.timeout > 0 {
295		when = time.Now().Add(socket.timeout)
296	}
297	whichstr := ""
298	switch which {
299	case readDeadline | writeDeadline:
300		whichstr = "read/write"
301		socket.conn.SetDeadline(when)
302	case readDeadline:
303		whichstr = "read"
304		socket.conn.SetReadDeadline(when)
305	case writeDeadline:
306		whichstr = "write"
307		socket.conn.SetWriteDeadline(when)
308	default:
309		panic("invalid parameter to updateDeadline")
310	}
311	debugf("Socket %p to %s: updated %s deadline to %s ahead (%s)", socket, socket.addr, whichstr, socket.timeout, when)
312}
313
314// Close terminates the socket use.
315func (socket *mongoSocket) Close() {
316	socket.kill(errors.New("Closed explicitly"), false)
317}
318
319func (socket *mongoSocket) kill(err error, abend bool) {
320	socket.Lock()
321	if socket.dead != nil {
322		debugf("Socket %p to %s: killed again: %s (previously: %s)", socket, socket.addr, err.Error(), socket.dead.Error())
323		socket.Unlock()
324		return
325	}
326	logf("Socket %p to %s: closing: %s (abend=%v)", socket, socket.addr, err.Error(), abend)
327	socket.dead = err
328	socket.conn.Close()
329	stats.socketsAlive(-1)
330	replyFuncs := socket.replyFuncs
331	socket.replyFuncs = make(map[uint32]replyFunc)
332	server := socket.server
333	socket.server = nil
334	socket.gotNonce.Broadcast()
335	socket.Unlock()
336	for _, replyFunc := range replyFuncs {
337		logf("Socket %p to %s: notifying replyFunc of closed socket: %s", socket, socket.addr, err.Error())
338		replyFunc(err, nil, -1, nil)
339	}
340	if abend {
341		server.AbendSocket(socket)
342	}
343}
344
345func (socket *mongoSocket) SimpleQuery(op *queryOp) (data []byte, err error) {
346	var wait, change sync.Mutex
347	var replyDone bool
348	var replyData []byte
349	var replyErr error
350	wait.Lock()
351	op.replyFunc = func(err error, reply *replyOp, docNum int, docData []byte) {
352		change.Lock()
353		if !replyDone {
354			replyDone = true
355			replyErr = err
356			if err == nil {
357				replyData = docData
358			}
359		}
360		change.Unlock()
361		wait.Unlock()
362	}
363	err = socket.Query(op)
364	if err != nil {
365		return nil, err
366	}
367	wait.Lock()
368	change.Lock()
369	data = replyData
370	err = replyErr
371	change.Unlock()
372	return data, err
373}
374
375func (socket *mongoSocket) Query(ops ...interface{}) (err error) {
376
377	if lops := socket.flushLogout(); len(lops) > 0 {
378		ops = append(lops, ops...)
379	}
380
381	buf := make([]byte, 0, 256)
382
383	// Serialize operations synchronously to avoid interrupting
384	// other goroutines while we can't really be sending data.
385	// Also, record id positions so that we can compute request
386	// ids at once later with the lock already held.
387	requests := make([]requestInfo, len(ops))
388	requestCount := 0
389
390	for _, op := range ops {
391		debugf("Socket %p to %s: serializing op: %#v", socket, socket.addr, op)
392		if qop, ok := op.(*queryOp); ok {
393			if cmd, ok := qop.query.(*findCmd); ok {
394				debugf("Socket %p to %s: find command: %#v", socket, socket.addr, cmd)
395			}
396		}
397		start := len(buf)
398		var replyFunc replyFunc
399		switch op := op.(type) {
400
401		case *updateOp:
402			buf = addHeader(buf, 2001)
403			buf = addInt32(buf, 0) // Reserved
404			buf = addCString(buf, op.Collection)
405			buf = addInt32(buf, int32(op.Flags))
406			debugf("Socket %p to %s: serializing selector document: %#v", socket, socket.addr, op.Selector)
407			buf, err = addBSON(buf, op.Selector)
408			if err != nil {
409				return err
410			}
411			debugf("Socket %p to %s: serializing update document: %#v", socket, socket.addr, op.Update)
412			buf, err = addBSON(buf, op.Update)
413			if err != nil {
414				return err
415			}
416
417		case *insertOp:
418			buf = addHeader(buf, 2002)
419			buf = addInt32(buf, int32(op.flags))
420			buf = addCString(buf, op.collection)
421			for _, doc := range op.documents {
422				debugf("Socket %p to %s: serializing document for insertion: %#v", socket, socket.addr, doc)
423				buf, err = addBSON(buf, doc)
424				if err != nil {
425					return err
426				}
427			}
428
429		case *queryOp:
430			buf = addHeader(buf, 2004)
431			buf = addInt32(buf, int32(op.flags))
432			buf = addCString(buf, op.collection)
433			buf = addInt32(buf, op.skip)
434			buf = addInt32(buf, op.limit)
435			buf, err = addBSON(buf, op.finalQuery(socket))
436			if err != nil {
437				return err
438			}
439			if op.selector != nil {
440				buf, err = addBSON(buf, op.selector)
441				if err != nil {
442					return err
443				}
444			}
445			replyFunc = op.replyFunc
446
447		case *getMoreOp:
448			buf = addHeader(buf, 2005)
449			buf = addInt32(buf, 0) // Reserved
450			buf = addCString(buf, op.collection)
451			buf = addInt32(buf, op.limit)
452			buf = addInt64(buf, op.cursorId)
453			replyFunc = op.replyFunc
454
455		case *deleteOp:
456			buf = addHeader(buf, 2006)
457			buf = addInt32(buf, 0) // Reserved
458			buf = addCString(buf, op.Collection)
459			buf = addInt32(buf, int32(op.Flags))
460			debugf("Socket %p to %s: serializing selector document: %#v", socket, socket.addr, op.Selector)
461			buf, err = addBSON(buf, op.Selector)
462			if err != nil {
463				return err
464			}
465
466		case *killCursorsOp:
467			buf = addHeader(buf, 2007)
468			buf = addInt32(buf, 0) // Reserved
469			buf = addInt32(buf, int32(len(op.cursorIds)))
470			for _, cursorId := range op.cursorIds {
471				buf = addInt64(buf, cursorId)
472			}
473
474		default:
475			panic("internal error: unknown operation type")
476		}
477
478		setInt32(buf, start, int32(len(buf)-start))
479
480		if replyFunc != nil {
481			request := &requests[requestCount]
482			request.replyFunc = replyFunc
483			request.bufferPos = start
484			requestCount++
485		}
486	}
487
488	// Buffer is ready for the pipe.  Lock, allocate ids, and enqueue.
489
490	socket.Lock()
491	if socket.dead != nil {
492		dead := socket.dead
493		socket.Unlock()
494		debugf("Socket %p to %s: failing query, already closed: %s", socket, socket.addr, socket.dead.Error())
495		// XXX This seems necessary in case the session is closed concurrently
496		// with a query being performed, but it's not yet tested:
497		for i := 0; i != requestCount; i++ {
498			request := &requests[i]
499			if request.replyFunc != nil {
500				request.replyFunc(dead, nil, -1, nil)
501			}
502		}
503		return dead
504	}
505
506	wasWaiting := len(socket.replyFuncs) > 0
507
508	// Reserve id 0 for requests which should have no responses.
509	requestId := socket.nextRequestId + 1
510	if requestId == 0 {
511		requestId++
512	}
513	socket.nextRequestId = requestId + uint32(requestCount)
514	for i := 0; i != requestCount; i++ {
515		request := &requests[i]
516		setInt32(buf, request.bufferPos+4, int32(requestId))
517		socket.replyFuncs[requestId] = request.replyFunc
518		requestId++
519	}
520
521	debugf("Socket %p to %s: sending %d op(s) (%d bytes)", socket, socket.addr, len(ops), len(buf))
522	stats.sentOps(len(ops))
523
524	socket.updateDeadline(writeDeadline)
525	_, err = socket.conn.Write(buf)
526	if !wasWaiting && requestCount > 0 {
527		socket.updateDeadline(readDeadline)
528	}
529	socket.Unlock()
530	return err
531}
532
533func fill(r net.Conn, b []byte) error {
534	l := len(b)
535	n, err := r.Read(b)
536	for n != l && err == nil {
537		var ni int
538		ni, err = r.Read(b[n:])
539		n += ni
540	}
541	return err
542}
543
544// Estimated minimum cost per socket: 1 goroutine + memory for the largest
545// document ever seen.
546func (socket *mongoSocket) readLoop() {
547	p := make([]byte, 36) // 16 from header + 20 from OP_REPLY fixed fields
548	s := make([]byte, 4)
549	conn := socket.conn // No locking, conn never changes.
550	for {
551		err := fill(conn, p)
552		if err != nil {
553			socket.kill(err, true)
554			return
555		}
556
557		totalLen := getInt32(p, 0)
558		responseTo := getInt32(p, 8)
559		opCode := getInt32(p, 12)
560
561		// Don't use socket.server.Addr here.  socket is not
562		// locked and socket.server may go away.
563		debugf("Socket %p to %s: got reply (%d bytes)", socket, socket.addr, totalLen)
564
565		_ = totalLen
566
567		if opCode != 1 {
568			socket.kill(errors.New("opcode != 1, corrupted data?"), true)
569			return
570		}
571
572		reply := replyOp{
573			flags:     uint32(getInt32(p, 16)),
574			cursorId:  getInt64(p, 20),
575			firstDoc:  getInt32(p, 28),
576			replyDocs: getInt32(p, 32),
577		}
578
579		stats.receivedOps(+1)
580		stats.receivedDocs(int(reply.replyDocs))
581
582		socket.Lock()
583		replyFunc, ok := socket.replyFuncs[uint32(responseTo)]
584		if ok {
585			delete(socket.replyFuncs, uint32(responseTo))
586		}
587		socket.Unlock()
588
589		if replyFunc != nil && reply.replyDocs == 0 {
590			replyFunc(nil, &reply, -1, nil)
591		} else {
592			for i := 0; i != int(reply.replyDocs); i++ {
593				err := fill(conn, s)
594				if err != nil {
595					if replyFunc != nil {
596						replyFunc(err, nil, -1, nil)
597					}
598					socket.kill(err, true)
599					return
600				}
601
602				b := make([]byte, int(getInt32(s, 0)))
603
604				// copy(b, s) in an efficient way.
605				b[0] = s[0]
606				b[1] = s[1]
607				b[2] = s[2]
608				b[3] = s[3]
609
610				err = fill(conn, b[4:])
611				if err != nil {
612					if replyFunc != nil {
613						replyFunc(err, nil, -1, nil)
614					}
615					socket.kill(err, true)
616					return
617				}
618
619				if globalDebug && globalLogger != nil {
620					m := bson.M{}
621					if err := bson.Unmarshal(b, m); err == nil {
622						debugf("Socket %p to %s: received document: %#v", socket, socket.addr, m)
623					}
624				}
625
626				if replyFunc != nil {
627					replyFunc(nil, &reply, i, b)
628				}
629
630				// XXX Do bound checking against totalLen.
631			}
632		}
633
634		socket.Lock()
635		if len(socket.replyFuncs) == 0 {
636			// Nothing else to read for now. Disable deadline.
637			socket.conn.SetReadDeadline(time.Time{})
638		} else {
639			socket.updateDeadline(readDeadline)
640		}
641		socket.Unlock()
642
643		// XXX Do bound checking against totalLen.
644	}
645}
646
647var emptyHeader = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
648
649func addHeader(b []byte, opcode int) []byte {
650	i := len(b)
651	b = append(b, emptyHeader...)
652	// Enough for current opcodes.
653	b[i+12] = byte(opcode)
654	b[i+13] = byte(opcode >> 8)
655	return b
656}
657
658func addInt32(b []byte, i int32) []byte {
659	return append(b, byte(i), byte(i>>8), byte(i>>16), byte(i>>24))
660}
661
662func addInt64(b []byte, i int64) []byte {
663	return append(b, byte(i), byte(i>>8), byte(i>>16), byte(i>>24),
664		byte(i>>32), byte(i>>40), byte(i>>48), byte(i>>56))
665}
666
667func addCString(b []byte, s string) []byte {
668	b = append(b, []byte(s)...)
669	b = append(b, 0)
670	return b
671}
672
673func addBSON(b []byte, doc interface{}) ([]byte, error) {
674	if doc == nil {
675		return append(b, 5, 0, 0, 0, 0), nil
676	}
677	data, err := bson.Marshal(doc)
678	if err != nil {
679		return b, err
680	}
681	return append(b, data...), nil
682}
683
684func setInt32(b []byte, pos int, i int32) {
685	b[pos] = byte(i)
686	b[pos+1] = byte(i >> 8)
687	b[pos+2] = byte(i >> 16)
688	b[pos+3] = byte(i >> 24)
689}
690
691func getInt32(b []byte, pos int) int32 {
692	return (int32(b[pos+0])) |
693		(int32(b[pos+1]) << 8) |
694		(int32(b[pos+2]) << 16) |
695		(int32(b[pos+3]) << 24)
696}
697
698func getInt64(b []byte, pos int) int64 {
699	return (int64(b[pos+0])) |
700		(int64(b[pos+1]) << 8) |
701		(int64(b[pos+2]) << 16) |
702		(int64(b[pos+3]) << 24) |
703		(int64(b[pos+4]) << 32) |
704		(int64(b[pos+5]) << 40) |
705		(int64(b[pos+6]) << 48) |
706		(int64(b[pos+7]) << 56)
707}
708