1/*
2Copyright 2015 The Kubernetes Authors.
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8    http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15*/
16
17package wsstream
18
19import (
20	"encoding/base64"
21	"fmt"
22	"io"
23	"net/http"
24	"regexp"
25	"strings"
26	"time"
27
28	"golang.org/x/net/websocket"
29	"k8s.io/klog/v2"
30
31	"k8s.io/apimachinery/pkg/util/runtime"
32)
33
34// The Websocket subprotocol "channel.k8s.io" prepends each binary message with a byte indicating
35// the channel number (zero indexed) the message was sent on. Messages in both directions should
36// prefix their messages with this channel byte. When used for remote execution, the channel numbers
37// are by convention defined to match the POSIX file-descriptors assigned to STDIN, STDOUT, and STDERR
38// (0, 1, and 2). No other conversion is performed on the raw subprotocol - writes are sent as they
39// are received by the server.
40//
41// Example client session:
42//
43//    CONNECT http://server.com with subprotocol "channel.k8s.io"
44//    WRITE []byte{0, 102, 111, 111, 10} # send "foo\n" on channel 0 (STDIN)
45//    READ  []byte{1, 10}                # receive "\n" on channel 1 (STDOUT)
46//    CLOSE
47//
48const ChannelWebSocketProtocol = "channel.k8s.io"
49
50// The Websocket subprotocol "base64.channel.k8s.io" base64 encodes each message with a character
51// indicating the channel number (zero indexed) the message was sent on. Messages in both directions
52// should prefix their messages with this channel char. When used for remote execution, the channel
53// numbers are by convention defined to match the POSIX file-descriptors assigned to STDIN, STDOUT,
54// and STDERR ('0', '1', and '2'). The data received on the server is base64 decoded (and must be
55// be valid) and data written by the server to the client is base64 encoded.
56//
57// Example client session:
58//
59//    CONNECT http://server.com with subprotocol "base64.channel.k8s.io"
60//    WRITE []byte{48, 90, 109, 57, 118, 67, 103, 111, 61} # send "foo\n" (base64: "Zm9vCgo=") on channel '0' (STDIN)
61//    READ  []byte{49, 67, 103, 61, 61} # receive "\n" (base64: "Cg==") on channel '1' (STDOUT)
62//    CLOSE
63//
64const Base64ChannelWebSocketProtocol = "base64.channel.k8s.io"
65
66type codecType int
67
68const (
69	rawCodec codecType = iota
70	base64Codec
71)
72
73type ChannelType int
74
75const (
76	IgnoreChannel ChannelType = iota
77	ReadChannel
78	WriteChannel
79	ReadWriteChannel
80)
81
82var (
83	// connectionUpgradeRegex matches any Connection header value that includes upgrade
84	connectionUpgradeRegex = regexp.MustCompile("(^|.*,\\s*)upgrade($|\\s*,)")
85)
86
87// IsWebSocketRequest returns true if the incoming request contains connection upgrade headers
88// for WebSockets.
89func IsWebSocketRequest(req *http.Request) bool {
90	if !strings.EqualFold(req.Header.Get("Upgrade"), "websocket") {
91		return false
92	}
93	return connectionUpgradeRegex.MatchString(strings.ToLower(req.Header.Get("Connection")))
94}
95
96// IgnoreReceives reads from a WebSocket until it is closed, then returns. If timeout is set, the
97// read and write deadlines are pushed every time a new message is received.
98func IgnoreReceives(ws *websocket.Conn, timeout time.Duration) {
99	defer runtime.HandleCrash()
100	var data []byte
101	for {
102		resetTimeout(ws, timeout)
103		if err := websocket.Message.Receive(ws, &data); err != nil {
104			return
105		}
106	}
107}
108
109// handshake ensures the provided user protocol matches one of the allowed protocols. It returns
110// no error if no protocol is specified.
111func handshake(config *websocket.Config, req *http.Request, allowed []string) error {
112	protocols := config.Protocol
113	if len(protocols) == 0 {
114		protocols = []string{""}
115	}
116
117	for _, protocol := range protocols {
118		for _, allow := range allowed {
119			if allow == protocol {
120				config.Protocol = []string{protocol}
121				return nil
122			}
123		}
124	}
125
126	return fmt.Errorf("requested protocol(s) are not supported: %v; supports %v", config.Protocol, allowed)
127}
128
129// ChannelProtocolConfig describes a websocket subprotocol with channels.
130type ChannelProtocolConfig struct {
131	Binary   bool
132	Channels []ChannelType
133}
134
135// NewDefaultChannelProtocols returns a channel protocol map with the
136// subprotocols "", "channel.k8s.io", "base64.channel.k8s.io" and the given
137// channels.
138func NewDefaultChannelProtocols(channels []ChannelType) map[string]ChannelProtocolConfig {
139	return map[string]ChannelProtocolConfig{
140		"":                             {Binary: true, Channels: channels},
141		ChannelWebSocketProtocol:       {Binary: true, Channels: channels},
142		Base64ChannelWebSocketProtocol: {Binary: false, Channels: channels},
143	}
144}
145
146// Conn supports sending multiple binary channels over a websocket connection.
147type Conn struct {
148	protocols        map[string]ChannelProtocolConfig
149	selectedProtocol string
150	channels         []*websocketChannel
151	codec            codecType
152	ready            chan struct{}
153	ws               *websocket.Conn
154	timeout          time.Duration
155}
156
157// NewConn creates a WebSocket connection that supports a set of channels. Channels begin each
158// web socket message with a single byte indicating the channel number (0-N). 255 is reserved for
159// future use. The channel types for each channel are passed as an array, supporting the different
160// duplex modes. Read and Write refer to whether the channel can be used as a Reader or Writer.
161//
162// The protocols parameter maps subprotocol names to ChannelProtocols. The empty string subprotocol
163// name is used if websocket.Config.Protocol is empty.
164func NewConn(protocols map[string]ChannelProtocolConfig) *Conn {
165	return &Conn{
166		ready:     make(chan struct{}),
167		protocols: protocols,
168	}
169}
170
171// SetIdleTimeout sets the interval for both reads and writes before timeout. If not specified,
172// there is no timeout on the connection.
173func (conn *Conn) SetIdleTimeout(duration time.Duration) {
174	conn.timeout = duration
175}
176
177// Open the connection and create channels for reading and writing. It returns
178// the selected subprotocol, a slice of channels and an error.
179func (conn *Conn) Open(w http.ResponseWriter, req *http.Request) (string, []io.ReadWriteCloser, error) {
180	go func() {
181		defer runtime.HandleCrash()
182		defer conn.Close()
183		websocket.Server{Handshake: conn.handshake, Handler: conn.handle}.ServeHTTP(w, req)
184	}()
185	<-conn.ready
186	rwc := make([]io.ReadWriteCloser, len(conn.channels))
187	for i := range conn.channels {
188		rwc[i] = conn.channels[i]
189	}
190	return conn.selectedProtocol, rwc, nil
191}
192
193func (conn *Conn) initialize(ws *websocket.Conn) {
194	negotiated := ws.Config().Protocol
195	conn.selectedProtocol = negotiated[0]
196	p := conn.protocols[conn.selectedProtocol]
197	if p.Binary {
198		conn.codec = rawCodec
199	} else {
200		conn.codec = base64Codec
201	}
202	conn.ws = ws
203	conn.channels = make([]*websocketChannel, len(p.Channels))
204	for i, t := range p.Channels {
205		switch t {
206		case ReadChannel:
207			conn.channels[i] = newWebsocketChannel(conn, byte(i), true, false)
208		case WriteChannel:
209			conn.channels[i] = newWebsocketChannel(conn, byte(i), false, true)
210		case ReadWriteChannel:
211			conn.channels[i] = newWebsocketChannel(conn, byte(i), true, true)
212		case IgnoreChannel:
213			conn.channels[i] = newWebsocketChannel(conn, byte(i), false, false)
214		}
215	}
216
217	close(conn.ready)
218}
219
220func (conn *Conn) handshake(config *websocket.Config, req *http.Request) error {
221	supportedProtocols := make([]string, 0, len(conn.protocols))
222	for p := range conn.protocols {
223		supportedProtocols = append(supportedProtocols, p)
224	}
225	return handshake(config, req, supportedProtocols)
226}
227
228func (conn *Conn) resetTimeout() {
229	if conn.timeout > 0 {
230		conn.ws.SetDeadline(time.Now().Add(conn.timeout))
231	}
232}
233
234// Close is only valid after Open has been called
235func (conn *Conn) Close() error {
236	<-conn.ready
237	for _, s := range conn.channels {
238		s.Close()
239	}
240	conn.ws.Close()
241	return nil
242}
243
244// handle implements a websocket handler.
245func (conn *Conn) handle(ws *websocket.Conn) {
246	defer conn.Close()
247	conn.initialize(ws)
248
249	for {
250		conn.resetTimeout()
251		var data []byte
252		if err := websocket.Message.Receive(ws, &data); err != nil {
253			if err != io.EOF {
254				klog.Errorf("Error on socket receive: %v", err)
255			}
256			break
257		}
258		if len(data) == 0 {
259			continue
260		}
261		channel := data[0]
262		if conn.codec == base64Codec {
263			channel = channel - '0'
264		}
265		data = data[1:]
266		if int(channel) >= len(conn.channels) {
267			klog.V(6).Infof("Frame is targeted for a reader %d that is not valid, possible protocol error", channel)
268			continue
269		}
270		if _, err := conn.channels[channel].DataFromSocket(data); err != nil {
271			klog.Errorf("Unable to write frame to %d: %v\n%s", channel, err, string(data))
272			continue
273		}
274	}
275}
276
277// write multiplexes the specified channel onto the websocket
278func (conn *Conn) write(num byte, data []byte) (int, error) {
279	conn.resetTimeout()
280	switch conn.codec {
281	case rawCodec:
282		frame := make([]byte, len(data)+1)
283		frame[0] = num
284		copy(frame[1:], data)
285		if err := websocket.Message.Send(conn.ws, frame); err != nil {
286			return 0, err
287		}
288	case base64Codec:
289		frame := string('0'+num) + base64.StdEncoding.EncodeToString(data)
290		if err := websocket.Message.Send(conn.ws, frame); err != nil {
291			return 0, err
292		}
293	}
294	return len(data), nil
295}
296
297// websocketChannel represents a channel in a connection
298type websocketChannel struct {
299	conn *Conn
300	num  byte
301	r    io.Reader
302	w    io.WriteCloser
303
304	read, write bool
305}
306
307// newWebsocketChannel creates a pipe for writing to a websocket. Do not write to this pipe
308// prior to the connection being opened. It may be no, half, or full duplex depending on
309// read and write.
310func newWebsocketChannel(conn *Conn, num byte, read, write bool) *websocketChannel {
311	r, w := io.Pipe()
312	return &websocketChannel{conn, num, r, w, read, write}
313}
314
315func (p *websocketChannel) Write(data []byte) (int, error) {
316	if !p.write {
317		return len(data), nil
318	}
319	return p.conn.write(p.num, data)
320}
321
322// DataFromSocket is invoked by the connection receiver to move data from the connection
323// into a specific channel.
324func (p *websocketChannel) DataFromSocket(data []byte) (int, error) {
325	if !p.read {
326		return len(data), nil
327	}
328
329	switch p.conn.codec {
330	case rawCodec:
331		return p.w.Write(data)
332	case base64Codec:
333		dst := make([]byte, len(data))
334		n, err := base64.StdEncoding.Decode(dst, data)
335		if err != nil {
336			return 0, err
337		}
338		return p.w.Write(dst[:n])
339	}
340	return 0, nil
341}
342
343func (p *websocketChannel) Read(data []byte) (int, error) {
344	if !p.read {
345		return 0, io.EOF
346	}
347	return p.r.Read(data)
348}
349
350func (p *websocketChannel) Close() error {
351	return p.w.Close()
352}
353