1 //
2 //
3 // Copyright 2020 gRPC authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //     http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 //
17 //
18 
19 #include <grpc/support/port_platform.h>
20 
21 #include "src/core/ext/filters/http/message_compress/message_decompress_filter.h"
22 
23 #include <assert.h>
24 #include <string.h>
25 
26 #include "absl/strings/str_cat.h"
27 #include "absl/strings/str_format.h"
28 
29 #include <grpc/compression.h>
30 #include <grpc/slice_buffer.h>
31 #include <grpc/support/alloc.h>
32 #include <grpc/support/log.h>
33 
34 #include "src/core/ext/filters/message_size/message_size_filter.h"
35 #include "src/core/lib/channel/channel_args.h"
36 #include "src/core/lib/compression/algorithm_metadata.h"
37 #include "src/core/lib/compression/compression_args.h"
38 #include "src/core/lib/compression/compression_internal.h"
39 #include "src/core/lib/compression/message_compress.h"
40 #include "src/core/lib/gpr/string.h"
41 #include "src/core/lib/slice/slice_internal.h"
42 #include "src/core/lib/slice/slice_string_helpers.h"
43 
44 namespace grpc_core {
45 namespace {
46 
47 class ChannelData {
48  public:
ChannelData(const grpc_channel_element_args * args)49   explicit ChannelData(const grpc_channel_element_args* args)
50       : max_recv_size_(GetMaxRecvSizeFromChannelArgs(args->channel_args)) {}
51 
max_recv_size() const52   int max_recv_size() const { return max_recv_size_; }
53 
54  private:
55   int max_recv_size_;
56 };
57 
58 class CallData {
59  public:
CallData(const grpc_call_element_args & args,const ChannelData * chand)60   CallData(const grpc_call_element_args& args, const ChannelData* chand)
61       : call_combiner_(args.call_combiner),
62         max_recv_message_length_(chand->max_recv_size()) {
63     // Initialize state for recv_initial_metadata_ready callback
64     GRPC_CLOSURE_INIT(&on_recv_initial_metadata_ready_,
65                       OnRecvInitialMetadataReady, this,
66                       grpc_schedule_on_exec_ctx);
67     // Initialize state for recv_message_ready callback
68     grpc_slice_buffer_init(&recv_slices_);
69     GRPC_CLOSURE_INIT(&on_recv_message_next_done_, OnRecvMessageNextDone, this,
70                       grpc_schedule_on_exec_ctx);
71     GRPC_CLOSURE_INIT(&on_recv_message_ready_, OnRecvMessageReady, this,
72                       grpc_schedule_on_exec_ctx);
73     // Initialize state for recv_trailing_metadata_ready callback
74     GRPC_CLOSURE_INIT(&on_recv_trailing_metadata_ready_,
75                       OnRecvTrailingMetadataReady, this,
76                       grpc_schedule_on_exec_ctx);
77     const MessageSizeParsedConfig* limits =
78         MessageSizeParsedConfig::GetFromCallContext(args.context);
79     if (limits != nullptr && limits->limits().max_recv_size >= 0 &&
80         (limits->limits().max_recv_size < max_recv_message_length_ ||
81          max_recv_message_length_ < 0)) {
82       max_recv_message_length_ = limits->limits().max_recv_size;
83     }
84   }
85 
~CallData()86   ~CallData() { grpc_slice_buffer_destroy_internal(&recv_slices_); }
87 
88   void DecompressStartTransportStreamOpBatch(
89       grpc_call_element* elem, grpc_transport_stream_op_batch* batch);
90 
91  private:
92   static void OnRecvInitialMetadataReady(void* arg, grpc_error_handle error);
93 
94   // Methods for processing a receive message event
95   void MaybeResumeOnRecvMessageReady();
96   static void OnRecvMessageReady(void* arg, grpc_error_handle error);
97   static void OnRecvMessageNextDone(void* arg, grpc_error_handle error);
98   grpc_error_handle PullSliceFromRecvMessage();
99   void ContinueReadingRecvMessage();
100   void FinishRecvMessage();
101   void ContinueRecvMessageReadyCallback(grpc_error_handle error);
102 
103   // Methods for processing a recv_trailing_metadata event
104   void MaybeResumeOnRecvTrailingMetadataReady();
105   static void OnRecvTrailingMetadataReady(void* arg, grpc_error_handle error);
106 
107   CallCombiner* call_combiner_;
108   // Overall error for the call
109   grpc_error_handle error_ = GRPC_ERROR_NONE;
110   // Fields for handling recv_initial_metadata_ready callback
111   grpc_closure on_recv_initial_metadata_ready_;
112   grpc_closure* original_recv_initial_metadata_ready_ = nullptr;
113   grpc_metadata_batch* recv_initial_metadata_ = nullptr;
114   // Fields for handling recv_message_ready callback
115   bool seen_recv_message_ready_ = false;
116   int max_recv_message_length_;
117   grpc_message_compression_algorithm algorithm_ = GRPC_MESSAGE_COMPRESS_NONE;
118   grpc_closure on_recv_message_ready_;
119   grpc_closure* original_recv_message_ready_ = nullptr;
120   grpc_closure on_recv_message_next_done_;
121   OrphanablePtr<ByteStream>* recv_message_ = nullptr;
122   // recv_slices_ holds the slices read from the original recv_message stream.
123   // It is initialized during construction and reset when a new stream is
124   // created using it.
125   grpc_slice_buffer recv_slices_;
126   std::aligned_storage<sizeof(SliceBufferByteStream),
127                        alignof(SliceBufferByteStream)>::type
128       recv_replacement_stream_;
129   // Fields for handling recv_trailing_metadata_ready callback
130   bool seen_recv_trailing_metadata_ready_ = false;
131   grpc_closure on_recv_trailing_metadata_ready_;
132   grpc_closure* original_recv_trailing_metadata_ready_ = nullptr;
133   grpc_error_handle on_recv_trailing_metadata_ready_error_ = GRPC_ERROR_NONE;
134 };
135 
DecodeMessageCompressionAlgorithm(grpc_mdelem md)136 grpc_message_compression_algorithm DecodeMessageCompressionAlgorithm(
137     grpc_mdelem md) {
138   grpc_message_compression_algorithm algorithm =
139       grpc_message_compression_algorithm_from_slice(GRPC_MDVALUE(md));
140   if (algorithm == GRPC_MESSAGE_COMPRESS_ALGORITHMS_COUNT) {
141     char* md_c_str = grpc_slice_to_c_string(GRPC_MDVALUE(md));
142     gpr_log(GPR_ERROR,
143             "Invalid incoming message compression algorithm: '%s'. "
144             "Interpreting incoming data as uncompressed.",
145             md_c_str);
146     gpr_free(md_c_str);
147     return GRPC_MESSAGE_COMPRESS_NONE;
148   }
149   return algorithm;
150 }
151 
OnRecvInitialMetadataReady(void * arg,grpc_error_handle error)152 void CallData::OnRecvInitialMetadataReady(void* arg, grpc_error_handle error) {
153   CallData* calld = static_cast<CallData*>(arg);
154   if (error == GRPC_ERROR_NONE) {
155     grpc_linked_mdelem* grpc_encoding =
156         calld->recv_initial_metadata_->legacy_index()->named.grpc_encoding;
157     if (grpc_encoding != nullptr) {
158       calld->algorithm_ = DecodeMessageCompressionAlgorithm(grpc_encoding->md);
159     }
160   }
161   calld->MaybeResumeOnRecvMessageReady();
162   calld->MaybeResumeOnRecvTrailingMetadataReady();
163   grpc_closure* closure = calld->original_recv_initial_metadata_ready_;
164   calld->original_recv_initial_metadata_ready_ = nullptr;
165   Closure::Run(DEBUG_LOCATION, closure, GRPC_ERROR_REF(error));
166 }
167 
MaybeResumeOnRecvMessageReady()168 void CallData::MaybeResumeOnRecvMessageReady() {
169   if (seen_recv_message_ready_) {
170     seen_recv_message_ready_ = false;
171     GRPC_CALL_COMBINER_START(call_combiner_, &on_recv_message_ready_,
172                              GRPC_ERROR_NONE,
173                              "continue recv_message_ready callback");
174   }
175 }
176 
OnRecvMessageReady(void * arg,grpc_error_handle error)177 void CallData::OnRecvMessageReady(void* arg, grpc_error_handle error) {
178   CallData* calld = static_cast<CallData*>(arg);
179   if (error == GRPC_ERROR_NONE) {
180     if (calld->original_recv_initial_metadata_ready_ != nullptr) {
181       calld->seen_recv_message_ready_ = true;
182       GRPC_CALL_COMBINER_STOP(calld->call_combiner_,
183                               "Deferring OnRecvMessageReady until after "
184                               "OnRecvInitialMetadataReady");
185       return;
186     }
187     if (calld->algorithm_ != GRPC_MESSAGE_COMPRESS_NONE) {
188       // recv_message can be NULL if trailing metadata is received instead of
189       // message, or it's possible that the message was not compressed.
190       if (*calld->recv_message_ == nullptr ||
191           (*calld->recv_message_)->length() == 0 ||
192           ((*calld->recv_message_)->flags() & GRPC_WRITE_INTERNAL_COMPRESS) ==
193               0) {
194         return calld->ContinueRecvMessageReadyCallback(GRPC_ERROR_NONE);
195       }
196       if (calld->max_recv_message_length_ >= 0 &&
197           (*calld->recv_message_)->length() >
198               static_cast<uint32_t>(calld->max_recv_message_length_)) {
199         GPR_DEBUG_ASSERT(calld->error_ == GRPC_ERROR_NONE);
200         calld->error_ = grpc_error_set_int(
201             GRPC_ERROR_CREATE_FROM_CPP_STRING(
202                 absl::StrFormat("Received message larger than max (%u vs. %d)",
203                                 (*calld->recv_message_)->length(),
204                                 calld->max_recv_message_length_)),
205             GRPC_ERROR_INT_GRPC_STATUS, GRPC_STATUS_RESOURCE_EXHAUSTED);
206         return calld->ContinueRecvMessageReadyCallback(
207             GRPC_ERROR_REF(calld->error_));
208       }
209       grpc_slice_buffer_destroy_internal(&calld->recv_slices_);
210       grpc_slice_buffer_init(&calld->recv_slices_);
211       return calld->ContinueReadingRecvMessage();
212     }
213   }
214   calld->ContinueRecvMessageReadyCallback(GRPC_ERROR_REF(error));
215 }
216 
ContinueReadingRecvMessage()217 void CallData::ContinueReadingRecvMessage() {
218   while ((*recv_message_)
219              ->Next((*recv_message_)->length() - recv_slices_.length,
220                     &on_recv_message_next_done_)) {
221     grpc_error_handle error = PullSliceFromRecvMessage();
222     if (error != GRPC_ERROR_NONE) {
223       return ContinueRecvMessageReadyCallback(error);
224     }
225     // We have read the entire message.
226     if (recv_slices_.length == (*recv_message_)->length()) {
227       return FinishRecvMessage();
228     }
229   }
230 }
231 
PullSliceFromRecvMessage()232 grpc_error_handle CallData::PullSliceFromRecvMessage() {
233   grpc_slice incoming_slice;
234   grpc_error_handle error = (*recv_message_)->Pull(&incoming_slice);
235   if (error == GRPC_ERROR_NONE) {
236     grpc_slice_buffer_add(&recv_slices_, incoming_slice);
237   }
238   return error;
239 }
240 
OnRecvMessageNextDone(void * arg,grpc_error_handle error)241 void CallData::OnRecvMessageNextDone(void* arg, grpc_error_handle error) {
242   CallData* calld = static_cast<CallData*>(arg);
243   if (error != GRPC_ERROR_NONE) {
244     return calld->ContinueRecvMessageReadyCallback(GRPC_ERROR_REF(error));
245   }
246   error = calld->PullSliceFromRecvMessage();
247   if (error != GRPC_ERROR_NONE) {
248     return calld->ContinueRecvMessageReadyCallback(error);
249   }
250   if (calld->recv_slices_.length == (*calld->recv_message_)->length()) {
251     calld->FinishRecvMessage();
252   } else {
253     calld->ContinueReadingRecvMessage();
254   }
255 }
256 
FinishRecvMessage()257 void CallData::FinishRecvMessage() {
258   grpc_slice_buffer decompressed_slices;
259   grpc_slice_buffer_init(&decompressed_slices);
260   if (grpc_msg_decompress(algorithm_, &recv_slices_, &decompressed_slices) ==
261       0) {
262     GPR_DEBUG_ASSERT(error_ == GRPC_ERROR_NONE);
263     error_ = GRPC_ERROR_CREATE_FROM_CPP_STRING(
264         absl::StrCat("Unexpected error decompressing data for algorithm with "
265                      "enum value ",
266                      algorithm_));
267     grpc_slice_buffer_destroy_internal(&decompressed_slices);
268   } else {
269     uint32_t recv_flags =
270         ((*recv_message_)->flags() & (~GRPC_WRITE_INTERNAL_COMPRESS)) |
271         GRPC_WRITE_INTERNAL_TEST_ONLY_WAS_COMPRESSED;
272     // Swap out the original receive byte stream with our new one and send the
273     // batch down.
274     // Initializing recv_replacement_stream_ with decompressed_slices removes
275     // all the slices from decompressed_slices leaving it empty.
276     new (&recv_replacement_stream_)
277         SliceBufferByteStream(&decompressed_slices, recv_flags);
278     recv_message_->reset(
279         reinterpret_cast<SliceBufferByteStream*>(&recv_replacement_stream_));
280     recv_message_ = nullptr;
281   }
282   ContinueRecvMessageReadyCallback(GRPC_ERROR_REF(error_));
283 }
284 
ContinueRecvMessageReadyCallback(grpc_error_handle error)285 void CallData::ContinueRecvMessageReadyCallback(grpc_error_handle error) {
286   MaybeResumeOnRecvTrailingMetadataReady();
287   // The surface will clean up the receiving stream if there is an error.
288   grpc_closure* closure = original_recv_message_ready_;
289   original_recv_message_ready_ = nullptr;
290   Closure::Run(DEBUG_LOCATION, closure, error);
291 }
292 
MaybeResumeOnRecvTrailingMetadataReady()293 void CallData::MaybeResumeOnRecvTrailingMetadataReady() {
294   if (seen_recv_trailing_metadata_ready_) {
295     seen_recv_trailing_metadata_ready_ = false;
296     grpc_error_handle error = on_recv_trailing_metadata_ready_error_;
297     on_recv_trailing_metadata_ready_error_ = GRPC_ERROR_NONE;
298     GRPC_CALL_COMBINER_START(call_combiner_, &on_recv_trailing_metadata_ready_,
299                              error, "Continuing OnRecvTrailingMetadataReady");
300   }
301 }
302 
OnRecvTrailingMetadataReady(void * arg,grpc_error_handle error)303 void CallData::OnRecvTrailingMetadataReady(void* arg, grpc_error_handle error) {
304   CallData* calld = static_cast<CallData*>(arg);
305   if (calld->original_recv_initial_metadata_ready_ != nullptr ||
306       calld->original_recv_message_ready_ != nullptr) {
307     calld->seen_recv_trailing_metadata_ready_ = true;
308     calld->on_recv_trailing_metadata_ready_error_ = GRPC_ERROR_REF(error);
309     GRPC_CALL_COMBINER_STOP(
310         calld->call_combiner_,
311         "Deferring OnRecvTrailingMetadataReady until after "
312         "OnRecvInitialMetadataReady and OnRecvMessageReady");
313     return;
314   }
315   error = grpc_error_add_child(GRPC_ERROR_REF(error), calld->error_);
316   calld->error_ = GRPC_ERROR_NONE;
317   grpc_closure* closure = calld->original_recv_trailing_metadata_ready_;
318   calld->original_recv_trailing_metadata_ready_ = nullptr;
319   Closure::Run(DEBUG_LOCATION, closure, error);
320 }
321 
DecompressStartTransportStreamOpBatch(grpc_call_element * elem,grpc_transport_stream_op_batch * batch)322 void CallData::DecompressStartTransportStreamOpBatch(
323     grpc_call_element* elem, grpc_transport_stream_op_batch* batch) {
324   // Handle recv_initial_metadata.
325   if (batch->recv_initial_metadata) {
326     recv_initial_metadata_ =
327         batch->payload->recv_initial_metadata.recv_initial_metadata;
328     original_recv_initial_metadata_ready_ =
329         batch->payload->recv_initial_metadata.recv_initial_metadata_ready;
330     batch->payload->recv_initial_metadata.recv_initial_metadata_ready =
331         &on_recv_initial_metadata_ready_;
332   }
333   // Handle recv_message
334   if (batch->recv_message) {
335     recv_message_ = batch->payload->recv_message.recv_message;
336     original_recv_message_ready_ =
337         batch->payload->recv_message.recv_message_ready;
338     batch->payload->recv_message.recv_message_ready = &on_recv_message_ready_;
339   }
340   // Handle recv_trailing_metadata
341   if (batch->recv_trailing_metadata) {
342     original_recv_trailing_metadata_ready_ =
343         batch->payload->recv_trailing_metadata.recv_trailing_metadata_ready;
344     batch->payload->recv_trailing_metadata.recv_trailing_metadata_ready =
345         &on_recv_trailing_metadata_ready_;
346   }
347   // Pass control down the stack.
348   grpc_call_next_op(elem, batch);
349 }
350 
DecompressStartTransportStreamOpBatch(grpc_call_element * elem,grpc_transport_stream_op_batch * batch)351 void DecompressStartTransportStreamOpBatch(
352     grpc_call_element* elem, grpc_transport_stream_op_batch* batch) {
353   GPR_TIMER_SCOPE("decompress_start_transport_stream_op_batch", 0);
354   CallData* calld = static_cast<CallData*>(elem->call_data);
355   calld->DecompressStartTransportStreamOpBatch(elem, batch);
356 }
357 
DecompressInitCallElem(grpc_call_element * elem,const grpc_call_element_args * args)358 grpc_error_handle DecompressInitCallElem(grpc_call_element* elem,
359                                          const grpc_call_element_args* args) {
360   ChannelData* chand = static_cast<ChannelData*>(elem->channel_data);
361   new (elem->call_data) CallData(*args, chand);
362   return GRPC_ERROR_NONE;
363 }
364 
DecompressDestroyCallElem(grpc_call_element * elem,const grpc_call_final_info *,grpc_closure *)365 void DecompressDestroyCallElem(grpc_call_element* elem,
366                                const grpc_call_final_info* /*final_info*/,
367                                grpc_closure* /*ignored*/) {
368   CallData* calld = static_cast<CallData*>(elem->call_data);
369   calld->~CallData();
370 }
371 
DecompressInitChannelElem(grpc_channel_element * elem,grpc_channel_element_args * args)372 grpc_error_handle DecompressInitChannelElem(grpc_channel_element* elem,
373                                             grpc_channel_element_args* args) {
374   ChannelData* chand = static_cast<ChannelData*>(elem->channel_data);
375   new (chand) ChannelData(args);
376   return GRPC_ERROR_NONE;
377 }
378 
DecompressDestroyChannelElem(grpc_channel_element * elem)379 void DecompressDestroyChannelElem(grpc_channel_element* elem) {
380   ChannelData* chand = static_cast<ChannelData*>(elem->channel_data);
381   chand->~ChannelData();
382 }
383 
384 }  // namespace
385 
386 const grpc_channel_filter MessageDecompressFilter = {
387     DecompressStartTransportStreamOpBatch,
388     grpc_channel_next_op,
389     sizeof(CallData),
390     DecompressInitCallElem,
391     grpc_call_stack_ignore_set_pollset_or_pollset_set,
392     DecompressDestroyCallElem,
393     sizeof(ChannelData),
394     DecompressInitChannelElem,
395     DecompressDestroyChannelElem,
396     grpc_channel_next_get_info,
397     "message_decompress"};
398 }  // namespace grpc_core
399