1package terminal
2
3import (
4	"bytes"
5	"errors"
6	"net"
7	"testing"
8	"time"
9
10	"github.com/gorilla/websocket"
11)
12
13type testcase struct {
14	input    *fakeConn
15	expected *fakeConn
16}
17
18type fakeConn struct {
19	// WebSocket message type
20	mt   int
21	data []byte
22	err  error
23}
24
25func (f *fakeConn) ReadMessage() (int, []byte, error) {
26	return f.mt, f.data, f.err
27}
28
29func (f *fakeConn) WriteMessage(mt int, data []byte) error {
30	f.mt = mt
31	f.data = data
32	return f.err
33}
34
35func (f *fakeConn) WriteControl(mt int, data []byte, _ time.Time) error {
36	f.mt = mt
37	f.data = data
38	return f.err
39}
40
41func (f *fakeConn) UnderlyingConn() net.Conn {
42	return nil
43}
44
45func fake(mt int, data []byte, err error) *fakeConn {
46	return &fakeConn{mt: mt, data: []byte(data), err: err}
47}
48
49var (
50	msg           = []byte("foo bar")
51	msgBase64     = []byte("Zm9vIGJhcg==")
52	kubeMsg       = append([]byte{0}, msg...)
53	kubeMsgBase64 = append([]byte{'0'}, msgBase64...)
54
55	fakeErr = errors.New("fake error")
56
57	text   = websocket.TextMessage
58	binary = websocket.BinaryMessage
59	other  = 999
60
61	fakeOther = fake(other, []byte("foo"), nil)
62)
63
64func assertEqual(t *testing.T, expected, actual *fakeConn, msg string, args ...interface{}) {
65	if expected.mt != actual.mt {
66		t.Logf("messageType expected to be %v but was %v", expected.mt, actual.mt)
67		t.Fatalf(msg, args...)
68	}
69
70	if bytes.Compare(expected.data, actual.data) != 0 {
71		t.Logf("data expected to be %q but was %q: ", expected.data, actual.data)
72		t.Fatalf(msg, args...)
73	}
74
75	if expected.err != actual.err {
76		t.Logf("error expected to be %v but was %v", expected.err, actual.err)
77		t.Fatalf(msg, args...)
78	}
79}
80
81func TestReadMessage(t *testing.T) {
82	testCases := map[string][]testcase{
83		"channel.k8s.io": {
84			{fake(binary, kubeMsg, fakeErr), fake(binary, kubeMsg, fakeErr)},
85			{fake(binary, kubeMsg, nil), fake(binary, msg, nil)},
86			{fake(text, kubeMsg, nil), fake(binary, msg, nil)},
87			{fakeOther, fakeOther},
88		},
89		"base64.channel.k8s.io": {
90			{fake(text, kubeMsgBase64, fakeErr), fake(text, kubeMsgBase64, fakeErr)},
91			{fake(text, kubeMsgBase64, nil), fake(binary, msg, nil)},
92			{fake(binary, kubeMsgBase64, nil), fake(binary, msg, nil)},
93			{fakeOther, fakeOther},
94		},
95		"terminal.gitlab.com": {
96			{fake(binary, msg, fakeErr), fake(binary, msg, fakeErr)},
97			{fake(binary, msg, nil), fake(binary, msg, nil)},
98			{fake(text, msg, nil), fake(binary, msg, nil)},
99			{fakeOther, fakeOther},
100		},
101		"base64.terminal.gitlab.com": {
102			{fake(text, msgBase64, fakeErr), fake(text, msgBase64, fakeErr)},
103			{fake(text, msgBase64, nil), fake(binary, msg, nil)},
104			{fake(binary, msgBase64, nil), fake(binary, msg, nil)},
105			{fakeOther, fakeOther},
106		},
107	}
108
109	for subprotocol, cases := range testCases {
110		for i, tc := range cases {
111			conn := Wrap(tc.input, subprotocol)
112			mt, data, err := conn.ReadMessage()
113			actual := fake(mt, data, err)
114			assertEqual(t, tc.expected, actual, "%s test case %v", subprotocol, i)
115		}
116	}
117}
118
119func TestWriteMessage(t *testing.T) {
120	testCases := map[string][]testcase{
121		"channel.k8s.io": {
122			{fake(binary, msg, fakeErr), fake(binary, kubeMsg, fakeErr)},
123			{fake(binary, msg, nil), fake(binary, kubeMsg, nil)},
124			{fake(text, msg, nil), fake(binary, kubeMsg, nil)},
125			{fakeOther, fakeOther},
126		},
127		"base64.channel.k8s.io": {
128			{fake(binary, msg, fakeErr), fake(text, kubeMsgBase64, fakeErr)},
129			{fake(binary, msg, nil), fake(text, kubeMsgBase64, nil)},
130			{fake(text, msg, nil), fake(text, kubeMsgBase64, nil)},
131			{fakeOther, fakeOther},
132		},
133		"terminal.gitlab.com": {
134			{fake(binary, msg, fakeErr), fake(binary, msg, fakeErr)},
135			{fake(binary, msg, nil), fake(binary, msg, nil)},
136			{fake(text, msg, nil), fake(binary, msg, nil)},
137			{fakeOther, fakeOther},
138		},
139		"base64.terminal.gitlab.com": {
140			{fake(binary, msg, fakeErr), fake(text, msgBase64, fakeErr)},
141			{fake(binary, msg, nil), fake(text, msgBase64, nil)},
142			{fake(text, msg, nil), fake(text, msgBase64, nil)},
143			{fakeOther, fakeOther},
144		},
145	}
146
147	for subprotocol, cases := range testCases {
148		for i, tc := range cases {
149			actual := fake(0, nil, tc.input.err)
150			conn := Wrap(actual, subprotocol)
151			actual.err = conn.WriteMessage(tc.input.mt, tc.input.data)
152			assertEqual(t, tc.expected, actual, "%s test case %v", subprotocol, i)
153		}
154	}
155}
156