1// mgo - MongoDB driver for Go
2//
3// Copyright (c) 2010-2012 - Gustavo Niemeyer <gustavo@niemeyer.net>
4//
5// All rights reserved.
6//
7// Redistribution and use in source and binary forms, with or without
8// modification, are permitted provided that the following conditions are met:
9//
10// 1. Redistributions of source code must retain the above copyright notice, this
11//    list of conditions and the following disclaimer.
12// 2. Redistributions in binary form must reproduce the above copyright notice,
13//    this list of conditions and the following disclaimer in the documentation
14//    and/or other materials provided with the distribution.
15//
16// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
17// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
18// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
20// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
21// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
22// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
23// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26
27package mgo
28
29import (
30	"errors"
31	"net"
32	"sort"
33	"sync"
34	"time"
35
36	"gopkg.in/mgo.v2/bson"
37)
38
39// ---------------------------------------------------------------------------
40// Mongo server encapsulation.
41
42type mongoServer struct {
43	sync.RWMutex
44	Addr          string
45	ResolvedAddr  string
46	tcpaddr       *net.TCPAddr
47	unusedSockets []*mongoSocket
48	liveSockets   []*mongoSocket
49	closed        bool
50	abended       bool
51	sync          chan bool
52	dial          dialer
53	pingValue     time.Duration
54	pingIndex     int
55	pingCount     uint32
56	pingWindow    [6]time.Duration
57	info          *mongoServerInfo
58}
59
60type dialer struct {
61	old func(addr net.Addr) (net.Conn, error)
62	new func(addr *ServerAddr) (net.Conn, error)
63}
64
65func (dial dialer) isSet() bool {
66	return dial.old != nil || dial.new != nil
67}
68
69type mongoServerInfo struct {
70	Master         bool
71	Mongos         bool
72	Tags           bson.D
73	MaxWireVersion int
74	SetName        string
75}
76
77var defaultServerInfo mongoServerInfo
78
79func newServer(addr string, tcpaddr *net.TCPAddr, sync chan bool, dial dialer) *mongoServer {
80	server := &mongoServer{
81		Addr:         addr,
82		ResolvedAddr: tcpaddr.String(),
83		tcpaddr:      tcpaddr,
84		sync:         sync,
85		dial:         dial,
86		info:         &defaultServerInfo,
87		pingValue:    time.Hour, // Push it back before an actual ping.
88	}
89	go server.pinger(true)
90	return server
91}
92
93var errPoolLimit = errors.New("per-server connection limit reached")
94var errServerClosed = errors.New("server was closed")
95
96// AcquireSocket returns a socket for communicating with the server.
97// This will attempt to reuse an old connection, if one is available. Otherwise,
98// it will establish a new one. The returned socket is owned by the call site,
99// and will return to the cache when the socket has its Release method called
100// the same number of times as AcquireSocket + Acquire were called for it.
101// If the poolLimit argument is greater than zero and the number of sockets in
102// use in this server is greater than the provided limit, errPoolLimit is
103// returned.
104func (server *mongoServer) AcquireSocket(poolLimit int, timeout time.Duration) (socket *mongoSocket, abended bool, err error) {
105	for {
106		server.Lock()
107		abended = server.abended
108		if server.closed {
109			server.Unlock()
110			return nil, abended, errServerClosed
111		}
112		n := len(server.unusedSockets)
113		if poolLimit > 0 && len(server.liveSockets)-n >= poolLimit {
114			server.Unlock()
115			return nil, false, errPoolLimit
116		}
117		if n > 0 {
118			socket = server.unusedSockets[n-1]
119			server.unusedSockets[n-1] = nil // Help GC.
120			server.unusedSockets = server.unusedSockets[:n-1]
121			info := server.info
122			server.Unlock()
123			err = socket.InitialAcquire(info, timeout)
124			if err != nil {
125				continue
126			}
127		} else {
128			server.Unlock()
129			socket, err = server.Connect(timeout)
130			if err == nil {
131				server.Lock()
132				// We've waited for the Connect, see if we got
133				// closed in the meantime
134				if server.closed {
135					server.Unlock()
136					socket.Release()
137					socket.Close()
138					return nil, abended, errServerClosed
139				}
140				server.liveSockets = append(server.liveSockets, socket)
141				server.Unlock()
142			}
143		}
144		return
145	}
146	panic("unreachable")
147}
148
149// Connect establishes a new connection to the server. This should
150// generally be done through server.AcquireSocket().
151func (server *mongoServer) Connect(timeout time.Duration) (*mongoSocket, error) {
152	server.RLock()
153	master := server.info.Master
154	dial := server.dial
155	server.RUnlock()
156
157	logf("Establishing new connection to %s (timeout=%s)...", server.Addr, timeout)
158	var conn net.Conn
159	var err error
160	switch {
161	case !dial.isSet():
162		// Cannot do this because it lacks timeout support. :-(
163		//conn, err = net.DialTCP("tcp", nil, server.tcpaddr)
164		conn, err = net.DialTimeout("tcp", server.ResolvedAddr, timeout)
165		if tcpconn, ok := conn.(*net.TCPConn); ok {
166			tcpconn.SetKeepAlive(true)
167		} else if err == nil {
168			panic("internal error: obtained TCP connection is not a *net.TCPConn!?")
169		}
170	case dial.old != nil:
171		conn, err = dial.old(server.tcpaddr)
172	case dial.new != nil:
173		conn, err = dial.new(&ServerAddr{server.Addr, server.tcpaddr})
174	default:
175		panic("dialer is set, but both dial.old and dial.new are nil")
176	}
177	if err != nil {
178		logf("Connection to %s failed: %v", server.Addr, err.Error())
179		return nil, err
180	}
181	logf("Connection to %s established.", server.Addr)
182
183	stats.conn(+1, master)
184	return newSocket(server, conn, timeout), nil
185}
186
187// Close forces closing all sockets that are alive, whether
188// they're currently in use or not.
189func (server *mongoServer) Close() {
190	server.Lock()
191	server.closed = true
192	liveSockets := server.liveSockets
193	unusedSockets := server.unusedSockets
194	server.liveSockets = nil
195	server.unusedSockets = nil
196	server.Unlock()
197	logf("Connections to %s closing (%d live sockets).", server.Addr, len(liveSockets))
198	for i, s := range liveSockets {
199		s.Close()
200		liveSockets[i] = nil
201	}
202	for i := range unusedSockets {
203		unusedSockets[i] = nil
204	}
205}
206
207// RecycleSocket puts socket back into the unused cache.
208func (server *mongoServer) RecycleSocket(socket *mongoSocket) {
209	server.Lock()
210	if !server.closed {
211		server.unusedSockets = append(server.unusedSockets, socket)
212	}
213	server.Unlock()
214}
215
216func removeSocket(sockets []*mongoSocket, socket *mongoSocket) []*mongoSocket {
217	for i, s := range sockets {
218		if s == socket {
219			copy(sockets[i:], sockets[i+1:])
220			n := len(sockets) - 1
221			sockets[n] = nil
222			sockets = sockets[:n]
223			break
224		}
225	}
226	return sockets
227}
228
229// AbendSocket notifies the server that the given socket has terminated
230// abnormally, and thus should be discarded rather than cached.
231func (server *mongoServer) AbendSocket(socket *mongoSocket) {
232	server.Lock()
233	server.abended = true
234	if server.closed {
235		server.Unlock()
236		return
237	}
238	server.liveSockets = removeSocket(server.liveSockets, socket)
239	server.unusedSockets = removeSocket(server.unusedSockets, socket)
240	server.Unlock()
241	// Maybe just a timeout, but suggest a cluster sync up just in case.
242	select {
243	case server.sync <- true:
244	default:
245	}
246}
247
248func (server *mongoServer) SetInfo(info *mongoServerInfo) {
249	server.Lock()
250	server.info = info
251	server.Unlock()
252}
253
254func (server *mongoServer) Info() *mongoServerInfo {
255	server.Lock()
256	info := server.info
257	server.Unlock()
258	return info
259}
260
261func (server *mongoServer) hasTags(serverTags []bson.D) bool {
262NextTagSet:
263	for _, tags := range serverTags {
264	NextReqTag:
265		for _, req := range tags {
266			for _, has := range server.info.Tags {
267				if req.Name == has.Name {
268					if req.Value == has.Value {
269						continue NextReqTag
270					}
271					continue NextTagSet
272				}
273			}
274			continue NextTagSet
275		}
276		return true
277	}
278	return false
279}
280
281var pingDelay = 15 * time.Second
282
283func (server *mongoServer) pinger(loop bool) {
284	var delay time.Duration
285	if raceDetector {
286		// This variable is only ever touched by tests.
287		globalMutex.Lock()
288		delay = pingDelay
289		globalMutex.Unlock()
290	} else {
291		delay = pingDelay
292	}
293	op := queryOp{
294		collection: "admin.$cmd",
295		query:      bson.D{{"ping", 1}},
296		flags:      flagSlaveOk,
297		limit:      -1,
298	}
299	for {
300		if loop {
301			time.Sleep(delay)
302		}
303		op := op
304		socket, _, err := server.AcquireSocket(0, delay)
305		if err == nil {
306			start := time.Now()
307			_, _ = socket.SimpleQuery(&op)
308			delay := time.Now().Sub(start)
309
310			server.pingWindow[server.pingIndex] = delay
311			server.pingIndex = (server.pingIndex + 1) % len(server.pingWindow)
312			server.pingCount++
313			var max time.Duration
314			for i := 0; i < len(server.pingWindow) && uint32(i) < server.pingCount; i++ {
315				if server.pingWindow[i] > max {
316					max = server.pingWindow[i]
317				}
318			}
319			socket.Release()
320			server.Lock()
321			if server.closed {
322				loop = false
323			}
324			server.pingValue = max
325			server.Unlock()
326			logf("Ping for %s is %d ms", server.Addr, max/time.Millisecond)
327		} else if err == errServerClosed {
328			return
329		}
330		if !loop {
331			return
332		}
333	}
334}
335
336type mongoServerSlice []*mongoServer
337
338func (s mongoServerSlice) Len() int {
339	return len(s)
340}
341
342func (s mongoServerSlice) Less(i, j int) bool {
343	return s[i].ResolvedAddr < s[j].ResolvedAddr
344}
345
346func (s mongoServerSlice) Swap(i, j int) {
347	s[i], s[j] = s[j], s[i]
348}
349
350func (s mongoServerSlice) Sort() {
351	sort.Sort(s)
352}
353
354func (s mongoServerSlice) Search(resolvedAddr string) (i int, ok bool) {
355	n := len(s)
356	i = sort.Search(n, func(i int) bool {
357		return s[i].ResolvedAddr >= resolvedAddr
358	})
359	return i, i != n && s[i].ResolvedAddr == resolvedAddr
360}
361
362type mongoServers struct {
363	slice mongoServerSlice
364}
365
366func (servers *mongoServers) Search(resolvedAddr string) (server *mongoServer) {
367	if i, ok := servers.slice.Search(resolvedAddr); ok {
368		return servers.slice[i]
369	}
370	return nil
371}
372
373func (servers *mongoServers) Add(server *mongoServer) {
374	servers.slice = append(servers.slice, server)
375	servers.slice.Sort()
376}
377
378func (servers *mongoServers) Remove(other *mongoServer) (server *mongoServer) {
379	if i, found := servers.slice.Search(other.ResolvedAddr); found {
380		server = servers.slice[i]
381		copy(servers.slice[i:], servers.slice[i+1:])
382		n := len(servers.slice) - 1
383		servers.slice[n] = nil // Help GC.
384		servers.slice = servers.slice[:n]
385	}
386	return
387}
388
389func (servers *mongoServers) Slice() []*mongoServer {
390	return ([]*mongoServer)(servers.slice)
391}
392
393func (servers *mongoServers) Get(i int) *mongoServer {
394	return servers.slice[i]
395}
396
397func (servers *mongoServers) Len() int {
398	return len(servers.slice)
399}
400
401func (servers *mongoServers) Empty() bool {
402	return len(servers.slice) == 0
403}
404
405func (servers *mongoServers) HasMongos() bool {
406	for _, s := range servers.slice {
407		if s.Info().Mongos {
408			return true
409		}
410	}
411	return false
412}
413
414// BestFit returns the best guess of what would be the most interesting
415// server to perform operations on at this point in time.
416func (servers *mongoServers) BestFit(mode Mode, serverTags []bson.D) *mongoServer {
417	var best *mongoServer
418	for _, next := range servers.slice {
419		if best == nil {
420			best = next
421			best.RLock()
422			if serverTags != nil && !next.info.Mongos && !best.hasTags(serverTags) {
423				best.RUnlock()
424				best = nil
425			}
426			continue
427		}
428		next.RLock()
429		swap := false
430		switch {
431		case serverTags != nil && !next.info.Mongos && !next.hasTags(serverTags):
432			// Must have requested tags.
433		case mode == Secondary && next.info.Master && !next.info.Mongos:
434			// Must be a secondary or mongos.
435		case next.info.Master != best.info.Master && mode != Nearest:
436			// Prefer slaves, unless the mode is PrimaryPreferred.
437			swap = (mode == PrimaryPreferred) != best.info.Master
438		case absDuration(next.pingValue-best.pingValue) > 15*time.Millisecond:
439			// Prefer nearest server.
440			swap = next.pingValue < best.pingValue
441		case len(next.liveSockets)-len(next.unusedSockets) < len(best.liveSockets)-len(best.unusedSockets):
442			// Prefer servers with less connections.
443			swap = true
444		}
445		if swap {
446			best.RUnlock()
447			best = next
448		} else {
449			next.RUnlock()
450		}
451	}
452	if best != nil {
453		best.RUnlock()
454	}
455	return best
456}
457
458func absDuration(d time.Duration) time.Duration {
459	if d < 0 {
460		return -d
461	}
462	return d
463}
464