1package rpc 2 3import ( 4 "fmt" 5 "io" 6 "net" 7 "sync" 8) 9 10type WrapErrorFunc func(error) interface{} 11 12type Transporter interface { 13 // IsConnected returns false when incoming packets have 14 // finished processing. 15 // 16 // TODO: Use a better name. 17 IsConnected() bool 18 19 registerProtocol(p Protocol) error 20 21 getDispatcher() (dispatcher, error) 22 getReceiver() (receiver, error) 23 24 // receiveFrames starts processing incoming frames in a 25 // background goroutine, if it's not already happening. 26 // Returns the result of done(), for convenience. 27 receiveFrames() <-chan struct{} 28 29 // Returns a channel that's closed when incoming frames have 30 // finished processing, either due to an error or the 31 // underlying connection being closed. Successive calls to 32 // done() return the same value. 33 done() <-chan struct{} 34 35 // err returns a non-nil error value after done() is closed. 36 // After done() is closed, successive calls to err return the 37 // same value. 38 err() error 39 40 // Close closes the transport and releases resources. 41 Close() 42} 43 44var _ Transporter = (*transport)(nil) 45 46type transport struct { 47 c net.Conn 48 enc *framedMsgpackEncoder 49 dispatcher dispatcher 50 receiver receiver 51 packetizer *packetizer 52 protocols *protocolHandler 53 calls *callContainer 54 log LogInterface 55 closeOnce sync.Once 56 startOnce sync.Once 57 stopCh chan struct{} 58 59 // Filled in right before stopCh is closed. 60 stopErr error 61} 62 63// DefaultMaxFrameLength (100 MiB) is a reasonable default value for 64// the maxFrameLength parameter in NewTransporter. 65const DefaultMaxFrameLength = 100 * 1024 * 1024 66 67// NewTransport creates a new Transporter from the given connection 68// and parameters. Both sides of a connection should use the same 69// number for maxFrameLength. 70func NewTransport(c net.Conn, l LogFactory, instrumenterStorage NetworkInstrumenterStorage, wef WrapErrorFunc, maxFrameLength int32) Transporter { 71 if maxFrameLength <= 0 { 72 panic(fmt.Sprintf("maxFrameLength must be positive: got %d", maxFrameLength)) 73 } 74 75 if l == nil { 76 l = NewSimpleLogFactory(nil, nil) 77 } 78 log := l.NewLog(c.RemoteAddr()) 79 if instrumenterStorage == nil { 80 instrumenterStorage = NewDummyInstrumentationStorage() 81 } 82 83 ret := &transport{ 84 c: c, 85 log: log, 86 stopCh: make(chan struct{}), 87 protocols: newProtocolHandler(wef), 88 calls: newCallContainer(), 89 } 90 enc := newFramedMsgpackEncoder(maxFrameLength, c) 91 ret.enc = enc 92 ret.dispatcher = newDispatch(enc, ret.calls, log, instrumenterStorage) 93 ret.receiver = newReceiveHandler(enc, ret.protocols, log) 94 ret.packetizer = newPacketizer(maxFrameLength, c, ret.protocols, ret.calls, log, instrumenterStorage) 95 return ret 96} 97 98func (t *transport) Close() { 99 t.closeOnce.Do(func() { 100 // Since the receiver might require the transport, we have to 101 // close it before terminating our loops 102 close(t.stopCh) 103 t.dispatcher.Close() 104 <-t.receiver.Close() 105 106 // First inform the encoder that it should close 107 encoderClosed := t.enc.Close() 108 // Unblock any remaining writes 109 t.c.Close() 110 // Wait for the encoder to finish handling the now unblocked writes 111 <-encoderClosed 112 }) 113} 114 115func (t *transport) IsConnected() bool { 116 select { 117 case <-t.stopCh: 118 return false 119 default: 120 return true 121 } 122} 123 124func (t *transport) receiveFrames() <-chan struct{} { 125 t.startOnce.Do(func() { 126 go t.receiveFramesLoop() 127 }) 128 return t.stopCh 129} 130 131func (t *transport) done() <-chan struct{} { 132 return t.stopCh 133} 134 135func (t *transport) err() error { 136 select { 137 case <-t.stopCh: 138 return t.stopErr 139 default: 140 return nil 141 } 142} 143 144func (t *transport) receiveFramesLoop() { 145 // Packetize: do work 146 var err error 147 for shouldContinue(err) { 148 var rpc rpcMessage 149 if rpc, err = t.packetizer.NextFrame(); shouldReceive(rpc) { 150 if rerr := t.receiver.Receive(rpc); rerr != nil { 151 t.log.Info("error on Receive: %v", rerr) 152 } 153 } 154 } 155 156 // Log packetizer completion 157 t.log.TransportError(err) 158 159 // This must happen before stopCh is closed to have a correct 160 // ordering. 161 t.stopErr = err 162 163 t.Close() 164} 165 166func (t *transport) getDispatcher() (dispatcher, error) { 167 if !t.IsConnected() { 168 return nil, io.EOF 169 } 170 return t.dispatcher, nil 171} 172 173func (t *transport) getReceiver() (receiver, error) { 174 if !t.IsConnected() { 175 return nil, io.EOF 176 } 177 return t.receiver, nil 178} 179 180func (t *transport) registerProtocol(p Protocol) error { 181 return t.protocols.registerProtocol(p) 182} 183 184func shouldContinue(err error) bool { 185 err = unboxRPCError(err) 186 switch err.(type) { 187 case nil: 188 return true 189 case CallNotFoundError: 190 return true 191 case MethodNotFoundError: 192 return true 193 case ProtocolNotFoundError: 194 return true 195 default: 196 return false 197 } 198} 199 200func shouldReceive(rpc rpcMessage) bool { 201 if rpc == nil { 202 return false 203 } 204 switch rpc.Err().(type) { 205 case nil: 206 return true 207 case MethodNotFoundError: 208 return true 209 case ProtocolNotFoundError: 210 return true 211 default: 212 return false 213 } 214} 215