1 /*
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef THRIFT_TRANSPORT_THEADER_H_
18 #define THRIFT_TRANSPORT_THEADER_H_ 1
19 
20 #include <functional>
21 #include <map>
22 #include <optional>
23 #include <string_view>
24 #include <vector>
25 
26 #include <folly/Optional.h>
27 #include <folly/String.h>
28 #include <folly/Utility.h>
29 #include <folly/portability/Unistd.h>
30 #include <thrift/lib/cpp/concurrency/Thread.h>
31 #include <thrift/lib/cpp/protocol/TProtocolTypes.h>
32 #include <thrift/lib/thrift/gen-cpp2/RpcMetadata_types.h>
33 
34 #include <bitset>
35 #include <chrono>
36 
37 // These are local to this build and never appear on the wire.
38 enum CLIENT_TYPE {
39   THRIFT_HEADER_CLIENT_TYPE = 0,
40   THRIFT_FRAMED_DEPRECATED = 1,
41   THRIFT_UNFRAMED_DEPRECATED = 2,
42   THRIFT_HTTP_SERVER_TYPE = 3,
43   THRIFT_HTTP_CLIENT_TYPE = 4,
44   THRIFT_FRAMED_COMPACT = 5,
45   THRIFT_ROCKET_CLIENT_TYPE = 6,
46   THRIFT_HTTP_GET_CLIENT_TYPE = 7,
47   THRIFT_UNFRAMED_COMPACT_DEPRECATED = 8,
48   THRIFT_HTTP2_CLIENT_TYPE = 9,
49   // This MUST always be last and have the largest value!
50   THRIFT_UNKNOWN_CLIENT_TYPE = 10,
51 };
52 
53 #define CLIENT_TYPES_LEN THRIFT_UNKNOWN_CLIENT_TYPE
54 
55 // These appear on the wire.
56 enum HEADER_FLAGS {
57   HEADER_FLAG_SUPPORT_OUT_OF_ORDER = 0x01,
58   // Set for reverse messages (server->client requests, client->server replies)
59   HEADER_FLAG_DUPLEX_REVERSE = 0x08,
60 };
61 
62 namespace folly {
63 class IOBuf;
64 class IOBufQueue;
65 } // namespace folly
66 
67 namespace apache {
68 namespace thrift {
69 namespace util {
70 class THttpClientParser;
71 }
72 } // namespace thrift
73 } // namespace apache
74 
75 namespace apache {
76 namespace thrift {
77 namespace transport {
78 
79 namespace detail {
80 /**
81  * This is a helper class to facilitate transport upgrade from header to rocket
82  * for non-TLS services. The socket stored in header channel is a shared_ptr
83  * while the socket in rocket is a unique_ptr. The goal is to transfer the
84  * socket from header to rocket, by managing the lifetime using this custom
85  * deleter which makes it possible for the unique_ptr stolen by stealPtr()
86  * outlives the shared_ptr.
87  */
88 template <typename T, typename Deleter>
89 class ReleaseDeleter {
90  public:
ReleaseDeleter(std::unique_ptr<T,Deleter> uPtr)91   explicit ReleaseDeleter(std::unique_ptr<T, Deleter> uPtr)
92       : ptr_(uPtr.release()), deleter_(uPtr.get_deleter()) {}
93 
operator()94   void operator()(T* obj) {
95     (void)obj;
96     if (ptr_) {
97       DCHECK(obj == ptr_);
98       deleter_(ptr_);
99     }
100   }
101 
102   /**
103    * Steal the unique_ptr stored in this deleter.
104    */
stealPtr()105   std::unique_ptr<T, Deleter> stealPtr() {
106     DCHECK(ptr_);
107     auto uPtr = std::unique_ptr<T, Deleter>(ptr_, deleter_);
108     ptr_ = nullptr;
109     return uPtr;
110   }
111 
112  private:
113   T* ptr_;
114   Deleter deleter_;
115 };
116 
117 template <typename T, typename Deleter>
convertToShared(std::unique_ptr<T,Deleter> uPtr)118 std::shared_ptr<T> convertToShared(std::unique_ptr<T, Deleter> uPtr) {
119   auto ptr = uPtr.get();
120   auto deleter = ReleaseDeleter<T, Deleter>(std::move(uPtr));
121   return std::shared_ptr<T>(ptr, deleter);
122 }
123 } // namespace detail
124 
125 using apache::thrift::protocol::T_BINARY_PROTOCOL;
126 using apache::thrift::protocol::T_COMPACT_PROTOCOL;
127 
128 /**
129  * Class that will take an IOBuf and wrap it in some thrift headers.
130  * see thrift/doc/HeaderFormat.txt for details.
131  *
132  * Supports transforms: zlib snappy zstd
133  * Supports headers: http-style key/value per request and per connection
134  * other: Protocol Id and seq ID in header.
135  *
136  * Backwards compatible with TFramed format, and unframed format, assuming
137  * your server transport is compatible (some server types require 4-byte size
138  * at the start).
139  */
140 class THeader final {
141  public:
142   enum {
143     ALLOW_BIG_FRAMES = 1 << 0,
144   };
145 
146   explicit THeader(int options = 0);
147 
setClientType(CLIENT_TYPE ct)148   void setClientType(CLIENT_TYPE ct) { this->clientType_ = ct; }
149   // Force using specified client type when using legacy client types
150   // i.e. sniffing out client type is disabled.
forceClientType(bool enable)151   void forceClientType(bool enable) { forceClientType_ = enable; }
getClientType()152   CLIENT_TYPE getClientType() const { return clientType_; }
153 
getProtocolId()154   uint16_t getProtocolId() const { return protoId_; }
setProtocolId(uint16_t protoId)155   void setProtocolId(uint16_t protoId) { this->protoId_ = protoId; }
156 
157   int8_t getProtocolVersion() const;
setProtocolVersion(uint8_t ver)158   void setProtocolVersion(uint8_t ver) { this->protoVersion_ = ver; }
159 
160   void resetProtocol();
161 
getFlags()162   uint16_t getFlags() const { return flags_; }
setFlags(uint16_t flags)163   void setFlags(uint16_t flags) { flags_ = flags; }
164 
165   // Info headers
166   typedef std::map<std::string, std::string> StringToStringMap;
167 
168   /**
169    * We know we got a packet in header format here, try to parse the header
170    *
171    * @param IObuf of the header + data.  Untransforms the data as appropriate.
172    * @return Just the data section in an IOBuf
173    */
174   std::unique_ptr<folly::IOBuf> readHeaderFormat(
175       std::unique_ptr<folly::IOBuf>, StringToStringMap& persistentReadHeaders);
176 
177   /**
178    * Untransform the data based on the received header flags
179    * On conclusion of function, setReadBuffer is called with the
180    * untransformed data.
181    *
182    * @param IOBuf input data section
183    * @return IOBuf output data section
184    */
185   static std::unique_ptr<folly::IOBuf> untransform(
186       std::unique_ptr<folly::IOBuf>, std::vector<uint16_t>& readTrans);
187 
188   /**
189    * Transform the data based on our write transform flags
190    * At conclusion of function the write buffer is set to the
191    * transformed data.
192    *
193    * @param IOBuf to transform.  Returns transformed IOBuf (or chain)
194    * @return transformed data IOBuf
195    */
196   static std::unique_ptr<folly::IOBuf> transform(
197       std::unique_ptr<folly::IOBuf>,
198       std::vector<uint16_t>& writeTrans,
199       size_t minCompressBytes = 0);
200 
201   /**
202    * Copy metadata, but not headers.
203    */
204   void copyMetadataFrom(const THeader& src);
205 
getNumTransforms(const std::vector<uint16_t> & transforms)206   static uint16_t getNumTransforms(const std::vector<uint16_t>& transforms) {
207     return folly::to_narrow(transforms.size());
208   }
209 
210   void setTransform(uint16_t transId);
211   void setReadTransform(uint16_t transId);
setTransforms(const std::vector<uint16_t> & trans)212   void setTransforms(const std::vector<uint16_t>& trans) {
213     writeTrans_ = trans;
214   }
getTransforms()215   const std::vector<uint16_t>& getTransforms() const { return readTrans_; }
getWriteTransforms()216   std::vector<uint16_t>& getWriteTransforms() { return writeTrans_; }
217 
218   void setClientMetadata(const ClientMetadata& clientMetadata);
219   std::optional<ClientMetadata> extractClientMetadata();
220 
221   // these work with write headers
222   void setHeader(const std::string& key, const std::string& value);
223   void setHeader(const std::string& key, std::string&& value);
224   void setHeader(
225       const char* key, size_t keyLength, const char* value, size_t valueLength);
226   void setHeaders(StringToStringMap&&);
227   void clearHeaders();
228   bool isWriteHeadersEmpty() const;
229   StringToStringMap& mutableWriteHeaders();
230   StringToStringMap releaseWriteHeaders();
231   StringToStringMap extractAllWriteHeaders();
232   const StringToStringMap& getWriteHeaders() const;
233 
234   // these work with read headers
235   void setReadHeaders(StringToStringMap&&);
236   void setReadHeader(const std::string& key, std::string&& value);
237   void eraseReadHeader(const std::string& key);
238   const StringToStringMap& getHeaders() const;
239   StringToStringMap releaseHeaders();
240 
setExtraWriteHeaders(StringToStringMap * extraWriteHeaders)241   void setExtraWriteHeaders(StringToStringMap* extraWriteHeaders) {
242     extraWriteHeaders_ = extraWriteHeaders;
243   }
getExtraWriteHeaders()244   StringToStringMap* getExtraWriteHeaders() const { return extraWriteHeaders_; }
245 
246   std::string getPeerIdentity() const;
247   void setIdentity(const std::string& identity);
248 
249   // accessors for seqId
getSequenceNumber()250   uint32_t getSequenceNumber() const { return seqId_; }
setSequenceNumber(uint32_t sid)251   void setSequenceNumber(uint32_t sid) { this->seqId_ = sid; }
252 
253   enum TRANSFORMS {
254     NONE = 0x00,
255     ZLIB_TRANSFORM = 0x01,
256     // HMAC_TRANSFORM = 0x02, Deprecated and no longer supported
257     // SNAPPY_TRANSFORM = 0x03, Deprecated and no longer supported
258     // QLZ_TRANSFORM = 0x04, Deprecated and no longer supported
259     ZSTD_TRANSFORM = 0x05,
260 
261     // DO NOT USE. Sentinel value for enum count. Always keep as last value.
262     TRANSFORM_LAST_FIELD = 0x06,
263   };
264 
265   /* IOBuf interface */
266 
267   /**
268    * Adds the header based on the type of transport:
269    * unframed - does nothing.
270    * framed - prepends frame size
271    * header - prepends header, optionally appends mac
272    * http - only supported for sync case, prepends http header.
273    *
274    * @return IOBuf chain with header _and_ data.  Data is not copied
275    */
276   std::unique_ptr<folly::IOBuf> addHeader(
277       std::unique_ptr<folly::IOBuf>,
278       StringToStringMap& persistentWriteHeaders,
279       bool transform = true);
280   /**
281    * Given an IOBuf Chain, remove the header.  Supports unframed (sync
282    * only), framed, header, and http (sync case only).  This doesn't
283    * check if the client type implied by the header is valid.
284    * isSupportedClient() or checkSupportedClient() should be called
285    * after.
286    *
287    * @param IOBufQueue - queue to try to read message from.
288    *
289    * @param needed - if the return is nullptr (i.e. we didn't read a full
290    *                 message), needed is set to the number of bytes needed
291    *                 before you should call removeHeader again.
292    *
293    * @return IOBuf - the message chain.  May be shared, may be chained.
294    *                 If nullptr, we didn't get enough data for a whole message,
295    *                 call removeHeader again after reading needed more bytes.
296    */
297   std::unique_ptr<folly::IOBuf> removeHeader(
298       folly::IOBufQueue*,
299       size_t& needed,
300       StringToStringMap& persistentReadHeaders);
301 
setDesiredCompressionConfig(CompressionConfig compressionConfig)302   void setDesiredCompressionConfig(CompressionConfig compressionConfig) {
303     compressionConfig_ = compressionConfig;
304   }
305 
getDesiredCompressionConfig()306   folly::Optional<CompressionConfig> getDesiredCompressionConfig() const {
307     return compressionConfig_;
308   }
309 
setCrc32c(folly::Optional<uint32_t> crc32c)310   void setCrc32c(folly::Optional<uint32_t> crc32c) { crc32c_ = crc32c; }
311 
getCrc32c()312   folly::Optional<uint32_t> getCrc32c() const { return crc32c_; }
313 
setServerLoad(folly::Optional<int64_t> load)314   void setServerLoad(folly::Optional<int64_t> load) { serverLoad_ = load; }
315 
getServerLoad()316   folly::Optional<int64_t> getServerLoad() const { return serverLoad_; }
317 
318   apache::thrift::concurrency::PRIORITY getCallPriority() const;
319 
320   std::chrono::milliseconds getTimeoutFromHeader(
321       const std::string& header) const;
322 
323   std::chrono::milliseconds getClientTimeout() const;
324 
325   std::chrono::milliseconds getClientQueueTimeout() const;
326 
327   // Overall queue timeout (either set by client or server override)
328   // This is set on the server side in responses.
329   folly::Optional<std::chrono::milliseconds> getServerQueueTimeout() const;
330 
331   // This is populated by the server and reflects the time the request spent
332   // in the queue prior to processing.
333   folly::Optional<std::chrono::milliseconds> getProcessDelay() const;
334 
335   const folly::Optional<std::string>& clientId() const;
336   const folly::Optional<std::string>& serviceTraceMeta() const;
337 
338   void setHttpClientParser(
339       std::shared_ptr<apache::thrift::util::THttpClientParser>);
340 
341   void setClientTimeout(std::chrono::milliseconds timeout);
342   void setClientQueueTimeout(std::chrono::milliseconds timeout);
343   void setServerQueueTimeout(std::chrono::milliseconds timeout);
344   void setProcessDelay(std::chrono::milliseconds timeQueued);
345   void setCallPriority(apache::thrift::concurrency::PRIORITY priority);
346   void setClientId(const std::string& clientId);
347   void setServiceTraceMeta(const std::string& serviceTraceMeta);
348 
349   // Utility method for converting TRANSFORMS enum to string
350   static const folly::StringPiece getStringTransform(
351       const TRANSFORMS transform);
352 
353   static CLIENT_TYPE tryGetClientType(const folly::IOBuf& data);
354 
setRoutingData(std::shared_ptr<void> data)355   void setRoutingData(std::shared_ptr<void> data) {
356     routingData_ = std::move(data);
357   }
releaseRoutingData()358   std::shared_ptr<void> releaseRoutingData() { return std::move(routingData_); }
359 
360   // 0 and 16th bits must be 0 to differentiate from framed & unframed
361   static const uint32_t HEADER_MAGIC = 0x0FFF0000;
362   static const uint32_t HEADER_MASK = 0xFFFF0000;
363   static const uint32_t FLAGS_MASK = 0x0000FFFF;
364   static const uint32_t HTTP_SERVER_MAGIC = 0x504F5354; // POST
365   static const uint32_t HTTP_CLIENT_MAGIC = 0x48545450; // HTTP
366   static const uint32_t HTTP_GET_CLIENT_MAGIC = 0x47455420; // GET
367   static const uint32_t HTTP_HEAD_CLIENT_MAGIC = 0x48454144; // HEAD
368   static const uint32_t BIG_FRAME_MAGIC = 0x42494746; // BIGF
369 
370   static const uint32_t MAX_FRAME_SIZE = 0x3FFFFFFF;
371   static const std::string PRIORITY_HEADER;
372   static const std::string& CLIENT_TIMEOUT_HEADER;
373   static const std::string QUEUE_TIMEOUT_HEADER;
374   static const std::string QUERY_LOAD_HEADER;
375   static const std::string kClientId;
376   static const std::string kServiceTraceMeta;
377   static constexpr std::string_view CLIENT_METADATA_HEADER = "client_metadata";
378 
379  private:
380   static bool isFramed(CLIENT_TYPE clientType);
381 
382   // Use first 64 bits to determine client protocol
383   static folly::Optional<CLIENT_TYPE> analyzeFirst32bit(uint32_t w);
384   static CLIENT_TYPE analyzeSecond32bit(uint32_t w);
385 
386   // Calls appropriate method based on client type
387   // returns nullptr if Header of Unknown type
388   std::unique_ptr<folly::IOBuf> removeNonHeader(
389       folly::IOBufQueue* queue,
390       size_t& needed,
391       CLIENT_TYPE clientType,
392       uint32_t sz);
393 
394   template <
395       template <class BaseProt>
396       class ProtocolClass,
397       protocol::PROTOCOL_TYPES ProtocolID>
398   std::unique_ptr<folly::IOBuf> removeUnframed(
399       folly::IOBufQueue* queue, size_t& needed);
400   std::unique_ptr<folly::IOBuf> removeHttpServer(folly::IOBufQueue* queue);
401   std::unique_ptr<folly::IOBuf> removeHttpClient(
402       folly::IOBufQueue* queue, size_t& needed);
403   std::unique_ptr<folly::IOBuf> removeFramed(
404       uint32_t sz, folly::IOBufQueue* queue);
405 
406   /**
407    * Returns the maximum number of bytes that write k/v headers can take
408    */
409   size_t getMaxWriteHeadersSize(
410       const StringToStringMap& persistentWriteHeaders) const;
411 
412   /**
413    * Returns whether the 1st byte of the protocol payload should be hadled
414    * as compact framed.
415    */
416   static bool compactFramed(uint32_t magic);
417 
418   std::optional<std::string> extractHeader(std::string_view key);
419   StringToStringMap& ensureReadHeaders();
420   StringToStringMap& ensureWriteHeaders();
421 
422   // Http client parser
423   std::shared_ptr<apache::thrift::util::THttpClientParser> httpClientParser_;
424 
425   int16_t protoId_;
426   int8_t protoVersion_;
427   CLIENT_TYPE clientType_;
428   bool forceClientType_;
429   uint32_t seqId_;
430   uint16_t flags_;
431   std::string identity_;
432 
433   std::vector<uint16_t> readTrans_;
434   std::vector<uint16_t> writeTrans_;
435 
436   // Map to use for headers
437   std::optional<StringToStringMap> readHeaders_;
438   std::optional<StringToStringMap> writeHeaders_;
439 
440   // Won't be cleared when flushing
441   StringToStringMap* extraWriteHeaders_{nullptr};
442 
443   // If these values are set, they are used instead of looking inside
444   // the header map.
445   folly::Optional<std::chrono::milliseconds> clientTimeout_;
446   folly::Optional<std::chrono::milliseconds> queueTimeout_;
447   folly::Optional<std::chrono::milliseconds> processDelay_;
448   folly::Optional<std::chrono::milliseconds> serverQueueTimeout_;
449   folly::Optional<apache::thrift::concurrency::PRIORITY> priority_;
450   folly::Optional<std::string> clientId_;
451   folly::Optional<std::string> serviceTraceMeta_;
452 
453   static const std::string IDENTITY_HEADER;
454   static const std::string ID_VERSION_HEADER;
455   static const std::string ID_VERSION;
456 
457   bool allowBigFrames_;
458   folly::Optional<CompressionConfig> compressionConfig_;
459 
460   std::shared_ptr<void> routingData_;
461 
462   struct infoIdType {
463     enum idType {
464       // start at 1 to avoid confusing header padding for an infoId
465       KEYVALUE = 1,
466       // for persistent header
467       PKEYVALUE = 2,
468       END // signal the end of infoIds we can handle
469     };
470   };
471 
472   // CRC32C of message payload for checksum.
473   folly::Optional<uint32_t> crc32c_;
474   folly::Optional<int64_t> serverLoad_;
475 };
476 
477 } // namespace transport
478 } // namespace thrift
479 } // namespace apache
480 
481 #endif // #ifndef THRIFT_TRANSPORT_THEADER_H_
482