1 /*
2  * Copyright (c) 2015, 2017, Oracle and/or its affiliates. All rights reserved.
3  *
4  * This program is free software; you can redistribute it and/or modify
5  * it under the terms of the GNU General Public License, version 2.0,
6  * as published by the Free Software Foundation.
7  *
8  * This program is also distributed with certain software (including
9  * but not limited to OpenSSL) that is licensed under separate terms,
10  * as designated in a particular file or component or in included license
11  * documentation.  The authors of MySQL hereby grant you an additional
12  * permission to link the program and your derivative works with the
13  * separately licensed software that they have included with MySQL.
14  *
15  * This program is distributed in the hope that it will be useful,
16  * but WITHOUT ANY WARRANTY; without even the implied warranty of
17  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
18  * GNU General Public License, version 2.0, for more details.
19  *
20  * You should have received a copy of the GNU General Public License
21  * along with this program; if not, write to the Free Software
22  * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA
23  * 02110-1301  USA
24  */
25 
26 // Avoid warnings from includes of other project and protobuf
27 #if __GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 6)
28 #pragma GCC diagnostic push
29 #pragma GCC diagnostic ignored "-Wshadow"
30 #pragma GCC diagnostic ignored "-Wunused-parameter"
31 #elif defined _MSC_VER
32 #pragma warning (push)
33 #pragma warning (disable : 4018 4996)
34 #endif
35 
36 #include "ngs_common/protocol_protobuf.h"
37 #include "mysqlx_protocol.h"
38 #include "mysqlx_resultset.h"
39 #include "mysqlx_row.h"
40 #include "mysqlx_error.h"
41 #include "mysqlx_version.h"
42 
43 #include "my_config.h"
44 #include "ngs_common/bind.h"
45 
46 #ifdef MYSQLXTEST_STANDALONE
47 #include "mysqlx/auth_mysql41.h"
48 #else
49 #include "password_hasher.h"
50 namespace mysqlx {
build_mysql41_authentication_response(const std::string & salt_data,const std::string & user,const std::string & password,const std::string & schema)51   std::string build_mysql41_authentication_response(const std::string &salt_data,
52     const std::string &user,
53     const std::string &password,
54     const std::string &schema)
55   {
56     std::string password_hash;
57     if (password.length())
58       password_hash = Password_hasher::get_password_from_salt(Password_hasher::scramble(salt_data.c_str(), password.c_str()));
59     std::string data;
60     data.append(schema).push_back('\0'); // authz
61     data.append(user).push_back('\0'); // authc
62     data.append(password_hash); // pass
63     return data;
64   }
65 }
66 #endif
67 
68 #if __GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 6)
69 #pragma GCC diagnostic pop
70 #elif defined _MSC_VER
71 #pragma warning (pop)
72 #endif
73 
74 #include <iostream>
75 #ifndef WIN32
76 #include <netdb.h>
77 #include <sys/socket.h>
78 #endif // WIN32
79 #ifdef HAVE_SYS_UN_H
80 #include <sys/un.h>
81 #endif // HAVE_SYS_UN_H
82 #include <string>
83 #include <iostream>
84 #include <limits>
85 
86 #ifdef WIN32
87 #  define snprintf _snprintf
88 #  pragma push_macro("ERROR")
89 #  undef ERROR
90 #endif
91 
92 using namespace mysqlx;
93 
parse_mysql_connstring(const std::string & connstring,std::string & protocol,std::string & user,std::string & password,std::string & host,int & port,std::string & sock,std::string & db,int & pwd_found)94 bool mysqlx::parse_mysql_connstring(const std::string &connstring,
95                                     std::string &protocol, std::string &user, std::string &password,
96                                     std::string &host, int &port, std::string &sock,
97                                     std::string &db, int &pwd_found)
98 {
99   // format is [protocol://][user[:pass]]@host[:port][/db] or user[:pass]@::socket[/db], like what cmdline utilities use
100   pwd_found = 0;
101   std::string remaining = connstring;
102 
103   std::string::size_type p;
104   p = remaining.find("://");
105   if (p != std::string::npos)
106   {
107     protocol = connstring.substr(0, p);
108     remaining = remaining.substr(p + 3);
109   }
110 
111   std::string s = remaining;
112   p = remaining.find('/');
113   if (p != std::string::npos)
114   {
115     db = remaining.substr(p + 1);
116     s = remaining.substr(0, p);
117   }
118   p = s.rfind('@');
119   std::string user_part;
120   std::string server_part = (p == std::string::npos) ? s : s.substr(p + 1);
121 
122   if (p == std::string::npos)
123   {
124     // by default, connect using the current OS username
125 #ifdef _WIN32
126     char tmp_buffer[1024];
127     char *tmp = tmp_buffer;
128     DWORD tmp_size = sizeof(tmp_buffer);
129 
130     if (!GetUserNameA(tmp_buffer, &tmp_size))
131     {
132       tmp = NULL;
133     }
134 #else
135     const char *tmp = getenv("USER");
136 #endif
137     user_part = tmp ? tmp : "";
138   }
139   else
140     user_part = s.substr(0, p);
141 
142   if ((p = user_part.find(':')) != std::string::npos)
143   {
144     user = user_part.substr(0, p);
145     password = user_part.substr(p + 1);
146     pwd_found = 1;
147   }
148   else
149     user = user_part;
150 
151   p = server_part.find(':');
152   if (p != std::string::npos)
153   {
154     host = server_part.substr(0, p);
155     server_part = server_part.substr(p + 1);
156     p = server_part.find(':');
157     if (p != std::string::npos)
158       sock = server_part.substr(p + 1);
159     else
160       if (!sscanf(server_part.substr(0, p).c_str(), "%i", &port))
161         return false;
162   }
163   else
164     host = server_part;
165   return true;
166 }
167 
throw_server_error(const Mysqlx::Error & error)168 static void throw_server_error(const Mysqlx::Error &error)
169 {
170   throw Error(error.code(), error.msg());
171 }
172 
173 
XProtocol(const Ssl_config & ssl_config,const std::size_t timeout,const bool dont_wait_for_disconnect,const Internet_protocol ip_mode)174 XProtocol::XProtocol(const Ssl_config &ssl_config,
175                      const std::size_t timeout,
176                      const bool dont_wait_for_disconnect,
177                      const Internet_protocol ip_mode)
178 : m_sync_connection(ssl_config.key, ssl_config.ca, ssl_config.ca_path,
179                     ssl_config.cert, ssl_config.cipher, ssl_config.tls_version, timeout),
180   m_client_id(0),
181   m_trace_packets(false), m_closed(true),
182   m_dont_wait_for_disconnect(dont_wait_for_disconnect),
183   m_ip_mode(ip_mode)
184 {
185   if (getenv("MYSQLX_TRACE_CONNECTION"))
186     m_trace_packets = true;
187 }
188 
~XProtocol()189 XProtocol::~XProtocol()
190 {
191   try
192   {
193     close();
194   }
195   catch (Error &)
196   {
197     // ignore close errors
198   }
199 }
200 
connect(const std::string & uri,const std::string & pass,const bool cap_expired_password)201 void XProtocol::connect(const std::string &uri, const std::string &pass, const bool cap_expired_password)
202 {
203   std::string protocol, host, schema, user, password;
204   std::string sock;
205   int pwd_found = 0;
206   int port = MYSQLX_TCP_PORT;
207 
208   if (!parse_mysql_connstring(uri, protocol, user, password, host, port, sock, schema, pwd_found))
209     throw Error(CR_WRONG_HOST_INFO, "Unable to parse connection string");
210 
211   if (protocol != "mysqlx" && !protocol.empty())
212     throw Error(CR_WRONG_HOST_INFO, "Unsupported protocol "+protocol);
213 
214   if (!pass.empty())
215     password = pass;
216 
217   connect(host, port);
218 
219   if (cap_expired_password)
220     setup_capability("client.pwd_expire_ok", true);
221 
222   authenticate(user, pass.empty() ? password : pass, schema);
223 }
224 
connect(const std::string & host,int port)225 void XProtocol::connect(const std::string &host, int port)
226 {
227   struct addrinfo *res_lst, hints, *t_res;
228   int gai_errno;
229   Error error;
230   char port_buf[NI_MAXSERV];
231 
232   snprintf(port_buf, NI_MAXSERV, "%d", port);
233 
234   memset(&hints, 0, sizeof(hints));
235   hints.ai_socktype= SOCK_STREAM;
236   hints.ai_protocol= IPPROTO_TCP;
237   hints.ai_family= AF_UNSPEC;
238 
239   if (IPv6 == m_ip_mode)
240     hints.ai_family = AF_INET6;
241   else if (IPv4 == m_ip_mode)
242     hints.ai_family = AF_INET;
243 
244   gai_errno= getaddrinfo(host.c_str(), port_buf, &hints, &res_lst);
245   if (gai_errno != 0)
246     throw Error(CR_UNKNOWN_HOST, "No such host is known '" + host + "'");
247 
248   for (t_res= res_lst; t_res; t_res= t_res->ai_next)
249   {
250     error = m_sync_connection.connect((sockaddr*)t_res->ai_addr, t_res->ai_addrlen);
251 
252     if (!error)
253       break;
254   }
255   freeaddrinfo(res_lst);
256 
257   if (error)
258   {
259     std::string error_description = error.what();
260     throw Error(CR_CONNECTION_ERROR, error_description + " connecting to " + host + ":" + port_buf);
261   }
262 
263   m_closed = false;
264 }
265 
connect_to_localhost(const std::string & unix_socket_or_named_pipe)266 void XProtocol::connect_to_localhost(const std::string &unix_socket_or_named_pipe)
267 {
268   Error error = m_sync_connection.connect_to_localhost(unix_socket_or_named_pipe);
269 
270   if (error)
271   {
272     std::string error_description = error.what();
273     throw Error(CR_CONNECTION_ERROR, error_description + ", while connecting to "+unix_socket_or_named_pipe);
274   }
275 
276   m_closed = false;
277 }
278 
279 
280 
authenticate(const std::string & user,const std::string & pass,const std::string & schema)281 void XProtocol::authenticate(const std::string &user, const std::string &pass, const std::string &schema)
282 {
283   if (m_sync_connection.supports_ssl())
284   {
285     setup_capability("tls", true);
286 
287     enable_tls();
288     authenticate_plain(user, pass, schema);
289   }
290   else
291     authenticate_mysql41(user, pass, schema);
292 }
293 
fetch_capabilities()294 void XProtocol::fetch_capabilities()
295 {
296   send(Mysqlx::Connection::CapabilitiesGet());
297   int mid;
298   ngs::unique_ptr<Message> message(recv_raw(mid));
299   if (mid != Mysqlx::ServerMessages::CONN_CAPABILITIES)
300     throw Error(CR_COMMANDS_OUT_OF_SYNC, "Unexpected response received from server");
301   m_capabilities = *static_cast<Mysqlx::Connection::Capabilities*>(message.get());
302 }
303 
enable_tls()304 void XProtocol::enable_tls()
305 {
306   Error ec = m_sync_connection.activate_tls();
307 
308   if (ec)
309   {
310     // If ssl activation failed then
311     // server and client are in different states
312     // lets force disconnect
313     set_closed();
314 
315     throw ec;
316   }
317 }
318 
set_closed()319 void XProtocol::set_closed()
320 {
321   m_closed = true;
322 }
323 
close()324 void XProtocol::close()
325 {
326   if (!m_closed)
327   {
328     if (m_last_result)
329       m_last_result->buffer();
330 
331     send(Mysqlx::Session::Close());
332     m_closed = true;
333 
334     int mid;
335     try
336     {
337       ngs::unique_ptr<Message> message(recv_raw(mid));
338       if (mid != Mysqlx::ServerMessages::OK)
339         throw Error(CR_COMMANDS_OUT_OF_SYNC, "Unexpected message received in response to Session.Close");
340 
341       perform_close();
342     }
343     catch (...)
344     {
345       m_sync_connection.close();
346       throw;
347     }
348   }
349 }
350 
get_received_msg_counter(const std::string & id) const351 unsigned long XProtocol::get_received_msg_counter(const std::string &id) const
352 {
353   std::map<std::string, unsigned long>::const_iterator i =
354       m_received_msg_counters.find(id);
355   return i == m_received_msg_counters.end() ? 0ul : i->second;
356 }
357 
perform_close()358 void XProtocol::perform_close()
359 {
360   if (m_dont_wait_for_disconnect)
361   {
362     m_sync_connection.close();
363     return;
364   }
365 
366   int mid;
367   ngs::unique_ptr<Message> message(recv_raw(mid));
368   std::stringstream s;
369 
370   s << "Unexpected message received with id:" << mid << " while waiting for disconnection";
371 
372   throw Error(CR_COMMANDS_OUT_OF_SYNC, s.str());
373 }
374 
recv_result()375 ngs::shared_ptr<Result> XProtocol::recv_result()
376 {
377   return new_result(true);
378 }
379 
new_empty_result()380 ngs::shared_ptr<Result> XProtocol::new_empty_result()
381 {
382   ngs::shared_ptr<Result> empty_result(new Result(shared_from_this(), false, false));
383 
384   return empty_result;
385 }
386 
execute_sql(const std::string & sql)387 ngs::shared_ptr<Result> XProtocol::execute_sql(const std::string &sql)
388 {
389   {
390     Mysqlx::Sql::StmtExecute exec;
391     exec.set_namespace_("sql");
392     exec.set_stmt(sql);
393     send(exec);
394   }
395 
396   return new_result(true);
397 }
398 
execute_stmt(const std::string & ns,const std::string & sql,const std::vector<ArgumentValue> & args)399 ngs::shared_ptr<Result> XProtocol::execute_stmt(const std::string &ns, const std::string &sql, const std::vector<ArgumentValue> &args)
400 {
401   {
402     Mysqlx::Sql::StmtExecute exec;
403     exec.set_namespace_(ns);
404     exec.set_stmt(sql);
405 
406     for (std::vector<ArgumentValue>::const_iterator iter = args.begin();
407          iter != args.end(); ++iter)
408     {
409       Mysqlx::Datatypes::Any *any = exec.mutable_args()->Add();
410 
411       any->set_type(Mysqlx::Datatypes::Any::SCALAR);
412       switch (iter->type())
413       {
414         case ArgumentValue::TInteger:
415           any->mutable_scalar()->set_type(Mysqlx::Datatypes::Scalar::V_SINT);
416           any->mutable_scalar()->set_v_signed_int(*iter);
417           break;
418         case ArgumentValue::TUInteger:
419           any->mutable_scalar()->set_type(Mysqlx::Datatypes::Scalar::V_UINT);
420           any->mutable_scalar()->set_v_unsigned_int(*iter);
421           break;
422         case ArgumentValue::TNull:
423           any->mutable_scalar()->set_type(Mysqlx::Datatypes::Scalar::V_NULL);
424           break;
425         case ArgumentValue::TDouble:
426           any->mutable_scalar()->set_type(Mysqlx::Datatypes::Scalar::V_DOUBLE);
427           any->mutable_scalar()->set_v_double(*iter);
428           break;
429         case ArgumentValue::TFloat:
430           any->mutable_scalar()->set_type(Mysqlx::Datatypes::Scalar::V_FLOAT);
431           any->mutable_scalar()->set_v_float(*iter);
432           break;
433         case ArgumentValue::TBool:
434           any->mutable_scalar()->set_type(Mysqlx::Datatypes::Scalar::V_BOOL);
435           any->mutable_scalar()->set_v_bool(*iter);
436           break;
437         case ArgumentValue::TString:
438           any->mutable_scalar()->set_type(Mysqlx::Datatypes::Scalar::V_STRING);
439           any->mutable_scalar()->mutable_v_string()->set_value(*iter);
440           break;
441         case ArgumentValue::TOctets:
442           any->mutable_scalar()->set_type(Mysqlx::Datatypes::Scalar::V_OCTETS);
443           any->mutable_scalar()->mutable_v_octets()->set_value(*iter);
444           break;
445       }
446     }
447     send(exec);
448   }
449 
450   return new_result(true);
451 }
452 
execute_find(const Mysqlx::Crud::Find & m)453 ngs::shared_ptr<Result> XProtocol::execute_find(const Mysqlx::Crud::Find &m)
454 {
455   send(m);
456 
457   return new_result(true);
458 }
459 
execute_update(const Mysqlx::Crud::Update & m)460 ngs::shared_ptr<Result> XProtocol::execute_update(const Mysqlx::Crud::Update &m)
461 {
462   send(m);
463 
464   return new_result(false);
465 }
466 
execute_insert(const Mysqlx::Crud::Insert & m)467 ngs::shared_ptr<Result> XProtocol::execute_insert(const Mysqlx::Crud::Insert &m)
468 {
469   send(m);
470 
471   return new_result(false);
472 }
473 
execute_delete(const Mysqlx::Crud::Delete & m)474 ngs::shared_ptr<Result> XProtocol::execute_delete(const Mysqlx::Crud::Delete &m)
475 {
476   send(m);
477 
478   return new_result(false);
479 }
480 
setup_capability(const std::string & name,const bool value)481 void XProtocol::setup_capability(const std::string &name, const bool value)
482 {
483   Mysqlx::Connection::CapabilitiesSet capSet;
484   Mysqlx::Connection::Capability     *cap = capSet.mutable_capabilities()->add_capabilities();
485   ::Mysqlx::Datatypes::Scalar        *scalar = cap->mutable_value()->mutable_scalar();
486 
487   cap->set_name(name);
488   cap->mutable_value()->set_type(Mysqlx::Datatypes::Any_Type_SCALAR);
489   scalar->set_type(Mysqlx::Datatypes::Scalar_Type_V_BOOL);
490   scalar->set_v_bool(value);
491   send(capSet);
492 
493   if (m_last_result)
494     m_last_result->buffer();
495 
496   int mid;
497   ngs::unique_ptr<Message> msg(recv_raw(mid));
498 
499   if (Mysqlx::ServerMessages::ERROR == mid)
500     throw_server_error(*(Mysqlx::Error*)msg.get());
501   if (Mysqlx::ServerMessages::OK != mid)
502   {
503     if (getenv("MYSQLX_DEBUG"))
504     {
505       std::string out;
506       google::protobuf::TextFormat::PrintToString(*msg, &out);
507       std::cout << out << "\n";
508     }
509     throw Error(CR_MALFORMED_PACKET, "Unexpected message received from server during handshake");
510   }
511 }
512 
authenticate_mysql41(const std::string & user,const std::string & pass,const std::string & db)513 void XProtocol::authenticate_mysql41(const std::string &user, const std::string &pass, const std::string &db)
514 {
515   {
516     Mysqlx::Session::AuthenticateStart auth;
517 
518     auth.set_mech_name("MYSQL41");
519 
520     send(Mysqlx::ClientMessages::SESS_AUTHENTICATE_START, auth);
521   }
522 
523   {
524     int mid;
525     ngs::unique_ptr<Message> message(recv_raw(mid));
526     switch (mid)
527     {
528       case Mysqlx::ServerMessages::SESS_AUTHENTICATE_CONTINUE:
529       {
530         Mysqlx::Session::AuthenticateContinue &auth_continue = *static_cast<Mysqlx::Session::AuthenticateContinue*>(message.get());
531 
532         std::string data;
533 
534         if (!auth_continue.has_auth_data())
535           throw Error(CR_MALFORMED_PACKET, "Missing authentication data");
536 
537         std::string password_hash;
538 
539         Mysqlx::Session::AuthenticateContinue auth_continue_response;
540 
541 #ifdef MYSQLXTEST_STANDALONE
542         auth_continue_response.set_auth_data(build_mysql41_authentication_response(auth_continue.auth_data(), user, pass, db));
543 #else
544         if (pass.length())
545         {
546           password_hash = Password_hasher::scramble(auth_continue.auth_data().c_str(), pass.c_str());
547           password_hash = Password_hasher::get_password_from_salt(password_hash);
548         }
549 
550         data.append(db).push_back('\0'); // authz
551         data.append(user).push_back('\0'); // authc
552         data.append(password_hash); // pass
553         auth_continue_response.set_auth_data(data);
554 #endif
555 
556         send(Mysqlx::ClientMessages::SESS_AUTHENTICATE_CONTINUE, auth_continue_response);
557       }
558       break;
559 
560       case Mysqlx::ServerMessages::NOTICE:
561         dispatch_notice(static_cast<Mysqlx::Notice::Frame*>(message.get()));
562         break;
563 
564       case Mysqlx::ServerMessages::ERROR:
565         throw_server_error(*static_cast<Mysqlx::Error*>(message.get()));
566         break;
567 
568       default:
569         throw Error(CR_MALFORMED_PACKET, "Unexpected message received from server during authentication");
570         break;
571     }
572   }
573 
574   bool done = false;
575   while (!done)
576   {
577     int mid;
578     ngs::unique_ptr<Message> message(recv_raw(mid));
579     switch (mid)
580     {
581       case Mysqlx::ServerMessages::SESS_AUTHENTICATE_OK:
582         done = true;
583         break;
584 
585       case Mysqlx::ServerMessages::ERROR:
586         throw_server_error(*static_cast<Mysqlx::Error*>(message.get()));
587         break;
588 
589       case Mysqlx::ServerMessages::NOTICE:
590         dispatch_notice(static_cast<Mysqlx::Notice::Frame*>(message.get()));
591         break;
592 
593       default:
594         throw Error(CR_MALFORMED_PACKET, "Unexpected message received from server during authentication");
595         break;
596     }
597   }
598 }
599 
authenticate_plain(const std::string & user,const std::string & pass,const std::string & db)600 void XProtocol::authenticate_plain(const std::string &user, const std::string &pass, const std::string &db)
601 {
602   {
603     Mysqlx::Session::AuthenticateStart auth;
604 
605     auth.set_mech_name("PLAIN");
606     std::string data;
607 
608     data.append(db).push_back('\0'); // authz
609     data.append(user).push_back('\0'); // authc
610     data.append(pass); // pass
611 
612     auth.set_auth_data(data);
613     send(Mysqlx::ClientMessages::SESS_AUTHENTICATE_START, auth);
614   }
615 
616   bool done = false;
617   while (!done)
618   {
619     int mid;
620     ngs::unique_ptr<Message> message(recv_raw(mid));
621     switch (mid)
622     {
623       case Mysqlx::ServerMessages::SESS_AUTHENTICATE_OK:
624         done = true;
625         break;
626 
627       case Mysqlx::ServerMessages::ERROR:
628         throw_server_error(*static_cast<Mysqlx::Error*>(message.get()));
629         break;
630 
631       case Mysqlx::ServerMessages::NOTICE:
632         dispatch_notice(static_cast<Mysqlx::Notice::Frame*>(message.get()));
633         break;
634 
635       default:
636         throw Error(CR_MALFORMED_PACKET, "Unexpected message received from server during authentication");
637         break;
638     }
639   }
640 }
641 
send_bytes(const std::string & data)642 void XProtocol::send_bytes(const std::string &data)
643 {
644   Error error = m_sync_connection.write(data.data(), data.size());
645   throw_mysqlx_error(error);
646 }
647 
send(int mid,const Message & msg)648 void XProtocol::send(int mid, const Message &msg)
649 {
650   Error error;
651   union
652   {
653     uint8_t buf[5];                        // Must be properly aligned
654     longlong dummy;
655   };
656   /*
657     Use dummy, otherwise g++ 4.4 reports: unused variable 'dummy'
658     MY_ATTRIBUTE((unused)) did not work, so we must use it.
659   */
660   dummy= 0;
661 
662   uint32_t *buf_ptr = (uint32_t *)buf;
663   *buf_ptr = msg.ByteSize() + 1;
664 #ifdef WORDS_BIGENDIAN
665   std::swap(buf[0], buf[3]);
666   std::swap(buf[1], buf[2]);
667 #endif
668   buf[4] = mid;
669 
670   if (m_trace_packets)
671   {
672     std::string out;
673     google::protobuf::TextFormat::Printer p;
674     p.SetInitialIndentLevel(1);
675     p.PrintToString(msg, &out);
676     std::cout << ">>>> SEND " << msg.ByteSize()+1 << " " << msg.GetDescriptor()->full_name() << " {\n" << out << "}\n";
677   }
678 
679   error = m_sync_connection.write(buf, 5);
680   if (!error)
681   {
682     std::string mbuf;
683     msg.SerializeToString(&mbuf);
684 
685     if (0 != mbuf.length())
686       error = m_sync_connection.write(mbuf.data(), mbuf.length());
687   }
688 
689   throw_mysqlx_error(error);
690 }
691 
push_local_notice_handler(Local_notice_handler handler)692 void XProtocol::push_local_notice_handler(Local_notice_handler handler)
693 {
694   m_local_notice_handlers.push_back(handler);
695 }
696 
pop_local_notice_handler()697 void XProtocol::pop_local_notice_handler()
698 {
699   m_local_notice_handlers.pop_back();
700 }
701 
dispatch_notice(Mysqlx::Notice::Frame * frame)702 void XProtocol::dispatch_notice(Mysqlx::Notice::Frame *frame)
703 {
704   if (frame->scope() == Mysqlx::Notice::Frame::LOCAL)
705   {
706     for (std::list<Local_notice_handler>::iterator iter = m_local_notice_handlers.begin();
707          iter != m_local_notice_handlers.end(); ++iter)
708       if ((*iter)(frame->type(), frame->payload())) // handler returns true if the notice was handled
709         return;
710 
711     {
712       if (frame->type() == 3)
713       {
714         Mysqlx::Notice::SessionStateChanged change;
715         change.ParseFromString(frame->payload());
716         if (!change.IsInitialized())
717           std::cerr << "Invalid notice received from server " << change.InitializationErrorString() << "\n";
718         else
719         {
720           if (change.param() == Mysqlx::Notice::SessionStateChanged::ACCOUNT_EXPIRED)
721           {
722             std::cout << "NOTICE: Account password expired\n";
723             return;
724           }
725           else if (change.param() == Mysqlx::Notice::SessionStateChanged::CLIENT_ID_ASSIGNED)
726           {
727             if (!change.has_value() || change.value().type() != Mysqlx::Datatypes::Scalar::V_UINT)
728               std::cerr << "Invalid notice received from server. Client_id is of the wrong type\n";
729             else
730               m_client_id = change.value().v_unsigned_int();
731             return;
732           }
733         }
734       }
735       std::cout << "Unhandled local notice\n";
736     }
737   }
738   else
739   {
740     std::cout << "Unhandled global notice\n";
741   }
742 }
743 
recv_next(int & mid)744 Message *XProtocol::recv_next(int &mid)
745 {
746   for (;;)
747   {
748     Message *msg = recv_raw(mid);
749     if (mid != Mysqlx::ServerMessages::NOTICE)
750       return msg;
751 
752     dispatch_notice(static_cast<Mysqlx::Notice::Frame*>(msg));
753     delete msg;
754   }
755 }
756 
recv_raw_with_deadline(int & mid,const int deadline_milliseconds)757 Message *XProtocol::recv_raw_with_deadline(int &mid, const int deadline_milliseconds)
758 {
759   char header_buffer[5];
760   std::size_t data = sizeof(header_buffer);
761   Error error = m_sync_connection.read_with_timeout(header_buffer, data, deadline_milliseconds);
762 
763   if (0 == data)
764   {
765     m_closed = true;
766     return NULL;
767   }
768 
769   throw_mysqlx_error(error);
770 
771   return recv_message_with_header(mid, header_buffer, sizeof(header_buffer));
772 }
773 
recv_payload(const int mid,const std::size_t msglen)774 Message *XProtocol::recv_payload(const int mid, const std::size_t msglen)
775 {
776   Error error;
777   Message* ret_val = NULL;
778   char *mbuf = new char[msglen];
779 
780   if (0 < msglen)
781     error = m_sync_connection.read(mbuf, msglen);
782 
783   if (!error)
784   {
785     switch (mid)
786     {
787       case Mysqlx::ServerMessages::OK:
788         ret_val = new Mysqlx::Ok();
789         break;
790       case Mysqlx::ServerMessages::ERROR:
791         ret_val = new Mysqlx::Error();
792         break;
793       case Mysqlx::ServerMessages::NOTICE:
794         ret_val = new Mysqlx::Notice::Frame();
795         break;
796       case Mysqlx::ServerMessages::CONN_CAPABILITIES:
797         ret_val = new Mysqlx::Connection::Capabilities();
798         break;
799       case Mysqlx::ServerMessages::SESS_AUTHENTICATE_CONTINUE:
800         ret_val = new Mysqlx::Session::AuthenticateContinue();
801         break;
802       case Mysqlx::ServerMessages::SESS_AUTHENTICATE_OK:
803         ret_val = new Mysqlx::Session::AuthenticateOk();
804         break;
805       case Mysqlx::ServerMessages::RESULTSET_COLUMN_META_DATA:
806         ret_val = new Mysqlx::Resultset::ColumnMetaData();
807         break;
808       case Mysqlx::ServerMessages::RESULTSET_ROW:
809         ret_val = new Mysqlx::Resultset::Row();
810         break;
811       case Mysqlx::ServerMessages::RESULTSET_FETCH_DONE:
812         ret_val = new Mysqlx::Resultset::FetchDone();
813         break;
814       case Mysqlx::ServerMessages::RESULTSET_FETCH_DONE_MORE_RESULTSETS:
815         ret_val = new Mysqlx::Resultset::FetchDoneMoreResultsets();
816         break;
817       case Mysqlx::ServerMessages::SQL_STMT_EXECUTE_OK:
818         ret_val = new Mysqlx::Sql::StmtExecuteOk();
819         break;
820     }
821 
822     if (!ret_val)
823     {
824       delete[] mbuf;
825       std::stringstream ss;
826       ss << "Unknown message received from server ";
827       ss << mid;
828       throw Error(CR_MALFORMED_PACKET, ss.str());
829     }
830 
831     // Parses the received message
832     ret_val->ParseFromString(std::string(mbuf, msglen));
833 
834     if (m_trace_packets)
835     {
836       std::string out;
837       google::protobuf::TextFormat::Printer p;
838       p.SetInitialIndentLevel(1);
839       p.PrintToString(*ret_val, &out);
840       std::cout << "<<<< RECEIVE " << msglen << " " << ret_val->GetDescriptor()->full_name() << " {\n" << out << "}\n";
841     }
842 
843     if (!ret_val->IsInitialized())
844     {
845       std::string err("Message is not properly initialized: ");
846       err += ret_val->InitializationErrorString();
847 
848       delete[] mbuf;
849       delete ret_val;
850 
851       throw Error(CR_MALFORMED_PACKET, err);
852     }
853   }
854   else
855   {
856     delete[] mbuf;
857     throw_mysqlx_error(error);
858   }
859 
860   delete[] mbuf;
861   update_received_msg_counter(ret_val);
862   return ret_val;
863 }
864 
recv_raw(int & mid)865 Message *XProtocol::recv_raw(int &mid)
866 {
867   union
868   {
869     char buf[5];                                // Must be properly aligned
870     longlong dummy;
871   };
872 
873   /*
874     Use dummy, otherwise g++ 4.4 reports: unused variable 'dummy'
875     MY_ATTRIBUTE((unused)) did not work, so we must use it.
876   */
877   dummy= 0;
878   mid = 0;
879 
880   return recv_message_with_header(mid, buf, 0);
881 }
882 
recv_message_with_header(int & mid,char (& header_buffer)[5],const std::size_t header_offset)883 Message *XProtocol::recv_message_with_header(int &mid, char (&header_buffer)[5], const std::size_t header_offset)
884 {
885   Message* ret_val = NULL;
886   Error error;
887 
888   error = m_sync_connection.read(header_buffer + header_offset, 5 - header_offset);
889 
890 #ifdef WORDS_BIGENDIAN
891   std::swap(header_buffer[0], header_buffer[3]);
892   std::swap(header_buffer[1], header_buffer[2]);
893 #endif
894 
895   if (!error)
896   {
897   uint32_t msglen = *(uint32_t*)header_buffer - 1;
898   mid = header_buffer[4];
899 
900   ret_val = recv_payload(mid, msglen);
901   }
902   else
903   {
904     throw_mysqlx_error(error);
905   }
906 
907   return ret_val;
908 }
909 
throw_mysqlx_error(const Error & error)910 void XProtocol::throw_mysqlx_error(const Error &error)
911 {
912   if (!error)
913     return;
914 
915   throw error;
916 }
917 
new_result(bool expect_data)918 ngs::shared_ptr<Result> XProtocol::new_result(bool expect_data)
919 {
920   if (m_last_result)
921     m_last_result->buffer();
922 
923   m_last_result.reset(new Result(shared_from_this(), expect_data));
924 
925   return m_last_result;
926 }
927 
update_received_msg_counter(const Message * msg)928 void XProtocol::update_received_msg_counter(const Message* msg)
929 {
930   const std::string &id = msg->GetDescriptor()->full_name();
931   ++m_received_msg_counters[id];
932 
933   if (id != Mysqlx::Notice::Frame::descriptor()->full_name()) return;
934 
935   static const std::string *notice_type_id[] = {
936       &Mysqlx::Notice::Warning::descriptor()->full_name(),
937       &Mysqlx::Notice::SessionVariableChanged::descriptor()->full_name(),
938       &Mysqlx::Notice::SessionStateChanged::descriptor()->full_name()};
939   static const unsigned notice_type_id_size =
940       sizeof(notice_type_id) / sizeof(notice_type_id[0]);
941   const ::google::protobuf::uint32 notice_type =
942       static_cast<const Mysqlx::Notice::Frame *>(msg)->type() - 1u;
943   if (notice_type < notice_type_id_size)
944     ++m_received_msg_counters[*notice_type_id[notice_type]];
945 }
946 
947 #ifdef WIN32
948 #  pragma pop_macro("ERROR")
949 #endif
950