1 /*
2 *
3 * Copyright 2016 gRPC authors.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 */
18
19 #include "test/cpp/util/proto_file_parser.h"
20
21 #include <algorithm>
22 #include <iostream>
23 #include <sstream>
24 #include <unordered_set>
25
26 #include "absl/memory/memory.h"
27
28 #include <grpcpp/support/config.h>
29
30 namespace grpc {
31 namespace testing {
32 namespace {
33
34 // Match the user input method string to the full_name from method descriptor.
MethodNameMatch(const std::string & full_name,const std::string & input)35 bool MethodNameMatch(const std::string& full_name, const std::string& input) {
36 std::string clean_input = input;
37 std::replace(clean_input.begin(), clean_input.end(), '/', '.');
38 if (clean_input.size() > full_name.size()) {
39 return false;
40 }
41 return full_name.compare(full_name.size() - clean_input.size(),
42 clean_input.size(), clean_input) == 0;
43 }
44 } // namespace
45
46 class ErrorPrinter : public protobuf::compiler::MultiFileErrorCollector {
47 public:
ErrorPrinter(ProtoFileParser * parser)48 explicit ErrorPrinter(ProtoFileParser* parser) : parser_(parser) {}
49
AddError(const std::string & filename,int line,int column,const std::string & message)50 void AddError(const std::string& filename, int line, int column,
51 const std::string& message) override {
52 std::ostringstream oss;
53 oss << "error " << filename << " " << line << " " << column << " "
54 << message << "\n";
55 parser_->LogError(oss.str());
56 }
57
AddWarning(const std::string & filename,int line,int column,const std::string & message)58 void AddWarning(const std::string& filename, int line, int column,
59 const std::string& message) override {
60 std::cerr << "warning " << filename << " " << line << " " << column << " "
61 << message << std::endl;
62 }
63
64 private:
65 ProtoFileParser* parser_; // not owned
66 };
67
ProtoFileParser(const std::shared_ptr<grpc::Channel> & channel,const std::string & proto_path,const std::string & protofiles)68 ProtoFileParser::ProtoFileParser(const std::shared_ptr<grpc::Channel>& channel,
69 const std::string& proto_path,
70 const std::string& protofiles)
71 : has_error_(false),
72 dynamic_factory_(new protobuf::DynamicMessageFactory()) {
73 std::vector<std::string> service_list;
74 if (channel) {
75 reflection_db_ =
76 absl::make_unique<grpc::ProtoReflectionDescriptorDatabase>(channel);
77 reflection_db_->GetServices(&service_list);
78 }
79
80 std::unordered_set<std::string> known_services;
81 if (!protofiles.empty()) {
82 source_tree_.MapPath("", proto_path);
83 error_printer_ = absl::make_unique<ErrorPrinter>(this);
84 importer_ = absl::make_unique<protobuf::compiler::Importer>(
85 &source_tree_, error_printer_.get());
86
87 std::string file_name;
88 std::stringstream ss(protofiles);
89 while (std::getline(ss, file_name, ',')) {
90 const auto* file_desc = importer_->Import(file_name);
91 if (file_desc) {
92 for (int i = 0; i < file_desc->service_count(); i++) {
93 service_desc_list_.push_back(file_desc->service(i));
94 known_services.insert(file_desc->service(i)->full_name());
95 }
96 } else {
97 std::cerr << file_name << " not found" << std::endl;
98 }
99 }
100
101 file_db_ =
102 absl::make_unique<protobuf::DescriptorPoolDatabase>(*importer_->pool());
103 }
104
105 if (!reflection_db_ && !file_db_) {
106 LogError("No available proto database");
107 return;
108 }
109
110 if (!reflection_db_) {
111 desc_db_ = std::move(file_db_);
112 } else if (!file_db_) {
113 desc_db_ = std::move(reflection_db_);
114 } else {
115 desc_db_ = absl::make_unique<protobuf::MergedDescriptorDatabase>(
116 reflection_db_.get(), file_db_.get());
117 }
118
119 desc_pool_ = absl::make_unique<protobuf::DescriptorPool>(desc_db_.get());
120
121 for (auto it = service_list.begin(); it != service_list.end(); it++) {
122 if (known_services.find(*it) == known_services.end()) {
123 if (const protobuf::ServiceDescriptor* service_desc =
124 desc_pool_->FindServiceByName(*it)) {
125 service_desc_list_.push_back(service_desc);
126 known_services.insert(*it);
127 }
128 }
129 }
130 }
131
~ProtoFileParser()132 ProtoFileParser::~ProtoFileParser() {}
133
GetFullMethodName(const std::string & method)134 std::string ProtoFileParser::GetFullMethodName(const std::string& method) {
135 has_error_ = false;
136
137 if (known_methods_.find(method) != known_methods_.end()) {
138 return known_methods_[method];
139 }
140
141 const protobuf::MethodDescriptor* method_descriptor = nullptr;
142 for (auto it = service_desc_list_.begin(); it != service_desc_list_.end();
143 it++) {
144 const auto* service_desc = *it;
145 for (int j = 0; j < service_desc->method_count(); j++) {
146 const auto* method_desc = service_desc->method(j);
147 if (MethodNameMatch(method_desc->full_name(), method)) {
148 if (method_descriptor) {
149 std::ostringstream error_stream;
150 error_stream << "Ambiguous method names: ";
151 error_stream << method_descriptor->full_name() << " ";
152 error_stream << method_desc->full_name();
153 LogError(error_stream.str());
154 }
155 method_descriptor = method_desc;
156 }
157 }
158 }
159 if (!method_descriptor) {
160 LogError("Method name not found");
161 }
162 if (has_error_) {
163 return "";
164 }
165
166 known_methods_[method] = method_descriptor->full_name();
167
168 return method_descriptor->full_name();
169 }
170
GetFormattedMethodName(const std::string & method)171 std::string ProtoFileParser::GetFormattedMethodName(const std::string& method) {
172 has_error_ = false;
173 std::string formatted_method_name = GetFullMethodName(method);
174 if (has_error_) {
175 return "";
176 }
177 size_t last_dot = formatted_method_name.find_last_of('.');
178 if (last_dot != std::string::npos) {
179 formatted_method_name[last_dot] = '/';
180 }
181 formatted_method_name.insert(formatted_method_name.begin(), '/');
182 return formatted_method_name;
183 }
184
GetMessageTypeFromMethod(const std::string & method,bool is_request)185 std::string ProtoFileParser::GetMessageTypeFromMethod(const std::string& method,
186 bool is_request) {
187 has_error_ = false;
188 std::string full_method_name = GetFullMethodName(method);
189 if (has_error_) {
190 return "";
191 }
192 const protobuf::MethodDescriptor* method_desc =
193 desc_pool_->FindMethodByName(full_method_name);
194 if (!method_desc) {
195 LogError("Method not found");
196 return "";
197 }
198
199 return is_request ? method_desc->input_type()->full_name()
200 : method_desc->output_type()->full_name();
201 }
202
IsStreaming(const std::string & method,bool is_request)203 bool ProtoFileParser::IsStreaming(const std::string& method, bool is_request) {
204 has_error_ = false;
205
206 std::string full_method_name = GetFullMethodName(method);
207 if (has_error_) {
208 return false;
209 }
210
211 const protobuf::MethodDescriptor* method_desc =
212 desc_pool_->FindMethodByName(full_method_name);
213 if (!method_desc) {
214 LogError("Method not found");
215 return false;
216 }
217
218 return is_request ? method_desc->client_streaming()
219 : method_desc->server_streaming();
220 }
221
GetSerializedProtoFromMethod(const std::string & method,const std::string & formatted_proto,bool is_request,bool is_json_format)222 std::string ProtoFileParser::GetSerializedProtoFromMethod(
223 const std::string& method, const std::string& formatted_proto,
224 bool is_request, bool is_json_format) {
225 has_error_ = false;
226 std::string message_type_name = GetMessageTypeFromMethod(method, is_request);
227 if (has_error_) {
228 return "";
229 }
230 return GetSerializedProtoFromMessageType(message_type_name, formatted_proto,
231 is_json_format);
232 }
233
GetFormattedStringFromMethod(const std::string & method,const std::string & serialized_proto,bool is_request,bool is_json_format)234 std::string ProtoFileParser::GetFormattedStringFromMethod(
235 const std::string& method, const std::string& serialized_proto,
236 bool is_request, bool is_json_format) {
237 has_error_ = false;
238 std::string message_type_name = GetMessageTypeFromMethod(method, is_request);
239 if (has_error_) {
240 return "";
241 }
242 return GetFormattedStringFromMessageType(message_type_name, serialized_proto,
243 is_json_format);
244 }
245
GetSerializedProtoFromMessageType(const std::string & message_type_name,const std::string & formatted_proto,bool is_json_format)246 std::string ProtoFileParser::GetSerializedProtoFromMessageType(
247 const std::string& message_type_name, const std::string& formatted_proto,
248 bool is_json_format) {
249 has_error_ = false;
250 std::string serialized;
251 const protobuf::Descriptor* desc =
252 desc_pool_->FindMessageTypeByName(message_type_name);
253 if (!desc) {
254 LogError("Message type not found");
255 return "";
256 }
257 std::unique_ptr<grpc::protobuf::Message> msg(
258 dynamic_factory_->GetPrototype(desc)->New());
259
260 bool ok;
261 if (is_json_format) {
262 ok = grpc::protobuf::json::JsonStringToMessage(formatted_proto, msg.get())
263 .ok();
264 if (!ok) {
265 LogError("Failed to convert json format to proto.");
266 return "";
267 }
268 } else {
269 ok = protobuf::TextFormat::ParseFromString(formatted_proto, msg.get());
270 if (!ok) {
271 LogError("Failed to convert text format to proto.");
272 return "";
273 }
274 }
275
276 ok = msg->SerializeToString(&serialized);
277 if (!ok) {
278 LogError("Failed to serialize proto.");
279 return "";
280 }
281 return serialized;
282 }
283
GetFormattedStringFromMessageType(const std::string & message_type_name,const std::string & serialized_proto,bool is_json_format)284 std::string ProtoFileParser::GetFormattedStringFromMessageType(
285 const std::string& message_type_name, const std::string& serialized_proto,
286 bool is_json_format) {
287 has_error_ = false;
288 const protobuf::Descriptor* desc =
289 desc_pool_->FindMessageTypeByName(message_type_name);
290 if (!desc) {
291 LogError("Message type not found");
292 return "";
293 }
294 std::unique_ptr<grpc::protobuf::Message> msg(
295 dynamic_factory_->GetPrototype(desc)->New());
296 if (!msg->ParseFromString(serialized_proto)) {
297 LogError("Failed to deserialize proto.");
298 return "";
299 }
300 std::string formatted_string;
301
302 if (is_json_format) {
303 grpc::protobuf::json::JsonPrintOptions jsonPrintOptions;
304 jsonPrintOptions.add_whitespace = true;
305 if (!grpc::protobuf::json::MessageToJsonString(*msg, &formatted_string,
306 jsonPrintOptions)
307 .ok()) {
308 LogError("Failed to print proto message to json format");
309 return "";
310 }
311 } else {
312 if (!protobuf::TextFormat::PrintToString(*msg, &formatted_string)) {
313 LogError("Failed to print proto message to text format");
314 return "";
315 }
316 }
317 return formatted_string;
318 }
319
LogError(const std::string & error_msg)320 void ProtoFileParser::LogError(const std::string& error_msg) {
321 if (!error_msg.empty()) {
322 std::cerr << error_msg << std::endl;
323 }
324 has_error_ = true;
325 }
326
327 } // namespace testing
328 } // namespace grpc
329