1// Copyright 2010 The Go Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style 3// license that can be found in the LICENSE file. 4 5package io_test 6 7import ( 8 "bytes" 9 "crypto/sha1" 10 "errors" 11 "fmt" 12 . "io" 13 "io/ioutil" 14 "runtime" 15 "strings" 16 "testing" 17 "time" 18) 19 20func TestMultiReader(t *testing.T) { 21 var mr Reader 22 var buf []byte 23 nread := 0 24 withFooBar := func(tests func()) { 25 r1 := strings.NewReader("foo ") 26 r2 := strings.NewReader("") 27 r3 := strings.NewReader("bar") 28 mr = MultiReader(r1, r2, r3) 29 buf = make([]byte, 20) 30 tests() 31 } 32 expectRead := func(size int, expected string, eerr error) { 33 nread++ 34 n, gerr := mr.Read(buf[0:size]) 35 if n != len(expected) { 36 t.Errorf("#%d, expected %d bytes; got %d", 37 nread, len(expected), n) 38 } 39 got := string(buf[0:n]) 40 if got != expected { 41 t.Errorf("#%d, expected %q; got %q", 42 nread, expected, got) 43 } 44 if gerr != eerr { 45 t.Errorf("#%d, expected error %v; got %v", 46 nread, eerr, gerr) 47 } 48 buf = buf[n:] 49 } 50 withFooBar(func() { 51 expectRead(2, "fo", nil) 52 expectRead(5, "o ", nil) 53 expectRead(5, "bar", nil) 54 expectRead(5, "", EOF) 55 }) 56 withFooBar(func() { 57 expectRead(4, "foo ", nil) 58 expectRead(1, "b", nil) 59 expectRead(3, "ar", nil) 60 expectRead(1, "", EOF) 61 }) 62 withFooBar(func() { 63 expectRead(5, "foo ", nil) 64 }) 65} 66 67func TestMultiWriter(t *testing.T) { 68 sink := new(bytes.Buffer) 69 // Hide bytes.Buffer's WriteString method: 70 testMultiWriter(t, struct { 71 Writer 72 fmt.Stringer 73 }{sink, sink}) 74} 75 76func TestMultiWriter_String(t *testing.T) { 77 testMultiWriter(t, new(bytes.Buffer)) 78} 79 80// Test that a multiWriter.WriteString calls results in at most 1 allocation, 81// even if multiple targets don't support WriteString. 82func TestMultiWriter_WriteStringSingleAlloc(t *testing.T) { 83 t.Skip("skipping on gccgo until we have escape analysis") 84 var sink1, sink2 bytes.Buffer 85 type simpleWriter struct { // hide bytes.Buffer's WriteString 86 Writer 87 } 88 mw := MultiWriter(simpleWriter{&sink1}, simpleWriter{&sink2}) 89 allocs := int(testing.AllocsPerRun(1000, func() { 90 WriteString(mw, "foo") 91 })) 92 if allocs != 1 { 93 t.Errorf("num allocations = %d; want 1", allocs) 94 } 95} 96 97type writeStringChecker struct{ called bool } 98 99func (c *writeStringChecker) WriteString(s string) (n int, err error) { 100 c.called = true 101 return len(s), nil 102} 103 104func (c *writeStringChecker) Write(p []byte) (n int, err error) { 105 return len(p), nil 106} 107 108func TestMultiWriter_StringCheckCall(t *testing.T) { 109 var c writeStringChecker 110 mw := MultiWriter(&c) 111 WriteString(mw, "foo") 112 if !c.called { 113 t.Error("did not see WriteString call to writeStringChecker") 114 } 115} 116 117func testMultiWriter(t *testing.T, sink interface { 118 Writer 119 fmt.Stringer 120}) { 121 sha1 := sha1.New() 122 mw := MultiWriter(sha1, sink) 123 124 sourceString := "My input text." 125 source := strings.NewReader(sourceString) 126 written, err := Copy(mw, source) 127 128 if written != int64(len(sourceString)) { 129 t.Errorf("short write of %d, not %d", written, len(sourceString)) 130 } 131 132 if err != nil { 133 t.Errorf("unexpected error: %v", err) 134 } 135 136 sha1hex := fmt.Sprintf("%x", sha1.Sum(nil)) 137 if sha1hex != "01cb303fa8c30a64123067c5aa6284ba7ec2d31b" { 138 t.Error("incorrect sha1 value") 139 } 140 141 if sink.String() != sourceString { 142 t.Errorf("expected %q; got %q", sourceString, sink.String()) 143 } 144} 145 146// writerFunc is an io.Writer implemented by the underlying func. 147type writerFunc func(p []byte) (int, error) 148 149func (f writerFunc) Write(p []byte) (int, error) { 150 return f(p) 151} 152 153// Test that MultiWriter properly flattens chained multiWriters. 154func TestMultiWriterSingleChainFlatten(t *testing.T) { 155 pc := make([]uintptr, 1000) // 1000 should fit the full stack 156 n := runtime.Callers(0, pc) 157 var myDepth = callDepth(pc[:n]) 158 var writeDepth int // will contain the depth from which writerFunc.Writer was called 159 var w Writer = MultiWriter(writerFunc(func(p []byte) (int, error) { 160 n := runtime.Callers(1, pc) 161 writeDepth += callDepth(pc[:n]) 162 return 0, nil 163 })) 164 165 mw := w 166 // chain a bunch of multiWriters 167 for i := 0; i < 100; i++ { 168 mw = MultiWriter(w) 169 } 170 171 mw = MultiWriter(w, mw, w, mw) 172 mw.Write(nil) // don't care about errors, just want to check the call-depth for Write 173 174 if writeDepth != 4*(myDepth+2) { // 2 should be multiWriter.Write and writerFunc.Write 175 t.Errorf("multiWriter did not flatten chained multiWriters: expected writeDepth %d, got %d", 176 4*(myDepth+2), writeDepth) 177 } 178} 179 180func TestMultiWriterError(t *testing.T) { 181 f1 := writerFunc(func(p []byte) (int, error) { 182 return len(p) / 2, ErrShortWrite 183 }) 184 f2 := writerFunc(func(p []byte) (int, error) { 185 t.Errorf("MultiWriter called f2.Write") 186 return len(p), nil 187 }) 188 w := MultiWriter(f1, f2) 189 n, err := w.Write(make([]byte, 100)) 190 if n != 50 || err != ErrShortWrite { 191 t.Errorf("Write = %d, %v, want 50, ErrShortWrite", n, err) 192 } 193} 194 195// Test that MultiReader copies the input slice and is insulated from future modification. 196func TestMultiReaderCopy(t *testing.T) { 197 slice := []Reader{strings.NewReader("hello world")} 198 r := MultiReader(slice...) 199 slice[0] = nil 200 data, err := ioutil.ReadAll(r) 201 if err != nil || string(data) != "hello world" { 202 t.Errorf("ReadAll() = %q, %v, want %q, nil", data, err, "hello world") 203 } 204} 205 206// Test that MultiWriter copies the input slice and is insulated from future modification. 207func TestMultiWriterCopy(t *testing.T) { 208 var buf bytes.Buffer 209 slice := []Writer{&buf} 210 w := MultiWriter(slice...) 211 slice[0] = nil 212 n, err := w.Write([]byte("hello world")) 213 if err != nil || n != 11 { 214 t.Errorf("Write(`hello world`) = %d, %v, want 11, nil", n, err) 215 } 216 if buf.String() != "hello world" { 217 t.Errorf("buf.String() = %q, want %q", buf.String(), "hello world") 218 } 219} 220 221// readerFunc is an io.Reader implemented by the underlying func. 222type readerFunc func(p []byte) (int, error) 223 224func (f readerFunc) Read(p []byte) (int, error) { 225 return f(p) 226} 227 228// callDepth returns the logical call depth for the given PCs. 229func callDepth(callers []uintptr) (depth int) { 230 frames := runtime.CallersFrames(callers) 231 more := true 232 for more { 233 _, more = frames.Next() 234 depth++ 235 } 236 return 237} 238 239// Test that MultiReader properly flattens chained multiReaders when Read is called 240func TestMultiReaderFlatten(t *testing.T) { 241 pc := make([]uintptr, 1000) // 1000 should fit the full stack 242 n := runtime.Callers(0, pc) 243 var myDepth = callDepth(pc[:n]) 244 var readDepth int // will contain the depth from which fakeReader.Read was called 245 var r Reader = MultiReader(readerFunc(func(p []byte) (int, error) { 246 n := runtime.Callers(1, pc) 247 readDepth = callDepth(pc[:n]) 248 return 0, errors.New("irrelevant") 249 })) 250 251 // chain a bunch of multiReaders 252 for i := 0; i < 100; i++ { 253 r = MultiReader(r) 254 } 255 256 r.Read(nil) // don't care about errors, just want to check the call-depth for Read 257 258 if readDepth != myDepth+2 { // 2 should be multiReader.Read and fakeReader.Read 259 t.Errorf("multiReader did not flatten chained multiReaders: expected readDepth %d, got %d", 260 myDepth+2, readDepth) 261 } 262} 263 264// byteAndEOFReader is a Reader which reads one byte (the underlying 265// byte) and io.EOF at once in its Read call. 266type byteAndEOFReader byte 267 268func (b byteAndEOFReader) Read(p []byte) (n int, err error) { 269 if len(p) == 0 { 270 // Read(0 bytes) is useless. We expect no such useless 271 // calls in this test. 272 panic("unexpected call") 273 } 274 p[0] = byte(b) 275 return 1, EOF 276} 277 278// This used to yield bytes forever; issue 16795. 279func TestMultiReaderSingleByteWithEOF(t *testing.T) { 280 got, err := ioutil.ReadAll(LimitReader(MultiReader(byteAndEOFReader('a'), byteAndEOFReader('b')), 10)) 281 if err != nil { 282 t.Fatal(err) 283 } 284 const want = "ab" 285 if string(got) != want { 286 t.Errorf("got %q; want %q", got, want) 287 } 288} 289 290// Test that a reader returning (n, EOF) at the end of a MultiReader 291// chain continues to return EOF on its final read, rather than 292// yielding a (0, EOF). 293func TestMultiReaderFinalEOF(t *testing.T) { 294 r := MultiReader(bytes.NewReader(nil), byteAndEOFReader('a')) 295 buf := make([]byte, 2) 296 n, err := r.Read(buf) 297 if n != 1 || err != EOF { 298 t.Errorf("got %v, %v; want 1, EOF", n, err) 299 } 300} 301 302func TestMultiReaderFreesExhaustedReaders(t *testing.T) { 303 if runtime.Compiler == "gccgo" { 304 t.Skip("skipping finalizer test on gccgo with conservative GC") 305 } 306 307 var mr Reader 308 closed := make(chan struct{}) 309 // The closure ensures that we don't have a live reference to buf1 310 // on our stack after MultiReader is inlined (Issue 18819). This 311 // is a work around for a limitation in liveness analysis. 312 func() { 313 buf1 := bytes.NewReader([]byte("foo")) 314 buf2 := bytes.NewReader([]byte("bar")) 315 mr = MultiReader(buf1, buf2) 316 runtime.SetFinalizer(buf1, func(*bytes.Reader) { 317 close(closed) 318 }) 319 }() 320 321 buf := make([]byte, 4) 322 if n, err := ReadFull(mr, buf); err != nil || string(buf) != "foob" { 323 t.Fatalf(`ReadFull = %d (%q), %v; want 3, "foo", nil`, n, buf[:n], err) 324 } 325 326 runtime.GC() 327 select { 328 case <-closed: 329 case <-time.After(5 * time.Second): 330 t.Fatal("timeout waiting for collection of buf1") 331 } 332 333 if n, err := ReadFull(mr, buf[:2]); err != nil || string(buf[:2]) != "ar" { 334 t.Fatalf(`ReadFull = %d (%q), %v; want 2, "ar", nil`, n, buf[:n], err) 335 } 336} 337 338func TestInterleavedMultiReader(t *testing.T) { 339 r1 := strings.NewReader("123") 340 r2 := strings.NewReader("45678") 341 342 mr1 := MultiReader(r1, r2) 343 mr2 := MultiReader(mr1) 344 345 buf := make([]byte, 4) 346 347 // Have mr2 use mr1's []Readers. 348 // Consume r1 (and clear it for GC to handle) and consume part of r2. 349 n, err := ReadFull(mr2, buf) 350 if got := string(buf[:n]); got != "1234" || err != nil { 351 t.Errorf(`ReadFull(mr2) = (%q, %v), want ("1234", nil)`, got, err) 352 } 353 354 // Consume the rest of r2 via mr1. 355 // This should not panic even though mr2 cleared r1. 356 n, err = ReadFull(mr1, buf) 357 if got := string(buf[:n]); got != "5678" || err != nil { 358 t.Errorf(`ReadFull(mr1) = (%q, %v), want ("5678", nil)`, got, err) 359 } 360} 361