1// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package websocket
6
7import (
8	"bufio"
9	"bytes"
10	"errors"
11	"fmt"
12	"io"
13	"io/ioutil"
14	"net"
15	"reflect"
16	"sync"
17	"testing"
18	"testing/iotest"
19	"time"
20)
21
22var _ net.Error = errWriteTimeout
23
24type fakeNetConn struct {
25	io.Reader
26	io.Writer
27}
28
29func (c fakeNetConn) Close() error                       { return nil }
30func (c fakeNetConn) LocalAddr() net.Addr                { return localAddr }
31func (c fakeNetConn) RemoteAddr() net.Addr               { return remoteAddr }
32func (c fakeNetConn) SetDeadline(t time.Time) error      { return nil }
33func (c fakeNetConn) SetReadDeadline(t time.Time) error  { return nil }
34func (c fakeNetConn) SetWriteDeadline(t time.Time) error { return nil }
35
36type fakeAddr int
37
38var (
39	localAddr  = fakeAddr(1)
40	remoteAddr = fakeAddr(2)
41)
42
43func (a fakeAddr) Network() string {
44	return "net"
45}
46
47func (a fakeAddr) String() string {
48	return "str"
49}
50
51// newTestConn creates a connnection backed by a fake network connection using
52// default values for buffering.
53func newTestConn(r io.Reader, w io.Writer, isServer bool) *Conn {
54	return newConn(fakeNetConn{Reader: r, Writer: w}, isServer, 1024, 1024, nil, nil, nil)
55}
56
57func TestFraming(t *testing.T) {
58	frameSizes := []int{
59		0, 1, 2, 124, 125, 126, 127, 128, 129, 65534, 65535,
60		// 65536, 65537
61	}
62	var readChunkers = []struct {
63		name string
64		f    func(io.Reader) io.Reader
65	}{
66		{"half", iotest.HalfReader},
67		{"one", iotest.OneByteReader},
68		{"asis", func(r io.Reader) io.Reader { return r }},
69	}
70	writeBuf := make([]byte, 65537)
71	for i := range writeBuf {
72		writeBuf[i] = byte(i)
73	}
74	var writers = []struct {
75		name string
76		f    func(w io.Writer, n int) (int, error)
77	}{
78		{"iocopy", func(w io.Writer, n int) (int, error) {
79			nn, err := io.Copy(w, bytes.NewReader(writeBuf[:n]))
80			return int(nn), err
81		}},
82		{"write", func(w io.Writer, n int) (int, error) {
83			return w.Write(writeBuf[:n])
84		}},
85		{"string", func(w io.Writer, n int) (int, error) {
86			return io.WriteString(w, string(writeBuf[:n]))
87		}},
88	}
89
90	for _, compress := range []bool{false, true} {
91		for _, isServer := range []bool{true, false} {
92			for _, chunker := range readChunkers {
93
94				var connBuf bytes.Buffer
95				wc := newTestConn(nil, &connBuf, isServer)
96				rc := newTestConn(chunker.f(&connBuf), nil, !isServer)
97				if compress {
98					wc.newCompressionWriter = compressNoContextTakeover
99					rc.newDecompressionReader = decompressNoContextTakeover
100				}
101				for _, n := range frameSizes {
102					for _, writer := range writers {
103						name := fmt.Sprintf("z:%v, s:%v, r:%s, n:%d w:%s", compress, isServer, chunker.name, n, writer.name)
104
105						w, err := wc.NextWriter(TextMessage)
106						if err != nil {
107							t.Errorf("%s: wc.NextWriter() returned %v", name, err)
108							continue
109						}
110						nn, err := writer.f(w, n)
111						if err != nil || nn != n {
112							t.Errorf("%s: w.Write(writeBuf[:n]) returned %d, %v", name, nn, err)
113							continue
114						}
115						err = w.Close()
116						if err != nil {
117							t.Errorf("%s: w.Close() returned %v", name, err)
118							continue
119						}
120
121						opCode, r, err := rc.NextReader()
122						if err != nil || opCode != TextMessage {
123							t.Errorf("%s: NextReader() returned %d, r, %v", name, opCode, err)
124							continue
125						}
126
127						t.Logf("frame size: %d", n)
128						rbuf, err := ioutil.ReadAll(r)
129						if err != nil {
130							t.Errorf("%s: ReadFull() returned rbuf, %v", name, err)
131							continue
132						}
133
134						if len(rbuf) != n {
135							t.Errorf("%s: len(rbuf) is %d, want %d", name, len(rbuf), n)
136							continue
137						}
138
139						for i, b := range rbuf {
140							if byte(i) != b {
141								t.Errorf("%s: bad byte at offset %d", name, i)
142								break
143							}
144						}
145					}
146				}
147			}
148		}
149	}
150}
151
152func TestControl(t *testing.T) {
153	const message = "this is a ping/pong messsage"
154	for _, isServer := range []bool{true, false} {
155		for _, isWriteControl := range []bool{true, false} {
156			name := fmt.Sprintf("s:%v, wc:%v", isServer, isWriteControl)
157			var connBuf bytes.Buffer
158			wc := newTestConn(nil, &connBuf, isServer)
159			rc := newTestConn(&connBuf, nil, !isServer)
160			if isWriteControl {
161				wc.WriteControl(PongMessage, []byte(message), time.Now().Add(time.Second))
162			} else {
163				w, err := wc.NextWriter(PongMessage)
164				if err != nil {
165					t.Errorf("%s: wc.NextWriter() returned %v", name, err)
166					continue
167				}
168				if _, err := w.Write([]byte(message)); err != nil {
169					t.Errorf("%s: w.Write() returned %v", name, err)
170					continue
171				}
172				if err := w.Close(); err != nil {
173					t.Errorf("%s: w.Close() returned %v", name, err)
174					continue
175				}
176				var actualMessage string
177				rc.SetPongHandler(func(s string) error { actualMessage = s; return nil })
178				rc.NextReader()
179				if actualMessage != message {
180					t.Errorf("%s: pong=%q, want %q", name, actualMessage, message)
181					continue
182				}
183			}
184		}
185	}
186}
187
188// simpleBufferPool is an implementation of BufferPool for TestWriteBufferPool.
189type simpleBufferPool struct {
190	v interface{}
191}
192
193func (p *simpleBufferPool) Get() interface{} {
194	v := p.v
195	p.v = nil
196	return v
197}
198
199func (p *simpleBufferPool) Put(v interface{}) {
200	p.v = v
201}
202
203func TestWriteBufferPool(t *testing.T) {
204	const message = "Now is the time for all good people to come to the aid of the party."
205
206	var buf bytes.Buffer
207	var pool simpleBufferPool
208	rc := newTestConn(&buf, nil, false)
209
210	// Specify writeBufferSize smaller than message size to ensure that pooling
211	// works with fragmented messages.
212	wc := newConn(fakeNetConn{Writer: &buf}, true, 1024, len(message)-1, &pool, nil, nil)
213
214	if wc.writeBuf != nil {
215		t.Fatal("writeBuf not nil after create")
216	}
217
218	// Part 1: test NextWriter/Write/Close
219
220	w, err := wc.NextWriter(TextMessage)
221	if err != nil {
222		t.Fatalf("wc.NextWriter() returned %v", err)
223	}
224
225	if wc.writeBuf == nil {
226		t.Fatal("writeBuf is nil after NextWriter")
227	}
228
229	writeBufAddr := &wc.writeBuf[0]
230
231	if _, err := io.WriteString(w, message); err != nil {
232		t.Fatalf("io.WriteString(w, message) returned %v", err)
233	}
234
235	if err := w.Close(); err != nil {
236		t.Fatalf("w.Close() returned %v", err)
237	}
238
239	if wc.writeBuf != nil {
240		t.Fatal("writeBuf not nil after w.Close()")
241	}
242
243	if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr {
244		t.Fatal("writeBuf not returned to pool")
245	}
246
247	opCode, p, err := rc.ReadMessage()
248	if opCode != TextMessage || err != nil {
249		t.Fatalf("ReadMessage() returned %d, p, %v", opCode, err)
250	}
251
252	if s := string(p); s != message {
253		t.Fatalf("message is %s, want %s", s, message)
254	}
255
256	// Part 2: Test WriteMessage.
257
258	if err := wc.WriteMessage(TextMessage, []byte(message)); err != nil {
259		t.Fatalf("wc.WriteMessage() returned %v", err)
260	}
261
262	if wc.writeBuf != nil {
263		t.Fatal("writeBuf not nil after wc.WriteMessage()")
264	}
265
266	if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr {
267		t.Fatal("writeBuf not returned to pool after WriteMessage")
268	}
269
270	opCode, p, err = rc.ReadMessage()
271	if opCode != TextMessage || err != nil {
272		t.Fatalf("ReadMessage() returned %d, p, %v", opCode, err)
273	}
274
275	if s := string(p); s != message {
276		t.Fatalf("message is %s, want %s", s, message)
277	}
278}
279
280// TestWriteBufferPoolSync ensures that *sync.Pool works as a buffer pool.
281func TestWriteBufferPoolSync(t *testing.T) {
282	var buf bytes.Buffer
283	var pool sync.Pool
284	wc := newConn(fakeNetConn{Writer: &buf}, true, 1024, 1024, &pool, nil, nil)
285	rc := newTestConn(&buf, nil, false)
286
287	const message = "Hello World!"
288	for i := 0; i < 3; i++ {
289		if err := wc.WriteMessage(TextMessage, []byte(message)); err != nil {
290			t.Fatalf("wc.WriteMessage() returned %v", err)
291		}
292		opCode, p, err := rc.ReadMessage()
293		if opCode != TextMessage || err != nil {
294			t.Fatalf("ReadMessage() returned %d, p, %v", opCode, err)
295		}
296		if s := string(p); s != message {
297			t.Fatalf("message is %s, want %s", s, message)
298		}
299	}
300}
301
302// errorWriter is an io.Writer than returns an error on all writes.
303type errorWriter struct{}
304
305func (ew errorWriter) Write(p []byte) (int, error) { return 0, errors.New("error") }
306
307// TestWriteBufferPoolError ensures that buffer is returned to pool after error
308// on write.
309func TestWriteBufferPoolError(t *testing.T) {
310
311	// Part 1: Test NextWriter/Write/Close
312
313	var pool simpleBufferPool
314	wc := newConn(fakeNetConn{Writer: errorWriter{}}, true, 1024, 1024, &pool, nil, nil)
315
316	w, err := wc.NextWriter(TextMessage)
317	if err != nil {
318		t.Fatalf("wc.NextWriter() returned %v", err)
319	}
320
321	if wc.writeBuf == nil {
322		t.Fatal("writeBuf is nil after NextWriter")
323	}
324
325	writeBufAddr := &wc.writeBuf[0]
326
327	if _, err := io.WriteString(w, "Hello"); err != nil {
328		t.Fatalf("io.WriteString(w, message) returned %v", err)
329	}
330
331	if err := w.Close(); err == nil {
332		t.Fatalf("w.Close() did not return error")
333	}
334
335	if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr {
336		t.Fatal("writeBuf not returned to pool")
337	}
338
339	// Part 2: Test WriteMessage
340
341	wc = newConn(fakeNetConn{Writer: errorWriter{}}, true, 1024, 1024, &pool, nil, nil)
342
343	if err := wc.WriteMessage(TextMessage, []byte("Hello")); err == nil {
344		t.Fatalf("wc.WriteMessage did not return error")
345	}
346
347	if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr {
348		t.Fatal("writeBuf not returned to pool")
349	}
350}
351
352func TestCloseFrameBeforeFinalMessageFrame(t *testing.T) {
353	const bufSize = 512
354
355	expectedErr := &CloseError{Code: CloseNormalClosure, Text: "hello"}
356
357	var b1, b2 bytes.Buffer
358	wc := newConn(&fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, bufSize, nil, nil, nil)
359	rc := newTestConn(&b1, &b2, true)
360
361	w, _ := wc.NextWriter(BinaryMessage)
362	w.Write(make([]byte, bufSize+bufSize/2))
363	wc.WriteControl(CloseMessage, FormatCloseMessage(expectedErr.Code, expectedErr.Text), time.Now().Add(10*time.Second))
364	w.Close()
365
366	op, r, err := rc.NextReader()
367	if op != BinaryMessage || err != nil {
368		t.Fatalf("NextReader() returned %d, %v", op, err)
369	}
370	_, err = io.Copy(ioutil.Discard, r)
371	if !reflect.DeepEqual(err, expectedErr) {
372		t.Fatalf("io.Copy() returned %v, want %v", err, expectedErr)
373	}
374	_, _, err = rc.NextReader()
375	if !reflect.DeepEqual(err, expectedErr) {
376		t.Fatalf("NextReader() returned %v, want %v", err, expectedErr)
377	}
378}
379
380func TestEOFWithinFrame(t *testing.T) {
381	const bufSize = 64
382
383	for n := 0; ; n++ {
384		var b bytes.Buffer
385		wc := newTestConn(nil, &b, false)
386		rc := newTestConn(&b, nil, true)
387
388		w, _ := wc.NextWriter(BinaryMessage)
389		w.Write(make([]byte, bufSize))
390		w.Close()
391
392		if n >= b.Len() {
393			break
394		}
395		b.Truncate(n)
396
397		op, r, err := rc.NextReader()
398		if err == errUnexpectedEOF {
399			continue
400		}
401		if op != BinaryMessage || err != nil {
402			t.Fatalf("%d: NextReader() returned %d, %v", n, op, err)
403		}
404		_, err = io.Copy(ioutil.Discard, r)
405		if err != errUnexpectedEOF {
406			t.Fatalf("%d: io.Copy() returned %v, want %v", n, err, errUnexpectedEOF)
407		}
408		_, _, err = rc.NextReader()
409		if err != errUnexpectedEOF {
410			t.Fatalf("%d: NextReader() returned %v, want %v", n, err, errUnexpectedEOF)
411		}
412	}
413}
414
415func TestEOFBeforeFinalFrame(t *testing.T) {
416	const bufSize = 512
417
418	var b1, b2 bytes.Buffer
419	wc := newConn(&fakeNetConn{Writer: &b1}, false, 1024, bufSize, nil, nil, nil)
420	rc := newTestConn(&b1, &b2, true)
421
422	w, _ := wc.NextWriter(BinaryMessage)
423	w.Write(make([]byte, bufSize+bufSize/2))
424
425	op, r, err := rc.NextReader()
426	if op != BinaryMessage || err != nil {
427		t.Fatalf("NextReader() returned %d, %v", op, err)
428	}
429	_, err = io.Copy(ioutil.Discard, r)
430	if err != errUnexpectedEOF {
431		t.Fatalf("io.Copy() returned %v, want %v", err, errUnexpectedEOF)
432	}
433	_, _, err = rc.NextReader()
434	if err != errUnexpectedEOF {
435		t.Fatalf("NextReader() returned %v, want %v", err, errUnexpectedEOF)
436	}
437}
438
439func TestWriteAfterMessageWriterClose(t *testing.T) {
440	wc := newTestConn(nil, &bytes.Buffer{}, false)
441	w, _ := wc.NextWriter(BinaryMessage)
442	io.WriteString(w, "hello")
443	if err := w.Close(); err != nil {
444		t.Fatalf("unxpected error closing message writer, %v", err)
445	}
446
447	if _, err := io.WriteString(w, "world"); err == nil {
448		t.Fatalf("no error writing after close")
449	}
450
451	w, _ = wc.NextWriter(BinaryMessage)
452	io.WriteString(w, "hello")
453
454	// close w by getting next writer
455	_, err := wc.NextWriter(BinaryMessage)
456	if err != nil {
457		t.Fatalf("unexpected error getting next writer, %v", err)
458	}
459
460	if _, err := io.WriteString(w, "world"); err == nil {
461		t.Fatalf("no error writing after close")
462	}
463}
464
465func TestReadLimit(t *testing.T) {
466	t.Run("Test ReadLimit is enforced", func(t *testing.T) {
467		const readLimit = 512
468		message := make([]byte, readLimit+1)
469
470		var b1, b2 bytes.Buffer
471		wc := newConn(&fakeNetConn{Writer: &b1}, false, 1024, readLimit-2, nil, nil, nil)
472		rc := newTestConn(&b1, &b2, true)
473		rc.SetReadLimit(readLimit)
474
475		// Send message at the limit with interleaved pong.
476		w, _ := wc.NextWriter(BinaryMessage)
477		w.Write(message[:readLimit-1])
478		wc.WriteControl(PongMessage, []byte("this is a pong"), time.Now().Add(10*time.Second))
479		w.Write(message[:1])
480		w.Close()
481
482		// Send message larger than the limit.
483		wc.WriteMessage(BinaryMessage, message[:readLimit+1])
484
485		op, _, err := rc.NextReader()
486		if op != BinaryMessage || err != nil {
487			t.Fatalf("1: NextReader() returned %d, %v", op, err)
488		}
489		op, r, err := rc.NextReader()
490		if op != BinaryMessage || err != nil {
491			t.Fatalf("2: NextReader() returned %d, %v", op, err)
492		}
493		_, err = io.Copy(ioutil.Discard, r)
494		if err != ErrReadLimit {
495			t.Fatalf("io.Copy() returned %v", err)
496		}
497	})
498
499	t.Run("Test that ReadLimit cannot be overflowed", func(t *testing.T) {
500		const readLimit = 1
501
502		var b1, b2 bytes.Buffer
503		rc := newTestConn(&b1, &b2, true)
504		rc.SetReadLimit(readLimit)
505
506		// First, send a non-final binary message
507		b1.Write([]byte("\x02\x81"))
508
509		// Mask key
510		b1.Write([]byte("\x00\x00\x00\x00"))
511
512		// First payload
513		b1.Write([]byte("A"))
514
515		// Next, send a negative-length, non-final continuation frame
516		b1.Write([]byte("\x00\xFF\x80\x00\x00\x00\x00\x00\x00\x00"))
517
518		// Mask key
519		b1.Write([]byte("\x00\x00\x00\x00"))
520
521		// Next, send a too long, final continuation frame
522		b1.Write([]byte("\x80\xFF\x00\x00\x00\x00\x00\x00\x00\x05"))
523
524		// Mask key
525		b1.Write([]byte("\x00\x00\x00\x00"))
526
527		// Too-long payload
528		b1.Write([]byte("BCDEF"))
529
530		op, r, err := rc.NextReader()
531		if op != BinaryMessage || err != nil {
532			t.Fatalf("1: NextReader() returned %d, %v", op, err)
533		}
534
535		var buf [10]byte
536		var read int
537		n, err := r.Read(buf[:])
538		if err != nil && err != ErrReadLimit {
539			t.Fatalf("unexpected error testing read limit: %v", err)
540		}
541		read += n
542
543		n, err = r.Read(buf[:])
544		if err != nil && err != ErrReadLimit {
545			t.Fatalf("unexpected error testing read limit: %v", err)
546		}
547		read += n
548
549		if err == nil && read > readLimit {
550			t.Fatalf("read limit exceeded: limit %d, read %d", readLimit, read)
551		}
552	})
553}
554
555func TestAddrs(t *testing.T) {
556	c := newTestConn(nil, nil, true)
557	if c.LocalAddr() != localAddr {
558		t.Errorf("LocalAddr = %v, want %v", c.LocalAddr(), localAddr)
559	}
560	if c.RemoteAddr() != remoteAddr {
561		t.Errorf("RemoteAddr = %v, want %v", c.RemoteAddr(), remoteAddr)
562	}
563}
564
565func TestUnderlyingConn(t *testing.T) {
566	var b1, b2 bytes.Buffer
567	fc := fakeNetConn{Reader: &b1, Writer: &b2}
568	c := newConn(fc, true, 1024, 1024, nil, nil, nil)
569	ul := c.UnderlyingConn()
570	if ul != fc {
571		t.Fatalf("Underlying conn is not what it should be.")
572	}
573}
574
575func TestBufioReadBytes(t *testing.T) {
576	// Test calling bufio.ReadBytes for value longer than read buffer size.
577
578	m := make([]byte, 512)
579	m[len(m)-1] = '\n'
580
581	var b1, b2 bytes.Buffer
582	wc := newConn(fakeNetConn{Writer: &b1}, false, len(m)+64, len(m)+64, nil, nil, nil)
583	rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, len(m)-64, len(m)-64, nil, nil, nil)
584
585	w, _ := wc.NextWriter(BinaryMessage)
586	w.Write(m)
587	w.Close()
588
589	op, r, err := rc.NextReader()
590	if op != BinaryMessage || err != nil {
591		t.Fatalf("NextReader() returned %d, %v", op, err)
592	}
593
594	br := bufio.NewReader(r)
595	p, err := br.ReadBytes('\n')
596	if err != nil {
597		t.Fatalf("ReadBytes() returned %v", err)
598	}
599	if len(p) != len(m) {
600		t.Fatalf("read returned %d bytes, want %d bytes", len(p), len(m))
601	}
602}
603
604var closeErrorTests = []struct {
605	err   error
606	codes []int
607	ok    bool
608}{
609	{&CloseError{Code: CloseNormalClosure}, []int{CloseNormalClosure}, true},
610	{&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived}, false},
611	{&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived, CloseNormalClosure}, true},
612	{errors.New("hello"), []int{CloseNormalClosure}, false},
613}
614
615func TestCloseError(t *testing.T) {
616	for _, tt := range closeErrorTests {
617		ok := IsCloseError(tt.err, tt.codes...)
618		if ok != tt.ok {
619			t.Errorf("IsCloseError(%#v, %#v) returned %v, want %v", tt.err, tt.codes, ok, tt.ok)
620		}
621	}
622}
623
624var unexpectedCloseErrorTests = []struct {
625	err   error
626	codes []int
627	ok    bool
628}{
629	{&CloseError{Code: CloseNormalClosure}, []int{CloseNormalClosure}, false},
630	{&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived}, true},
631	{&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived, CloseNormalClosure}, false},
632	{errors.New("hello"), []int{CloseNormalClosure}, false},
633}
634
635func TestUnexpectedCloseErrors(t *testing.T) {
636	for _, tt := range unexpectedCloseErrorTests {
637		ok := IsUnexpectedCloseError(tt.err, tt.codes...)
638		if ok != tt.ok {
639			t.Errorf("IsUnexpectedCloseError(%#v, %#v) returned %v, want %v", tt.err, tt.codes, ok, tt.ok)
640		}
641	}
642}
643
644type blockingWriter struct {
645	c1, c2 chan struct{}
646}
647
648func (w blockingWriter) Write(p []byte) (int, error) {
649	// Allow main to continue
650	close(w.c1)
651	// Wait for panic in main
652	<-w.c2
653	return len(p), nil
654}
655
656func TestConcurrentWritePanic(t *testing.T) {
657	w := blockingWriter{make(chan struct{}), make(chan struct{})}
658	c := newTestConn(nil, w, false)
659	go func() {
660		c.WriteMessage(TextMessage, []byte{})
661	}()
662
663	// wait for goroutine to block in write.
664	<-w.c1
665
666	defer func() {
667		close(w.c2)
668		if v := recover(); v != nil {
669			return
670		}
671	}()
672
673	c.WriteMessage(TextMessage, []byte{})
674	t.Fatal("should not get here")
675}
676
677type failingReader struct{}
678
679func (r failingReader) Read(p []byte) (int, error) {
680	return 0, io.EOF
681}
682
683func TestFailedConnectionReadPanic(t *testing.T) {
684	c := newTestConn(failingReader{}, nil, false)
685
686	defer func() {
687		if v := recover(); v != nil {
688			return
689		}
690	}()
691
692	for i := 0; i < 20000; i++ {
693		c.ReadMessage()
694	}
695	t.Fatal("should not get here")
696}
697