1package native
2
3import (
4	"bufio"
5	"github.com/ziutek/mymysql/mysql"
6	"io"
7	"io/ioutil"
8)
9
10type pktReader struct {
11	rd     *bufio.Reader
12	seq    *byte
13	remain int
14	last   bool
15	buf    [8]byte
16	ibuf   [3]byte
17}
18
19func (my *Conn) newPktReader() *pktReader {
20	return &pktReader{rd: my.rd, seq: &my.seq}
21}
22
23func (pr *pktReader) readHeader() {
24	// Read next packet header
25	buf := pr.ibuf[:]
26	for {
27		n, err := pr.rd.Read(buf)
28		if err != nil {
29			panic(err)
30		}
31		buf = buf[n:]
32		if len(buf) == 0 {
33			break
34		}
35	}
36	pr.remain = int(DecodeU24(pr.ibuf[:]))
37	seq, err := pr.rd.ReadByte()
38	if err != nil {
39		panic(err)
40	}
41	// Chceck sequence number
42	if *pr.seq != seq {
43		panic(mysql.ErrSeq)
44	}
45	*pr.seq++
46	// Last packet?
47	pr.last = (pr.remain != 0xffffff)
48}
49
50func (pr *pktReader) readFull(buf []byte) {
51	for len(buf) > 0 {
52		if pr.remain == 0 {
53			if pr.last {
54				// No more packets
55				panic(io.EOF)
56			}
57			pr.readHeader()
58		}
59		n := len(buf)
60		if n > pr.remain {
61			n = pr.remain
62		}
63		n, err := pr.rd.Read(buf[:n])
64		pr.remain -= n
65		if err != nil {
66			panic(err)
67		}
68		buf = buf[n:]
69	}
70	return
71}
72
73func (pr *pktReader) readByte() byte {
74	if pr.remain == 0 {
75		if pr.last {
76			// No more packets
77			panic(io.EOF)
78		}
79		pr.readHeader()
80	}
81	b, err := pr.rd.ReadByte()
82	if err != nil {
83		panic(err)
84	}
85	pr.remain--
86	return b
87}
88
89func (pr *pktReader) readAll() (buf []byte) {
90	m := 0
91	for {
92		if pr.remain == 0 {
93			if pr.last {
94				break
95			}
96			pr.readHeader()
97		}
98		new_buf := make([]byte, m+pr.remain)
99		copy(new_buf, buf)
100		buf = new_buf
101		n, err := pr.rd.Read(buf[m:])
102		pr.remain -= n
103		m += n
104		if err != nil {
105			panic(err)
106		}
107	}
108	return
109}
110
111func (pr *pktReader) skipAll() {
112	for {
113		if pr.remain == 0 {
114			if pr.last {
115				break
116			}
117			pr.readHeader()
118		}
119		n, err := io.CopyN(ioutil.Discard, pr.rd, int64(pr.remain))
120		pr.remain -= int(n)
121		if err != nil {
122			panic(err)
123		}
124	}
125	return
126}
127
128func (pr *pktReader) skipN(n int) {
129	for n > 0 {
130		if pr.remain == 0 {
131			if pr.last {
132				panic(io.EOF)
133			}
134			pr.readHeader()
135		}
136		m := int64(n)
137		if n > pr.remain {
138			m = int64(pr.remain)
139		}
140		m, err := io.CopyN(ioutil.Discard, pr.rd, m)
141		pr.remain -= int(m)
142		n -= int(m)
143		if err != nil {
144			panic(err)
145		}
146	}
147	return
148}
149
150func (pr *pktReader) unreadByte() {
151	if err := pr.rd.UnreadByte(); err != nil {
152		panic(err)
153	}
154	pr.remain++
155}
156
157func (pr *pktReader) eof() bool {
158	return pr.remain == 0 && pr.last
159}
160
161func (pr *pktReader) checkEof() {
162	if !pr.eof() {
163		panic(mysql.ErrPktLong)
164	}
165}
166
167type pktWriter struct {
168	wr       *bufio.Writer
169	seq      *byte
170	remain   int
171	to_write int
172	last     bool
173	buf      [23]byte
174	ibuf     [3]byte
175}
176
177func (my *Conn) newPktWriter(to_write int) *pktWriter {
178	return &pktWriter{wr: my.wr, seq: &my.seq, to_write: to_write}
179}
180
181func (pw *pktWriter) writeHeader(l int) {
182	buf := pw.ibuf[:]
183	EncodeU24(buf, uint32(l))
184	if _, err := pw.wr.Write(buf); err != nil {
185		panic(err)
186	}
187	if err := pw.wr.WriteByte(*pw.seq); err != nil {
188		panic(err)
189	}
190	// Update sequence number
191	*pw.seq++
192}
193
194func (pw *pktWriter) write(buf []byte) {
195	if len(buf) == 0 {
196		return
197	}
198	var nn int
199	for len(buf) != 0 {
200		if pw.remain == 0 {
201			if pw.to_write == 0 {
202				panic("too many data for write as packet")
203			}
204			if pw.to_write >= 0xffffff {
205				pw.remain = 0xffffff
206			} else {
207				pw.remain = pw.to_write
208				pw.last = true
209			}
210			pw.to_write -= pw.remain
211			pw.writeHeader(pw.remain)
212		}
213		nn = len(buf)
214		if nn > pw.remain {
215			nn = pw.remain
216		}
217		var err error
218		nn, err = pw.wr.Write(buf[0:nn])
219		pw.remain -= nn
220		if err != nil {
221			panic(err)
222		}
223		buf = buf[nn:]
224	}
225	if pw.remain+pw.to_write == 0 {
226		if !pw.last {
227			// Write  header for empty packet
228			pw.writeHeader(0)
229		}
230		// Flush bufio buffers
231		if err := pw.wr.Flush(); err != nil {
232			panic(err)
233		}
234	}
235	return
236}
237
238func (pw *pktWriter) writeByte(b byte) {
239	pw.buf[0] = b
240	pw.write(pw.buf[:1])
241}
242
243// n should be <= 23
244func (pw *pktWriter) writeZeros(n int) {
245	buf := pw.buf[:n]
246	for i := range buf {
247		buf[i] = 0
248	}
249	pw.write(buf)
250}
251