1 //
2 //
3 // Copyright 2020 gRPC authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 //
17 //
18
19 #include <grpc/support/port_platform.h>
20
21 #include "src/core/ext/filters/http/message_compress/message_decompress_filter.h"
22
23 #include <assert.h>
24 #include <string.h>
25
26 #include "absl/strings/str_cat.h"
27
28 #include <grpc/compression.h>
29 #include <grpc/slice_buffer.h>
30 #include <grpc/support/alloc.h>
31 #include <grpc/support/log.h>
32
33 #include "absl/strings/str_format.h"
34 #include "src/core/ext/filters/message_size/message_size_filter.h"
35 #include "src/core/lib/channel/channel_args.h"
36 #include "src/core/lib/compression/algorithm_metadata.h"
37 #include "src/core/lib/compression/compression_args.h"
38 #include "src/core/lib/compression/compression_internal.h"
39 #include "src/core/lib/compression/message_compress.h"
40 #include "src/core/lib/gpr/string.h"
41 #include "src/core/lib/slice/slice_internal.h"
42 #include "src/core/lib/slice/slice_string_helpers.h"
43
44 namespace grpc_core {
45 namespace {
46
47 class ChannelData {
48 public:
ChannelData(const grpc_channel_element_args * args)49 explicit ChannelData(const grpc_channel_element_args* args)
50 : max_recv_size_(GetMaxRecvSizeFromChannelArgs(args->channel_args)) {}
51
max_recv_size() const52 int max_recv_size() const { return max_recv_size_; }
53
54 private:
55 int max_recv_size_;
56 };
57
58 class CallData {
59 public:
CallData(const grpc_call_element_args & args,const ChannelData * chand)60 CallData(const grpc_call_element_args& args, const ChannelData* chand)
61 : call_combiner_(args.call_combiner),
62 max_recv_message_length_(chand->max_recv_size()) {
63 // Initialize state for recv_initial_metadata_ready callback
64 GRPC_CLOSURE_INIT(&on_recv_initial_metadata_ready_,
65 OnRecvInitialMetadataReady, this,
66 grpc_schedule_on_exec_ctx);
67 // Initialize state for recv_message_ready callback
68 grpc_slice_buffer_init(&recv_slices_);
69 GRPC_CLOSURE_INIT(&on_recv_message_next_done_, OnRecvMessageNextDone, this,
70 grpc_schedule_on_exec_ctx);
71 GRPC_CLOSURE_INIT(&on_recv_message_ready_, OnRecvMessageReady, this,
72 grpc_schedule_on_exec_ctx);
73 // Initialize state for recv_trailing_metadata_ready callback
74 GRPC_CLOSURE_INIT(&on_recv_trailing_metadata_ready_,
75 OnRecvTrailingMetadataReady, this,
76 grpc_schedule_on_exec_ctx);
77 const MessageSizeParsedConfig* limits =
78 MessageSizeParsedConfig::GetFromCallContext(args.context);
79 if (limits != nullptr && limits->limits().max_recv_size >= 0 &&
80 (limits->limits().max_recv_size < max_recv_message_length_ ||
81 max_recv_message_length_ < 0)) {
82 max_recv_message_length_ = limits->limits().max_recv_size;
83 }
84 }
85
~CallData()86 ~CallData() { grpc_slice_buffer_destroy_internal(&recv_slices_); }
87
88 void DecompressStartTransportStreamOpBatch(
89 grpc_call_element* elem, grpc_transport_stream_op_batch* batch);
90
91 private:
92 static void OnRecvInitialMetadataReady(void* arg, grpc_error* error);
93
94 // Methods for processing a receive message event
95 void MaybeResumeOnRecvMessageReady();
96 static void OnRecvMessageReady(void* arg, grpc_error* error);
97 static void OnRecvMessageNextDone(void* arg, grpc_error* error);
98 grpc_error* PullSliceFromRecvMessage();
99 void ContinueReadingRecvMessage();
100 void FinishRecvMessage();
101 void ContinueRecvMessageReadyCallback(grpc_error* error);
102
103 // Methods for processing a recv_trailing_metadata event
104 void MaybeResumeOnRecvTrailingMetadataReady();
105 static void OnRecvTrailingMetadataReady(void* arg, grpc_error* error);
106
107 CallCombiner* call_combiner_;
108 // Overall error for the call
109 grpc_error* error_ = GRPC_ERROR_NONE;
110 // Fields for handling recv_initial_metadata_ready callback
111 grpc_closure on_recv_initial_metadata_ready_;
112 grpc_closure* original_recv_initial_metadata_ready_ = nullptr;
113 grpc_metadata_batch* recv_initial_metadata_ = nullptr;
114 // Fields for handling recv_message_ready callback
115 bool seen_recv_message_ready_ = false;
116 int max_recv_message_length_;
117 grpc_message_compression_algorithm algorithm_ = GRPC_MESSAGE_COMPRESS_NONE;
118 grpc_closure on_recv_message_ready_;
119 grpc_closure* original_recv_message_ready_ = nullptr;
120 grpc_closure on_recv_message_next_done_;
121 OrphanablePtr<ByteStream>* recv_message_ = nullptr;
122 // recv_slices_ holds the slices read from the original recv_message stream.
123 // It is initialized during construction and reset when a new stream is
124 // created using it.
125 grpc_slice_buffer recv_slices_;
126 std::aligned_storage<sizeof(SliceBufferByteStream),
127 alignof(SliceBufferByteStream)>::type
128 recv_replacement_stream_;
129 // Fields for handling recv_trailing_metadata_ready callback
130 bool seen_recv_trailing_metadata_ready_ = false;
131 grpc_closure on_recv_trailing_metadata_ready_;
132 grpc_closure* original_recv_trailing_metadata_ready_ = nullptr;
133 grpc_error* on_recv_trailing_metadata_ready_error_ = GRPC_ERROR_NONE;
134 };
135
DecodeMessageCompressionAlgorithm(grpc_mdelem md)136 grpc_message_compression_algorithm DecodeMessageCompressionAlgorithm(
137 grpc_mdelem md) {
138 grpc_message_compression_algorithm algorithm =
139 grpc_message_compression_algorithm_from_slice(GRPC_MDVALUE(md));
140 if (algorithm == GRPC_MESSAGE_COMPRESS_ALGORITHMS_COUNT) {
141 char* md_c_str = grpc_slice_to_c_string(GRPC_MDVALUE(md));
142 gpr_log(GPR_ERROR,
143 "Invalid incoming message compression algorithm: '%s'. "
144 "Interpreting incoming data as uncompressed.",
145 md_c_str);
146 gpr_free(md_c_str);
147 return GRPC_MESSAGE_COMPRESS_NONE;
148 }
149 return algorithm;
150 }
151
OnRecvInitialMetadataReady(void * arg,grpc_error * error)152 void CallData::OnRecvInitialMetadataReady(void* arg, grpc_error* error) {
153 CallData* calld = static_cast<CallData*>(arg);
154 if (error == GRPC_ERROR_NONE) {
155 grpc_linked_mdelem* grpc_encoding =
156 calld->recv_initial_metadata_->idx.named.grpc_encoding;
157 if (grpc_encoding != nullptr) {
158 calld->algorithm_ = DecodeMessageCompressionAlgorithm(grpc_encoding->md);
159 }
160 }
161 calld->MaybeResumeOnRecvMessageReady();
162 calld->MaybeResumeOnRecvTrailingMetadataReady();
163 grpc_closure* closure = calld->original_recv_initial_metadata_ready_;
164 calld->original_recv_initial_metadata_ready_ = nullptr;
165 Closure::Run(DEBUG_LOCATION, closure, GRPC_ERROR_REF(error));
166 }
167
MaybeResumeOnRecvMessageReady()168 void CallData::MaybeResumeOnRecvMessageReady() {
169 if (seen_recv_message_ready_) {
170 seen_recv_message_ready_ = false;
171 GRPC_CALL_COMBINER_START(call_combiner_, &on_recv_message_ready_,
172 GRPC_ERROR_NONE,
173 "continue recv_message_ready callback");
174 }
175 }
176
OnRecvMessageReady(void * arg,grpc_error * error)177 void CallData::OnRecvMessageReady(void* arg, grpc_error* error) {
178 CallData* calld = static_cast<CallData*>(arg);
179 if (error == GRPC_ERROR_NONE) {
180 if (calld->original_recv_initial_metadata_ready_ != nullptr) {
181 calld->seen_recv_message_ready_ = true;
182 GRPC_CALL_COMBINER_STOP(calld->call_combiner_,
183 "Deferring OnRecvMessageReady until after "
184 "OnRecvInitialMetadataReady");
185 return;
186 }
187 if (calld->algorithm_ != GRPC_MESSAGE_COMPRESS_NONE) {
188 // recv_message can be NULL if trailing metadata is received instead of
189 // message, or it's possible that the message was not compressed.
190 if (*calld->recv_message_ == nullptr ||
191 (*calld->recv_message_)->length() == 0 ||
192 ((*calld->recv_message_)->flags() & GRPC_WRITE_INTERNAL_COMPRESS) ==
193 0) {
194 return calld->ContinueRecvMessageReadyCallback(GRPC_ERROR_NONE);
195 }
196 if (calld->max_recv_message_length_ >= 0 &&
197 (*calld->recv_message_)->length() >
198 static_cast<uint32_t>(calld->max_recv_message_length_)) {
199 std::string message_string = absl::StrFormat(
200 "Received message larger than max (%u vs. %d)",
201 (*calld->recv_message_)->length(), calld->max_recv_message_length_);
202 GPR_DEBUG_ASSERT(calld->error_ == GRPC_ERROR_NONE);
203 calld->error_ = grpc_error_set_int(
204 GRPC_ERROR_CREATE_FROM_COPIED_STRING(message_string.c_str()),
205 GRPC_ERROR_INT_GRPC_STATUS, GRPC_STATUS_RESOURCE_EXHAUSTED);
206 return calld->ContinueRecvMessageReadyCallback(
207 GRPC_ERROR_REF(calld->error_));
208 }
209 grpc_slice_buffer_destroy_internal(&calld->recv_slices_);
210 grpc_slice_buffer_init(&calld->recv_slices_);
211 return calld->ContinueReadingRecvMessage();
212 }
213 }
214 calld->ContinueRecvMessageReadyCallback(GRPC_ERROR_REF(error));
215 }
216
ContinueReadingRecvMessage()217 void CallData::ContinueReadingRecvMessage() {
218 while ((*recv_message_)
219 ->Next((*recv_message_)->length() - recv_slices_.length,
220 &on_recv_message_next_done_)) {
221 grpc_error* error = PullSliceFromRecvMessage();
222 if (error != GRPC_ERROR_NONE) {
223 return ContinueRecvMessageReadyCallback(error);
224 }
225 // We have read the entire message.
226 if (recv_slices_.length == (*recv_message_)->length()) {
227 return FinishRecvMessage();
228 }
229 }
230 }
231
PullSliceFromRecvMessage()232 grpc_error* CallData::PullSliceFromRecvMessage() {
233 grpc_slice incoming_slice;
234 grpc_error* error = (*recv_message_)->Pull(&incoming_slice);
235 if (error == GRPC_ERROR_NONE) {
236 grpc_slice_buffer_add(&recv_slices_, incoming_slice);
237 }
238 return error;
239 }
240
OnRecvMessageNextDone(void * arg,grpc_error * error)241 void CallData::OnRecvMessageNextDone(void* arg, grpc_error* error) {
242 CallData* calld = static_cast<CallData*>(arg);
243 if (error != GRPC_ERROR_NONE) {
244 return calld->ContinueRecvMessageReadyCallback(GRPC_ERROR_REF(error));
245 }
246 error = calld->PullSliceFromRecvMessage();
247 if (error != GRPC_ERROR_NONE) {
248 return calld->ContinueRecvMessageReadyCallback(error);
249 }
250 if (calld->recv_slices_.length == (*calld->recv_message_)->length()) {
251 calld->FinishRecvMessage();
252 } else {
253 calld->ContinueReadingRecvMessage();
254 }
255 }
256
FinishRecvMessage()257 void CallData::FinishRecvMessage() {
258 grpc_slice_buffer decompressed_slices;
259 grpc_slice_buffer_init(&decompressed_slices);
260 if (grpc_msg_decompress(algorithm_, &recv_slices_, &decompressed_slices) ==
261 0) {
262 GPR_DEBUG_ASSERT(error_ == GRPC_ERROR_NONE);
263 error_ = GRPC_ERROR_CREATE_FROM_COPIED_STRING(
264 absl::StrCat("Unexpected error decompressing data for algorithm with "
265 "enum value ",
266 algorithm_)
267 .c_str());
268 grpc_slice_buffer_destroy_internal(&decompressed_slices);
269 } else {
270 uint32_t recv_flags =
271 ((*recv_message_)->flags() & (~GRPC_WRITE_INTERNAL_COMPRESS)) |
272 GRPC_WRITE_INTERNAL_TEST_ONLY_WAS_COMPRESSED;
273 // Swap out the original receive byte stream with our new one and send the
274 // batch down.
275 // Initializing recv_replacement_stream_ with decompressed_slices removes
276 // all the slices from decompressed_slices leaving it empty.
277 new (&recv_replacement_stream_)
278 SliceBufferByteStream(&decompressed_slices, recv_flags);
279 recv_message_->reset(
280 reinterpret_cast<SliceBufferByteStream*>(&recv_replacement_stream_));
281 recv_message_ = nullptr;
282 }
283 ContinueRecvMessageReadyCallback(GRPC_ERROR_REF(error_));
284 }
285
ContinueRecvMessageReadyCallback(grpc_error * error)286 void CallData::ContinueRecvMessageReadyCallback(grpc_error* error) {
287 MaybeResumeOnRecvTrailingMetadataReady();
288 // The surface will clean up the receiving stream if there is an error.
289 grpc_closure* closure = original_recv_message_ready_;
290 original_recv_message_ready_ = nullptr;
291 Closure::Run(DEBUG_LOCATION, closure, error);
292 }
293
MaybeResumeOnRecvTrailingMetadataReady()294 void CallData::MaybeResumeOnRecvTrailingMetadataReady() {
295 if (seen_recv_trailing_metadata_ready_) {
296 seen_recv_trailing_metadata_ready_ = false;
297 grpc_error* error = on_recv_trailing_metadata_ready_error_;
298 on_recv_trailing_metadata_ready_error_ = GRPC_ERROR_NONE;
299 GRPC_CALL_COMBINER_START(call_combiner_, &on_recv_trailing_metadata_ready_,
300 error, "Continuing OnRecvTrailingMetadataReady");
301 }
302 }
303
OnRecvTrailingMetadataReady(void * arg,grpc_error * error)304 void CallData::OnRecvTrailingMetadataReady(void* arg, grpc_error* error) {
305 CallData* calld = static_cast<CallData*>(arg);
306 if (calld->original_recv_initial_metadata_ready_ != nullptr ||
307 calld->original_recv_message_ready_ != nullptr) {
308 calld->seen_recv_trailing_metadata_ready_ = true;
309 calld->on_recv_trailing_metadata_ready_error_ = GRPC_ERROR_REF(error);
310 GRPC_CALL_COMBINER_STOP(
311 calld->call_combiner_,
312 "Deferring OnRecvTrailingMetadataReady until after "
313 "OnRecvInitialMetadataReady and OnRecvMessageReady");
314 return;
315 }
316 error = grpc_error_add_child(GRPC_ERROR_REF(error), calld->error_);
317 calld->error_ = GRPC_ERROR_NONE;
318 grpc_closure* closure = calld->original_recv_trailing_metadata_ready_;
319 calld->original_recv_trailing_metadata_ready_ = nullptr;
320 Closure::Run(DEBUG_LOCATION, closure, error);
321 }
322
DecompressStartTransportStreamOpBatch(grpc_call_element * elem,grpc_transport_stream_op_batch * batch)323 void CallData::DecompressStartTransportStreamOpBatch(
324 grpc_call_element* elem, grpc_transport_stream_op_batch* batch) {
325 // Handle recv_initial_metadata.
326 if (batch->recv_initial_metadata) {
327 recv_initial_metadata_ =
328 batch->payload->recv_initial_metadata.recv_initial_metadata;
329 original_recv_initial_metadata_ready_ =
330 batch->payload->recv_initial_metadata.recv_initial_metadata_ready;
331 batch->payload->recv_initial_metadata.recv_initial_metadata_ready =
332 &on_recv_initial_metadata_ready_;
333 }
334 // Handle recv_message
335 if (batch->recv_message) {
336 recv_message_ = batch->payload->recv_message.recv_message;
337 original_recv_message_ready_ =
338 batch->payload->recv_message.recv_message_ready;
339 batch->payload->recv_message.recv_message_ready = &on_recv_message_ready_;
340 }
341 // Handle recv_trailing_metadata
342 if (batch->recv_trailing_metadata) {
343 original_recv_trailing_metadata_ready_ =
344 batch->payload->recv_trailing_metadata.recv_trailing_metadata_ready;
345 batch->payload->recv_trailing_metadata.recv_trailing_metadata_ready =
346 &on_recv_trailing_metadata_ready_;
347 }
348 // Pass control down the stack.
349 grpc_call_next_op(elem, batch);
350 }
351
DecompressStartTransportStreamOpBatch(grpc_call_element * elem,grpc_transport_stream_op_batch * batch)352 void DecompressStartTransportStreamOpBatch(
353 grpc_call_element* elem, grpc_transport_stream_op_batch* batch) {
354 GPR_TIMER_SCOPE("decompress_start_transport_stream_op_batch", 0);
355 CallData* calld = static_cast<CallData*>(elem->call_data);
356 calld->DecompressStartTransportStreamOpBatch(elem, batch);
357 }
358
DecompressInitCallElem(grpc_call_element * elem,const grpc_call_element_args * args)359 grpc_error* DecompressInitCallElem(grpc_call_element* elem,
360 const grpc_call_element_args* args) {
361 ChannelData* chand = static_cast<ChannelData*>(elem->channel_data);
362 new (elem->call_data) CallData(*args, chand);
363 return GRPC_ERROR_NONE;
364 }
365
DecompressDestroyCallElem(grpc_call_element * elem,const grpc_call_final_info *,grpc_closure *)366 void DecompressDestroyCallElem(grpc_call_element* elem,
367 const grpc_call_final_info* /*final_info*/,
368 grpc_closure* /*ignored*/) {
369 CallData* calld = static_cast<CallData*>(elem->call_data);
370 calld->~CallData();
371 }
372
DecompressInitChannelElem(grpc_channel_element * elem,grpc_channel_element_args * args)373 grpc_error* DecompressInitChannelElem(grpc_channel_element* elem,
374 grpc_channel_element_args* args) {
375 ChannelData* chand = static_cast<ChannelData*>(elem->channel_data);
376 new (chand) ChannelData(args);
377 return GRPC_ERROR_NONE;
378 }
379
DecompressDestroyChannelElem(grpc_channel_element * elem)380 void DecompressDestroyChannelElem(grpc_channel_element* elem) {
381 ChannelData* chand = static_cast<ChannelData*>(elem->channel_data);
382 chand->~ChannelData();
383 }
384
385 } // namespace
386
387 const grpc_channel_filter MessageDecompressFilter = {
388 DecompressStartTransportStreamOpBatch,
389 grpc_channel_next_op,
390 sizeof(CallData),
391 DecompressInitCallElem,
392 grpc_call_stack_ignore_set_pollset_or_pollset_set,
393 DecompressDestroyCallElem,
394 sizeof(ChannelData),
395 DecompressInitChannelElem,
396 DecompressDestroyChannelElem,
397 grpc_channel_next_get_info,
398 "message_decompress"};
399 } // namespace grpc_core
400