1package native 2 3import ( 4 "bufio" 5 "github.com/ziutek/mymysql/mysql" 6 "io" 7 "io/ioutil" 8) 9 10type pktReader struct { 11 rd *bufio.Reader 12 seq *byte 13 remain int 14 last bool 15 buf [8]byte 16 ibuf [3]byte 17} 18 19func (my *Conn) newPktReader() *pktReader { 20 return &pktReader{rd: my.rd, seq: &my.seq} 21} 22 23func (pr *pktReader) readHeader() { 24 // Read next packet header 25 buf := pr.ibuf[:] 26 for { 27 n, err := pr.rd.Read(buf) 28 if err != nil { 29 panic(err) 30 } 31 buf = buf[n:] 32 if len(buf) == 0 { 33 break 34 } 35 } 36 pr.remain = int(DecodeU24(pr.ibuf[:])) 37 seq, err := pr.rd.ReadByte() 38 if err != nil { 39 panic(err) 40 } 41 // Chceck sequence number 42 if *pr.seq != seq { 43 panic(mysql.ErrSeq) 44 } 45 *pr.seq++ 46 // Last packet? 47 pr.last = (pr.remain != 0xffffff) 48} 49 50func (pr *pktReader) readFull(buf []byte) { 51 for len(buf) > 0 { 52 if pr.remain == 0 { 53 if pr.last { 54 // No more packets 55 panic(io.EOF) 56 } 57 pr.readHeader() 58 } 59 n := len(buf) 60 if n > pr.remain { 61 n = pr.remain 62 } 63 n, err := pr.rd.Read(buf[:n]) 64 pr.remain -= n 65 if err != nil { 66 panic(err) 67 } 68 buf = buf[n:] 69 } 70 return 71} 72 73func (pr *pktReader) readByte() byte { 74 if pr.remain == 0 { 75 if pr.last { 76 // No more packets 77 panic(io.EOF) 78 } 79 pr.readHeader() 80 } 81 b, err := pr.rd.ReadByte() 82 if err != nil { 83 panic(err) 84 } 85 pr.remain-- 86 return b 87} 88 89func (pr *pktReader) readAll() (buf []byte) { 90 m := 0 91 for { 92 if pr.remain == 0 { 93 if pr.last { 94 break 95 } 96 pr.readHeader() 97 } 98 new_buf := make([]byte, m+pr.remain) 99 copy(new_buf, buf) 100 buf = new_buf 101 n, err := pr.rd.Read(buf[m:]) 102 pr.remain -= n 103 m += n 104 if err != nil { 105 panic(err) 106 } 107 } 108 return 109} 110 111func (pr *pktReader) skipAll() { 112 for { 113 if pr.remain == 0 { 114 if pr.last { 115 break 116 } 117 pr.readHeader() 118 } 119 n, err := io.CopyN(ioutil.Discard, pr.rd, int64(pr.remain)) 120 pr.remain -= int(n) 121 if err != nil { 122 panic(err) 123 } 124 } 125 return 126} 127 128func (pr *pktReader) skipN(n int) { 129 for n > 0 { 130 if pr.remain == 0 { 131 if pr.last { 132 panic(io.EOF) 133 } 134 pr.readHeader() 135 } 136 m := int64(n) 137 if n > pr.remain { 138 m = int64(pr.remain) 139 } 140 m, err := io.CopyN(ioutil.Discard, pr.rd, m) 141 pr.remain -= int(m) 142 n -= int(m) 143 if err != nil { 144 panic(err) 145 } 146 } 147 return 148} 149 150func (pr *pktReader) unreadByte() { 151 if err := pr.rd.UnreadByte(); err != nil { 152 panic(err) 153 } 154 pr.remain++ 155} 156 157func (pr *pktReader) eof() bool { 158 return pr.remain == 0 && pr.last 159} 160 161func (pr *pktReader) checkEof() { 162 if !pr.eof() { 163 panic(mysql.ErrPktLong) 164 } 165} 166 167type pktWriter struct { 168 wr *bufio.Writer 169 seq *byte 170 remain int 171 to_write int 172 last bool 173 buf [23]byte 174 ibuf [3]byte 175} 176 177func (my *Conn) newPktWriter(to_write int) *pktWriter { 178 return &pktWriter{wr: my.wr, seq: &my.seq, to_write: to_write} 179} 180 181func (pw *pktWriter) writeHeader(l int) { 182 buf := pw.ibuf[:] 183 EncodeU24(buf, uint32(l)) 184 if _, err := pw.wr.Write(buf); err != nil { 185 panic(err) 186 } 187 if err := pw.wr.WriteByte(*pw.seq); err != nil { 188 panic(err) 189 } 190 // Update sequence number 191 *pw.seq++ 192} 193 194func (pw *pktWriter) write(buf []byte) { 195 if len(buf) == 0 { 196 return 197 } 198 var nn int 199 for len(buf) != 0 { 200 if pw.remain == 0 { 201 if pw.to_write == 0 { 202 panic("too many data for write as packet") 203 } 204 if pw.to_write >= 0xffffff { 205 pw.remain = 0xffffff 206 } else { 207 pw.remain = pw.to_write 208 pw.last = true 209 } 210 pw.to_write -= pw.remain 211 pw.writeHeader(pw.remain) 212 } 213 nn = len(buf) 214 if nn > pw.remain { 215 nn = pw.remain 216 } 217 var err error 218 nn, err = pw.wr.Write(buf[0:nn]) 219 pw.remain -= nn 220 if err != nil { 221 panic(err) 222 } 223 buf = buf[nn:] 224 } 225 if pw.remain+pw.to_write == 0 { 226 if !pw.last { 227 // Write header for empty packet 228 pw.writeHeader(0) 229 } 230 // Flush bufio buffers 231 if err := pw.wr.Flush(); err != nil { 232 panic(err) 233 } 234 } 235 return 236} 237 238func (pw *pktWriter) writeByte(b byte) { 239 pw.buf[0] = b 240 pw.write(pw.buf[:1]) 241} 242 243// n should be <= 23 244func (pw *pktWriter) writeZeros(n int) { 245 buf := pw.buf[:n] 246 for i := range buf { 247 buf[i] = 0 248 } 249 pw.write(buf) 250} 251