1 // Copyright 2013 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "google_apis/gcm/base/socket_stream.h"
6 
7 #include <stddef.h>
8 
9 #include "base/bind.h"
10 #include "base/callback.h"
11 #include "net/base/io_buffer.h"
12 #include "net/socket/stream_socket.h"
13 
14 namespace gcm {
15 
16 namespace {
17 
18 // TODO(zea): consider having dynamically-sized buffers if this becomes too
19 // expensive.
20 const size_t kDefaultBufferSize = 8*1024;
21 
22 }  // namespace
23 
SocketInputStream(mojo::ScopedDataPipeConsumerHandle stream)24 SocketInputStream::SocketInputStream(mojo::ScopedDataPipeConsumerHandle stream)
25     : stream_(std::move(stream)),
26       stream_watcher_(FROM_HERE, mojo::SimpleWatcher::ArmingPolicy::MANUAL),
27       read_size_(0),
28       io_buffer_(base::MakeRefCounted<net::IOBuffer>(kDefaultBufferSize)),
29       read_buffer_(
30           base::MakeRefCounted<net::DrainableIOBuffer>(io_buffer_,
31                                                        kDefaultBufferSize)),
32       next_pos_(0),
33       last_error_(net::OK) {
34   stream_watcher_.Watch(
35       stream_.get(),
36       MOJO_HANDLE_SIGNAL_READABLE | MOJO_HANDLE_SIGNAL_PEER_CLOSED,
37       MOJO_TRIGGER_CONDITION_SIGNALS_SATISFIED,
38       base::BindRepeating(&SocketInputStream::ReadMore,
39                           base::Unretained(this)));
40 }
41 
~SocketInputStream()42 SocketInputStream::~SocketInputStream() {
43 }
44 
Next(const void ** data,int * size)45 bool SocketInputStream::Next(const void** data, int* size) {
46   if (GetState() != EMPTY && GetState() != READY) {
47     NOTREACHED() << "Invalid input stream read attempt.";
48     return false;
49   }
50 
51   if (GetState() == EMPTY) {
52     DVLOG(1) << "No unread data remaining, ending read.";
53     return false;
54   }
55 
56   DCHECK_EQ(GetState(), READY)
57       << " Input stream must have pending data before reading.";
58   DCHECK_LT(next_pos_, read_buffer_->BytesConsumed());
59   *data = io_buffer_->data() + next_pos_;
60   *size = UnreadByteCount();
61   next_pos_ = read_buffer_->BytesConsumed();
62   DVLOG(1) << "Consuming " << *size << " bytes in input buffer.";
63   return true;
64 }
65 
BackUp(int count)66 void SocketInputStream::BackUp(int count) {
67   DCHECK(GetState() == READY || GetState() == EMPTY);
68   // TODO(zea): investigating crbug.com/409985
69   CHECK_GT(count, 0);
70   CHECK_LE(count, next_pos_);
71 
72   next_pos_ -= count;
73   DVLOG(1) << "Backing up " << count << " bytes in input buffer. "
74            << "Current position now at " << next_pos_
75            << " of " << read_buffer_->BytesConsumed();
76 }
77 
Skip(int count)78 bool SocketInputStream::Skip(int count) {
79   NOTIMPLEMENTED();
80   return false;
81 }
82 
ByteCount() const83 int64_t SocketInputStream::ByteCount() const {
84   DCHECK_NE(GetState(), CLOSED);
85   DCHECK_NE(GetState(), READING);
86   return next_pos_;
87 }
88 
UnreadByteCount() const89 int SocketInputStream::UnreadByteCount() const {
90   DCHECK_NE(GetState(), CLOSED);
91   DCHECK_NE(GetState(), READING);
92   return read_buffer_->BytesConsumed() - next_pos_;
93 }
94 
Refresh(base::OnceClosure callback,int byte_limit)95 net::Error SocketInputStream::Refresh(base::OnceClosure callback,
96                                       int byte_limit) {
97   DCHECK(!read_callback_);
98   DCHECK_NE(GetState(), CLOSED);
99   DCHECK_NE(GetState(), READING);
100   DCHECK_GT(byte_limit, 0);
101 
102   if (byte_limit > read_buffer_->BytesRemaining()) {
103     LOG(ERROR) << "Out of buffer space, closing input stream.";
104     CloseStream(net::ERR_FILE_TOO_BIG);
105     return net::OK;
106   }
107 
108   read_size_ = byte_limit;
109   read_callback_ = std::move(callback);
110   stream_watcher_.ArmOrNotify();
111   last_error_ = net::ERR_IO_PENDING;
112   return net::ERR_IO_PENDING;
113 }
114 
ReadMore(MojoResult result,const mojo::HandleSignalsState &)115 void SocketInputStream::ReadMore(
116     MojoResult result,
117     const mojo::HandleSignalsState& /* ignored */) {
118   DCHECK(read_callback_);
119   DCHECK_NE(0u, read_size_);
120 
121   uint32_t num_bytes = read_size_;
122   if (result == MOJO_RESULT_OK) {
123     DVLOG(1) << "Refreshing input stream, limit of " << num_bytes << " bytes.";
124     result = stream_->ReadData(read_buffer_->data(), &num_bytes,
125                                MOJO_READ_DATA_FLAG_NONE);
126     DVLOG(1) << "Read returned mojo result" << result;
127   }
128 
129   if (result == MOJO_RESULT_SHOULD_WAIT) {
130     stream_watcher_.ArmOrNotify();
131     return;
132   }
133 
134   read_size_ = 0;
135   if (result != MOJO_RESULT_OK) {
136     CloseStream(net::ERR_FAILED);
137     std::move(read_callback_).Run();
138     return;
139   }
140 
141   // If an EOF has been received, close the stream.
142   if (result == MOJO_RESULT_OK && num_bytes == 0) {
143     CloseStream(net::ERR_CONNECTION_CLOSED);
144     std::move(read_callback_).Run();
145     return;
146   }
147 
148   // If an error occurred before the completion callback could complete, ignore
149   // the result.
150   if (GetState() == CLOSED)
151     return;
152 
153   last_error_ = net::OK;
154   read_buffer_->DidConsume(num_bytes);
155   // TODO(zea): investigating crbug.com/409985
156   CHECK_GT(UnreadByteCount(), 0);
157 
158   DVLOG(1) << "Refresh complete with " << num_bytes << " new bytes. "
159            << "Current position " << next_pos_ << " of "
160            << read_buffer_->BytesConsumed() << ".";
161 
162   std::move(read_callback_).Run();
163 }
164 
RebuildBuffer()165 void SocketInputStream::RebuildBuffer() {
166   DVLOG(1) << "Rebuilding input stream, consumed "
167            << next_pos_ << " bytes.";
168   DCHECK_NE(GetState(), READING);
169   DCHECK_NE(GetState(), CLOSED);
170 
171   int unread_data_size = 0;
172   const void* unread_data_ptr = nullptr;
173   Next(&unread_data_ptr, &unread_data_size);
174   ResetInternal();
175 
176   if (unread_data_ptr != io_buffer_->data()) {
177     DVLOG(1) << "Have " << unread_data_size
178              << " unread bytes remaining, shifting.";
179     // Move any remaining unread data to the start of the buffer;
180     std::memmove(io_buffer_->data(), unread_data_ptr, unread_data_size);
181   } else {
182     DVLOG(1) << "Have " << unread_data_size << " unread bytes remaining.";
183   }
184   read_buffer_->DidConsume(unread_data_size);
185   // TODO(zea): investigating crbug.com/409985
186   CHECK_GE(UnreadByteCount(), 0);
187 }
188 
last_error() const189 net::Error SocketInputStream::last_error() const {
190   return last_error_;
191 }
192 
GetState() const193 SocketInputStream::State SocketInputStream::GetState() const {
194   if (last_error_ < net::ERR_IO_PENDING)
195     return CLOSED;
196 
197   if (last_error_ == net::ERR_IO_PENDING)
198     return READING;
199 
200   DCHECK_EQ(last_error_, net::OK);
201   if (read_buffer_->BytesConsumed() == next_pos_)
202     return EMPTY;
203 
204   return READY;
205 }
206 
ResetInternal()207 void SocketInputStream::ResetInternal() {
208   read_buffer_->SetOffset(0);
209   next_pos_ = 0;
210   last_error_ = net::OK;
211   weak_ptr_factory_.InvalidateWeakPtrs();  // Invalidate any callbacks.
212 }
213 
CloseStream(net::Error error)214 void SocketInputStream::CloseStream(net::Error error) {
215   DCHECK_LT(error, net::ERR_IO_PENDING);
216   ResetInternal();
217   last_error_ = error;
218   LOG(ERROR) << "Closing stream with result " << error;
219 }
220 
SocketOutputStream(mojo::ScopedDataPipeProducerHandle stream)221 SocketOutputStream::SocketOutputStream(
222     mojo::ScopedDataPipeProducerHandle stream)
223     : stream_(std::move(stream)),
224       stream_watcher_(FROM_HERE, mojo::SimpleWatcher::ArmingPolicy::MANUAL),
225       io_buffer_(
226           base::MakeRefCounted<net::IOBufferWithSize>(kDefaultBufferSize)),
227       next_pos_(0),
228       last_error_(net::OK) {
229   stream_watcher_.Watch(
230       stream_.get(),
231       MOJO_HANDLE_SIGNAL_WRITABLE | MOJO_HANDLE_SIGNAL_PEER_CLOSED,
232       MOJO_TRIGGER_CONDITION_SIGNALS_SATISFIED,
233       base::BindRepeating(&SocketOutputStream::WriteMore,
234                           base::Unretained(this)));
235 }
236 
~SocketOutputStream()237 SocketOutputStream::~SocketOutputStream() {
238 }
239 
Next(void ** data,int * size)240 bool SocketOutputStream::Next(void** data, int* size) {
241   DCHECK_NE(GetState(), CLOSED);
242   DCHECK_NE(GetState(), FLUSHING);
243   if (next_pos_ == io_buffer_->size())
244     return false;
245 
246   *data = io_buffer_->data() + next_pos_;
247   *size = io_buffer_->size() - next_pos_;
248   next_pos_ = io_buffer_->size();
249   return true;
250 }
251 
BackUp(int count)252 void SocketOutputStream::BackUp(int count) {
253   DCHECK_GE(count, 0);
254   if (count > next_pos_)
255     next_pos_ = 0;
256   next_pos_ -= count;
257   DVLOG(1) << "Backing up " << count << " bytes in output buffer. "
258            << next_pos_ << " bytes used.";
259 }
260 
ByteCount() const261 int64_t SocketOutputStream::ByteCount() const {
262   DCHECK_NE(GetState(), CLOSED);
263   DCHECK_NE(GetState(), FLUSHING);
264   return next_pos_;
265 }
266 
Flush(base::OnceClosure callback)267 net::Error SocketOutputStream::Flush(base::OnceClosure callback) {
268   DCHECK(!write_callback_);
269   DCHECK_EQ(GetState(), READY);
270 
271   if (!write_buffer_) {
272     write_buffer_ = base::MakeRefCounted<net::DrainableIOBuffer>(
273         io_buffer_.get(), next_pos_);
274   }
275 
276   last_error_ = net::ERR_IO_PENDING;
277   stream_watcher_.ArmOrNotify();
278   write_callback_ = std::move(callback);
279   return net::ERR_IO_PENDING;
280 }
281 
WriteMore(MojoResult result,const mojo::HandleSignalsState & state)282 void SocketOutputStream::WriteMore(MojoResult result,
283                                    const mojo::HandleSignalsState& state) {
284   DCHECK(write_callback_);
285   DCHECK(write_buffer_);
286 
287   uint32_t num_bytes = write_buffer_->BytesRemaining();
288   DVLOG(1) << "Flushing " << num_bytes << " bytes into socket.";
289   if (result == MOJO_RESULT_OK) {
290     result = stream_->WriteData(write_buffer_->data(), &num_bytes,
291                                 MOJO_WRITE_DATA_FLAG_NONE);
292   }
293   if (result == MOJO_RESULT_SHOULD_WAIT) {
294     stream_watcher_.ArmOrNotify();
295     return;
296   }
297   if (result != MOJO_RESULT_OK) {
298     LOG(ERROR) << "Failed to flush socket.";
299     last_error_ = net::ERR_FAILED;
300     std::move(write_callback_).Run();
301     return;
302   }
303   DVLOG(1) << "Wrote  " << num_bytes;
304   // If an error occurred before the completion callback could complete, ignore
305   // the result.
306   if (GetState() == CLOSED)
307     return;
308 
309   DCHECK_GE(num_bytes, 0u);
310   last_error_ = net::OK;
311   write_buffer_->DidConsume(num_bytes);
312   if (write_buffer_->BytesRemaining() > 0) {
313     DVLOG(1) << "Partial flush complete. Retrying.";
314     // Only a partial write was completed. Flush again to finish the write.
315     Flush(std::move(write_callback_));
316     return;
317   }
318   DVLOG(1) << "Socket flush complete.";
319   write_buffer_ = nullptr;
320   next_pos_ = 0;
321   std::move(write_callback_).Run();
322 }
323 
GetState() const324 SocketOutputStream::State SocketOutputStream::GetState() const{
325   if (last_error_ < net::ERR_IO_PENDING)
326     return CLOSED;
327 
328   if (last_error_ == net::ERR_IO_PENDING)
329     return FLUSHING;
330 
331   DCHECK_EQ(last_error_, net::OK);
332   if (next_pos_ == 0)
333     return EMPTY;
334 
335   return READY;
336 }
337 
last_error() const338 net::Error SocketOutputStream::last_error() const {
339   return last_error_;
340 }
341 
342 }  // namespace gcm
343