1package rpc
2
3import (
4	"fmt"
5	"io"
6
7	"github.com/keybase/go-codec/codec"
8	"golang.org/x/net/context"
9)
10
11func newCodecMsgpackHandle() codec.Handle {
12	return &codec.MsgpackHandle{
13		WriteExt:    true,
14		RawToString: true,
15	}
16}
17
18type writeBundle struct {
19	bytes []byte
20	ch    chan error
21	sn    func()
22}
23
24type framedMsgpackEncoder struct {
25	maxFrameLength   int32
26	handle           codec.Handle
27	writer           io.Writer
28	writeCh          chan writeBundle
29	doneCh           chan struct{}
30	closedCh         chan struct{}
31	compressorCacher *compressorCacher
32}
33
34func newFramedMsgpackEncoder(maxFrameLength int32, writer io.Writer) *framedMsgpackEncoder {
35	e := &framedMsgpackEncoder{
36		maxFrameLength:   maxFrameLength,
37		handle:           newCodecMsgpackHandle(),
38		writer:           writer,
39		writeCh:          make(chan writeBundle),
40		doneCh:           make(chan struct{}),
41		closedCh:         make(chan struct{}),
42		compressorCacher: newCompressorCacher(),
43	}
44	go e.writerLoop()
45	return e
46}
47
48func encodeToBytes(enc *codec.Encoder, i interface{}) (v []byte, err error) {
49	enc.ResetBytes(&v)
50	err = enc.Encode(i)
51	return v, err
52}
53
54func (e *framedMsgpackEncoder) compressData(ctype CompressionType, i interface{}) (interface{}, error) {
55	c := e.compressorCacher.getCompressor(ctype)
56	if c == nil {
57		return i, nil
58	}
59	enc := codec.NewEncoderBytes(nil, e.handle)
60	content, err := encodeToBytes(enc, i)
61	if err != nil {
62		return nil, err
63	}
64	compressedContent, err := c.Compress(content)
65	if err != nil {
66		return nil, err
67	}
68	compressedI := interface{}(compressedContent)
69	return compressedI, nil
70}
71
72func (e *framedMsgpackEncoder) encodeFrame(i interface{}) ([]byte, error) {
73	enc := codec.NewEncoderBytes(nil, e.handle)
74	content, err := encodeToBytes(enc, i)
75	if err != nil {
76		return nil, err
77	}
78	if len(content) > int(e.maxFrameLength) {
79		return nil, fmt.Errorf("frame length too big: %d > %d", len(content), e.maxFrameLength)
80	}
81	length, err := encodeToBytes(enc, len(content))
82	if err != nil {
83		return nil, err
84	}
85	return append(length, content...), nil
86}
87
88// encodeAndWriteInternal is called directly by tests that need to
89// write invalid frames.
90func (e *framedMsgpackEncoder) encodeAndWriteInternal(ctx context.Context, frame interface{}, sendNotifier func()) (int64, <-chan error) {
91	bytes, err := e.encodeFrame(frame)
92	ch := make(chan error, 1)
93	if err != nil {
94		ch <- err
95		return 0, ch
96	}
97	select {
98	case <-e.doneCh:
99		ch <- io.EOF
100	case <-ctx.Done():
101		ch <- ctx.Err()
102	case e.writeCh <- writeBundle{bytes, ch, sendNotifier}:
103	}
104	return int64(len(bytes)), ch
105}
106
107func (e *framedMsgpackEncoder) EncodeAndWrite(ctx context.Context, frame []interface{}, sendNotifier func()) (int64, <-chan error) {
108	return e.encodeAndWriteInternal(ctx, frame, sendNotifier)
109}
110
111func (e *framedMsgpackEncoder) EncodeAndWriteAsync(frame []interface{}) (int64, <-chan error) {
112	bytes, err := e.encodeFrame(frame)
113	ch := make(chan error, 1)
114	if err != nil {
115		ch <- err
116		return 0, ch
117	}
118	select {
119	case <-e.doneCh:
120		ch <- io.EOF
121	case e.writeCh <- writeBundle{bytes, ch, nil}:
122	default:
123		go func() {
124			select {
125			case <-e.doneCh:
126				ch <- io.EOF
127			case e.writeCh <- writeBundle{bytes, ch, nil}:
128			}
129		}()
130	}
131	return int64(len(bytes)), ch
132}
133
134func (e *framedMsgpackEncoder) writerLoop() {
135	for {
136		select {
137		case <-e.doneCh:
138			close(e.closedCh)
139			return
140		case write := <-e.writeCh:
141			if write.sn != nil {
142				write.sn()
143			}
144			_, err := e.writer.Write(write.bytes)
145			write.ch <- err
146		}
147	}
148}
149
150func (e *framedMsgpackEncoder) Close() <-chan struct{} {
151	close(e.doneCh)
152	return e.closedCh
153}
154