1package transport
2
3import (
4	"fmt"
5	"sync"
6	"time"
7
8	"golang.org/x/net/context"
9
10	"google.golang.org/grpc"
11	"google.golang.org/grpc/codes"
12
13	"github.com/coreos/etcd/raft"
14	"github.com/coreos/etcd/raft/raftpb"
15	"github.com/docker/swarmkit/api"
16	"github.com/docker/swarmkit/log"
17	"github.com/docker/swarmkit/manager/state/raft/membership"
18	"github.com/pkg/errors"
19)
20
21const (
22	// GRPCMaxMsgSize is the max allowed gRPC message size for raft messages.
23	GRPCMaxMsgSize = 4 << 20
24)
25
26type peer struct {
27	id uint64
28
29	tr *Transport
30
31	msgc chan raftpb.Message
32
33	ctx    context.Context
34	cancel context.CancelFunc
35	done   chan struct{}
36
37	mu      sync.Mutex
38	cc      *grpc.ClientConn
39	addr    string
40	newAddr string
41
42	active       bool
43	becameActive time.Time
44}
45
46func newPeer(id uint64, addr string, tr *Transport) (*peer, error) {
47	cc, err := tr.dial(addr)
48	if err != nil {
49		return nil, errors.Wrapf(err, "failed to create conn for %x with addr %s", id, addr)
50	}
51	ctx, cancel := context.WithCancel(tr.ctx)
52	ctx = log.WithField(ctx, "peer_id", fmt.Sprintf("%x", id))
53	p := &peer{
54		id:     id,
55		addr:   addr,
56		cc:     cc,
57		tr:     tr,
58		ctx:    ctx,
59		cancel: cancel,
60		msgc:   make(chan raftpb.Message, 4096),
61		done:   make(chan struct{}),
62	}
63	go p.run(ctx)
64	return p, nil
65}
66
67func (p *peer) send(m raftpb.Message) (err error) {
68	p.mu.Lock()
69	defer func() {
70		if err != nil {
71			p.active = false
72			p.becameActive = time.Time{}
73		}
74		p.mu.Unlock()
75	}()
76	select {
77	case <-p.ctx.Done():
78		return p.ctx.Err()
79	default:
80	}
81	select {
82	case p.msgc <- m:
83	case <-p.ctx.Done():
84		return p.ctx.Err()
85	default:
86		p.tr.config.ReportUnreachable(p.id)
87		return errors.Errorf("peer is unreachable")
88	}
89	return nil
90}
91
92func (p *peer) update(addr string) error {
93	p.mu.Lock()
94	defer p.mu.Unlock()
95	if p.addr == addr {
96		return nil
97	}
98	cc, err := p.tr.dial(addr)
99	if err != nil {
100		return err
101	}
102
103	p.cc.Close()
104	p.cc = cc
105	p.addr = addr
106	return nil
107}
108
109func (p *peer) updateAddr(addr string) error {
110	p.mu.Lock()
111	defer p.mu.Unlock()
112	if p.addr == addr {
113		return nil
114	}
115	log.G(p.ctx).Debugf("peer %x updated to address %s, it will be used if old failed", p.id, addr)
116	p.newAddr = addr
117	return nil
118}
119
120func (p *peer) conn() *grpc.ClientConn {
121	p.mu.Lock()
122	defer p.mu.Unlock()
123	return p.cc
124}
125
126func (p *peer) address() string {
127	p.mu.Lock()
128	defer p.mu.Unlock()
129	return p.addr
130}
131
132func (p *peer) resolveAddr(ctx context.Context, id uint64) (string, error) {
133	resp, err := api.NewRaftClient(p.conn()).ResolveAddress(ctx, &api.ResolveAddressRequest{RaftID: id})
134	if err != nil {
135		return "", errors.Wrap(err, "failed to resolve address")
136	}
137	return resp.Addr, nil
138}
139
140// Returns the raft message struct size (not including the payload size) for the given raftpb.Message.
141// The payload is typically the snapshot or append entries.
142func raftMessageStructSize(m *raftpb.Message) int {
143	return (&api.ProcessRaftMessageRequest{Message: m}).Size() - len(m.Snapshot.Data)
144}
145
146// Returns the max allowable payload based on MaxRaftMsgSize and
147// the struct size for the given raftpb.Message.
148func raftMessagePayloadSize(m *raftpb.Message) int {
149	return GRPCMaxMsgSize - raftMessageStructSize(m)
150}
151
152// Split a large raft message into smaller messages.
153// Currently this means splitting the []Snapshot.Data into chunks whose size
154// is dictacted by MaxRaftMsgSize.
155func splitSnapshotData(ctx context.Context, m *raftpb.Message) []api.StreamRaftMessageRequest {
156	var messages []api.StreamRaftMessageRequest
157	if m.Type != raftpb.MsgSnap {
158		return messages
159	}
160
161	// get the size of the data to be split.
162	size := len(m.Snapshot.Data)
163
164	// Get the max payload size.
165	payloadSize := raftMessagePayloadSize(m)
166
167	// split the snapshot into smaller messages.
168	for snapDataIndex := 0; snapDataIndex < size; {
169		chunkSize := size - snapDataIndex
170		if chunkSize > payloadSize {
171			chunkSize = payloadSize
172		}
173
174		raftMsg := *m
175
176		// sub-slice for this snapshot chunk.
177		raftMsg.Snapshot.Data = m.Snapshot.Data[snapDataIndex : snapDataIndex+chunkSize]
178
179		snapDataIndex += chunkSize
180
181		// add message to the list of messages to be sent.
182		msg := api.StreamRaftMessageRequest{Message: &raftMsg}
183		messages = append(messages, msg)
184	}
185
186	return messages
187}
188
189// Function to check if this message needs to be split to be streamed
190// (because it is larger than GRPCMaxMsgSize).
191// Returns true if the message type is MsgSnap
192// and size larger than MaxRaftMsgSize.
193func needsSplitting(m *raftpb.Message) bool {
194	raftMsg := api.ProcessRaftMessageRequest{Message: m}
195	return m.Type == raftpb.MsgSnap && raftMsg.Size() > GRPCMaxMsgSize
196}
197
198func (p *peer) sendProcessMessage(ctx context.Context, m raftpb.Message) error {
199	ctx, cancel := context.WithTimeout(ctx, p.tr.config.SendTimeout)
200	defer cancel()
201
202	var err error
203	var stream api.Raft_StreamRaftMessageClient
204	stream, err = api.NewRaftClient(p.conn()).StreamRaftMessage(ctx)
205
206	if err == nil {
207		// Split the message if needed.
208		// Currently only supported for MsgSnap.
209		var msgs []api.StreamRaftMessageRequest
210		if needsSplitting(&m) {
211			msgs = splitSnapshotData(ctx, &m)
212		} else {
213			raftMsg := api.StreamRaftMessageRequest{Message: &m}
214			msgs = append(msgs, raftMsg)
215		}
216
217		// Stream
218		for _, msg := range msgs {
219			err = stream.Send(&msg)
220			if err != nil {
221				log.G(ctx).WithError(err).Error("error streaming message to peer")
222				stream.CloseAndRecv()
223				break
224			}
225		}
226
227		// Finished sending all the messages.
228		// Close and receive response.
229		if err == nil {
230			_, err = stream.CloseAndRecv()
231
232			if err != nil {
233				log.G(ctx).WithError(err).Error("error receiving response")
234			}
235		}
236	} else {
237		log.G(ctx).WithError(err).Error("error sending message to peer")
238	}
239
240	// Try doing a regular rpc if the receiver doesn't support streaming.
241	if grpc.Code(err) == codes.Unimplemented {
242		_, err = api.NewRaftClient(p.conn()).ProcessRaftMessage(ctx, &api.ProcessRaftMessageRequest{Message: &m})
243	}
244
245	// Handle errors.
246	if grpc.Code(err) == codes.NotFound && grpc.ErrorDesc(err) == membership.ErrMemberRemoved.Error() {
247		p.tr.config.NodeRemoved()
248	}
249	if m.Type == raftpb.MsgSnap {
250		if err != nil {
251			p.tr.config.ReportSnapshot(m.To, raft.SnapshotFailure)
252		} else {
253			p.tr.config.ReportSnapshot(m.To, raft.SnapshotFinish)
254		}
255	}
256	if err != nil {
257		p.tr.config.ReportUnreachable(m.To)
258		return err
259	}
260	return nil
261}
262
263func healthCheckConn(ctx context.Context, cc *grpc.ClientConn) error {
264	resp, err := api.NewHealthClient(cc).Check(ctx, &api.HealthCheckRequest{Service: "Raft"})
265	if err != nil {
266		return errors.Wrap(err, "failed to check health")
267	}
268	if resp.Status != api.HealthCheckResponse_SERVING {
269		return errors.Errorf("health check returned status %s", resp.Status)
270	}
271	return nil
272}
273
274func (p *peer) healthCheck(ctx context.Context) error {
275	ctx, cancel := context.WithTimeout(ctx, p.tr.config.SendTimeout)
276	defer cancel()
277	return healthCheckConn(ctx, p.conn())
278}
279
280func (p *peer) setActive() {
281	p.mu.Lock()
282	if !p.active {
283		p.active = true
284		p.becameActive = time.Now()
285	}
286	p.mu.Unlock()
287}
288
289func (p *peer) setInactive() {
290	p.mu.Lock()
291	p.active = false
292	p.becameActive = time.Time{}
293	p.mu.Unlock()
294}
295
296func (p *peer) activeTime() time.Time {
297	p.mu.Lock()
298	defer p.mu.Unlock()
299	return p.becameActive
300}
301
302func (p *peer) drain() error {
303	ctx, cancel := context.WithTimeout(context.Background(), 16*time.Second)
304	defer cancel()
305	for {
306		select {
307		case m, ok := <-p.msgc:
308			if !ok {
309				// all messages proceeded
310				return nil
311			}
312			if err := p.sendProcessMessage(ctx, m); err != nil {
313				return errors.Wrap(err, "send drain message")
314			}
315		case <-ctx.Done():
316			return ctx.Err()
317		}
318	}
319}
320
321func (p *peer) handleAddressChange(ctx context.Context) error {
322	p.mu.Lock()
323	newAddr := p.newAddr
324	p.newAddr = ""
325	p.mu.Unlock()
326	if newAddr == "" {
327		return nil
328	}
329	cc, err := p.tr.dial(newAddr)
330	if err != nil {
331		return err
332	}
333	ctx, cancel := context.WithTimeout(ctx, p.tr.config.SendTimeout)
334	defer cancel()
335	if err := healthCheckConn(ctx, cc); err != nil {
336		cc.Close()
337		return err
338	}
339	// there is possibility of race if host changing address too fast, but
340	// it's unlikely and eventually thing should be settled
341	p.mu.Lock()
342	p.cc.Close()
343	p.cc = cc
344	p.addr = newAddr
345	p.tr.config.UpdateNode(p.id, p.addr)
346	p.mu.Unlock()
347	return nil
348}
349
350func (p *peer) run(ctx context.Context) {
351	defer func() {
352		p.mu.Lock()
353		p.active = false
354		p.becameActive = time.Time{}
355		// at this point we can be sure that nobody will write to msgc
356		if p.msgc != nil {
357			close(p.msgc)
358		}
359		p.mu.Unlock()
360		if err := p.drain(); err != nil {
361			log.G(ctx).WithError(err).Error("failed to drain message queue")
362		}
363		close(p.done)
364	}()
365	if err := p.healthCheck(ctx); err == nil {
366		p.setActive()
367	}
368	for {
369		select {
370		case <-ctx.Done():
371			return
372		default:
373		}
374
375		select {
376		case m := <-p.msgc:
377			// we do not propagate context here, because this operation should be finished
378			// or timed out for correct raft work.
379			err := p.sendProcessMessage(context.Background(), m)
380			if err != nil {
381				log.G(ctx).WithError(err).Debugf("failed to send message %s", m.Type)
382				p.setInactive()
383				if err := p.handleAddressChange(ctx); err != nil {
384					log.G(ctx).WithError(err).Error("failed to change address after failure")
385				}
386				continue
387			}
388			p.setActive()
389		case <-ctx.Done():
390			return
391		}
392	}
393}
394
395func (p *peer) stop() {
396	p.cancel()
397	<-p.done
398}
399