1// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style 3// license that can be found in the LICENSE file. 4 5package websocket 6 7import ( 8 "bufio" 9 "bytes" 10 "errors" 11 "fmt" 12 "io" 13 "io/ioutil" 14 "net" 15 "reflect" 16 "sync" 17 "testing" 18 "testing/iotest" 19 "time" 20) 21 22var _ net.Error = errWriteTimeout 23 24type fakeNetConn struct { 25 io.Reader 26 io.Writer 27} 28 29func (c fakeNetConn) Close() error { return nil } 30func (c fakeNetConn) LocalAddr() net.Addr { return localAddr } 31func (c fakeNetConn) RemoteAddr() net.Addr { return remoteAddr } 32func (c fakeNetConn) SetDeadline(t time.Time) error { return nil } 33func (c fakeNetConn) SetReadDeadline(t time.Time) error { return nil } 34func (c fakeNetConn) SetWriteDeadline(t time.Time) error { return nil } 35 36type fakeAddr int 37 38var ( 39 localAddr = fakeAddr(1) 40 remoteAddr = fakeAddr(2) 41) 42 43func (a fakeAddr) Network() string { 44 return "net" 45} 46 47func (a fakeAddr) String() string { 48 return "str" 49} 50 51// newTestConn creates a connnection backed by a fake network connection using 52// default values for buffering. 53func newTestConn(r io.Reader, w io.Writer, isServer bool) *Conn { 54 return newConn(fakeNetConn{Reader: r, Writer: w}, isServer, 1024, 1024, nil, nil, nil) 55} 56 57func TestFraming(t *testing.T) { 58 frameSizes := []int{ 59 0, 1, 2, 124, 125, 126, 127, 128, 129, 65534, 65535, 60 // 65536, 65537 61 } 62 var readChunkers = []struct { 63 name string 64 f func(io.Reader) io.Reader 65 }{ 66 {"half", iotest.HalfReader}, 67 {"one", iotest.OneByteReader}, 68 {"asis", func(r io.Reader) io.Reader { return r }}, 69 } 70 writeBuf := make([]byte, 65537) 71 for i := range writeBuf { 72 writeBuf[i] = byte(i) 73 } 74 var writers = []struct { 75 name string 76 f func(w io.Writer, n int) (int, error) 77 }{ 78 {"iocopy", func(w io.Writer, n int) (int, error) { 79 nn, err := io.Copy(w, bytes.NewReader(writeBuf[:n])) 80 return int(nn), err 81 }}, 82 {"write", func(w io.Writer, n int) (int, error) { 83 return w.Write(writeBuf[:n]) 84 }}, 85 {"string", func(w io.Writer, n int) (int, error) { 86 return io.WriteString(w, string(writeBuf[:n])) 87 }}, 88 } 89 90 for _, compress := range []bool{false, true} { 91 for _, isServer := range []bool{true, false} { 92 for _, chunker := range readChunkers { 93 94 var connBuf bytes.Buffer 95 wc := newTestConn(nil, &connBuf, isServer) 96 rc := newTestConn(chunker.f(&connBuf), nil, !isServer) 97 if compress { 98 wc.newCompressionWriter = compressNoContextTakeover 99 rc.newDecompressionReader = decompressNoContextTakeover 100 } 101 for _, n := range frameSizes { 102 for _, writer := range writers { 103 name := fmt.Sprintf("z:%v, s:%v, r:%s, n:%d w:%s", compress, isServer, chunker.name, n, writer.name) 104 105 w, err := wc.NextWriter(TextMessage) 106 if err != nil { 107 t.Errorf("%s: wc.NextWriter() returned %v", name, err) 108 continue 109 } 110 nn, err := writer.f(w, n) 111 if err != nil || nn != n { 112 t.Errorf("%s: w.Write(writeBuf[:n]) returned %d, %v", name, nn, err) 113 continue 114 } 115 err = w.Close() 116 if err != nil { 117 t.Errorf("%s: w.Close() returned %v", name, err) 118 continue 119 } 120 121 opCode, r, err := rc.NextReader() 122 if err != nil || opCode != TextMessage { 123 t.Errorf("%s: NextReader() returned %d, r, %v", name, opCode, err) 124 continue 125 } 126 127 t.Logf("frame size: %d", n) 128 rbuf, err := ioutil.ReadAll(r) 129 if err != nil { 130 t.Errorf("%s: ReadFull() returned rbuf, %v", name, err) 131 continue 132 } 133 134 if len(rbuf) != n { 135 t.Errorf("%s: len(rbuf) is %d, want %d", name, len(rbuf), n) 136 continue 137 } 138 139 for i, b := range rbuf { 140 if byte(i) != b { 141 t.Errorf("%s: bad byte at offset %d", name, i) 142 break 143 } 144 } 145 } 146 } 147 } 148 } 149 } 150} 151 152func TestControl(t *testing.T) { 153 const message = "this is a ping/pong messsage" 154 for _, isServer := range []bool{true, false} { 155 for _, isWriteControl := range []bool{true, false} { 156 name := fmt.Sprintf("s:%v, wc:%v", isServer, isWriteControl) 157 var connBuf bytes.Buffer 158 wc := newTestConn(nil, &connBuf, isServer) 159 rc := newTestConn(&connBuf, nil, !isServer) 160 if isWriteControl { 161 wc.WriteControl(PongMessage, []byte(message), time.Now().Add(time.Second)) 162 } else { 163 w, err := wc.NextWriter(PongMessage) 164 if err != nil { 165 t.Errorf("%s: wc.NextWriter() returned %v", name, err) 166 continue 167 } 168 if _, err := w.Write([]byte(message)); err != nil { 169 t.Errorf("%s: w.Write() returned %v", name, err) 170 continue 171 } 172 if err := w.Close(); err != nil { 173 t.Errorf("%s: w.Close() returned %v", name, err) 174 continue 175 } 176 var actualMessage string 177 rc.SetPongHandler(func(s string) error { actualMessage = s; return nil }) 178 rc.NextReader() 179 if actualMessage != message { 180 t.Errorf("%s: pong=%q, want %q", name, actualMessage, message) 181 continue 182 } 183 } 184 } 185 } 186} 187 188// simpleBufferPool is an implementation of BufferPool for TestWriteBufferPool. 189type simpleBufferPool struct { 190 v interface{} 191} 192 193func (p *simpleBufferPool) Get() interface{} { 194 v := p.v 195 p.v = nil 196 return v 197} 198 199func (p *simpleBufferPool) Put(v interface{}) { 200 p.v = v 201} 202 203func TestWriteBufferPool(t *testing.T) { 204 const message = "Now is the time for all good people to come to the aid of the party." 205 206 var buf bytes.Buffer 207 var pool simpleBufferPool 208 rc := newTestConn(&buf, nil, false) 209 210 // Specify writeBufferSize smaller than message size to ensure that pooling 211 // works with fragmented messages. 212 wc := newConn(fakeNetConn{Writer: &buf}, true, 1024, len(message)-1, &pool, nil, nil) 213 214 if wc.writeBuf != nil { 215 t.Fatal("writeBuf not nil after create") 216 } 217 218 // Part 1: test NextWriter/Write/Close 219 220 w, err := wc.NextWriter(TextMessage) 221 if err != nil { 222 t.Fatalf("wc.NextWriter() returned %v", err) 223 } 224 225 if wc.writeBuf == nil { 226 t.Fatal("writeBuf is nil after NextWriter") 227 } 228 229 writeBufAddr := &wc.writeBuf[0] 230 231 if _, err := io.WriteString(w, message); err != nil { 232 t.Fatalf("io.WriteString(w, message) returned %v", err) 233 } 234 235 if err := w.Close(); err != nil { 236 t.Fatalf("w.Close() returned %v", err) 237 } 238 239 if wc.writeBuf != nil { 240 t.Fatal("writeBuf not nil after w.Close()") 241 } 242 243 if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr { 244 t.Fatal("writeBuf not returned to pool") 245 } 246 247 opCode, p, err := rc.ReadMessage() 248 if opCode != TextMessage || err != nil { 249 t.Fatalf("ReadMessage() returned %d, p, %v", opCode, err) 250 } 251 252 if s := string(p); s != message { 253 t.Fatalf("message is %s, want %s", s, message) 254 } 255 256 // Part 2: Test WriteMessage. 257 258 if err := wc.WriteMessage(TextMessage, []byte(message)); err != nil { 259 t.Fatalf("wc.WriteMessage() returned %v", err) 260 } 261 262 if wc.writeBuf != nil { 263 t.Fatal("writeBuf not nil after wc.WriteMessage()") 264 } 265 266 if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr { 267 t.Fatal("writeBuf not returned to pool after WriteMessage") 268 } 269 270 opCode, p, err = rc.ReadMessage() 271 if opCode != TextMessage || err != nil { 272 t.Fatalf("ReadMessage() returned %d, p, %v", opCode, err) 273 } 274 275 if s := string(p); s != message { 276 t.Fatalf("message is %s, want %s", s, message) 277 } 278} 279 280// TestWriteBufferPoolSync ensures that *sync.Pool works as a buffer pool. 281func TestWriteBufferPoolSync(t *testing.T) { 282 var buf bytes.Buffer 283 var pool sync.Pool 284 wc := newConn(fakeNetConn{Writer: &buf}, true, 1024, 1024, &pool, nil, nil) 285 rc := newTestConn(&buf, nil, false) 286 287 const message = "Hello World!" 288 for i := 0; i < 3; i++ { 289 if err := wc.WriteMessage(TextMessage, []byte(message)); err != nil { 290 t.Fatalf("wc.WriteMessage() returned %v", err) 291 } 292 opCode, p, err := rc.ReadMessage() 293 if opCode != TextMessage || err != nil { 294 t.Fatalf("ReadMessage() returned %d, p, %v", opCode, err) 295 } 296 if s := string(p); s != message { 297 t.Fatalf("message is %s, want %s", s, message) 298 } 299 } 300} 301 302// errorWriter is an io.Writer than returns an error on all writes. 303type errorWriter struct{} 304 305func (ew errorWriter) Write(p []byte) (int, error) { return 0, errors.New("error") } 306 307// TestWriteBufferPoolError ensures that buffer is returned to pool after error 308// on write. 309func TestWriteBufferPoolError(t *testing.T) { 310 311 // Part 1: Test NextWriter/Write/Close 312 313 var pool simpleBufferPool 314 wc := newConn(fakeNetConn{Writer: errorWriter{}}, true, 1024, 1024, &pool, nil, nil) 315 316 w, err := wc.NextWriter(TextMessage) 317 if err != nil { 318 t.Fatalf("wc.NextWriter() returned %v", err) 319 } 320 321 if wc.writeBuf == nil { 322 t.Fatal("writeBuf is nil after NextWriter") 323 } 324 325 writeBufAddr := &wc.writeBuf[0] 326 327 if _, err := io.WriteString(w, "Hello"); err != nil { 328 t.Fatalf("io.WriteString(w, message) returned %v", err) 329 } 330 331 if err := w.Close(); err == nil { 332 t.Fatalf("w.Close() did not return error") 333 } 334 335 if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr { 336 t.Fatal("writeBuf not returned to pool") 337 } 338 339 // Part 2: Test WriteMessage 340 341 wc = newConn(fakeNetConn{Writer: errorWriter{}}, true, 1024, 1024, &pool, nil, nil) 342 343 if err := wc.WriteMessage(TextMessage, []byte("Hello")); err == nil { 344 t.Fatalf("wc.WriteMessage did not return error") 345 } 346 347 if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr { 348 t.Fatal("writeBuf not returned to pool") 349 } 350} 351 352func TestCloseFrameBeforeFinalMessageFrame(t *testing.T) { 353 const bufSize = 512 354 355 expectedErr := &CloseError{Code: CloseNormalClosure, Text: "hello"} 356 357 var b1, b2 bytes.Buffer 358 wc := newConn(&fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, bufSize, nil, nil, nil) 359 rc := newTestConn(&b1, &b2, true) 360 361 w, _ := wc.NextWriter(BinaryMessage) 362 w.Write(make([]byte, bufSize+bufSize/2)) 363 wc.WriteControl(CloseMessage, FormatCloseMessage(expectedErr.Code, expectedErr.Text), time.Now().Add(10*time.Second)) 364 w.Close() 365 366 op, r, err := rc.NextReader() 367 if op != BinaryMessage || err != nil { 368 t.Fatalf("NextReader() returned %d, %v", op, err) 369 } 370 _, err = io.Copy(ioutil.Discard, r) 371 if !reflect.DeepEqual(err, expectedErr) { 372 t.Fatalf("io.Copy() returned %v, want %v", err, expectedErr) 373 } 374 _, _, err = rc.NextReader() 375 if !reflect.DeepEqual(err, expectedErr) { 376 t.Fatalf("NextReader() returned %v, want %v", err, expectedErr) 377 } 378} 379 380func TestEOFWithinFrame(t *testing.T) { 381 const bufSize = 64 382 383 for n := 0; ; n++ { 384 var b bytes.Buffer 385 wc := newTestConn(nil, &b, false) 386 rc := newTestConn(&b, nil, true) 387 388 w, _ := wc.NextWriter(BinaryMessage) 389 w.Write(make([]byte, bufSize)) 390 w.Close() 391 392 if n >= b.Len() { 393 break 394 } 395 b.Truncate(n) 396 397 op, r, err := rc.NextReader() 398 if err == errUnexpectedEOF { 399 continue 400 } 401 if op != BinaryMessage || err != nil { 402 t.Fatalf("%d: NextReader() returned %d, %v", n, op, err) 403 } 404 _, err = io.Copy(ioutil.Discard, r) 405 if err != errUnexpectedEOF { 406 t.Fatalf("%d: io.Copy() returned %v, want %v", n, err, errUnexpectedEOF) 407 } 408 _, _, err = rc.NextReader() 409 if err != errUnexpectedEOF { 410 t.Fatalf("%d: NextReader() returned %v, want %v", n, err, errUnexpectedEOF) 411 } 412 } 413} 414 415func TestEOFBeforeFinalFrame(t *testing.T) { 416 const bufSize = 512 417 418 var b1, b2 bytes.Buffer 419 wc := newConn(&fakeNetConn{Writer: &b1}, false, 1024, bufSize, nil, nil, nil) 420 rc := newTestConn(&b1, &b2, true) 421 422 w, _ := wc.NextWriter(BinaryMessage) 423 w.Write(make([]byte, bufSize+bufSize/2)) 424 425 op, r, err := rc.NextReader() 426 if op != BinaryMessage || err != nil { 427 t.Fatalf("NextReader() returned %d, %v", op, err) 428 } 429 _, err = io.Copy(ioutil.Discard, r) 430 if err != errUnexpectedEOF { 431 t.Fatalf("io.Copy() returned %v, want %v", err, errUnexpectedEOF) 432 } 433 _, _, err = rc.NextReader() 434 if err != errUnexpectedEOF { 435 t.Fatalf("NextReader() returned %v, want %v", err, errUnexpectedEOF) 436 } 437} 438 439func TestWriteAfterMessageWriterClose(t *testing.T) { 440 wc := newTestConn(nil, &bytes.Buffer{}, false) 441 w, _ := wc.NextWriter(BinaryMessage) 442 io.WriteString(w, "hello") 443 if err := w.Close(); err != nil { 444 t.Fatalf("unxpected error closing message writer, %v", err) 445 } 446 447 if _, err := io.WriteString(w, "world"); err == nil { 448 t.Fatalf("no error writing after close") 449 } 450 451 w, _ = wc.NextWriter(BinaryMessage) 452 io.WriteString(w, "hello") 453 454 // close w by getting next writer 455 _, err := wc.NextWriter(BinaryMessage) 456 if err != nil { 457 t.Fatalf("unexpected error getting next writer, %v", err) 458 } 459 460 if _, err := io.WriteString(w, "world"); err == nil { 461 t.Fatalf("no error writing after close") 462 } 463} 464 465func TestReadLimit(t *testing.T) { 466 t.Run("Test ReadLimit is enforced", func(t *testing.T) { 467 const readLimit = 512 468 message := make([]byte, readLimit+1) 469 470 var b1, b2 bytes.Buffer 471 wc := newConn(&fakeNetConn{Writer: &b1}, false, 1024, readLimit-2, nil, nil, nil) 472 rc := newTestConn(&b1, &b2, true) 473 rc.SetReadLimit(readLimit) 474 475 // Send message at the limit with interleaved pong. 476 w, _ := wc.NextWriter(BinaryMessage) 477 w.Write(message[:readLimit-1]) 478 wc.WriteControl(PongMessage, []byte("this is a pong"), time.Now().Add(10*time.Second)) 479 w.Write(message[:1]) 480 w.Close() 481 482 // Send message larger than the limit. 483 wc.WriteMessage(BinaryMessage, message[:readLimit+1]) 484 485 op, _, err := rc.NextReader() 486 if op != BinaryMessage || err != nil { 487 t.Fatalf("1: NextReader() returned %d, %v", op, err) 488 } 489 op, r, err := rc.NextReader() 490 if op != BinaryMessage || err != nil { 491 t.Fatalf("2: NextReader() returned %d, %v", op, err) 492 } 493 _, err = io.Copy(ioutil.Discard, r) 494 if err != ErrReadLimit { 495 t.Fatalf("io.Copy() returned %v", err) 496 } 497 }) 498 499 t.Run("Test that ReadLimit cannot be overflowed", func(t *testing.T) { 500 const readLimit = 1 501 502 var b1, b2 bytes.Buffer 503 rc := newTestConn(&b1, &b2, true) 504 rc.SetReadLimit(readLimit) 505 506 // First, send a non-final binary message 507 b1.Write([]byte("\x02\x81")) 508 509 // Mask key 510 b1.Write([]byte("\x00\x00\x00\x00")) 511 512 // First payload 513 b1.Write([]byte("A")) 514 515 // Next, send a negative-length, non-final continuation frame 516 b1.Write([]byte("\x00\xFF\x80\x00\x00\x00\x00\x00\x00\x00")) 517 518 // Mask key 519 b1.Write([]byte("\x00\x00\x00\x00")) 520 521 // Next, send a too long, final continuation frame 522 b1.Write([]byte("\x80\xFF\x00\x00\x00\x00\x00\x00\x00\x05")) 523 524 // Mask key 525 b1.Write([]byte("\x00\x00\x00\x00")) 526 527 // Too-long payload 528 b1.Write([]byte("BCDEF")) 529 530 op, r, err := rc.NextReader() 531 if op != BinaryMessage || err != nil { 532 t.Fatalf("1: NextReader() returned %d, %v", op, err) 533 } 534 535 var buf [10]byte 536 var read int 537 n, err := r.Read(buf[:]) 538 if err != nil && err != ErrReadLimit { 539 t.Fatalf("unexpected error testing read limit: %v", err) 540 } 541 read += n 542 543 n, err = r.Read(buf[:]) 544 if err != nil && err != ErrReadLimit { 545 t.Fatalf("unexpected error testing read limit: %v", err) 546 } 547 read += n 548 549 if err == nil && read > readLimit { 550 t.Fatalf("read limit exceeded: limit %d, read %d", readLimit, read) 551 } 552 }) 553} 554 555func TestAddrs(t *testing.T) { 556 c := newTestConn(nil, nil, true) 557 if c.LocalAddr() != localAddr { 558 t.Errorf("LocalAddr = %v, want %v", c.LocalAddr(), localAddr) 559 } 560 if c.RemoteAddr() != remoteAddr { 561 t.Errorf("RemoteAddr = %v, want %v", c.RemoteAddr(), remoteAddr) 562 } 563} 564 565func TestUnderlyingConn(t *testing.T) { 566 var b1, b2 bytes.Buffer 567 fc := fakeNetConn{Reader: &b1, Writer: &b2} 568 c := newConn(fc, true, 1024, 1024, nil, nil, nil) 569 ul := c.UnderlyingConn() 570 if ul != fc { 571 t.Fatalf("Underlying conn is not what it should be.") 572 } 573} 574 575func TestBufioReadBytes(t *testing.T) { 576 // Test calling bufio.ReadBytes for value longer than read buffer size. 577 578 m := make([]byte, 512) 579 m[len(m)-1] = '\n' 580 581 var b1, b2 bytes.Buffer 582 wc := newConn(fakeNetConn{Writer: &b1}, false, len(m)+64, len(m)+64, nil, nil, nil) 583 rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, len(m)-64, len(m)-64, nil, nil, nil) 584 585 w, _ := wc.NextWriter(BinaryMessage) 586 w.Write(m) 587 w.Close() 588 589 op, r, err := rc.NextReader() 590 if op != BinaryMessage || err != nil { 591 t.Fatalf("NextReader() returned %d, %v", op, err) 592 } 593 594 br := bufio.NewReader(r) 595 p, err := br.ReadBytes('\n') 596 if err != nil { 597 t.Fatalf("ReadBytes() returned %v", err) 598 } 599 if len(p) != len(m) { 600 t.Fatalf("read returned %d bytes, want %d bytes", len(p), len(m)) 601 } 602} 603 604var closeErrorTests = []struct { 605 err error 606 codes []int 607 ok bool 608}{ 609 {&CloseError{Code: CloseNormalClosure}, []int{CloseNormalClosure}, true}, 610 {&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived}, false}, 611 {&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived, CloseNormalClosure}, true}, 612 {errors.New("hello"), []int{CloseNormalClosure}, false}, 613} 614 615func TestCloseError(t *testing.T) { 616 for _, tt := range closeErrorTests { 617 ok := IsCloseError(tt.err, tt.codes...) 618 if ok != tt.ok { 619 t.Errorf("IsCloseError(%#v, %#v) returned %v, want %v", tt.err, tt.codes, ok, tt.ok) 620 } 621 } 622} 623 624var unexpectedCloseErrorTests = []struct { 625 err error 626 codes []int 627 ok bool 628}{ 629 {&CloseError{Code: CloseNormalClosure}, []int{CloseNormalClosure}, false}, 630 {&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived}, true}, 631 {&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived, CloseNormalClosure}, false}, 632 {errors.New("hello"), []int{CloseNormalClosure}, false}, 633} 634 635func TestUnexpectedCloseErrors(t *testing.T) { 636 for _, tt := range unexpectedCloseErrorTests { 637 ok := IsUnexpectedCloseError(tt.err, tt.codes...) 638 if ok != tt.ok { 639 t.Errorf("IsUnexpectedCloseError(%#v, %#v) returned %v, want %v", tt.err, tt.codes, ok, tt.ok) 640 } 641 } 642} 643 644type blockingWriter struct { 645 c1, c2 chan struct{} 646} 647 648func (w blockingWriter) Write(p []byte) (int, error) { 649 // Allow main to continue 650 close(w.c1) 651 // Wait for panic in main 652 <-w.c2 653 return len(p), nil 654} 655 656func TestConcurrentWritePanic(t *testing.T) { 657 w := blockingWriter{make(chan struct{}), make(chan struct{})} 658 c := newTestConn(nil, w, false) 659 go func() { 660 c.WriteMessage(TextMessage, []byte{}) 661 }() 662 663 // wait for goroutine to block in write. 664 <-w.c1 665 666 defer func() { 667 close(w.c2) 668 if v := recover(); v != nil { 669 return 670 } 671 }() 672 673 c.WriteMessage(TextMessage, []byte{}) 674 t.Fatal("should not get here") 675} 676 677type failingReader struct{} 678 679func (r failingReader) Read(p []byte) (int, error) { 680 return 0, io.EOF 681} 682 683func TestFailedConnectionReadPanic(t *testing.T) { 684 c := newTestConn(failingReader{}, nil, false) 685 686 defer func() { 687 if v := recover(); v != nil { 688 return 689 } 690 }() 691 692 for i := 0; i < 20000; i++ { 693 c.ReadMessage() 694 } 695 t.Fatal("should not get here") 696} 697