1 // Copyright 2013 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/websockets/websocket_deflate_stream.h"
6
7 #include <stdint.h>
8
9 #include <algorithm>
10 #include <string>
11 #include <utility>
12 #include <vector>
13
14 #include "base/bind.h"
15 #include "base/logging.h"
16 #include "base/memory/scoped_refptr.h"
17 #include "net/base/io_buffer.h"
18 #include "net/base/net_errors.h"
19 #include "net/websockets/websocket_deflate_parameters.h"
20 #include "net/websockets/websocket_deflate_predictor.h"
21 #include "net/websockets/websocket_deflater.h"
22 #include "net/websockets/websocket_errors.h"
23 #include "net/websockets/websocket_frame.h"
24 #include "net/websockets/websocket_inflater.h"
25 #include "net/websockets/websocket_stream.h"
26
27 class GURL;
28
29 namespace net {
30
31 namespace {
32
33 const int kWindowBits = 15;
34 const size_t kChunkSize = 4 * 1024;
35
36 } // namespace
37
WebSocketDeflateStream(std::unique_ptr<WebSocketStream> stream,const WebSocketDeflateParameters & params,std::unique_ptr<WebSocketDeflatePredictor> predictor)38 WebSocketDeflateStream::WebSocketDeflateStream(
39 std::unique_ptr<WebSocketStream> stream,
40 const WebSocketDeflateParameters& params,
41 std::unique_ptr<WebSocketDeflatePredictor> predictor)
42 : stream_(std::move(stream)),
43 deflater_(params.client_context_take_over_mode()),
44 inflater_(kChunkSize, kChunkSize),
45 reading_state_(NOT_READING),
46 writing_state_(NOT_WRITING),
47 current_reading_opcode_(WebSocketFrameHeader::kOpCodeText),
48 current_writing_opcode_(WebSocketFrameHeader::kOpCodeText),
49 predictor_(std::move(predictor)) {
50 DCHECK(stream_);
51 DCHECK(params.IsValidAsResponse());
52 int client_max_window_bits = 15;
53 if (params.is_client_max_window_bits_specified()) {
54 DCHECK(params.has_client_max_window_bits_value());
55 client_max_window_bits = params.client_max_window_bits();
56 }
57 deflater_.Initialize(client_max_window_bits);
58 inflater_.Initialize(kWindowBits);
59 }
60
61 WebSocketDeflateStream::~WebSocketDeflateStream() = default;
62
ReadFrames(std::vector<std::unique_ptr<WebSocketFrame>> * frames,CompletionOnceCallback callback)63 int WebSocketDeflateStream::ReadFrames(
64 std::vector<std::unique_ptr<WebSocketFrame>>* frames,
65 CompletionOnceCallback callback) {
66 read_callback_ = std::move(callback);
67 inflater_outputs_.clear();
68 int result = stream_->ReadFrames(
69 frames, base::BindOnce(&WebSocketDeflateStream::OnReadComplete,
70 base::Unretained(this), base::Unretained(frames)));
71 if (result < 0)
72 return result;
73 DCHECK_EQ(OK, result);
74 DCHECK(!frames->empty());
75
76 return InflateAndReadIfNecessary(frames);
77 }
78
WriteFrames(std::vector<std::unique_ptr<WebSocketFrame>> * frames,CompletionOnceCallback callback)79 int WebSocketDeflateStream::WriteFrames(
80 std::vector<std::unique_ptr<WebSocketFrame>>* frames,
81 CompletionOnceCallback callback) {
82 deflater_outputs_.clear();
83 int result = Deflate(frames);
84 if (result != OK)
85 return result;
86 if (frames->empty())
87 return OK;
88 return stream_->WriteFrames(frames, std::move(callback));
89 }
90
Close()91 void WebSocketDeflateStream::Close() { stream_->Close(); }
92
GetSubProtocol() const93 std::string WebSocketDeflateStream::GetSubProtocol() const {
94 return stream_->GetSubProtocol();
95 }
96
GetExtensions() const97 std::string WebSocketDeflateStream::GetExtensions() const {
98 return stream_->GetExtensions();
99 }
100
OnReadComplete(std::vector<std::unique_ptr<WebSocketFrame>> * frames,int result)101 void WebSocketDeflateStream::OnReadComplete(
102 std::vector<std::unique_ptr<WebSocketFrame>>* frames,
103 int result) {
104 if (result != OK) {
105 frames->clear();
106 std::move(read_callback_).Run(result);
107 return;
108 }
109
110 int r = InflateAndReadIfNecessary(frames);
111 if (r != ERR_IO_PENDING)
112 std::move(read_callback_).Run(r);
113 }
114
Deflate(std::vector<std::unique_ptr<WebSocketFrame>> * frames)115 int WebSocketDeflateStream::Deflate(
116 std::vector<std::unique_ptr<WebSocketFrame>>* frames) {
117 std::vector<std::unique_ptr<WebSocketFrame>> frames_to_write;
118 // Store frames of the currently processed message if writing_state_ equals to
119 // WRITING_POSSIBLY_COMPRESSED_MESSAGE.
120 std::vector<std::unique_ptr<WebSocketFrame>> frames_of_message;
121 for (size_t i = 0; i < frames->size(); ++i) {
122 DCHECK(!(*frames)[i]->header.reserved1);
123 if (!WebSocketFrameHeader::IsKnownDataOpCode((*frames)[i]->header.opcode)) {
124 frames_to_write.push_back(std::move((*frames)[i]));
125 continue;
126 }
127 if (writing_state_ == NOT_WRITING)
128 OnMessageStart(*frames, i);
129
130 std::unique_ptr<WebSocketFrame> frame(std::move((*frames)[i]));
131 predictor_->RecordInputDataFrame(frame.get());
132
133 if (writing_state_ == WRITING_UNCOMPRESSED_MESSAGE) {
134 if (frame->header.final)
135 writing_state_ = NOT_WRITING;
136 predictor_->RecordWrittenDataFrame(frame.get());
137 frames_to_write.push_back(std::move(frame));
138 current_writing_opcode_ = WebSocketFrameHeader::kOpCodeContinuation;
139 } else {
140 if (frame->payload &&
141 !deflater_.AddBytes(
142 frame->payload,
143 static_cast<size_t>(frame->header.payload_length))) {
144 DVLOG(1) << "WebSocket protocol error. "
145 << "deflater_.AddBytes() returns an error.";
146 return ERR_WS_PROTOCOL_ERROR;
147 }
148 if (frame->header.final && !deflater_.Finish()) {
149 DVLOG(1) << "WebSocket protocol error. "
150 << "deflater_.Finish() returns an error.";
151 return ERR_WS_PROTOCOL_ERROR;
152 }
153
154 if (writing_state_ == WRITING_COMPRESSED_MESSAGE) {
155 if (deflater_.CurrentOutputSize() >= kChunkSize ||
156 frame->header.final) {
157 int result = AppendCompressedFrame(frame->header, &frames_to_write);
158 if (result != OK)
159 return result;
160 }
161 if (frame->header.final)
162 writing_state_ = NOT_WRITING;
163 } else {
164 DCHECK_EQ(WRITING_POSSIBLY_COMPRESSED_MESSAGE, writing_state_);
165 bool final = frame->header.final;
166 frames_of_message.push_back(std::move(frame));
167 if (final) {
168 int result = AppendPossiblyCompressedMessage(&frames_of_message,
169 &frames_to_write);
170 if (result != OK)
171 return result;
172 frames_of_message.clear();
173 writing_state_ = NOT_WRITING;
174 }
175 }
176 }
177 }
178 DCHECK_NE(WRITING_POSSIBLY_COMPRESSED_MESSAGE, writing_state_);
179 frames->swap(frames_to_write);
180 return OK;
181 }
182
OnMessageStart(const std::vector<std::unique_ptr<WebSocketFrame>> & frames,size_t index)183 void WebSocketDeflateStream::OnMessageStart(
184 const std::vector<std::unique_ptr<WebSocketFrame>>& frames,
185 size_t index) {
186 WebSocketFrame* frame = frames[index].get();
187 current_writing_opcode_ = frame->header.opcode;
188 DCHECK(current_writing_opcode_ == WebSocketFrameHeader::kOpCodeText ||
189 current_writing_opcode_ == WebSocketFrameHeader::kOpCodeBinary);
190 WebSocketDeflatePredictor::Result prediction =
191 predictor_->Predict(frames, index);
192
193 switch (prediction) {
194 case WebSocketDeflatePredictor::DEFLATE:
195 writing_state_ = WRITING_COMPRESSED_MESSAGE;
196 return;
197 case WebSocketDeflatePredictor::DO_NOT_DEFLATE:
198 writing_state_ = WRITING_UNCOMPRESSED_MESSAGE;
199 return;
200 case WebSocketDeflatePredictor::TRY_DEFLATE:
201 writing_state_ = WRITING_POSSIBLY_COMPRESSED_MESSAGE;
202 return;
203 }
204 NOTREACHED();
205 }
206
AppendCompressedFrame(const WebSocketFrameHeader & header,std::vector<std::unique_ptr<WebSocketFrame>> * frames_to_write)207 int WebSocketDeflateStream::AppendCompressedFrame(
208 const WebSocketFrameHeader& header,
209 std::vector<std::unique_ptr<WebSocketFrame>>* frames_to_write) {
210 const WebSocketFrameHeader::OpCode opcode = current_writing_opcode_;
211 scoped_refptr<IOBufferWithSize> compressed_payload =
212 deflater_.GetOutput(deflater_.CurrentOutputSize());
213 if (!compressed_payload.get()) {
214 DVLOG(1) << "WebSocket protocol error. "
215 << "deflater_.GetOutput() returns an error.";
216 return ERR_WS_PROTOCOL_ERROR;
217 }
218 deflater_outputs_.push_back(compressed_payload);
219 auto compressed = std::make_unique<WebSocketFrame>(opcode);
220 compressed->header.CopyFrom(header);
221 compressed->header.opcode = opcode;
222 compressed->header.final = header.final;
223 compressed->header.reserved1 =
224 (opcode != WebSocketFrameHeader::kOpCodeContinuation);
225 compressed->payload = compressed_payload->data();
226 compressed->header.payload_length = compressed_payload->size();
227
228 current_writing_opcode_ = WebSocketFrameHeader::kOpCodeContinuation;
229 predictor_->RecordWrittenDataFrame(compressed.get());
230 frames_to_write->push_back(std::move(compressed));
231 return OK;
232 }
233
AppendPossiblyCompressedMessage(std::vector<std::unique_ptr<WebSocketFrame>> * frames,std::vector<std::unique_ptr<WebSocketFrame>> * frames_to_write)234 int WebSocketDeflateStream::AppendPossiblyCompressedMessage(
235 std::vector<std::unique_ptr<WebSocketFrame>>* frames,
236 std::vector<std::unique_ptr<WebSocketFrame>>* frames_to_write) {
237 DCHECK(!frames->empty());
238
239 const WebSocketFrameHeader::OpCode opcode = current_writing_opcode_;
240 scoped_refptr<IOBufferWithSize> compressed_payload =
241 deflater_.GetOutput(deflater_.CurrentOutputSize());
242 if (!compressed_payload.get()) {
243 DVLOG(1) << "WebSocket protocol error. "
244 << "deflater_.GetOutput() returns an error.";
245 return ERR_WS_PROTOCOL_ERROR;
246 }
247 deflater_outputs_.push_back(compressed_payload);
248
249 uint64_t original_payload_length = 0;
250 for (size_t i = 0; i < frames->size(); ++i) {
251 WebSocketFrame* frame = (*frames)[i].get();
252 // Asserts checking that frames represent one whole data message.
253 DCHECK(WebSocketFrameHeader::IsKnownDataOpCode(frame->header.opcode));
254 DCHECK_EQ(i == 0,
255 WebSocketFrameHeader::kOpCodeContinuation !=
256 frame->header.opcode);
257 DCHECK_EQ(i == frames->size() - 1, frame->header.final);
258 original_payload_length += frame->header.payload_length;
259 }
260 if (original_payload_length <=
261 static_cast<uint64_t>(compressed_payload->size())) {
262 // Compression is not effective. Use the original frames.
263 for (size_t i = 0; i < frames->size(); ++i) {
264 std::unique_ptr<WebSocketFrame> frame = std::move((*frames)[i]);
265 predictor_->RecordWrittenDataFrame(frame.get());
266 frames_to_write->push_back(std::move(frame));
267 }
268 frames->clear();
269 return OK;
270 }
271 auto compressed = std::make_unique<WebSocketFrame>(opcode);
272 compressed->header.CopyFrom((*frames)[0]->header);
273 compressed->header.opcode = opcode;
274 compressed->header.final = true;
275 compressed->header.reserved1 = true;
276 compressed->payload = compressed_payload->data();
277 compressed->header.payload_length = compressed_payload->size();
278
279 predictor_->RecordWrittenDataFrame(compressed.get());
280 frames_to_write->push_back(std::move(compressed));
281 return OK;
282 }
283
Inflate(std::vector<std::unique_ptr<WebSocketFrame>> * frames)284 int WebSocketDeflateStream::Inflate(
285 std::vector<std::unique_ptr<WebSocketFrame>>* frames) {
286 std::vector<std::unique_ptr<WebSocketFrame>> frames_to_output;
287 std::vector<std::unique_ptr<WebSocketFrame>> frames_passed;
288 frames->swap(frames_passed);
289 for (size_t i = 0; i < frames_passed.size(); ++i) {
290 std::unique_ptr<WebSocketFrame> frame(std::move(frames_passed[i]));
291 frames_passed[i] = nullptr;
292 DVLOG(3) << "Input frame: opcode=" << frame->header.opcode
293 << " final=" << frame->header.final
294 << " reserved1=" << frame->header.reserved1
295 << " payload_length=" << frame->header.payload_length;
296
297 if (!WebSocketFrameHeader::IsKnownDataOpCode(frame->header.opcode)) {
298 frames_to_output.push_back(std::move(frame));
299 continue;
300 }
301
302 if (reading_state_ == NOT_READING) {
303 if (frame->header.reserved1)
304 reading_state_ = READING_COMPRESSED_MESSAGE;
305 else
306 reading_state_ = READING_UNCOMPRESSED_MESSAGE;
307 current_reading_opcode_ = frame->header.opcode;
308 } else {
309 if (frame->header.reserved1) {
310 DVLOG(1) << "WebSocket protocol error. "
311 << "Receiving a non-first frame with RSV1 flag set.";
312 return ERR_WS_PROTOCOL_ERROR;
313 }
314 }
315
316 if (reading_state_ == READING_UNCOMPRESSED_MESSAGE) {
317 if (frame->header.final)
318 reading_state_ = NOT_READING;
319 current_reading_opcode_ = WebSocketFrameHeader::kOpCodeContinuation;
320 frames_to_output.push_back(std::move(frame));
321 } else {
322 DCHECK_EQ(reading_state_, READING_COMPRESSED_MESSAGE);
323 if (frame->payload &&
324 !inflater_.AddBytes(
325 frame->payload,
326 static_cast<size_t>(frame->header.payload_length))) {
327 DVLOG(1) << "WebSocket protocol error. "
328 << "inflater_.AddBytes() returns an error.";
329 return ERR_WS_PROTOCOL_ERROR;
330 }
331 if (frame->header.final) {
332 if (!inflater_.Finish()) {
333 DVLOG(1) << "WebSocket protocol error. "
334 << "inflater_.Finish() returns an error.";
335 return ERR_WS_PROTOCOL_ERROR;
336 }
337 }
338 // TODO(yhirano): Many frames can be generated by the inflater and
339 // memory consumption can grow.
340 // We could avoid it, but avoiding it makes this class much more
341 // complicated.
342 while (inflater_.CurrentOutputSize() >= kChunkSize ||
343 frame->header.final) {
344 size_t size = std::min(kChunkSize, inflater_.CurrentOutputSize());
345 auto inflated =
346 std::make_unique<WebSocketFrame>(WebSocketFrameHeader::kOpCodeText);
347 scoped_refptr<IOBufferWithSize> data = inflater_.GetOutput(size);
348 inflater_outputs_.push_back(data);
349 bool is_final = !inflater_.CurrentOutputSize() && frame->header.final;
350 if (!data.get()) {
351 DVLOG(1) << "WebSocket protocol error. "
352 << "inflater_.GetOutput() returns an error.";
353 return ERR_WS_PROTOCOL_ERROR;
354 }
355 inflated->header.CopyFrom(frame->header);
356 inflated->header.opcode = current_reading_opcode_;
357 inflated->header.final = is_final;
358 inflated->header.reserved1 = false;
359 inflated->payload = data->data();
360 inflated->header.payload_length = data->size();
361 DVLOG(3) << "Inflated frame: opcode=" << inflated->header.opcode
362 << " final=" << inflated->header.final
363 << " reserved1=" << inflated->header.reserved1
364 << " payload_length=" << inflated->header.payload_length;
365 frames_to_output.push_back(std::move(inflated));
366 current_reading_opcode_ = WebSocketFrameHeader::kOpCodeContinuation;
367 if (is_final)
368 break;
369 }
370 if (frame->header.final)
371 reading_state_ = NOT_READING;
372 }
373 }
374 frames->swap(frames_to_output);
375 return frames->empty() ? ERR_IO_PENDING : OK;
376 }
377
InflateAndReadIfNecessary(std::vector<std::unique_ptr<WebSocketFrame>> * frames)378 int WebSocketDeflateStream::InflateAndReadIfNecessary(
379 std::vector<std::unique_ptr<WebSocketFrame>>* frames) {
380 int result = Inflate(frames);
381 while (result == ERR_IO_PENDING) {
382 DCHECK(frames->empty());
383
384 result = stream_->ReadFrames(
385 frames,
386 base::BindOnce(&WebSocketDeflateStream::OnReadComplete,
387 base::Unretained(this), base::Unretained(frames)));
388 if (result < 0)
389 break;
390 DCHECK_EQ(OK, result);
391 DCHECK(!frames->empty());
392
393 result = Inflate(frames);
394 }
395 if (result < 0)
396 frames->clear();
397 return result;
398 }
399
400 } // namespace net
401