1 /* Copyright (c) 2011, 2019, Oracle and/or its affiliates. All rights reserved.
2 
3    This program is free software; you can redistribute it and/or modify
4    it under the terms of the GNU General Public License, version 2.0,
5    as published by the Free Software Foundation.
6 
7    This program is also distributed with certain software (including
8    but not limited to OpenSSL) that is licensed under separate terms,
9    as designated in a particular file or component or in included license
10    documentation.  The authors of MySQL hereby grant you an additional
11    permission to link the program and your derivative works with the
12    separately licensed software that they have included with MySQL.
13 
14    This program is distributed in the hope that it will be useful,
15    but WITHOUT ANY WARRANTY; without even the implied warranty of
16    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
17    GNU General Public License, version 2.0, for more details.
18 
19    You should have received a copy of the GNU General Public License
20    along with this program; if not, write to the Free Software
21    Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA */
22 
23 #include "common.h"
24 #include <sddl.h>   // for ConvertSidToStringSid()
25 #include <secext.h> // for GetUserNameEx()
26 
27 
28 template <> void error_log_print<error_log_level::INFO>(const char *fmt, ...);
29 template <> void error_log_print<error_log_level::WARNING>(const char *fmt, ...);
30 template <> void error_log_print<error_log_level::ERROR>(const char *fmt, ...);
31 
32 /**
33   Option indicating desired level of logging. Values:
34 
35   0 - no logging
36   1 - log only error messages
37   2 - additionally log warnings
38   3 - additionally log info notes
39   4 - also log debug messages
40 
41   Value of this option should be taken into account in the
42   implementation of  error_log_vprint() function (see
43   log_client.cc).
44 
45   Note: No error or debug messages are logged in production code
46   (see logging macros in common.h).
47 */
48 int opt_auth_win_log_level= 2;
49 
50 
51 /** Connection class **************************************************/
52 
53 /**
54   Create connection out of an active MYSQL_PLUGIN_VIO object.
55 
56   @param[in] vio  pointer to a @c MYSQL_PLUGIN_VIO object used for
57                   connection - it can not be NULL
58 */
59 
Connection(MYSQL_PLUGIN_VIO * vio)60 Connection::Connection(MYSQL_PLUGIN_VIO *vio): m_vio(vio), m_error(0)
61 {
62   DBUG_ASSERT(vio);
63 }
64 
65 
66 /**
67   Write data to the connection.
68 
69   @param[in]  blob  data to be written
70 
71   @return 0 on success, VIO error code on failure.
72 
73   @note In case of error, VIO error code is stored in the connection object
74   and can be obtained with @c error() method.
75 */
76 
write(const Blob & blob)77 int Connection::write(const Blob &blob)
78 {
79   m_error= m_vio->write_packet(m_vio, blob.ptr(), static_cast<int>(blob.len()));
80 
81 #ifndef DBUG_OFF
82   if (m_error)
83     DBUG_PRINT("error", ("vio write error %d", m_error));
84 #endif
85 
86   return m_error;
87 }
88 
89 
90 /**
91   Read data from connection.
92 
93   @return A Blob containing read packet or null Blob in case of error.
94 
95   @note In case of error, VIO error code is stored in the connection object
96   and can be obtained with @c error() method.
97 */
98 
read()99 Blob Connection::read()
100 {
101   unsigned char *ptr;
102   int len= m_vio->read_packet(m_vio, &ptr);
103 
104   if (len < 0)
105   {
106     m_error= true;
107     return Blob();
108   }
109 
110   return Blob(ptr, len);
111 }
112 
113 
114 /** Sid class *****************************************************/
115 
116 
117 /**
118   Create Sid object corresponding to a given account name.
119 
120   @param[in]  account_name  name of a Windows account
121 
122   The account name can be in any form accepted by @c LookupAccountName()
123   function.
124 
125   @note In case of errors created object is invalid and its @c is_valid()
126   method returns @c false.
127 */
128 
Sid(const wchar_t * account_name)129 Sid::Sid(const wchar_t *account_name): m_data(NULL)
130 #ifndef DBUG_OFF
131 , m_as_string(NULL)
132 #endif
133 {
134   DWORD sid_size= 0, domain_size= 0;
135   bool success;
136 
137   // Determine required buffer sizes
138 
139   success= LookupAccountNameW(NULL, account_name, NULL, &sid_size,
140                              NULL, &domain_size, &m_type);
141 
142   if (!success && GetLastError() != ERROR_INSUFFICIENT_BUFFER)
143   {
144 #ifndef DBUG_OFF
145     Error_message_buf error_buf;
146     DBUG_PRINT("error", ("Could not determine SID buffer size, "
147                          "LookupAccountName() failed with error %X (%s)",
148                          GetLastError(), get_last_error_message(error_buf)));
149 #endif
150     return;
151   }
152 
153   // Query for SID (domain is ignored)
154 
155   wchar_t *domain= new wchar_t[domain_size];
156   m_data= (TOKEN_USER*) new BYTE[sid_size + sizeof(TOKEN_USER)];
157   m_data->User.Sid= (BYTE*)m_data + sizeof(TOKEN_USER);
158 
159   success= LookupAccountNameW(NULL, account_name,
160                              m_data->User.Sid, &sid_size,
161                              domain, &domain_size,
162                              &m_type);
163 
164   if (!success || !is_valid())
165   {
166 #ifndef DBUG_OFF
167     Error_message_buf error_buf;
168     DBUG_PRINT("error", ("Could not determine SID of '%S', "
169                          "LookupAccountName() failed with error %X (%s)",
170                          account_name, GetLastError(),
171                          get_last_error_message(error_buf)));
172 #endif
173     goto fail;
174   }
175 
176   goto end;
177 
178 fail:
179   if (m_data)
180     delete [] m_data;
181   m_data= NULL;
182 
183 end:
184   if (domain)
185     delete [] domain;
186 }
187 
188 
189 /**
190   Create Sid object corresponding to a given security token.
191 
192   @param[in]  token   security token of a Windows account
193 
194   @note In case of errors created object is invalid and its @c is_valid()
195   method returns @c false.
196 */
197 
Sid(HANDLE token)198 Sid::Sid(HANDLE token): m_data(NULL)
199 #ifndef DBUG_OFF
200 , m_as_string(NULL)
201 #endif
202 {
203   DWORD             req_size= 0;
204   bool              success;
205 
206   // Determine required buffer size
207 
208   success= GetTokenInformation(token, TokenUser, NULL, 0, &req_size);
209   if (!success && GetLastError() != ERROR_INSUFFICIENT_BUFFER)
210   {
211 #ifndef DBUG_OFF
212     Error_message_buf error_buf;
213     DBUG_PRINT("error", ("Could not determine SID buffer size, "
214                          "GetTokenInformation() failed with error %X (%s)",
215                          GetLastError(), get_last_error_message(error_buf)));
216 #endif
217     return;
218   }
219 
220   m_data= (TOKEN_USER*) new BYTE[req_size];
221   success= GetTokenInformation(token, TokenUser, m_data, req_size, &req_size);
222 
223   if (!success || !is_valid())
224   {
225     delete [] m_data;
226     m_data= NULL;
227 #ifndef DBUG_OFF
228     if (!success)
229     {
230       Error_message_buf error_buf;
231       DBUG_PRINT("error", ("Could not read SID from security token, "
232                            "GetTokenInformation() failed with error %X (%s)",
233                            GetLastError(), get_last_error_message(error_buf)));
234     }
235 #endif
236   }
237 }
238 
239 
~Sid()240 Sid::~Sid()
241 {
242   if (m_data)
243     delete [] m_data;
244 #ifndef DBUG_OFF
245   if (m_as_string)
246     LocalFree(m_as_string);
247 #endif
248 }
249 
250 /// Check if Sid object is valid.
is_valid(void) const251 bool Sid::is_valid(void) const
252 {
253   return m_data && m_data->User.Sid && IsValidSid(m_data->User.Sid);
254 }
255 
256 
257 #ifndef DBUG_OFF
258 
259 /**
260   Produces string representation of the SID.
261 
262   @return String representation of the SID or NULL in case of errors.
263 
264   @note Memory allocated for the string is automatically freed in Sid's
265   destructor.
266 */
267 
as_string()268 const char* Sid::as_string()
269 {
270   if (!m_data)
271     return NULL;
272 
273   if (!m_as_string)
274   {
275     bool success= ConvertSidToStringSid(m_data->User.Sid, &m_as_string);
276 
277     if (!success)
278     {
279 #ifndef DBUG_OFF
280       Error_message_buf error_buf;
281       DBUG_PRINT("error", ("Could not get textual representation of a SID, "
282                            "ConvertSidToStringSid() failed with error %X (%s)",
283                            GetLastError(), get_last_error_message(error_buf)));
284 #endif
285       m_as_string= NULL;
286       return NULL;
287     }
288   }
289 
290   return m_as_string;
291 }
292 
293 #endif
294 
295 
operator ==(const Sid & other)296 bool Sid::operator ==(const Sid &other)
297 {
298   if (!is_valid() || !other.is_valid())
299     return false;
300 
301   return EqualSid(m_data->User.Sid, other.m_data->User.Sid);
302 }
303 
304 
305 /** Generating User Principal Name *************************/
306 
307 /**
308   Call Windows API functions to get UPN of the current user and store it
309   in internal buffer.
310 */
311 
UPN()312 UPN::UPN(): m_buf(NULL)
313 {
314   wchar_t  buf1[MAX_SERVICE_NAME_LENGTH];
315 
316   // First we try to use GetUserNameEx.
317 
318   m_len= sizeof(buf1)/sizeof(wchar_t);
319 
320   if (!GetUserNameExW(NameUserPrincipal, buf1, (PULONG)&m_len))
321   {
322     if (GetLastError())
323     {
324 #ifndef DBUG_OFF
325       Error_message_buf error_buf;
326       DBUG_PRINT("note", ("When determining UPN"
327                           ", GetUserNameEx() failed with error %X (%s)",
328                           GetLastError(), get_last_error_message(error_buf)));
329 #endif
330       if (ERROR_MORE_DATA == GetLastError())
331         ERROR_LOG(INFO, ("Buffer overrun when determining UPN:"
332                          " need %ul characters but have %ul",
333                          m_len, sizeof(buf1)/sizeof(WCHAR)));
334     }
335 
336     m_len= 0;   // m_len == 0 indicates invalid UPN
337     return;
338   }
339 
340   /*
341     UPN is stored in buf1 in wide-char format - convert it to utf8
342     for sending over network.
343   */
344 
345   m_buf= wchar_to_utf8(buf1, &m_len);
346 
347   if(!m_buf)
348     ERROR_LOG(ERROR, ("Failed to convert UPN to utf8"));
349 
350   // Note: possible error would be indicated by the fact that m_buf is NULL.
351   return;
352 }
353 
354 
~UPN()355 UPN::~UPN()
356 {
357   if (m_buf)
358     free(m_buf);
359 }
360 
361 
362 /**
363   Convert a wide-char string to utf8 representation.
364 
365   @param[in]     string   null-terminated wide-char string to be converted
366   @param[in,out] len      length of the string to be converted or 0; on
367                           return length (in bytes, excluding terminating
368                           null character) of the converted string
369 
370   If len is 0 then the length of the string will be computed by this function.
371 
372   @return Pointer to a buffer containing utf8 representation or NULL in
373           case of error.
374 
375   @note The returned buffer must be freed with @c free() call.
376 */
377 
wchar_to_utf8(const wchar_t * string,size_t * len)378 char* wchar_to_utf8(const wchar_t *string, size_t *len)
379 {
380   char   *buf= NULL;
381   size_t  str_len= len && *len ? *len : wcslen(string);
382 
383   /*
384     A conversion from utf8 to wchar_t will never take more than 3 bytes per
385     character, so a buffer of length 3 * str_len schould be sufficient.
386     We check that assumption with an assertion later.
387   */
388 
389   size_t  buf_len= 3 * str_len;
390 
391   buf= (char*)malloc(buf_len + 1);
392   if (!buf)
393   {
394     DBUG_PRINT("error",("Out of memory when converting string '%S' to utf8",
395                         string));
396     return NULL;
397   }
398 
399   int res= WideCharToMultiByte(CP_UTF8,              // convert to UTF-8
400                                0,                    // conversion flags
401                                string,               // input buffer
402                                str_len,              // its length
403                                buf, buf_len,         // output buffer and its size
404                                NULL, NULL);          // default character (not used)
405 
406   if (res)
407   {
408     buf[res]= '\0';
409     if (len)
410       *len= res;
411     return buf;
412   }
413 
414   // res is 0 which indicates error
415 
416 #ifndef DBUG_OFF
417   Error_message_buf error_buf;
418   DBUG_PRINT("error", ("Could not convert string '%S' to utf8"
419                        ", WideCharToMultiByte() failed with error %X (%s)",
420                        string, GetLastError(),
421                        get_last_error_message(error_buf)));
422 #endif
423 
424   // Let's check our assumption about sufficient buffer size
425   DBUG_ASSERT(ERROR_INSUFFICIENT_BUFFER != GetLastError());
426 
427   return NULL;
428 }
429 
430 
431 /**
432   Convert an utf8 string to a wide-char string.
433 
434   @param[in]     string   null-terminated utf8 string to be converted
435   @param[in,out] len      length of the string to be converted or 0; on
436                           return length (in chars) of the converted string
437 
438   If len is 0 then the length of the string will be computed by this function.
439 
440   @return Pointer to a buffer containing wide-char representation or NULL in
441           case of error.
442 
443   @note The returned buffer must be freed with @c free() call.
444 */
445 
utf8_to_wchar(const char * string,size_t * len)446 wchar_t* utf8_to_wchar(const char *string, size_t *len)
447 {
448   size_t buf_len;
449 
450   /*
451     Note: length (in bytes) of an utf8 string is always bigger than the
452     number of characters in this string. Hence a buffer of size len will
453     be sufficient. We add 1 for the terminating null character.
454   */
455 
456   buf_len= len && *len ? *len : strlen(string);
457   wchar_t *buf=  (wchar_t*)malloc((buf_len+1)*sizeof(wchar_t));
458 
459   if (!buf)
460   {
461     DBUG_PRINT("error",("Out of memory when converting utf8 string '%s'"
462                         " to wide-char representation", string));
463     return NULL;
464   }
465 
466   size_t  res;
467   res= MultiByteToWideChar(CP_UTF8,            // convert from UTF-8
468                            0,                  // conversion flags
469                            string,             // input buffer
470                            buf_len,            // its size
471                            buf, buf_len);      // output buffer and its size
472   if (res)
473   {
474     buf[res]= '\0';
475     if (len)
476       *len= res;
477     return buf;
478   }
479 
480   // error in MultiByteToWideChar()
481 
482 #ifndef DBUG_OFF
483   Error_message_buf error_buf;
484   DBUG_PRINT("error", ("Could not convert UPN from UTF-8"
485                        ", MultiByteToWideChar() failed with error %X (%s)",
486                        GetLastError(), get_last_error_message(error_buf)));
487 #endif
488 
489   // Let's check our assumption about sufficient buffer size
490   DBUG_ASSERT(ERROR_INSUFFICIENT_BUFFER != GetLastError());
491 
492   return NULL;
493 }
494 
495 
496 /** Error handling ****************************************************/
497 
498 
499 /**
500   Returns error message corresponding to the last Windows error given
501   by GetLastError().
502 
503   @note Error message is overwritten by next call to
504   @c get_last_error_message().
505 */
506 
get_last_error_message(Error_message_buf buf)507 const char* get_last_error_message(Error_message_buf buf)
508 {
509   int error= GetLastError();
510 
511   buf[0]= '\0';
512   FormatMessage(FORMAT_MESSAGE_FROM_SYSTEM,
513 		NULL, error, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT),
514 		(LPTSTR)buf, sizeof(Error_message_buf), NULL);
515 
516   return buf;
517 }
518