1 /*
2  * Copyright (c) 2018, 2019, Oracle and/or its affiliates. All rights reserved.
3  *
4  * This program is free software; you can redistribute it and/or modify
5  * it under the terms of the GNU General Public License, version 2.0,
6  * as published by the Free Software Foundation.
7  *
8  * This program is also distributed with certain software (including
9  * but not limited to OpenSSL) that is licensed under separate terms,
10  * as designated in a particular file or component or in included license
11  * documentation.  The authors of MySQL hereby grant you an additional
12  * permission to link the program and your derivative works with the
13  * separately licensed software that they have included with MySQL.
14  *
15  * This program is distributed in the hope that it will be useful,
16  * but WITHOUT ANY WARRANTY; without even the implied warranty of
17  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
18  * GNU General Public License, version 2.0, for more details.
19  *
20  * You should have received a copy of the GNU General Public License
21  * along with this program; if not, write to the Free Software
22  * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301  USA
23  */
24 
25 #include "plugin/x/tests/driver/parsers/message_parser.h"
26 
27 #include <memory>
28 
29 #include "plugin/x/tests/driver/common/utils_string_parsing.h"
30 #include "plugin/x/tests/driver/connector/mysqlx_all_msgs.h"
31 
32 using Message = xcl::XProtocol::Message;
33 
34 namespace parser {
35 
36 namespace details {
37 
38 class Error_dumper : public ::google::protobuf::io::ErrorCollector {
39   std::stringstream m_out;
40 
41  public:
AddError(int line,int column,const std::string & message)42   void AddError(int line, int column, const std::string &message) override {
43     m_out << "ERROR in message: line " << line + 1 << ": column " << column
44           << ": " << message << "\n";
45   }
46 
AddWarning(int line,int column,const std::string & message)47   void AddWarning(int line, int column, const std::string &message) override {
48     m_out << "WARNING in message: line " << line + 1 << ": column " << column
49           << ": " << message << "\n";
50   }
51 
str()52   std::string str() { return m_out.str(); }
53 };
54 
parse_mesage(const std::string & text_message,const std::string & text_name,Message * message,std::string * out_error,const bool allow_partial_messaged)55 bool parse_mesage(const std::string &text_message, const std::string &text_name,
56                   Message *message, std::string *out_error,
57                   const bool allow_partial_messaged) {
58   google::protobuf::TextFormat::Parser parser;
59   Error_dumper dumper;
60   parser.RecordErrorsTo(&dumper);
61   parser.AllowPartialMessage(allow_partial_messaged);
62   if (!parser.ParseFromString(text_message, message)) {
63     if (nullptr != out_error) {
64       *out_error = "Invalid message in input: " + text_name + '\n';
65       int i = 1;
66       for (std::string::size_type p = 0, n = text_message.find('\n', p + 1);
67            p != std::string::npos; p = (n == std::string::npos ? n : n + 1),
68                                   n = text_message.find('\n', p + 1), ++i) {
69         *out_error +=
70             std::to_string(i) + ": " + text_message.substr(p, n - p) + '\n';
71       }
72       *out_error += "\n" + dumper.str() + '\n';
73     }
74 
75     return false;
76   }
77 
78   return true;
79 }
80 
81 template <typename MSG>
parse_serialize_message(const std::string & text_payload,std::string * out_error,const bool allow_partial_messaged)82 Message *parse_serialize_message(const std::string &text_payload,
83                                  std::string *out_error,
84                                  const bool allow_partial_messaged) {
85   std::unique_ptr<MSG> msg{new MSG()};
86 
87   if (!parse_mesage(text_payload, "", msg.get(), out_error,
88                     allow_partial_messaged))
89     return {};
90 
91   return msg.release();
92 }
93 
get_notice_payload_from_text(const Mysqlx::Notice::Frame_Type type,const std::string & text_payload,std::string * out_binary_payload,const bool allow_partial_messaged)94 bool get_notice_payload_from_text(const Mysqlx::Notice::Frame_Type type,
95                                   const std::string &text_payload,
96                                   std::string *out_binary_payload,
97                                   const bool allow_partial_messaged) {
98   std::string error;
99   std::unique_ptr<Message> msg{parser::get_notice_message_from_text(
100       type, text_payload, &error, allow_partial_messaged)};
101 
102   if (nullptr == msg) {
103     // Fail when there is a payload, still we received a null message
104     return text_payload.empty();
105   }
106 
107   if (allow_partial_messaged)
108     return msg->SerializePartialToString(out_binary_payload);
109 
110   return msg->SerializeToString(out_binary_payload);
111 }
112 
113 }  // namespace details
114 
get_notice_message_from_text(const Mysqlx::Notice::Frame_Type type,const std::string & text_payload,std::string * out_error,const bool allow_partial_messaged)115 Message *get_notice_message_from_text(const Mysqlx::Notice::Frame_Type type,
116                                       const std::string &text_payload,
117                                       std::string *out_error,
118                                       const bool allow_partial_messaged) {
119   switch (type) {
120     case Mysqlx::Notice::Frame_Type_WARNING:
121       return details::parse_serialize_message<Mysqlx::Notice::Warning>(
122           text_payload, out_error, allow_partial_messaged);
123     case Mysqlx::Notice::Frame_Type_SESSION_VARIABLE_CHANGED:
124       return details::parse_serialize_message<
125           Mysqlx::Notice::SessionVariableChanged>(text_payload, out_error,
126                                                   allow_partial_messaged);
127     case Mysqlx::Notice::Frame_Type_SESSION_STATE_CHANGED:
128       return details::parse_serialize_message<
129           Mysqlx::Notice::SessionStateChanged>(text_payload, out_error,
130                                                allow_partial_messaged);
131     case Mysqlx::Notice::Frame_Type_GROUP_REPLICATION_STATE_CHANGED:
132       return details::parse_serialize_message<
133           Mysqlx::Notice::GroupReplicationStateChanged>(text_payload, out_error,
134                                                         allow_partial_messaged);
135     default:
136       return nullptr;
137   }
138 }
139 
get_name_and_body_from_text(const std::string & text_message,std::string * out_full_message_name,std::string * out_message_body,const bool is_body_full)140 bool get_name_and_body_from_text(const std::string &text_message,
141                                  std::string *out_full_message_name,
142                                  std::string *out_message_body,
143                                  const bool is_body_full) {
144   const auto separator = text_message.find("{");
145 
146   if (std::string::npos == separator) {
147     return false;
148   }
149 
150   if (nullptr != out_full_message_name) {
151     *out_full_message_name = text_message.substr(0, separator);
152     aux::trim(*out_full_message_name);
153   }
154 
155   auto body = text_message.substr(separator);
156 
157   if (is_body_full) {
158     aux::trim(body, " \t\n\r");
159 
160     if (body.size() < 2) return false;
161 
162     if (body[0] != '{') return false;
163     if (body[body.size() - 1] != '}') return false;
164 
165     body = body.substr(1, body.size() - 2);
166   }
167 
168   if (nullptr != out_message_body) {
169     *out_message_body = body;
170   }
171 
172   return true;
173 }
174 
get_client_message_from_text(const std::string & name,const std::string & data,xcl::XProtocol::Client_message_type_id * msg_id,std::string * out_error,const bool allow_partial_messaged)175 Message *get_client_message_from_text(
176     const std::string &name, const std::string &data,
177     xcl::XProtocol::Client_message_type_id *msg_id, std::string *out_error,
178     const bool allow_partial_messaged) {
179   std::string find_by = name;
180   Message *message;
181 
182   if (find_by.empty()) {
183     *out_error = "Message name is empty";
184     return nullptr;
185   }
186 
187   while (true) {
188     auto msg = client_msgs_by_name.find(find_by);
189 
190     if (msg == client_msgs_by_name.end()) {
191       if (client_msgs_by_full_name.count(name) &&
192           find_by != client_msgs_by_full_name[name]) {
193         find_by = client_msgs_by_full_name[name];
194         continue;
195       }
196       *out_error = "Invalid message type " + name;
197       return nullptr;
198     }
199 
200     message = msg->second.first();
201     *msg_id = msg->second.second;
202     break;
203   }
204 
205   if (!details::parse_mesage(data, name, message, out_error,
206                              allow_partial_messaged)) {
207     delete message;
208     return nullptr;
209   }
210 
211   return message;
212 }
213 
get_server_message_from_text(const std::string & name,const std::string & data,xcl::XProtocol::Server_message_type_id * msg_id,std::string * out_error,const bool allow_partial_messaged)214 Message *get_server_message_from_text(
215     const std::string &name, const std::string &data,
216     xcl::XProtocol::Server_message_type_id *msg_id, std::string *out_error,
217     const bool allow_partial_messaged) {
218   std::string find_by = name;
219   Message *message;
220 
221   while (true) {
222     auto msg = server_msgs_by_name.find(find_by);
223 
224     if (msg == server_msgs_by_name.end()) {
225       if (server_msgs_by_full_name.count(name) &&
226           find_by != server_msgs_by_full_name[name]) {
227         find_by = server_msgs_by_full_name[name];
228         continue;
229       }
230       *out_error = "Invalid message type " + name;
231       return nullptr;
232     }
233 
234     message = msg->second.first();
235     *msg_id = msg->second.second;
236     break;
237   }
238 
239   if (!details::parse_mesage(data, name, message, out_error,
240                              allow_partial_messaged)) {
241     delete message;
242     return nullptr;
243   }
244 
245   if (Mysqlx::ServerMessages::NOTICE == *msg_id) {
246     auto notice = reinterpret_cast<Mysqlx::Notice::Frame *>(message);
247 
248     std::string out_payload;
249     if (!details::get_notice_payload_from_text(
250             static_cast<Mysqlx::Notice::Frame_Type>(notice->type()),
251             notice->payload(), &out_payload, allow_partial_messaged)) {
252       *out_error = "Invalid notice payload: " + notice->payload();
253       return nullptr;
254     }
255 
256     notice->set_payload(out_payload);
257   }
258 
259   return message;
260 }
261 
262 }  // namespace parser
263