1 use super::EncoderWriter;
2 use crate::tests::random_config;
3 use crate::{encode_config, encode_config_buf, STANDARD_NO_PAD, URL_SAFE};
4 
5 use std::io::{Cursor, Write};
6 use std::{cmp, io, str};
7 
8 use rand::Rng;
9 
10 #[test]
encode_three_bytes()11 fn encode_three_bytes() {
12     let mut c = Cursor::new(Vec::new());
13     {
14         let mut enc = EncoderWriter::new(&mut c, URL_SAFE);
15 
16         let sz = enc.write(b"abc").unwrap();
17         assert_eq!(sz, 3);
18     }
19     assert_eq!(&c.get_ref()[..], encode_config("abc", URL_SAFE).as_bytes());
20 }
21 
22 #[test]
encode_nine_bytes_two_writes()23 fn encode_nine_bytes_two_writes() {
24     let mut c = Cursor::new(Vec::new());
25     {
26         let mut enc = EncoderWriter::new(&mut c, URL_SAFE);
27 
28         let sz = enc.write(b"abcdef").unwrap();
29         assert_eq!(sz, 6);
30         let sz = enc.write(b"ghi").unwrap();
31         assert_eq!(sz, 3);
32     }
33     assert_eq!(
34         &c.get_ref()[..],
35         encode_config("abcdefghi", URL_SAFE).as_bytes()
36     );
37 }
38 
39 #[test]
encode_one_then_two_bytes()40 fn encode_one_then_two_bytes() {
41     let mut c = Cursor::new(Vec::new());
42     {
43         let mut enc = EncoderWriter::new(&mut c, URL_SAFE);
44 
45         let sz = enc.write(b"a").unwrap();
46         assert_eq!(sz, 1);
47         let sz = enc.write(b"bc").unwrap();
48         assert_eq!(sz, 2);
49     }
50     assert_eq!(&c.get_ref()[..], encode_config("abc", URL_SAFE).as_bytes());
51 }
52 
53 #[test]
encode_one_then_five_bytes()54 fn encode_one_then_five_bytes() {
55     let mut c = Cursor::new(Vec::new());
56     {
57         let mut enc = EncoderWriter::new(&mut c, URL_SAFE);
58 
59         let sz = enc.write(b"a").unwrap();
60         assert_eq!(sz, 1);
61         let sz = enc.write(b"bcdef").unwrap();
62         assert_eq!(sz, 5);
63     }
64     assert_eq!(
65         &c.get_ref()[..],
66         encode_config("abcdef", URL_SAFE).as_bytes()
67     );
68 }
69 
70 #[test]
encode_1_2_3_bytes()71 fn encode_1_2_3_bytes() {
72     let mut c = Cursor::new(Vec::new());
73     {
74         let mut enc = EncoderWriter::new(&mut c, URL_SAFE);
75 
76         let sz = enc.write(b"a").unwrap();
77         assert_eq!(sz, 1);
78         let sz = enc.write(b"bc").unwrap();
79         assert_eq!(sz, 2);
80         let sz = enc.write(b"def").unwrap();
81         assert_eq!(sz, 3);
82     }
83     assert_eq!(
84         &c.get_ref()[..],
85         encode_config("abcdef", URL_SAFE).as_bytes()
86     );
87 }
88 
89 #[test]
encode_with_padding()90 fn encode_with_padding() {
91     let mut c = Cursor::new(Vec::new());
92     {
93         let mut enc = EncoderWriter::new(&mut c, URL_SAFE);
94 
95         enc.write_all(b"abcd").unwrap();
96 
97         enc.flush().unwrap();
98     }
99     assert_eq!(&c.get_ref()[..], encode_config("abcd", URL_SAFE).as_bytes());
100 }
101 
102 #[test]
encode_with_padding_multiple_writes()103 fn encode_with_padding_multiple_writes() {
104     let mut c = Cursor::new(Vec::new());
105     {
106         let mut enc = EncoderWriter::new(&mut c, URL_SAFE);
107 
108         assert_eq!(1, enc.write(b"a").unwrap());
109         assert_eq!(2, enc.write(b"bc").unwrap());
110         assert_eq!(3, enc.write(b"def").unwrap());
111         assert_eq!(1, enc.write(b"g").unwrap());
112 
113         enc.flush().unwrap();
114     }
115     assert_eq!(
116         &c.get_ref()[..],
117         encode_config("abcdefg", URL_SAFE).as_bytes()
118     );
119 }
120 
121 #[test]
finish_writes_extra_byte()122 fn finish_writes_extra_byte() {
123     let mut c = Cursor::new(Vec::new());
124     {
125         let mut enc = EncoderWriter::new(&mut c, URL_SAFE);
126 
127         assert_eq!(6, enc.write(b"abcdef").unwrap());
128 
129         // will be in extra
130         assert_eq!(1, enc.write(b"g").unwrap());
131 
132         // 1 trailing byte = 2 encoded chars
133         let _ = enc.finish().unwrap();
134     }
135     assert_eq!(
136         &c.get_ref()[..],
137         encode_config("abcdefg", URL_SAFE).as_bytes()
138     );
139 }
140 
141 #[test]
write_partial_chunk_encodes_partial_chunk()142 fn write_partial_chunk_encodes_partial_chunk() {
143     let mut c = Cursor::new(Vec::new());
144     {
145         let mut enc = EncoderWriter::new(&mut c, STANDARD_NO_PAD);
146 
147         // nothing encoded yet
148         assert_eq!(2, enc.write(b"ab").unwrap());
149         // encoded here
150         let _ = enc.finish().unwrap();
151     }
152     assert_eq!(
153         &c.get_ref()[..],
154         encode_config("ab", STANDARD_NO_PAD).as_bytes()
155     );
156     assert_eq!(3, c.get_ref().len());
157 }
158 
159 #[test]
write_1_chunk_encodes_complete_chunk()160 fn write_1_chunk_encodes_complete_chunk() {
161     let mut c = Cursor::new(Vec::new());
162     {
163         let mut enc = EncoderWriter::new(&mut c, STANDARD_NO_PAD);
164 
165         assert_eq!(3, enc.write(b"abc").unwrap());
166         let _ = enc.finish().unwrap();
167     }
168     assert_eq!(
169         &c.get_ref()[..],
170         encode_config("abc", STANDARD_NO_PAD).as_bytes()
171     );
172     assert_eq!(4, c.get_ref().len());
173 }
174 
175 #[test]
write_1_chunk_and_partial_encodes_only_complete_chunk()176 fn write_1_chunk_and_partial_encodes_only_complete_chunk() {
177     let mut c = Cursor::new(Vec::new());
178     {
179         let mut enc = EncoderWriter::new(&mut c, STANDARD_NO_PAD);
180 
181         // "d" not written
182         assert_eq!(3, enc.write(b"abcd").unwrap());
183         let _ = enc.finish().unwrap();
184     }
185     assert_eq!(
186         &c.get_ref()[..],
187         encode_config("abc", STANDARD_NO_PAD).as_bytes()
188     );
189     assert_eq!(4, c.get_ref().len());
190 }
191 
192 #[test]
write_2_partials_to_exactly_complete_chunk_encodes_complete_chunk()193 fn write_2_partials_to_exactly_complete_chunk_encodes_complete_chunk() {
194     let mut c = Cursor::new(Vec::new());
195     {
196         let mut enc = EncoderWriter::new(&mut c, STANDARD_NO_PAD);
197 
198         assert_eq!(1, enc.write(b"a").unwrap());
199         assert_eq!(2, enc.write(b"bc").unwrap());
200         let _ = enc.finish().unwrap();
201     }
202     assert_eq!(
203         &c.get_ref()[..],
204         encode_config("abc", STANDARD_NO_PAD).as_bytes()
205     );
206     assert_eq!(4, c.get_ref().len());
207 }
208 
209 #[test]
write_partial_then_enough_to_complete_chunk_but_not_complete_another_chunk_encodes_complete_chunk_without_consuming_remaining( )210 fn write_partial_then_enough_to_complete_chunk_but_not_complete_another_chunk_encodes_complete_chunk_without_consuming_remaining(
211 ) {
212     let mut c = Cursor::new(Vec::new());
213     {
214         let mut enc = EncoderWriter::new(&mut c, STANDARD_NO_PAD);
215 
216         assert_eq!(1, enc.write(b"a").unwrap());
217         // doesn't consume "d"
218         assert_eq!(2, enc.write(b"bcd").unwrap());
219         let _ = enc.finish().unwrap();
220     }
221     assert_eq!(
222         &c.get_ref()[..],
223         encode_config("abc", STANDARD_NO_PAD).as_bytes()
224     );
225     assert_eq!(4, c.get_ref().len());
226 }
227 
228 #[test]
write_partial_then_enough_to_complete_chunk_and_another_chunk_encodes_complete_chunks()229 fn write_partial_then_enough_to_complete_chunk_and_another_chunk_encodes_complete_chunks() {
230     let mut c = Cursor::new(Vec::new());
231     {
232         let mut enc = EncoderWriter::new(&mut c, STANDARD_NO_PAD);
233 
234         assert_eq!(1, enc.write(b"a").unwrap());
235         // completes partial chunk, and another chunk
236         assert_eq!(5, enc.write(b"bcdef").unwrap());
237         let _ = enc.finish().unwrap();
238     }
239     assert_eq!(
240         &c.get_ref()[..],
241         encode_config("abcdef", STANDARD_NO_PAD).as_bytes()
242     );
243     assert_eq!(8, c.get_ref().len());
244 }
245 
246 #[test]
write_partial_then_enough_to_complete_chunk_and_another_chunk_and_another_partial_chunk_encodes_only_complete_chunks( )247 fn write_partial_then_enough_to_complete_chunk_and_another_chunk_and_another_partial_chunk_encodes_only_complete_chunks(
248 ) {
249     let mut c = Cursor::new(Vec::new());
250     {
251         let mut enc = EncoderWriter::new(&mut c, STANDARD_NO_PAD);
252 
253         assert_eq!(1, enc.write(b"a").unwrap());
254         // completes partial chunk, and another chunk, with one more partial chunk that's not
255         // consumed
256         assert_eq!(5, enc.write(b"bcdefe").unwrap());
257         let _ = enc.finish().unwrap();
258     }
259     assert_eq!(
260         &c.get_ref()[..],
261         encode_config("abcdef", STANDARD_NO_PAD).as_bytes()
262     );
263     assert_eq!(8, c.get_ref().len());
264 }
265 
266 #[test]
drop_calls_finish_for_you()267 fn drop_calls_finish_for_you() {
268     let mut c = Cursor::new(Vec::new());
269     {
270         let mut enc = EncoderWriter::new(&mut c, STANDARD_NO_PAD);
271         assert_eq!(1, enc.write(b"a").unwrap());
272     }
273     assert_eq!(
274         &c.get_ref()[..],
275         encode_config("a", STANDARD_NO_PAD).as_bytes()
276     );
277     assert_eq!(2, c.get_ref().len());
278 }
279 
280 #[test]
every_possible_split_of_input()281 fn every_possible_split_of_input() {
282     let mut rng = rand::thread_rng();
283     let mut orig_data = Vec::<u8>::new();
284     let mut stream_encoded = Vec::<u8>::new();
285     let mut normal_encoded = String::new();
286 
287     let size = 5_000;
288 
289     for i in 0..size {
290         orig_data.clear();
291         stream_encoded.clear();
292         normal_encoded.clear();
293 
294         for _ in 0..size {
295             orig_data.push(rng.gen());
296         }
297 
298         let config = random_config(&mut rng);
299         encode_config_buf(&orig_data, config, &mut normal_encoded);
300 
301         {
302             let mut stream_encoder = EncoderWriter::new(&mut stream_encoded, config);
303             // Write the first i bytes, then the rest
304             stream_encoder.write_all(&orig_data[0..i]).unwrap();
305             stream_encoder.write_all(&orig_data[i..]).unwrap();
306         }
307 
308         assert_eq!(normal_encoded, str::from_utf8(&stream_encoded).unwrap());
309     }
310 }
311 
312 #[test]
encode_random_config_matches_normal_encode_reasonable_input_len()313 fn encode_random_config_matches_normal_encode_reasonable_input_len() {
314     // choose up to 2 * buf size, so ~half the time it'll use a full buffer
315     do_encode_random_config_matches_normal_encode(super::encoder::BUF_SIZE * 2)
316 }
317 
318 #[test]
encode_random_config_matches_normal_encode_tiny_input_len()319 fn encode_random_config_matches_normal_encode_tiny_input_len() {
320     do_encode_random_config_matches_normal_encode(10)
321 }
322 
323 #[test]
retrying_writes_that_error_with_interrupted_works()324 fn retrying_writes_that_error_with_interrupted_works() {
325     let mut rng = rand::thread_rng();
326     let mut orig_data = Vec::<u8>::new();
327     let mut stream_encoded = Vec::<u8>::new();
328     let mut normal_encoded = String::new();
329 
330     for _ in 0..1_000 {
331         orig_data.clear();
332         stream_encoded.clear();
333         normal_encoded.clear();
334 
335         let orig_len: usize = rng.gen_range(100, 20_000);
336         for _ in 0..orig_len {
337             orig_data.push(rng.gen());
338         }
339 
340         // encode the normal way
341         let config = random_config(&mut rng);
342         encode_config_buf(&orig_data, config, &mut normal_encoded);
343 
344         // encode via the stream encoder
345         {
346             let mut interrupt_rng = rand::thread_rng();
347             let mut interrupting_writer = InterruptingWriter {
348                 w: &mut stream_encoded,
349                 rng: &mut interrupt_rng,
350                 fraction: 0.8,
351             };
352 
353             let mut stream_encoder = EncoderWriter::new(&mut interrupting_writer, config);
354             let mut bytes_consumed = 0;
355             while bytes_consumed < orig_len {
356                 // use short inputs since we want to use `extra` a lot as that's what needs rollback
357                 // when errors occur
358                 let input_len: usize = cmp::min(rng.gen_range(0, 10), orig_len - bytes_consumed);
359 
360                 retry_interrupted_write_all(
361                     &mut stream_encoder,
362                     &orig_data[bytes_consumed..bytes_consumed + input_len],
363                 )
364                 .unwrap();
365 
366                 bytes_consumed += input_len;
367             }
368 
369             loop {
370                 let res = stream_encoder.finish();
371                 match res {
372                     Ok(_) => break,
373                     Err(e) => match e.kind() {
374                         io::ErrorKind::Interrupted => continue,
375                         _ => Err(e).unwrap(), // bail
376                     },
377                 }
378             }
379 
380             assert_eq!(orig_len, bytes_consumed);
381         }
382 
383         assert_eq!(normal_encoded, str::from_utf8(&stream_encoded).unwrap());
384     }
385 }
386 
387 #[test]
writes_that_only_write_part_of_input_and_sometimes_interrupt_produce_correct_encoded_data()388 fn writes_that_only_write_part_of_input_and_sometimes_interrupt_produce_correct_encoded_data() {
389     let mut rng = rand::thread_rng();
390     let mut orig_data = Vec::<u8>::new();
391     let mut stream_encoded = Vec::<u8>::new();
392     let mut normal_encoded = String::new();
393 
394     for _ in 0..1_000 {
395         orig_data.clear();
396         stream_encoded.clear();
397         normal_encoded.clear();
398 
399         let orig_len: usize = rng.gen_range(100, 20_000);
400         for _ in 0..orig_len {
401             orig_data.push(rng.gen());
402         }
403 
404         // encode the normal way
405         let config = random_config(&mut rng);
406         encode_config_buf(&orig_data, config, &mut normal_encoded);
407 
408         // encode via the stream encoder
409         {
410             let mut partial_rng = rand::thread_rng();
411             let mut partial_writer = PartialInterruptingWriter {
412                 w: &mut stream_encoded,
413                 rng: &mut partial_rng,
414                 full_input_fraction: 0.1,
415                 no_interrupt_fraction: 0.1,
416             };
417 
418             let mut stream_encoder = EncoderWriter::new(&mut partial_writer, config);
419             let mut bytes_consumed = 0;
420             while bytes_consumed < orig_len {
421                 // use at most medium-length inputs to exercise retry logic more aggressively
422                 let input_len: usize = cmp::min(rng.gen_range(0, 100), orig_len - bytes_consumed);
423 
424                 let res =
425                     stream_encoder.write(&orig_data[bytes_consumed..bytes_consumed + input_len]);
426 
427                 // retry on interrupt
428                 match res {
429                     Ok(len) => bytes_consumed += len,
430                     Err(e) => match e.kind() {
431                         io::ErrorKind::Interrupted => continue,
432                         _ => {
433                             panic!("should not see other errors");
434                         }
435                     },
436                 }
437             }
438 
439             stream_encoder.finish().unwrap();
440 
441             assert_eq!(orig_len, bytes_consumed);
442         }
443 
444         assert_eq!(normal_encoded, str::from_utf8(&stream_encoded).unwrap());
445     }
446 }
447 
448 /// Retry writes until all the data is written or an error that isn't Interrupted is returned.
retry_interrupted_write_all<W: Write>(w: &mut W, buf: &[u8]) -> io::Result<()>449 fn retry_interrupted_write_all<W: Write>(w: &mut W, buf: &[u8]) -> io::Result<()> {
450     let mut bytes_consumed = 0;
451 
452     while bytes_consumed < buf.len() {
453         let res = w.write(&buf[bytes_consumed..]);
454 
455         match res {
456             Ok(len) => bytes_consumed += len,
457             Err(e) => match e.kind() {
458                 io::ErrorKind::Interrupted => continue,
459                 _ => return Err(e),
460             },
461         }
462     }
463 
464     Ok(())
465 }
466 
do_encode_random_config_matches_normal_encode(max_input_len: usize)467 fn do_encode_random_config_matches_normal_encode(max_input_len: usize) {
468     let mut rng = rand::thread_rng();
469     let mut orig_data = Vec::<u8>::new();
470     let mut stream_encoded = Vec::<u8>::new();
471     let mut normal_encoded = String::new();
472 
473     for _ in 0..1_000 {
474         orig_data.clear();
475         stream_encoded.clear();
476         normal_encoded.clear();
477 
478         let orig_len: usize = rng.gen_range(100, 20_000);
479         for _ in 0..orig_len {
480             orig_data.push(rng.gen());
481         }
482 
483         // encode the normal way
484         let config = random_config(&mut rng);
485         encode_config_buf(&orig_data, config, &mut normal_encoded);
486 
487         // encode via the stream encoder
488         {
489             let mut stream_encoder = EncoderWriter::new(&mut stream_encoded, config);
490             let mut bytes_consumed = 0;
491             while bytes_consumed < orig_len {
492                 let input_len: usize =
493                     cmp::min(rng.gen_range(0, max_input_len), orig_len - bytes_consumed);
494 
495                 // write a little bit of the data
496                 stream_encoder
497                     .write_all(&orig_data[bytes_consumed..bytes_consumed + input_len])
498                     .unwrap();
499 
500                 bytes_consumed += input_len;
501             }
502 
503             stream_encoder.finish().unwrap();
504 
505             assert_eq!(orig_len, bytes_consumed);
506         }
507 
508         assert_eq!(normal_encoded, str::from_utf8(&stream_encoded).unwrap());
509     }
510 }
511 
512 /// A `Write` implementation that returns Interrupted some fraction of the time, randomly.
513 struct InterruptingWriter<'a, W: 'a + Write, R: 'a + Rng> {
514     w: &'a mut W,
515     rng: &'a mut R,
516     /// In [0, 1]. If a random number in [0, 1] is  `<= threshold`, `Write` methods will return
517     /// an `Interrupted` error
518     fraction: f64,
519 }
520 
521 impl<'a, W: Write, R: Rng> Write for InterruptingWriter<'a, W, R> {
write(&mut self, buf: &[u8]) -> io::Result<usize>522     fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
523         if self.rng.gen_range(0.0, 1.0) <= self.fraction {
524             return Err(io::Error::new(io::ErrorKind::Interrupted, "interrupted"));
525         }
526 
527         self.w.write(buf)
528     }
529 
flush(&mut self) -> io::Result<()>530     fn flush(&mut self) -> io::Result<()> {
531         if self.rng.gen_range(0.0, 1.0) <= self.fraction {
532             return Err(io::Error::new(io::ErrorKind::Interrupted, "interrupted"));
533         }
534 
535         self.w.flush()
536     }
537 }
538 
539 /// A `Write` implementation that sometimes will only write part of its input.
540 struct PartialInterruptingWriter<'a, W: 'a + Write, R: 'a + Rng> {
541     w: &'a mut W,
542     rng: &'a mut R,
543     /// In [0, 1]. If a random number in [0, 1] is  `<= threshold`, `write()` will write all its
544     /// input. Otherwise, it will write a random substring
545     full_input_fraction: f64,
546     no_interrupt_fraction: f64,
547 }
548 
549 impl<'a, W: Write, R: Rng> Write for PartialInterruptingWriter<'a, W, R> {
write(&mut self, buf: &[u8]) -> io::Result<usize>550     fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
551         if self.rng.gen_range(0.0, 1.0) > self.no_interrupt_fraction {
552             return Err(io::Error::new(io::ErrorKind::Interrupted, "interrupted"));
553         }
554 
555         if self.rng.gen_range(0.0, 1.0) <= self.full_input_fraction || buf.len() == 0 {
556             // pass through the buf untouched
557             self.w.write(buf)
558         } else {
559             // only use a prefix of it
560             self.w
561                 .write(&buf[0..(self.rng.gen_range(0, buf.len() - 1))])
562         }
563     }
564 
flush(&mut self) -> io::Result<()>565     fn flush(&mut self) -> io::Result<()> {
566         self.w.flush()
567     }
568 }
569