1package terminal 2 3import ( 4 "encoding/base64" 5 "io" 6 "net" 7 "time" 8 9 "github.com/gorilla/websocket" 10 "golang.org/x/text/encoding" 11 "golang.org/x/text/encoding/unicode" 12) 13 14func Wrap(conn Connection, subprotocol string) Connection { 15 switch subprotocol { 16 case "channel.k8s.io": 17 return &kubeWrapper{base64: false, conn: conn} 18 case "base64.channel.k8s.io": 19 return &kubeWrapper{base64: true, conn: conn} 20 case "terminal.gitlab.com": 21 return &gitlabWrapper{base64: false, conn: conn} 22 case "base64.terminal.gitlab.com": 23 return &gitlabWrapper{base64: true, conn: conn} 24 } 25 26 return conn 27} 28 29func NewIOWrapper(conn Connection) *ioWrapper { 30 return &ioWrapper{ 31 Connection: conn, 32 messageType: websocket.BinaryMessage, 33 encoder: unicode.UTF8.NewEncoder(), 34 decoder: unicode.UTF8.NewDecoder(), 35 } 36} 37 38type kubeWrapper struct { 39 base64 bool 40 conn Connection 41} 42 43type gitlabWrapper struct { 44 base64 bool 45 conn Connection 46} 47 48type ioWrapper struct { 49 Connection 50 messageType int 51 52 encoder *encoding.Encoder 53 decoder *encoding.Decoder 54} 55 56func (w *gitlabWrapper) ReadMessage() (int, []byte, error) { 57 mt, data, err := w.conn.ReadMessage() 58 if err != nil { 59 return mt, data, err 60 } 61 62 if isData(mt) { 63 mt = websocket.BinaryMessage 64 if w.base64 { 65 data, err = decodeBase64(data) 66 if err != nil { 67 } 68 } 69 } 70 71 return mt, data, err 72} 73 74func (w *gitlabWrapper) WriteMessage(mt int, data []byte) error { 75 if isData(mt) { 76 if w.base64 { 77 mt = websocket.TextMessage 78 data = encodeBase64(data) 79 } else { 80 mt = websocket.BinaryMessage 81 } 82 } 83 84 return w.conn.WriteMessage(mt, data) 85} 86 87func (w *gitlabWrapper) WriteControl(mt int, data []byte, deadline time.Time) error { 88 return w.conn.WriteControl(mt, data, deadline) 89} 90 91func (w *gitlabWrapper) Close() error { 92 return w.conn.UnderlyingConn().Close() 93} 94 95func (w *gitlabWrapper) UnderlyingConn() net.Conn { 96 return w.conn.UnderlyingConn() 97} 98 99// Coalesces all wsstreams into a single stream. In practice, we should only 100// receive data on stream 1. 101func (w *kubeWrapper) ReadMessage() (int, []byte, error) { 102 mt, data, err := w.conn.ReadMessage() 103 if err != nil { 104 return mt, data, err 105 } 106 107 if isData(mt) { 108 mt = websocket.BinaryMessage 109 110 // Remove the WSStream channel number, decode to raw 111 if len(data) > 0 { 112 data = data[1:] 113 if w.base64 { 114 data, err = decodeBase64(data) 115 } 116 } 117 } 118 119 return mt, data, err 120} 121 122// Always sends to wsstream 0 123func (w *kubeWrapper) WriteMessage(mt int, data []byte) error { 124 if isData(mt) { 125 if w.base64 { 126 mt = websocket.TextMessage 127 data = append([]byte{'0'}, encodeBase64(data)...) 128 } else { 129 mt = websocket.BinaryMessage 130 data = append([]byte{0}, data...) 131 } 132 } 133 134 return w.conn.WriteMessage(mt, data) 135} 136 137func (w *kubeWrapper) WriteControl(mt int, data []byte, deadline time.Time) error { 138 return w.conn.WriteControl(mt, data, deadline) 139} 140 141func (w *kubeWrapper) UnderlyingConn() net.Conn { 142 return w.conn.UnderlyingConn() 143} 144 145// encodes the given data as utf-8 and writes it to the websocket 146func (w *ioWrapper) Write(data []byte) (n int, err error) { 147 n = len(data) 148 if w.messageType != websocket.BinaryMessage { 149 utf8, err := w.encoder.String(string(data)) 150 if err != nil { 151 return 0, err 152 } 153 data = []byte(utf8) 154 } 155 err = w.WriteMessage(w.messageType, data) 156 return n, err 157} 158 159// decodes utf-8 encoded data from the websocket 160func (w *ioWrapper) Read(out []byte) (n int, err error) { 161 mt, data, err := w.ReadMessage() 162 if mt != websocket.BinaryMessage { 163 switch err { 164 case nil: 165 data, err = w.decoder.Bytes(data) 166 case io.EOF: 167 return 0, io.EOF 168 } 169 } 170 if err != nil { 171 return 0, err 172 } 173 w.messageType = mt 174 return copy(out, data), nil 175} 176 177func isData(mt int) bool { 178 return mt == websocket.BinaryMessage || mt == websocket.TextMessage 179} 180 181func encodeBase64(data []byte) []byte { 182 buf := make([]byte, base64.StdEncoding.EncodedLen(len(data))) 183 base64.StdEncoding.Encode(buf, data) 184 185 return buf 186} 187 188func decodeBase64(data []byte) ([]byte, error) { 189 buf := make([]byte, base64.StdEncoding.DecodedLen(len(data))) 190 n, err := base64.StdEncoding.Decode(buf, data) 191 return buf[:n], err 192} 193