1package sockjs
2
3import (
4	"fmt"
5	"net/http"
6	"strings"
7	"time"
8
9	"github.com/gorilla/websocket"
10)
11
12// WebSocketReadBufSize is a parameter that is used for WebSocket Upgrader.
13// https://github.com/gorilla/websocket/blob/master/server.go#L230
14var WebSocketReadBufSize = 4096
15
16// WebSocketWriteBufSize is a parameter that is used for WebSocket Upgrader
17// https://github.com/gorilla/websocket/blob/master/server.go#L230
18var WebSocketWriteBufSize = 4096
19
20func (h *handler) sockjsWebsocket(rw http.ResponseWriter, req *http.Request) {
21	var conn *websocket.Conn
22	var err error
23	if h.options.WebsocketUpgrader != nil {
24		conn, err = h.options.WebsocketUpgrader.Upgrade(rw, req, nil)
25	} else {
26		// use default as before, so that those 2 buffer size variables are used as before
27		conn, err = websocket.Upgrade(rw, req, nil, WebSocketReadBufSize, WebSocketWriteBufSize)
28	}
29	if _, ok := err.(websocket.HandshakeError); ok {
30		http.Error(rw, `Can "Upgrade" only to "WebSocket".`, http.StatusBadRequest)
31		return
32	} else if err != nil {
33		rw.WriteHeader(http.StatusInternalServerError)
34		return
35	}
36	sessID, _ := h.parseSessionID(req.URL)
37	sess := newSession(req, sessID, h.options.DisconnectDelay, h.options.HeartbeatDelay)
38	receiver := newWsReceiver(conn, h.options.WebsocketWriteTimeout)
39	sess.attachReceiver(receiver)
40	if h.handlerFunc != nil {
41		go h.handlerFunc(sess)
42	}
43	readCloseCh := make(chan struct{})
44	go func() {
45		var d []string
46		for {
47			err := conn.ReadJSON(&d)
48			if err != nil {
49				close(readCloseCh)
50				return
51			}
52			sess.accept(d...)
53		}
54	}()
55
56	select {
57	case <-readCloseCh:
58	case <-receiver.doneNotify():
59	}
60	sess.close()
61	conn.Close()
62}
63
64type wsReceiver struct {
65	conn         *websocket.Conn
66	closeCh      chan struct{}
67	writeTimeout time.Duration
68}
69
70func newWsReceiver(conn *websocket.Conn, writeTimeout time.Duration) *wsReceiver {
71	return &wsReceiver{
72		conn:         conn,
73		closeCh:      make(chan struct{}),
74		writeTimeout: writeTimeout,
75	}
76}
77
78func (w *wsReceiver) sendBulk(messages ...string) {
79	if len(messages) > 0 {
80		w.sendFrame(fmt.Sprintf("a[%s]", strings.Join(transform(messages, quote), ",")))
81	}
82}
83
84func (w *wsReceiver) sendFrame(frame string) {
85	if w.writeTimeout != 0 {
86		w.conn.SetWriteDeadline(time.Now().Add(w.writeTimeout))
87	}
88	if err := w.conn.WriteMessage(websocket.TextMessage, []byte(frame)); err != nil {
89		w.close()
90	}
91}
92
93func (w *wsReceiver) close() {
94	select {
95	case <-w.closeCh: // already closed
96	default:
97		close(w.closeCh)
98	}
99}
100func (w *wsReceiver) canSend() bool {
101	select {
102	case <-w.closeCh: // already closed
103		return false
104	default:
105		return true
106	}
107}
108func (w *wsReceiver) doneNotify() <-chan struct{}        { return w.closeCh }
109func (w *wsReceiver) interruptedNotify() <-chan struct{} { return nil }
110