1 //
2 // Copyright (C) 2004-2006 Maciej Sobczak, Stephen Hutton, David Courtney
3 // Distributed under the Boost Software License, Version 1.0.
4 // (See accompanying file LICENSE_1_0.txt or copy at
5 // http://www.boost.org/LICENSE_1_0.txt)
6 //
7 
8 #define SOCI_ODBC_SOURCE
9 #include "soci/odbc/soci-odbc.h"
10 #include <cctype>
11 #include <sstream>
12 #include <cstring>
13 
14 using namespace soci;
15 using namespace soci::details;
16 
17 
odbc_statement_backend(odbc_session_backend & session)18 odbc_statement_backend::odbc_statement_backend(odbc_session_backend &session)
19     : session_(session), hstmt_(0), numRowsFetched_(0),
20       hasVectorUseElements_(false), boundByName_(false), boundByPos_(false),
21       rowsAffected_(-1LL)
22 {
23 }
24 
alloc()25 void odbc_statement_backend::alloc()
26 {
27     SQLRETURN rc;
28 
29     // Allocate environment handle
30     rc = SQLAllocHandle(SQL_HANDLE_STMT, session_.hdbc_, &hstmt_);
31     if (is_odbc_error(rc))
32     {
33         throw odbc_soci_error(SQL_HANDLE_DBC, session_.hdbc_,
34                               "allocating statement");
35     }
36 }
37 
clean_up()38 void odbc_statement_backend::clean_up()
39 {
40     rowsAffected_ = -1LL;
41 
42     SQLFreeHandle(SQL_HANDLE_STMT, hstmt_);
43 }
44 
45 
prepare(std::string const & query,statement_type)46 void odbc_statement_backend::prepare(std::string const & query,
47     statement_type /* eType */)
48 {
49     // rewrite the query by transforming all named parameters into
50     // the ODBC numbers ones (:abc -> $1, etc.)
51 
52     enum { eNormal, eInQuotes, eInName, eInAccessDate } state = eNormal;
53 
54     std::string name;
55     query_.reserve(query.length());
56 
57     for (std::string::const_iterator it = query.begin(), end = query.end();
58          it != end; ++it)
59     {
60         switch (state)
61         {
62         case eNormal:
63             if (*it == '\'')
64             {
65                 query_ += *it;
66                 state = eInQuotes;
67             }
68             else if (*it == '#')
69             {
70                 query_ += *it;
71                 state = eInAccessDate;
72             }
73             else if (*it == ':')
74             {
75                 state = eInName;
76             }
77             else // regular character, stay in the same state
78             {
79                 query_ += *it;
80             }
81             break;
82         case eInQuotes:
83             if (*it == '\'')
84             {
85                 query_ += *it;
86                 state = eNormal;
87             }
88             else // regular quoted character
89             {
90                 query_ += *it;
91             }
92             break;
93         case eInName:
94             if (std::isalnum(*it) || *it == '_')
95             {
96                 name += *it;
97             }
98             else // end of name
99             {
100                 names_.push_back(name);
101                 name.clear();
102                 query_ += "?";
103                 query_ += *it;
104                 state = eNormal;
105             }
106             break;
107         case eInAccessDate:
108             if (*it == '#')
109             {
110                 query_ += *it;
111                 state = eNormal;
112             }
113             else // regular quoted character
114             {
115                 query_ += *it;
116             }
117             break;
118         }
119     }
120 
121     if (state == eInName)
122     {
123         names_.push_back(name);
124         query_ += "?";
125     }
126 
127     SQLRETURN rc = SQLPrepare(hstmt_, sqlchar_cast(query_), (SQLINTEGER)query_.size());
128     if (is_odbc_error(rc))
129     {
130         std::ostringstream ss;
131         ss << "preparing query \"" << query_ << "\"";
132         throw odbc_soci_error(SQL_HANDLE_STMT, hstmt_, ss.str());
133     }
134 }
135 
136 statement_backend::exec_fetch_result
execute(int number)137 odbc_statement_backend::execute(int number)
138 {
139     // Store the number of rows processed by this call.
140     SQLULEN rows_processed = 0;
141     if (hasVectorUseElements_)
142     {
143         SQLSetStmtAttr(hstmt_, SQL_ATTR_PARAMS_PROCESSED_PTR, &rows_processed, 0);
144     }
145 
146     // if we are called twice for the same statement we need to close the open
147     // cursor or an "invalid cursor state" error will occur on execute
148     SQLCloseCursor(hstmt_);
149 
150     SQLRETURN rc = SQLExecute(hstmt_);
151     if (is_odbc_error(rc))
152     {
153         // Construct the error object immediately, before calling any other
154         // ODBC functions, in order to not lose the error message.
155         const odbc_soci_error err(SQL_HANDLE_STMT, hstmt_, "executing statement");
156 
157         // There is no universal way to determine the number of affected rows
158         // after a failed update.
159         rowsAffected_ = -1LL;
160 
161         // If executing bulk operation a partial
162         // number of rows affected may be available.
163         if (hasVectorUseElements_)
164         {
165             do
166             {
167                 SQLLEN res = 0;
168                 // SQLRowCount will return error after a partially executed statement.
169                 // SQL_DIAG_ROW_COUNT returns the same info but must be collected immediatelly after the execution.
170                 rc = SQLGetDiagField(SQL_HANDLE_STMT, hstmt_, 0, SQL_DIAG_ROW_COUNT, &res, 0, NULL);
171                 if (!is_odbc_error(rc) && res != -1)
172                 {
173                   if (rowsAffected_ == -1LL)
174                     rowsAffected_ = res;
175                   else
176                     rowsAffected_ += res;
177                 }
178                 --rows_processed; // Avoid unnecessary calls to SQLGetDiagField
179             }
180             // Move forward to the next result while there are rows processed.
181             while (rows_processed > 0 && SQLMoreResults(hstmt_) == SQL_SUCCESS);
182         }
183         throw err;
184     }
185     else if (hasVectorUseElements_)
186     {
187         // We already have the number of rows, no need to do anything.
188         rowsAffected_ = rows_processed;
189     }
190     else // We need to retrieve the number of rows affected explicitly.
191     {
192         SQLLEN res = 0;
193         rc = SQLRowCount(hstmt_, &res);
194         if (is_odbc_error(rc))
195         {
196             throw odbc_soci_error(SQL_HANDLE_STMT, hstmt_,
197                                   "getting number of affected rows");
198         }
199 
200         rowsAffected_ = res;
201     }
202     SQLSMALLINT colCount;
203     SQLNumResultCols(hstmt_, &colCount);
204 
205     if (number > 0 && colCount > 0)
206     {
207         return fetch(number);
208     }
209 
210     return ef_success;
211 }
212 
213 statement_backend::exec_fetch_result
fetch(int number)214 odbc_statement_backend::fetch(int number)
215 {
216     numRowsFetched_ = 0;
217     SQLULEN const row_array_size = static_cast<SQLULEN>(number);
218 
219     SQLSetStmtAttr(hstmt_, SQL_ATTR_ROW_BIND_TYPE, SQL_BIND_BY_COLUMN, 0);
220     SQLSetStmtAttr(hstmt_, SQL_ATTR_ROW_ARRAY_SIZE, (SQLPOINTER)row_array_size, 0);
221     SQLSetStmtAttr(hstmt_, SQL_ATTR_ROWS_FETCHED_PTR, &numRowsFetched_, 0);
222 
223     SQLRETURN rc = SQLFetch(hstmt_);
224 
225     if (SQL_NO_DATA == rc)
226     {
227         return ef_no_data;
228     }
229 
230     if (is_odbc_error(rc))
231     {
232         throw odbc_soci_error(SQL_HANDLE_STMT, hstmt_, "fetching data");
233     }
234 
235     return ef_success;
236 }
237 
get_affected_rows()238 long long odbc_statement_backend::get_affected_rows()
239 {
240     return rowsAffected_;
241 }
242 
get_number_of_rows()243 int odbc_statement_backend::get_number_of_rows()
244 {
245     return static_cast<int>(numRowsFetched_);
246 }
247 
get_parameter_name(int index) const248 std::string odbc_statement_backend::get_parameter_name(int index) const
249 {
250     return names_.at(index);
251 }
252 
rewrite_for_procedure_call(std::string const & query)253 std::string odbc_statement_backend::rewrite_for_procedure_call(
254     std::string const &query)
255 {
256     return query;
257 }
258 
prepare_for_describe()259 int odbc_statement_backend::prepare_for_describe()
260 {
261     SQLSMALLINT numCols;
262     SQLNumResultCols(hstmt_, &numCols);
263     return numCols;
264 }
265 
describe_column(int colNum,data_type & type,std::string & columnName)266 void odbc_statement_backend::describe_column(int colNum, data_type & type,
267                                           std::string & columnName)
268 {
269     SQLCHAR colNameBuffer[2048];
270     SQLSMALLINT colNameBufferOverflow;
271     SQLSMALLINT dataType;
272     SQLULEN colSize;
273     SQLSMALLINT decDigits;
274     SQLSMALLINT isNullable;
275 
276     SQLRETURN rc = SQLDescribeCol(hstmt_, static_cast<SQLUSMALLINT>(colNum),
277                                   colNameBuffer, 2048,
278                                   &colNameBufferOverflow, &dataType,
279                                   &colSize, &decDigits, &isNullable);
280 
281     if (is_odbc_error(rc))
282     {
283         std::ostringstream ss;
284         ss << "getting description of column at position " << colNum;
285         throw odbc_soci_error(SQL_HANDLE_STMT, hstmt_, ss.str());
286     }
287 
288     char const *name = reinterpret_cast<char const *>(colNameBuffer);
289     columnName.assign(name, std::strlen(name));
290 
291     switch (dataType)
292     {
293     case SQL_TYPE_DATE:
294     case SQL_TYPE_TIME:
295     case SQL_TYPE_TIMESTAMP:
296         type = dt_date;
297         break;
298     case SQL_DOUBLE:
299     case SQL_DECIMAL:
300     case SQL_REAL:
301     case SQL_FLOAT:
302     case SQL_NUMERIC:
303         type = dt_double;
304         break;
305     case SQL_TINYINT:
306     case SQL_SMALLINT:
307     case SQL_INTEGER:
308         type = dt_integer;
309         break;
310     case SQL_BIGINT:
311         type = dt_long_long;
312         break;
313     case SQL_CHAR:
314     case SQL_VARCHAR:
315     case SQL_LONGVARCHAR:
316     default:
317         type = dt_string;
318         break;
319     }
320 }
321 
column_size(int colNum)322 std::size_t odbc_statement_backend::column_size(int colNum)
323 {
324     SQLCHAR colNameBuffer[2048];
325     SQLSMALLINT colNameBufferOverflow;
326     SQLSMALLINT dataType;
327     SQLULEN colSize;
328     SQLSMALLINT decDigits;
329     SQLSMALLINT isNullable;
330 
331     SQLRETURN rc = SQLDescribeCol(hstmt_, static_cast<SQLUSMALLINT>(colNum),
332                                   colNameBuffer, 2048,
333                                   &colNameBufferOverflow, &dataType,
334                                   &colSize, &decDigits, &isNullable);
335 
336     if (is_odbc_error(rc))
337     {
338         std::ostringstream ss;
339         ss << "getting size of column at position " << colNum;
340         throw odbc_soci_error(SQL_HANDLE_STMT, hstmt_, ss.str());
341     }
342 
343     return colSize;
344 }
345 
make_into_type_backend()346 odbc_standard_into_type_backend * odbc_statement_backend::make_into_type_backend()
347 {
348     return new odbc_standard_into_type_backend(*this);
349 }
350 
make_use_type_backend()351 odbc_standard_use_type_backend * odbc_statement_backend::make_use_type_backend()
352 {
353     return new odbc_standard_use_type_backend(*this);
354 }
355 
356 odbc_vector_into_type_backend *
make_vector_into_type_backend()357 odbc_statement_backend::make_vector_into_type_backend()
358 {
359     return new odbc_vector_into_type_backend(*this);
360 }
361 
make_vector_use_type_backend()362 odbc_vector_use_type_backend * odbc_statement_backend::make_vector_use_type_backend()
363 {
364     hasVectorUseElements_ = true;
365     return new odbc_vector_use_type_backend(*this);
366 }
367