1 /*
2  * Copyright (c) 2015, 2021, Oracle and/or its affiliates.
3  *
4  * This program is free software; you can redistribute it and/or modify
5  * it under the terms of the GNU General Public License, version 2.0,
6  * as published by the Free Software Foundation.
7  *
8  * This program is also distributed with certain software (including
9  * but not limited to OpenSSL) that is licensed under separate terms,
10  * as designated in a particular file or component or in included license
11  * documentation.  The authors of MySQL hereby grant you an additional
12  * permission to link the program and your derivative works with the
13  * separately licensed software that they have included with MySQL.
14  *
15  * This program is distributed in the hope that it will be useful,
16  * but WITHOUT ANY WARRANTY; without even the implied warranty of
17  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
18  * GNU General Public License, version 2.0, for more details.
19  *
20  * You should have received a copy of the GNU General Public License
21  * along with this program; if not, write to the Free Software
22  * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA
23  * 02110-1301  USA
24  */
25 
26 #ifndef _MYSQLX_PROTOCOL_H_
27 #define _MYSQLX_PROTOCOL_H_
28 
29 #undef ERROR //Needed to avoid conflict with ERROR in mysqlx.pb.h
30 
31 // Avoid warnings from includes of other project and protobuf
32 #if __GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 6)
33 #pragma GCC diagnostic push
34 #pragma GCC diagnostic ignored "-Wshadow"
35 #pragma GCC diagnostic ignored "-Wunused-parameter"
36 #elif defined _MSC_VER
37 #pragma warning (push)
38 #pragma warning (disable : 4018 4996)
39 #endif
40 
41 #if __GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 6)
42 #pragma GCC diagnostic pop
43 #elif defined _MSC_VER
44 #pragma warning (pop)
45 #endif
46 
47 #include "ngs_common/smart_ptr.h"
48 #include "ngs_common/bind.h"
49 #include <list>
50 #include <assert.h>
51 
52 #include "ngs_common/protocol_protobuf.h"
53 #include "mysqlx_connection.h"
54 
55 
56 namespace mysqlx
57 {
58   typedef google::protobuf::Message Message;
59   typedef ngs::function<bool (int,std::string)> Local_notice_handler;
60 
61   class Result;
62 
63   class ArgumentValue
64   {
65   public:
66     enum Type
67     {
68       TInteger,
69       TUInteger,
70       TNull,
71       TDouble,
72       TFloat,
73       TBool,
74       TString,
75       TOctets,
76     };
77 
ArgumentValue(const ArgumentValue & other)78     ArgumentValue(const ArgumentValue &other)
79     {
80       m_type = other.m_type;
81       m_value = other.m_value;
82       if (m_type == TString || m_type == TOctets)
83         m_value.s = new std::string(*other.m_value.s);
84     }
85 
86     ArgumentValue &operator = (const ArgumentValue &other)
87     {
88       if (&other == this)
89         return *this;
90 
91       m_type = other.m_type;
92       m_value = other.m_value;
93       if (m_type == TString || m_type == TOctets)
94         m_value.s = new std::string(*other.m_value.s);
95 
96       return *this;
97     }
98 
99     explicit ArgumentValue(const std::string &s, Type type = TString)
100     {
101       assert(type == TOctets || type == TString);
102       m_type = type;
103       m_value.s = new std::string(s);
104     }
105 
ArgumentValue(int64_t n)106     explicit ArgumentValue(int64_t n)
107     {
108       m_type = TInteger;
109       m_value.i = n;
110     }
111 
ArgumentValue(uint64_t n)112     explicit ArgumentValue(uint64_t n)
113     {
114       m_type = TUInteger;
115       m_value.ui = n;
116     }
117 
ArgumentValue(double n)118     explicit ArgumentValue(double n)
119     {
120       m_type = TDouble;
121       m_value.d = n;
122     }
123 
ArgumentValue(float n)124     explicit ArgumentValue(float n)
125     {
126       m_type = TFloat;
127       m_value.f = n;
128     }
129 
ArgumentValue(bool n)130     explicit ArgumentValue(bool n)
131     {
132       m_type = TBool;
133       m_value.b = n;
134     }
135 
ArgumentValue()136     explicit ArgumentValue()
137     {
138       m_type = TNull;
139     }
140 
~ArgumentValue()141     ~ArgumentValue()
142     {
143       if (m_type == TString || m_type == TOctets)
144         delete m_value.s;
145     }
146 
type()147     inline Type type() const { return m_type; }
148 
uint64_t()149     inline operator uint64_t () const
150     {
151       if (m_type != TUInteger)
152         throw std::logic_error("type error");
153       return m_value.ui;
154     }
155 
int64_t()156     inline operator int64_t () const
157     {
158       if (m_type != TInteger)
159         throw std::logic_error("type error");
160       return m_value.i;
161     }
162 
163     inline operator double() const
164     {
165       if (m_type != TDouble)
166         throw std::logic_error("type error");
167       return m_value.d;
168     }
169 
170     inline operator float() const
171     {
172       if (m_type != TFloat)
173         throw std::logic_error("type error");
174       return m_value.f;
175     }
176 
177     inline operator bool() const
178     {
179       if (m_type != TBool)
180         throw std::logic_error("type error");
181       return m_value.b;
182     }
183 
184     inline operator const std::string & () const
185     {
186       if (m_type != TString && m_type != TOctets)
187         throw std::logic_error("type error");
188       return *m_value.s;
189     }
190 
191   private:
192     Type m_type;
193     union
194     {
195       std::string *s;
196       int64_t i;
197       uint64_t ui;
198       double d;
199       float f;
200       bool b;
201     } m_value;
202   };
203 
204   struct Ssl_config
205   {
Ssl_configSsl_config206     Ssl_config()
207     : key(NULL),
208       ca(NULL),
209       ca_path(NULL),
210       cert(NULL),
211       cipher(NULL),
212       tls_version(NULL)
213     {
214     }
215 
216     const char *key;
217     const char *ca;
218     const char *ca_path;
219     const char *cert;
220     const char *cipher;
221     const char *tls_version;
222   };
223 
224   enum Internet_protocol
225   {
226     IP_any = 0,
227     IPv4,
228     IPv6,
229   };
230 
231   class XProtocol : public ngs::enable_shared_from_this<XProtocol>
232   {
233   public:
234     XProtocol(const Ssl_config &ssl_config, const std::size_t timeout, const bool dont_wait_for_disconnect = true, const Internet_protocol ip_mode = IPv4);
235     ~XProtocol();
236 
client_id()237     uint64_t client_id() const { return m_client_id; }
capabilities()238     const Mysqlx::Connection::Capabilities &capabilities() const { return m_capabilities; }
239 
240     void push_local_notice_handler(Local_notice_handler handler);
241     void pop_local_notice_handler();
242 
243     void connect(const std::string &uri, const std::string &pass, const bool cap_expired_password = false); //XXX capabilities flags
244     void connect(const std::string &host, int port);
245     void connect_to_localhost(const std::string &unix_socket_or_named_pipe);
246 
247     void close();
248     void set_closed();
is_closed()249     bool is_closed() const { return m_closed; }
250 
251     void enable_tls();
252 
253     void send(int mid, const Message &msg);
254     Message *recv_next(int &mid);
255 
256     Message *recv_raw(int &mid);
257     Message *recv_payload(const int mid, const std::size_t msglen);
258     Message *recv_raw_with_deadline(int &mid, const int deadline_milliseconds);
259 
260     ngs::shared_ptr<Result> recv_result();
261     ngs::shared_ptr<Result> new_empty_result();
262 
263     // Overrides for Client Session Messages
send(const Mysqlx::Session::AuthenticateStart & m)264     void send(const Mysqlx::Session::AuthenticateStart &m) { send(Mysqlx::ClientMessages::SESS_AUTHENTICATE_START, m); };
send(const Mysqlx::Session::AuthenticateContinue & m)265     void send(const Mysqlx::Session::AuthenticateContinue &m) { send(Mysqlx::ClientMessages::SESS_AUTHENTICATE_CONTINUE, m); };
send(const Mysqlx::Session::Reset & m)266     void send(const Mysqlx::Session::Reset &m) { send(Mysqlx::ClientMessages::SESS_RESET, m); };
send(const Mysqlx::Session::Close & m)267     void send(const Mysqlx::Session::Close &m) { send(Mysqlx::ClientMessages::SESS_CLOSE, m); };
268 
269     // Overrides for SQL Messages
send(const Mysqlx::Sql::StmtExecute & m)270     void send(const Mysqlx::Sql::StmtExecute &m) { send(Mysqlx::ClientMessages::SQL_STMT_EXECUTE, m); };
271 
272     // Overrides for CRUD operations
send(const Mysqlx::Crud::Find & m)273     void send(const Mysqlx::Crud::Find &m) { send(Mysqlx::ClientMessages::CRUD_FIND, m); };
send(const Mysqlx::Crud::Insert & m)274     void send(const Mysqlx::Crud::Insert &m) { send(Mysqlx::ClientMessages::CRUD_INSERT, m); };
send(const Mysqlx::Crud::Update & m)275     void send(const Mysqlx::Crud::Update &m) { send(Mysqlx::ClientMessages::CRUD_UPDATE, m); };
send(const Mysqlx::Crud::Delete & m)276     void send(const Mysqlx::Crud::Delete &m) { send(Mysqlx::ClientMessages::CRUD_DELETE, m); };
277 
278     // Overrides for Connection
send(const Mysqlx::Connection::CapabilitiesGet & m)279     void send(const Mysqlx::Connection::CapabilitiesGet &m) { send(Mysqlx::ClientMessages::CON_CAPABILITIES_GET, m); };
send(const Mysqlx::Connection::CapabilitiesSet & m)280     void send(const Mysqlx::Connection::CapabilitiesSet &m) { send(Mysqlx::ClientMessages::CON_CAPABILITIES_SET, m); };
send(const Mysqlx::Connection::Close & m)281     void send(const Mysqlx::Connection::Close &m) { send(Mysqlx::ClientMessages::CON_CLOSE, m); };
282 
283   public:
284     ngs::shared_ptr<Result> execute_sql(const std::string &sql);
285     ngs::shared_ptr<Result> execute_stmt(const std::string &ns, const std::string &sql, const std::vector<ArgumentValue> &args);
286 
287     ngs::shared_ptr<Result> execute_find(const Mysqlx::Crud::Find &m);
288     ngs::shared_ptr<Result> execute_update(const Mysqlx::Crud::Update &m);
289     ngs::shared_ptr<Result> execute_insert(const Mysqlx::Crud::Insert &m);
290     ngs::shared_ptr<Result> execute_delete(const Mysqlx::Crud::Delete &m);
291 
292     void fetch_capabilities();
293     void setup_capability(const std::string &name, const bool value);
294 
295     void authenticate(const std::string &user, const std::string &pass, const std::string &schema);
296     void authenticate_plain(const std::string &user, const std::string &pass, const std::string &db);
297     void authenticate_mysql41(const std::string &user, const std::string &pass, const std::string &db);
298 
299     void send_bytes(const std::string &data);
300 
set_trace_protocol(bool flag)301     void set_trace_protocol(bool flag) { m_trace_packets = flag; }
302     unsigned long get_received_msg_counter(const std::string &id) const;
303 
304   private:
305     void perform_close();
306     void dispatch_notice(Mysqlx::Notice::Frame *frame);
307     Message *recv_message_with_header(int &mid, char (&header_buffer)[5], const std::size_t header_offset);
308     void throw_mysqlx_error(const Error &ec);
309     ngs::shared_ptr<Result> new_result(bool expect_data);
310     void update_received_msg_counter(const Message* msg);
311   private:
312     std::list<Local_notice_handler> m_local_notice_handlers;
313     Mysqlx::Connection::Capabilities m_capabilities;
314 
315     Connection m_sync_connection;
316     uint64_t m_client_id;
317     bool m_trace_packets;
318     bool m_closed;
319     const bool m_dont_wait_for_disconnect;
320     const Internet_protocol m_ip_mode;
321     ngs::shared_ptr<Result> m_last_result;
322     std::map<std::string, unsigned long> m_received_msg_counters;
323   };
324 
325   bool parse_mysql_connstring(const std::string &connstring,
326                               std::string &protocol, std::string &user, std::string &password,
327                               std::string &host, int &port, std::string &sock,
328                               std::string &db, int &pwd_found);
329 } // namespace mysqlx
330 
331 #endif // _MYSQLX_PROTOCOL_H_
332