1package consul
2
3import (
4	"crypto/tls"
5	"fmt"
6	"io"
7	"net"
8	"strings"
9	"time"
10
11	"github.com/armon/go-metrics"
12	"github.com/hashicorp/consul/agent/consul/state"
13	"github.com/hashicorp/consul/agent/metadata"
14	"github.com/hashicorp/consul/agent/pool"
15	"github.com/hashicorp/consul/agent/structs"
16	"github.com/hashicorp/consul/lib"
17	memdb "github.com/hashicorp/go-memdb"
18	"github.com/hashicorp/memberlist"
19	"github.com/hashicorp/net-rpc-msgpackrpc"
20	"github.com/hashicorp/yamux"
21)
22
23const (
24	// maxQueryTime is used to bound the limit of a blocking query
25	maxQueryTime = 600 * time.Second
26
27	// defaultQueryTime is the amount of time we block waiting for a change
28	// if no time is specified. Previously we would wait the maxQueryTime.
29	defaultQueryTime = 300 * time.Second
30
31	// jitterFraction is a the limit to the amount of jitter we apply
32	// to a user specified MaxQueryTime. We divide the specified time by
33	// the fraction. So 16 == 6.25% limit of jitter. This same fraction
34	// is applied to the RPCHoldTimeout
35	jitterFraction = 16
36
37	// Warn if the Raft command is larger than this.
38	// If it's over 1MB something is probably being abusive.
39	raftWarnSize = 1024 * 1024
40
41	// enqueueLimit caps how long we will wait to enqueue
42	// a new Raft command. Something is probably wrong if this
43	// value is ever reached. However, it prevents us from blocking
44	// the requesting goroutine forever.
45	enqueueLimit = 30 * time.Second
46)
47
48// listen is used to listen for incoming RPC connections
49func (s *Server) listen(listener net.Listener) {
50	for {
51		// Accept a connection
52		conn, err := listener.Accept()
53		if err != nil {
54			if s.shutdown {
55				return
56			}
57			s.logger.Printf("[ERR] consul.rpc: failed to accept RPC conn: %v", err)
58			continue
59		}
60
61		go s.handleConn(conn, false)
62		metrics.IncrCounter([]string{"rpc", "accept_conn"}, 1)
63	}
64}
65
66// logConn is a wrapper around memberlist's LogConn so that we format references
67// to "from" addresses in a consistent way. This is just a shorter name.
68func logConn(conn net.Conn) string {
69	return memberlist.LogConn(conn)
70}
71
72// handleConn is used to determine if this is a Raft or
73// Consul type RPC connection and invoke the correct handler
74func (s *Server) handleConn(conn net.Conn, isTLS bool) {
75	// Read a single byte
76	buf := make([]byte, 1)
77	if _, err := conn.Read(buf); err != nil {
78		if err != io.EOF {
79			s.logger.Printf("[ERR] consul.rpc: failed to read byte: %v %s", err, logConn(conn))
80		}
81		conn.Close()
82		return
83	}
84	typ := pool.RPCType(buf[0])
85
86	// Enforce TLS if VerifyIncoming is set
87	if s.config.VerifyIncoming && !isTLS && typ != pool.RPCTLS {
88		s.logger.Printf("[WARN] consul.rpc: Non-TLS connection attempted with VerifyIncoming set %s", logConn(conn))
89		conn.Close()
90		return
91	}
92
93	// Switch on the byte
94	switch typ {
95	case pool.RPCConsul:
96		s.handleConsulConn(conn)
97
98	case pool.RPCRaft:
99		metrics.IncrCounter([]string{"rpc", "raft_handoff"}, 1)
100		s.raftLayer.Handoff(conn)
101
102	case pool.RPCTLS:
103		if s.rpcTLS == nil {
104			s.logger.Printf("[WARN] consul.rpc: TLS connection attempted, server not configured for TLS %s", logConn(conn))
105			conn.Close()
106			return
107		}
108		conn = tls.Server(conn, s.rpcTLS)
109		s.handleConn(conn, true)
110
111	case pool.RPCMultiplexV2:
112		s.handleMultiplexV2(conn)
113
114	case pool.RPCSnapshot:
115		s.handleSnapshotConn(conn)
116
117	default:
118		if !s.handleEnterpriseRPCConn(typ, conn, isTLS) {
119			s.logger.Printf("[ERR] consul.rpc: unrecognized RPC byte: %v %s", typ, logConn(conn))
120			conn.Close()
121		}
122	}
123}
124
125// handleMultiplexV2 is used to multiplex a single incoming connection
126// using the Yamux multiplexer
127func (s *Server) handleMultiplexV2(conn net.Conn) {
128	defer conn.Close()
129	conf := yamux.DefaultConfig()
130	conf.LogOutput = s.config.LogOutput
131	server, _ := yamux.Server(conn, conf)
132	for {
133		sub, err := server.Accept()
134		if err != nil {
135			if err != io.EOF {
136				s.logger.Printf("[ERR] consul.rpc: multiplex conn accept failed: %v %s", err, logConn(conn))
137			}
138			return
139		}
140		go s.handleConsulConn(sub)
141	}
142}
143
144// handleConsulConn is used to service a single Consul RPC connection
145func (s *Server) handleConsulConn(conn net.Conn) {
146	defer conn.Close()
147	rpcCodec := msgpackrpc.NewServerCodec(conn)
148	for {
149		select {
150		case <-s.shutdownCh:
151			return
152		default:
153		}
154
155		if err := s.rpcServer.ServeRequest(rpcCodec); err != nil {
156			if err != io.EOF && !strings.Contains(err.Error(), "closed") {
157				s.logger.Printf("[ERR] consul.rpc: RPC error: %v %s", err, logConn(conn))
158				metrics.IncrCounter([]string{"rpc", "request_error"}, 1)
159			}
160			return
161		}
162		metrics.IncrCounter([]string{"rpc", "request"}, 1)
163	}
164}
165
166// handleSnapshotConn is used to dispatch snapshot saves and restores, which
167// stream so don't use the normal RPC mechanism.
168func (s *Server) handleSnapshotConn(conn net.Conn) {
169	go func() {
170		defer conn.Close()
171		if err := s.handleSnapshotRequest(conn); err != nil {
172			s.logger.Printf("[ERR] consul.rpc: Snapshot RPC error: %v %s", err, logConn(conn))
173		}
174	}()
175}
176
177// canRetry returns true if the given situation is safe for a retry.
178func canRetry(args interface{}, err error) bool {
179	// No leader errors are always safe to retry since no state could have
180	// been changed.
181	if structs.IsErrNoLeader(err) {
182		return true
183	}
184
185	// Reads are safe to retry for stream errors, such as if a server was
186	// being shut down.
187	info, ok := args.(structs.RPCInfo)
188	if ok && info.IsRead() && lib.IsErrEOF(err) {
189		return true
190	}
191
192	return false
193}
194
195// forward is used to forward to a remote DC or to forward to the local leader
196// Returns a bool of if forwarding was performed, as well as any error
197func (s *Server) forward(method string, info structs.RPCInfo, args interface{}, reply interface{}) (bool, error) {
198	var firstCheck time.Time
199
200	// Handle DC forwarding
201	dc := info.RequestDatacenter()
202	if dc != s.config.Datacenter {
203		err := s.forwardDC(method, dc, args, reply)
204		return true, err
205	}
206
207	// Check if we can allow a stale read, ensure our local DB is initialized
208	if info.IsRead() && info.AllowStaleRead() && !s.raft.LastContact().IsZero() {
209		return false, nil
210	}
211
212CHECK_LEADER:
213	// Fail fast if we are in the process of leaving
214	select {
215	case <-s.leaveCh:
216		return true, structs.ErrNoLeader
217	default:
218	}
219
220	// Find the leader
221	isLeader, leader := s.getLeader()
222
223	// Handle the case we are the leader
224	if isLeader {
225		return false, nil
226	}
227
228	// Handle the case of a known leader
229	rpcErr := structs.ErrNoLeader
230	if leader != nil {
231		rpcErr = s.connPool.RPC(s.config.Datacenter, leader.Addr,
232			leader.Version, method, leader.UseTLS, args, reply)
233		if rpcErr != nil && canRetry(info, rpcErr) {
234			goto RETRY
235		}
236		return true, rpcErr
237	}
238
239RETRY:
240	// Gate the request until there is a leader
241	if firstCheck.IsZero() {
242		firstCheck = time.Now()
243	}
244	if time.Since(firstCheck) < s.config.RPCHoldTimeout {
245		jitter := lib.RandomStagger(s.config.RPCHoldTimeout / jitterFraction)
246		select {
247		case <-time.After(jitter):
248			goto CHECK_LEADER
249		case <-s.leaveCh:
250		case <-s.shutdownCh:
251		}
252	}
253
254	// No leader found and hold time exceeded
255	return true, rpcErr
256}
257
258// getLeader returns if the current node is the leader, and if not then it
259// returns the leader which is potentially nil if the cluster has not yet
260// elected a leader.
261func (s *Server) getLeader() (bool, *metadata.Server) {
262	// Check if we are the leader
263	if s.IsLeader() {
264		return true, nil
265	}
266
267	// Get the leader
268	leader := s.raft.Leader()
269	if leader == "" {
270		return false, nil
271	}
272
273	// Lookup the server
274	server := s.serverLookup.Server(leader)
275
276	// Server could be nil
277	return false, server
278}
279
280// forwardDC is used to forward an RPC call to a remote DC, or fail if no servers
281func (s *Server) forwardDC(method, dc string, args interface{}, reply interface{}) error {
282	manager, server, ok := s.router.FindRoute(dc)
283	if !ok {
284		s.logger.Printf("[WARN] consul.rpc: RPC request for DC %q, no path found", dc)
285		return structs.ErrNoDCPath
286	}
287
288	metrics.IncrCounterWithLabels([]string{"rpc", "cross-dc"}, 1,
289		[]metrics.Label{{Name: "datacenter", Value: dc}})
290	if err := s.connPool.RPC(dc, server.Addr, server.Version, method, server.UseTLS, args, reply); err != nil {
291		manager.NotifyFailedServer(server)
292		s.logger.Printf("[ERR] consul: RPC failed to server %s in DC %q: %v", server.Addr, dc, err)
293		return err
294	}
295
296	return nil
297}
298
299// globalRPC is used to forward an RPC request to one server in each datacenter.
300// This will only error for RPC-related errors. Otherwise, application-level
301// errors can be sent in the response objects.
302func (s *Server) globalRPC(method string, args interface{},
303	reply structs.CompoundResponse) error {
304
305	// Make a new request into each datacenter
306	dcs := s.router.GetDatacenters()
307
308	replies, total := 0, len(dcs)
309	errorCh := make(chan error, total)
310	respCh := make(chan interface{}, total)
311
312	for _, dc := range dcs {
313		go func(dc string) {
314			rr := reply.New()
315			if err := s.forwardDC(method, dc, args, &rr); err != nil {
316				errorCh <- err
317				return
318			}
319			respCh <- rr
320		}(dc)
321	}
322
323	for replies < total {
324		select {
325		case err := <-errorCh:
326			return err
327		case rr := <-respCh:
328			reply.Add(rr)
329			replies++
330		}
331	}
332	return nil
333}
334
335// raftApply is used to encode a message, run it through raft, and return
336// the FSM response along with any errors
337func (s *Server) raftApply(t structs.MessageType, msg interface{}) (interface{}, error) {
338	buf, err := structs.Encode(t, msg)
339	if err != nil {
340		return nil, fmt.Errorf("Failed to encode request: %v", err)
341	}
342
343	// Warn if the command is very large
344	if n := len(buf); n > raftWarnSize {
345		s.logger.Printf("[WARN] consul: Attempting to apply large raft entry (%d bytes)", n)
346	}
347
348	future := s.raft.Apply(buf, enqueueLimit)
349	if err := future.Error(); err != nil {
350		return nil, err
351	}
352
353	return future.Response(), nil
354}
355
356// queryFn is used to perform a query operation. If a re-query is needed, the
357// passed-in watch set will be used to block for changes. The passed-in state
358// store should be used (vs. calling fsm.State()) since the given state store
359// will be correctly watched for changes if the state store is restored from
360// a snapshot.
361type queryFn func(memdb.WatchSet, *state.Store) error
362
363// blockingQuery is used to process a potentially blocking query operation.
364func (s *Server) blockingQuery(queryOpts *structs.QueryOptions, queryMeta *structs.QueryMeta,
365	fn queryFn) error {
366	var timeout *time.Timer
367
368	// Fast path right to the non-blocking query.
369	if queryOpts.MinQueryIndex == 0 {
370		goto RUN_QUERY
371	}
372
373	// Restrict the max query time, and ensure there is always one.
374	if queryOpts.MaxQueryTime > maxQueryTime {
375		queryOpts.MaxQueryTime = maxQueryTime
376	} else if queryOpts.MaxQueryTime <= 0 {
377		queryOpts.MaxQueryTime = defaultQueryTime
378	}
379
380	// Apply a small amount of jitter to the request.
381	queryOpts.MaxQueryTime += lib.RandomStagger(queryOpts.MaxQueryTime / jitterFraction)
382
383	// Setup a query timeout.
384	timeout = time.NewTimer(queryOpts.MaxQueryTime)
385	defer timeout.Stop()
386
387RUN_QUERY:
388	// Update the query metadata.
389	s.setQueryMeta(queryMeta)
390
391	// If the read must be consistent we verify that we are still the leader.
392	if queryOpts.RequireConsistent {
393		if err := s.consistentRead(); err != nil {
394			return err
395		}
396	}
397
398	// Run the query.
399	metrics.IncrCounter([]string{"rpc", "query"}, 1)
400
401	// Operate on a consistent set of state. This makes sure that the
402	// abandon channel goes with the state that the caller is using to
403	// build watches.
404	state := s.fsm.State()
405
406	// We can skip all watch tracking if this isn't a blocking query.
407	var ws memdb.WatchSet
408	if queryOpts.MinQueryIndex > 0 {
409		ws = memdb.NewWatchSet()
410
411		// This channel will be closed if a snapshot is restored and the
412		// whole state store is abandoned.
413		ws.Add(state.AbandonCh())
414	}
415
416	// Block up to the timeout if we didn't see anything fresh.
417	err := fn(ws, state)
418	// Note we check queryOpts.MinQueryIndex is greater than zero to determine if
419	// blocking was requested by client, NOT meta.Index since the state function
420	// might return zero if something is not initialised and care wasn't taken to
421	// handle that special case (in practice this happened a lot so fixing it
422	// systematically here beats trying to remember to add zero checks in every
423	// state method). We also need to ensure that unless there is an error, we
424	// return an index > 0 otherwise the client will never block and burn CPU and
425	// requests.
426	if err == nil && queryMeta.Index < 1 {
427		queryMeta.Index = 1
428	}
429	if err == nil && queryOpts.MinQueryIndex > 0 && queryMeta.Index <= queryOpts.MinQueryIndex {
430		if expired := ws.Watch(timeout.C); !expired {
431			// If a restore may have woken us up then bail out from
432			// the query immediately. This is slightly race-ey since
433			// this might have been interrupted for other reasons,
434			// but it's OK to kick it back to the caller in either
435			// case.
436			select {
437			case <-state.AbandonCh():
438			default:
439				goto RUN_QUERY
440			}
441		}
442	}
443	return err
444}
445
446// setQueryMeta is used to populate the QueryMeta data for an RPC call
447func (s *Server) setQueryMeta(m *structs.QueryMeta) {
448	if s.IsLeader() {
449		m.LastContact = 0
450		m.KnownLeader = true
451	} else {
452		m.LastContact = time.Since(s.raft.LastContact())
453		m.KnownLeader = (s.raft.Leader() != "")
454	}
455}
456
457// consistentRead is used to ensure we do not perform a stale
458// read. This is done by verifying leadership before the read.
459func (s *Server) consistentRead() error {
460	defer metrics.MeasureSince([]string{"rpc", "consistentRead"}, time.Now())
461	future := s.raft.VerifyLeader()
462	if err := future.Error(); err != nil {
463		return err //fail fast if leader verification fails
464	}
465	// poll consistent read readiness, wait for up to RPCHoldTimeout milliseconds
466	if s.isReadyForConsistentReads() {
467		return nil
468	}
469	jitter := lib.RandomStagger(s.config.RPCHoldTimeout / jitterFraction)
470	deadline := time.Now().Add(s.config.RPCHoldTimeout)
471
472	for time.Now().Before(deadline) {
473
474		select {
475		case <-time.After(jitter):
476			// Drop through and check before we loop again.
477
478		case <-s.shutdownCh:
479			return fmt.Errorf("shutdown waiting for leader")
480		}
481
482		if s.isReadyForConsistentReads() {
483			return nil
484		}
485	}
486
487	return structs.ErrNotReadyForConsistentReads
488}
489