1 /*
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <folly/Demangle.h>
18 #include <folly/Portability.h>
19 
20 #include <thrift/lib/cpp2/GeneratedCodeHelper.h>
21 #include <thrift/lib/cpp2/protocol/BinaryProtocol.h>
22 #include <thrift/lib/cpp2/protocol/CompactProtocol.h>
23 #include <thrift/lib/cpp2/protocol/Protocol.h>
24 
25 using namespace std;
26 using namespace folly;
27 using namespace apache::thrift;
28 using namespace apache::thrift::protocol;
29 using namespace apache::thrift::transport;
30 
31 namespace apache {
32 namespace thrift {
33 
34 namespace detail {
35 
THRIFT_PLUGGABLE_FUNC_REGISTER(bool,includeInRecentRequestsCount,const std::string_view)36 THRIFT_PLUGGABLE_FUNC_REGISTER(
37     bool, includeInRecentRequestsCount, const std::string_view /*methodName*/) {
38   // Users of the module will override the behavior
39   return true;
40 }
41 
42 } // namespace detail
43 
44 namespace detail {
45 namespace ac {
46 
throw_app_exn(char const * const msg)47 [[noreturn]] void throw_app_exn(char const* const msg) {
48   throw TApplicationException(msg);
49 }
50 } // namespace ac
51 } // namespace detail
52 
53 namespace detail {
54 namespace ap {
55 
56 template <typename ProtocolReader, typename ProtocolWriter>
write_exn(bool includeEnvelope,const char * method,ProtocolWriter * prot,int32_t protoSeqId,ContextStack * ctx,const TApplicationException & x)57 std::unique_ptr<folly::IOBuf> helper<ProtocolReader, ProtocolWriter>::write_exn(
58     bool includeEnvelope,
59     const char* method,
60     ProtocolWriter* prot,
61     int32_t protoSeqId,
62     ContextStack* ctx,
63     const TApplicationException& x) {
64   IOBufQueue queue(IOBufQueue::cacheChainLength());
65   size_t bufSize =
66       apache::thrift::detail::serializedExceptionBodySizeZC(prot, &x);
67   bufSize += prot->serializedMessageSize(method);
68   prot->setOutput(&queue, bufSize);
69   if (ctx) {
70     ctx->handlerErrorWrapped(exception_wrapper(x));
71   }
72   if (includeEnvelope) {
73     prot->writeMessageBegin(method, MessageType::T_EXCEPTION, protoSeqId);
74   }
75   apache::thrift::detail::serializeExceptionBody(prot, &x);
76   if (includeEnvelope) {
77     prot->writeMessageEnd();
78   }
79   return std::move(queue).move();
80 }
81 
82 template <typename ProtocolReader, typename ProtocolWriter>
process_exn(const char * func,const TApplicationException::TApplicationExceptionType type,const string & msg,ResponseChannelRequest::UniquePtr req,Cpp2RequestContext * ctx,EventBase * eb,int32_t protoSeqId)83 void helper<ProtocolReader, ProtocolWriter>::process_exn(
84     const char* func,
85     const TApplicationException::TApplicationExceptionType type,
86     const string& msg,
87     ResponseChannelRequest::UniquePtr req,
88     Cpp2RequestContext* ctx,
89     EventBase* eb,
90     int32_t protoSeqId) {
91   ProtocolWriter oprot;
92   if (req) {
93     LOG(ERROR) << msg << " in function " << func;
94     TApplicationException x(type, msg);
95     auto payload = THeader::transform(
96         helper_w<ProtocolWriter>::write_exn(
97             req->includeEnvelope(), func, &oprot, protoSeqId, nullptr, x),
98         ctx->getHeader()->getWriteTransforms());
99     eb->runInEventBaseThread(
100         [payload = move(payload), request = move(req)]() mutable {
101           if (request->isStream()) {
102             request->sendStreamReply(
103                 ResponsePayload::create(std::move(payload)),
104                 detail::ServerStreamFactory{nullptr});
105           } else if (request->isSink()) {
106 #if FOLLY_HAS_COROUTINES
107             request->sendSinkReply(
108                 ResponsePayload::create(std::move(payload)),
109                 detail::SinkConsumerImpl{});
110 #else
111             DCHECK(false);
112 #endif
113           } else {
114             request->sendReply(ResponsePayload::create(std::move(payload)));
115           }
116         });
117   } else {
118     LOG(ERROR) << msg << " in oneway function " << func;
119   }
120 }
121 
122 template struct helper<BinaryProtocolReader, BinaryProtocolWriter>;
123 template struct helper<CompactProtocolReader, CompactProtocolWriter>;
124 
125 template <typename ProtocolReader>
setupRequestContextWithMessageBegin(const MessageBegin::Metadata & msgBegin,ResponseChannelRequest::UniquePtr & req,Cpp2RequestContext * ctx,folly::EventBase * eb)126 static bool setupRequestContextWithMessageBegin(
127     const MessageBegin::Metadata& msgBegin,
128     ResponseChannelRequest::UniquePtr& req,
129     Cpp2RequestContext* ctx,
130     folly::EventBase* eb) {
131   using h = helper_r<ProtocolReader>;
132   const char* fn = "process";
133   if (!msgBegin.isValid) {
134     LOG(ERROR) << "received invalid message from client: "
135                << msgBegin.errMessage;
136     auto type = TApplicationException::TApplicationExceptionType::UNKNOWN;
137     const char* msg = "invalid message from client";
138     h::process_exn(fn, type, msg, std::move(req), ctx, eb, msgBegin.seqId);
139     return false;
140   }
141   if (msgBegin.msgType != MessageType::T_CALL &&
142       msgBegin.msgType != MessageType::T_ONEWAY) {
143     LOG(ERROR) << "received invalid message of type "
144                << folly::to_underlying(msgBegin.msgType);
145     auto type =
146         TApplicationException::TApplicationExceptionType::INVALID_MESSAGE_TYPE;
147     const char* msg = "invalid message arguments";
148     h::process_exn(fn, type, msg, std::move(req), ctx, eb, msgBegin.seqId);
149     return false;
150   }
151 
152   ctx->setProtoSeqId(msgBegin.seqId);
153   return true;
154 }
155 
setupRequestContextWithMessageBegin(const MessageBegin::Metadata & msgBegin,protocol::PROTOCOL_TYPES protType,ResponseChannelRequest::UniquePtr & req,Cpp2RequestContext * ctx,folly::EventBase * eb)156 bool setupRequestContextWithMessageBegin(
157     const MessageBegin::Metadata& msgBegin,
158     protocol::PROTOCOL_TYPES protType,
159     ResponseChannelRequest::UniquePtr& req,
160     Cpp2RequestContext* ctx,
161     folly::EventBase* eb) {
162   switch (protType) {
163     case protocol::T_BINARY_PROTOCOL:
164       return setupRequestContextWithMessageBegin<BinaryProtocolReader>(
165           msgBegin, req, ctx, eb);
166     case protocol::T_COMPACT_PROTOCOL:
167       return setupRequestContextWithMessageBegin<CompactProtocolReader>(
168           msgBegin, req, ctx, eb);
169     default:
170       LOG(ERROR) << "invalid protType: " << folly::to_underlying(protType);
171       return false;
172   }
173 }
174 
deserializeMessageBegin(const folly::IOBuf & buf,protocol::PROTOCOL_TYPES protType)175 MessageBegin deserializeMessageBegin(
176     const folly::IOBuf& buf, protocol::PROTOCOL_TYPES protType) {
177   MessageBegin msgBegin;
178   auto& meta = msgBegin.metadata;
179   try {
180     switch (protType) {
181       case protocol::T_COMPACT_PROTOCOL: {
182         CompactProtocolReader iprot;
183         iprot.setInput(&buf);
184         iprot.readMessageBegin(msgBegin.methodName, meta.msgType, meta.seqId);
185         meta.size = iprot.getCursorPosition();
186         break;
187       }
188       case protocol::T_BINARY_PROTOCOL: {
189         BinaryProtocolReader iprot;
190         iprot.setInput(&buf);
191         iprot.readMessageBegin(msgBegin.methodName, meta.msgType, meta.seqId);
192         meta.size = iprot.getCursorPosition();
193         break;
194       }
195       default:
196         break;
197     }
198   } catch (const TException& ex) {
199     meta.isValid = false;
200     meta.errMessage = ex.what();
201     LOG(ERROR) << "received invalid message from client: " << ex.what();
202   }
203   return msgBegin;
204 }
205 } // namespace ap
206 } // namespace detail
207 
208 namespace detail {
209 namespace si {
formatUnimplementedMethodException(std::string_view methodName)210 std::string formatUnimplementedMethodException(std::string_view methodName) {
211   return fmt::format("Function {} is unimplemented", methodName);
212 }
213 
create_app_exn_unimplemented(const char * name)214 TApplicationException create_app_exn_unimplemented(const char* name) {
215   return TApplicationException(formatUnimplementedMethodException(name));
216 }
217 
throw_app_exn_unimplemented(char const * const name)218 [[noreturn]] void throw_app_exn_unimplemented(char const* const name) {
219   throw create_app_exn_unimplemented(name);
220 }
221 } // namespace si
222 } // namespace detail
223 
224 namespace {
225 
226 constexpr size_t kMaxUexwSize = 1024;
227 
setUserExceptionHeader(Cpp2RequestContext & ctx,std::string exType,std::string exReason,bool setClientCode)228 void setUserExceptionHeader(
229     Cpp2RequestContext& ctx,
230     std::string exType,
231     std::string exReason,
232     bool setClientCode) {
233   auto header = ctx.getHeader();
234   if (!header) {
235     return;
236   }
237 
238   if (setClientCode) {
239     header->setHeader(std::string(detail::kHeaderEx), kAppClientErrorCode);
240   }
241 
242   header->setHeader(std::string(detail::kHeaderUex), std::move(exType));
243   header->setHeader(
244       std::string(detail::kHeaderUexw),
245       exReason.size() > kMaxUexwSize ? exReason.substr(0, kMaxUexwSize)
246                                      : std::move(exReason));
247 }
248 
249 } // namespace
250 
251 namespace util {
252 
appendExceptionToHeader(const folly::exception_wrapper & ew,Cpp2RequestContext & ctx)253 void appendExceptionToHeader(
254     const folly::exception_wrapper& ew, Cpp2RequestContext& ctx) {
255   auto* ex = ew.get_exception();
256   if (const auto* aex = dynamic_cast<const AppBaseError*>(ex)) {
257     setUserExceptionHeader(
258         ctx,
259         std::string(aex->name()),
260         std::string(aex->what()),
261         aex->isClientError());
262     return;
263   }
264 
265   const auto what = ew.what();
266   folly::StringPiece whatsp(what);
267   auto typeName = ew.class_name();
268 
269   ew.with_exception([&](const ExceptionMetadataOverrideBase& emob) {
270     if (auto type = emob.type()) {
271       typeName = folly::demangle(*type);
272     }
273   });
274 
275   whatsp.removePrefix(typeName);
276   whatsp.removePrefix(": ");
277 
278   auto exName = typeName.toStdString();
279   auto exWhat = whatsp.str();
280 
281   setUserExceptionHeader(ctx, std::move(exName), std::move(exWhat), false);
282 }
283 
toTApplicationException(const folly::exception_wrapper & ew)284 TApplicationException toTApplicationException(
285     const folly::exception_wrapper& ew) {
286   auto& ex = *ew.get_exception();
287   auto msg = folly::exceptionStr(ex).toStdString();
288 
289   if (auto* ae =
290           dynamic_cast<const AppBaseError*>(&ex)) { // customized app errors
291     return TApplicationException(
292         TApplicationException::TApplicationExceptionType::UNKNOWN, ex.what());
293   } else {
294     if (auto* te = dynamic_cast<const TApplicationException*>(&ex)) {
295       return *te;
296     } else {
297       return TApplicationException(
298           TApplicationException::TApplicationExceptionType::UNKNOWN,
299           std::move(msg));
300     }
301   }
302 }
303 
includeInRecentRequestsCount(const std::string_view methodName)304 bool includeInRecentRequestsCount(const std::string_view methodName) {
305   return apache::thrift::detail::includeInRecentRequestsCount(methodName);
306 }
307 
308 } // namespace util
309 
310 } // namespace thrift
311 } // namespace apache
312