1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements. See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership. The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License. You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied. See the License for the
15 // specific language governing permissions and limitations
16 // under the License.
17 
18 use std::cmp;
19 use std::io;
20 use std::io::{Read, Write};
21 
22 use super::{TReadTransport, TReadTransportFactory, TWriteTransport, TWriteTransportFactory};
23 
24 /// Default capacity of the read buffer in bytes.
25 const READ_CAPACITY: usize = 4096;
26 
27 /// Default capacity of the write buffer in bytes..
28 const WRITE_CAPACITY: usize = 4096;
29 
30 /// Transport that reads messages via an internal buffer.
31 ///
32 /// A `TBufferedReadTransport` maintains a fixed-size internal read buffer.
33 /// On a call to `TBufferedReadTransport::read(...)` one full message - both
34 /// fixed-length header and bytes - is read from the wrapped channel and buffered.
35 /// Subsequent read calls are serviced from the internal buffer until it is
36 /// exhausted, at which point the next full message is read from the wrapped
37 /// channel.
38 ///
39 /// # Examples
40 ///
41 /// Create and use a `TBufferedReadTransport`.
42 ///
43 /// ```no_run
44 /// use std::io::Read;
45 /// use thrift::transport::{TBufferedReadTransport, TTcpChannel};
46 ///
47 /// let mut c = TTcpChannel::new();
48 /// c.open("localhost:9090").unwrap();
49 ///
50 /// let mut t = TBufferedReadTransport::new(c);
51 ///
52 /// t.read(&mut vec![0u8; 1]).unwrap();
53 /// ```
54 #[derive(Debug)]
55 pub struct TBufferedReadTransport<C>
56 where
57     C: Read,
58 {
59     buf: Box<[u8]>,
60     pos: usize,
61     cap: usize,
62     chan: C,
63 }
64 
65 impl<C> TBufferedReadTransport<C>
66 where
67     C: Read,
68 {
69     /// Create a `TBufferedTransport` with default-sized internal read and
70     /// write buffers that wraps the given `TIoChannel`.
new(channel: C) -> TBufferedReadTransport<C>71     pub fn new(channel: C) -> TBufferedReadTransport<C> {
72         TBufferedReadTransport::with_capacity(READ_CAPACITY, channel)
73     }
74 
75     /// Create a `TBufferedTransport` with an internal read buffer of size
76     /// `read_capacity` and an internal write buffer of size
77     /// `write_capacity` that wraps the given `TIoChannel`.
with_capacity(read_capacity: usize, channel: C) -> TBufferedReadTransport<C>78     pub fn with_capacity(read_capacity: usize, channel: C) -> TBufferedReadTransport<C> {
79         TBufferedReadTransport {
80             buf: vec![0; read_capacity].into_boxed_slice(),
81             pos: 0,
82             cap: 0,
83             chan: channel,
84         }
85     }
86 
get_bytes(&mut self) -> io::Result<&[u8]>87     fn get_bytes(&mut self) -> io::Result<&[u8]> {
88         if self.cap - self.pos == 0 {
89             self.pos = 0;
90             self.cap = self.chan.read(&mut self.buf)?;
91         }
92 
93         Ok(&self.buf[self.pos..self.cap])
94     }
95 
consume(&mut self, consumed: usize)96     fn consume(&mut self, consumed: usize) {
97         // TODO: was a bug here += <-- test somehow
98         self.pos = cmp::min(self.cap, self.pos + consumed);
99     }
100 }
101 
102 impl<C> Read for TBufferedReadTransport<C>
103 where
104     C: Read,
105 {
read(&mut self, buf: &mut [u8]) -> io::Result<usize>106     fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
107         let mut bytes_read = 0;
108 
109         loop {
110             let nread = {
111                 let avail_bytes = self.get_bytes()?;
112                 let avail_space = buf.len() - bytes_read;
113                 let nread = cmp::min(avail_space, avail_bytes.len());
114                 buf[bytes_read..(bytes_read + nread)].copy_from_slice(&avail_bytes[..nread]);
115                 nread
116             };
117 
118             self.consume(nread);
119             bytes_read += nread;
120 
121             if bytes_read == buf.len() || nread == 0 {
122                 break;
123             }
124         }
125 
126         Ok(bytes_read)
127     }
128 }
129 
130 /// Factory for creating instances of `TBufferedReadTransport`.
131 #[derive(Default)]
132 pub struct TBufferedReadTransportFactory;
133 
134 impl TBufferedReadTransportFactory {
new() -> TBufferedReadTransportFactory135     pub fn new() -> TBufferedReadTransportFactory {
136         TBufferedReadTransportFactory {}
137     }
138 }
139 
140 impl TReadTransportFactory for TBufferedReadTransportFactory {
141     /// Create a `TBufferedReadTransport`.
create(&self, channel: Box<dyn Read + Send>) -> Box<dyn TReadTransport + Send>142     fn create(&self, channel: Box<dyn Read + Send>) -> Box<dyn TReadTransport + Send> {
143         Box::new(TBufferedReadTransport::new(channel))
144     }
145 }
146 
147 /// Transport that writes messages via an internal buffer.
148 ///
149 /// A `TBufferedWriteTransport` maintains a fixed-size internal write buffer.
150 /// All writes are made to this buffer and are sent to the wrapped channel only
151 /// when `TBufferedWriteTransport::flush()` is called. On a flush a fixed-length
152 /// header with a count of the buffered bytes is written, followed by the bytes
153 /// themselves.
154 ///
155 /// # Examples
156 ///
157 /// Create and use a `TBufferedWriteTransport`.
158 ///
159 /// ```no_run
160 /// use std::io::Write;
161 /// use thrift::transport::{TBufferedWriteTransport, TTcpChannel};
162 ///
163 /// let mut c = TTcpChannel::new();
164 /// c.open("localhost:9090").unwrap();
165 ///
166 /// let mut t = TBufferedWriteTransport::new(c);
167 ///
168 /// t.write(&[0x00]).unwrap();
169 /// t.flush().unwrap();
170 /// ```
171 #[derive(Debug)]
172 pub struct TBufferedWriteTransport<C>
173 where
174     C: Write,
175 {
176     buf: Vec<u8>,
177     cap: usize,
178     channel: C,
179 }
180 
181 impl<C> TBufferedWriteTransport<C>
182 where
183     C: Write,
184 {
185     /// Create a `TBufferedTransport` with default-sized internal read and
186     /// write buffers that wraps the given `TIoChannel`.
new(channel: C) -> TBufferedWriteTransport<C>187     pub fn new(channel: C) -> TBufferedWriteTransport<C> {
188         TBufferedWriteTransport::with_capacity(WRITE_CAPACITY, channel)
189     }
190 
191     /// Create a `TBufferedTransport` with an internal read buffer of size
192     /// `read_capacity` and an internal write buffer of size
193     /// `write_capacity` that wraps the given `TIoChannel`.
with_capacity(write_capacity: usize, channel: C) -> TBufferedWriteTransport<C>194     pub fn with_capacity(write_capacity: usize, channel: C) -> TBufferedWriteTransport<C> {
195         assert!(
196             write_capacity > 0,
197             "write buffer size must be a positive integer"
198         );
199 
200         TBufferedWriteTransport {
201             buf: Vec::with_capacity(write_capacity),
202             cap: write_capacity,
203             channel,
204         }
205     }
206 }
207 
208 impl<C> Write for TBufferedWriteTransport<C>
209 where
210     C: Write,
211 {
write(&mut self, buf: &[u8]) -> io::Result<usize>212     fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
213         if !buf.is_empty() {
214             let mut avail_bytes;
215 
216             loop {
217                 avail_bytes = cmp::min(buf.len(), self.cap - self.buf.len());
218 
219                 if avail_bytes == 0 {
220                     self.flush()?;
221                 } else {
222                     break;
223                 }
224             }
225 
226             let avail_bytes = avail_bytes;
227 
228             self.buf.extend_from_slice(&buf[..avail_bytes]);
229             assert!(self.buf.len() <= self.cap, "copy overflowed buffer");
230 
231             Ok(avail_bytes)
232         } else {
233             Ok(0)
234         }
235     }
236 
flush(&mut self) -> io::Result<()>237     fn flush(&mut self) -> io::Result<()> {
238         self.channel.write_all(&self.buf)?;
239         self.channel.flush()?;
240         self.buf.clear();
241         Ok(())
242     }
243 }
244 
245 /// Factory for creating instances of `TBufferedWriteTransport`.
246 #[derive(Default)]
247 pub struct TBufferedWriteTransportFactory;
248 
249 impl TBufferedWriteTransportFactory {
new() -> TBufferedWriteTransportFactory250     pub fn new() -> TBufferedWriteTransportFactory {
251         TBufferedWriteTransportFactory {}
252     }
253 }
254 
255 impl TWriteTransportFactory for TBufferedWriteTransportFactory {
256     /// Create a `TBufferedWriteTransport`.
create(&self, channel: Box<dyn Write + Send>) -> Box<dyn TWriteTransport + Send>257     fn create(&self, channel: Box<dyn Write + Send>) -> Box<dyn TWriteTransport + Send> {
258         Box::new(TBufferedWriteTransport::new(channel))
259     }
260 }
261 
262 #[cfg(test)]
263 mod tests {
264     use std::io::{Read, Write};
265 
266     use super::*;
267     use crate::transport::TBufferChannel;
268 
269     #[test]
must_return_zero_if_read_buffer_is_empty()270     fn must_return_zero_if_read_buffer_is_empty() {
271         let mem = TBufferChannel::with_capacity(10, 0);
272         let mut t = TBufferedReadTransport::with_capacity(10, mem);
273 
274         let mut b = vec![0; 10];
275         let read_result = t.read(&mut b);
276 
277         assert_eq!(read_result.unwrap(), 0);
278     }
279 
280     #[test]
must_return_zero_if_caller_reads_into_zero_capacity_buffer()281     fn must_return_zero_if_caller_reads_into_zero_capacity_buffer() {
282         let mem = TBufferChannel::with_capacity(10, 0);
283         let mut t = TBufferedReadTransport::with_capacity(10, mem);
284 
285         let read_result = t.read(&mut []);
286 
287         assert_eq!(read_result.unwrap(), 0);
288     }
289 
290     #[test]
must_return_zero_if_nothing_more_can_be_read()291     fn must_return_zero_if_nothing_more_can_be_read() {
292         let mem = TBufferChannel::with_capacity(4, 0);
293         let mut t = TBufferedReadTransport::with_capacity(4, mem);
294 
295         t.chan.set_readable_bytes(&[0, 1, 2, 3]);
296 
297         // read buffer is exactly the same size as bytes available
298         let mut buf = vec![0u8; 4];
299         let read_result = t.read(&mut buf);
300 
301         // we've read exactly 4 bytes
302         assert_eq!(read_result.unwrap(), 4);
303         assert_eq!(&buf, &[0, 1, 2, 3]);
304 
305         // try read again
306         let buf_again = vec![0u8; 4];
307         let read_result = t.read(&mut buf);
308 
309         // this time, 0 bytes and we haven't changed the buffer
310         assert_eq!(read_result.unwrap(), 0);
311         assert_eq!(&buf_again, &[0, 0, 0, 0])
312     }
313 
314     #[test]
must_fill_user_buffer_with_only_as_many_bytes_as_available()315     fn must_fill_user_buffer_with_only_as_many_bytes_as_available() {
316         let mem = TBufferChannel::with_capacity(4, 0);
317         let mut t = TBufferedReadTransport::with_capacity(4, mem);
318 
319         t.chan.set_readable_bytes(&[0, 1, 2, 3]);
320 
321         // read buffer is much larger than the bytes available
322         let mut buf = vec![0u8; 8];
323         let read_result = t.read(&mut buf);
324 
325         // we've read exactly 4 bytes
326         assert_eq!(read_result.unwrap(), 4);
327         assert_eq!(&buf[..4], &[0, 1, 2, 3]);
328 
329         // try read again
330         let read_result = t.read(&mut buf[4..]);
331 
332         // this time, 0 bytes and we haven't changed the buffer
333         assert_eq!(read_result.unwrap(), 0);
334         assert_eq!(&buf, &[0, 1, 2, 3, 0, 0, 0, 0])
335     }
336 
337     #[test]
must_read_successfully()338     fn must_read_successfully() {
339         // this test involves a few loops within the buffered transport
340         // itself where it has to drain the underlying transport in order
341         // to service a read
342 
343         // we have a much smaller buffer than the
344         // underlying transport has bytes available
345         let mem = TBufferChannel::with_capacity(10, 0);
346         let mut t = TBufferedReadTransport::with_capacity(2, mem);
347 
348         // fill the underlying transport's byte buffer
349         let mut readable_bytes = [0u8; 10];
350         for (i, b) in readable_bytes.iter_mut().enumerate() {
351             *b = i as u8;
352         }
353 
354         t.chan.set_readable_bytes(&readable_bytes);
355 
356         // we ask to read into a buffer that's much larger
357         // than the one the buffered transport has; as a result
358         // it's going to have to keep asking the underlying
359         // transport for more bytes
360         let mut buf = [0u8; 8];
361         let read_result = t.read(&mut buf);
362 
363         // we should have read 8 bytes
364         assert_eq!(read_result.unwrap(), 8);
365         assert_eq!(&buf, &[0, 1, 2, 3, 4, 5, 6, 7]);
366 
367         // let's clear out the buffer and try read again
368         for b in &mut buf{
369             *b = 0;
370         }
371         let read_result = t.read(&mut buf);
372 
373         // this time we were only able to read 2 bytes
374         // (all that's remaining from the underlying transport)
375         // let's also check that the remaining bytes are untouched
376         assert_eq!(read_result.unwrap(), 2);
377         assert_eq!(&buf[0..2], &[8, 9]);
378         assert_eq!(&buf[2..], &[0, 0, 0, 0, 0, 0]);
379 
380         // try read again (we should get 0)
381         // and all the existing bytes were untouched
382         let read_result = t.read(&mut buf);
383         assert_eq!(read_result.unwrap(), 0);
384         assert_eq!(&buf[0..2], &[8, 9]);
385         assert_eq!(&buf[2..], &[0, 0, 0, 0, 0, 0]);
386     }
387 
388     #[test]
must_return_error_when_nothing_can_be_written_to_underlying_channel()389     fn must_return_error_when_nothing_can_be_written_to_underlying_channel() {
390         let mem = TBufferChannel::with_capacity(0, 0);
391         let mut t = TBufferedWriteTransport::with_capacity(1, mem);
392 
393         let b = vec![0; 10];
394         let r = t.write(&b);
395 
396         // should have written 1 byte
397         assert_eq!(r.unwrap(), 1);
398 
399         // let's try again...
400         let r = t.write(&b[1..]);
401 
402         // this time we'll error out because the auto-flush failed
403         assert!(r.is_err());
404     }
405 
406     #[test]
must_return_zero_if_caller_calls_write_with_empty_buffer()407     fn must_return_zero_if_caller_calls_write_with_empty_buffer() {
408         let mem = TBufferChannel::with_capacity(0, 10);
409         let mut t = TBufferedWriteTransport::with_capacity(10, mem);
410 
411         let r = t.write(&[]);
412         let expected: [u8; 0] = [];
413 
414         assert_eq!(r.unwrap(), 0);
415         assert_eq_transport_written_bytes!(t, expected);
416     }
417 
418     #[test]
must_auto_flush_if_write_buffer_full()419     fn must_auto_flush_if_write_buffer_full() {
420         let mem = TBufferChannel::with_capacity(0, 8);
421         let mut t = TBufferedWriteTransport::with_capacity(4, mem);
422 
423         let b0 = [0x00, 0x01, 0x02, 0x03];
424         let b1 = [0x04, 0x05, 0x06, 0x07];
425 
426         // write the first 4 bytes; we've now filled the transport's write buffer
427         let r = t.write(&b0);
428         assert_eq!(r.unwrap(), 4);
429 
430         // try write the next 4 bytes; this causes the transport to auto-flush the first 4 bytes
431         let r = t.write(&b1);
432         assert_eq!(r.unwrap(), 4);
433 
434         // check that in writing the second 4 bytes we auto-flushed the first 4 bytes
435         assert_eq_transport_num_written_bytes!(t, 4);
436         assert_eq_transport_written_bytes!(t, b0);
437         t.channel.empty_write_buffer();
438 
439         // now flush the transport to push the second 4 bytes to the underlying channel
440         assert!(t.flush().is_ok());
441 
442         // check that we wrote out the second 4 bytes
443         assert_eq_transport_written_bytes!(t, b1);
444     }
445 
446     #[test]
must_write_to_inner_transport_on_flush()447     fn must_write_to_inner_transport_on_flush() {
448         let mem = TBufferChannel::with_capacity(10, 10);
449         let mut t = TBufferedWriteTransport::new(mem);
450 
451         let b: [u8; 5] = [0, 1, 2, 3, 4];
452         assert_eq!(t.write(&b).unwrap(), 5);
453         assert_eq_transport_num_written_bytes!(t, 0);
454 
455         assert!(t.flush().is_ok());
456 
457         assert_eq_transport_written_bytes!(t, b);
458     }
459 
460     #[test]
must_write_successfully_after_flush()461     fn must_write_successfully_after_flush() {
462         let mem = TBufferChannel::with_capacity(0, 5);
463         let mut t = TBufferedWriteTransport::with_capacity(5, mem);
464 
465         // write and flush
466         let b: [u8; 5] = [0, 1, 2, 3, 4];
467         assert_eq!(t.write(&b).unwrap(), 5);
468         assert!(t.flush().is_ok());
469 
470         // check the flushed bytes
471         assert_eq_transport_written_bytes!(t, b);
472 
473         // reset our underlying transport
474         t.channel.empty_write_buffer();
475 
476         // write and flush again
477         assert_eq!(t.write(&b).unwrap(), 5);
478         assert!(t.flush().is_ok());
479 
480         // check the flushed bytes
481         assert_eq_transport_written_bytes!(t, b);
482     }
483 }
484