1 //
2 //
3 // Copyright 2020 gRPC authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 //
17 //
18
19 #include <grpc/support/port_platform.h>
20
21 #include "src/core/ext/filters/http/message_compress/message_decompress_filter.h"
22
23 #include <assert.h>
24 #include <string.h>
25
26 #include "absl/strings/str_cat.h"
27 #include "absl/strings/str_format.h"
28
29 #include <grpc/compression.h>
30 #include <grpc/slice_buffer.h>
31 #include <grpc/support/alloc.h>
32 #include <grpc/support/log.h>
33
34 #include "src/core/ext/filters/message_size/message_size_filter.h"
35 #include "src/core/lib/channel/channel_args.h"
36 #include "src/core/lib/compression/algorithm_metadata.h"
37 #include "src/core/lib/compression/compression_args.h"
38 #include "src/core/lib/compression/compression_internal.h"
39 #include "src/core/lib/compression/message_compress.h"
40 #include "src/core/lib/gpr/string.h"
41 #include "src/core/lib/slice/slice_internal.h"
42 #include "src/core/lib/slice/slice_string_helpers.h"
43
44 namespace grpc_core {
45 namespace {
46
47 class ChannelData {
48 public:
ChannelData(const grpc_channel_element_args * args)49 explicit ChannelData(const grpc_channel_element_args* args)
50 : max_recv_size_(GetMaxRecvSizeFromChannelArgs(args->channel_args)) {}
51
max_recv_size() const52 int max_recv_size() const { return max_recv_size_; }
53
54 private:
55 int max_recv_size_;
56 };
57
58 class CallData {
59 public:
CallData(const grpc_call_element_args & args,const ChannelData * chand)60 CallData(const grpc_call_element_args& args, const ChannelData* chand)
61 : call_combiner_(args.call_combiner),
62 max_recv_message_length_(chand->max_recv_size()) {
63 // Initialize state for recv_initial_metadata_ready callback
64 GRPC_CLOSURE_INIT(&on_recv_initial_metadata_ready_,
65 OnRecvInitialMetadataReady, this,
66 grpc_schedule_on_exec_ctx);
67 // Initialize state for recv_message_ready callback
68 grpc_slice_buffer_init(&recv_slices_);
69 GRPC_CLOSURE_INIT(&on_recv_message_next_done_, OnRecvMessageNextDone, this,
70 grpc_schedule_on_exec_ctx);
71 GRPC_CLOSURE_INIT(&on_recv_message_ready_, OnRecvMessageReady, this,
72 grpc_schedule_on_exec_ctx);
73 // Initialize state for recv_trailing_metadata_ready callback
74 GRPC_CLOSURE_INIT(&on_recv_trailing_metadata_ready_,
75 OnRecvTrailingMetadataReady, this,
76 grpc_schedule_on_exec_ctx);
77 const MessageSizeParsedConfig* limits =
78 MessageSizeParsedConfig::GetFromCallContext(args.context);
79 if (limits != nullptr && limits->limits().max_recv_size >= 0 &&
80 (limits->limits().max_recv_size < max_recv_message_length_ ||
81 max_recv_message_length_ < 0)) {
82 max_recv_message_length_ = limits->limits().max_recv_size;
83 }
84 }
85
~CallData()86 ~CallData() { grpc_slice_buffer_destroy_internal(&recv_slices_); }
87
88 void DecompressStartTransportStreamOpBatch(
89 grpc_call_element* elem, grpc_transport_stream_op_batch* batch);
90
91 private:
92 static void OnRecvInitialMetadataReady(void* arg, grpc_error_handle error);
93
94 // Methods for processing a receive message event
95 void MaybeResumeOnRecvMessageReady();
96 static void OnRecvMessageReady(void* arg, grpc_error_handle error);
97 static void OnRecvMessageNextDone(void* arg, grpc_error_handle error);
98 grpc_error_handle PullSliceFromRecvMessage();
99 void ContinueReadingRecvMessage();
100 void FinishRecvMessage();
101 void ContinueRecvMessageReadyCallback(grpc_error_handle error);
102
103 // Methods for processing a recv_trailing_metadata event
104 void MaybeResumeOnRecvTrailingMetadataReady();
105 static void OnRecvTrailingMetadataReady(void* arg, grpc_error_handle error);
106
107 CallCombiner* call_combiner_;
108 // Overall error for the call
109 grpc_error_handle error_ = GRPC_ERROR_NONE;
110 // Fields for handling recv_initial_metadata_ready callback
111 grpc_closure on_recv_initial_metadata_ready_;
112 grpc_closure* original_recv_initial_metadata_ready_ = nullptr;
113 grpc_metadata_batch* recv_initial_metadata_ = nullptr;
114 // Fields for handling recv_message_ready callback
115 bool seen_recv_message_ready_ = false;
116 int max_recv_message_length_;
117 grpc_message_compression_algorithm algorithm_ = GRPC_MESSAGE_COMPRESS_NONE;
118 grpc_closure on_recv_message_ready_;
119 grpc_closure* original_recv_message_ready_ = nullptr;
120 grpc_closure on_recv_message_next_done_;
121 OrphanablePtr<ByteStream>* recv_message_ = nullptr;
122 // recv_slices_ holds the slices read from the original recv_message stream.
123 // It is initialized during construction and reset when a new stream is
124 // created using it.
125 grpc_slice_buffer recv_slices_;
126 std::aligned_storage<sizeof(SliceBufferByteStream),
127 alignof(SliceBufferByteStream)>::type
128 recv_replacement_stream_;
129 // Fields for handling recv_trailing_metadata_ready callback
130 bool seen_recv_trailing_metadata_ready_ = false;
131 grpc_closure on_recv_trailing_metadata_ready_;
132 grpc_closure* original_recv_trailing_metadata_ready_ = nullptr;
133 grpc_error_handle on_recv_trailing_metadata_ready_error_ = GRPC_ERROR_NONE;
134 };
135
DecodeMessageCompressionAlgorithm(grpc_mdelem md)136 grpc_message_compression_algorithm DecodeMessageCompressionAlgorithm(
137 grpc_mdelem md) {
138 grpc_message_compression_algorithm algorithm =
139 grpc_message_compression_algorithm_from_slice(GRPC_MDVALUE(md));
140 if (algorithm == GRPC_MESSAGE_COMPRESS_ALGORITHMS_COUNT) {
141 char* md_c_str = grpc_slice_to_c_string(GRPC_MDVALUE(md));
142 gpr_log(GPR_ERROR,
143 "Invalid incoming message compression algorithm: '%s'. "
144 "Interpreting incoming data as uncompressed.",
145 md_c_str);
146 gpr_free(md_c_str);
147 return GRPC_MESSAGE_COMPRESS_NONE;
148 }
149 return algorithm;
150 }
151
OnRecvInitialMetadataReady(void * arg,grpc_error_handle error)152 void CallData::OnRecvInitialMetadataReady(void* arg, grpc_error_handle error) {
153 CallData* calld = static_cast<CallData*>(arg);
154 if (error == GRPC_ERROR_NONE) {
155 grpc_linked_mdelem* grpc_encoding =
156 calld->recv_initial_metadata_->legacy_index()->named.grpc_encoding;
157 if (grpc_encoding != nullptr) {
158 calld->algorithm_ = DecodeMessageCompressionAlgorithm(grpc_encoding->md);
159 }
160 }
161 calld->MaybeResumeOnRecvMessageReady();
162 calld->MaybeResumeOnRecvTrailingMetadataReady();
163 grpc_closure* closure = calld->original_recv_initial_metadata_ready_;
164 calld->original_recv_initial_metadata_ready_ = nullptr;
165 Closure::Run(DEBUG_LOCATION, closure, GRPC_ERROR_REF(error));
166 }
167
MaybeResumeOnRecvMessageReady()168 void CallData::MaybeResumeOnRecvMessageReady() {
169 if (seen_recv_message_ready_) {
170 seen_recv_message_ready_ = false;
171 GRPC_CALL_COMBINER_START(call_combiner_, &on_recv_message_ready_,
172 GRPC_ERROR_NONE,
173 "continue recv_message_ready callback");
174 }
175 }
176
OnRecvMessageReady(void * arg,grpc_error_handle error)177 void CallData::OnRecvMessageReady(void* arg, grpc_error_handle error) {
178 CallData* calld = static_cast<CallData*>(arg);
179 if (error == GRPC_ERROR_NONE) {
180 if (calld->original_recv_initial_metadata_ready_ != nullptr) {
181 calld->seen_recv_message_ready_ = true;
182 GRPC_CALL_COMBINER_STOP(calld->call_combiner_,
183 "Deferring OnRecvMessageReady until after "
184 "OnRecvInitialMetadataReady");
185 return;
186 }
187 if (calld->algorithm_ != GRPC_MESSAGE_COMPRESS_NONE) {
188 // recv_message can be NULL if trailing metadata is received instead of
189 // message, or it's possible that the message was not compressed.
190 if (*calld->recv_message_ == nullptr ||
191 (*calld->recv_message_)->length() == 0 ||
192 ((*calld->recv_message_)->flags() & GRPC_WRITE_INTERNAL_COMPRESS) ==
193 0) {
194 return calld->ContinueRecvMessageReadyCallback(GRPC_ERROR_NONE);
195 }
196 if (calld->max_recv_message_length_ >= 0 &&
197 (*calld->recv_message_)->length() >
198 static_cast<uint32_t>(calld->max_recv_message_length_)) {
199 GPR_DEBUG_ASSERT(calld->error_ == GRPC_ERROR_NONE);
200 calld->error_ = grpc_error_set_int(
201 GRPC_ERROR_CREATE_FROM_CPP_STRING(
202 absl::StrFormat("Received message larger than max (%u vs. %d)",
203 (*calld->recv_message_)->length(),
204 calld->max_recv_message_length_)),
205 GRPC_ERROR_INT_GRPC_STATUS, GRPC_STATUS_RESOURCE_EXHAUSTED);
206 return calld->ContinueRecvMessageReadyCallback(
207 GRPC_ERROR_REF(calld->error_));
208 }
209 grpc_slice_buffer_destroy_internal(&calld->recv_slices_);
210 grpc_slice_buffer_init(&calld->recv_slices_);
211 return calld->ContinueReadingRecvMessage();
212 }
213 }
214 calld->ContinueRecvMessageReadyCallback(GRPC_ERROR_REF(error));
215 }
216
ContinueReadingRecvMessage()217 void CallData::ContinueReadingRecvMessage() {
218 while ((*recv_message_)
219 ->Next((*recv_message_)->length() - recv_slices_.length,
220 &on_recv_message_next_done_)) {
221 grpc_error_handle error = PullSliceFromRecvMessage();
222 if (error != GRPC_ERROR_NONE) {
223 return ContinueRecvMessageReadyCallback(error);
224 }
225 // We have read the entire message.
226 if (recv_slices_.length == (*recv_message_)->length()) {
227 return FinishRecvMessage();
228 }
229 }
230 }
231
PullSliceFromRecvMessage()232 grpc_error_handle CallData::PullSliceFromRecvMessage() {
233 grpc_slice incoming_slice;
234 grpc_error_handle error = (*recv_message_)->Pull(&incoming_slice);
235 if (error == GRPC_ERROR_NONE) {
236 grpc_slice_buffer_add(&recv_slices_, incoming_slice);
237 }
238 return error;
239 }
240
OnRecvMessageNextDone(void * arg,grpc_error_handle error)241 void CallData::OnRecvMessageNextDone(void* arg, grpc_error_handle error) {
242 CallData* calld = static_cast<CallData*>(arg);
243 if (error != GRPC_ERROR_NONE) {
244 return calld->ContinueRecvMessageReadyCallback(GRPC_ERROR_REF(error));
245 }
246 error = calld->PullSliceFromRecvMessage();
247 if (error != GRPC_ERROR_NONE) {
248 return calld->ContinueRecvMessageReadyCallback(error);
249 }
250 if (calld->recv_slices_.length == (*calld->recv_message_)->length()) {
251 calld->FinishRecvMessage();
252 } else {
253 calld->ContinueReadingRecvMessage();
254 }
255 }
256
FinishRecvMessage()257 void CallData::FinishRecvMessage() {
258 grpc_slice_buffer decompressed_slices;
259 grpc_slice_buffer_init(&decompressed_slices);
260 if (grpc_msg_decompress(algorithm_, &recv_slices_, &decompressed_slices) ==
261 0) {
262 GPR_DEBUG_ASSERT(error_ == GRPC_ERROR_NONE);
263 error_ = GRPC_ERROR_CREATE_FROM_CPP_STRING(
264 absl::StrCat("Unexpected error decompressing data for algorithm with "
265 "enum value ",
266 algorithm_));
267 grpc_slice_buffer_destroy_internal(&decompressed_slices);
268 } else {
269 uint32_t recv_flags =
270 ((*recv_message_)->flags() & (~GRPC_WRITE_INTERNAL_COMPRESS)) |
271 GRPC_WRITE_INTERNAL_TEST_ONLY_WAS_COMPRESSED;
272 // Swap out the original receive byte stream with our new one and send the
273 // batch down.
274 // Initializing recv_replacement_stream_ with decompressed_slices removes
275 // all the slices from decompressed_slices leaving it empty.
276 new (&recv_replacement_stream_)
277 SliceBufferByteStream(&decompressed_slices, recv_flags);
278 recv_message_->reset(
279 reinterpret_cast<SliceBufferByteStream*>(&recv_replacement_stream_));
280 recv_message_ = nullptr;
281 }
282 ContinueRecvMessageReadyCallback(GRPC_ERROR_REF(error_));
283 }
284
ContinueRecvMessageReadyCallback(grpc_error_handle error)285 void CallData::ContinueRecvMessageReadyCallback(grpc_error_handle error) {
286 MaybeResumeOnRecvTrailingMetadataReady();
287 // The surface will clean up the receiving stream if there is an error.
288 grpc_closure* closure = original_recv_message_ready_;
289 original_recv_message_ready_ = nullptr;
290 Closure::Run(DEBUG_LOCATION, closure, error);
291 }
292
MaybeResumeOnRecvTrailingMetadataReady()293 void CallData::MaybeResumeOnRecvTrailingMetadataReady() {
294 if (seen_recv_trailing_metadata_ready_) {
295 seen_recv_trailing_metadata_ready_ = false;
296 grpc_error_handle error = on_recv_trailing_metadata_ready_error_;
297 on_recv_trailing_metadata_ready_error_ = GRPC_ERROR_NONE;
298 GRPC_CALL_COMBINER_START(call_combiner_, &on_recv_trailing_metadata_ready_,
299 error, "Continuing OnRecvTrailingMetadataReady");
300 }
301 }
302
OnRecvTrailingMetadataReady(void * arg,grpc_error_handle error)303 void CallData::OnRecvTrailingMetadataReady(void* arg, grpc_error_handle error) {
304 CallData* calld = static_cast<CallData*>(arg);
305 if (calld->original_recv_initial_metadata_ready_ != nullptr ||
306 calld->original_recv_message_ready_ != nullptr) {
307 calld->seen_recv_trailing_metadata_ready_ = true;
308 calld->on_recv_trailing_metadata_ready_error_ = GRPC_ERROR_REF(error);
309 GRPC_CALL_COMBINER_STOP(
310 calld->call_combiner_,
311 "Deferring OnRecvTrailingMetadataReady until after "
312 "OnRecvInitialMetadataReady and OnRecvMessageReady");
313 return;
314 }
315 error = grpc_error_add_child(GRPC_ERROR_REF(error), calld->error_);
316 calld->error_ = GRPC_ERROR_NONE;
317 grpc_closure* closure = calld->original_recv_trailing_metadata_ready_;
318 calld->original_recv_trailing_metadata_ready_ = nullptr;
319 Closure::Run(DEBUG_LOCATION, closure, error);
320 }
321
DecompressStartTransportStreamOpBatch(grpc_call_element * elem,grpc_transport_stream_op_batch * batch)322 void CallData::DecompressStartTransportStreamOpBatch(
323 grpc_call_element* elem, grpc_transport_stream_op_batch* batch) {
324 // Handle recv_initial_metadata.
325 if (batch->recv_initial_metadata) {
326 recv_initial_metadata_ =
327 batch->payload->recv_initial_metadata.recv_initial_metadata;
328 original_recv_initial_metadata_ready_ =
329 batch->payload->recv_initial_metadata.recv_initial_metadata_ready;
330 batch->payload->recv_initial_metadata.recv_initial_metadata_ready =
331 &on_recv_initial_metadata_ready_;
332 }
333 // Handle recv_message
334 if (batch->recv_message) {
335 recv_message_ = batch->payload->recv_message.recv_message;
336 original_recv_message_ready_ =
337 batch->payload->recv_message.recv_message_ready;
338 batch->payload->recv_message.recv_message_ready = &on_recv_message_ready_;
339 }
340 // Handle recv_trailing_metadata
341 if (batch->recv_trailing_metadata) {
342 original_recv_trailing_metadata_ready_ =
343 batch->payload->recv_trailing_metadata.recv_trailing_metadata_ready;
344 batch->payload->recv_trailing_metadata.recv_trailing_metadata_ready =
345 &on_recv_trailing_metadata_ready_;
346 }
347 // Pass control down the stack.
348 grpc_call_next_op(elem, batch);
349 }
350
DecompressStartTransportStreamOpBatch(grpc_call_element * elem,grpc_transport_stream_op_batch * batch)351 void DecompressStartTransportStreamOpBatch(
352 grpc_call_element* elem, grpc_transport_stream_op_batch* batch) {
353 GPR_TIMER_SCOPE("decompress_start_transport_stream_op_batch", 0);
354 CallData* calld = static_cast<CallData*>(elem->call_data);
355 calld->DecompressStartTransportStreamOpBatch(elem, batch);
356 }
357
DecompressInitCallElem(grpc_call_element * elem,const grpc_call_element_args * args)358 grpc_error_handle DecompressInitCallElem(grpc_call_element* elem,
359 const grpc_call_element_args* args) {
360 ChannelData* chand = static_cast<ChannelData*>(elem->channel_data);
361 new (elem->call_data) CallData(*args, chand);
362 return GRPC_ERROR_NONE;
363 }
364
DecompressDestroyCallElem(grpc_call_element * elem,const grpc_call_final_info *,grpc_closure *)365 void DecompressDestroyCallElem(grpc_call_element* elem,
366 const grpc_call_final_info* /*final_info*/,
367 grpc_closure* /*ignored*/) {
368 CallData* calld = static_cast<CallData*>(elem->call_data);
369 calld->~CallData();
370 }
371
DecompressInitChannelElem(grpc_channel_element * elem,grpc_channel_element_args * args)372 grpc_error_handle DecompressInitChannelElem(grpc_channel_element* elem,
373 grpc_channel_element_args* args) {
374 ChannelData* chand = static_cast<ChannelData*>(elem->channel_data);
375 new (chand) ChannelData(args);
376 return GRPC_ERROR_NONE;
377 }
378
DecompressDestroyChannelElem(grpc_channel_element * elem)379 void DecompressDestroyChannelElem(grpc_channel_element* elem) {
380 ChannelData* chand = static_cast<ChannelData*>(elem->channel_data);
381 chand->~ChannelData();
382 }
383
384 } // namespace
385
386 const grpc_channel_filter MessageDecompressFilter = {
387 DecompressStartTransportStreamOpBatch,
388 grpc_channel_next_op,
389 sizeof(CallData),
390 DecompressInitCallElem,
391 grpc_call_stack_ignore_set_pollset_or_pollset_set,
392 DecompressDestroyCallElem,
393 sizeof(ChannelData),
394 DecompressInitChannelElem,
395 DecompressDestroyChannelElem,
396 grpc_channel_next_get_info,
397 "message_decompress"};
398 } // namespace grpc_core
399