1package terminal 2 3import ( 4 "bytes" 5 "errors" 6 "net" 7 "testing" 8 "time" 9 10 "github.com/gorilla/websocket" 11) 12 13type testcase struct { 14 input *fakeConn 15 expected *fakeConn 16} 17 18type fakeConn struct { 19 // WebSocket message type 20 mt int 21 data []byte 22 err error 23} 24 25func (f *fakeConn) ReadMessage() (int, []byte, error) { 26 return f.mt, f.data, f.err 27} 28 29func (f *fakeConn) WriteMessage(mt int, data []byte) error { 30 f.mt = mt 31 f.data = data 32 return f.err 33} 34 35func (f *fakeConn) WriteControl(mt int, data []byte, _ time.Time) error { 36 f.mt = mt 37 f.data = data 38 return f.err 39} 40 41func (f *fakeConn) UnderlyingConn() net.Conn { 42 return nil 43} 44 45func fake(mt int, data []byte, err error) *fakeConn { 46 return &fakeConn{mt: mt, data: []byte(data), err: err} 47} 48 49var ( 50 msg = []byte("foo bar") 51 msgBase64 = []byte("Zm9vIGJhcg==") 52 kubeMsg = append([]byte{0}, msg...) 53 kubeMsgBase64 = append([]byte{'0'}, msgBase64...) 54 55 fakeErr = errors.New("fake error") 56 57 text = websocket.TextMessage 58 binary = websocket.BinaryMessage 59 other = 999 60 61 fakeOther = fake(other, []byte("foo"), nil) 62) 63 64func assertEqual(t *testing.T, expected, actual *fakeConn, msg string, args ...interface{}) { 65 if expected.mt != actual.mt { 66 t.Logf("messageType expected to be %v but was %v", expected.mt, actual.mt) 67 t.Fatalf(msg, args...) 68 } 69 70 if bytes.Compare(expected.data, actual.data) != 0 { 71 t.Logf("data expected to be %q but was %q: ", expected.data, actual.data) 72 t.Fatalf(msg, args...) 73 } 74 75 if expected.err != actual.err { 76 t.Logf("error expected to be %v but was %v", expected.err, actual.err) 77 t.Fatalf(msg, args...) 78 } 79} 80 81func TestReadMessage(t *testing.T) { 82 testCases := map[string][]testcase{ 83 "channel.k8s.io": { 84 {fake(binary, kubeMsg, fakeErr), fake(binary, kubeMsg, fakeErr)}, 85 {fake(binary, kubeMsg, nil), fake(binary, msg, nil)}, 86 {fake(text, kubeMsg, nil), fake(binary, msg, nil)}, 87 {fakeOther, fakeOther}, 88 }, 89 "base64.channel.k8s.io": { 90 {fake(text, kubeMsgBase64, fakeErr), fake(text, kubeMsgBase64, fakeErr)}, 91 {fake(text, kubeMsgBase64, nil), fake(binary, msg, nil)}, 92 {fake(binary, kubeMsgBase64, nil), fake(binary, msg, nil)}, 93 {fakeOther, fakeOther}, 94 }, 95 "terminal.gitlab.com": { 96 {fake(binary, msg, fakeErr), fake(binary, msg, fakeErr)}, 97 {fake(binary, msg, nil), fake(binary, msg, nil)}, 98 {fake(text, msg, nil), fake(binary, msg, nil)}, 99 {fakeOther, fakeOther}, 100 }, 101 "base64.terminal.gitlab.com": { 102 {fake(text, msgBase64, fakeErr), fake(text, msgBase64, fakeErr)}, 103 {fake(text, msgBase64, nil), fake(binary, msg, nil)}, 104 {fake(binary, msgBase64, nil), fake(binary, msg, nil)}, 105 {fakeOther, fakeOther}, 106 }, 107 } 108 109 for subprotocol, cases := range testCases { 110 for i, tc := range cases { 111 conn := Wrap(tc.input, subprotocol) 112 mt, data, err := conn.ReadMessage() 113 actual := fake(mt, data, err) 114 assertEqual(t, tc.expected, actual, "%s test case %v", subprotocol, i) 115 } 116 } 117} 118 119func TestWriteMessage(t *testing.T) { 120 testCases := map[string][]testcase{ 121 "channel.k8s.io": { 122 {fake(binary, msg, fakeErr), fake(binary, kubeMsg, fakeErr)}, 123 {fake(binary, msg, nil), fake(binary, kubeMsg, nil)}, 124 {fake(text, msg, nil), fake(binary, kubeMsg, nil)}, 125 {fakeOther, fakeOther}, 126 }, 127 "base64.channel.k8s.io": { 128 {fake(binary, msg, fakeErr), fake(text, kubeMsgBase64, fakeErr)}, 129 {fake(binary, msg, nil), fake(text, kubeMsgBase64, nil)}, 130 {fake(text, msg, nil), fake(text, kubeMsgBase64, nil)}, 131 {fakeOther, fakeOther}, 132 }, 133 "terminal.gitlab.com": { 134 {fake(binary, msg, fakeErr), fake(binary, msg, fakeErr)}, 135 {fake(binary, msg, nil), fake(binary, msg, nil)}, 136 {fake(text, msg, nil), fake(binary, msg, nil)}, 137 {fakeOther, fakeOther}, 138 }, 139 "base64.terminal.gitlab.com": { 140 {fake(binary, msg, fakeErr), fake(text, msgBase64, fakeErr)}, 141 {fake(binary, msg, nil), fake(text, msgBase64, nil)}, 142 {fake(text, msg, nil), fake(text, msgBase64, nil)}, 143 {fakeOther, fakeOther}, 144 }, 145 } 146 147 for subprotocol, cases := range testCases { 148 for i, tc := range cases { 149 actual := fake(0, nil, tc.input.err) 150 conn := Wrap(actual, subprotocol) 151 actual.err = conn.WriteMessage(tc.input.mt, tc.input.data) 152 assertEqual(t, tc.expected, actual, "%s test case %v", subprotocol, i) 153 } 154 } 155} 156