1package terminal
2
3import (
4	"encoding/base64"
5	"io"
6	"net"
7	"time"
8
9	"github.com/gorilla/websocket"
10	"golang.org/x/text/encoding"
11	"golang.org/x/text/encoding/unicode"
12)
13
14func Wrap(conn Connection, subprotocol string) Connection {
15	switch subprotocol {
16	case "channel.k8s.io":
17		return &kubeWrapper{base64: false, conn: conn}
18	case "base64.channel.k8s.io":
19		return &kubeWrapper{base64: true, conn: conn}
20	case "terminal.gitlab.com":
21		return &gitlabWrapper{base64: false, conn: conn}
22	case "base64.terminal.gitlab.com":
23		return &gitlabWrapper{base64: true, conn: conn}
24	}
25
26	return conn
27}
28
29func NewIOWrapper(conn Connection) *ioWrapper {
30	return &ioWrapper{
31		Connection:  conn,
32		messageType: websocket.BinaryMessage,
33		encoder:     unicode.UTF8.NewEncoder(),
34		decoder:     unicode.UTF8.NewDecoder(),
35	}
36}
37
38type kubeWrapper struct {
39	base64 bool
40	conn   Connection
41}
42
43type gitlabWrapper struct {
44	base64 bool
45	conn   Connection
46}
47
48type ioWrapper struct {
49	Connection
50	messageType int
51
52	encoder *encoding.Encoder
53	decoder *encoding.Decoder
54}
55
56func (w *gitlabWrapper) ReadMessage() (int, []byte, error) {
57	mt, data, err := w.conn.ReadMessage()
58	if err != nil {
59		return mt, data, err
60	}
61
62	if isData(mt) {
63		mt = websocket.BinaryMessage
64		if w.base64 {
65			data, err = decodeBase64(data)
66			if err != nil {
67			}
68		}
69	}
70
71	return mt, data, err
72}
73
74func (w *gitlabWrapper) WriteMessage(mt int, data []byte) error {
75	if isData(mt) {
76		if w.base64 {
77			mt = websocket.TextMessage
78			data = encodeBase64(data)
79		} else {
80			mt = websocket.BinaryMessage
81		}
82	}
83
84	return w.conn.WriteMessage(mt, data)
85}
86
87func (w *gitlabWrapper) WriteControl(mt int, data []byte, deadline time.Time) error {
88	return w.conn.WriteControl(mt, data, deadline)
89}
90
91func (w *gitlabWrapper) Close() error {
92	return w.conn.UnderlyingConn().Close()
93}
94
95func (w *gitlabWrapper) UnderlyingConn() net.Conn {
96	return w.conn.UnderlyingConn()
97}
98
99// Coalesces all wsstreams into a single stream. In practice, we should only
100// receive data on stream 1.
101func (w *kubeWrapper) ReadMessage() (int, []byte, error) {
102	mt, data, err := w.conn.ReadMessage()
103	if err != nil {
104		return mt, data, err
105	}
106
107	if isData(mt) {
108		mt = websocket.BinaryMessage
109
110		// Remove the WSStream channel number, decode to raw
111		if len(data) > 0 {
112			data = data[1:]
113			if w.base64 {
114				data, err = decodeBase64(data)
115			}
116		}
117	}
118
119	return mt, data, err
120}
121
122// Always sends to wsstream 0
123func (w *kubeWrapper) WriteMessage(mt int, data []byte) error {
124	if isData(mt) {
125		if w.base64 {
126			mt = websocket.TextMessage
127			data = append([]byte{'0'}, encodeBase64(data)...)
128		} else {
129			mt = websocket.BinaryMessage
130			data = append([]byte{0}, data...)
131		}
132	}
133
134	return w.conn.WriteMessage(mt, data)
135}
136
137func (w *kubeWrapper) WriteControl(mt int, data []byte, deadline time.Time) error {
138	return w.conn.WriteControl(mt, data, deadline)
139}
140
141func (w *kubeWrapper) UnderlyingConn() net.Conn {
142	return w.conn.UnderlyingConn()
143}
144
145// encodes the given data as utf-8 and writes it to the websocket
146func (w *ioWrapper) Write(data []byte) (n int, err error) {
147	n = len(data)
148	if w.messageType != websocket.BinaryMessage {
149		utf8, err := w.encoder.String(string(data))
150		if err != nil {
151			return 0, err
152		}
153		data = []byte(utf8)
154	}
155	err = w.WriteMessage(w.messageType, data)
156	return n, err
157}
158
159// decodes utf-8 encoded data from the websocket
160func (w *ioWrapper) Read(out []byte) (n int, err error) {
161	mt, data, err := w.ReadMessage()
162	if mt != websocket.BinaryMessage {
163		switch err {
164		case nil:
165			data, err = w.decoder.Bytes(data)
166		case io.EOF:
167			return 0, io.EOF
168		}
169	}
170	if err != nil {
171		return 0, err
172	}
173	w.messageType = mt
174	return copy(out, data), nil
175}
176
177func isData(mt int) bool {
178	return mt == websocket.BinaryMessage || mt == websocket.TextMessage
179}
180
181func encodeBase64(data []byte) []byte {
182	buf := make([]byte, base64.StdEncoding.EncodedLen(len(data)))
183	base64.StdEncoding.Encode(buf, data)
184
185	return buf
186}
187
188func decodeBase64(data []byte) ([]byte, error) {
189	buf := make([]byte, base64.StdEncoding.DecodedLen(len(data)))
190	n, err := base64.StdEncoding.Decode(buf, data)
191	return buf[:n], err
192}
193