1package rpc 2 3import ( 4 "fmt" 5 "io" 6 7 "github.com/keybase/go-codec/codec" 8 "golang.org/x/net/context" 9) 10 11func newCodecMsgpackHandle() codec.Handle { 12 return &codec.MsgpackHandle{ 13 WriteExt: true, 14 RawToString: true, 15 } 16} 17 18type writeBundle struct { 19 bytes []byte 20 ch chan error 21 sn func() 22} 23 24type framedMsgpackEncoder struct { 25 maxFrameLength int32 26 handle codec.Handle 27 writer io.Writer 28 writeCh chan writeBundle 29 doneCh chan struct{} 30 closedCh chan struct{} 31 compressorCacher *compressorCacher 32} 33 34func newFramedMsgpackEncoder(maxFrameLength int32, writer io.Writer) *framedMsgpackEncoder { 35 e := &framedMsgpackEncoder{ 36 maxFrameLength: maxFrameLength, 37 handle: newCodecMsgpackHandle(), 38 writer: writer, 39 writeCh: make(chan writeBundle), 40 doneCh: make(chan struct{}), 41 closedCh: make(chan struct{}), 42 compressorCacher: newCompressorCacher(), 43 } 44 go e.writerLoop() 45 return e 46} 47 48func encodeToBytes(enc *codec.Encoder, i interface{}) (v []byte, err error) { 49 enc.ResetBytes(&v) 50 err = enc.Encode(i) 51 return v, err 52} 53 54func (e *framedMsgpackEncoder) compressData(ctype CompressionType, i interface{}) (interface{}, error) { 55 c := e.compressorCacher.getCompressor(ctype) 56 if c == nil { 57 return i, nil 58 } 59 enc := codec.NewEncoderBytes(nil, e.handle) 60 content, err := encodeToBytes(enc, i) 61 if err != nil { 62 return nil, err 63 } 64 compressedContent, err := c.Compress(content) 65 if err != nil { 66 return nil, err 67 } 68 compressedI := interface{}(compressedContent) 69 return compressedI, nil 70} 71 72func (e *framedMsgpackEncoder) encodeFrame(i interface{}) ([]byte, error) { 73 enc := codec.NewEncoderBytes(nil, e.handle) 74 content, err := encodeToBytes(enc, i) 75 if err != nil { 76 return nil, err 77 } 78 if len(content) > int(e.maxFrameLength) { 79 return nil, fmt.Errorf("frame length too big: %d > %d", len(content), e.maxFrameLength) 80 } 81 length, err := encodeToBytes(enc, len(content)) 82 if err != nil { 83 return nil, err 84 } 85 return append(length, content...), nil 86} 87 88// encodeAndWriteInternal is called directly by tests that need to 89// write invalid frames. 90func (e *framedMsgpackEncoder) encodeAndWriteInternal(ctx context.Context, frame interface{}, sendNotifier func()) (int64, <-chan error) { 91 bytes, err := e.encodeFrame(frame) 92 ch := make(chan error, 1) 93 if err != nil { 94 ch <- err 95 return 0, ch 96 } 97 select { 98 case <-e.doneCh: 99 ch <- io.EOF 100 case <-ctx.Done(): 101 ch <- ctx.Err() 102 case e.writeCh <- writeBundle{bytes, ch, sendNotifier}: 103 } 104 return int64(len(bytes)), ch 105} 106 107func (e *framedMsgpackEncoder) EncodeAndWrite(ctx context.Context, frame []interface{}, sendNotifier func()) (int64, <-chan error) { 108 return e.encodeAndWriteInternal(ctx, frame, sendNotifier) 109} 110 111func (e *framedMsgpackEncoder) EncodeAndWriteAsync(frame []interface{}) (int64, <-chan error) { 112 bytes, err := e.encodeFrame(frame) 113 ch := make(chan error, 1) 114 if err != nil { 115 ch <- err 116 return 0, ch 117 } 118 select { 119 case <-e.doneCh: 120 ch <- io.EOF 121 case e.writeCh <- writeBundle{bytes, ch, nil}: 122 default: 123 go func() { 124 select { 125 case <-e.doneCh: 126 ch <- io.EOF 127 case e.writeCh <- writeBundle{bytes, ch, nil}: 128 } 129 }() 130 } 131 return int64(len(bytes)), ch 132} 133 134func (e *framedMsgpackEncoder) writerLoop() { 135 for { 136 select { 137 case <-e.doneCh: 138 close(e.closedCh) 139 return 140 case write := <-e.writeCh: 141 if write.sn != nil { 142 write.sn() 143 } 144 _, err := e.writer.Write(write.bytes) 145 write.ch <- err 146 } 147 } 148} 149 150func (e *framedMsgpackEncoder) Close() <-chan struct{} { 151 close(e.doneCh) 152 return e.closedCh 153} 154