1 
2 #include "pyodbc.h"
3 #include "wrapper.h"
4 #include "textenc.h"
5 #include "connection.h"
6 #include "errors.h"
7 #include "pyodbcmodule.h"
8 
9 // Exceptions
10 
11 struct SqlStateMapping
12 {
13     char* prefix;
14     size_t prefix_len;
15     PyObject** pexc_class;      // Note: Double indirection (pexc_class) necessary because the pointer values are not
16                                 // initialized during startup
17 };
18 
19 static const struct SqlStateMapping sql_state_mapping[] =
20 {
21     { "01002", 5, &OperationalError },
22     { "08001", 5, &OperationalError },
23     { "08003", 5, &OperationalError },
24     { "08004", 5, &OperationalError },
25     { "08007", 5, &OperationalError },
26     { "08S01", 5, &OperationalError },
27     { "0A000", 5, &NotSupportedError },
28     { "28000", 5, &InterfaceError },
29     { "40002", 5, &IntegrityError },
30     { "22",    2, &DataError },
31     { "23",    2, &IntegrityError },
32     { "24",    2, &ProgrammingError },
33     { "25",    2, &ProgrammingError },
34     { "42",    2, &ProgrammingError },
35     { "HY001", 5, &OperationalError },
36     { "HY014", 5, &OperationalError },
37     { "HYT00", 5, &OperationalError },
38     { "HYT01", 5, &OperationalError },
39     { "IM001", 5, &InterfaceError },
40     { "IM002", 5, &InterfaceError },
41     { "IM003", 5, &InterfaceError },
42 };
43 
44 
ExceptionFromSqlState(const char * sqlstate)45 static PyObject* ExceptionFromSqlState(const char* sqlstate)
46 {
47     // Returns the appropriate Python exception class given a SQLSTATE value.
48 
49     if (sqlstate && *sqlstate)
50     {
51         for (size_t i = 0; i < _countof(sql_state_mapping); i++)
52             if (memcmp(sqlstate, sql_state_mapping[i].prefix, sql_state_mapping[i].prefix_len) == 0)
53                 return *sql_state_mapping[i].pexc_class;
54     }
55 
56     return Error;
57 }
58 
59 
RaiseErrorV(const char * sqlstate,PyObject * exc_class,const char * format,...)60 PyObject* RaiseErrorV(const char* sqlstate, PyObject* exc_class, const char* format, ...)
61 {
62     PyObject *pAttrs = 0, *pError = 0;
63 
64     if (!sqlstate || !*sqlstate)
65         sqlstate = "HY000";
66 
67     if (!exc_class)
68         exc_class = ExceptionFromSqlState(sqlstate);
69 
70     // Note: Don't use any native strprintf routines.  With Py_ssize_t, we need "%zd", but VC .NET doesn't support it.
71     // PyString_FromFormatV already takes this into account.
72 
73     va_list marker;
74     va_start(marker, format);
75     PyObject* pMsg = PyString_FromFormatV(format, marker);
76     va_end(marker);
77     if (!pMsg)
78     {
79         PyErr_NoMemory();
80         return 0;
81     }
82 
83     // Create an exception with a 'sqlstate' attribute (set to None if we don't have one) whose 'args' attribute is a
84     // tuple containing the message and sqlstate value.  The 'sqlstate' attribute ensures it is easy to access in
85     // Python (and more understandable to the reader than ex.args[1]), but putting it in the args ensures it shows up
86     // in logs because of the default repr/str.
87 
88     pAttrs = Py_BuildValue("(Os)", pMsg, sqlstate);
89     if (pAttrs)
90     {
91         pError = PyEval_CallObject(exc_class, pAttrs);
92         if (pError)
93             RaiseErrorFromException(pError);
94     }
95 
96     Py_DECREF(pMsg);
97     Py_XDECREF(pAttrs);
98     Py_XDECREF(pError);
99 
100     return 0;
101 }
102 
103 
104 #if PY_MAJOR_VERSION < 3
105 #define PyString_CompareWithASCIIString(lhs, rhs) _strcmpi(PyString_AS_STRING(lhs), rhs)
106 #else
107 #define PyString_CompareWithASCIIString PyUnicode_CompareWithASCIIString
108 #endif
109 
110 
HasSqlState(PyObject * ex,const char * szSqlState)111 bool HasSqlState(PyObject* ex, const char* szSqlState)
112 {
113     // Returns true if `ex` is an exception and has the given SQLSTATE.  It is safe to pass 0 for ex.
114 
115     bool has = false;
116 
117     if (ex)
118     {
119         PyObject* args = PyObject_GetAttrString(ex, "args");
120         if (args != 0)
121         {
122             PyObject* s = PySequence_GetItem(args, 1);
123             if (s != 0 && PyString_Check(s))
124             {
125                 // const char* sz = PyString_AsString(s);
126                 // if (sz && _strcmpi(sz, szSqlState) == 0)
127                 //     has = true;
128                 has = (PyString_CompareWithASCIIString(s, szSqlState) == 0);
129             }
130             Py_XDECREF(s);
131             Py_DECREF(args);
132         }
133     }
134 
135     return has;
136 }
137 
138 
GetError(const char * sqlstate,PyObject * exc_class,PyObject * pMsg)139 static PyObject* GetError(const char* sqlstate, PyObject* exc_class, PyObject* pMsg)
140 {
141     // pMsg
142     //   The error message.  This function takes ownership of this object, so we'll free it if we fail to create an
143     //   error.
144 
145     PyObject *pSqlState=0, *pAttrs=0, *pError=0;
146 
147     if (!sqlstate || !*sqlstate)
148         sqlstate = "HY000";
149 
150     if (!exc_class)
151         exc_class = ExceptionFromSqlState(sqlstate);
152 
153     pAttrs = PyTuple_New(2);
154     if (!pAttrs)
155     {
156         Py_DECREF(pMsg);
157         return 0;
158     }
159 
160     PyTuple_SetItem(pAttrs, 1, pMsg); // pAttrs now owns the pMsg reference; steals a reference, does not increment
161 
162     pSqlState = PyString_FromString(sqlstate);
163     if (!pSqlState)
164     {
165         Py_DECREF(pAttrs);
166         return 0;
167     }
168 
169     PyTuple_SetItem(pAttrs, 0, pSqlState); // pAttrs now owns the pSqlState reference
170 
171     pError = PyEval_CallObject(exc_class, pAttrs); // pError will incref pAttrs
172 
173     Py_XDECREF(pAttrs);
174 
175     return pError;
176 }
177 
178 
179 static const char* DEFAULT_ERROR = "The driver did not supply an error!";
180 
RaiseErrorFromHandle(Connection * conn,const char * szFunction,HDBC hdbc,HSTMT hstmt)181 PyObject* RaiseErrorFromHandle(Connection *conn, const char* szFunction, HDBC hdbc, HSTMT hstmt)
182 {
183     // The exception is "set" in the interpreter.  This function returns 0 so this can be used in a return statement.
184 
185     PyObject* pError = GetErrorFromHandle(conn, szFunction, hdbc, hstmt);
186 
187     if (pError)
188     {
189         RaiseErrorFromException(pError);
190         Py_DECREF(pError);
191     }
192 
193     return 0;
194 }
195 
196 
GetErrorFromHandle(Connection * conn,const char * szFunction,HDBC hdbc,HSTMT hstmt)197 PyObject* GetErrorFromHandle(Connection *conn, const char* szFunction, HDBC hdbc, HSTMT hstmt)
198 {
199     TRACE("In RaiseError(%s)!\n", szFunction);
200 
201     // Creates and returns an exception from ODBC error information.
202     //
203     // ODBC can generate a chain of errors which we concatenate into one error message.  We use the SQLSTATE from the
204     // first message, which seems to be the most detailed, to determine the class of exception.
205     //
206     // If the function fails, for example, if it runs out of memory, zero is returned.
207     //
208     // szFunction
209     //   The name of the function that failed.  Python generates a useful stack trace, but we often don't know where in
210     //   the C++ code we failed.
211 
212     SQLSMALLINT nHandleType;
213     SQLHANDLE   h;
214 
215     char sqlstate[6] = "";
216     SQLINTEGER nNativeError;
217     SQLSMALLINT cchMsg;
218 
219     ODBCCHAR sqlstateT[6];
220     SQLSMALLINT msgLen = 1023;
221     ODBCCHAR *szMsg = (ODBCCHAR*) pyodbc_malloc((msgLen + 1) * sizeof(ODBCCHAR));
222 
223     if (!szMsg) {
224         PyErr_NoMemory();
225         return 0;
226     }
227 
228     if (hstmt != SQL_NULL_HANDLE)
229     {
230         nHandleType = SQL_HANDLE_STMT;
231         h = hstmt;
232     }
233     else if (hdbc != SQL_NULL_HANDLE)
234     {
235         nHandleType = SQL_HANDLE_DBC;
236         h = hdbc;
237     }
238     else
239     {
240         nHandleType = SQL_HANDLE_ENV;
241         h = henv;
242     }
243 
244     // unixODBC + PostgreSQL driver 07.01.0003 (Fedora 8 binaries from RPMs) crash if you call SQLGetDiagRec more
245     // than once.  I hate to do this, but I'm going to only call it once for non-Windows platforms for now...
246 
247     SQLSMALLINT iRecord = 1;
248 
249     Object msg;
250 
251     for (;;)
252     {
253         szMsg[0]     = 0;
254         sqlstateT[0] = 0;
255         nNativeError = 0;
256         cchMsg       = 0;
257 
258         SQLRETURN ret;
259         Py_BEGIN_ALLOW_THREADS
260         ret = SQLGetDiagRecW(nHandleType, h, iRecord, (SQLWCHAR*)sqlstateT, &nNativeError, (SQLWCHAR*)szMsg, msgLen, &cchMsg);
261         Py_END_ALLOW_THREADS
262         if (!SQL_SUCCEEDED(ret))
263             break;
264 
265         // If needed, allocate a bigger error message buffer and retry.
266         if (cchMsg > msgLen - 1) {
267             msgLen = cchMsg + 1;
268             if (!pyodbc_realloc((BYTE**) &szMsg, (msgLen + 1) * sizeof(ODBCCHAR))) {
269                 PyErr_NoMemory();
270                 pyodbc_free(szMsg);
271                 return 0;
272             }
273             Py_BEGIN_ALLOW_THREADS
274             ret = SQLGetDiagRecW(nHandleType, h, iRecord, (SQLWCHAR*)sqlstateT, &nNativeError, (SQLWCHAR*)szMsg, msgLen, &cchMsg);
275             Py_END_ALLOW_THREADS
276             if (!SQL_SUCCEEDED(ret))
277                 break;
278         }
279 
280         // Not always NULL terminated (MS Access)
281         sqlstateT[5] = 0;
282 
283         // For now, default to UTF-16 if this is not in the context of a connection.
284         // Note that this will not work if the DM is using a different wide encoding (e.g. UTF-32).
285         const char *unicode_enc = conn ? conn->metadata_enc.name : ENCSTR_UTF16NE;
286         Object msgStr(PyUnicode_Decode((char*)szMsg, cchMsg * sizeof(ODBCCHAR), unicode_enc, "strict"));
287 
288         if (cchMsg != 0 && msgStr.Get())
289         {
290             if (iRecord == 1)
291             {
292                 // This is the first error message, so save the SQLSTATE for determining the
293                 // exception class and append the calling function name.
294                 CopySqlState(sqlstateT, sqlstate);
295                 msg = PyUnicode_FromFormat("[%s] %V (%ld) (%s)", sqlstate, msgStr.Get(), "(null)", (long)nNativeError, szFunction);
296                 if (!msg) {
297                     PyErr_NoMemory();
298                     pyodbc_free(szMsg);
299                     return 0;
300                 }
301             }
302             else
303             {
304                 // This is not the first error message, so append to the existing one.
305                 Object more(PyUnicode_FromFormat("; [%s] %V (%ld)", sqlstate, msgStr.Get(), "(null)", (long)nNativeError));
306                 if (!more)
307                     break;  // Something went wrong, but we'll return the msg we have so far
308 
309                 Object both(PyUnicode_Concat(msg, more));
310                 if (!both)
311                     break;
312 
313                 msg = both.Detach();
314             }
315         }
316 
317         iRecord++;
318 
319 #ifndef _MSC_VER
320         // See non-Windows comment above
321         break;
322 #endif
323     }
324 
325     // Raw message buffer not needed anymore
326     pyodbc_free(szMsg);
327 
328     if (!msg || PyUnicode_GetSize(msg.Get()) == 0)
329     {
330         // This only happens using unixODBC.  (Haven't tried iODBC yet.)  Either the driver or the driver manager is
331         // buggy and has signaled a fault without recording error information.
332         sqlstate[0] = '\0';
333         msg = PyString_FromString(DEFAULT_ERROR);
334         if (!msg)
335         {
336             PyErr_NoMemory();
337             return 0;
338         }
339     }
340 
341     return GetError(sqlstate, 0, msg.Detach());
342 }
343 
344 
GetSqlState(HSTMT hstmt,char * szSqlState)345 static bool GetSqlState(HSTMT hstmt, char* szSqlState)
346 {
347     SQLSMALLINT cchMsg;
348     SQLRETURN ret;
349 
350     Py_BEGIN_ALLOW_THREADS
351     ret = SQLGetDiagField(SQL_HANDLE_STMT, hstmt, 1, SQL_DIAG_SQLSTATE, (SQLCHAR*)szSqlState, 5, &cchMsg);
352     Py_END_ALLOW_THREADS
353     return SQL_SUCCEEDED(ret);
354 }
355 
356 
HasSqlState(HSTMT hstmt,const char * szSqlState)357 bool HasSqlState(HSTMT hstmt, const char* szSqlState)
358 {
359     char szActual[6];
360     if (!GetSqlState(hstmt, szActual))
361         return false;
362     return memcmp(szActual, szSqlState, 5) == 0;
363 }
364