1// Copyright 2009 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package websocket
6
7import (
8	"bytes"
9	"crypto/rand"
10	"fmt"
11	"io"
12	"log"
13	"net"
14	"net/http"
15	"net/http/httptest"
16	"net/url"
17	"reflect"
18	"runtime"
19	"strings"
20	"sync"
21	"testing"
22	"time"
23)
24
25var serverAddr string
26var once sync.Once
27
28func echoServer(ws *Conn) {
29	defer ws.Close()
30	io.Copy(ws, ws)
31}
32
33type Count struct {
34	S string
35	N int
36}
37
38func countServer(ws *Conn) {
39	defer ws.Close()
40	for {
41		var count Count
42		err := JSON.Receive(ws, &count)
43		if err != nil {
44			return
45		}
46		count.N++
47		count.S = strings.Repeat(count.S, count.N)
48		err = JSON.Send(ws, count)
49		if err != nil {
50			return
51		}
52	}
53}
54
55type testCtrlAndDataHandler struct {
56	hybiFrameHandler
57}
58
59func (h *testCtrlAndDataHandler) WritePing(b []byte) (int, error) {
60	h.hybiFrameHandler.conn.wio.Lock()
61	defer h.hybiFrameHandler.conn.wio.Unlock()
62	w, err := h.hybiFrameHandler.conn.frameWriterFactory.NewFrameWriter(PingFrame)
63	if err != nil {
64		return 0, err
65	}
66	n, err := w.Write(b)
67	w.Close()
68	return n, err
69}
70
71func ctrlAndDataServer(ws *Conn) {
72	defer ws.Close()
73	h := &testCtrlAndDataHandler{hybiFrameHandler: hybiFrameHandler{conn: ws}}
74	ws.frameHandler = h
75
76	go func() {
77		for i := 0; ; i++ {
78			var b []byte
79			if i%2 != 0 { // with or without payload
80				b = []byte(fmt.Sprintf("#%d-CONTROL-FRAME-FROM-SERVER", i))
81			}
82			if _, err := h.WritePing(b); err != nil {
83				break
84			}
85			if _, err := h.WritePong(b); err != nil { // unsolicited pong
86				break
87			}
88			time.Sleep(10 * time.Millisecond)
89		}
90	}()
91
92	b := make([]byte, 128)
93	for {
94		n, err := ws.Read(b)
95		if err != nil {
96			break
97		}
98		if _, err := ws.Write(b[:n]); err != nil {
99			break
100		}
101	}
102}
103
104func subProtocolHandshake(config *Config, req *http.Request) error {
105	for _, proto := range config.Protocol {
106		if proto == "chat" {
107			config.Protocol = []string{proto}
108			return nil
109		}
110	}
111	return ErrBadWebSocketProtocol
112}
113
114func subProtoServer(ws *Conn) {
115	for _, proto := range ws.Config().Protocol {
116		io.WriteString(ws, proto)
117	}
118}
119
120func startServer() {
121	http.Handle("/echo", Handler(echoServer))
122	http.Handle("/count", Handler(countServer))
123	http.Handle("/ctrldata", Handler(ctrlAndDataServer))
124	subproto := Server{
125		Handshake: subProtocolHandshake,
126		Handler:   Handler(subProtoServer),
127	}
128	http.Handle("/subproto", subproto)
129	server := httptest.NewServer(nil)
130	serverAddr = server.Listener.Addr().String()
131	log.Print("Test WebSocket server listening on ", serverAddr)
132}
133
134func newConfig(t *testing.T, path string) *Config {
135	config, _ := NewConfig(fmt.Sprintf("ws://%s%s", serverAddr, path), "http://localhost")
136	return config
137}
138
139func TestEcho(t *testing.T) {
140	once.Do(startServer)
141
142	// websocket.Dial()
143	client, err := net.Dial("tcp", serverAddr)
144	if err != nil {
145		t.Fatal("dialing", err)
146	}
147	conn, err := NewClient(newConfig(t, "/echo"), client)
148	if err != nil {
149		t.Errorf("WebSocket handshake error: %v", err)
150		return
151	}
152
153	msg := []byte("hello, world\n")
154	if _, err := conn.Write(msg); err != nil {
155		t.Errorf("Write: %v", err)
156	}
157	var actual_msg = make([]byte, 512)
158	n, err := conn.Read(actual_msg)
159	if err != nil {
160		t.Errorf("Read: %v", err)
161	}
162	actual_msg = actual_msg[0:n]
163	if !bytes.Equal(msg, actual_msg) {
164		t.Errorf("Echo: expected %q got %q", msg, actual_msg)
165	}
166	conn.Close()
167}
168
169func TestAddr(t *testing.T) {
170	once.Do(startServer)
171
172	// websocket.Dial()
173	client, err := net.Dial("tcp", serverAddr)
174	if err != nil {
175		t.Fatal("dialing", err)
176	}
177	conn, err := NewClient(newConfig(t, "/echo"), client)
178	if err != nil {
179		t.Errorf("WebSocket handshake error: %v", err)
180		return
181	}
182
183	ra := conn.RemoteAddr().String()
184	if !strings.HasPrefix(ra, "ws://") || !strings.HasSuffix(ra, "/echo") {
185		t.Errorf("Bad remote addr: %v", ra)
186	}
187	la := conn.LocalAddr().String()
188	if !strings.HasPrefix(la, "http://") {
189		t.Errorf("Bad local addr: %v", la)
190	}
191	conn.Close()
192}
193
194func TestCount(t *testing.T) {
195	once.Do(startServer)
196
197	// websocket.Dial()
198	client, err := net.Dial("tcp", serverAddr)
199	if err != nil {
200		t.Fatal("dialing", err)
201	}
202	conn, err := NewClient(newConfig(t, "/count"), client)
203	if err != nil {
204		t.Errorf("WebSocket handshake error: %v", err)
205		return
206	}
207
208	var count Count
209	count.S = "hello"
210	if err := JSON.Send(conn, count); err != nil {
211		t.Errorf("Write: %v", err)
212	}
213	if err := JSON.Receive(conn, &count); err != nil {
214		t.Errorf("Read: %v", err)
215	}
216	if count.N != 1 {
217		t.Errorf("count: expected %d got %d", 1, count.N)
218	}
219	if count.S != "hello" {
220		t.Errorf("count: expected %q got %q", "hello", count.S)
221	}
222	if err := JSON.Send(conn, count); err != nil {
223		t.Errorf("Write: %v", err)
224	}
225	if err := JSON.Receive(conn, &count); err != nil {
226		t.Errorf("Read: %v", err)
227	}
228	if count.N != 2 {
229		t.Errorf("count: expected %d got %d", 2, count.N)
230	}
231	if count.S != "hellohello" {
232		t.Errorf("count: expected %q got %q", "hellohello", count.S)
233	}
234	conn.Close()
235}
236
237func TestWithQuery(t *testing.T) {
238	once.Do(startServer)
239
240	client, err := net.Dial("tcp", serverAddr)
241	if err != nil {
242		t.Fatal("dialing", err)
243	}
244
245	config := newConfig(t, "/echo")
246	config.Location, err = url.ParseRequestURI(fmt.Sprintf("ws://%s/echo?q=v", serverAddr))
247	if err != nil {
248		t.Fatal("location url", err)
249	}
250
251	ws, err := NewClient(config, client)
252	if err != nil {
253		t.Errorf("WebSocket handshake: %v", err)
254		return
255	}
256	ws.Close()
257}
258
259func testWithProtocol(t *testing.T, subproto []string) (string, error) {
260	once.Do(startServer)
261
262	client, err := net.Dial("tcp", serverAddr)
263	if err != nil {
264		t.Fatal("dialing", err)
265	}
266
267	config := newConfig(t, "/subproto")
268	config.Protocol = subproto
269
270	ws, err := NewClient(config, client)
271	if err != nil {
272		return "", err
273	}
274	msg := make([]byte, 16)
275	n, err := ws.Read(msg)
276	if err != nil {
277		return "", err
278	}
279	ws.Close()
280	return string(msg[:n]), nil
281}
282
283func TestWithProtocol(t *testing.T) {
284	proto, err := testWithProtocol(t, []string{"chat"})
285	if err != nil {
286		t.Errorf("SubProto: unexpected error: %v", err)
287	}
288	if proto != "chat" {
289		t.Errorf("SubProto: expected %q, got %q", "chat", proto)
290	}
291}
292
293func TestWithTwoProtocol(t *testing.T) {
294	proto, err := testWithProtocol(t, []string{"test", "chat"})
295	if err != nil {
296		t.Errorf("SubProto: unexpected error: %v", err)
297	}
298	if proto != "chat" {
299		t.Errorf("SubProto: expected %q, got %q", "chat", proto)
300	}
301}
302
303func TestWithBadProtocol(t *testing.T) {
304	_, err := testWithProtocol(t, []string{"test"})
305	if err != ErrBadStatus {
306		t.Errorf("SubProto: expected %v, got %v", ErrBadStatus, err)
307	}
308}
309
310func TestHTTP(t *testing.T) {
311	once.Do(startServer)
312
313	// If the client did not send a handshake that matches the protocol
314	// specification, the server MUST return an HTTP response with an
315	// appropriate error code (such as 400 Bad Request)
316	resp, err := http.Get(fmt.Sprintf("http://%s/echo", serverAddr))
317	if err != nil {
318		t.Errorf("Get: error %#v", err)
319		return
320	}
321	if resp == nil {
322		t.Error("Get: resp is null")
323		return
324	}
325	if resp.StatusCode != http.StatusBadRequest {
326		t.Errorf("Get: expected %q got %q", http.StatusBadRequest, resp.StatusCode)
327	}
328}
329
330func TestTrailingSpaces(t *testing.T) {
331	// http://code.google.com/p/go/issues/detail?id=955
332	// The last runs of this create keys with trailing spaces that should not be
333	// generated by the client.
334	once.Do(startServer)
335	config := newConfig(t, "/echo")
336	for i := 0; i < 30; i++ {
337		// body
338		ws, err := DialConfig(config)
339		if err != nil {
340			t.Errorf("Dial #%d failed: %v", i, err)
341			break
342		}
343		ws.Close()
344	}
345}
346
347func TestDialConfigBadVersion(t *testing.T) {
348	once.Do(startServer)
349	config := newConfig(t, "/echo")
350	config.Version = 1234
351
352	_, err := DialConfig(config)
353
354	if dialerr, ok := err.(*DialError); ok {
355		if dialerr.Err != ErrBadProtocolVersion {
356			t.Errorf("dial expected err %q but got %q", ErrBadProtocolVersion, dialerr.Err)
357		}
358	}
359}
360
361func TestDialConfigWithDialer(t *testing.T) {
362	once.Do(startServer)
363	config := newConfig(t, "/echo")
364	config.Dialer = &net.Dialer{
365		Deadline: time.Now().Add(-time.Minute),
366	}
367	_, err := DialConfig(config)
368	dialerr, ok := err.(*DialError)
369	if !ok {
370		t.Fatalf("DialError expected, got %#v", err)
371	}
372	neterr, ok := dialerr.Err.(*net.OpError)
373	if !ok {
374		t.Fatalf("net.OpError error expected, got %#v", dialerr.Err)
375	}
376	if !neterr.Timeout() {
377		t.Fatalf("expected timeout error, got %#v", neterr)
378	}
379}
380
381func TestSmallBuffer(t *testing.T) {
382	// http://code.google.com/p/go/issues/detail?id=1145
383	// Read should be able to handle reading a fragment of a frame.
384	once.Do(startServer)
385
386	// websocket.Dial()
387	client, err := net.Dial("tcp", serverAddr)
388	if err != nil {
389		t.Fatal("dialing", err)
390	}
391	conn, err := NewClient(newConfig(t, "/echo"), client)
392	if err != nil {
393		t.Errorf("WebSocket handshake error: %v", err)
394		return
395	}
396
397	msg := []byte("hello, world\n")
398	if _, err := conn.Write(msg); err != nil {
399		t.Errorf("Write: %v", err)
400	}
401	var small_msg = make([]byte, 8)
402	n, err := conn.Read(small_msg)
403	if err != nil {
404		t.Errorf("Read: %v", err)
405	}
406	if !bytes.Equal(msg[:len(small_msg)], small_msg) {
407		t.Errorf("Echo: expected %q got %q", msg[:len(small_msg)], small_msg)
408	}
409	var second_msg = make([]byte, len(msg))
410	n, err = conn.Read(second_msg)
411	if err != nil {
412		t.Errorf("Read: %v", err)
413	}
414	second_msg = second_msg[0:n]
415	if !bytes.Equal(msg[len(small_msg):], second_msg) {
416		t.Errorf("Echo: expected %q got %q", msg[len(small_msg):], second_msg)
417	}
418	conn.Close()
419}
420
421var parseAuthorityTests = []struct {
422	in  *url.URL
423	out string
424}{
425	{
426		&url.URL{
427			Scheme: "ws",
428			Host:   "www.google.com",
429		},
430		"www.google.com:80",
431	},
432	{
433		&url.URL{
434			Scheme: "wss",
435			Host:   "www.google.com",
436		},
437		"www.google.com:443",
438	},
439	{
440		&url.URL{
441			Scheme: "ws",
442			Host:   "www.google.com:80",
443		},
444		"www.google.com:80",
445	},
446	{
447		&url.URL{
448			Scheme: "wss",
449			Host:   "www.google.com:443",
450		},
451		"www.google.com:443",
452	},
453	// some invalid ones for parseAuthority. parseAuthority doesn't
454	// concern itself with the scheme unless it actually knows about it
455	{
456		&url.URL{
457			Scheme: "http",
458			Host:   "www.google.com",
459		},
460		"www.google.com",
461	},
462	{
463		&url.URL{
464			Scheme: "http",
465			Host:   "www.google.com:80",
466		},
467		"www.google.com:80",
468	},
469	{
470		&url.URL{
471			Scheme: "asdf",
472			Host:   "127.0.0.1",
473		},
474		"127.0.0.1",
475	},
476	{
477		&url.URL{
478			Scheme: "asdf",
479			Host:   "www.google.com",
480		},
481		"www.google.com",
482	},
483}
484
485func TestParseAuthority(t *testing.T) {
486	for _, tt := range parseAuthorityTests {
487		out := parseAuthority(tt.in)
488		if out != tt.out {
489			t.Errorf("got %v; want %v", out, tt.out)
490		}
491	}
492}
493
494type closerConn struct {
495	net.Conn
496	closed int // count of the number of times Close was called
497}
498
499func (c *closerConn) Close() error {
500	c.closed++
501	return c.Conn.Close()
502}
503
504func TestClose(t *testing.T) {
505	if runtime.GOOS == "plan9" {
506		t.Skip("see golang.org/issue/11454")
507	}
508
509	once.Do(startServer)
510
511	conn, err := net.Dial("tcp", serverAddr)
512	if err != nil {
513		t.Fatal("dialing", err)
514	}
515
516	cc := closerConn{Conn: conn}
517
518	client, err := NewClient(newConfig(t, "/echo"), &cc)
519	if err != nil {
520		t.Fatalf("WebSocket handshake: %v", err)
521	}
522
523	// set the deadline to ten minutes ago, which will have expired by the time
524	// client.Close sends the close status frame.
525	conn.SetDeadline(time.Now().Add(-10 * time.Minute))
526
527	if err := client.Close(); err == nil {
528		t.Errorf("ws.Close(): expected error, got %v", err)
529	}
530	if cc.closed < 1 {
531		t.Fatalf("ws.Close(): expected underlying ws.rwc.Close to be called > 0 times, got: %v", cc.closed)
532	}
533}
534
535var originTests = []struct {
536	req    *http.Request
537	origin *url.URL
538}{
539	{
540		req: &http.Request{
541			Header: http.Header{
542				"Origin": []string{"http://www.example.com"},
543			},
544		},
545		origin: &url.URL{
546			Scheme: "http",
547			Host:   "www.example.com",
548		},
549	},
550	{
551		req: &http.Request{},
552	},
553}
554
555func TestOrigin(t *testing.T) {
556	conf := newConfig(t, "/echo")
557	conf.Version = ProtocolVersionHybi13
558	for i, tt := range originTests {
559		origin, err := Origin(conf, tt.req)
560		if err != nil {
561			t.Error(err)
562			continue
563		}
564		if !reflect.DeepEqual(origin, tt.origin) {
565			t.Errorf("#%d: got origin %v; want %v", i, origin, tt.origin)
566			continue
567		}
568	}
569}
570
571func TestCtrlAndData(t *testing.T) {
572	once.Do(startServer)
573
574	c, err := net.Dial("tcp", serverAddr)
575	if err != nil {
576		t.Fatal(err)
577	}
578	ws, err := NewClient(newConfig(t, "/ctrldata"), c)
579	if err != nil {
580		t.Fatal(err)
581	}
582	defer ws.Close()
583
584	h := &testCtrlAndDataHandler{hybiFrameHandler: hybiFrameHandler{conn: ws}}
585	ws.frameHandler = h
586
587	b := make([]byte, 128)
588	for i := 0; i < 2; i++ {
589		data := []byte(fmt.Sprintf("#%d-DATA-FRAME-FROM-CLIENT", i))
590		if _, err := ws.Write(data); err != nil {
591			t.Fatalf("#%d: %v", i, err)
592		}
593		var ctrl []byte
594		if i%2 != 0 { // with or without payload
595			ctrl = []byte(fmt.Sprintf("#%d-CONTROL-FRAME-FROM-CLIENT", i))
596		}
597		if _, err := h.WritePing(ctrl); err != nil {
598			t.Fatalf("#%d: %v", i, err)
599		}
600		n, err := ws.Read(b)
601		if err != nil {
602			t.Fatalf("#%d: %v", i, err)
603		}
604		if !bytes.Equal(b[:n], data) {
605			t.Fatalf("#%d: got %v; want %v", i, b[:n], data)
606		}
607	}
608}
609
610func TestCodec_ReceiveLimited(t *testing.T) {
611	const limit = 2048
612	var payloads [][]byte
613	for _, size := range []int{
614		1024,
615		2048,
616		4096, // receive of this message would be interrupted due to limit
617		2048, // this one is to make sure next receive recovers discarding leftovers
618	} {
619		b := make([]byte, size)
620		rand.Read(b)
621		payloads = append(payloads, b)
622	}
623	handlerDone := make(chan struct{})
624	limitedHandler := func(ws *Conn) {
625		defer close(handlerDone)
626		ws.MaxPayloadBytes = limit
627		defer ws.Close()
628		for i, p := range payloads {
629			t.Logf("payload #%d (size %d, exceeds limit: %v)", i, len(p), len(p) > limit)
630			var recv []byte
631			err := Message.Receive(ws, &recv)
632			switch err {
633			case nil:
634			case ErrFrameTooLarge:
635				if len(p) <= limit {
636					t.Fatalf("unexpected frame size limit: expected %d bytes of payload having limit at %d", len(p), limit)
637				}
638				continue
639			default:
640				t.Fatalf("unexpected error: %v (want either nil or ErrFrameTooLarge)", err)
641			}
642			if len(recv) > limit {
643				t.Fatalf("received %d bytes of payload having limit at %d", len(recv), limit)
644			}
645			if !bytes.Equal(p, recv) {
646				t.Fatalf("received payload differs:\ngot:\t%v\nwant:\t%v", recv, p)
647			}
648		}
649	}
650	server := httptest.NewServer(Handler(limitedHandler))
651	defer server.CloseClientConnections()
652	defer server.Close()
653	addr := server.Listener.Addr().String()
654	ws, err := Dial("ws://"+addr+"/", "", "http://localhost/")
655	if err != nil {
656		t.Fatal(err)
657	}
658	defer ws.Close()
659	for i, p := range payloads {
660		if err := Message.Send(ws, p); err != nil {
661			t.Fatalf("payload #%d (size %d): %v", i, len(p), err)
662		}
663	}
664	<-handlerDone
665}
666