1package native
2
3import (
4	"errors"
5	"github.com/ziutek/mymysql/mysql"
6	"log"
7	"math"
8	"strconv"
9)
10
11type Result struct {
12	my          *Conn
13	status_only bool // true if result doesn't contain result set
14	binary      bool // Binary result expected
15
16	field_count int
17	fields      []*mysql.Field // Fields table
18	fc_map      map[string]int // Maps field name to column number
19
20	message       []byte
21	affected_rows uint64
22
23	// Primary key value (useful for AUTO_INCREMENT primary keys)
24	insert_id uint64
25
26	// Number of warinigs during command execution
27	// You can use the SHOW WARNINGS query for details.
28	warning_count int
29
30	// MySQL server status immediately after the query execution
31	status uint16
32
33	// Seted by GetRow if it returns nil row
34	eor_returned bool
35}
36
37// Returns true if this is status result that includes no result set
38func (res *Result) StatusOnly() bool {
39	return res.status_only
40}
41
42// Returns a table containing descriptions of the columns
43func (res *Result) Fields() []*mysql.Field {
44	return res.fields
45}
46
47// Returns index for given name or -1 if field of that name doesn't exist
48func (res *Result) Map(field_name string) int {
49	if fi, ok := res.fc_map[field_name]; ok {
50		return fi
51	}
52	return -1
53}
54
55func (res *Result) Message() string {
56	return string(res.message)
57}
58
59func (res *Result) AffectedRows() uint64 {
60	return res.affected_rows
61}
62
63func (res *Result) InsertId() uint64 {
64	return res.insert_id
65}
66
67func (res *Result) WarnCount() int {
68	return res.warning_count
69}
70
71func (res *Result) MakeRow() mysql.Row {
72	return make(mysql.Row, res.field_count)
73}
74
75func (my *Conn) getResult(res *Result, row mysql.Row) *Result {
76loop:
77	pr := my.newPktReader() // New reader for next packet
78	pkt0 := pr.readByte()
79
80	if pkt0 == 255 {
81		// Error packet
82		my.getErrorPacket(pr)
83	}
84
85	if res == nil {
86		switch {
87		case pkt0 == 0:
88			// OK packet
89			return my.getOkPacket(pr)
90
91		case pkt0 > 0 && pkt0 < 251:
92			// Result set header packet
93			res = my.getResSetHeadPacket(pr)
94			// Read next packet
95			goto loop
96		case pkt0 == 251:
97			// Load infile response
98			// Handle response
99			goto loop
100		case pkt0 == 254:
101			// EOF packet (without body)
102			return nil
103		}
104	} else {
105		switch {
106		case pkt0 == 254:
107			// EOF packet
108			res.warning_count, res.status = my.getEofPacket(pr)
109			my.status = res.status
110			return res
111
112		case pkt0 > 0 && pkt0 < 251 && res.field_count < len(res.fields):
113			// Field packet
114			field := my.getFieldPacket(pr)
115			res.fields[res.field_count] = field
116			res.fc_map[field.Name] = res.field_count
117			// Increment field count
118			res.field_count++
119			// Read next packet
120			goto loop
121
122		case pkt0 < 254 && res.field_count == len(res.fields):
123			// Row Data Packet
124			if len(row) != res.field_count {
125				panic(mysql.ErrRowLength)
126			}
127			if res.binary {
128				my.getBinRowPacket(pr, res, row)
129			} else {
130				my.getTextRowPacket(pr, res, row)
131			}
132			return nil
133		}
134	}
135	panic(mysql.ErrUnkResultPkt)
136}
137
138func (my *Conn) getOkPacket(pr *pktReader) (res *Result) {
139	if my.Debug {
140		log.Printf("[%2d ->] OK packet:", my.seq-1)
141	}
142	res = new(Result)
143	res.status_only = true
144	res.my = my
145	// First byte was readed by getResult
146	res.affected_rows = pr.readLCB()
147	res.insert_id = pr.readLCB()
148	res.status = pr.readU16()
149	my.status = res.status
150	res.warning_count = int(pr.readU16())
151	res.message = pr.readAll()
152	pr.checkEof()
153
154	if my.Debug {
155		log.Printf(tab8s+"AffectedRows=%d InsertId=0x%x Status=0x%x "+
156			"WarningCount=%d Message=\"%s\"", res.affected_rows, res.insert_id,
157			res.status, res.warning_count, res.message,
158		)
159	}
160	return
161}
162
163func (my *Conn) getErrorPacket(pr *pktReader) {
164	if my.Debug {
165		log.Printf("[%2d ->] Error packet:", my.seq-1)
166	}
167	var err mysql.Error
168	err.Code = pr.readU16()
169	if pr.readByte() != '#' {
170		panic(mysql.ErrPkt)
171	}
172	pr.skipN(5)
173	err.Msg = pr.readAll()
174	pr.checkEof()
175
176	if my.Debug {
177		log.Printf(tab8s+"code=0x%x msg=\"%s\"", err.Code, err.Msg)
178	}
179	panic(&err)
180}
181
182func (my *Conn) getEofPacket(pr *pktReader) (warn_count int, status uint16) {
183	if my.Debug {
184		if pr.eof() {
185			log.Printf("[%2d ->] EOF packet without body", my.seq-1)
186		} else {
187			log.Printf("[%2d ->] EOF packet:", my.seq-1)
188		}
189	}
190	if pr.eof() {
191		return
192	}
193	warn_count = int(pr.readU16())
194	if pr.eof() {
195		return
196	}
197	status = pr.readU16()
198	pr.checkEof()
199
200	if my.Debug {
201		log.Printf(tab8s+"WarningCount=%d Status=0x%x", warn_count, status)
202	}
203	return
204}
205
206func (my *Conn) getResSetHeadPacket(pr *pktReader) (res *Result) {
207	if my.Debug {
208		log.Printf("[%2d ->] Result set header packet:", my.seq-1)
209	}
210	pr.unreadByte()
211
212	field_count := int(pr.readLCB())
213	pr.checkEof()
214
215	res = &Result{
216		my:     my,
217		fields: make([]*mysql.Field, field_count),
218		fc_map: make(map[string]int),
219	}
220
221	if my.Debug {
222		log.Printf(tab8s+"FieldCount=%d", field_count)
223	}
224	return
225}
226
227func (my *Conn) getFieldPacket(pr *pktReader) (field *mysql.Field) {
228	if my.Debug {
229		log.Printf("[%2d ->] Field packet:", my.seq-1)
230	}
231	pr.unreadByte()
232
233	field = new(mysql.Field)
234	if my.fullFieldInfo {
235		field.Catalog = string(pr.readBin())
236		field.Db = string(pr.readBin())
237		field.Table = string(pr.readBin())
238		field.OrgTable = string(pr.readBin())
239	} else {
240		pr.skipBin()
241		pr.skipBin()
242		pr.skipBin()
243		pr.skipBin()
244	}
245	field.Name = string(pr.readBin())
246	if my.fullFieldInfo {
247		field.OrgName = string(pr.readBin())
248	} else {
249		pr.skipBin()
250	}
251	pr.skipN(1 + 2)
252	//field.Charset= pr.readU16()
253	field.DispLen = pr.readU32()
254	field.Type = pr.readByte()
255	field.Flags = pr.readU16()
256	field.Scale = pr.readByte()
257	pr.skipN(2)
258	pr.checkEof()
259
260	if my.Debug {
261		log.Printf(tab8s+"Name=\"%s\" Type=0x%x", field.Name, field.Type)
262	}
263	return
264}
265
266func (my *Conn) getTextRowPacket(pr *pktReader, res *Result, row mysql.Row) {
267	if my.Debug {
268		log.Printf("[%2d ->] Text row data packet", my.seq-1)
269	}
270	pr.unreadByte()
271
272	for ii := 0; ii < res.field_count; ii++ {
273		bin, null := pr.readNullBin()
274		if null {
275			row[ii] = nil
276		} else {
277			row[ii] = bin
278		}
279	}
280	pr.checkEof()
281}
282
283func (my *Conn) getBinRowPacket(pr *pktReader, res *Result, row mysql.Row) {
284	if my.Debug {
285		log.Printf("[%2d ->] Binary row data packet", my.seq-1)
286	}
287	// First byte was readed by getResult
288
289	null_bitmap := make([]byte, (res.field_count+7+2)>>3)
290	pr.readFull(null_bitmap)
291
292	for ii, field := range res.fields {
293		null_byte := (ii + 2) >> 3
294		null_mask := byte(1) << uint(2+ii-(null_byte<<3))
295		if null_bitmap[null_byte]&null_mask != 0 {
296			// Null field
297			row[ii] = nil
298			continue
299		}
300		unsigned := (field.Flags & _FLAG_UNSIGNED) != 0
301		if my.narrowTypeSet {
302			row[ii] = readValueNarrow(pr, field.Type, unsigned)
303		} else {
304			row[ii] = readValue(pr, field.Type, unsigned)
305		}
306	}
307}
308
309func readValue(pr *pktReader, typ byte, unsigned bool) interface{} {
310	switch typ {
311	case MYSQL_TYPE_STRING, MYSQL_TYPE_VAR_STRING, MYSQL_TYPE_VARCHAR,
312		MYSQL_TYPE_BIT, MYSQL_TYPE_BLOB, MYSQL_TYPE_TINY_BLOB,
313		MYSQL_TYPE_MEDIUM_BLOB, MYSQL_TYPE_LONG_BLOB, MYSQL_TYPE_SET,
314		MYSQL_TYPE_ENUM, MYSQL_TYPE_GEOMETRY:
315		return pr.readBin()
316	case MYSQL_TYPE_TINY:
317		if unsigned {
318			return pr.readByte()
319		} else {
320			return int8(pr.readByte())
321		}
322	case MYSQL_TYPE_SHORT, MYSQL_TYPE_YEAR:
323		if unsigned {
324			return pr.readU16()
325		} else {
326			return int16(pr.readU16())
327		}
328	case MYSQL_TYPE_LONG, MYSQL_TYPE_INT24:
329		if unsigned {
330			return pr.readU32()
331		} else {
332			return int32(pr.readU32())
333		}
334	case MYSQL_TYPE_LONGLONG:
335		if unsigned {
336			return pr.readU64()
337		} else {
338			return int64(pr.readU64())
339		}
340	case MYSQL_TYPE_FLOAT:
341		return math.Float32frombits(pr.readU32())
342	case MYSQL_TYPE_DOUBLE:
343		return math.Float64frombits(pr.readU64())
344	case MYSQL_TYPE_DECIMAL, MYSQL_TYPE_NEWDECIMAL:
345		dec := string(pr.readBin())
346		r, err := strconv.ParseFloat(dec, 64)
347		if err != nil {
348			panic(errors.New("MySQL server returned wrong decimal value: " + dec))
349		}
350		return r
351	case MYSQL_TYPE_DATE, MYSQL_TYPE_NEWDATE:
352		return pr.readDate()
353	case MYSQL_TYPE_DATETIME, MYSQL_TYPE_TIMESTAMP:
354		return pr.readTime()
355	case MYSQL_TYPE_TIME:
356		return pr.readDuration()
357	}
358	panic(mysql.ErrUnkMySQLType)
359}
360
361func readValueNarrow(pr *pktReader, typ byte, unsigned bool) interface{} {
362	switch typ {
363	case MYSQL_TYPE_STRING, MYSQL_TYPE_VAR_STRING, MYSQL_TYPE_VARCHAR,
364		MYSQL_TYPE_BIT, MYSQL_TYPE_BLOB, MYSQL_TYPE_TINY_BLOB,
365		MYSQL_TYPE_MEDIUM_BLOB, MYSQL_TYPE_LONG_BLOB, MYSQL_TYPE_SET,
366		MYSQL_TYPE_ENUM, MYSQL_TYPE_GEOMETRY:
367		return pr.readBin()
368	case MYSQL_TYPE_TINY:
369		if unsigned {
370			return int64(pr.readByte())
371		}
372		return int64(int8(pr.readByte()))
373	case MYSQL_TYPE_SHORT, MYSQL_TYPE_YEAR:
374		if unsigned {
375			return int64(pr.readU16())
376		}
377		return int64(int16(pr.readU16()))
378	case MYSQL_TYPE_LONG, MYSQL_TYPE_INT24:
379		if unsigned {
380			return int64(pr.readU32())
381		}
382		return int64(int32(pr.readU32()))
383	case MYSQL_TYPE_LONGLONG:
384		v := pr.readU64()
385		if unsigned && v > math.MaxInt64 {
386			panic(errors.New("Value to large for int64 type"))
387		}
388		return int64(v)
389	case MYSQL_TYPE_FLOAT:
390		return float64(math.Float32frombits(pr.readU32()))
391	case MYSQL_TYPE_DOUBLE:
392		return math.Float64frombits(pr.readU64())
393	case MYSQL_TYPE_DECIMAL, MYSQL_TYPE_NEWDECIMAL:
394		dec := string(pr.readBin())
395		r, err := strconv.ParseFloat(dec, 64)
396		if err != nil {
397			panic("MySQL server returned wrong decimal value: " + dec)
398		}
399		return r
400	case MYSQL_TYPE_DATETIME, MYSQL_TYPE_TIMESTAMP, MYSQL_TYPE_DATE, MYSQL_TYPE_NEWDATE:
401		return pr.readTime()
402	case MYSQL_TYPE_TIME:
403		return int64(pr.readDuration())
404	}
405	panic(mysql.ErrUnkMySQLType)
406}
407