1package transport 2 3import ( 4 "fmt" 5 "sync" 6 "time" 7 8 "golang.org/x/net/context" 9 10 "google.golang.org/grpc" 11 "google.golang.org/grpc/codes" 12 13 "github.com/coreos/etcd/raft" 14 "github.com/coreos/etcd/raft/raftpb" 15 "github.com/docker/swarmkit/api" 16 "github.com/docker/swarmkit/log" 17 "github.com/docker/swarmkit/manager/state/raft/membership" 18 "github.com/pkg/errors" 19) 20 21const ( 22 // GRPCMaxMsgSize is the max allowed gRPC message size for raft messages. 23 GRPCMaxMsgSize = 4 << 20 24) 25 26type peer struct { 27 id uint64 28 29 tr *Transport 30 31 msgc chan raftpb.Message 32 33 ctx context.Context 34 cancel context.CancelFunc 35 done chan struct{} 36 37 mu sync.Mutex 38 cc *grpc.ClientConn 39 addr string 40 newAddr string 41 42 active bool 43 becameActive time.Time 44} 45 46func newPeer(id uint64, addr string, tr *Transport) (*peer, error) { 47 cc, err := tr.dial(addr) 48 if err != nil { 49 return nil, errors.Wrapf(err, "failed to create conn for %x with addr %s", id, addr) 50 } 51 ctx, cancel := context.WithCancel(tr.ctx) 52 ctx = log.WithField(ctx, "peer_id", fmt.Sprintf("%x", id)) 53 p := &peer{ 54 id: id, 55 addr: addr, 56 cc: cc, 57 tr: tr, 58 ctx: ctx, 59 cancel: cancel, 60 msgc: make(chan raftpb.Message, 4096), 61 done: make(chan struct{}), 62 } 63 go p.run(ctx) 64 return p, nil 65} 66 67func (p *peer) send(m raftpb.Message) (err error) { 68 p.mu.Lock() 69 defer func() { 70 if err != nil { 71 p.active = false 72 p.becameActive = time.Time{} 73 } 74 p.mu.Unlock() 75 }() 76 select { 77 case <-p.ctx.Done(): 78 return p.ctx.Err() 79 default: 80 } 81 select { 82 case p.msgc <- m: 83 case <-p.ctx.Done(): 84 return p.ctx.Err() 85 default: 86 p.tr.config.ReportUnreachable(p.id) 87 return errors.Errorf("peer is unreachable") 88 } 89 return nil 90} 91 92func (p *peer) update(addr string) error { 93 p.mu.Lock() 94 defer p.mu.Unlock() 95 if p.addr == addr { 96 return nil 97 } 98 cc, err := p.tr.dial(addr) 99 if err != nil { 100 return err 101 } 102 103 p.cc.Close() 104 p.cc = cc 105 p.addr = addr 106 return nil 107} 108 109func (p *peer) updateAddr(addr string) error { 110 p.mu.Lock() 111 defer p.mu.Unlock() 112 if p.addr == addr { 113 return nil 114 } 115 log.G(p.ctx).Debugf("peer %x updated to address %s, it will be used if old failed", p.id, addr) 116 p.newAddr = addr 117 return nil 118} 119 120func (p *peer) conn() *grpc.ClientConn { 121 p.mu.Lock() 122 defer p.mu.Unlock() 123 return p.cc 124} 125 126func (p *peer) address() string { 127 p.mu.Lock() 128 defer p.mu.Unlock() 129 return p.addr 130} 131 132func (p *peer) resolveAddr(ctx context.Context, id uint64) (string, error) { 133 resp, err := api.NewRaftClient(p.conn()).ResolveAddress(ctx, &api.ResolveAddressRequest{RaftID: id}) 134 if err != nil { 135 return "", errors.Wrap(err, "failed to resolve address") 136 } 137 return resp.Addr, nil 138} 139 140// Returns the raft message struct size (not including the payload size) for the given raftpb.Message. 141// The payload is typically the snapshot or append entries. 142func raftMessageStructSize(m *raftpb.Message) int { 143 return (&api.ProcessRaftMessageRequest{Message: m}).Size() - len(m.Snapshot.Data) 144} 145 146// Returns the max allowable payload based on MaxRaftMsgSize and 147// the struct size for the given raftpb.Message. 148func raftMessagePayloadSize(m *raftpb.Message) int { 149 return GRPCMaxMsgSize - raftMessageStructSize(m) 150} 151 152// Split a large raft message into smaller messages. 153// Currently this means splitting the []Snapshot.Data into chunks whose size 154// is dictacted by MaxRaftMsgSize. 155func splitSnapshotData(ctx context.Context, m *raftpb.Message) []api.StreamRaftMessageRequest { 156 var messages []api.StreamRaftMessageRequest 157 if m.Type != raftpb.MsgSnap { 158 return messages 159 } 160 161 // get the size of the data to be split. 162 size := len(m.Snapshot.Data) 163 164 // Get the max payload size. 165 payloadSize := raftMessagePayloadSize(m) 166 167 // split the snapshot into smaller messages. 168 for snapDataIndex := 0; snapDataIndex < size; { 169 chunkSize := size - snapDataIndex 170 if chunkSize > payloadSize { 171 chunkSize = payloadSize 172 } 173 174 raftMsg := *m 175 176 // sub-slice for this snapshot chunk. 177 raftMsg.Snapshot.Data = m.Snapshot.Data[snapDataIndex : snapDataIndex+chunkSize] 178 179 snapDataIndex += chunkSize 180 181 // add message to the list of messages to be sent. 182 msg := api.StreamRaftMessageRequest{Message: &raftMsg} 183 messages = append(messages, msg) 184 } 185 186 return messages 187} 188 189// Function to check if this message needs to be split to be streamed 190// (because it is larger than GRPCMaxMsgSize). 191// Returns true if the message type is MsgSnap 192// and size larger than MaxRaftMsgSize. 193func needsSplitting(m *raftpb.Message) bool { 194 raftMsg := api.ProcessRaftMessageRequest{Message: m} 195 return m.Type == raftpb.MsgSnap && raftMsg.Size() > GRPCMaxMsgSize 196} 197 198func (p *peer) sendProcessMessage(ctx context.Context, m raftpb.Message) error { 199 ctx, cancel := context.WithTimeout(ctx, p.tr.config.SendTimeout) 200 defer cancel() 201 202 var err error 203 var stream api.Raft_StreamRaftMessageClient 204 stream, err = api.NewRaftClient(p.conn()).StreamRaftMessage(ctx) 205 206 if err == nil { 207 // Split the message if needed. 208 // Currently only supported for MsgSnap. 209 var msgs []api.StreamRaftMessageRequest 210 if needsSplitting(&m) { 211 msgs = splitSnapshotData(ctx, &m) 212 } else { 213 raftMsg := api.StreamRaftMessageRequest{Message: &m} 214 msgs = append(msgs, raftMsg) 215 } 216 217 // Stream 218 for _, msg := range msgs { 219 err = stream.Send(&msg) 220 if err != nil { 221 log.G(ctx).WithError(err).Error("error streaming message to peer") 222 stream.CloseAndRecv() 223 break 224 } 225 } 226 227 // Finished sending all the messages. 228 // Close and receive response. 229 if err == nil { 230 _, err = stream.CloseAndRecv() 231 232 if err != nil { 233 log.G(ctx).WithError(err).Error("error receiving response") 234 } 235 } 236 } else { 237 log.G(ctx).WithError(err).Error("error sending message to peer") 238 } 239 240 // Try doing a regular rpc if the receiver doesn't support streaming. 241 if grpc.Code(err) == codes.Unimplemented { 242 _, err = api.NewRaftClient(p.conn()).ProcessRaftMessage(ctx, &api.ProcessRaftMessageRequest{Message: &m}) 243 } 244 245 // Handle errors. 246 if grpc.Code(err) == codes.NotFound && grpc.ErrorDesc(err) == membership.ErrMemberRemoved.Error() { 247 p.tr.config.NodeRemoved() 248 } 249 if m.Type == raftpb.MsgSnap { 250 if err != nil { 251 p.tr.config.ReportSnapshot(m.To, raft.SnapshotFailure) 252 } else { 253 p.tr.config.ReportSnapshot(m.To, raft.SnapshotFinish) 254 } 255 } 256 if err != nil { 257 p.tr.config.ReportUnreachable(m.To) 258 return err 259 } 260 return nil 261} 262 263func healthCheckConn(ctx context.Context, cc *grpc.ClientConn) error { 264 resp, err := api.NewHealthClient(cc).Check(ctx, &api.HealthCheckRequest{Service: "Raft"}) 265 if err != nil { 266 return errors.Wrap(err, "failed to check health") 267 } 268 if resp.Status != api.HealthCheckResponse_SERVING { 269 return errors.Errorf("health check returned status %s", resp.Status) 270 } 271 return nil 272} 273 274func (p *peer) healthCheck(ctx context.Context) error { 275 ctx, cancel := context.WithTimeout(ctx, p.tr.config.SendTimeout) 276 defer cancel() 277 return healthCheckConn(ctx, p.conn()) 278} 279 280func (p *peer) setActive() { 281 p.mu.Lock() 282 if !p.active { 283 p.active = true 284 p.becameActive = time.Now() 285 } 286 p.mu.Unlock() 287} 288 289func (p *peer) setInactive() { 290 p.mu.Lock() 291 p.active = false 292 p.becameActive = time.Time{} 293 p.mu.Unlock() 294} 295 296func (p *peer) activeTime() time.Time { 297 p.mu.Lock() 298 defer p.mu.Unlock() 299 return p.becameActive 300} 301 302func (p *peer) drain() error { 303 ctx, cancel := context.WithTimeout(context.Background(), 16*time.Second) 304 defer cancel() 305 for { 306 select { 307 case m, ok := <-p.msgc: 308 if !ok { 309 // all messages proceeded 310 return nil 311 } 312 if err := p.sendProcessMessage(ctx, m); err != nil { 313 return errors.Wrap(err, "send drain message") 314 } 315 case <-ctx.Done(): 316 return ctx.Err() 317 } 318 } 319} 320 321func (p *peer) handleAddressChange(ctx context.Context) error { 322 p.mu.Lock() 323 newAddr := p.newAddr 324 p.newAddr = "" 325 p.mu.Unlock() 326 if newAddr == "" { 327 return nil 328 } 329 cc, err := p.tr.dial(newAddr) 330 if err != nil { 331 return err 332 } 333 ctx, cancel := context.WithTimeout(ctx, p.tr.config.SendTimeout) 334 defer cancel() 335 if err := healthCheckConn(ctx, cc); err != nil { 336 cc.Close() 337 return err 338 } 339 // there is possibility of race if host changing address too fast, but 340 // it's unlikely and eventually thing should be settled 341 p.mu.Lock() 342 p.cc.Close() 343 p.cc = cc 344 p.addr = newAddr 345 p.tr.config.UpdateNode(p.id, p.addr) 346 p.mu.Unlock() 347 return nil 348} 349 350func (p *peer) run(ctx context.Context) { 351 defer func() { 352 p.mu.Lock() 353 p.active = false 354 p.becameActive = time.Time{} 355 // at this point we can be sure that nobody will write to msgc 356 if p.msgc != nil { 357 close(p.msgc) 358 } 359 p.mu.Unlock() 360 if err := p.drain(); err != nil { 361 log.G(ctx).WithError(err).Error("failed to drain message queue") 362 } 363 close(p.done) 364 }() 365 if err := p.healthCheck(ctx); err == nil { 366 p.setActive() 367 } 368 for { 369 select { 370 case <-ctx.Done(): 371 return 372 default: 373 } 374 375 select { 376 case m := <-p.msgc: 377 // we do not propagate context here, because this operation should be finished 378 // or timed out for correct raft work. 379 err := p.sendProcessMessage(context.Background(), m) 380 if err != nil { 381 log.G(ctx).WithError(err).Debugf("failed to send message %s", m.Type) 382 p.setInactive() 383 if err := p.handleAddressChange(ctx); err != nil { 384 log.G(ctx).WithError(err).Error("failed to change address after failure") 385 } 386 continue 387 } 388 p.setActive() 389 case <-ctx.Done(): 390 return 391 } 392 } 393} 394 395func (p *peer) stop() { 396 p.cancel() 397 <-p.done 398} 399