1// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
2//
3// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
4//
5// This Source Code Form is subject to the terms of the Mozilla Public
6// License, v. 2.0. If a copy of the MPL was not distributed with this file,
7// You can obtain one at http://mozilla.org/MPL/2.0/.
8
9package mysql
10
11import (
12	"database/sql/driver"
13	"io"
14)
15
16type mysqlField struct {
17	tableName string
18	name      string
19	flags     fieldFlag
20	fieldType byte
21	decimals  byte
22}
23
24type resultSet struct {
25	columns []mysqlField
26	done    bool
27}
28
29type mysqlRows struct {
30	mc *mysqlConn
31	rs resultSet
32}
33
34type binaryRows struct {
35	mysqlRows
36	// stmtCols is a pointer to the statement's cached columns for different
37	// result sets.
38	stmtCols *[][]mysqlField
39	// i is a number of the current result set. It is used to fetch proper
40	// columns from stmtCols.
41	i int
42}
43
44type textRows struct {
45	mysqlRows
46}
47
48func (rows *mysqlRows) Columns() []string {
49	columns := make([]string, len(rows.rs.columns))
50	if rows.mc != nil && rows.mc.cfg.ColumnsWithAlias {
51		for i := range columns {
52			if tableName := rows.rs.columns[i].tableName; len(tableName) > 0 {
53				columns[i] = tableName + "." + rows.rs.columns[i].name
54			} else {
55				columns[i] = rows.rs.columns[i].name
56			}
57		}
58	} else {
59		for i := range columns {
60			columns[i] = rows.rs.columns[i].name
61		}
62	}
63	return columns
64}
65
66func (rows *mysqlRows) Close() (err error) {
67	mc := rows.mc
68	if mc == nil {
69		return nil
70	}
71	if mc.netConn == nil {
72		return ErrInvalidConn
73	}
74
75	// Remove unread packets from stream
76	if !rows.rs.done {
77		err = mc.readUntilEOF()
78	}
79	if err == nil {
80		if err = mc.discardResults(); err != nil {
81			return err
82		}
83	}
84
85	rows.mc = nil
86	return err
87}
88
89func (rows *mysqlRows) HasNextResultSet() (b bool) {
90	if rows.mc == nil {
91		return false
92	}
93	return rows.mc.status&statusMoreResultsExists != 0
94}
95
96func (rows *mysqlRows) nextResultSet() (int, error) {
97	if rows.mc == nil {
98		return 0, io.EOF
99	}
100	if rows.mc.netConn == nil {
101		return 0, ErrInvalidConn
102	}
103
104	// Remove unread packets from stream
105	if !rows.rs.done {
106		if err := rows.mc.readUntilEOF(); err != nil {
107			return 0, err
108		}
109		rows.rs.done = true
110	}
111
112	if !rows.HasNextResultSet() {
113		rows.mc = nil
114		return 0, io.EOF
115	}
116	rows.rs = resultSet{}
117	return rows.mc.readResultSetHeaderPacket()
118}
119
120func (rows *mysqlRows) nextNotEmptyResultSet() (int, error) {
121	for {
122		resLen, err := rows.nextResultSet()
123		if err != nil {
124			return 0, err
125		}
126
127		if resLen > 0 {
128			return resLen, nil
129		}
130
131		rows.rs.done = true
132	}
133}
134
135func (rows *binaryRows) NextResultSet() (err error) {
136	resLen, err := rows.nextNotEmptyResultSet()
137	if err != nil {
138		return err
139	}
140
141	// get columns, if not cached, read them and cache them.
142	if rows.i >= len(*rows.stmtCols) {
143		rows.rs.columns, err = rows.mc.readColumns(resLen)
144		*rows.stmtCols = append(*rows.stmtCols, rows.rs.columns)
145	} else {
146		rows.rs.columns = (*rows.stmtCols)[rows.i]
147		if err := rows.mc.readUntilEOF(); err != nil {
148			return err
149		}
150	}
151
152	rows.i++
153	return nil
154}
155
156func (rows *binaryRows) Next(dest []driver.Value) error {
157	if mc := rows.mc; mc != nil {
158		if mc.netConn == nil {
159			return ErrInvalidConn
160		}
161
162		// Fetch next row from stream
163		return rows.readRow(dest)
164	}
165	return io.EOF
166}
167
168func (rows *textRows) NextResultSet() (err error) {
169	resLen, err := rows.nextNotEmptyResultSet()
170	if err != nil {
171		return err
172	}
173
174	rows.rs.columns, err = rows.mc.readColumns(resLen)
175	return err
176}
177
178func (rows *textRows) Next(dest []driver.Value) error {
179	if mc := rows.mc; mc != nil {
180		if mc.netConn == nil {
181			return ErrInvalidConn
182		}
183
184		// Fetch next row from stream
185		return rows.readRow(dest)
186	}
187	return io.EOF
188}
189